# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import json import os import readline # type: ignore # noqa import sys import time from dataclasses import dataclass from pathlib import Path from typing import Iterable, Optional, Tuple, Union import fire import model as fast import torch from stats import Stats from tokenizer import Tokenizer, ChatFormat import sample_utils from xformers.ops.fmha.attn_bias import ( BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, ) @dataclass class GenArgs: gen_length: int = 32 gen_bsz: int = 1 prompt_length: int = 64 use_sampling: bool = False temperature: float = 0.8 top_p: float = 0.9 class FastGen: GRAPH_WARMUPS: int = 1 tokenizer: Tokenizer @staticmethod def build( ckpt_dir: str, gen_args: GenArgs, device: Union[torch.device, str], tokenizer_path: Optional[str] = None, num_layers: int = 13, use_full_vocab: bool = False, ) -> "FastGen": """ Load a Llama or Code Llama checkpoint and return a new generator for this model. """ start_time = time.time() model_args_prefill = fast.ModelArgs(use_kernel=False) model_args_decode = fast.ModelArgs(use_kernel=True) tokenizer = Tokenizer("./tokenizer.model") torch.set_default_device(device) torch.set_default_dtype(torch.bfloat16) prefill_model = fast.Transformer(model_args_prefill) decode_model = fast.Transformer(model_args_decode) fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt") fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu", weights_only=True) int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt") int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu", weights_only=True) prefill_model.load_state_dict(fp16_checkpoint, strict=True) decode_model.load_state_dict(int2_checkpoint, strict=True) torch.cuda.synchronize() print(f"loaded model in {time.time() - start_time:.2f} seconds") start_time = time.time() return FastGen(gen_args, model_args_prefill, prefill_model, decode_model, tokenizer) def __init__( self, args: GenArgs, model_args: fast.ModelArgs, prefill_model: fast.Transformer, decode_model: fast.Transformer, tokenizer: Tokenizer, ): self.gen_args = args self.max_seq_length = args.prompt_length + args.gen_length self.model_args = model_args # self.model = model self.prefill_model = prefill_model self.decode_model = decode_model self.tokenizer = tokenizer self._prefill_cuda_graph, self._prefill_compile_model, self._prefill_inputs, self._prefill_logits = None, None, None, None self._generate_cuda_graph, self._generate_compile_model, self._generate_inputs, self._generate_logits = None, None, None, None self._cache = None start_time = time.time() self._prefill_compile_model = self.compile_prefill() self._generate_compile_model = self.compile_generate() print(f"compiled model in {time.time() - start_time:.2f} seconds") def compile_prefill(self): if self._cache is None: self._cache = fast.make_cache( args=self.model_args, length=self.gen_args.gen_bsz * self.max_seq_length, ) seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)] bias = AttnBias.from_seqlens( q_seqlen=seq_lens, kv_seqlen=seq_lens, kv_padding=self.max_seq_length, ) bias.q_seqinfo.to("cuda") bias.k_seqinfo.to("cuda") tokens = torch.IntTensor([1] * self.gen_args.gen_bsz * self.gen_args.prompt_length).cuda() self._prefill_inputs = (tokens, bias) s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): _ = self.prefill_model.forward_with_attn_bias( token_values=self._prefill_inputs[0], attn_bias=self._prefill_inputs[1], cache=self._cache, ) torch.cuda.current_stream().wait_stream(s) self._prefill_cuda_graph = torch.cuda.CUDAGraph() recording_kwargs = {} if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__: # In PyTorch 2.1+ and nightlies from late Aug 2023, # we can do this to maybe avoid watchdog-related crashes recording_kwargs["capture_error_mode"] = "thread_local" with torch.cuda.graph(self._prefill_cuda_graph, **recording_kwargs): self._prefill_logits = self.prefill_model.forward_with_attn_bias( token_values=self._prefill_inputs[0], attn_bias=self._prefill_inputs[1], cache=self._cache, ) def replay(tokens, seq_lens=None): self._prefill_inputs[0].copy_(tokens) if seq_lens is not None: self._prefill_inputs[1].k_seqinfo.seqlen.copy_(seq_lens) self._prefill_cuda_graph.replay() torch.cuda.synchronize() return self._prefill_logits return replay def compile_generate(self): if self._cache is None: self._cache = fast.make_cache( args=self.model_args, length=self.gen_args.gen_bsz * self.max_seq_length, ) seq_lens = [1 for _ in range(self.gen_args.gen_bsz)] kv_seq_lens = [self.gen_args.prompt_length for _ in range(self.gen_args.gen_bsz)] bias = AttnBias.from_seqlens( q_seqlen=seq_lens, kv_seqlen=kv_seq_lens, kv_padding=self.max_seq_length, ) bias.q_seqinfo.to("cuda") bias.k_seqinfo.to("cuda") tokens = torch.IntTensor([1] * self.gen_args.gen_bsz).cuda() self._generate_inputs = (tokens, bias) s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): _ = self.decode_model.forward_with_attn_bias( token_values=self._generate_inputs[0], attn_bias=self._generate_inputs[1], cache=self._cache, ) torch.cuda.current_stream().wait_stream(s) self._generate_cuda_graph = torch.cuda.CUDAGraph() recording_kwargs = {} if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__: # In PyTorch 2.1+ and nightlies from late Aug 2023, # we can do this to maybe avoid watchdog-related crashes recording_kwargs["capture_error_mode"] = "thread_local" with torch.cuda.graph(self._generate_cuda_graph, **recording_kwargs): self._generate_logits = self.decode_model.forward_with_attn_bias( token_values=self._generate_inputs[0], attn_bias=self._generate_inputs[1], cache=self._cache, ) def replay(tokens, seq_lens): self._generate_inputs[0].copy_(tokens) self._generate_inputs[1].k_seqinfo.seqlen.copy_(seq_lens) self._generate_cuda_graph.replay() return self._generate_logits return replay @torch.inference_mode() def generate_all( self, prompts: list[list[int]], use_cuda_graphs: bool, use_sampling: bool ) -> Tuple[Stats, list[list[int]]]: bs = len(prompts) prompt_lens = [len(p) for p in prompts] padded_prompt_lens = [self.gen_args.prompt_length] * bs max_prompt_length = max(prompt_lens) gen_length = self.gen_args.gen_length max_seq_length = max_prompt_length + gen_length print(max_prompt_length, gen_length) bias = AttnBias.from_seqlens( q_seqlen=padded_prompt_lens, kv_seqlen=prompt_lens, kv_padding=max_seq_length, ) bias.q_seqinfo.to("cuda") bias.k_seqinfo.to("cuda") # Input tensors to the cuda graph kv_seqlen = bias.k_seqinfo.seqlen prompts = [prompt + [1] * (self.gen_args.prompt_length - len(prompt)) for prompt in prompts] tokens = torch.IntTensor(sum(prompts, [])).cuda() out_tokens = torch.zeros((max_seq_length, bs), dtype=torch.int) stats = Stats() torch.cuda.synchronize() stats.phase("prefill" if use_cuda_graphs else "total") # stats.phase("total") output = self._prefill_compile_model(tokens, None) logits = output[kv_seqlen - 1, :] logits = logits.view(bs, self.model_args.vocab_size) if use_sampling: temp = 0.7 top_p = 0.95 probs = torch.softmax(logits / temp, dim=-1) next_token = sample_utils.top_p(probs, top_p) else: next_token = torch.argmax(logits, dim=-1) next_token = next_token.reshape(bs) out_tokens[0, :] = next_token torch.cuda.synchronize() stats.phase("decode" if use_cuda_graphs else "total") eos_id = self.tokenizer.eot_id for niter in range(1, gen_length): kv_seqlen.add_(kv_seqlen < max_seq_length) output = self._generate_compile_model(next_token, kv_seqlen) logits = output.view(bs, self.model_args.vocab_size) if use_sampling: temp = 0.7 top_p = 0.95 probs = torch.softmax(logits / temp, dim=-1) next_token = sample_utils.top_p(probs, top_p) else: next_token = torch.argmax(logits, dim=-1) next_token = next_token.reshape(bs) out_tokens[niter, :] = next_token if next_token.eq(eos_id).any(): break torch.cuda.synchronize() stats.end_phase(tokens=niter * bs) def trim_answer(prompt_len, tokens): # print(prompt, tokens) """Trim the answer to end it on an eos token.""" tokens = tokens[: max_seq_length - prompt_len] eos_id = self.tokenizer.eot_id if eos_id in tokens: return tokens[: tokens.index(eos_id) + 1] else: return tokens answers = [ trim_answer(prompt_len, answer) for prompt_len, answer in zip(prompt_lens, out_tokens.t().tolist()) ] return stats, answers def get_prompts(interactive: bool) -> Iterable[list[str]]: if interactive: while True: try: prompts = input("enter prompt: ").split("\n") except EOFError: print("exiting") sys.exit(0) yield prompts else: yield [ "Hello, my name is", ] def main(ckpt_dir: str, interactive: bool = False, chat_format: bool = False, sampling: bool = False): local_rank = 0 device = f"cuda:{local_rank}" torch.cuda.set_device(local_rank) g = FastGen.build(ckpt_dir, GenArgs(), device) if chat_format: g.tokenizer = ChatFormat(g.tokenizer) for prompts in get_prompts(interactive): # prompts = [f"{prompt}\n" for prompt in prompts] if chat_format: # prompts = [f'<|begin_of_text|>User: {prompt}<|eot_id|>Assistant: ' for prompt in prompts] tokens = [g.tokenizer.encode_dialog_prompt(dialog=[{"role": "user", "content": prompt}], completion=True) for prompt in prompts] else: tokens = [g.tokenizer.encode(x, bos=False, eos=False) for x in prompts] print(tokens) stats, out_tokens = g.generate_all( tokens, use_cuda_graphs="NO_CUDA_GRAPHS" not in os.environ, use_sampling=sampling, ) for i, prompt in enumerate(prompts): print(f"> {prompt}") answer = g.tokenizer.decode(out_tokens[i]) print(answer) print("---------------") for phase_stats in stats.phases: print(phase_stats.show()) print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") if __name__ == "__main__": fire.Fire(main)