SIGN IN SIGN UP
microsoft / BitNet UNCLAIMED

Official inference framework for 1-bit LLMs

36832 0 0 Python
2025-05-15 05:55:42 +00:00
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import readline # type: ignore # noqa
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Optional, Tuple, Union
import fire
import model as fast
import torch
from stats import Stats
from tokenizer import Tokenizer, ChatFormat
import sample_utils
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
)
@dataclass
class GenArgs:
gen_length: int = 32
gen_bsz: int = 1
prompt_length: int = 64
use_sampling: bool = False
temperature: float = 0.8
top_p: float = 0.9
class FastGen:
GRAPH_WARMUPS: int = 1
tokenizer: Tokenizer
@staticmethod
def build(
ckpt_dir: str,
gen_args: GenArgs,
device: Union[torch.device, str],
tokenizer_path: Optional[str] = None,
num_layers: int = 13,
use_full_vocab: bool = False,
) -> "FastGen":
"""
Load a Llama or Code Llama checkpoint and return a new
generator for this model.
"""
start_time = time.time()
model_args_prefill = fast.ModelArgs(use_kernel=False)
model_args_decode = fast.ModelArgs(use_kernel=True)
tokenizer = Tokenizer("./tokenizer.model")
torch.set_default_device(device)
torch.set_default_dtype(torch.bfloat16)
prefill_model = fast.Transformer(model_args_prefill)
decode_model = fast.Transformer(model_args_decode)
fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt")
fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu")
int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt")
int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu")
prefill_model.load_state_dict(fp16_checkpoint, strict=True)
decode_model.load_state_dict(int2_checkpoint, strict=True)
torch.cuda.synchronize()
print(f"loaded model in {time.time() - start_time:.2f} seconds")
start_time = time.time()
return FastGen(gen_args, model_args_prefill, prefill_model, decode_model, tokenizer)
def __init__(
self,
args: GenArgs,
model_args: fast.ModelArgs,
prefill_model: fast.Transformer,
decode_model: fast.Transformer,
tokenizer: Tokenizer,
):
self.gen_args = args
self.max_seq_length = args.prompt_length + args.gen_length
self.model_args = model_args
# self.model = model
self.prefill_model = prefill_model
self.decode_model = decode_model
self.tokenizer = tokenizer
self._prefill_cuda_graph, self._prefill_compile_model, self._prefill_inputs, self._prefill_logits = None, None, None, None
self._generate_cuda_graph, self._generate_compile_model, self._generate_inputs, self._generate_logits = None, None, None, None
self._cache = None
start_time = time.time()
self._prefill_compile_model = self.compile_prefill()
self._generate_compile_model = self.compile_generate()
print(f"compiled model in {time.time() - start_time:.2f} seconds")
def compile_prefill(self):
if self._cache is None:
self._cache = fast.make_cache(
args=self.model_args,
length=self.gen_args.gen_bsz * self.max_seq_length,
)
seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)]
bias = AttnBias.from_seqlens(
q_seqlen=seq_lens,
kv_seqlen=seq_lens,
kv_padding=self.max_seq_length,
)
bias.q_seqinfo.to("cuda")
bias.k_seqinfo.to("cuda")
tokens = torch.IntTensor([1] * self.gen_args.gen_bsz * self.gen_args.prompt_length).cuda()
self._prefill_inputs = (tokens, bias)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
_ = self.prefill_model.forward_with_attn_bias(
token_values=self._prefill_inputs[0],
attn_bias=self._prefill_inputs[1],
cache=self._cache,
)
torch.cuda.current_stream().wait_stream(s)
self._prefill_cuda_graph = torch.cuda.CUDAGraph()
recording_kwargs = {}
if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
# In PyTorch 2.1+ and nightlies from late Aug 2023,
# we can do this to maybe avoid watchdog-related crashes
recording_kwargs["capture_error_mode"] = "thread_local"
with torch.cuda.graph(self._prefill_cuda_graph, **recording_kwargs):
self._prefill_logits = self.prefill_model.forward_with_attn_bias(
token_values=self._prefill_inputs[0],
attn_bias=self._prefill_inputs[1],
cache=self._cache,
)
def replay(tokens, seq_lens=None):
self._prefill_inputs[0].copy_(tokens)
if seq_lens is not None:
self._prefill_inputs[1].k_seqinfo.seqlen.copy_(seq_lens)
self._prefill_cuda_graph.replay()
torch.cuda.synchronize()
return self._prefill_logits
return replay
def compile_generate(self):
if self._cache is None:
self._cache = fast.make_cache(
args=self.model_args,
length=self.gen_args.gen_bsz * self.max_seq_length,
)
seq_lens = [1 for _ in range(self.gen_args.gen_bsz)]
kv_seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)]
bias = AttnBias.from_seqlens(
q_seqlen=seq_lens,
kv_seqlen=kv_seq_lens,
kv_padding=self.max_seq_length,
)
bias.q_seqinfo.to("cuda")
bias.k_seqinfo.to("cuda")
tokens = torch.IntTensor([1] * self.gen_args.gen_bsz).cuda()
self._generate_inputs = (tokens, bias)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
_ = self.decode_model.forward_with_attn_bias(
token_values=self._generate_inputs[0],
attn_bias=self._generate_inputs[1],
cache=self._cache,
)
torch.cuda.current_stream().wait_stream(s)
self._generate_cuda_graph = torch.cuda.CUDAGraph()
recording_kwargs = {}
if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
# In PyTorch 2.1+ and nightlies from late Aug 2023,
# we can do this to maybe avoid watchdog-related crashes
recording_kwargs["capture_error_mode"] = "thread_local"
with torch.cuda.graph(self._generate_cuda_graph, **recording_kwargs):
self._generate_logits = self.decode_model.forward_with_attn_bias(
token_values=self._generate_inputs[0],
attn_bias=self._generate_inputs[1],
cache=self._cache,
)
def replay(tokens, seq_lens):
self._generate_inputs[0].copy_(tokens)
self._generate_inputs[1].k_seqinfo.seqlen.copy_(seq_lens)
self._generate_cuda_graph.replay()
return self._generate_logits
return replay
@torch.inference_mode()
def generate_all(
self, prompts: list[list[int]], use_cuda_graphs: bool, use_sampling: bool
) -> Tuple[Stats, list[list[int]]]:
bs = len(prompts)
prompt_lens = [len(p) for p in prompts]
padded_prompt_lens = [self.gen_args.prompt_length] * bs
max_prompt_length = max(prompt_lens)
gen_length = self.gen_args.gen_length
max_seq_length = max_prompt_length + gen_length
print(max_prompt_length, gen_length)
bias = AttnBias.from_seqlens(
q_seqlen=padded_prompt_lens,
kv_seqlen=prompt_lens,
kv_padding=max_seq_length,
)
bias.q_seqinfo.to("cuda")
bias.k_seqinfo.to("cuda")
# Input tensors to the cuda graph
kv_seqlen = bias.k_seqinfo.seqlen
prompts = [prompt + [1] * (self.gen_args.prompt_length - len(prompt)) for prompt in prompts]
tokens = torch.IntTensor(sum(prompts, [])).cuda()
out_tokens = torch.zeros((max_seq_length, bs), dtype=torch.int)
stats = Stats()
torch.cuda.synchronize()
stats.phase("prefill" if use_cuda_graphs else "total")
# stats.phase("total")
output = self._prefill_compile_model(tokens, None)
logits = output[kv_seqlen - 1, :]
logits = logits.view(bs, self.model_args.vocab_size)
if use_sampling:
temp = 0.7
top_p = 0.95
probs = torch.softmax(logits / temp, dim=-1)
next_token = sample_utils.top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(bs)
out_tokens[0, :] = next_token
torch.cuda.synchronize()
stats.phase("decode" if use_cuda_graphs else "total")
eos_id = self.tokenizer.eot_id
for niter in range(1, gen_length):
kv_seqlen.add_(kv_seqlen < max_seq_length)
output = self._generate_compile_model(next_token, kv_seqlen)
logits = output.view(bs, self.model_args.vocab_size)
if use_sampling:
temp = 0.7
top_p = 0.95
probs = torch.softmax(logits / temp, dim=-1)
next_token = sample_utils.top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(bs)
out_tokens[niter, :] = next_token
if next_token.eq(eos_id).any():
break
torch.cuda.synchronize()
stats.end_phase(tokens=niter * bs)
def trim_answer(prompt_len, tokens):
# print(prompt, tokens)
"""Trim the answer to end it on an eos token."""
tokens = tokens[: max_seq_length - prompt_len]
eos_id = self.tokenizer.eot_id
if eos_id in tokens:
return tokens[: tokens.index(eos_id) + 1]
else:
return tokens
answers = [
trim_answer(prompt_len, answer)
for prompt_len, answer in zip(prompt_lens, out_tokens.t().tolist())
]
return stats, answers
def get_prompts(interactive: bool) -> Iterable[list[str]]:
if interactive:
while True:
try:
prompts = input("enter prompt: ").split("\n")
except EOFError:
print("exiting")
sys.exit(0)
yield prompts
else:
yield [
"Hello, my name is",
]
def main(ckpt_dir: str, interactive: bool = False, chat_format: bool = False, sampling: bool = False):
local_rank = 0
device = f"cuda:{local_rank}"
torch.cuda.set_device(local_rank)
g = FastGen.build(ckpt_dir, GenArgs(), device)
if chat_format:
g.tokenizer = ChatFormat(g.tokenizer)
for prompts in get_prompts(interactive):
# prompts = [f"{prompt}\n" for prompt in prompts]
if chat_format:
# prompts = [f'<|begin_of_text|>User: {prompt}<|eot_id|>Assistant: ' for prompt in prompts]
tokens = [g.tokenizer.encode_dialog_prompt(dialog=[{"role": "user", "content": prompt}], completion=True) for prompt in prompts]
else:
tokens = [g.tokenizer.encode(x, bos=False, eos=False) for x in prompts]
print(tokens)
stats, out_tokens = g.generate_all(
tokens, use_cuda_graphs="NO_CUDA_GRAPHS" not in os.environ, use_sampling=sampling,
)
for i, prompt in enumerate(prompts):
print(f"> {prompt}")
answer = g.tokenizer.decode(out_tokens[i])
print(answer)
print("---------------")
for phase_stats in stats.phases:
print(phase_stats.show())
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
if __name__ == "__main__":
fire.Fire(main)