SIGN IN SIGN UP
microsoft / BitNet UNCLAIMED

Official inference framework for 1-bit LLMs

36832 0 0 Python
2025-05-15 05:55:42 +00:00
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import functional as F
from xformers.ops import RMSNorm, fmha, rope_padded
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
)
import ctypes
bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so')
def bitnet_int8xint2_linear(input0, input1, s, ws):
out_shape = list(input0.shape)
out_shape[-1] = input1.shape[0]
stream = torch.cuda.current_stream()
M = input0.shape[0]
if len(out_shape) == 3:
M *= input0.shape[1]
N = input1.shape[0]
K = input1.shape[1] * 4
ret = torch.zeros(*out_shape, dtype=torch.bfloat16, device=input0.device)
bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)])
return ret
@dataclass
class ModelArgs:
dim: int = 2560
n_layers: int = 30
n_heads: int = 20
n_kv_heads: int = 5
vocab_size: int = 128256
ffn_dim: int = 6912
norm_eps: float = 1e-5
rope_theta: float = 500000.0
use_kernel: bool = False
LayerCache = Tuple[torch.Tensor, torch.Tensor]
class BitLinearKernel(nn.Module):
in_features: int
out_features: int
weight: torch.Tensor
weight_scale: torch.Tensor
def __init__(self, in_features: int, out_features: int, bias: bool = False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features//4, dtype=torch.int8), requires_grad=False)
self.weight_scale = torch.nn.Parameter(torch.zeros(4, dtype=torch.bfloat16), requires_grad=False)
@torch.compile
def quant_input(self, input):
s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
return (input * s).round().clamp(-128, 127).to(torch.int8), s
def forward(self, input):
input, s = self.quant_input(input)
return bitnet_int8xint2_linear(input, self.weight, s, self.weight_scale)
class BitLinear(nn.Linear):
@torch.compile
def quant_input(self, input):
s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
return (input * s).round().clamp(-128, 127) / s
def forward(self, input):
input = self.quant_input(input)
return F.linear(input, self.weight)
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int,
n_heads: int,
n_kv_heads: int,
rope_theta: float,
norm_eps: float,
use_kernel: bool,
):
super().__init__()
self.head_dim = head_dim
self.rope_theta = rope_theta
self.n_local_heads = n_heads
self.n_local_kv_heads = n_kv_heads
Linear = BitLinearKernel if use_kernel else BitLinear
self.wqkv = Linear(
dim,
(self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim,
bias=False,
)
self.wo = Linear(
self.n_local_heads * head_dim,
dim,
bias=False,
)
self.attn_sub_norm = RMSNorm(dim, norm_eps)
def forward(
self,
x: torch.Tensor,
cache: LayerCache,
attn_bias: AttnBias,
) -> torch.Tensor:
xqkv = self.wqkv(x)
xq = xqkv[:, : (self.n_local_heads * self.head_dim)]
xkv = xqkv[:, (self.n_local_heads * self.head_dim) :]
xk, xv = xkv.chunk(2, 1)
output_shape = xq.shape
heads_per_group = self.n_local_heads // self.n_local_kv_heads
xq = xq.view(
1, xq.shape[0], self.n_local_kv_heads, heads_per_group, self.head_dim
)
xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, 1, self.head_dim)
# xq = rearrange(xq, 'b (g h l d) -> 1 b h g (d l)', g=heads_per_group, h=self.n_local_kv_heads, d=self.head_dim // 2, l=2)
# xk = rearrange(xk, 'b (g l d) -> 1 b g 1 (d l)', g=self.n_local_kv_heads, d=self.head_dim // 2)
xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, 1, self.head_dim)
cache_k, cache_v = cache
xq = rope_padded(
xq=xq,
xk=xk,
xv=xv,
cache_k=cache_k,
cache_v=cache_v,
attn_bias=attn_bias,
theta=self.rope_theta,
)
output = fmha.memory_efficient_attention_forward(
xq, cache_k, cache_v, attn_bias, op = fmha.flash.FwOp
)
output = output.reshape(output_shape)
output = self.attn_sub_norm(output)
output = self.wo(output)
return output
@torch.compile
def squared_relu(x: torch.Tensor) -> torch.Tensor:
return F.relu(x) ** 2
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
norm_eps: float,
use_kernel: bool,
):
super().__init__()
Linear = BitLinearKernel if use_kernel else BitLinear
self.w13 = Linear(
dim,
2 * hidden_dim,
bias=False,
)
self.w2 = Linear(
hidden_dim,
dim,
bias=False,
)
self.ffn_sub_norm = RMSNorm(hidden_dim, norm_eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x13 = self.w13(x)
x1, x3 = x13.chunk(2, -1)
inner = self.ffn_sub_norm(squared_relu(x1) * x3)
output = self.w2(inner)
return output
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.dim % args.n_heads == 0
head_dim = args.dim // args.n_heads
if args.n_kv_heads is not None:
n_kv_heads = args.n_kv_heads
else:
n_kv_heads = args.n_heads
assert args.n_heads % n_kv_heads == 0
self.attention = Attention(
dim=args.dim,
head_dim=head_dim,
n_heads=args.n_heads,
n_kv_heads=n_kv_heads,
rope_theta=args.rope_theta,
norm_eps=args.norm_eps,
use_kernel=args.use_kernel,
)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=args.ffn_dim,
norm_eps=args.norm_eps,
use_kernel=args.use_kernel,
)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
cache: LayerCache,
attn_bias: AttnBias,
) -> torch.Tensor:
h = x + self.attention.forward(
self.attention_norm(x),
cache,
attn_bias,
)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.vocab_size > 0
self.tok_embeddings = nn.Embedding(
num_embeddings=args.vocab_size,
embedding_dim=args.dim,
)
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(TransformerBlock(args))
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(
args.dim,
args.vocab_size,
bias=False,
)
@torch.no_grad()
def forward_with_attn_bias(
self,
token_values: torch.Tensor,
attn_bias: AttnBias,
cache: list[LayerCache],
) -> torch.Tensor:
h = self.tok_embeddings(token_values)
for i, layer in enumerate(self.layers):
h = layer(h, cache[i], attn_bias)
logits = self.output(self.norm(h))
return logits.float()
def forward(
self,
token_values: torch.Tensor,
token_lengths: torch.Tensor,
start_pos: torch.Tensor,
cache: list[LayerCache],
kv_padding: int,
) -> torch.Tensor:
attn_bias = AttnBias.from_seqlens(
q_seqlen=token_lengths.tolist(),
kv_seqlen=(start_pos + token_lengths).tolist(),
kv_padding=kv_padding,
)
return self.forward_with_attn_bias(token_values, attn_bias, cache)
def make_cache(
args: ModelArgs,
length: int,
device: Optional[Union[str, torch.device]] = None,
n_layers: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
) -> list[LayerCache]:
"""
Allocate a cache to be used with the Transformer module.
Args:
args (ModelArgs): the model configuration.
length (int): per layer cache size.
It is usually budgeted as ``max_batch * max_seq``
device (torch.device, optional): the device on which
the cache should be allocated.
n_layers (int, optional): the number of layers to
allocate a cache for (defaults to the model
settings).
dtype (torch.dtype, optional): the dtype to use for
cache entries (defaults to the default dtype).
Returns:
The cache object to pass to ``Tranformer.forward``.
"""
head_dim = args.dim // args.n_heads
n_kv_heads = args.n_kv_heads
if n_kv_heads is None:
n_kv_heads = args.n_heads
n_local_kv_heads = n_kv_heads
if n_layers is None:
n_layers = args.n_layers
shape = (1, length, n_local_kv_heads, 1, head_dim)
heads_per_group = args.n_heads // n_kv_heads
expansion = (-1, -1, -1, heads_per_group, -1)
return [
(
torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
)
for _ in range(n_layers)
]
def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]:
"""
Take a prefix view of a larger cache.
The original cache object remains of identical size and valid
after the shrinked alias has been used. This function is useful
when a cache was allocated for a larger batch size than what is
necessary.
Args:
cache: the cache to take a view in.
length (int): the desired length
Returns:
A view in the input cache object.
"""
if len(cache) > 0:
assert cache[0][0].shape[1] >= length
return [(ck[:, :length], cv[:, :length]) for ck, cv in cache]