|
|
# BitNet Inference Kernel
|
|||
|
|
|
|||
|
|
This repository provides a highly efficient GEMV kernel implementation for the BitNet model, optimized for W2A8 inference — 2-bit weights and 8-bit activations. It is tailored for use with the [BitNet-b1.58-2B-4T](https://arxiv.org/abs/2504.12285) model.
|
|||
|
|
|
|||
|
|
## Features
|
|||
|
|
|
|||
|
|
- Support for W2A8 (2-bit weight × 8-bit activation) GEMV computation
|
|||
|
|
- Custom CUDA kernels with low-latency execution
|
|||
|
|
- Optimizations for memory access, decoding, and compute throughput
|
|||
|
|
|
|||
|
|
## Usage
|
|||
|
|
|
|||
|
|
Installation and kernel performance tests:
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
# (Recommended) Create a new conda environment
|
|||
|
|
conda create --name bitnet-gpu "python<3.13"
|
|||
|
|
conda activate bitnet-gpu
|
|||
|
|
|
|||
|
|
# Install dependencies
|
|||
|
|
pip install -r requirements.txt
|
|||
|
|
|
|||
|
|
# Build the kernel
|
|||
|
|
cd bitnet_kernels
|
|||
|
|
bash compile.sh
|
|||
|
|
cd ..
|
|||
|
|
|
|||
|
|
# Run performance tests
|
|||
|
|
python test.py
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
End-to-end inference:
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
# Download and convert the BitNet-b1.58-2B model
|
|||
|
|
mkdir checkpoints
|
|||
|
|
huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./checkpoints/bitnet-b1.58-2B-4T-bf16
|
|||
|
|
python ./convert_safetensors.py --safetensors_file ./checkpoints/bitnet-b1.58-2B-4T-bf16/model.safetensors --output checkpoints/model_state.pt --model_name 2B
|
|||
|
|
python ./convert_checkpoint.py --input ./checkpoints/model_state.pt
|
|||
|
|
rm ./checkpoints/model_state.pt
|
|||
|
|
|
|||
|
|
# Inference
|
|||
|
|
python3 ./generate.py ./checkpoints/ --interactive --chat_format
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## Optimizations
|
|||
|
|
|
|||
|
|
### Weight Permutation
|
|||
|
|
|
|||
|
|
The weight matrix is divided into 16×32 blocks to optimize memory access patterns.
|
|||
|
|
|
|||
|
|
Within each block, values are stored contiguously in memory and permuted to facilitate efficient access and processing.
|
|||
|
|
|
|||
|
|
See `convert_checkpoint.py` for details.
|
|||
|
|
|
|||
|
|
### Fast Decoding
|
|||
|
|
|
|||
|
|
Every 16 two-bit values are packed into a single 32-bit integer using the following interleaving pattern:
|
|||
|
|
```
|
|||
|
|
[0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
This layout is designed to accelerate decoding by enabling efficient extraction of 4 values at a time into `int8`.
|
|||
|
|
|
|||
|
|
### `dp4a` Instruction
|
|||
|
|
|
|||
|
|
We use the `dp4a` instruction to accelerate low-precision dot product operations.
|
|||
|
|
|
|||
|
|
This instruction performs a dot product between two 4-element vectors (each stored in a 32-bit word as 8-bit integers) and accumulates the result into a 32-bit integer.
|
|||
|
|
|
|||
|
|
It significantly improves GEMV throughput when processing quantized weights and activations.
|
|||
|
|
|
|||
|
|
|
|||
|
|
## Performance
|
|||
|
|
|
|||
|
|
Kernel performance (tested on NVIDIA A100 40GB GPU):
|
|||
|
|
|
|||
|
|
| Shape (N×K) | W2A8 Latency (us) | BF16 Latency (us) | Speedup Ratio |
|
|||
|
|
|---------------------|-------------------|-------------------|----------------------|
|
|||
|
|
| 2560 × 2560 | 13.32 | 18.32 | 1.38 |
|
|||
|
|
| 3840 × 2560 | 14.90 | 18.87 | 1.27 |
|
|||
|
|
| 13824 × 2560 | 18.75 | 59.51 | 3.17 |
|
|||
|
|
| 2560 × 6912 | 14.49 | 37.78 | 2.61 |
|
|||
|
|
| 3200 × 3200 | 14.61 | 19.08 | 1.31 |
|
|||
|
|
| 4800 × 3200 | 13.09 | 21.84 | 1.67 |
|
|||
|
|
| 3200 × 10240 | 19.64 | 60.79 | 3.10 |
|
|||
|
|
| 20480 × 3200 | 30.99 | 112.39 | 3.63 |
|
|||
|
|
|
|||
|
|
Generation throughput:
|
|||
|
|
|
|||
|
|
| BF16 (tokens/s) | W2A8 (tokens/s) | Speedup Ratio |
|
|||
|
|
|---|---|---|
|
|||
|
|
| 10.9 | 213.3 | 19.6 |
|