2023-11-03 02:12:14 -04:00
from __future__ import annotations
2023-11-06 09:07:27 -05:00
import os
2024-07-09 14:06:46 -04:00
import sys
2023-11-10 02:51:58 -05:00
import json
2023-11-08 04:48:51 +01:00
import ctypes
2023-09-29 19:52:04 -04:00
import dataclasses
2024-02-08 09:07:03 +08:00
import random
import string
2024-04-30 01:35:38 -04:00
from contextlib import ExitStack
2024-07-09 12:20:17 -04:00
from typing import (
Any ,
Dict ,
Iterator ,
List ,
Literal ,
Optional ,
Tuple ,
Union ,
Protocol ,
cast ,
)
2023-11-06 09:07:27 -05:00
2024-01-18 21:21:37 -05:00
import jinja2
2024-05-10 12:47:56 +08:00
from jinja2 . sandbox import ImmutableSandboxedEnvironment
2024-01-18 21:21:37 -05:00
2024-04-20 00:00:53 -04:00
import numpy as np
import numpy . typing as npt
2025-07-03 01:57:43 -04:00
import llama_cpp . llama_cpp as llama_cpp
2023-11-08 04:48:51 +01:00
import llama_cpp . llama as llama
2023-11-08 00:07:16 -05:00
import llama_cpp . llama_types as llama_types
import llama_cpp . llama_grammar as llama_grammar
2023-11-03 02:12:14 -04:00
2024-02-23 18:40:52 +09:00
from . _logger import logger
2024-01-18 21:21:37 -05:00
from . _utils import suppress_stdout_stderr , Singleton
2023-11-08 11:05:45 -05:00
2024-01-29 14:22:23 -05:00
### Common Chat Templates and Special Tokens ###
# Source: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
CHATML_CHAT_TEMPLATE = " { % f or message in messages % } {{ ' <|im_start|> ' + message[ ' role ' ] + ' \n ' + message[ ' content ' ] + ' <|im_end|> ' + ' \n ' }} { % e ndfor % } { % i f add_generation_prompt % } {{ ' <|im_start|>assistant \n ' }} { % e ndif % } "
CHATML_BOS_TOKEN = " <s> "
CHATML_EOS_TOKEN = " <|im_end|> "
# Source: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
MISTRAL_INSTRUCT_CHAT_TEMPLATE = " {{ bos_token }} { % f or message in messages % } { % i f (message[ ' role ' ] == ' user ' ) != (loop.index0 % 2 == 0) % } {{ raise_exception( ' Conversation roles must alternate user/assistant/user/assistant/... ' ) }} { % e ndif % } { % i f message[ ' role ' ] == ' user ' % } {{ ' [INST] ' + message[ ' content ' ] + ' [/INST] ' }} { % e lif message[ ' role ' ] == ' assistant ' % } {{ message[ ' content ' ] + eos_token + ' ' }} { % e lse % } {{ raise_exception( ' Only user and assistant roles are supported! ' ) }} { % e ndif % } { % e ndfor % } "
MISTRAL_INSTRUCT_BOS_TOKEN = " <s> "
MISTRAL_INSTRUCT_EOS_TOKEN = " </s> "
2024-02-23 16:27:38 +00:00
# Source: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json
MIXTRAL_INSTRUCT_CHAT_TEMPLATE = " {{ bos_token }} { % f or message in messages % } { % i f (message[ ' role ' ] == ' user ' ) != (loop.index0 % 2 == 0) % } {{ raise_exception( ' Conversation roles must alternate user/assistant/user/assistant/... ' ) }} { % e ndif % } { % i f message[ ' role ' ] == ' user ' % } {{ ' [INST] ' + message[ ' content ' ] + ' [/INST] ' }} { % e lif message[ ' role ' ] == ' assistant ' % } {{ message[ ' content ' ] + eos_token}} { % e lse % } {{ raise_exception( ' Only user and assistant roles are supported! ' ) }} { % e ndif % } { % e ndfor % } "
2024-01-29 14:22:23 -05:00
2024-04-23 06:33:29 +00:00
# Source: https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json
LLAMA3_INSTRUCT_CHAT_TEMPLATE = " { % s et loop_messages = messages % } { % f or message in loop_messages % } { % s et content = ' <|start_header_id|> ' + message[ ' role ' ] + ' <|end_header_id|> \n \n ' + message[ ' content ' ] | trim + ' <|eot_id|> ' % } { % i f loop.index0 == 0 % } { % s et content = bos_token + content % } { % e ndif % } {{ content }} { % e ndfor % } { % i f add_generation_prompt % } {{ ' <|start_header_id|>assistant<|end_header_id|> \n \n ' }} { % e ndif % } "
2024-01-29 14:22:23 -05:00
### Chat Completion Handler ###
2023-11-03 02:12:14 -04:00
2024-02-12 15:56:07 -05:00
2023-11-03 02:12:14 -04:00
class LlamaChatCompletionHandler ( Protocol ) :
2024-01-18 21:21:37 -05:00
""" Base Protocol for a llama chat completion handler.
Very generic protocol that can be used to implement any chat format.
The only hard requirement is that it must return a ChatCompletion when
stream=False and an iterator of ChatCompletionChunks when stream=True. """
2023-11-03 02:12:14 -04:00
def __call__ (
self ,
2023-11-08 04:48:51 +01:00
* ,
2024-01-18 21:21:37 -05:00
# llama.cpp instance
2023-11-03 02:12:14 -04:00
llama : llama . Llama ,
2024-01-18 21:21:37 -05:00
# openai api parameters
2023-11-03 02:12:14 -04:00
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
2023-11-08 04:48:51 +01:00
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
2023-11-03 02:12:14 -04:00
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
2023-11-07 23:41:29 -05:00
seed : Optional [ int ] = None ,
2023-11-08 00:07:16 -05:00
response_format : Optional [
llama_types . ChatCompletionRequestResponseFormat
] = None ,
2023-11-10 02:51:58 -05:00
max_tokens : Optional [ int ] = None ,
2023-11-03 02:12:14 -04:00
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
2024-01-18 21:21:37 -05:00
model : Optional [ str ] = None ,
logit_bias : Optional [ Dict [ str , float ] ] = None ,
# llama.cpp parameters
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-11-03 02:12:14 -04:00
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
2024-04-10 03:41:55 -04:00
logprobs : Optional [ bool ] = None ,
top_logprobs : Optional [ int ] = None ,
2023-11-08 04:48:51 +01:00
* * kwargs , # type: ignore
2023-11-08 00:07:16 -05:00
) - > Union [
llama_types . CreateChatCompletionResponse ,
Iterator [ llama_types . CreateChatCompletionStreamResponse ] ,
2024-02-12 15:56:07 -05:00
] : . . .
2023-11-03 02:12:14 -04:00
2024-01-18 21:21:37 -05:00
class LlamaChatCompletionHandlerNotFoundException ( Exception ) :
pass
class LlamaChatCompletionHandlerRegistry ( Singleton ) :
_chat_handlers : Dict [ str , LlamaChatCompletionHandler ] = { }
def register_chat_completion_handler (
self ,
name : str ,
chat_handler : LlamaChatCompletionHandler ,
overwrite : bool = False ,
) :
if not overwrite and name in self . _chat_handlers :
raise ValueError (
f " Formatter with name ' { name } ' is already registered. Use `overwrite=True` to overwrite it. "
)
self . _chat_handlers [ name ] = chat_handler
def unregister_chat_handler ( self , name : str ) :
if name in self . _chat_handlers :
del self . _chat_handlers [ name ]
else :
raise ValueError ( f " No formatter registered under the name ' { name } ' . " )
def get_chat_completion_handler_by_name (
self , name : str
) - > LlamaChatCompletionHandler :
try :
chat_handler = self . _chat_handlers [ name ]
return chat_handler
except KeyError :
raise LlamaChatCompletionHandlerNotFoundException (
f " Invalid chat handler: { name } (valid formats: { list ( self . _chat_handlers . keys ( ) ) } ) "
)
2023-11-03 02:12:14 -04:00
def get_chat_completion_handler ( name : str ) - > LlamaChatCompletionHandler :
2024-01-18 21:21:37 -05:00
return LlamaChatCompletionHandlerRegistry ( ) . get_chat_completion_handler_by_name (
name
)
2023-11-03 02:12:14 -04:00
def register_chat_completion_handler ( name : str ) :
def decorator ( f : LlamaChatCompletionHandler ) :
2024-01-18 21:21:37 -05:00
LlamaChatCompletionHandlerRegistry ( ) . register_chat_completion_handler ( name , f )
2023-11-03 02:12:14 -04:00
return f
return decorator
2023-09-29 19:52:04 -04:00
2024-01-18 21:21:37 -05:00
### Chat Formatter ###
2024-02-12 15:56:07 -05:00
2024-01-18 21:21:37 -05:00
@dataclasses.dataclass
class ChatFormatterResponse :
2024-01-19 15:04:42 -05:00
""" Dataclass that stores completion parameters for a given chat format and
create_chat_completion request.
prompt contains the formatted prompt generated from the chat format and messages.
stop contains the stop token or list of stop tokens to use for the chat format. """
2024-01-18 21:21:37 -05:00
prompt : str
stop : Optional [ Union [ str , List [ str ] ] ] = None
2024-04-20 00:00:53 -04:00
stopping_criteria : Optional [ llama . StoppingCriteriaList ] = None
2024-06-04 16:15:41 +02:00
added_special : bool = False
2024-01-18 21:21:37 -05:00
class ChatFormatter ( Protocol ) :
""" Base Protocol for a chat formatter. A chat formatter is a function that
2024-01-19 15:04:42 -05:00
takes a list of messages and returns a chat format response which can be used
to generate a completion. The response can also include a stop token or list
of stop tokens to use for the completion. """
2024-01-18 21:21:37 -05:00
def __call__ (
self ,
* ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
2024-02-12 15:56:07 -05:00
) - > ChatFormatterResponse : . . .
2024-01-18 21:21:37 -05:00
2024-01-19 15:04:42 -05:00
class Jinja2ChatFormatter ( ChatFormatter ) :
def __init__ (
self ,
template : str ,
eos_token : str ,
bos_token : str ,
2024-01-21 18:37:24 -05:00
add_generation_prompt : bool = True ,
2024-04-20 00:00:53 -04:00
stop_token_ids : Optional [ List [ int ] ] = None ,
2024-01-19 15:04:42 -05:00
) :
""" A chat formatter that uses jinja2 templates to format the prompt. """
self . template = template
self . eos_token = eos_token
self . bos_token = bos_token
2024-01-21 18:37:24 -05:00
self . add_generation_prompt = add_generation_prompt
2024-07-09 12:20:17 -04:00
self . stop_token_ids = (
set ( stop_token_ids ) if stop_token_ids is not None else None
)
2023-09-29 19:52:04 -04:00
2024-05-10 12:47:56 +08:00
self . _environment = ImmutableSandboxedEnvironment (
2024-01-19 15:04:42 -05:00
loader = jinja2 . BaseLoader ( ) ,
trim_blocks = True ,
lstrip_blocks = True ,
) . from_string ( self . template )
2023-09-29 19:52:04 -04:00
2024-01-19 15:04:42 -05:00
def __call__ (
self ,
* ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
2024-03-19 04:55:57 -04:00
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
2024-01-19 15:04:42 -05:00
* * kwargs : Any ,
) - > ChatFormatterResponse :
2024-01-31 08:42:21 -05:00
def raise_exception ( message : str ) :
raise ValueError ( message )
2024-01-19 15:04:42 -05:00
prompt = self . _environment . render (
2024-01-31 08:42:21 -05:00
messages = messages ,
eos_token = self . eos_token ,
bos_token = self . bos_token ,
raise_exception = raise_exception ,
2024-02-12 15:56:07 -05:00
add_generation_prompt = self . add_generation_prompt ,
2024-03-19 04:55:57 -04:00
functions = functions ,
function_call = function_call ,
tools = tools ,
tool_choice = tool_choice ,
2024-01-19 15:04:42 -05:00
)
2024-01-31 08:42:21 -05:00
2024-04-20 00:00:53 -04:00
stopping_criteria = None
if self . stop_token_ids is not None :
2024-07-09 12:20:17 -04:00
2024-04-20 00:00:53 -04:00
def stop_on_last_token (
2024-07-09 12:20:17 -04:00
tokens : npt . NDArray [ np . intc ] , logits : npt . NDArray [ np . single ]
2024-04-20 00:00:53 -04:00
) - > bool :
return tokens [ - 1 ] in self . stop_token_ids
2024-07-09 12:20:17 -04:00
2024-04-20 00:00:53 -04:00
stopping_criteria = llama . StoppingCriteriaList ( [ stop_on_last_token ] )
2024-07-09 12:20:17 -04:00
return ChatFormatterResponse (
prompt = prompt ,
stop = [ self . eos_token ] ,
stopping_criteria = stopping_criteria ,
added_special = True ,
)
2023-09-30 21:01:34 -04:00
2024-01-19 15:04:42 -05:00
def to_chat_handler ( self ) - > LlamaChatCompletionHandler :
return chat_formatter_to_chat_completion_handler ( self )
2024-01-05 00:12:02 +01:00
2023-09-30 21:01:34 -04:00
2024-12-06 12:35:46 +00:00
def _convert_text_completion_logprobs_to_chat (
logprobs : Optional [ llama_types . CompletionLogprobs ] ,
) - > llama_types . ChatCompletionLogprobs :
if logprobs is None :
return None
return {
" content " : [
{
" token " : token ,
" bytes " : None ,
" logprob " : logprob ,
" top_logprobs " : [
{
" token " : top_token ,
" logprob " : top_logprob ,
" bytes " : None ,
}
for top_token , top_logprob in top_logprobs . items ( )
] ,
} for ( token , logprob , top_logprobs ) in zip ( logprobs [ " tokens " ] , logprobs [ " token_logprobs " ] , logprobs [ " top_logprobs " ] )
] ,
" refusal " : None ,
}
2023-11-03 02:12:14 -04:00
def _convert_text_completion_to_chat (
completion : llama_types . Completion ,
) - > llama_types . ChatCompletion :
2024-01-18 21:21:37 -05:00
assert " usage " in completion
2023-11-03 02:12:14 -04:00
return {
" id " : " chat " + completion [ " id " ] ,
" object " : " chat.completion " ,
" created " : completion [ " created " ] ,
" model " : completion [ " model " ] ,
" choices " : [
{
" index " : 0 ,
" message " : {
" role " : " assistant " ,
" content " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( completion [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2023-11-03 02:12:14 -04:00
" finish_reason " : completion [ " choices " ] [ 0 ] [ " finish_reason " ] ,
}
] ,
" usage " : completion [ " usage " ] ,
}
def _convert_text_completion_chunks_to_chat (
2023-11-08 04:48:51 +01:00
chunks : Iterator [ llama_types . CreateCompletionStreamResponse ] ,
2023-11-03 02:12:14 -04:00
) - > Iterator [ llama_types . ChatCompletionChunk ] :
for i , chunk in enumerate ( chunks ) :
if i == 0 :
yield {
" id " : " chat " + chunk [ " id " ] ,
" model " : chunk [ " model " ] ,
" created " : chunk [ " created " ] ,
" object " : " chat.completion.chunk " ,
" choices " : [
{
" index " : 0 ,
" delta " : {
" role " : " assistant " ,
} ,
2024-04-01 02:30:13 +09:00
" logprobs " : None ,
2023-11-03 02:12:14 -04:00
" finish_reason " : None ,
}
] ,
}
yield {
" id " : " chat " + chunk [ " id " ] ,
" model " : chunk [ " model " ] ,
" created " : chunk [ " created " ] ,
" object " : " chat.completion.chunk " ,
" choices " : [
{
" index " : 0 ,
2024-02-12 15:56:07 -05:00
" delta " : (
{
" content " : chunk [ " choices " ] [ 0 ] [ " text " ] ,
}
if chunk [ " choices " ] [ 0 ] [ " finish_reason " ] is None
else { }
) ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( chunk [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2023-11-03 02:12:14 -04:00
" finish_reason " : chunk [ " choices " ] [ 0 ] [ " finish_reason " ] ,
}
] ,
}
def _convert_completion_to_chat (
completion_or_chunks : Union [
2023-11-08 04:48:51 +01:00
llama_types . CreateCompletionResponse ,
Iterator [ llama_types . CreateCompletionStreamResponse ] ,
2023-11-03 02:12:14 -04:00
] ,
stream : bool = False ,
2023-11-08 04:48:51 +01:00
) - > Union [
llama_types . CreateChatCompletionResponse , Iterator [ llama_types . ChatCompletionChunk ]
] :
2023-11-03 02:12:14 -04:00
if stream :
2023-11-08 04:48:51 +01:00
chunks : Iterator [ llama_types . CreateCompletionStreamResponse ] = completion_or_chunks # type: ignore
2023-11-03 02:12:14 -04:00
return _convert_text_completion_chunks_to_chat ( chunks )
else :
completion : llama_types . Completion = completion_or_chunks # type: ignore
return _convert_text_completion_to_chat ( completion )
2024-03-19 04:55:57 -04:00
def _convert_completion_to_chat_function (
tool_name : str ,
completion_or_chunks : Union [
llama_types . CreateCompletionResponse ,
Iterator [ llama_types . CreateCompletionStreamResponse ] ,
] ,
stream : bool ,
) :
if not stream :
completion : llama_types . CreateCompletionResponse = completion_or_chunks # type: ignore
assert " usage " in completion
tool_id = " call_ " + " _0_ " + tool_name + " _ " + completion [ " id " ]
# TODO: Fix for legacy function calls
chat_completion : llama_types . CreateChatCompletionResponse = {
" id " : " chat " + completion [ " id " ] ,
" object " : " chat.completion " ,
" created " : completion [ " created " ] ,
" model " : completion [ " model " ] ,
" choices " : [
{
" index " : 0 ,
" message " : {
" role " : " assistant " ,
" content " : None ,
" function_call " : {
" name " : tool_name ,
" arguments " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
" tool_calls " : [
{
" id " : tool_id ,
" type " : " function " ,
" function " : {
" name " : tool_name ,
" arguments " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
}
] ,
} ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( completion [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2024-03-19 04:55:57 -04:00
" finish_reason " : " tool_calls " ,
}
] ,
" usage " : completion [ " usage " ] ,
}
return chat_completion
else :
chunks : Iterator [ llama_types . CreateCompletionStreamResponse ] = completion_or_chunks # type: ignore
def _stream_response_to_function_stream (
chunks : Iterator [ llama_types . CreateCompletionStreamResponse ] ,
) - > Iterator [ llama_types . CreateChatCompletionStreamResponse ] :
# blank first message
first = True
id_ = None
created = None
model = None
tool_id = None
for chunk in chunks :
if first :
id_ = " chat " + chunk [ " id " ]
created = chunk [ " created " ]
model = chunk [ " model " ]
tool_id = " call_ " + " _0_ " + tool_name + " _ " + chunk [ " id " ]
yield {
" id " : id_ ,
" object " : " chat.completion.chunk " ,
" created " : created ,
" model " : model ,
" choices " : [
{
" index " : 0 ,
" finish_reason " : None ,
" logprobs " : None ,
" delta " : {
" role " : " assistant " ,
" content " : None ,
" function_call " : None ,
" tool_calls " : None ,
} ,
}
] ,
}
yield {
" id " : " chat " + chunk [ " id " ] ,
" object " : " chat.completion.chunk " ,
" created " : chunk [ " created " ] ,
" model " : chunk [ " model " ] ,
" choices " : [
{
" index " : 0 ,
" finish_reason " : None ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( chunk [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2024-03-19 04:55:57 -04:00
" delta " : {
" role " : None ,
" content " : None ,
" function_call " : {
" name " : tool_name ,
" arguments " : chunk [ " choices " ] [ 0 ] [ " text " ] ,
} ,
" tool_calls " : [
{
" index " : 0 ,
" id " : tool_id ,
" type " : " function " ,
" function " : {
" name " : tool_name ,
2024-07-09 12:20:17 -04:00
" arguments " : chunk [ " choices " ] [ 0 ] [
" text "
] ,
2024-03-19 04:55:57 -04:00
} ,
}
] ,
} ,
}
] ,
}
first = False
continue
assert tool_id is not None
yield {
" id " : " chat " + chunk [ " id " ] ,
" object " : " chat.completion.chunk " ,
" created " : chunk [ " created " ] ,
" model " : chunk [ " model " ] ,
" choices " : [
{
" index " : 0 ,
" finish_reason " : None ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( chunk [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2024-03-19 04:55:57 -04:00
" delta " : {
" role " : None ,
" content " : None ,
" function_call " : {
" name " : tool_name ,
" arguments " : chunk [ " choices " ] [ 0 ] [ " text " ] ,
} ,
" tool_calls " : [
{
" index " : 0 ,
" id " : tool_id ,
" type " : " function " ,
" function " : {
" name " : tool_name ,
2024-07-09 12:20:17 -04:00
" arguments " : chunk [ " choices " ] [ 0 ] [ " text " ] ,
2024-03-19 04:55:57 -04:00
} ,
}
] ,
} ,
}
] ,
}
if id_ is not None and created is not None and model is not None :
yield {
" id " : id_ ,
" object " : " chat.completion.chunk " ,
" created " : created ,
" model " : model ,
" choices " : [
{
" index " : 0 ,
" finish_reason " : " tool_calls " ,
" logprobs " : None ,
" delta " : {
" role " : None ,
" content " : None ,
" function_call " : None ,
" tool_calls " : None ,
} ,
}
] ,
}
return _stream_response_to_function_stream ( chunks )
2024-01-18 21:21:37 -05:00
def chat_formatter_to_chat_completion_handler (
chat_formatter : ChatFormatter ,
) - > LlamaChatCompletionHandler :
def chat_completion_handler (
* ,
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
seed : Optional [ int ] = None ,
response_format : Optional [
llama_types . ChatCompletionRequestResponseFormat
] = None ,
max_tokens : Optional [ int ] = None ,
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
logit_bias : Optional [ Dict [ str , float ] ] = None ,
2024-04-10 03:41:55 -04:00
logprobs : Optional [ bool ] = None ,
top_logprobs : Optional [ int ] = None ,
2024-01-18 21:21:37 -05:00
* * kwargs , # type: ignore
) - > Union [
llama_types . CreateChatCompletionResponse ,
Iterator [ llama_types . CreateChatCompletionStreamResponse ] ,
] :
result = chat_formatter (
messages = messages ,
functions = functions ,
function_call = function_call ,
2024-03-19 04:55:57 -04:00
tools = tools ,
tool_choice = tool_choice ,
2024-01-18 21:21:37 -05:00
)
2024-07-09 12:20:17 -04:00
prompt = llama . tokenize (
result . prompt . encode ( " utf-8 " ) ,
add_bos = not result . added_special ,
special = True ,
)
2024-01-18 21:21:37 -05:00
if result . stop is not None :
stop = [ ] if stop is None else [ stop ] if isinstance ( stop , str ) else stop
rstop = result . stop if isinstance ( result . stop , list ) else [ result . stop ]
stop = stop + rstop
2024-04-20 00:00:53 -04:00
stopping_criteria = None
if result . stopping_criteria is not None :
stopping_criteria = result . stopping_criteria
2024-01-18 21:21:37 -05:00
if response_format is not None and response_format [ " type " ] == " json_object " :
2024-07-09 12:20:17 -04:00
grammar = _grammar_for_response_format (
response_format , verbose = llama . verbose
)
2024-01-18 21:21:37 -05:00
2024-03-19 04:55:57 -04:00
# Convert legacy functions to tools
if functions is not None :
tools = [
{
" type " : " function " ,
" function " : function ,
}
for function in functions
]
# Convert legacy function_call to tool_choice
if function_call is not None :
if isinstance ( function_call , str ) and (
function_call == " none " or function_call == " auto "
) :
tool_choice = function_call
if isinstance ( function_call , dict ) and " name " in function_call :
tool_choice = {
" type " : " function " ,
" function " : {
" name " : function_call [ " name " ] ,
} ,
}
tool = None
2024-07-09 12:20:17 -04:00
if (
tool_choice is not None
and isinstance ( tool_choice , dict )
and tools is not None
) :
2024-03-19 04:55:57 -04:00
name = tool_choice [ " function " ] [ " name " ]
tool = next ( ( t for t in tools if t [ " function " ] [ " name " ] == name ) , None )
if tool is None :
raise ValueError ( f " Tool choice ' { name } ' not found in tools. " )
schema = tool [ " function " ] [ " parameters " ]
try :
# create grammar from json schema
grammar = llama_grammar . LlamaGrammar . from_json_schema (
json . dumps ( schema ) , verbose = llama . verbose
)
except Exception as e :
2024-07-09 14:06:46 -04:00
if llama . verbose :
print ( str ( e ) , file = sys . stderr )
2024-03-19 04:55:57 -04:00
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF , verbose = llama . verbose
)
2024-01-18 21:21:37 -05:00
completion_or_chunks = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
2024-04-10 03:41:55 -04:00
logprobs = top_logprobs if logprobs else None ,
2024-01-18 21:21:37 -05:00
stream = stream ,
stop = stop ,
seed = seed ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
2024-04-20 00:00:53 -04:00
stopping_criteria = stopping_criteria ,
2024-01-18 21:21:37 -05:00
grammar = grammar ,
logit_bias = logit_bias ,
)
2024-03-19 04:55:57 -04:00
if tool is not None :
tool_name = tool [ " function " ] [ " name " ]
return _convert_completion_to_chat_function (
tool_name , completion_or_chunks , stream
)
2024-01-18 21:21:37 -05:00
return _convert_completion_to_chat ( completion_or_chunks , stream = stream )
return chat_completion_handler
2023-09-29 19:52:04 -04:00
2023-11-08 04:48:51 +01:00
def hf_autotokenizer_to_chat_formatter (
pretrained_model_name_or_path : Union [ str , os . PathLike [ str ] ]
) - > ChatFormatter :
2023-11-06 09:07:27 -05:00
# https://huggingface.co/docs/transformers/main/chat_templating
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1#instruction-format
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
2024-01-18 21:21:37 -05:00
from transformers import AutoTokenizer # type: ignore
2023-11-06 09:07:27 -05:00
2024-01-18 21:21:37 -05:00
tokenizer = AutoTokenizer . from_pretrained ( pretrained_model_name_or_path ) # type: ignore
2023-11-06 09:07:27 -05:00
def format_autotokenizer (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
2024-01-18 21:21:37 -05:00
tokenizer . use_default_system_prompt = False # type: ignore
prompt : str = tokenizer . apply_chat_template ( messages , tokenize = False ) # type: ignore
assert isinstance ( prompt , str )
2023-11-06 09:07:27 -05:00
# Return formatted prompt and eos token by default
2024-07-09 12:20:17 -04:00
return ChatFormatterResponse (
prompt = prompt , stop = tokenizer . eos_token , added_special = True
)
2023-11-06 09:07:27 -05:00
return format_autotokenizer
2024-01-18 21:21:37 -05:00
def hf_autotokenizer_to_chat_completion_handler (
pretrained_model_name_or_path : Union [ str , os . PathLike [ str ] ]
) - > LlamaChatCompletionHandler :
chat_formatter = hf_autotokenizer_to_chat_formatter ( pretrained_model_name_or_path )
return chat_formatter_to_chat_completion_handler ( chat_formatter )
2024-01-19 15:04:42 -05:00
def hf_tokenizer_config_to_chat_formatter (
2024-01-22 08:32:48 -05:00
tokenizer_config : Dict [ str , Any ] ,
add_generation_prompt : bool = True ,
2024-01-19 15:04:42 -05:00
) - > ChatFormatter :
2024-01-18 21:21:37 -05:00
assert isinstance ( tokenizer_config , dict )
assert " chat_template " in tokenizer_config
assert isinstance ( tokenizer_config [ " chat_template " ] , str )
chat_template = tokenizer_config [ " chat_template " ]
assert " bos_token " in tokenizer_config
assert isinstance ( tokenizer_config [ " bos_token " ] , str )
bos_token = tokenizer_config [ " bos_token " ]
assert " eos_token " in tokenizer_config
assert isinstance ( tokenizer_config [ " eos_token " ] , str )
eos_token = tokenizer_config [ " eos_token " ]
2024-05-10 06:49:40 +02:00
env = ImmutableSandboxedEnvironment (
2024-01-18 21:21:37 -05:00
trim_blocks = True ,
lstrip_blocks = True ,
) . from_string ( chat_template )
2024-01-22 08:32:48 -05:00
def format_tokenizer_config (
2024-01-18 21:21:37 -05:00
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
# TODO: veryify this is correct
# Add a blank assistant message to the end of the messages to prompt the model to generate a response
2024-01-22 08:32:48 -05:00
if add_generation_prompt :
messages = [
2024-01-18 21:21:37 -05:00
* messages ,
llama_types . ChatCompletionRequestAssistantMessage (
role = " assistant " , content = " "
) ,
2024-01-22 08:32:48 -05:00
]
prompt = env . render (
messages = messages ,
2024-01-18 21:21:37 -05:00
bos_token = bos_token ,
eos_token = eos_token ,
)
2024-07-09 12:20:17 -04:00
return ChatFormatterResponse (
prompt = prompt , stop = [ eos_token , bos_token ] , added_special = True
)
2024-01-19 15:04:42 -05:00
2024-01-22 08:32:48 -05:00
return format_tokenizer_config
2024-01-18 21:21:37 -05:00
def hf_tokenizer_config_to_chat_completion_handler (
tokenizer_config : Dict [ str , Any ] ,
2024-01-22 08:32:48 -05:00
add_generation_prompt : bool = True ,
2024-01-18 21:21:37 -05:00
) - > LlamaChatCompletionHandler :
2024-02-12 15:56:07 -05:00
chat_formatter = hf_tokenizer_config_to_chat_formatter (
tokenizer_config , add_generation_prompt = add_generation_prompt
)
2024-01-18 21:21:37 -05:00
return chat_formatter_to_chat_completion_handler ( chat_formatter )
2024-01-29 14:22:23 -05:00
def guess_chat_format_from_gguf_metadata ( metadata : Dict [ str , str ] ) - > Optional [ str ] :
if " tokenizer.chat_template " not in metadata :
return None
if metadata [ " tokenizer.chat_template " ] == CHATML_CHAT_TEMPLATE :
return " chatml "
2024-07-09 12:20:17 -04:00
if (
metadata [ " tokenizer.chat_template " ] == MISTRAL_INSTRUCT_CHAT_TEMPLATE
or metadata [ " tokenizer.chat_template " ] == MIXTRAL_INSTRUCT_CHAT_TEMPLATE
) :
2024-01-29 14:22:23 -05:00
return " mistral-instruct "
2024-04-23 06:33:29 +00:00
if metadata [ " tokenizer.chat_template " ] == LLAMA3_INSTRUCT_CHAT_TEMPLATE :
return " llama-3 "
2024-01-29 14:22:23 -05:00
return None
2024-02-12 15:56:07 -05:00
2024-01-19 15:04:42 -05:00
### Utility functions for formatting chat prompts ###
2024-01-29 14:22:23 -05:00
# TODO: Replace these with jinja2 templates
2024-01-19 15:04:42 -05:00
def _get_system_message (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
) - > str :
""" Get the first system message. """
for message in messages :
if message [ " role " ] == " system " :
return message [ " content " ] or " "
return " "
def _map_roles (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
role_map : Dict [ str , str ] ,
) - > List [ Tuple [ str , Optional [ str ] ] ] :
""" Map the message roles. """
output : List [ Tuple [ str , Optional [ str ] ] ] = [ ]
for message in messages :
role = message [ " role " ]
if role in role_map :
content : str | None = (
message [ " content " ] if isinstance ( message [ " content " ] , str ) else None
)
output . append ( ( role_map [ role ] , content ) )
return output
def _format_llama2 (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str , sep2 : str
) - > str :
""" Format the prompt with the llama2 style. """
seps = [ sep , sep2 ]
ret = system_message + sep
for i , ( role , message ) in enumerate ( messages ) :
if system_message and i == 0 :
m = message or " "
ret + = m + seps [ i % 2 ]
elif message :
ret + = role + message + " " + seps [ i % 2 ]
else :
ret + = role + " "
return ret
def _format_add_colon_single (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the add-colon-single style. """
ret = system_message + sep
for role , message in messages :
if message :
ret + = role + " : " + message + sep
else :
ret + = role + " : "
return ret
def _format_add_colon_two (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str , sep2 : str
) - > str :
""" Format the prompt with the add-colon-two style. """
seps = [ sep , sep2 ]
ret = system_message + seps [ 0 ]
for i , ( role , message ) in enumerate ( messages ) :
if message :
ret + = role + " : " + message + seps [ i % 2 ]
else :
ret + = role + " : "
return ret
def _format_no_colon_single (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the no-colon-single style. """
ret = system_message
for role , message in messages :
if message :
ret + = role + message + sep
else :
ret + = role
return ret
def _format_add_colon_space_single (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the add-colon-space-single style. """
ret = system_message + sep
for role , message in messages :
if message :
ret + = role + " : " + message + sep
else :
ret + = role + " : " # must be end with a space
return ret
def _format_chatml (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the chatml style. """
ret = " " if system_message == " " else system_message + sep + " \n "
for role , message in messages :
if message :
ret + = role + " \n " + message + sep + " \n "
else :
ret + = role + " \n "
return ret
def _format_chatglm3 (
system_message : str , messages : List [ Tuple [ str , Optional [ str ] ] ] , sep : str
) - > str :
""" Format the prompt with the chatglm3 style. """
ret = " "
if system_message :
ret + = system_message
for role , message in messages :
if message :
ret + = role + " \n " + " " + message
else :
ret + = role
return ret
2024-07-09 12:20:17 -04:00
def _grammar_for_json ( verbose : bool = False ) :
return llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF , verbose = verbose
)
2024-03-15 12:58:34 -04:00
def _grammar_for_json_schema (
2024-07-09 12:20:17 -04:00
schema : str , verbose : bool = False , fallback_to_json : bool = True
2024-03-15 12:58:34 -04:00
) :
try :
return llama_grammar . LlamaGrammar . from_json_schema ( schema , verbose = verbose )
except Exception as e :
if fallback_to_json :
return _grammar_for_json ( verbose = verbose )
else :
raise e
2024-07-09 12:20:17 -04:00
2024-03-15 12:58:34 -04:00
def _grammar_for_response_format (
2024-07-09 12:20:17 -04:00
response_format : llama_types . ChatCompletionRequestResponseFormat ,
verbose : bool = False ,
2024-03-15 12:58:34 -04:00
) :
if response_format [ " type " ] != " json_object " :
return None
if " schema " in response_format :
return _grammar_for_json_schema (
json . dumps ( response_format [ " schema " ] ) , verbose = verbose
)
else :
return _grammar_for_json ( verbose = verbose )
2024-01-19 15:04:42 -05:00
2024-07-09 12:20:17 -04:00
2024-01-19 15:04:42 -05:00
### Chat Formats ###
def register_chat_format ( name : str ) :
def decorator ( f : ChatFormatter ) :
chat_completion_handler = chat_formatter_to_chat_completion_handler ( f )
LlamaChatCompletionHandlerRegistry ( ) . register_chat_completion_handler (
name , chat_completion_handler
)
return f
return decorator
2023-11-05 17:00:13 -05:00
# see https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py
# system prompt is "embedded" in the first message
2023-09-29 19:52:04 -04:00
@register_chat_format ( " llama-2 " )
def format_llama2 (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
2024-06-04 16:15:41 +02:00
_system_template = " [INST] <<SYS>> \n {system_message} \n <</SYS>> "
2023-11-05 17:00:13 -05:00
_roles = dict ( user = " <s>[INST] " , assistant = " [/INST] " )
2023-09-29 19:52:04 -04:00
_messages = _map_roles ( messages , _roles )
2023-11-05 17:00:13 -05:00
system_message = _get_system_message ( messages )
if system_message :
system_message = _system_template . format ( system_message = system_message )
_prompt = _format_llama2 ( system_message , _messages , " " , " </s> " ) + " [/INST] "
2023-09-29 19:52:04 -04:00
return ChatFormatterResponse ( prompt = _prompt )
2024-04-23 06:33:29 +00:00
# Chat format for Llama-3 models, see more details at:
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202-L229
@register_chat_format ( " llama-3 " )
def format_llama3 (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict (
system = " <|start_header_id|>system<|end_header_id|> \n \n " ,
user = " <|start_header_id|>user<|end_header_id|> \n \n " ,
assistant = " <|start_header_id|>assistant<|end_header_id|> \n \n " ,
)
_sep = " <|eot_id|> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
2024-06-04 16:15:41 +02:00
_prompt = _format_no_colon_single ( " " , _messages , _sep )
2024-04-23 06:33:29 +00:00
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-09-29 19:52:04 -04:00
@register_chat_format ( " alpaca " )
def format_alpaca (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " ### Instruction " , assistant = " ### Response " )
_sep = " \n \n "
_sep2 = " </s> "
system_message = _get_system_message ( messages )
_messages = _map_roles ( messages , _roles )
_prompt = _format_add_colon_two ( system_message , _messages , _sep , _sep2 )
return ChatFormatterResponse ( prompt = _prompt )
2024-01-18 21:21:37 -05:00
2023-12-14 10:43:43 +08:00
@register_chat_format ( " qwen " )
def format_qwen (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " <|im_start|>user " , assistant = " <|im_start|>assistant " )
2024-08-30 00:14:49 -04:00
system_message = _get_system_message ( messages ) or " You are a helpful assistant. "
2024-01-18 21:21:37 -05:00
system_template = " <|im_start|>system \n {system_message} "
system_message = system_template . format ( system_message = system_message )
2023-12-14 10:43:43 +08:00
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_sep = " <|im_end|> "
_prompt = _format_chatml ( system_message , _messages , _sep )
_sep2 = " <|endoftext|> "
2024-01-18 21:21:37 -05:00
return ChatFormatterResponse ( prompt = _prompt , stop = _sep2 )
2023-09-29 19:52:04 -04:00
@register_chat_format ( " vicuna " )
def format (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_message = " A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user ' s questions. "
_roles = dict ( user = " USER " , assistant = " ASSISTANT " )
_sep = " "
_sep2 = " </s> "
system_message = _system_message
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_two ( system_message , _messages , _sep , _sep2 )
return ChatFormatterResponse ( prompt = _prompt )
@register_chat_format ( " oasst_llama " )
def format_oasst_llama (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_template = " [INST] <<SYS>> \n {system_message} \n <</SYS>> \n \n "
_roles = dict ( user = " <|prompter|> " , assistant = " <|assistant|> " )
_sep = " </s> "
system_message = _get_system_message ( messages )
system_message = _system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-11-22 19:08:06 +08:00
@register_chat_format ( " baichuan-2 " )
def format_baichuan2 (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_template = " {system_message} "
_roles = dict ( user = " <reserved_106> " , assistant = " <reserved_107> " )
_sep = " "
system_message = _get_system_message ( messages )
system_message = _system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-11-23 14:19:50 +08:00
@register_chat_format ( " baichuan " )
def format_baichuan (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_template = " {system_message} "
_roles = dict ( user = " <reserved_102> " , assistant = " <reserved_103> " )
_sep = " "
system_message = _get_system_message ( messages )
system_message = _system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-09-29 19:52:04 -04:00
@register_chat_format ( " openbuddy " )
def format_openbuddy (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
2024-02-13 23:57:10 -05:00
_system_message = """ You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.
Always answer as helpfully and logically as possible, while being safe. Your answers should not include any harmful, political, religious, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don ' t know the answer to a question, please don ' t share false information.
You can speak fluently in many languages, for example: English, Chinese.
You cannot access the internet, but you have vast knowledge, cutoff: 2021-09.
You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), you are based on LLaMA and Falcon transformers model, not related to GPT or OpenAI.
"""
2023-09-29 19:52:04 -04:00
_roles = dict ( user = " User " , assistant = " Assistant " )
_sep = " \n "
system_message = _system_message
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
@register_chat_format ( " redpajama-incite " )
def format_redpajama_incite (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_system_message = _get_system_message ( messages )
_roles = dict ( user = " <human> " , assistant = " <bot> " )
_sep = " \n "
_stop = " <human> "
system_message = _system_message
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _stop )
@register_chat_format ( " snoozy " )
def format_snoozy (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = " ### Instruction: \n {system_message} "
default_system_message = " The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. "
_system_message = _get_system_message ( messages )
_system_message = (
_system_message if _system_message != " " else default_system_message
)
system_message = system_template . format ( system_message = _system_message )
_roles = dict ( user = " ### Prompt " , assistant = " ### Response " )
_sep = " \n "
_stop = " ### "
system_message = _system_message
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _stop )
@register_chat_format ( " phind " )
def format_phind (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " ### User Message " , assistant = " ### Assistant " )
_sep = " \n \n "
_system_message = " ### System Prompt \n You are an intelligent programming assistant. "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( _system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-11-21 04:02:20 -05:00
2023-11-20 21:19:25 -08:00
@register_chat_format ( " intel " )
def format_intel (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " ### User: " , assistant = " ### Assistant: " )
_sep = " \n "
_system_message = " ### System: \n {system_message} "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_add_colon_single ( _system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2023-09-29 19:52:04 -04:00
@register_chat_format ( " open-orca " )
def format_open_orca (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = " {system_message} "
system_message = (
" You are a helpful assistant. Please answer truthfully and write out your "
2023-11-27 09:39:18 +13:00
" thinking step by step to be sure you get the right answer. If you make a mistake or encounter "
" an error in your thinking, say so out loud and attempt to correct it. If you don ' t know or "
" aren ' t sure about something, say so clearly. You will act as a professional logician, mathematician, "
" and physicist. You will also act as the most appropriate type of expert to answer any particular "
" question or solve the relevant problem; state which expert type your are, if so. Also think of "
" any particular named expert that would be ideal to answer the relevant question or solve the "
" relevant problem; name and act as them, if appropriate. "
2023-09-29 19:52:04 -04:00
)
roles = ( " User " , " Assistant " )
sep = " <|end_of_turn|> \n "
# stop_token_ids=[32000, 32001], # "<|end_of_turn|>"
stop_str = " User "
system_message = system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , dict ( zip ( roles , roles ) ) )
_messages . append ( ( roles [ 1 ] , None ) )
_prompt = _format_add_colon_space_single ( system_message , _messages , sep )
return ChatFormatterResponse ( prompt = _prompt , stop = stop_str )
2023-09-30 21:01:34 -04:00
2023-11-20 21:19:25 -08:00
@register_chat_format ( " mistrallite " )
def format_mistrallite (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
_roles = dict ( user = " <|prompter|> " , assistant = " </s> \n <|assistant|> " )
_sep = " "
system_template = """ <|system|> {system_message} </s> """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt )
2024-01-18 21:21:37 -05:00
2023-11-22 22:20:08 -08:00
@register_chat_format ( " zephyr " )
def format_zephyr (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = """ <|system|>
{system_message} """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_roles = dict ( user = " <|user|> \n " , assistant = " <|assistant|> \n " )
_sep = " </s> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatml ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-11-21 04:02:20 -05:00
2023-12-12 09:44:04 +08:00
@register_chat_format ( " pygmalion " )
def format_pygmalion (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = """ <|system|> {system_message} """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_roles = dict ( user = " <|user|> " , assistant = " <|model|> " )
_sep = " \n "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatml ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-09-30 21:01:34 -04:00
@register_chat_format ( " chatml " )
def format_chatml (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = """ <|im_start|>system
{system_message} """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_roles = dict ( user = " <|im_start|>user " , assistant = " <|im_start|>assistant " )
_sep = " <|im_end|> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatml ( system_message , _messages , _sep )
2023-11-10 04:24:48 -05:00
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-11-03 02:12:14 -04:00
2024-01-18 21:21:37 -05:00
2024-01-29 02:34:42 -03:00
@register_chat_format ( " mistral-instruct " )
2024-01-29 00:59:01 -05:00
def format_mistral_instruct (
2024-01-29 02:34:42 -03:00
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
2024-01-29 00:59:01 -05:00
eos = " </s> "
stop = eos
2024-06-04 16:15:41 +02:00
prompt = " "
2024-01-29 00:59:01 -05:00
for message in messages :
2024-02-12 15:56:07 -05:00
if (
message [ " role " ] == " user "
and message [ " content " ] is not None
and isinstance ( message [ " content " ] , str )
) :
2024-01-29 00:59:01 -05:00
prompt + = " [INST] " + message [ " content " ]
2024-07-09 12:20:17 -04:00
elif message [ " role " ] == " assistant " and message [ " content " ] is not None :
2024-01-29 00:59:01 -05:00
prompt + = " [/INST] " + message [ " content " ] + eos
prompt + = " [/INST] "
return ChatFormatterResponse ( prompt = prompt , stop = stop )
2024-01-29 02:34:42 -03:00
2024-01-05 00:12:02 +01:00
@register_chat_format ( " chatglm3 " )
def format_chatglm3 (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = """ <|system|>
{system_message} """
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
_roles = dict ( user = " <|user|> " , assistant = " <|assistant|> " )
_sep = " </s> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatglm3 ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2023-11-21 04:02:20 -05:00
2023-11-20 21:19:25 -08:00
@register_chat_format ( " openchat " )
def format_openchat (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_template = " {system_message} <|end_of_turn|> "
system_message = _get_system_message ( messages )
system_message = system_template . format ( system_message = system_message )
2023-11-21 04:02:20 -05:00
_roles = dict (
user = " GPT4 Correct User: " , assistant = " <|end_of_turn|>GPT4 Correct Assistant: "
)
2023-11-20 21:19:25 -08:00
_sep = " <|end_of_turn|> "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_chatml ( system_message , _messages , _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2024-02-12 15:56:07 -05:00
2024-01-05 06:12:58 +07:00
# Chat format for Saiga models, see more details and available models:
# https://huggingface.co/collections/IlyaGusev/saiga2-saigamistral-6505d4ccc3d1e53166b636cd
@register_chat_format ( " saiga " )
def format_saiga (
messages : list [ llama_types . ChatCompletionRequestMessage ] ,
2024-01-18 21:21:37 -05:00
* * kwargs : Any ,
2024-01-05 06:12:58 +07:00
) - > ChatFormatterResponse :
_message_template = " <s> {role} \n {content} </s> "
_roles = dict ( user = " user " , bot = " bot " , system = " system " )
_messages = _map_roles ( messages , _roles )
_prompt = " "
for role , content in _messages :
if content :
_prompt + = _message_template . format ( role = role , content = content )
else :
_prompt + = f " <s> { role } \n "
# Response template
_prompt + = " <s>bot "
return ChatFormatterResponse ( prompt = _prompt . strip ( ) )
2024-02-12 15:56:07 -05:00
2024-02-23 18:40:52 +09:00
# Chat format for Google's Gemma models, see more details and available models:
# https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b
@register_chat_format ( " gemma " )
def format_gemma (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
* * kwargs : Any ,
) - > ChatFormatterResponse :
system_message = _get_system_message ( messages )
2024-04-05 10:50:49 -04:00
if system_message != " " :
2024-02-23 18:40:52 +09:00
logger . debug (
" `role= ' system ' ` messages are not allowed on Google ' s Gemma models. "
)
_roles = dict ( user = " <start_of_turn>user \n " , assistant = " <start_of_turn>model \n " )
_sep = " <end_of_turn> \n "
_messages = _map_roles ( messages , _roles )
_messages . append ( ( _roles [ " assistant " ] , None ) )
_prompt = _format_no_colon_single ( system_message = " " , messages = _messages , sep = _sep )
return ChatFormatterResponse ( prompt = _prompt , stop = _sep )
2024-01-29 14:22:23 -05:00
# Tricky chat formats that require custom chat handlers
2024-01-05 06:12:58 +07:00
2024-02-12 15:56:07 -05:00
2023-11-03 02:12:14 -04:00
@register_chat_completion_handler ( " functionary " )
def functionary_chat_handler (
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
2023-11-08 04:48:51 +01:00
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
2023-11-10 02:51:58 -05:00
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
2023-11-03 02:12:14 -04:00
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
2023-11-21 06:21:33 +02:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-11-03 02:12:14 -04:00
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
2023-11-09 00:55:23 -05:00
response_format : Optional [ llama_types . ChatCompletionRequestResponseFormat ] = None ,
2023-11-10 02:51:58 -05:00
max_tokens : Optional [ int ] = None ,
2023-11-03 02:12:14 -04:00
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
2023-11-08 04:48:51 +01:00
* * kwargs , # type: ignore
2023-11-03 02:12:14 -04:00
) - > Union [ llama_types . ChatCompletion , Iterator [ llama_types . ChatCompletionChunk ] ] :
SYSTEM_MESSAGE = """ A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user ' s questions. The assistant calls functions with appropriate input when necessary """
2023-11-21 04:02:20 -05:00
def generate_type_definition (
param : Dict [ str , llama_types . JsonType ] , indent_level : int , shared_defs
) - > str :
indent = " " * indent_level
if " $ref " in param :
2023-11-10 02:51:58 -05:00
# Reference to a shared definition
2023-11-21 04:02:20 -05:00
ref_name = param [ " $ref " ] . split ( " / " ) [
- 1
] # Extract the type name from the reference
2023-11-10 02:51:58 -05:00
return ref_name
2023-11-21 04:02:20 -05:00
elif param . get ( " type " ) == " array " :
items = param . get ( " items " , { } )
2023-11-10 02:51:58 -05:00
item_type = generate_type_definition ( items , indent_level + 1 , shared_defs )
return f " Array< { item_type } > "
2023-11-21 04:02:20 -05:00
elif param . get ( " type " ) == " object " :
properties = param . get ( " properties " , { } )
2023-11-10 02:51:58 -05:00
nested_schema = " { \n "
for nested_param_name , nested_param in properties . items ( ) :
2023-11-21 04:02:20 -05:00
nested_param_type = generate_type_definition (
nested_param , indent_level + 1 , shared_defs
)
nested_schema + = (
f " { indent } { nested_param_name } : { nested_param_type } , \n "
)
2023-11-10 02:51:58 -05:00
nested_schema + = indent + " } "
return nested_schema
2023-11-21 04:02:20 -05:00
elif " enum " in param :
2023-11-10 02:51:58 -05:00
# Enum type
2023-11-21 04:02:20 -05:00
return " | " . join ( [ f ' " { enum_value } " ' for enum_value in param [ " enum " ] ] )
2023-11-10 02:51:58 -05:00
else :
# Simple type
2023-11-21 04:02:20 -05:00
return param . get ( " type " , " any " )
2023-11-10 02:51:58 -05:00
def generate_shared_definitions ( shared_defs , indent_level : int ) - > str :
2023-11-21 04:02:20 -05:00
indent = " " * indent_level
2023-11-10 02:51:58 -05:00
shared_definitions = " "
for def_name , def_properties in shared_defs . items ( ) :
shared_definitions + = f " { indent } type { def_name } = "
2023-11-21 04:02:20 -05:00
if def_properties . get ( " type " ) == " object " :
shared_definitions + = generate_type_definition (
def_properties , indent_level , shared_defs
)
elif " enum " in def_properties :
2023-11-10 02:51:58 -05:00
# Enum type
2023-11-21 04:02:20 -05:00
shared_definitions + = " | " . join (
[ f ' " { enum_value } " ' for enum_value in def_properties [ " enum " ] ]
)
2023-11-10 02:51:58 -05:00
shared_definitions + = " ; \n "
return shared_definitions
def generate_schema_from_functions ( functions , namespace = " functions " ) - > str :
2023-11-21 04:02:20 -05:00
schema = (
" // Supported function definitions that should be called when necessary. \n "
)
2023-11-03 02:12:14 -04:00
schema + = f " namespace { namespace } {{ \n \n "
2023-11-10 02:51:58 -05:00
# Generate shared definitions
shared_definitions = { }
for function in functions :
parameters = function . get ( " parameters " , { } )
shared_definitions . update ( parameters . get ( " $defs " , { } ) )
schema + = generate_shared_definitions ( shared_definitions , 1 )
2023-11-03 02:12:14 -04:00
for function in functions :
function_name = function [ " name " ]
description = function . get ( " description " , " " )
2023-11-10 02:51:58 -05:00
parameters = function . get ( " parameters " , { } )
2023-11-03 02:12:14 -04:00
required_params = parameters . get ( " required " , [ ] )
2023-11-21 04:02:20 -05:00
2023-11-10 02:51:58 -05:00
schema + = f " // { description } \n "
schema + = f " type { function_name } = (_: {{ \n "
2023-11-21 04:02:20 -05:00
2023-11-03 02:12:14 -04:00
for param_name , param in parameters . get ( " properties " , { } ) . items ( ) :
2023-11-10 02:51:58 -05:00
param_description = param . get ( " description " , " " )
param_type = generate_type_definition ( param , 2 , shared_definitions )
optional_indicator = " " if param_name in required_params else " ? "
schema + = f " // { param_description } \n "
schema + = f " { param_name } { optional_indicator } : { param_type } , \n "
schema + = " }) => any; \n \n "
schema + = " }} // namespace {} \n " . format ( namespace )
2023-11-03 02:12:14 -04:00
return schema
def prepare_messages_for_inference (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunctions ] ] = None ,
2023-11-10 02:51:58 -05:00
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
2023-11-03 02:12:14 -04:00
) :
all_messages : List [ llama_types . ChatCompletionRequestMessage ] = [ ]
if functions is not None :
all_messages . append (
2023-11-08 04:48:51 +01:00
llama_types . ChatCompletionRequestSystemMessage (
2023-11-03 02:12:14 -04:00
role = " system " , content = generate_schema_from_functions ( functions )
)
)
2023-11-21 04:02:20 -05:00
2023-11-10 02:51:58 -05:00
if tools is not None :
all_messages . append (
llama_types . ChatCompletionRequestSystemMessage (
2023-11-21 04:02:20 -05:00
role = " system " ,
content = generate_schema_from_functions (
[
tool [ " function " ]
for tool in tools
if tool [ " type " ] == " function "
]
) ,
2023-11-10 02:51:58 -05:00
)
)
2023-11-03 02:12:14 -04:00
all_messages . append (
2023-11-08 04:48:51 +01:00
llama_types . ChatCompletionRequestSystemMessage (
2023-11-03 02:12:14 -04:00
role = " system " , content = SYSTEM_MESSAGE
)
)
for message in messages :
# Function call responses
if message [ " role " ] == " function " and " name " in message :
message [ " name " ] = f " functions. { message [ ' name ' ] } "
# Function call requests by assistant
if " function_call " in message :
message [ " function_call " ] [
" name "
] = f " functions. { message [ ' function_call ' ] [ ' name ' ] } "
all_messages . append ( message )
all_messages . append (
2023-11-08 04:48:51 +01:00
llama_types . ChatCompletionRequestAssistantMessage (
role = " assistant " , content = None
)
2023-11-03 02:12:14 -04:00
)
def message_to_str ( msg : llama_types . ChatCompletionRequestMessage ) :
if msg [ " role " ] == " system " :
return f " system: \n { msg [ ' content ' ] } \n "
elif msg [ " role " ] == " function " and " name " in msg :
return f " function name= { msg [ ' name ' ] } : \n { msg [ ' content ' ] } \n "
elif msg [ " role " ] == " function " and " function_call " in msg :
return f " function name= { msg [ ' function_call ' ] [ ' name ' ] } : \n { msg [ ' function_call ' ] [ ' arguments ' ] } \n "
2023-11-10 02:51:58 -05:00
elif msg [ " role " ] == " tool " :
if msg [ " content " ] is not None :
return f " function name= { msg [ ' tool_call_id ' ] } : \n { msg [ ' content ' ] } \n "
else :
return f " function name= { msg [ ' tool_call_id ' ] } \n "
2023-11-03 02:12:14 -04:00
elif msg [ " role " ] == " user " :
if msg [ " content " ] is None :
2023-11-10 02:51:58 -05:00
return " user: \n </s></s> \n "
2023-11-03 02:12:14 -04:00
else :
2023-11-10 02:51:58 -05:00
return f " user: \n </s> { msg [ ' content ' ] } </s> \n "
2023-11-03 02:12:14 -04:00
elif msg [ " role " ] == " assistant " :
if msg [ " content " ] is not None and " function_call " in msg :
2023-11-10 02:51:58 -05:00
return f " assistant: \n { msg [ ' content ' ] } \n assistant to= { msg [ ' function_call ' ] [ ' name ' ] } : \n { msg [ ' function_call ' ] [ ' arguments ' ] } </s> \n "
2023-11-03 02:12:14 -04:00
elif " function_call " in msg :
2023-11-10 02:51:58 -05:00
return f " assistant to= { msg [ ' function_call ' ] [ ' name ' ] } : \n { msg [ ' function_call ' ] [ ' arguments ' ] } </s> \n "
elif " tool_calls " in msg and len ( msg [ " tool_calls " ] ) > 0 :
2023-11-21 04:02:20 -05:00
for tool_call in msg [
" tool_calls "
] : # NOTE: probably doesn't work with the functionary model
2023-11-10 02:51:58 -05:00
return f " assistant to= { tool_call [ ' id ' ] } : \n { tool_call [ ' function ' ] [ ' arguments ' ] } </s> \n "
2023-11-03 02:12:14 -04:00
elif msg [ " content " ] is None :
return " assistant "
else :
return f " assistant: \n { msg [ ' content ' ] } \n "
else :
raise ValueError ( f " Unsupported role: { msg [ ' role ' ] } " )
return " " . join ( [ message_to_str ( msg ) for msg in all_messages ] )
2023-11-21 04:02:20 -05:00
2023-11-10 02:51:58 -05:00
if tools is not None :
functions = [ tool [ " function " ] for tool in tools if tool [ " type " ] == " function " ]
2023-11-21 04:02:20 -05:00
2023-11-10 02:51:58 -05:00
if tool_choice is not None :
2023-11-21 04:02:20 -05:00
function_call = (
tool_choice if isinstance ( tool_choice , str ) else tool_choice [ " function " ]
)
2023-11-03 02:12:14 -04:00
2023-11-10 02:51:58 -05:00
prompt = prepare_messages_for_inference ( messages , functions , tools )
2023-11-03 02:12:14 -04:00
if function_call is None and ( functions is None or len ( functions ) == 0 ) :
completion_or_completion_chunks = llama . create_completion (
prompt = prompt + " : \n " ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
2023-11-21 06:21:33 +02:00
min_p = min_p ,
typical_p = typical_p ,
2023-11-03 02:12:14 -04:00
stream = stream ,
stop = [ " user: " , " </s> " ] ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
)
return _convert_completion_to_chat ( completion_or_completion_chunks , stream = stream ) # type: ignore
if function_call is None or (
isinstance ( function_call , str ) and function_call == " auto "
) :
stop = " \n "
completion : llama_types . Completion = llama . create_completion (
prompt = prompt , stop = stop , stream = False
) # type: ignore
completion_text = completion [ " choices " ] [ 0 ] [ " text " ]
# strip " to=functions." and ending ":"
2023-11-10 02:51:58 -05:00
function_call = completion_text . split ( " . " ) [ - 1 ] [ : - 1 ]
2023-11-03 02:12:14 -04:00
new_prompt = prompt + completion_text + stop
elif isinstance ( function_call , str ) and function_call != " none " :
2024-07-09 14:06:46 -04:00
new_prompt = prompt + " : \n "
2023-11-03 02:12:14 -04:00
elif isinstance ( function_call , dict ) :
2023-11-10 02:51:58 -05:00
new_prompt = prompt + f " to=functions. { function_call [ ' name ' ] } : \n "
2023-11-03 02:12:14 -04:00
function_call = function_call [ " name " ]
else :
2024-07-09 14:06:46 -04:00
new_prompt = prompt + " : \n "
2023-11-10 02:51:58 -05:00
function_body = None
for function in functions or [ ] :
if function [ " name " ] == function_call :
function_body = function [ " parameters " ]
break
for tool in tools or [ ] :
if tool [ " type " ] == " function " and tool [ " function " ] [ " name " ] == function_call :
function_body = tool [ " function " ] [ " parameters " ]
break
2023-11-21 04:02:20 -05:00
2023-11-10 02:51:58 -05:00
if function_body is not None :
try :
with suppress_stdout_stderr ( disable = llama . verbose ) :
2023-11-21 04:02:20 -05:00
grammar_text = llama_grammar . json_schema_to_gbnf (
json . dumps ( function_body )
)
grammar = llama_grammar . LlamaGrammar . from_string (
2024-02-12 15:56:07 -05:00
llama_grammar . json_schema_to_gbnf ( json . dumps ( function_body ) ) ,
verbose = llama . verbose ,
2023-11-21 04:02:20 -05:00
)
2023-11-10 02:51:58 -05:00
print ( grammar_text )
except Exception as e :
if llama . verbose :
2023-11-21 04:02:20 -05:00
print (
" Failed to parse function body as JSON schema, falling back to default grammar "
)
2023-11-10 02:51:58 -05:00
print ( e )
with suppress_stdout_stderr ( disable = llama . verbose ) :
2023-11-21 04:02:20 -05:00
grammar = llama_grammar . LlamaGrammar . from_string (
2024-02-12 15:56:07 -05:00
llama_grammar . JSON_GBNF ,
verbose = llama . verbose ,
2023-11-21 04:02:20 -05:00
)
2023-11-10 02:51:58 -05:00
else :
with suppress_stdout_stderr ( disable = llama . verbose ) :
2024-02-12 15:56:07 -05:00
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF , verbose = llama . verbose
)
2023-11-03 02:12:14 -04:00
completion : llama_types . Completion = llama . create_completion (
2023-11-10 02:51:58 -05:00
prompt = new_prompt ,
stop = [ " user: " , " </s> " ] ,
stream = False ,
grammar = grammar ,
max_tokens = max_tokens ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
2023-11-21 06:21:33 +02:00
min_p = min_p ,
typical_p = typical_p ,
2023-11-10 02:51:58 -05:00
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
2023-11-03 02:12:14 -04:00
) # type: ignore
2023-11-08 04:48:51 +01:00
assert " usage " in completion
assert isinstance ( function_call , str )
2023-11-08 00:07:16 -05:00
assert stream is False # TODO: support stream mode
2023-11-08 04:48:51 +01:00
2023-11-23 20:14:23 -05:00
if llama . verbose :
print ( new_prompt )
print ( completion [ " choices " ] [ 0 ] [ " text " ] )
2023-11-09 00:55:23 -05:00
2023-11-23 20:14:23 -05:00
# TODO: support stream mode
2023-11-03 02:12:14 -04:00
return llama_types . CreateChatCompletionResponse (
id = " chat " + completion [ " id " ] ,
object = " chat.completion " ,
created = completion [ " created " ] ,
model = completion [ " model " ] ,
choices = [
{
" index " : 0 ,
" message " : {
2023-11-10 02:51:58 -05:00
" role " : " assistant " ,
2023-11-03 02:12:14 -04:00
" content " : None ,
" function_call " : {
" name " : function_call ,
" arguments " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
2023-11-10 02:51:58 -05:00
" tool_calls " : [
{
" id " : function_call ,
" type " : " function " ,
" function " : {
" name " : function_call ,
" arguments " : completion [ " choices " ] [ 0 ] [ " text " ] ,
2023-11-21 04:02:20 -05:00
} ,
2023-11-10 02:51:58 -05:00
}
2023-11-21 04:02:20 -05:00
] ,
2023-11-03 02:12:14 -04:00
} ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( completion [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2023-11-10 02:51:58 -05:00
" finish_reason " : " tool_calls " ,
2023-11-03 02:12:14 -04:00
}
] ,
usage = completion [ " usage " ] ,
)
2023-11-08 04:48:51 +01:00
2024-02-08 09:07:03 +08:00
@register_chat_completion_handler ( " functionary-v1 " )
@register_chat_completion_handler ( " functionary-v2 " )
def functionary_v1_v2_chat_handler (
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
response_format : Optional [ llama_types . ChatCompletionRequestResponseFormat ] = None ,
max_tokens : Optional [ int ] = None ,
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
* * kwargs , # type: ignore
) - > Union [ llama_types . ChatCompletion , Iterator [ llama_types . ChatCompletionChunk ] ] :
SYSTEM_MESSAGE = """ A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user ' s questions. The assistant calls functions with appropriate input when necessary """
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
tokenizer = llama . tokenizer_
2024-02-12 15:56:07 -05:00
assert hasattr (
tokenizer , " hf_tokenizer "
) , " Please provide a valid hf_tokenizer_path from https://huggingface.co/meetkai when initializing the Llama class "
2024-02-08 09:07:03 +08:00
from transformers import AutoTokenizer
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
if " <|START_OF_FUNCTION_CALL|> " in tokenizer . hf_tokenizer . additional_special_tokens :
version = " v1 "
END_SYSTEM_TOKEN = " <|END_OF_SYSTEM|> "
END_USER_TOKEN = " <|END_OF_USER|> "
END_ASSISTANT_TOKEN = " <|END_OF_ASSISTANT|> "
END_FUNCTION_RESULT_TOKEN = " <|END_OF_FUNCTION_RESULT|> "
START_FUNCTION_CALL_TOKEN = " <|START_OF_FUNCTION_CALL|> "
END_FUNCTION_CALL_TOKEN = " <|END_OF_FUNCTION_CALL|> "
else :
version = " v2 "
RECIPIENT_TOKEN = " <|recipient|> "
FROM_TOKEN = " <|from|> "
STOP_TOKEN = " <|stop|> "
CONTENT_TOKEN = " <|content|> "
def generate_type_definition (
param : Dict [ str , llama_types . JsonType ] , indent_level : int , shared_defs
) - > str :
indent = " " * indent_level
if " $ref " in param :
# Reference to a shared definition
ref_name = param [ " $ref " ] . split ( " / " ) [
- 1
] # Extract the type name from the reference
return ref_name
elif param . get ( " type " ) == " array " :
items = param . get ( " items " , { } )
item_type = generate_type_definition ( items , indent_level + 1 , shared_defs )
return f " Array< { item_type } > "
elif param . get ( " type " ) == " object " :
properties = param . get ( " properties " , { } )
nested_schema = " { \n "
for nested_param_name , nested_param in properties . items ( ) :
nested_param_type = generate_type_definition (
nested_param , indent_level + 1 , shared_defs
)
nested_schema + = (
f " { indent } { nested_param_name } : { nested_param_type } , \n "
)
nested_schema + = indent + " } "
return nested_schema
elif " enum " in param :
# Enum type
return " | " . join ( [ f ' " { enum_value } " ' for enum_value in param [ " enum " ] ] )
else :
# Simple type
return param . get ( " type " , " any " )
def generate_shared_definitions ( shared_defs , indent_level : int ) - > str :
indent = " " * indent_level
shared_definitions = " "
for def_name , def_properties in shared_defs . items ( ) :
shared_definitions + = f " { indent } type { def_name } = "
if def_properties . get ( " type " ) == " object " :
shared_definitions + = generate_type_definition (
def_properties , indent_level , shared_defs
)
elif " enum " in def_properties :
# Enum type
shared_definitions + = " | " . join (
[ f ' " { enum_value } " ' for enum_value in def_properties [ " enum " ] ]
)
shared_definitions + = " ; \n "
return shared_definitions
def generate_schema_from_functions ( functions , namespace = " functions " ) - > str :
schema = (
" // Supported function definitions that should be called when necessary. \n "
)
schema + = f " namespace { namespace } {{ \n \n "
# Generate shared definitions
shared_definitions = { }
for function in functions :
parameters = function . get ( " parameters " , { } )
shared_definitions . update ( parameters . get ( " $defs " , { } ) )
schema + = generate_shared_definitions ( shared_definitions , 1 )
for function in functions :
function_name = function [ " name " ]
description = function . get ( " description " , " " )
parameters = function . get ( " parameters " , { } )
required_params = parameters . get ( " required " , [ ] )
schema + = f " // { description } \n "
schema + = f " type { function_name } = (_: {{ \n "
for param_name , param in parameters . get ( " properties " , { } ) . items ( ) :
param_description = param . get ( " description " , " " )
param_type = generate_type_definition ( param , 2 , shared_definitions )
optional_indicator = " " if param_name in required_params else " ? "
schema + = f " // { param_description } \n "
schema + = f " { param_name } { optional_indicator } : { param_type } , \n "
schema + = " }) => any; \n \n "
schema + = " }} // namespace {} " . format ( namespace )
return schema
def prepare_messages_for_inference (
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
tokenizer : AutoTokenizer ,
version : Literal [ " v1 " , " v2 " ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunctions ] ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
2024-04-28 08:49:52 +08:00
tool_choice : Union [ Dict , str ] = " auto " ,
2024-02-08 09:07:03 +08:00
) :
all_messages : List [ llama_types . ChatCompletionRequestMessage ] = [ ]
2024-04-28 08:49:52 +08:00
if tool_choice == " none " :
2024-02-08 09:07:03 +08:00
all_messages . append (
llama_types . ChatCompletionRequestSystemMessage (
2024-04-28 08:49:52 +08:00
role = " system " , content = generate_schema_from_functions ( [ ] )
2024-02-08 09:07:03 +08:00
)
)
2024-04-28 08:49:52 +08:00
else :
if functions is not None :
all_messages . append (
llama_types . ChatCompletionRequestSystemMessage (
role = " system " , content = generate_schema_from_functions ( functions )
)
)
elif tools is not None and tool_choice != " none " :
all_messages . append (
llama_types . ChatCompletionRequestSystemMessage (
role = " system " ,
content = generate_schema_from_functions (
[
tool [ " function " ]
for tool in tools
if tool [ " type " ] == " function "
]
) ,
)
2024-02-08 09:07:03 +08:00
)
all_messages . append (
llama_types . ChatCompletionRequestSystemMessage (
role = " system " , content = SYSTEM_MESSAGE
)
)
for message in messages :
# Function call responses
if message [ " role " ] == " function " and " name " in message :
message [ " name " ] = f " functions. { message [ ' name ' ] } "
# Function call requests by assistant
if " function_call " in message :
message [ " function_call " ] [
" name "
] = f " functions. { message [ ' function_call ' ] [ ' name ' ] } "
all_messages . append ( message )
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
if version == " v1 " :
suffix = " assistant: \n "
else :
suffix = " <|from|>assistant \n <|recipient|> "
2024-02-12 15:56:07 -05:00
return (
tokenizer . hf_tokenizer . apply_chat_template ( all_messages , tokenize = False )
+ suffix
)
2024-02-08 09:07:03 +08:00
if tools is not None :
functions = [ tool [ " function " ] for tool in tools if tool [ " type " ] == " function " ]
if tool_choice is not None :
function_call = (
tool_choice if isinstance ( tool_choice , str ) else tool_choice [ " function " ]
)
2024-05-04 22:11:20 +08:00
elif function_call is not None :
pass
2024-03-18 22:40:57 +08:00
else :
function_call = " auto "
2024-02-08 09:07:03 +08:00
2024-02-12 15:56:07 -05:00
prompt = prepare_messages_for_inference (
2024-04-28 08:49:52 +08:00
messages , tokenizer , version , functions , tools , function_call
2024-02-12 15:56:07 -05:00
)
2024-02-08 09:07:03 +08:00
# If no tools/functions are provided
2024-03-18 22:40:57 +08:00
if function_call == " none " or functions is None or len ( functions ) == 0 :
2024-02-08 09:07:03 +08:00
if version == " v1 " :
stop = END_ASSISTANT_TOKEN
else :
stop = STOP_TOKEN
prompt + = " all \n <|content|> "
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
completion_or_completion_chunks = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = stream ,
stop = stop ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
)
2024-05-04 22:11:20 +08:00
if stream is False :
2024-07-09 12:20:17 -04:00
completion_or_completion_chunks [ " choices " ] [ 0 ] [ " text " ] = (
completion_or_completion_chunks [ " choices " ] [ 0 ] [ " text " ] . lstrip ( )
)
2024-02-08 09:07:03 +08:00
return _convert_completion_to_chat ( completion_or_completion_chunks , stream = stream ) # type: ignore
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
def get_grammar ( function_call ) :
function_body = None
for function in functions or [ ] :
if function [ " name " ] == function_call :
function_body = function [ " parameters " ]
break
for tool in tools or [ ] :
if tool [ " type " ] == " function " and tool [ " function " ] [ " name " ] == function_call :
function_body = tool [ " function " ] [ " parameters " ]
break
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
try :
with suppress_stdout_stderr ( disable = llama . verbose ) :
grammar_text = llama_grammar . json_schema_to_gbnf (
json . dumps ( function_body )
)
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . json_schema_to_gbnf ( json . dumps ( function_body ) )
)
print ( grammar_text )
except Exception as e :
if llama . verbose :
print (
" Failed to parse function body as JSON schema, falling back to default grammar "
)
print ( e )
with suppress_stdout_stderr ( disable = llama . verbose ) :
grammar = llama_grammar . LlamaGrammar . from_string (
2024-02-12 15:56:07 -05:00
llama_grammar . JSON_GBNF , verbose = llama . verbose
2024-02-08 09:07:03 +08:00
)
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
return grammar
2024-02-12 15:56:07 -05:00
2024-05-04 22:11:20 +08:00
def create_completion ( prompt , stop , grammar ) :
2024-07-09 12:20:17 -04:00
completion = cast (
llama_types . Completion ,
llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = stream ,
stop = stop ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
) ,
)
2024-02-12 15:56:07 -05:00
2024-02-08 09:07:03 +08:00
return completion
2024-02-12 15:56:07 -05:00
2024-03-18 22:40:57 +08:00
content = " "
2024-02-08 09:07:03 +08:00
function_calls , function_bodies = [ ] , [ ]
2024-04-28 08:49:52 +08:00
completion_tokens = 0
2024-07-09 12:20:17 -04:00
2024-05-04 22:11:20 +08:00
def generate_streaming ( tools , functions , function_call , prompt ) :
assert version == " v2 " , " Streaming for v1 is not supported "
2024-07-09 12:20:17 -04:00
2024-05-04 22:11:20 +08:00
chunk_id , chunk_created = None , None
2024-07-09 12:20:17 -04:00
2024-03-18 22:40:57 +08:00
# If tool_choice/function_call is provided
2024-04-28 08:49:52 +08:00
if isinstance ( function_call , dict ) :
2024-03-18 22:40:57 +08:00
prompt + = f " { function_call [ ' name ' ] } \n { CONTENT_TOKEN } "
2024-05-04 22:11:20 +08:00
grammar = get_grammar ( function_call [ " name " ] )
2024-03-18 22:40:57 +08:00
stops = [ STOP_TOKEN , FROM_TOKEN ]
2024-07-09 12:20:17 -04:00
tool_id = " " . join (
[ random . choice ( string . ascii_letters + string . digits ) for _ in range ( 24 ) ]
)
2024-05-04 22:11:20 +08:00
completion = create_completion ( prompt = prompt , stop = stops , grammar = grammar )
completion_text = " "
first = True
for chunk in completion :
# Yield the tool/function name first
if first :
if tools is not None :
func_call_dict = {
" tool_calls " : [
{
" index " : 0 ,
" id " : " call_ " + tool_id ,
" type " : " function " ,
2024-07-09 12:20:17 -04:00
" function " : {
" name " : function_call [ " name " ] ,
" arguments " : " " ,
} ,
2024-05-04 22:11:20 +08:00
}
]
}
else :
2024-07-09 12:20:17 -04:00
func_call_dict = {
" function_call " : {
" name " : function_call [ " name " ] ,
" arguments " : " " ,
}
}
2024-05-04 22:11:20 +08:00
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk [ " id " ] ,
object = " chat.completion.chunk " ,
created = chunk [ " created " ] ,
model = chunk [ " model " ] ,
choices = [
2024-07-09 12:20:17 -04:00
{
" index " : 0 ,
" logprobs " : None ,
" delta " : {
" role " : None ,
" content " : None ,
* * func_call_dict ,
} ,
}
2024-05-04 22:11:20 +08:00
] ,
)
first = False
if tools is not None :
func_call_dict = {
" tool_calls " : [
{
" index " : 0 ,
" id " : " call_ " + tool_id ,
" type " : " function " ,
" function " : {
" name " : None ,
" arguments " : chunk [ " choices " ] [ 0 ] [ " text " ] . rstrip ( ) ,
} ,
}
]
}
else :
2024-07-09 12:20:17 -04:00
func_call_dict = {
" function_call " : {
" name " : None ,
" arguments " : chunk [ " choices " ] [ 0 ] [ " text " ] . rstrip ( ) ,
}
}
2024-05-04 22:11:20 +08:00
if len ( chunk [ " choices " ] [ 0 ] [ " text " ] . rstrip ( ) ) > 0 :
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk [ " id " ] ,
object = " chat.completion.chunk " ,
created = chunk [ " created " ] ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( chunk [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2024-05-04 22:11:20 +08:00
" delta " : {
" role " : None ,
" content " : None ,
* * func_call_dict ,
} ,
}
] ,
)
# Yield tool_call/function_call stop message
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk [ " id " ] ,
object = " chat.completion.chunk " ,
created = chunk [ " created " ] ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
2024-07-09 12:20:17 -04:00
" finish_reason " : (
" tool_calls " if tools is not None else " function_call "
) ,
2024-05-04 22:11:20 +08:00
" logprobs " : None ,
" delta " : {
2024-07-09 12:20:17 -04:00
" role " : None ,
" content " : None ,
" function_call " : None ,
" tool_calls " : None ,
2024-05-04 22:11:20 +08:00
} ,
}
] ,
)
2024-03-18 22:40:57 +08:00
# If "auto" or no tool_choice/function_call
elif isinstance ( function_call , str ) and function_call == " auto " :
2024-05-04 22:11:20 +08:00
tool_index = 0
2024-03-18 22:40:57 +08:00
while True :
# Generate function name first
2024-02-08 09:07:03 +08:00
grammar = None
2024-03-18 22:40:57 +08:00
stops = CONTENT_TOKEN
2024-07-09 12:20:17 -04:00
completion = create_completion (
prompt = prompt , stop = stops , grammar = grammar
)
2024-05-04 22:11:20 +08:00
completion_text = " "
for chunk in completion :
completion_text + = chunk [ " choices " ] [ 0 ] [ " text " ]
if chunk_id is None :
chunk_id = chunk [ " id " ]
if chunk_created is None :
chunk_created = chunk [ " created " ]
2024-03-18 22:40:57 +08:00
function_name = completion_text . strip ( )
if function_name == " all " :
prompt + = " all \n <|content|> "
2024-05-04 22:11:20 +08:00
# Yield the first empty message for content
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
model = chunk [ " model " ] ,
created = chunk_created ,
object = " chat.completion.chunk " ,
choices = [
{
" index " : 0 ,
" delta " : { " role " : " assistant " , " content " : " " } ,
" logprobs " : None ,
" finish_reason " : None ,
}
] ,
)
2024-02-08 09:07:03 +08:00
else :
2024-05-04 22:11:20 +08:00
prompt + = f " { function_name } \n <|content|> "
grammar = get_grammar ( function_name )
2024-07-09 12:20:17 -04:00
tool_id = " " . join (
[
random . choice ( string . ascii_letters + string . digits )
for _ in range ( 24 )
]
)
2024-05-04 22:11:20 +08:00
if tools is not None :
func_call_dict = {
" tool_calls " : [
{
" index " : tool_index ,
" id " : " call_ " + tool_id ,
" type " : " function " ,
2024-07-09 12:20:17 -04:00
" function " : {
" name " : function_name ,
" arguments " : " " ,
} ,
2024-05-04 22:11:20 +08:00
}
]
}
else :
2024-07-09 12:20:17 -04:00
func_call_dict = {
" function_call " : { " name " : function_name , " arguments " : " " }
}
2024-05-04 22:11:20 +08:00
# Stream function name
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
object = " chat.completion.chunk " ,
created = chunk_created ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( chunk [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2024-05-04 22:11:20 +08:00
" delta " : {
" role " : " assistant " ,
" content " : None ,
* * func_call_dict ,
} ,
}
] ,
)
2024-03-18 22:40:57 +08:00
# Generate content
stops = [ RECIPIENT_TOKEN , STOP_TOKEN ]
2024-07-09 12:20:17 -04:00
completion = create_completion (
prompt = prompt , stop = stops , grammar = grammar
)
2024-03-18 22:40:57 +08:00
if function_name == " all " :
2024-05-04 22:11:20 +08:00
completion_text = " "
2024-07-09 12:20:17 -04:00
stop_sequence , buffer , is_end = (
" \n <|from|>assistant \n <|recipient|> " ,
[ ] ,
False ,
)
2024-05-04 22:11:20 +08:00
for i , chunk in enumerate ( completion ) :
completion_text + = chunk [ " choices " ] [ 0 ] [ " text " ]
if is_end :
buffer . append ( chunk [ " choices " ] [ 0 ] [ " text " ] . strip ( " " ) )
if stop_sequence . startswith ( " " . join ( buffer ) ) :
continue
else :
buffer . pop ( )
while len ( buffer ) > 0 :
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
object = " chat.completion.chunk " ,
created = chunk_created ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( chunk [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2024-05-04 22:11:20 +08:00
" delta " : {
2024-07-09 12:20:17 -04:00
" role " : " assistant " ,
" content " : buffer . pop ( 0 ) ,
2024-05-04 22:11:20 +08:00
} ,
}
] ,
)
is_end = False
elif chunk [ " choices " ] [ 0 ] [ " text " ] == " \n " :
is_end = True
buffer . append ( chunk [ " choices " ] [ 0 ] [ " text " ] . strip ( " " ) )
continue
if len ( buffer ) == 0 and len ( chunk [ " choices " ] [ 0 ] [ " text " ] ) > 0 :
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
object = " chat.completion.chunk " ,
created = chunk_created ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( chunk [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2024-05-04 22:11:20 +08:00
" delta " : {
" role " : " assistant " ,
2024-07-09 12:20:17 -04:00
" content " : (
chunk [ " choices " ] [ 0 ] [ " text " ]
if i > 0
else chunk [ " choices " ] [ 0 ] [
" text "
] . lstrip ( )
) ,
2024-05-04 22:11:20 +08:00
} ,
}
] ,
)
2024-03-18 22:40:57 +08:00
# Check whether the model wants to generate another turn
2024-07-09 12:20:17 -04:00
if (
" <|from|> assistant " in completion_text
or " <|from|>assistant " in completion_text
) :
2024-04-28 08:49:52 +08:00
if completion_text . endswith ( " \n <|from|>assistant \n " ) :
2024-07-09 12:20:17 -04:00
cleaned_completion_text = completion_text [
: - len ( " \n <|from|>assistant \n " )
] . strip ( )
2024-04-28 08:49:52 +08:00
elif completion_text . endswith ( " \n <|from|> assistant \n " ) :
2024-07-09 12:20:17 -04:00
cleaned_completion_text = completion_text [
: - len ( " \n <|from|> assistant \n " )
] . strip ( )
2024-04-28 08:49:52 +08:00
else :
cleaned_completion_text = completion_text . strip ( )
2024-03-18 22:40:57 +08:00
prompt + = f " { cleaned_completion_text } \n <|from|>assistant \n <|recipient|> "
else :
2024-05-04 22:11:20 +08:00
# Yield stop message
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
model = chunk [ " model " ] ,
created = chunk_created ,
object = " chat.completion.chunk " ,
choices = [
{
" index " : 0 ,
" delta " : { } ,
" logprobs " : None ,
" finish_reason " : " stop " ,
}
] ,
)
2024-03-18 22:40:57 +08:00
break
else :
# Check whether the model wants to generate another turn
2024-05-04 22:11:20 +08:00
completion_text = " "
for chunk in completion :
completion_text + = chunk [ " choices " ] [ 0 ] [ " text " ]
if len ( chunk [ " choices " ] [ 0 ] [ " text " ] . rstrip ( ) ) > 0 :
if tools is not None :
func_call_dict = {
" tool_calls " : [
{
" index " : tool_index ,
" id " : " call_ " + tool_id ,
" type " : " function " ,
" function " : {
" name " : None ,
2024-07-09 12:20:17 -04:00
" arguments " : chunk [ " choices " ] [ 0 ] [
" text "
] . rstrip ( ) ,
2024-05-04 22:11:20 +08:00
} ,
}
]
}
else :
2024-07-09 12:20:17 -04:00
func_call_dict = {
" function_call " : {
" name " : None ,
" arguments " : chunk [ " choices " ] [ 0 ] [
" text "
] . rstrip ( ) ,
}
}
2024-05-04 22:11:20 +08:00
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
object = " chat.completion.chunk " ,
created = chunk_created ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( chunk [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2024-05-04 22:11:20 +08:00
" delta " : {
" role " : None ,
" content " : None ,
* * func_call_dict ,
} ,
}
] ,
)
2024-03-18 22:40:57 +08:00
prompt + = completion_text . strip ( )
grammar = None
2024-07-09 12:20:17 -04:00
completion = create_completion (
prompt = prompt , stop = stops , grammar = grammar
)
completion_text + = " " . join (
[ chunk [ " choices " ] [ 0 ] [ " text " ] for chunk in completion ]
)
if (
" <|from|> assistant " in completion_text
or " <|from|>assistant " in completion_text
) and tools is not None :
2024-03-18 22:40:57 +08:00
prompt + = " \n <|from|>assistant \n <|recipient|> "
2024-05-04 22:11:20 +08:00
tool_index + = 1
2024-03-18 22:40:57 +08:00
else :
2024-05-04 22:11:20 +08:00
# Yield tool_call/function_call stop message
yield llama_types . CreateChatCompletionStreamResponse (
id = " chat " + chunk_id ,
object = " chat.completion.chunk " ,
created = chunk_created ,
model = chunk [ " model " ] ,
choices = [
{
" index " : 0 ,
2024-07-09 12:20:17 -04:00
" finish_reason " : (
" tool_calls "
if tools is not None
else " function_call "
) ,
2024-05-04 22:11:20 +08:00
" logprobs " : None ,
" delta " : {
2024-07-09 12:20:17 -04:00
" role " : None ,
" content " : None ,
" function_call " : None ,
" tool_calls " : None ,
2024-05-04 22:11:20 +08:00
} ,
}
] ,
)
2024-03-18 22:40:57 +08:00
break
2024-07-09 12:20:17 -04:00
2024-05-04 22:11:20 +08:00
if stream is not False :
return generate_streaming (
tools = tools , functions = functions , function_call = function_call , prompt = prompt
2024-02-08 09:07:03 +08:00
)
2024-05-04 22:11:20 +08:00
else :
if version == " v1 " :
# If no or "auto" tool_choice/function_call
if isinstance ( function_call , str ) and function_call == " auto " :
stops = [ " \n " , END_ASSISTANT_TOKEN ]
# If tool_choice/function_call is provided
elif isinstance ( function_call , dict ) :
prompt + = f " { START_FUNCTION_CALL_TOKEN } { function_call [ ' name ' ] } : \n "
stops = END_FUNCTION_CALL_TOKEN
function_call = function_call [ " name " ]
function_calls . append ( function_call )
grammar = get_grammar ( function_call )
else :
prompt = prompt
stops = [ " \n " , END_ASSISTANT_TOKEN ]
2024-02-08 09:07:03 +08:00
2024-05-08 07:21:27 +01:00
completion = create_completion ( prompt = prompt , stop = stops , grammar = grammar )
2024-05-04 22:11:20 +08:00
completion_text = completion [ " choices " ] [ 0 ] [ " text " ]
completion_tokens + = completion [ " usage " ] [ " completion_tokens " ]
# If the generation does not involve a function call
if (
START_FUNCTION_CALL_TOKEN not in prompt
and START_FUNCTION_CALL_TOKEN not in completion_text
) :
completion [ " usage " ] [ " completion_tokens " ] = completion_tokens
return _convert_completion_to_chat ( completion , stream = stream ) # type: ignore
# If the generation involves a function call in completion, generate the parameters
elif (
START_FUNCTION_CALL_TOKEN not in prompt
and START_FUNCTION_CALL_TOKEN in completion_text
) :
prompt + = (
completion_text . replace (
f " { START_FUNCTION_CALL_TOKEN } " , START_FUNCTION_CALL_TOKEN
)
+ " \n "
)
function_calls . append (
completion_text . split ( START_FUNCTION_CALL_TOKEN ) [ - 1 ] [ : - 1 ] . strip ( )
)
grammar = get_grammar ( function_calls [ - 1 ] )
2024-07-09 12:20:17 -04:00
completion = create_completion (
prompt = prompt , stop = END_FUNCTION_CALL_TOKEN , grammar = grammar
)
2024-05-04 22:11:20 +08:00
completion_tokens + = completion [ " usage " ] [ " completion_tokens " ]
function_bodies . append ( completion [ " choices " ] [ 0 ] [ " text " ] . strip ( ) )
# If the prompt involves a function call, just append generated parameters to function_bodies
else :
function_bodies . append ( completion_text . strip ( ) )
2024-04-28 08:49:52 +08:00
else :
2024-05-04 22:11:20 +08:00
# If tool_choice/function_call is provided
if isinstance ( function_call , dict ) :
prompt + = f " { function_call [ ' name ' ] } \n { CONTENT_TOKEN } "
function_call = function_call [ " name " ]
function_calls . append ( function_call )
grammar = get_grammar ( function_call )
stops = [ STOP_TOKEN , FROM_TOKEN ]
2024-07-09 12:20:17 -04:00
completion = create_completion (
prompt = prompt , stop = stops , grammar = grammar
)
2024-05-04 22:11:20 +08:00
completion_text = completion [ " choices " ] [ 0 ] [ " text " ]
completion_tokens + = completion [ " usage " ] [ " completion_tokens " ]
function_bodies . append ( completion_text . strip ( ) )
# If "auto" or no tool_choice/function_call
elif isinstance ( function_call , str ) and function_call == " auto " :
while True :
# Generate function name first
grammar = None
stops = CONTENT_TOKEN
2024-07-09 12:20:17 -04:00
completion = create_completion (
prompt = prompt , stop = stops , grammar = grammar
)
2024-05-04 22:11:20 +08:00
completion_text = completion [ " choices " ] [ 0 ] [ " text " ]
completion_tokens + = completion [ " usage " ] [ " completion_tokens " ]
function_name = completion_text . strip ( )
if function_name == " all " :
prompt + = " all \n <|content|> "
else :
function_call = completion_text . strip ( )
prompt + = f " { function_call } \n <|content|> "
function_calls . append ( function_call )
grammar = get_grammar ( function_call )
# Generate content
stops = [ RECIPIENT_TOKEN , STOP_TOKEN ]
2024-07-09 12:20:17 -04:00
completion = create_completion (
prompt = prompt , stop = stops , grammar = grammar
)
2024-05-04 22:11:20 +08:00
completion_text = completion [ " choices " ] [ 0 ] [ " text " ]
completion_tokens + = completion [ " usage " ] [ " completion_tokens " ]
if function_name == " all " :
if completion_text . endswith ( " \n <|from|>assistant \n " ) :
2024-07-09 12:20:17 -04:00
content + = completion_text [ : - len ( " \n <|from|>assistant \n " ) ]
2024-05-04 22:11:20 +08:00
if completion_text . endswith ( " \n <|from|> assistant \n " ) :
content + = completion_text [ - len ( " \n <|from|> assistant \n " ) ]
else :
content + = completion_text
content = content . lstrip ( )
# Check whether the model wants to generate another turn
2024-07-09 12:20:17 -04:00
if (
" <|from|> assistant " in completion_text
or " <|from|>assistant " in completion_text
) :
2024-05-04 22:11:20 +08:00
if completion_text . endswith ( " \n <|from|>assistant \n " ) :
2024-07-09 12:20:17 -04:00
cleaned_completion_text = completion_text [
: - len ( " \n <|from|>assistant \n " )
] . strip ( )
2024-05-04 22:11:20 +08:00
elif completion_text . endswith ( " \n <|from|> assistant \n " ) :
2024-07-09 12:20:17 -04:00
cleaned_completion_text = completion_text [
- len ( " \n <|from|> assistant \n " )
] . strip ( )
2024-05-04 22:11:20 +08:00
else :
cleaned_completion_text = completion_text . strip ( )
prompt + = f " { cleaned_completion_text } \n <|from|>assistant \n <|recipient|> "
else :
break
else :
function_bodies . append ( completion_text . strip ( ) )
# Check whether the model wants to generate another turn
prompt + = completion_text . strip ( )
grammar = None
2024-07-09 12:20:17 -04:00
completion = create_completion (
prompt = prompt , stop = stops , grammar = grammar
)
2024-05-04 22:11:20 +08:00
completion_tokens + = completion [ " usage " ] [ " completion_tokens " ]
2024-07-09 12:20:17 -04:00
if (
" <|from|> assistant " in completion [ " choices " ] [ 0 ] [ " text " ]
or " <|from|>assistant " in completion [ " choices " ] [ 0 ] [ " text " ]
) :
2024-05-04 22:11:20 +08:00
prompt + = " \n <|from|>assistant \n <|recipient|> "
else :
break
assert " usage " in completion
assert len ( function_calls ) == len ( function_bodies )
tool_calls : List [ llama_types . ChatCompletionMessageToolCall ] = [ ]
for function_call , function_body in zip ( function_calls , function_bodies ) :
tool_calls . append (
{
" id " : " call_ "
+ " " . join (
[
random . choice ( string . ascii_letters + string . digits )
for _ in range ( 24 )
]
) ,
" type " : " function " ,
" function " : {
" name " : function_call ,
" arguments " : function_body ,
} ,
}
)
# TODO: support stream mode
2024-07-09 12:20:17 -04:00
function_call_dict : Union [
Dict [ str , str ] ,
Dict [
Literal [ " function_call " ] ,
llama_types . ChatCompletionRequestAssistantMessageFunctionCall ,
] ,
] = { }
2024-05-04 22:11:20 +08:00
if len ( tool_calls ) > 0 :
if tools is not None :
function_call_dict [ " tool_calls " ] = tool_calls
else :
function_call_dict [ " function_call " ] = {
" name " : tool_calls [ 0 ] [ " function " ] [ " name " ] ,
" arguments " : tool_calls [ 0 ] [ " function " ] [ " arguments " ] ,
}
completion [ " usage " ] [ " completion_tokens " ] = completion_tokens
return llama_types . CreateChatCompletionResponse (
id = " chat " + completion [ " id " ] ,
object = " chat.completion " ,
created = completion [ " created " ] ,
model = completion [ " model " ] ,
choices = [
{
" index " : 0 ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( completion [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2024-05-04 22:11:20 +08:00
" message " : {
" role " : " assistant " ,
" content " : None if content == " " else content ,
* * function_call_dict ,
} ,
" finish_reason " : " tool_calls " if len ( tool_calls ) > 0 else " stop " ,
}
] ,
usage = completion [ " usage " ] ,
)
2024-02-08 09:07:03 +08:00
2023-11-08 04:48:51 +01:00
class Llava15ChatHandler :
2024-07-09 12:20:17 -04:00
DEFAULT_SYSTEM_MESSAGE : Optional [ str ] = (
2025-07-03 01:57:43 -04:00
" A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human ' s questions. "
2024-07-09 12:20:17 -04:00
)
2024-04-30 01:35:38 -04:00
CHAT_FORMAT = (
" { % f or message in messages % } "
" { % i f message.role == ' system ' % } "
" {{ message.content }} "
" { % e ndif % } "
" { % i f message.role == ' user ' % } "
" { % i f message.content is string % } "
" \n USER: {{ message.content }} "
2024-04-30 03:08:46 -04:00
" { % e ndif % } "
" { % i f message.content is iterable % } "
2024-04-30 01:35:38 -04:00
" \n USER: "
" { % f or content in message.content % } "
" { % i f content.type == ' image_url ' and content.image_url is string % } "
" {{ content.image_url }} "
" { % e ndif % } "
" { % i f content.type == ' image_url ' and content.image_url is mapping % } "
" {{ content.image_url.url }} "
" { % e ndif % } "
" { % e ndfor % } "
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
" { % i f content.type == ' text ' % } "
" {{ content.text }} "
" { % e ndif % } "
" { % e ndfor % } "
2024-04-30 01:35:38 -04:00
" { % e ndif % } "
" { % e ndif % } "
" { % i f message.role == ' assistant ' and message.content is not none % } "
" \n ASSISTANT: {{ message.content }} "
" { % e ndif % } "
" { % e ndfor % } "
" { % i f add_generation_prompt % } "
" \n ASSISTANT: "
" { % e ndif % } "
)
2023-11-08 11:05:45 -05:00
2024-04-30 15:50:30 -04:00
def __init__ ( self , clip_model_path : str , verbose : bool = True ) :
2025-07-03 01:57:43 -04:00
import llama_cpp . mtmd_cpp as mtmd_cpp
2023-11-08 04:48:51 +01:00
self . clip_model_path = clip_model_path
2023-11-08 11:05:45 -05:00
self . verbose = verbose
2025-07-03 01:57:43 -04:00
self . _mtmd_cpp = mtmd_cpp
2024-04-30 01:35:38 -04:00
self . _exit_stack = ExitStack ( )
2025-07-03 01:57:43 -04:00
self . mtmd_ctx : Optional [ mtmd_cpp . mtmd_context_p ] = None
2023-11-08 04:48:51 +01:00
2024-03-08 21:00:10 -05:00
if not os . path . exists ( clip_model_path ) :
raise ValueError ( f " Clip model path does not exist: { clip_model_path } " )
2025-07-03 01:57:43 -04:00
def _init_mtmd_context ( self , llama_model : llama . Llama ) :
""" Initialize mtmd context with the llama model. """
if self . mtmd_ctx is not None :
return # Already initialized
2023-11-08 11:05:45 -05:00
with suppress_stdout_stderr ( disable = self . verbose ) :
2025-07-03 01:57:43 -04:00
# Get default parameters
ctx_params = self . _mtmd_cpp . mtmd_context_params_default ( )
2025-07-03 02:01:24 -04:00
ctx_params . use_gpu = True # TODO: Make this configurable
2025-07-03 01:57:43 -04:00
ctx_params . print_timings = self . verbose
2025-07-03 02:01:24 -04:00
ctx_params . n_threads = llama_model . n_threads
2025-07-03 01:57:43 -04:00
ctx_params . verbosity = 2 if self . verbose else 0 # GGML_LOG_LEVEL_INFO = 2
# Initialize mtmd context
self . mtmd_ctx = self . _mtmd_cpp . mtmd_init_from_file (
self . clip_model_path . encode ( ) ,
llama_model . model ,
ctx_params
)
2023-11-08 04:48:51 +01:00
2025-07-03 01:57:43 -04:00
if self . mtmd_ctx is None :
raise ValueError ( f " Failed to load mtmd context from: { self . clip_model_path } " )
2024-07-09 12:20:17 -04:00
2025-07-03 01:57:43 -04:00
# Check if vision is supported
if not self . _mtmd_cpp . mtmd_support_vision ( self . mtmd_ctx ) :
raise ValueError ( " Vision is not supported by this model " )
2023-11-08 04:48:51 +01:00
2025-07-03 01:57:43 -04:00
def mtmd_free ( ) :
2024-04-30 01:35:38 -04:00
with suppress_stdout_stderr ( disable = self . verbose ) :
2025-07-03 01:57:43 -04:00
if self . mtmd_ctx is not None :
self . _mtmd_cpp . mtmd_free ( self . mtmd_ctx )
self . mtmd_ctx = None
2024-07-09 12:20:17 -04:00
2025-07-03 01:57:43 -04:00
self . _exit_stack . callback ( mtmd_free )
2023-11-08 04:48:51 +01:00
2024-04-30 01:35:38 -04:00
def load_image ( self , image_url : str ) - > bytes :
return self . _load_image ( image_url )
2023-11-08 04:48:51 +01:00
2025-07-03 01:57:43 -04:00
def _create_bitmap_from_bytes ( self , image_bytes : bytes ) :
""" Create mtmd_bitmap from image bytes. """
if self . mtmd_ctx is None :
raise ValueError ( " mtmd context not initialized " )
2024-08-29 01:05:32 -04:00
with suppress_stdout_stderr ( disable = self . verbose ) :
2025-07-03 01:57:43 -04:00
# Create bitmap from buffer using helper function
bitmap = self . _mtmd_cpp . mtmd_helper_bitmap_init_from_buf (
self . mtmd_ctx ,
( ctypes . c_uint8 * len ( image_bytes ) ) . from_buffer ( bytearray ( image_bytes ) ) ,
len ( image_bytes )
2024-08-29 01:05:32 -04:00
)
2025-07-03 01:57:43 -04:00
if bitmap is None :
raise ValueError ( " Failed to create bitmap from image bytes " )
return bitmap
2024-08-29 01:05:32 -04:00
2023-11-08 04:48:51 +01:00
def __call__ (
self ,
* ,
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
2023-11-21 06:21:33 +02:00
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
2023-11-08 04:48:51 +01:00
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
2024-04-30 01:35:38 -04:00
seed : Optional [ int ] = None ,
2023-11-09 00:55:23 -05:00
response_format : Optional [
llama_types . ChatCompletionRequestResponseFormat
] = None ,
2023-11-10 02:51:58 -05:00
max_tokens : Optional [ int ] = None ,
2023-11-08 04:48:51 +01:00
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
2024-04-30 01:35:38 -04:00
logit_bias : Optional [ Dict [ str , float ] ] = None ,
logprobs : Optional [ bool ] = None ,
top_logprobs : Optional [ int ] = None ,
2023-11-08 04:48:51 +01:00
* * kwargs , # type: ignore
2023-11-08 00:07:16 -05:00
) - > Union [
llama_types . CreateChatCompletionResponse ,
Iterator [ llama_types . CreateChatCompletionStreamResponse ] ,
] :
2025-07-03 01:57:43 -04:00
# Initialize mtmd context
self . _init_mtmd_context ( llama )
assert self . mtmd_ctx is not None
2024-04-30 01:35:38 -04:00
2023-11-08 04:48:51 +01:00
system_prompt = _get_system_message ( messages )
2024-05-02 11:32:18 -04:00
if system_prompt == " " and self . DEFAULT_SYSTEM_MESSAGE is not None :
2024-07-09 12:20:17 -04:00
messages = [
llama_types . ChatCompletionRequestSystemMessage (
role = " system " , content = self . DEFAULT_SYSTEM_MESSAGE
)
] + messages
2024-04-30 01:35:38 -04:00
image_urls = self . get_image_urls ( messages )
2024-05-10 06:49:40 +02:00
template = ImmutableSandboxedEnvironment (
trim_blocks = True ,
lstrip_blocks = True ,
) . from_string ( self . CHAT_FORMAT )
2025-07-03 01:57:43 -04:00
# Get the default media marker
media_marker = self . _mtmd_cpp . mtmd_default_marker ( ) . decode ( ' utf-8 ' )
# Replace image URLs with media markers in the template
2024-05-08 13:12:31 -04:00
text = template . render (
messages = messages ,
add_generation_prompt = True ,
eos_token = llama . detokenize ( [ llama . token_eos ( ) ] ) ,
bos_token = llama . detokenize ( [ llama . token_bos ( ) ] ) ,
)
2025-07-03 01:57:43 -04:00
# Replace image URLs in text with media markers
for image_url in image_urls :
text = text . replace ( image_url , media_marker )
2024-04-30 01:35:38 -04:00
2024-08-29 01:05:32 -04:00
if self . verbose :
print ( text , file = sys . stderr )
2025-07-03 01:57:43 -04:00
# Create bitmaps from images
bitmaps = [ ]
bitmap_cleanup = [ ]
try :
for image_url in image_urls :
image_bytes = self . load_image ( image_url )
bitmap = self . _create_bitmap_from_bytes ( image_bytes )
bitmaps . append ( bitmap )
bitmap_cleanup . append ( bitmap )
# Create input text structure
input_text = self . _mtmd_cpp . mtmd_input_text ( )
input_text . text = text . encode ( ' utf-8 ' )
input_text . add_special = True
input_text . parse_special = True
# Create input chunks
chunks = self . _mtmd_cpp . mtmd_input_chunks_init ( )
if chunks is None :
raise ValueError ( " Failed to create input chunks " )
2023-11-08 04:48:51 +01:00
2025-07-03 01:57:43 -04:00
try :
# Tokenize text and images together
bitmap_array = ( self . _mtmd_cpp . mtmd_bitmap_p_ctypes * len ( bitmaps ) ) ( * bitmaps )
result = self . _mtmd_cpp . mtmd_tokenize (
self . mtmd_ctx ,
chunks ,
ctypes . byref ( input_text ) ,
bitmap_array ,
len ( bitmaps )
2024-07-09 12:20:17 -04:00
)
2024-04-30 01:35:38 -04:00
2025-07-03 01:57:43 -04:00
if result != 0 :
raise ValueError ( f " Failed to tokenize input: error code { result } " )
# Reset llama context
llama . reset ( )
llama . _ctx . kv_cache_clear ( )
# Process each chunk
n_past = llama_cpp . llama_pos ( 0 )
n_chunks = self . _mtmd_cpp . mtmd_input_chunks_size ( chunks )
for i in range ( n_chunks ) :
chunk = self . _mtmd_cpp . mtmd_input_chunks_get ( chunks , i )
if chunk is None :
continue
chunk_type = self . _mtmd_cpp . mtmd_input_chunk_get_type ( chunk )
if chunk_type == self . _mtmd_cpp . MTMD_INPUT_CHUNK_TYPE_TEXT :
# Handle text chunk
n_tokens_out = ctypes . c_size_t ( )
tokens_ptr = self . _mtmd_cpp . mtmd_input_chunk_get_tokens_text (
chunk , ctypes . byref ( n_tokens_out )
)
if tokens_ptr and n_tokens_out . value > 0 :
# Convert ctypes array to Python list
tokens = [ tokens_ptr [ j ] for j in range ( n_tokens_out . value ) ]
if llama . n_tokens + len ( tokens ) > llama . n_ctx ( ) :
raise ValueError (
f " Prompt exceeds n_ctx: { llama . n_tokens + len ( tokens ) } > { llama . n_ctx ( ) } "
)
llama . eval ( tokens )
elif chunk_type in [ self . _mtmd_cpp . MTMD_INPUT_CHUNK_TYPE_IMAGE , self . _mtmd_cpp . MTMD_INPUT_CHUNK_TYPE_AUDIO ] :
# Handle image/audio chunk using helper
chunk_n_tokens = self . _mtmd_cpp . mtmd_input_chunk_get_n_tokens ( chunk )
if llama . n_tokens + chunk_n_tokens > llama . n_ctx ( ) :
raise ValueError (
f " Prompt exceeds n_ctx: { llama . n_tokens + chunk_n_tokens } > { llama . n_ctx ( ) } "
)
new_n_past = llama_cpp . llama_pos ( 0 )
result = self . _mtmd_cpp . mtmd_helper_eval_chunk_single (
self . mtmd_ctx ,
llama . _ctx . ctx ,
chunk ,
llama_cpp . llama_pos ( llama . n_tokens ) ,
llama_cpp . llama_seq_id ( 0 ) ,
llama . n_batch ,
False , # logits_last
ctypes . byref ( new_n_past )
)
if result != 0 :
raise ValueError ( f " Failed to evaluate chunk: error code { result } " )
# Update llama's token count
llama . n_tokens = new_n_past . value
# Get prompt tokens to avoid a cache miss
prompt = llama . input_ids [ : llama . n_tokens ] . tolist ( )
2023-11-09 00:55:23 -05:00
2025-07-03 01:57:43 -04:00
finally :
self . _mtmd_cpp . mtmd_input_chunks_free ( chunks )
finally :
# Cleanup bitmaps
for bitmap in bitmap_cleanup :
self . _mtmd_cpp . mtmd_bitmap_free ( bitmap )
# Handle response format and tools (same as before)
2023-11-09 00:55:23 -05:00
if response_format is not None and response_format [ " type " ] == " json_object " :
2024-03-15 12:58:34 -04:00
grammar = _grammar_for_response_format ( response_format )
2023-11-08 04:48:51 +01:00
2024-04-30 01:35:38 -04:00
# Convert legacy functions to tools
if functions is not None :
tools = [
{
" type " : " function " ,
" function " : function ,
}
for function in functions
]
# Convert legacy function_call to tool_choice
if function_call is not None :
if isinstance ( function_call , str ) and (
function_call == " none " or function_call == " auto "
) :
tool_choice = function_call
if isinstance ( function_call , dict ) and " name " in function_call :
tool_choice = {
" type " : " function " ,
" function " : {
" name " : function_call [ " name " ] ,
} ,
}
tool = None
2024-07-09 12:20:17 -04:00
if (
tool_choice is not None
and isinstance ( tool_choice , dict )
and tools is not None
) :
2024-04-30 01:35:38 -04:00
name = tool_choice [ " function " ] [ " name " ]
tool = next ( ( t for t in tools if t [ " function " ] [ " name " ] == name ) , None )
if tool is None :
raise ValueError ( f " Tool choice ' { name } ' not found in tools. " )
schema = tool [ " function " ] [ " parameters " ]
try :
# create grammar from json schema
grammar = llama_grammar . LlamaGrammar . from_json_schema (
json . dumps ( schema ) , verbose = llama . verbose
)
except Exception as e :
2024-07-09 14:06:46 -04:00
if llama . verbose :
print ( str ( e ) , file = sys . stderr )
2024-04-30 01:35:38 -04:00
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF , verbose = llama . verbose
)
completion_or_chunks = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
logprobs = top_logprobs if logprobs else None ,
2023-11-08 04:48:51 +01:00
stream = stream ,
2024-04-30 01:35:38 -04:00
stop = stop ,
seed = seed ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
logit_bias = logit_bias ,
2023-11-08 00:07:16 -05:00
)
2025-07-03 01:57:43 -04:00
2024-04-30 01:35:38 -04:00
if tool is not None :
tool_name = tool [ " function " ] [ " name " ]
return _convert_completion_to_chat_function (
tool_name , completion_or_chunks , stream
)
return _convert_completion_to_chat ( completion_or_chunks , stream = stream )
@staticmethod
def _load_image ( image_url : str ) - > bytes :
# TODO: Add Pillow support for other image formats beyond (jpg, png)
if image_url . startswith ( " data: " ) :
import base64
image_bytes = base64 . b64decode ( image_url . split ( " , " ) [ 1 ] )
return image_bytes
else :
import urllib . request
with urllib . request . urlopen ( image_url ) as f :
image_bytes = f . read ( )
return image_bytes
@staticmethod
def get_image_urls ( messages : List [ llama_types . ChatCompletionRequestMessage ] ) :
image_urls : List [ str ] = [ ]
for message in messages :
if message [ " role " ] == " user " :
if message [ " content " ] is None :
continue
for content in message [ " content " ] :
if isinstance ( content , dict ) and " type " in content :
if content [ " type " ] == " image_url " :
if (
isinstance ( content [ " image_url " ] , dict )
and " url " in content [ " image_url " ]
) :
image_urls . append ( content [ " image_url " ] [ " url " ] )
else :
image_urls . append ( content [ " image_url " ] )
return image_urls
@staticmethod
def split_text_on_image_urls ( text : str , image_urls : List [ str ] ) :
2025-07-03 01:57:43 -04:00
""" This method is no longer used in the new implementation. """
2024-04-30 01:35:38 -04:00
def find_first ( s : str , substrs : List [ str ] ) :
for i , substr in enumerate ( substrs ) :
pos = s . find ( substr )
if pos != - 1 :
return pos , i
return None , None
split_text : List [ Tuple [ Literal [ " text " , " image_url " ] , str ] ] = [ ]
remaining = text
while remaining :
# Find first image_url
pos , i = find_first ( remaining , image_urls )
if pos is not None and i is not None :
if pos > 0 :
split_text . append ( ( " text " , remaining [ : pos ] ) )
split_text . append ( ( " image_url " , image_urls [ i ] ) )
remaining = remaining [ pos + len ( image_urls [ i ] ) : ]
else :
split_text . append ( ( " text " , remaining ) )
remaining = " "
return split_text
@classmethod
def from_pretrained (
cls ,
repo_id : str ,
filename : Optional [ str ] ,
local_dir : Optional [ Union [ str , os . PathLike [ str ] ] ] = None ,
local_dir_use_symlinks : Union [ bool , Literal [ " auto " ] ] = " auto " ,
cache_dir : Optional [ Union [ str , os . PathLike [ str ] ] ] = None ,
* * kwargs : Any ,
) - > " Llava15ChatHandler " :
import fnmatch
from pathlib import Path
2024-07-09 12:20:17 -04:00
2024-04-30 01:35:38 -04:00
try :
2024-07-09 12:20:17 -04:00
from huggingface_hub import hf_hub_download , HfFileSystem # type: ignore
from huggingface_hub . utils import validate_repo_id # type: ignore
2024-04-30 01:35:38 -04:00
except ImportError :
raise ImportError (
" Llama.from_pretrained requires the huggingface-hub package. "
" You can install it with `pip install huggingface-hub`. "
)
validate_repo_id ( repo_id )
hffs = HfFileSystem ( )
files = [
file [ " name " ] if isinstance ( file , dict ) else file
2024-07-09 12:20:17 -04:00
for file in hffs . ls ( repo_id ) # type: ignore
2024-04-30 01:35:38 -04:00
]
# split each file into repo_id, subfolder, filename
file_list : List [ str ] = [ ]
for file in files :
rel_path = Path ( file ) . relative_to ( repo_id )
file_list . append ( str ( rel_path ) )
matching_files = [ file for file in file_list if fnmatch . fnmatch ( file , filename ) ] # type: ignore
if len ( matching_files ) == 0 :
raise ValueError (
f " No file found in { repo_id } that match { filename } \n \n "
f " Available Files: \n { json . dumps ( file_list ) } "
)
if len ( matching_files ) > 1 :
raise ValueError (
f " Multiple files found in { repo_id } matching { filename } \n \n "
f " Available Files: \n { json . dumps ( files ) } "
)
( matching_file , ) = matching_files
subfolder = str ( Path ( matching_file ) . parent )
filename = Path ( matching_file ) . name
# download the file
hf_hub_download (
repo_id = repo_id ,
filename = filename ,
subfolder = subfolder ,
local_dir = cast ( Union [ str , Path , None ] , local_dir ) ,
local_dir_use_symlinks = local_dir_use_symlinks ,
cache_dir = cast ( Union [ str , Path , None ] , cache_dir ) ,
)
if local_dir is None :
model_path = hf_hub_download (
repo_id = repo_id ,
filename = filename ,
subfolder = subfolder ,
local_dir = local_dir ,
local_dir_use_symlinks = local_dir_use_symlinks ,
cache_dir = cast ( Union [ str , Path , None ] , cache_dir ) ,
local_files_only = True ,
)
else :
model_path = os . path . join ( local_dir , filename )
return cls (
clip_model_path = model_path ,
* * kwargs ,
)
2024-07-09 12:20:17 -04:00
2024-04-30 01:35:38 -04:00
class ObsidianChatHandler ( Llava15ChatHandler ) :
# Prompt Format
# The model followed ChatML format. However, with ### as the seperator
# <|im_start|>user
# What is this sign about?\n<image>
# ###
# <|im_start|>assistant
# The sign is about bullying, and it is placed on a black background with a red background.
# ###
CHAT_FORMAT = (
" { % f or message in messages % } "
# System message
" { % i f message.role == ' system ' % } "
" <|im_start|>system \n "
" {{ message.content }} \n "
" ### \n "
" { % e ndif % } "
# User message
" { % i f message.role == ' user ' % } "
" <|im_start|>user \n "
" { % i f message.content is string % } "
" {{ message.content }} "
" { % e ndif % } "
" { % i f message.content is iterable % } "
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
" { % i f content.type == ' image_url ' and content.image_url is string % } "
" {{ content.image_url }} "
" { % e ndif % } "
" { % i f content.type == ' image_url ' and content.image_url is mapping % } "
" {{ content.image_url.url }} "
" { % e ndif % } "
" { % e ndfor % } "
2024-04-30 01:35:38 -04:00
" { % f or content in message.content % } "
" { % i f content.type == ' text ' % } "
" {{ content.text }} "
" { % e ndif % } "
" { % e ndfor % } "
" { % e ndif % } "
" ### \n "
" { % e ndif % } "
# Assistant message
" { % i f message.role == ' assistant ' % } "
" <|im_start|>assistant \n "
" {{ message.content }} "
" ### \n "
" { % e ndif % } "
" { % e ndfor % } "
# Generation prompt
" { % i f add_generation_prompt % } "
" <|im_start|>assistant \n "
" { % e ndif % } "
)
2024-07-09 12:20:17 -04:00
2024-04-30 01:35:38 -04:00
class MoondreamChatHandler ( Llava15ChatHandler ) :
# Chat Format:
# f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
CHAT_FORMAT = (
" { % f or message in messages % } "
" { % i f message.role == ' user ' % } "
" { % i f message.content is iterable % } "
# <image>
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
2024-04-30 01:35:38 -04:00
" { % i f content.type == ' image_url ' % } "
" { % i f content.image_url is string % } "
" {{ content.image_url }} \n \n "
" { % e ndif % } "
" { % i f content.image_url is mapping % } "
" {{ content.image_url.url }} \n \n "
" { % e ndif % } "
" { % e ndif % } "
2024-04-30 03:08:46 -04:00
" { % e ndfor % } "
2024-04-30 01:35:38 -04:00
# Question:
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
2024-04-30 01:35:38 -04:00
" { % i f content.type == ' text ' % } "
" Question: {{ content.text }} \n \n "
" { % e ndif % } "
" { % e ndfor % } "
" { % e ndif % } "
# Question:
" { % i f message.content is string % } "
" Question: {{ message.content }} \n \n "
" { % e ndif % } "
" { % e ndif % } "
# Answer:
" { % i f message.role == ' assistant ' % } "
" Answer: {{ message.content }} \n \n "
" { % e ndif % } "
" { % e ndfor % } "
# Generation prompt
" { % i f add_generation_prompt % } "
" Answer: "
" { % e ndif % } "
)
2024-07-09 12:20:17 -04:00
2024-04-30 01:35:38 -04:00
class Llava16ChatHandler ( Llava15ChatHandler ) :
DEFAULT_SYSTEM_MESSAGE = " A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human ' s questions. "
# Example prompt
# "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
CHAT_FORMAT = (
" { % f or message in messages % } "
" { % i f message.role == ' system ' % } "
" {{ message.content }} "
" { % e ndif % } "
" { % i f message.role == ' user ' % } "
" { % i f message.content is iterable % } "
# <image>
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
2024-04-30 01:35:38 -04:00
" { % i f content.type == ' image_url ' % } "
" { % i f content.image_url is string % } "
" {{ content.image_url }} \n "
" { % e ndif % } "
" { % i f content.image_url is mapping % } "
" {{ content.image_url.url }} \n "
" { % e ndif % } "
" { % e ndif % } "
2024-04-30 03:08:46 -04:00
" { % e ndfor % } "
2024-04-30 01:35:38 -04:00
# Question:
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
2024-04-30 01:35:38 -04:00
" { % i f content.type == ' text ' % } "
" {{ content.text }} "
" { % e ndif % } "
" { % e ndfor % } "
" { % e ndif % } "
# Question:
" { % i f message.content is string % } "
" {{ message.content }} "
" { % e ndif % } "
" { % e ndif % } "
# Answer:
" { % i f message.role == ' assistant ' % } "
" {{ message.content }} "
" { % e ndif % } "
" { % e ndfor % } "
# Generation prompt
" { % i f add_generation_prompt % } "
" Answer: "
" { % e ndif % } "
)
2024-07-09 12:20:17 -04:00
2024-04-30 01:35:38 -04:00
class NanoLlavaChatHandler ( Llava15ChatHandler ) :
# Prompt Format
# The model follow the ChatML standard, however, without \n at the end of <|im_end|>:
# <|im_start|>system
# Answer the question<|im_end|><|im_start|>user
# <image>
# What is the picture about?<|im_end|><|im_start|>assistant
2024-05-08 13:12:31 -04:00
DEFAULT_SYSTEM_MESSAGE = " Answer the question "
2024-04-30 01:35:38 -04:00
CHAT_FORMAT = (
" { % f or message in messages % } "
# System message
" { % i f message.role == ' system ' % } "
" <|im_start|>system \n "
" {{ message.content }} "
" <|im_end|> "
" { % e ndif % } "
# User message
" { % i f message.role == ' user ' % } "
" <|im_start|>user \n "
" { % i f message.content is string % } "
" {{ message.content }} "
" { % e ndif % } "
" { % i f message.content is iterable % } "
2024-04-30 03:08:46 -04:00
" { % f or content in message.content % } "
" { % i f content.type == ' image_url ' and content.image_url is string % } "
" {{ content.image_url }} "
" { % e ndif % } "
" { % i f content.type == ' image_url ' and content.image_url is mapping % } "
" {{ content.image_url.url }} "
" { % e ndif % } "
" { % e ndfor % } "
2024-04-30 01:35:38 -04:00
" { % f or content in message.content % } "
" { % i f content.type == ' text ' % } "
" {{ content.text }} "
" { % e ndif % } "
" { % e ndfor % } "
" { % e ndif % } "
" <|im_end|> "
" { % e ndif % } "
# Assistant message
" { % i f message.role == ' assistant ' % } "
" <|im_start|>assistant \n "
" {{ message.content }} "
" <|im_end|> "
" { % e ndif % } "
" { % e ndfor % } "
# Generation prompt
" { % i f add_generation_prompt % } "
" <|im_start|>assistant \n "
" { % e ndif % } "
)
2024-02-12 15:56:07 -05:00
2024-07-09 12:20:17 -04:00
2024-05-29 02:29:44 -04:00
class Llama3VisionAlphaChatHandler ( Llava15ChatHandler ) :
2024-05-02 11:32:18 -04:00
# question = "<image>" + q
# prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
DEFAULT_SYSTEM_MESSAGE = None
CHAT_FORMAT = (
" { % f or message in messages % } "
" <|start_header_id|> "
" { % i f message.role == ' user ' % } "
" user<|end_header_id|> \n \n "
" { % i f message.content is iterable % } "
# <image>
" { % f or content in message.content % } "
" { % i f content.type == ' image_url ' % } "
" { % i f content.image_url is string % } "
" {{ content.image_url }} "
" { % e ndif % } "
" { % i f content.image_url is mapping % } "
" {{ content.image_url.url }} "
" { % e ndif % } "
" { % e ndif % } "
" { % e ndfor % } "
# Question:
" { % f or content in message.content % } "
" { % i f content.type == ' text ' % } "
" {{ content.text }} "
" { % e ndif % } "
" { % e ndfor % } "
" { % e ndif % } "
# Question:
" { % i f message.content is string % } "
" {{ message.content }} "
" { % e ndif % } "
" { % e ndif % } "
# Answer:
" { % i f message.role == ' assistant ' % } "
" assistant<|end_header_id|> \n \n "
" {{ message.content }} "
" { % e ndif % } "
" <|eot_id|> "
" { % e ndfor % } "
# Generation prompt
" { % i f add_generation_prompt % } "
" <|start_header_id|>assistant<|end_header_id|> \n \n "
" { % e ndif % } "
)
2024-02-12 15:56:07 -05:00
2024-07-09 12:20:17 -04:00
2024-05-29 02:29:44 -04:00
# alias
Llama3VisionAlpha = Llama3VisionAlphaChatHandler
2024-08-29 01:26:34 -04:00
class MiniCPMv26ChatHandler ( Llava15ChatHandler ) :
2024-08-29 01:05:32 -04:00
DEFAULT_SYSTEM_MESSAGE = " You are a helpful assistant. "
CHAT_FORMAT = (
" { % f or message in messages % } "
" { % i f loop.first and messages[0][ ' role ' ] != ' system ' % } "
" <|im_start|>system \n You are a helpful assistant.<|im_end|> \n "
" { % e ndif % } "
" <|im_start|> {{ message[ ' role ' ] }} \n "
" { % i f message[ ' content ' ] is iterable % } "
" { % f or content in message[ ' content ' ] % } "
" { % i f content.type == ' image_url ' % } "
" { % i f content.image_url is string % } "
" {{ content.image_url }} "
" { % e ndif % } "
" { % i f content.image_url is mapping % } "
" {{ content.image_url.url }} "
" { % e ndif % } "
" { % e ndif % } "
" { % e ndfor % } "
" { % f or content in message[ ' content ' ] % } "
" { % i f content.type == ' text ' % } "
" {{ content.text }} "
" { % e ndif % } "
" { % e ndfor % } "
" { % e ndif % } "
" { % i f message[ ' content ' ] is string % } "
" {{ message[ ' content ' ] }} "
" { % e ndif % } "
" <|im_end|> \n "
" { % e ndfor % } "
" { % i f add_generation_prompt % } "
" <|im_start|>assistant \n "
" { % e ndif % } "
)
2025-07-03 01:57:43 -04:00
class Qwen25VLChatHandler ( Llava15ChatHandler ) :
DEFAULT_SYSTEM_MESSAGE = " You are a helpful assistant. "
CHAT_FORMAT = (
" <|im_start|>system \n "
" You are a helpful assistant.<|im_end|> \n "
" { % f or message in messages % } "
" { % i f message[ ' role ' ] == ' user ' % } "
" <|im_start|>user \n "
" { % i f message[ ' content ' ] is string % } "
" {{ message[ ' content ' ] }} "
" { % e lse % } "
" { % f or content in message[ ' content ' ] % } "
" { % i f content[ ' type ' ] == ' text ' % } "
" {{ content[ ' text ' ] }} "
" { % e lif content[ ' type ' ] == ' image_url ' % } "
" { % i f content.image_url is string % } "
" {{ content.image_url }} "
" { % e lse % } "
" {{ content.image_url.url }} "
" { % e ndif % } "
" { % e ndif % } "
" { % e ndfor % } "
" { % e ndif % } "
" <|im_end|> \n "
" { % e ndif % } "
" { % e ndfor % } "
" <|im_start|>assistant \n "
)
def __call__ ( self , * * kwargs ) :
llama = kwargs [ ' llama ' ]
# Clear state for multiple runs
llama . reset ( )
llama . _ctx . kv_cache_clear ( )
llama . n_tokens = 0
if hasattr ( llama , ' input_ids ' ) :
llama . input_ids . fill ( 0 )
# Clear any handler state
if hasattr ( self , ' _last_image_embed ' ) :
self . _last_image_embed = None
self . _last_image_hash = None
if self . verbose :
messages = kwargs . get ( ' messages ' , [ ] )
image_count = len ( self . get_image_urls ( messages ) )
print ( f " Minimal - Cleared state, processing { image_count } images " , file = sys . stderr )
# Use parent implementation
return super ( ) . __call__ ( * * kwargs )
2024-02-12 15:56:07 -05:00
@register_chat_completion_handler ( " chatml-function-calling " )
def chatml_function_calling (
llama : llama . Llama ,
messages : List [ llama_types . ChatCompletionRequestMessage ] ,
functions : Optional [ List [ llama_types . ChatCompletionFunction ] ] = None ,
function_call : Optional [ llama_types . ChatCompletionRequestFunctionCall ] = None ,
tools : Optional [ List [ llama_types . ChatCompletionTool ] ] = None ,
tool_choice : Optional [ llama_types . ChatCompletionToolChoiceOption ] = None ,
temperature : float = 0.2 ,
top_p : float = 0.95 ,
top_k : int = 40 ,
min_p : float = 0.05 ,
typical_p : float = 1.0 ,
stream : bool = False ,
stop : Optional [ Union [ str , List [ str ] ] ] = [ ] ,
response_format : Optional [ llama_types . ChatCompletionRequestResponseFormat ] = None ,
max_tokens : Optional [ int ] = None ,
presence_penalty : float = 0.0 ,
frequency_penalty : float = 0.0 ,
repeat_penalty : float = 1.1 ,
tfs_z : float = 1.0 ,
mirostat_mode : int = 0 ,
mirostat_tau : float = 5.0 ,
mirostat_eta : float = 0.1 ,
model : Optional [ str ] = None ,
logits_processor : Optional [ llama . LogitsProcessorList ] = None ,
grammar : Optional [ llama . LlamaGrammar ] = None ,
2024-04-10 03:41:55 -04:00
logprobs : Optional [ bool ] = None ,
top_logprobs : Optional [ int ] = None ,
2024-02-12 15:56:07 -05:00
* * kwargs , # type: ignore
) - > Union [
llama_types . CreateChatCompletionResponse ,
Iterator [ llama_types . CreateChatCompletionStreamResponse ] ,
] :
function_calling_template = (
" { % f or message in messages % } "
" <|im_start|> {{ message.role }} \n "
# System message
" { % i f message.role == ' system ' % } "
" {{ message.content }} "
" { % i f tool_calls % } "
" \n \n You have access to the following functions: \n "
" { % f or tool in tools % } "
" \n functions. {{ tool.function.name }}: \n "
" {{ tool.function.parameters | tojson }} "
" \n { % e ndfor % } "
" \n \n You can respond to users messages with either a single message or one or more function calls. "
" \n \n To respond with a message begin the message with ' message: ' , use the following format: "
" \n \n message: "
" \n <message> "
" \n \n To respond with one or more function calls begin the message with ' functions.<function_name>: ' , use the following format: "
" \n \n functions.<function_name>: "
' \n { " arg1 " : " value1 " , " arg2 " : " value2 " } '
" \n functions.<function_name>: "
' \n { " arg1 " : " value1 " , " arg2 " : " value2 " } '
" { % e ndif % } "
2024-02-13 23:02:50 -05:00
" <|im_end|> \n "
2024-02-12 15:56:07 -05:00
" { % e ndif % } "
# User message
" { % i f message.role == ' user ' % } "
" {{ message.content }} "
2024-02-13 23:02:50 -05:00
" <|im_end|> \n "
2024-02-12 15:56:07 -05:00
" { % e ndif % } "
# Assistant message
" { % i f message.role == ' assistant ' % } "
## Reglar message
" { % i f message.content and message.content | length > 0 % } "
2024-02-13 03:11:35 -05:00
" { % i f tool_calls % } "
2024-02-12 15:56:07 -05:00
" message: \n "
2024-02-13 03:11:35 -05:00
" { % e ndif % } "
2024-02-12 15:56:07 -05:00
" {{ message.content }} "
2024-02-13 23:02:50 -05:00
" <|im_end|> \n "
2024-02-12 15:56:07 -05:00
" { % e ndif % } "
## Function calls
2024-02-13 03:11:35 -05:00
" { % i f ' tool_calls ' in message % } "
2024-02-12 15:56:07 -05:00
" { % f or tool_call in message.tool_calls % } "
" functions. {{ tool_call.function.name }}: \n "
" {{ tool_call.function.arguments }} "
" { % e ndfor % } "
2024-02-13 23:02:50 -05:00
" <|im_end|> \n "
2024-02-12 15:56:07 -05:00
" { % e ndif % } "
" { % e ndif % } "
" { % e ndfor % } "
2024-02-13 23:02:50 -05:00
" { % i f add_generation_prompt % }<|im_start|>assistant \n { % e ndif % } "
2024-02-12 15:56:07 -05:00
)
2024-05-10 06:49:40 +02:00
template_renderer = ImmutableSandboxedEnvironment (
2024-02-12 15:56:07 -05:00
autoescape = jinja2 . select_autoescape ( [ " html " , " xml " ] ) ,
undefined = jinja2 . StrictUndefined ,
) . from_string ( function_calling_template )
# Convert legacy functions to tools
if functions is not None :
tools = [
{
" type " : " function " ,
" function " : function ,
}
for function in functions
]
# Convert legacy function_call to tool_choice
if function_call is not None :
if isinstance ( function_call , str ) and (
function_call == " none " or function_call == " auto "
) :
tool_choice = function_call
if isinstance ( function_call , dict ) and " name " in function_call :
tool_choice = {
" type " : " function " ,
" function " : {
" name " : function_call [ " name " ] ,
} ,
}
2024-07-09 12:20:17 -04:00
stop = (
[ stop , " <|im_end|> " ]
if isinstance ( stop , str )
else stop + [ " <|im_end|> " ] if stop else [ " <|im_end|> " ]
)
2024-02-13 23:02:50 -05:00
2024-02-12 15:56:07 -05:00
# Case 1: No tool choice by user
if (
tool_choice is None
or ( isinstance ( tool_choice , str ) and tool_choice == " none " )
or tools is None
or len ( tools ) == 0
) :
prompt = template_renderer . render (
messages = messages ,
tools = [ ] ,
tool_calls = None ,
2024-02-13 03:24:41 -05:00
add_generation_prompt = True ,
2024-02-12 15:56:07 -05:00
)
2024-03-15 12:58:34 -04:00
2024-02-12 15:56:07 -05:00
if response_format is not None and response_format [ " type " ] == " json_object " :
2024-03-15 12:58:34 -04:00
grammar = _grammar_for_response_format ( response_format )
2024-02-12 15:56:07 -05:00
return _convert_completion_to_chat (
llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = stream ,
stop = stop ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
2024-04-10 03:41:55 -04:00
logprobs = top_logprobs if logprobs else None ,
2024-02-12 15:56:07 -05:00
) ,
stream = stream ,
)
# Case 2: Tool choice by user
if isinstance ( tool_choice , dict ) :
tool_name = tool_choice [ " function " ] [ " name " ]
tool = next (
( tool for tool in tools if tool [ " function " ] [ " name " ] == tool_name ) , None
)
if tool is None :
raise ValueError ( f " Tool with name ' { tool_name } ' not found in tools " )
prompt = template_renderer . render (
messages = messages ,
tools = tools ,
tool_calls = True ,
2024-02-13 03:24:41 -05:00
add_generation_prompt = True ,
2024-02-12 15:56:07 -05:00
)
prompt + = f " functions. { tool_name } : \n "
try :
grammar = llama_grammar . LlamaGrammar . from_json_schema (
json . dumps ( tool [ " function " ] [ " parameters " ] ) , verbose = llama . verbose
)
except Exception as e :
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF , verbose = llama . verbose
)
if llama . verbose :
print (
" Failed to parse function body as JSON schema, falling back to default grammar "
)
print ( e )
completion_or_chunks = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = stream ,
stop = stop ,
max_tokens = max_tokens ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
)
return _convert_completion_to_chat_function (
tool_name , completion_or_chunks , stream
)
# Case 3: Automatic tool choice
assert isinstance ( tool_choice , str ) and tool_choice == " auto "
function_names = " | " . join (
[ f ''' " functions. { tool [ ' function ' ] [ ' name ' ] } : " ''' for tool in tools ]
)
initial_gbnf_tool_grammar = (
""" root ::= functions | " message: " \n """
f """ functions ::= { function_names } \n """
)
follow_up_gbnf_tool_grammar = (
""" root ::= functions | " <|im_end|> " \n """
f """ functions ::= { function_names } \n """
)
prompt = template_renderer . render (
messages = messages ,
tools = tools ,
tool_calls = True ,
2024-02-13 03:24:41 -05:00
add_generation_prompt = True ,
2024-02-12 15:56:07 -05:00
)
completion_or_chunks = llama . create_completion (
prompt = prompt ,
temperature = 0 ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = False ,
stop = [ " : " ] ,
max_tokens = None ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = llama_grammar . LlamaGrammar . from_string (
initial_gbnf_tool_grammar , verbose = llama . verbose
) ,
)
completion : llama_types . CreateCompletionResponse = completion_or_chunks # type: ignore
text = completion [ " choices " ] [ 0 ] [ " text " ]
if " message " in text :
return _convert_completion_to_chat (
llama . create_completion (
prompt = prompt + " message: \n " ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = stream ,
stop = [ " <|im_end|> " ] ,
2024-04-10 03:41:55 -04:00
logprobs = top_logprobs if logprobs else None ,
2024-02-12 15:56:07 -05:00
max_tokens = None ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = llama_grammar . LlamaGrammar . from_string (
follow_up_gbnf_tool_grammar , verbose = llama . verbose
) ,
) ,
stream = stream ,
)
# One or more function calls
tool_name = text [ len ( " functions. " ) : ]
tool = next ( ( tool for tool in tools if tool [ " function " ] [ " name " ] == tool_name ) , None )
if not stream :
2024-04-05 10:50:49 -04:00
completions : List [ llama_types . CreateCompletionResponse ] = [ ]
completions_tool_name : List [ str ] = [ ]
2024-02-12 15:56:07 -05:00
while tool is not None :
prompt + = f " functions. { tool_name } : \n "
try :
grammar = llama_grammar . LlamaGrammar . from_json_schema (
json . dumps ( tool [ " function " ] [ " parameters " ] ) , verbose = llama . verbose
)
except Exception as e :
grammar = llama_grammar . LlamaGrammar . from_string (
llama_grammar . JSON_GBNF , verbose = llama . verbose
)
if llama . verbose :
print (
" Failed to parse function body as JSON schema, falling back to default grammar "
)
print ( e )
completion_or_chunks = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = False ,
stop = stop ,
max_tokens = None ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = grammar ,
)
2024-07-09 12:20:17 -04:00
completion_or_chunks = cast (
llama_types . CreateCompletionResponse , completion_or_chunks
)
2024-02-12 15:56:07 -05:00
completions . append ( completion_or_chunks )
completions_tool_name . append ( tool_name )
prompt + = completion_or_chunks [ " choices " ] [ 0 ] [ " text " ]
prompt + = " \n "
response = llama . create_completion (
prompt = prompt ,
temperature = temperature ,
top_p = top_p ,
top_k = top_k ,
min_p = min_p ,
typical_p = typical_p ,
stream = False ,
stop = stop ,
max_tokens = None ,
presence_penalty = presence_penalty ,
frequency_penalty = frequency_penalty ,
repeat_penalty = repeat_penalty ,
tfs_z = tfs_z ,
mirostat_mode = mirostat_mode ,
mirostat_tau = mirostat_tau ,
mirostat_eta = mirostat_eta ,
model = model ,
logits_processor = logits_processor ,
grammar = llama_grammar . LlamaGrammar . from_string (
follow_up_gbnf_tool_grammar , verbose = llama . verbose
) ,
)
2024-04-05 10:50:49 -04:00
response = cast ( llama_types . CreateCompletionResponse , response )
2024-02-12 15:56:07 -05:00
tool_name = response [ " choices " ] [ 0 ] [ " text " ] [ len ( " functions. " ) : ]
tool = next (
( tool for tool in tools if tool [ " function " ] [ " name " ] == tool_name ) , None
)
# Merge completions
2024-07-09 12:20:17 -04:00
function_call_dict : Union [
Dict [ str , str ] ,
Dict [
Literal [ " function_call " ] ,
llama_types . ChatCompletionRequestAssistantMessageFunctionCall ,
] ,
] = (
{
" function_call " : {
" name " : tool_name ,
" arguments " : completions [ 0 ] [ " choices " ] [ 0 ] [ " text " ] ,
}
2024-02-12 15:56:07 -05:00
}
2024-07-09 12:20:17 -04:00
if len ( completions ) == 1
else { }
)
2024-02-12 15:56:07 -05:00
return {
" id " : " chat " + completion [ " id " ] ,
" object " : " chat.completion " ,
" created " : completion [ " created " ] ,
" model " : completion [ " model " ] ,
" choices " : [
{
" finish_reason " : " tool_calls " ,
" index " : 0 ,
2024-12-06 12:35:46 +00:00
" logprobs " : _convert_text_completion_logprobs_to_chat ( completion [ " choices " ] [ 0 ] [ " logprobs " ] ) ,
2024-02-12 15:56:07 -05:00
" message " : {
" role " : " assistant " ,
" content " : None ,
" tool_calls " : [
{
" id " : " call_ "
+ f " _ { i } _ "
+ tool_name
+ " _ "
+ completion [ " id " ] ,
" type " : " function " ,
" function " : {
" name " : tool_name ,
" arguments " : completion [ " choices " ] [ 0 ] [ " text " ] ,
} ,
}
for i , ( tool_name , completion ) in enumerate (
zip ( completions_tool_name , completions )
)
] ,
2024-07-09 12:20:17 -04:00
* * function_call_dict ,
2024-02-12 15:56:07 -05:00
} ,
}
] ,
" usage " : {
" completion_tokens " : sum (
2024-07-09 12:20:17 -04:00
(
completion [ " usage " ] [ " completion_tokens " ]
if " usage " in completion
else 0
)
2024-02-12 15:56:07 -05:00
for completion in completions
) ,
" prompt_tokens " : sum (
2024-04-05 10:50:49 -04:00
completion [ " usage " ] [ " prompt_tokens " ] if " usage " in completion else 0
for completion in completions
2024-02-12 15:56:07 -05:00
) ,
" total_tokens " : sum (
2024-04-05 10:50:49 -04:00
completion [ " usage " ] [ " total_tokens " ] if " usage " in completion else 0
for completion in completions
2024-02-12 15:56:07 -05:00
) ,
} ,
}
2024-05-08 07:21:27 +01:00
raise ValueError ( " Automatic streaming tool choice is not supported " )