|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||
|
|
#
|
||
|
|
# This source code is licensed under the BSD license found in the
|
||
|
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Optional, Tuple, Union
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from torch import nn
|
||
|
|
from torch.nn import functional as F
|
||
|
|
|
||
|
|
from xformers.ops import RMSNorm, fmha, rope_padded
|
||
|
|
from xformers.ops.fmha.attn_bias import (
|
||
|
|
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
|
||
|
|
)
|
||
|
|
|
||
|
|
import ctypes
|
||
|
|
bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so')
|
||
|
|
|
||
|
|
def bitnet_int8xint2_linear(input0, input1, s, ws):
|
||
|
|
out_shape = list(input0.shape)
|
||
|
|
out_shape[-1] = input1.shape[0]
|
||
|
|
|
||
|
|
stream = torch.cuda.current_stream()
|
||
|
|
|
||
|
|
M = input0.shape[0]
|
||
|
|
if len(out_shape) == 3:
|
||
|
|
M *= input0.shape[1]
|
||
|
|
N = input1.shape[0]
|
||
|
|
K = input1.shape[1] * 4
|
||
|
|
|
||
|
|
ret = torch.zeros(*out_shape, dtype=torch.bfloat16, device=input0.device)
|
||
|
|
|
||
|
|
bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)])
|
||
|
|
|
||
|
|
return ret
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ModelArgs:
|
||
|
|
dim: int = 2560
|
||
|
|
n_layers: int = 30
|
||
|
|
n_heads: int = 20
|
||
|
|
n_kv_heads: int = 5
|
||
|
|
vocab_size: int = 128256
|
||
|
|
ffn_dim: int = 6912
|
||
|
|
norm_eps: float = 1e-5
|
||
|
|
rope_theta: float = 500000.0
|
||
|
|
use_kernel: bool = False
|
||
|
|
|
||
|
|
|
||
|
|
LayerCache = Tuple[torch.Tensor, torch.Tensor]
|
||
|
|
|
||
|
|
class BitLinearKernel(nn.Module):
|
||
|
|
in_features: int
|
||
|
|
out_features: int
|
||
|
|
weight: torch.Tensor
|
||
|
|
weight_scale: torch.Tensor
|
||
|
|
|
||
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = False):
|
||
|
|
super().__init__()
|
||
|
|
self.in_features = in_features
|
||
|
|
self.out_features = out_features
|
||
|
|
|
||
|
|
self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features//4, dtype=torch.int8), requires_grad=False)
|
||
|
|
self.weight_scale = torch.nn.Parameter(torch.zeros(4, dtype=torch.bfloat16), requires_grad=False)
|
||
|
|
|
||
|
|
@torch.compile
|
||
|
|
def quant_input(self, input):
|
||
|
|
s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
|
||
|
|
return (input * s).round().clamp(-128, 127).to(torch.int8), s
|
||
|
|
|
||
|
|
def forward(self, input):
|
||
|
|
input, s = self.quant_input(input)
|
||
|
|
return bitnet_int8xint2_linear(input, self.weight, s, self.weight_scale)
|
||
|
|
|
||
|
|
class BitLinear(nn.Linear):
|
||
|
|
@torch.compile
|
||
|
|
def quant_input(self, input):
|
||
|
|
s = 127 / input.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
|
||
|
|
return (input * s).round().clamp(-128, 127) / s
|
||
|
|
|
||
|
|
def forward(self, input):
|
||
|
|
input = self.quant_input(input)
|
||
|
|
return F.linear(input, self.weight)
|
||
|
|
|
||
|
|
class Attention(nn.Module):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
dim: int,
|
||
|
|
head_dim: int,
|
||
|
|
n_heads: int,
|
||
|
|
n_kv_heads: int,
|
||
|
|
rope_theta: float,
|
||
|
|
norm_eps: float,
|
||
|
|
use_kernel: bool,
|
||
|
|
):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
self.head_dim = head_dim
|
||
|
|
self.rope_theta = rope_theta
|
||
|
|
|
||
|
|
self.n_local_heads = n_heads
|
||
|
|
self.n_local_kv_heads = n_kv_heads
|
||
|
|
|
||
|
|
Linear = BitLinearKernel if use_kernel else BitLinear
|
||
|
|
|
||
|
|
self.wqkv = Linear(
|
||
|
|
dim,
|
||
|
|
(self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim,
|
||
|
|
bias=False,
|
||
|
|
)
|
||
|
|
self.wo = Linear(
|
||
|
|
self.n_local_heads * head_dim,
|
||
|
|
dim,
|
||
|
|
bias=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
self.attn_sub_norm = RMSNorm(dim, norm_eps)
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
x: torch.Tensor,
|
||
|
|
cache: LayerCache,
|
||
|
|
attn_bias: AttnBias,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
|
||
|
|
xqkv = self.wqkv(x)
|
||
|
|
xq = xqkv[:, : (self.n_local_heads * self.head_dim)]
|
||
|
|
xkv = xqkv[:, (self.n_local_heads * self.head_dim) :]
|
||
|
|
xk, xv = xkv.chunk(2, 1)
|
||
|
|
|
||
|
|
output_shape = xq.shape
|
||
|
|
heads_per_group = self.n_local_heads // self.n_local_kv_heads
|
||
|
|
xq = xq.view(
|
||
|
|
1, xq.shape[0], self.n_local_kv_heads, heads_per_group, self.head_dim
|
||
|
|
)
|
||
|
|
xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, 1, self.head_dim)
|
||
|
|
# xq = rearrange(xq, 'b (g h l d) -> 1 b h g (d l)', g=heads_per_group, h=self.n_local_kv_heads, d=self.head_dim // 2, l=2)
|
||
|
|
# xk = rearrange(xk, 'b (g l d) -> 1 b g 1 (d l)', g=self.n_local_kv_heads, d=self.head_dim // 2)
|
||
|
|
xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, 1, self.head_dim)
|
||
|
|
cache_k, cache_v = cache
|
||
|
|
|
||
|
|
xq = rope_padded(
|
||
|
|
xq=xq,
|
||
|
|
xk=xk,
|
||
|
|
xv=xv,
|
||
|
|
cache_k=cache_k,
|
||
|
|
cache_v=cache_v,
|
||
|
|
attn_bias=attn_bias,
|
||
|
|
theta=self.rope_theta,
|
||
|
|
)
|
||
|
|
|
||
|
|
output = fmha.memory_efficient_attention_forward(
|
||
|
|
xq, cache_k, cache_v, attn_bias, op = fmha.flash.FwOp
|
||
|
|
)
|
||
|
|
|
||
|
|
output = output.reshape(output_shape)
|
||
|
|
output = self.attn_sub_norm(output)
|
||
|
|
output = self.wo(output)
|
||
|
|
|
||
|
|
return output
|
||
|
|
|
||
|
|
@torch.compile
|
||
|
|
def squared_relu(x: torch.Tensor) -> torch.Tensor:
|
||
|
|
return F.relu(x) ** 2
|
||
|
|
|
||
|
|
class FeedForward(nn.Module):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
dim: int,
|
||
|
|
hidden_dim: int,
|
||
|
|
norm_eps: float,
|
||
|
|
use_kernel: bool,
|
||
|
|
):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
Linear = BitLinearKernel if use_kernel else BitLinear
|
||
|
|
|
||
|
|
self.w13 = Linear(
|
||
|
|
dim,
|
||
|
|
2 * hidden_dim,
|
||
|
|
bias=False,
|
||
|
|
)
|
||
|
|
self.w2 = Linear(
|
||
|
|
hidden_dim,
|
||
|
|
dim,
|
||
|
|
bias=False,
|
||
|
|
)
|
||
|
|
self.ffn_sub_norm = RMSNorm(hidden_dim, norm_eps)
|
||
|
|
|
||
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
|
|
x13 = self.w13(x)
|
||
|
|
x1, x3 = x13.chunk(2, -1)
|
||
|
|
inner = self.ffn_sub_norm(squared_relu(x1) * x3)
|
||
|
|
output = self.w2(inner)
|
||
|
|
return output
|
||
|
|
|
||
|
|
|
||
|
|
class TransformerBlock(nn.Module):
|
||
|
|
def __init__(self, args: ModelArgs):
|
||
|
|
super().__init__()
|
||
|
|
|
||
|
|
assert args.dim % args.n_heads == 0
|
||
|
|
head_dim = args.dim // args.n_heads
|
||
|
|
if args.n_kv_heads is not None:
|
||
|
|
n_kv_heads = args.n_kv_heads
|
||
|
|
else:
|
||
|
|
n_kv_heads = args.n_heads
|
||
|
|
|
||
|
|
assert args.n_heads % n_kv_heads == 0
|
||
|
|
|
||
|
|
self.attention = Attention(
|
||
|
|
dim=args.dim,
|
||
|
|
head_dim=head_dim,
|
||
|
|
n_heads=args.n_heads,
|
||
|
|
n_kv_heads=n_kv_heads,
|
||
|
|
rope_theta=args.rope_theta,
|
||
|
|
norm_eps=args.norm_eps,
|
||
|
|
use_kernel=args.use_kernel,
|
||
|
|
)
|
||
|
|
self.feed_forward = FeedForward(
|
||
|
|
dim=args.dim,
|
||
|
|
hidden_dim=args.ffn_dim,
|
||
|
|
norm_eps=args.norm_eps,
|
||
|
|
use_kernel=args.use_kernel,
|
||
|
|
)
|
||
|
|
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||
|
|
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
x: torch.Tensor,
|
||
|
|
cache: LayerCache,
|
||
|
|
attn_bias: AttnBias,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
h = x + self.attention.forward(
|
||
|
|
self.attention_norm(x),
|
||
|
|
cache,
|
||
|
|
attn_bias,
|
||
|
|
)
|
||
|
|
out = h + self.feed_forward(self.ffn_norm(h))
|
||
|
|
return out
|
||
|
|
|
||
|
|
|
||
|
|
class Transformer(nn.Module):
|
||
|
|
def __init__(self, args: ModelArgs):
|
||
|
|
super().__init__()
|
||
|
|
assert args.vocab_size > 0
|
||
|
|
|
||
|
|
self.tok_embeddings = nn.Embedding(
|
||
|
|
num_embeddings=args.vocab_size,
|
||
|
|
embedding_dim=args.dim,
|
||
|
|
)
|
||
|
|
|
||
|
|
self.layers = nn.ModuleList()
|
||
|
|
for _ in range(args.n_layers):
|
||
|
|
self.layers.append(TransformerBlock(args))
|
||
|
|
|
||
|
|
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
||
|
|
|
||
|
|
self.output = nn.Linear(
|
||
|
|
args.dim,
|
||
|
|
args.vocab_size,
|
||
|
|
bias=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
@torch.no_grad()
|
||
|
|
def forward_with_attn_bias(
|
||
|
|
self,
|
||
|
|
token_values: torch.Tensor,
|
||
|
|
attn_bias: AttnBias,
|
||
|
|
cache: list[LayerCache],
|
||
|
|
) -> torch.Tensor:
|
||
|
|
h = self.tok_embeddings(token_values)
|
||
|
|
|
||
|
|
for i, layer in enumerate(self.layers):
|
||
|
|
h = layer(h, cache[i], attn_bias)
|
||
|
|
|
||
|
|
logits = self.output(self.norm(h))
|
||
|
|
return logits.float()
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
token_values: torch.Tensor,
|
||
|
|
token_lengths: torch.Tensor,
|
||
|
|
start_pos: torch.Tensor,
|
||
|
|
cache: list[LayerCache],
|
||
|
|
kv_padding: int,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
attn_bias = AttnBias.from_seqlens(
|
||
|
|
q_seqlen=token_lengths.tolist(),
|
||
|
|
kv_seqlen=(start_pos + token_lengths).tolist(),
|
||
|
|
kv_padding=kv_padding,
|
||
|
|
)
|
||
|
|
return self.forward_with_attn_bias(token_values, attn_bias, cache)
|
||
|
|
|
||
|
|
|
||
|
|
def make_cache(
|
||
|
|
args: ModelArgs,
|
||
|
|
length: int,
|
||
|
|
device: Optional[Union[str, torch.device]] = None,
|
||
|
|
n_layers: Optional[int] = None,
|
||
|
|
dtype: Optional[torch.dtype] = None,
|
||
|
|
) -> list[LayerCache]:
|
||
|
|
"""
|
||
|
|
Allocate a cache to be used with the Transformer module.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
args (ModelArgs): the model configuration.
|
||
|
|
length (int): per layer cache size.
|
||
|
|
It is usually budgeted as ``max_batch * max_seq``
|
||
|
|
device (torch.device, optional): the device on which
|
||
|
|
the cache should be allocated.
|
||
|
|
n_layers (int, optional): the number of layers to
|
||
|
|
allocate a cache for (defaults to the model
|
||
|
|
settings).
|
||
|
|
dtype (torch.dtype, optional): the dtype to use for
|
||
|
|
cache entries (defaults to the default dtype).
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The cache object to pass to ``Tranformer.forward``.
|
||
|
|
"""
|
||
|
|
|
||
|
|
head_dim = args.dim // args.n_heads
|
||
|
|
n_kv_heads = args.n_kv_heads
|
||
|
|
if n_kv_heads is None:
|
||
|
|
n_kv_heads = args.n_heads
|
||
|
|
n_local_kv_heads = n_kv_heads
|
||
|
|
|
||
|
|
if n_layers is None:
|
||
|
|
n_layers = args.n_layers
|
||
|
|
|
||
|
|
shape = (1, length, n_local_kv_heads, 1, head_dim)
|
||
|
|
heads_per_group = args.n_heads // n_kv_heads
|
||
|
|
expansion = (-1, -1, -1, heads_per_group, -1)
|
||
|
|
return [
|
||
|
|
(
|
||
|
|
torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
|
||
|
|
torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
|
||
|
|
)
|
||
|
|
for _ in range(n_layers)
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]:
|
||
|
|
"""
|
||
|
|
Take a prefix view of a larger cache.
|
||
|
|
|
||
|
|
The original cache object remains of identical size and valid
|
||
|
|
after the shrinked alias has been used. This function is useful
|
||
|
|
when a cache was allocated for a larger batch size than what is
|
||
|
|
necessary.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
cache: the cache to take a view in.
|
||
|
|
length (int): the desired length
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A view in the input cache object.
|
||
|
|
"""
|
||
|
|
|
||
|
|
if len(cache) > 0:
|
||
|
|
assert cache[0][0].shape[1] >= length
|
||
|
|
|
||
|
|
return [(ck[:, :length], cv[:, :length]) for ck, cv in cache]
|