# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass from typing import Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from xformers.ops import RMSNorm, fmha, rope_padded from xformers.ops.fmha.attn_bias import ( BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, ) import ctypes bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so') def bitnet_int8xint2_linear(input0, input1, s, ws): out_shape = list(input0.shape) out_shape[-1] = input1.shape[0] stream = torch.cuda.current_stream() M = input0.shape[0] if len(out_shape) == 3: M *= input0.shape[1] N = input1.shape[0] K = input1.shape[1] * 4 ret = torch.zeros(*out_shape, dtype=torch.bfloat16, device=input0.device) bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)]) return ret @dataclass class ModelArgs: dim: int = 2560 n_layers: int = 30 n_heads: int = 20 n_kv_heads: int = 5 vocab_size: int = 128256 ffn_dim: int = 6912 norm_eps: float = 1e-5 rope_theta: float = 500000.0 use_kernel: bool = False LayerCache = Tuple[torch.Tensor, torch.Tensor] class BitLinearKernel(nn.Module): in_features: int out_features: int weight: torch.Tensor weight_scale: torch.Tensor def __init__(self, in_features: int, out_features: int, bias: bool = False): super().__init__() self.in_features = in_features self.out_features = out_features self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features//4, dtype=torch.int8), requires_grad=False) self.weight_scale = torch.nn.Parameter(torch.zeros(4, dtype=torch.bfloat16), requires_grad=False) @torch.compile def quant_input(self, input): s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) return (input * s).round().clamp(-128, 127).to(torch.int8), s def forward(self, input): input, s = self.quant_input(input) return bitnet_int8xint2_linear(input, self.weight, s, self.weight_scale) class BitLinear(nn.Linear): @torch.compile def quant_input(self, input): s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) return (input * s).round().clamp(-128, 127) / s def forward(self, input): input = self.quant_input(input) return F.linear(input, self.weight) class Attention(nn.Module): def __init__( self, dim: int, head_dim: int, n_heads: int, n_kv_heads: int, rope_theta: float, norm_eps: float, use_kernel: bool, ): super().__init__() self.head_dim = head_dim self.rope_theta = rope_theta self.n_local_heads = n_heads self.n_local_kv_heads = n_kv_heads Linear = BitLinearKernel if use_kernel else BitLinear self.wqkv = Linear( dim, (self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim, bias=False, ) self.wo = Linear( self.n_local_heads * head_dim, dim, bias=False, ) self.attn_sub_norm = RMSNorm(dim, norm_eps) def forward( self, x: torch.Tensor, cache: LayerCache, attn_bias: AttnBias, ) -> torch.Tensor: xqkv = self.wqkv(x) xq = xqkv[:, : (self.n_local_heads * self.head_dim)] xkv = xqkv[:, (self.n_local_heads * self.head_dim) :] xk, xv = xkv.chunk(2, 1) output_shape = xq.shape heads_per_group = self.n_local_heads // self.n_local_kv_heads xq = xq.view( 1, xq.shape[0], self.n_local_kv_heads, heads_per_group, self.head_dim ) xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, 1, self.head_dim) # xq = rearrange(xq, 'b (g h l d) -> 1 b h g (d l)', g=heads_per_group, h=self.n_local_kv_heads, d=self.head_dim // 2, l=2) # xk = rearrange(xk, 'b (g l d) -> 1 b g 1 (d l)', g=self.n_local_kv_heads, d=self.head_dim // 2) xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, 1, self.head_dim) cache_k, cache_v = cache xq = rope_padded( xq=xq, xk=xk, xv=xv, cache_k=cache_k, cache_v=cache_v, attn_bias=attn_bias, theta=self.rope_theta, ) output = fmha.memory_efficient_attention_forward( xq, cache_k, cache_v, attn_bias, op = fmha.flash.FwOp ) output = output.reshape(output_shape) output = self.attn_sub_norm(output) output = self.wo(output) return output @torch.compile def squared_relu(x: torch.Tensor) -> torch.Tensor: return F.relu(x) ** 2 class FeedForward(nn.Module): def __init__( self, dim: int, hidden_dim: int, norm_eps: float, use_kernel: bool, ): super().__init__() Linear = BitLinearKernel if use_kernel else BitLinear self.w13 = Linear( dim, 2 * hidden_dim, bias=False, ) self.w2 = Linear( hidden_dim, dim, bias=False, ) self.ffn_sub_norm = RMSNorm(hidden_dim, norm_eps) def forward(self, x: torch.Tensor) -> torch.Tensor: x13 = self.w13(x) x1, x3 = x13.chunk(2, -1) inner = self.ffn_sub_norm(squared_relu(x1) * x3) output = self.w2(inner) return output class TransformerBlock(nn.Module): def __init__(self, args: ModelArgs): super().__init__() assert args.dim % args.n_heads == 0 head_dim = args.dim // args.n_heads if args.n_kv_heads is not None: n_kv_heads = args.n_kv_heads else: n_kv_heads = args.n_heads assert args.n_heads % n_kv_heads == 0 self.attention = Attention( dim=args.dim, head_dim=head_dim, n_heads=args.n_heads, n_kv_heads=n_kv_heads, rope_theta=args.rope_theta, norm_eps=args.norm_eps, use_kernel=args.use_kernel, ) self.feed_forward = FeedForward( dim=args.dim, hidden_dim=args.ffn_dim, norm_eps=args.norm_eps, use_kernel=args.use_kernel, ) self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) def forward( self, x: torch.Tensor, cache: LayerCache, attn_bias: AttnBias, ) -> torch.Tensor: h = x + self.attention.forward( self.attention_norm(x), cache, attn_bias, ) out = h + self.feed_forward(self.ffn_norm(h)) return out class Transformer(nn.Module): def __init__(self, args: ModelArgs): super().__init__() assert args.vocab_size > 0 self.tok_embeddings = nn.Embedding( num_embeddings=args.vocab_size, embedding_dim=args.dim, ) self.layers = nn.ModuleList() for _ in range(args.n_layers): self.layers.append(TransformerBlock(args)) self.norm = RMSNorm(args.dim, eps=args.norm_eps) self.output = nn.Linear( args.dim, args.vocab_size, bias=False, ) @torch.no_grad() def forward_with_attn_bias( self, token_values: torch.Tensor, attn_bias: AttnBias, cache: list[LayerCache], ) -> torch.Tensor: h = self.tok_embeddings(token_values) for i, layer in enumerate(self.layers): h = layer(h, cache[i], attn_bias) logits = self.output(self.norm(h)) return logits.float() def forward( self, token_values: torch.Tensor, token_lengths: torch.Tensor, start_pos: torch.Tensor, cache: list[LayerCache], kv_padding: int, ) -> torch.Tensor: attn_bias = AttnBias.from_seqlens( q_seqlen=token_lengths.tolist(), kv_seqlen=(start_pos + token_lengths).tolist(), kv_padding=kv_padding, ) return self.forward_with_attn_bias(token_values, attn_bias, cache) def make_cache( args: ModelArgs, length: int, device: Optional[Union[str, torch.device]] = None, n_layers: Optional[int] = None, dtype: Optional[torch.dtype] = None, ) -> list[LayerCache]: """ Allocate a cache to be used with the Transformer module. Args: args (ModelArgs): the model configuration. length (int): per layer cache size. It is usually budgeted as ``max_batch * max_seq`` device (torch.device, optional): the device on which the cache should be allocated. n_layers (int, optional): the number of layers to allocate a cache for (defaults to the model settings). dtype (torch.dtype, optional): the dtype to use for cache entries (defaults to the default dtype). Returns: The cache object to pass to ``Tranformer.forward``. """ head_dim = args.dim // args.n_heads n_kv_heads = args.n_kv_heads if n_kv_heads is None: n_kv_heads = args.n_heads n_local_kv_heads = n_kv_heads if n_layers is None: n_layers = args.n_layers shape = (1, length, n_local_kv_heads, 1, head_dim) heads_per_group = args.n_heads // n_kv_heads expansion = (-1, -1, -1, heads_per_group, -1) return [ ( torch.zeros(shape, device=device, dtype=dtype).expand(expansion), torch.zeros(shape, device=device, dtype=dtype).expand(expansion), ) for _ in range(n_layers) ] def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]: """ Take a prefix view of a larger cache. The original cache object remains of identical size and valid after the shrinked alias has been used. This function is useful when a cache was allocated for a larger batch size than what is necessary. Args: cache: the cache to take a view in. length (int): the desired length Returns: A view in the input cache object. """ if len(cache) > 0: assert cache[0][0].shape[1] >= length return [(ck[:, :length], cv[:, :length]) for ck, cv in cache]