2023-04-06 15:30:57 +02:00
|
|
|
import os
|
|
|
|
|
import argparse
|
2023-05-04 18:33:08 +02:00
|
|
|
import re
|
2023-04-06 15:30:57 +02:00
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field
|
2023-05-04 18:33:08 +02:00
|
|
|
from typing import List
|
2023-04-06 15:30:57 +02:00
|
|
|
|
|
|
|
|
# Based on https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class GptParams:
|
|
|
|
|
seed: int = -1
|
|
|
|
|
n_threads: int = min(4, os.cpu_count() or 1)
|
|
|
|
|
n_predict: int = 128
|
|
|
|
|
n_parts: int = -1
|
|
|
|
|
n_ctx: int = 512
|
|
|
|
|
n_batch: int = 8
|
|
|
|
|
n_keep: int = 0
|
|
|
|
|
|
2023-05-04 18:33:08 +02:00
|
|
|
ignore_eos: bool = False
|
|
|
|
|
logit_bias: dict[int, float] = field(default_factory=dict)
|
2023-04-06 15:30:57 +02:00
|
|
|
top_k: int = 40
|
|
|
|
|
top_p: float = 0.95
|
2023-05-04 18:33:08 +02:00
|
|
|
tfs_z: float = 1.00
|
|
|
|
|
typical_p: float = 1.00
|
2023-04-06 15:30:57 +02:00
|
|
|
temp: float = 0.80
|
|
|
|
|
repeat_penalty: float = 1.10
|
2023-05-04 18:33:08 +02:00
|
|
|
repeat_last_n: int = 64
|
|
|
|
|
frequency_penalty: float = 0.0
|
|
|
|
|
presence_penalty: float = 0.0
|
|
|
|
|
mirostat: int = 0
|
|
|
|
|
mirostat_tau: float = 5.0
|
|
|
|
|
mirostat_eta: float = 0.1
|
2023-04-06 15:30:57 +02:00
|
|
|
|
|
|
|
|
model: str = "./models/llama-7B/ggml-model.bin"
|
|
|
|
|
prompt: str = ""
|
2023-05-04 18:33:08 +02:00
|
|
|
path_session: str = ""
|
2023-04-06 15:30:57 +02:00
|
|
|
input_prefix: str = " "
|
2023-05-06 13:18:25 +02:00
|
|
|
input_suffix: str = ""
|
2023-04-06 15:30:57 +02:00
|
|
|
antiprompt: List[str] = field(default_factory=list)
|
|
|
|
|
|
2023-05-04 18:33:08 +02:00
|
|
|
lora_adapter: str = ""
|
|
|
|
|
lora_base: str = ""
|
|
|
|
|
|
2023-04-06 15:30:57 +02:00
|
|
|
memory_f16: bool = True
|
|
|
|
|
random_prompt: bool = False
|
|
|
|
|
use_color: bool = False
|
|
|
|
|
interactive: bool = False
|
|
|
|
|
|
|
|
|
|
embedding: bool = False
|
|
|
|
|
interactive_start: bool = False
|
|
|
|
|
|
|
|
|
|
instruct: bool = False
|
2023-05-04 18:33:08 +02:00
|
|
|
penalize_nl: bool = True
|
2023-04-06 15:30:57 +02:00
|
|
|
perplexity: bool = False
|
2023-04-10 16:35:38 +02:00
|
|
|
use_mmap: bool = True
|
2023-04-06 15:30:57 +02:00
|
|
|
use_mlock: bool = False
|
|
|
|
|
mem_test: bool = False
|
|
|
|
|
verbose_prompt: bool = False
|
|
|
|
|
|
2023-04-07 13:32:19 +02:00
|
|
|
file: str = None
|
|
|
|
|
|
|
|
|
|
# If chat ended prematurely, append this to the conversation to fix it.
|
|
|
|
|
# Set to "\nUser:" etc.
|
|
|
|
|
# This is an alternative to input_prefix which always adds it, so it potentially duplicates "User:""
|
2023-04-10 16:38:45 +02:00
|
|
|
fix_prefix: str = ""
|
2024-07-09 12:20:17 -04:00
|
|
|
input_echo: bool = (True,)
|
2023-04-07 13:32:19 +02:00
|
|
|
|
2023-04-06 15:30:57 +02:00
|
|
|
# Default instructions for Alpaca
|
|
|
|
|
# switch to "Human" and "Assistant" for Vicuna.
|
2023-04-07 13:32:19 +02:00
|
|
|
# TODO: TBD how they are gonna handle this upstream
|
2024-07-09 12:20:17 -04:00
|
|
|
instruct_inp_prefix: str = "\n\n### Instruction:\n\n"
|
|
|
|
|
instruct_inp_suffix: str = "\n\n### Response:\n\n"
|
2023-04-06 15:30:57 +02:00
|
|
|
|
|
|
|
|
|
2024-07-09 12:20:17 -04:00
|
|
|
def gpt_params_parse(argv=None):
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-s",
|
|
|
|
|
"--seed",
|
|
|
|
|
type=int,
|
|
|
|
|
default=-1,
|
|
|
|
|
help="RNG seed (use random seed for <= 0)",
|
|
|
|
|
dest="seed",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-t",
|
|
|
|
|
"--threads",
|
|
|
|
|
type=int,
|
|
|
|
|
default=min(4, os.cpu_count() or 1),
|
|
|
|
|
help="number of threads to use during computation",
|
|
|
|
|
dest="n_threads",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-n",
|
|
|
|
|
"--n_predict",
|
|
|
|
|
type=int,
|
|
|
|
|
default=128,
|
|
|
|
|
help="number of tokens to predict (-1 = infinity)",
|
|
|
|
|
dest="n_predict",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--n_parts", type=int, default=-1, help="number of model parts", dest="n_parts"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-c",
|
|
|
|
|
"--ctx_size",
|
|
|
|
|
type=int,
|
|
|
|
|
default=512,
|
|
|
|
|
help="size of the prompt context",
|
|
|
|
|
dest="n_ctx",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-b",
|
|
|
|
|
"--batch_size",
|
|
|
|
|
type=int,
|
|
|
|
|
default=8,
|
|
|
|
|
help="batch size for prompt processing",
|
|
|
|
|
dest="n_batch",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--keep",
|
|
|
|
|
type=int,
|
|
|
|
|
default=0,
|
|
|
|
|
help="number of tokens to keep from the initial prompt",
|
|
|
|
|
dest="n_keep",
|
|
|
|
|
)
|
2023-05-04 18:33:08 +02:00
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-l",
|
|
|
|
|
"--logit-bias",
|
|
|
|
|
type=str,
|
2024-07-09 12:20:17 -04:00
|
|
|
action="append",
|
2023-05-04 18:33:08 +02:00
|
|
|
help="--logit-bias TOKEN_ID(+/-)BIAS",
|
2024-07-09 12:20:17 -04:00
|
|
|
dest="logit_bias_str",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--ignore-eos",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="ignore end of stream token and continue generating",
|
|
|
|
|
dest="ignore_eos",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--top_k", type=int, default=40, help="top-k sampling", dest="top_k"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--top_p", type=float, default=0.95, help="top-p samplin", dest="top_p"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--tfs",
|
|
|
|
|
type=float,
|
|
|
|
|
default=1.0,
|
|
|
|
|
help="tail free sampling, parameter z (1.0 = disabled)",
|
|
|
|
|
dest="tfs_z",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--temp", type=float, default=0.80, help="temperature", dest="temp"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--repeat_penalty",
|
|
|
|
|
type=float,
|
|
|
|
|
default=1.10,
|
|
|
|
|
help="penalize repeat sequence of tokens",
|
|
|
|
|
dest="repeat_penalty",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--repeat_last_n",
|
|
|
|
|
type=int,
|
|
|
|
|
default=64,
|
|
|
|
|
help="last n tokens to consider for penalize ",
|
|
|
|
|
dest="repeat_last_n",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--frequency_penalty",
|
|
|
|
|
type=float,
|
|
|
|
|
default=0.0,
|
|
|
|
|
help="repeat alpha frequency penalty (0.0 = disabled)",
|
|
|
|
|
dest="tfs_z",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--presence_penalty",
|
|
|
|
|
type=float,
|
|
|
|
|
default=0.0,
|
|
|
|
|
help="repeat alpha presence penalty (0.0 = disabled)",
|
|
|
|
|
dest="presence_penalty",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--mirostat",
|
|
|
|
|
type=float,
|
|
|
|
|
default=1.0,
|
|
|
|
|
help="use Mirostat sampling.",
|
|
|
|
|
dest="mirostat",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--mirostat_ent",
|
|
|
|
|
type=float,
|
|
|
|
|
default=5.0,
|
|
|
|
|
help="Mirostat target entropy, parameter tau represents the average surprise value",
|
|
|
|
|
dest="mirostat_tau",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--mirostat_lr",
|
|
|
|
|
type=float,
|
|
|
|
|
default=0.1,
|
|
|
|
|
help="Mirostat learning rate, parameter eta",
|
|
|
|
|
dest="mirostat_eta",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-m",
|
|
|
|
|
"--model",
|
|
|
|
|
type=str,
|
|
|
|
|
default="./models/llama-7B/ggml-model.bin",
|
|
|
|
|
help="model path",
|
|
|
|
|
dest="model",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-p", "--prompt", type=str, default=None, help="initial prompt", dest="prompt"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"-f",
|
|
|
|
|
"--file",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="file containing initial prompt to load",
|
|
|
|
|
dest="file",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--session",
|
|
|
|
|
type=str,
|
|
|
|
|
default=None,
|
|
|
|
|
help="file to cache model state in (may be large!)",
|
|
|
|
|
dest="path_session",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--in-prefix",
|
|
|
|
|
type=str,
|
|
|
|
|
default="",
|
|
|
|
|
help="string to prefix user inputs with",
|
|
|
|
|
dest="input_prefix",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--in-suffix", type=str, default="", help="append to input", dest="input_suffix"
|
|
|
|
|
)
|
2023-04-06 15:30:57 +02:00
|
|
|
parser.add_argument(
|
2023-05-04 18:33:08 +02:00
|
|
|
"-r",
|
|
|
|
|
"--reverse-prompt",
|
|
|
|
|
type=str,
|
2024-07-09 12:20:17 -04:00
|
|
|
action="append",
|
2023-05-04 18:33:08 +02:00
|
|
|
help="poll user input upon seeing PROMPT (can be\nspecified more than once for multiple prompts).",
|
2024-07-09 12:20:17 -04:00
|
|
|
dest="antiprompt",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--lora",
|
|
|
|
|
type=str,
|
|
|
|
|
default="",
|
|
|
|
|
help="apply LoRA adapter (implies --no-mmap)",
|
|
|
|
|
dest="lora_adapter",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--lora-base",
|
|
|
|
|
type=str,
|
|
|
|
|
default="",
|
|
|
|
|
help="optional model to use as a base for the layers modified by the LoRA adapter",
|
|
|
|
|
dest="lora_base",
|
2023-04-06 15:30:57 +02:00
|
|
|
)
|
2023-05-04 18:33:08 +02:00
|
|
|
|
2024-07-09 12:20:17 -04:00
|
|
|
parser.add_argument(
|
|
|
|
|
"--memory_f32",
|
|
|
|
|
action="store_false",
|
|
|
|
|
help="use f32 instead of f16 for memory key+value",
|
|
|
|
|
dest="memory_f16",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--random-prompt",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="start with a randomized prompt.",
|
|
|
|
|
dest="random_prompt",
|
|
|
|
|
)
|
2023-04-07 13:32:19 +02:00
|
|
|
parser.add_argument(
|
2023-05-04 18:33:08 +02:00
|
|
|
"--color",
|
2023-04-07 13:32:19 +02:00
|
|
|
action="store_true",
|
2023-05-04 18:33:08 +02:00
|
|
|
help="colorise output to distinguish prompt and user input from generations",
|
2024-07-09 12:20:17 -04:00
|
|
|
dest="use_color",
|
2023-04-07 13:32:19 +02:00
|
|
|
)
|
2023-05-04 18:33:08 +02:00
|
|
|
parser.add_argument(
|
2024-07-09 12:20:17 -04:00
|
|
|
"-i",
|
|
|
|
|
"--interactive",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="run in interactive mode",
|
|
|
|
|
dest="interactive",
|
2023-05-04 18:33:08 +02:00
|
|
|
)
|
2024-07-09 12:20:17 -04:00
|
|
|
|
2023-05-04 18:33:08 +02:00
|
|
|
parser.add_argument("--embedding", action="store_true", help="", dest="embedding")
|
2023-04-06 15:30:57 +02:00
|
|
|
parser.add_argument(
|
|
|
|
|
"--interactive-first",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="run in interactive mode and wait for input right away",
|
2024-07-09 12:20:17 -04:00
|
|
|
dest="interactive_start",
|
2023-04-06 15:30:57 +02:00
|
|
|
)
|
2023-05-04 18:33:08 +02:00
|
|
|
|
2023-04-06 15:30:57 +02:00
|
|
|
parser.add_argument(
|
|
|
|
|
"-ins",
|
|
|
|
|
"--instruct",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="run in instruction mode (use with Alpaca or Vicuna models)",
|
2024-07-09 12:20:17 -04:00
|
|
|
dest="instruct",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--no-penalize-nl",
|
|
|
|
|
action="store_false",
|
|
|
|
|
help="do not penalize newline token",
|
|
|
|
|
dest="penalize_nl",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--perplexity",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="compute perplexity over the prompt",
|
|
|
|
|
dest="perplexity",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--no-mmap",
|
|
|
|
|
action="store_false",
|
|
|
|
|
help="do not memory-map model (slower load but may reduce pageouts if not using mlock)",
|
|
|
|
|
dest="use_mmap",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--mlock",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="force system to keep model in RAM rather than swapping or compressing",
|
|
|
|
|
dest="use_mlock",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--mtest",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="compute maximum memory usage",
|
|
|
|
|
dest="mem_test",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--verbose-prompt",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="print prompt before generation",
|
|
|
|
|
dest="verbose_prompt",
|
2023-04-06 15:30:57 +02:00
|
|
|
)
|
2023-05-04 18:33:08 +02:00
|
|
|
|
2024-07-09 12:20:17 -04:00
|
|
|
# Custom args
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--fix-prefix",
|
|
|
|
|
type=str,
|
|
|
|
|
default="",
|
|
|
|
|
help="append to input when generated n_predict tokens",
|
|
|
|
|
dest="fix_prefix",
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--input-noecho",
|
|
|
|
|
action="store_false",
|
|
|
|
|
help="dont output the input",
|
|
|
|
|
dest="input_echo",
|
|
|
|
|
)
|
2023-05-04 18:33:08 +02:00
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--interactive-start",
|
|
|
|
|
action="store_true",
|
|
|
|
|
help="run in interactive mode",
|
2024-07-09 12:20:17 -04:00
|
|
|
dest="interactive",
|
2023-05-04 18:33:08 +02:00
|
|
|
)
|
|
|
|
|
|
2023-04-06 15:30:57 +02:00
|
|
|
args = parser.parse_args(argv)
|
2024-07-09 12:20:17 -04:00
|
|
|
|
2023-05-04 18:33:08 +02:00
|
|
|
logit_bias_str = args.logit_bias_str
|
|
|
|
|
delattr(args, "logit_bias_str")
|
|
|
|
|
params = GptParams(**vars(args))
|
|
|
|
|
|
2024-07-09 12:20:17 -04:00
|
|
|
if params.lora_adapter:
|
2023-05-04 18:33:08 +02:00
|
|
|
params.use_mmap = False
|
|
|
|
|
|
2024-07-09 12:20:17 -04:00
|
|
|
if logit_bias_str != None:
|
2023-05-04 18:33:08 +02:00
|
|
|
for i in logit_bias_str:
|
2024-07-09 12:20:17 -04:00
|
|
|
if m := re.match(r"(\d+)([-+]\d+)", i):
|
2023-05-06 13:27:52 +02:00
|
|
|
params.logit_bias[int(m.group(1))] = float(m.group(2))
|
2023-05-04 18:33:08 +02:00
|
|
|
|
|
|
|
|
return params
|
2023-04-06 15:30:57 +02:00
|
|
|
|
2024-07-09 12:20:17 -04:00
|
|
|
|
2023-04-06 15:30:57 +02:00
|
|
|
def gpt_random_prompt(rng):
|
|
|
|
|
return [
|
|
|
|
|
"So",
|
|
|
|
|
"Once upon a time",
|
|
|
|
|
"When",
|
|
|
|
|
"The",
|
|
|
|
|
"After",
|
|
|
|
|
"If",
|
|
|
|
|
"import",
|
|
|
|
|
"He",
|
|
|
|
|
"She",
|
|
|
|
|
"They",
|
|
|
|
|
][rng % 10]
|
|
|
|
|
|
2024-07-09 12:20:17 -04:00
|
|
|
|
2023-04-06 15:30:57 +02:00
|
|
|
if __name__ == "__main__":
|
2023-05-04 18:33:08 +02:00
|
|
|
print(gpt_params_parse())
|