SIGN IN SIGN UP
microsoft / BitNet UNCLAIMED

Official inference framework for 1-bit LLMs

36832 0 0 Python
2025-05-15 05:55:42 +00:00
# BitNet Inference Kernel
This repository provides a highly efficient GEMV kernel implementation for the BitNet model, optimized for W2A8 inference — 2-bit weights and 8-bit activations. It is tailored for use with the [BitNet-b1.58-2B-4T](https://arxiv.org/abs/2504.12285) model.
## Features
- Support for W2A8 (2-bit weight × 8-bit activation) GEMV computation
- Custom CUDA kernels with low-latency execution
- Optimizations for memory access, decoding, and compute throughput
## Usage
Installation and kernel performance tests:
```bash
# (Recommended) Create a new conda environment
conda create --name bitnet-gpu "python<3.13"
conda activate bitnet-gpu
# Install dependencies
pip install -r requirements.txt
# Build the kernel
cd bitnet_kernels
bash compile.sh
cd ..
# Run performance tests
python test.py
```
End-to-end inference:
```bash
# Download and convert the BitNet-b1.58-2B model
mkdir checkpoints
huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./checkpoints/bitnet-b1.58-2B-4T-bf16
python ./convert_safetensors.py --safetensors_file ./checkpoints/bitnet-b1.58-2B-4T-bf16/model.safetensors --output checkpoints/model_state.pt --model_name 2B
python ./convert_checkpoint.py --input ./checkpoints/model_state.pt
rm ./checkpoints/model_state.pt
# Inference
python3 ./generate.py ./checkpoints/ --interactive --chat_format
```
## Optimizations
### Weight Permutation
The weight matrix is divided into 16×32 blocks to optimize memory access patterns.
Within each block, values are stored contiguously in memory and permuted to facilitate efficient access and processing.
See `convert_checkpoint.py` for details.
### Fast Decoding
Every 16 two-bit values are packed into a single 32-bit integer using the following interleaving pattern:
```
[0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
```
This layout is designed to accelerate decoding by enabling efficient extraction of 4 values at a time into `int8`.
### `dp4a` Instruction
We use the `dp4a` instruction to accelerate low-precision dot product operations.
This instruction performs a dot product between two 4-element vectors (each stored in a 32-bit word as 8-bit integers) and accumulates the result into a 32-bit integer.
It significantly improves GEMV throughput when processing quantized weights and activations.
## Performance
Kernel performance (tested on NVIDIA A100 40GB GPU):
| Shape (N×K) | W2A8 Latency (us) | BF16 Latency (us) | Speedup Ratio |
|---------------------|-------------------|-------------------|----------------------|
| 2560 × 2560 | 13.32 | 18.32 | 1.38 |
| 3840 × 2560 | 14.90 | 18.87 | 1.27 |
| 13824 × 2560 | 18.75 | 59.51 | 3.17 |
| 2560 × 6912 | 14.49 | 37.78 | 2.61 |
| 3200 × 3200 | 14.61 | 19.08 | 1.31 |
| 4800 × 3200 | 13.09 | 21.84 | 1.67 |
| 3200 × 10240 | 19.64 | 60.79 | 3.10 |
| 20480 × 3200 | 30.99 | 112.39 | 3.63 |
Generation throughput:
| BF16 (tokens/s) | W2A8 (tokens/s) | Speedup Ratio |
|---|---|---|
| 10.9 | 213.3 | 19.6 |