|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import os
|
||
|
|
import ctypes
|
||
|
|
import typing
|
||
|
|
import contextlib
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
import llama_cpp
|
||
|
|
import llama_cpp.llava_cpp as llava_cpp
|
||
|
|
|
||
|
|
|
||
|
|
class LlavaEmbedding:
|
||
|
|
def __init__(self, embedding: ctypes._Pointer[llava_cpp.llava_image_embed]):
|
||
|
|
self._embedding = embedding
|
||
|
|
self._exit_stack = contextlib.ExitStack()
|
||
|
|
|
||
|
|
def llava_image_embed_free():
|
||
|
|
llava_cpp.llava_image_embed_free(self._embedding)
|
||
|
|
|
||
|
|
self._exit_stack.callback(llava_image_embed_free)
|
||
|
|
|
||
|
|
@property
|
||
|
|
def n_image_pos(self) -> int:
|
||
|
|
return self._embedding.contents.n_image_pos
|
||
|
|
|
||
|
|
def embed(
|
||
|
|
self, llama_ctx: llama_cpp.llama_context_p, n_tokens: int, n_batch: int
|
||
|
|
) -> int:
|
||
|
|
n_past = ctypes.c_int(n_tokens)
|
||
|
|
n_past_p = ctypes.pointer(n_past)
|
||
|
|
llava_cpp.llava_eval_image_embed(
|
||
|
|
llama_ctx,
|
||
|
|
self._embedding,
|
||
|
|
n_batch,
|
||
|
|
n_past_p,
|
||
|
|
)
|
||
|
|
return n_past.value
|
||
|
|
|
||
|
|
def numpy_view(self, shape: typing.Tuple[int, int]) -> np.ndarray:
|
||
|
|
return np.ctypeslib.as_array(
|
||
|
|
self._embedding.contents.embed, shape=shape
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class LlavaModel:
|
||
|
|
def __init__(self, path: str, n_threads: int = 1):
|
||
|
|
self._path = path
|
||
|
|
self._n_threads = n_threads
|
||
|
|
self._exit_stack = contextlib.ExitStack()
|
||
|
|
|
||
|
|
if not os.path.exists(self._path):
|
||
|
|
raise ValueError(f"Clip model path does not exist: {self._path}")
|
||
|
|
|
||
|
|
clip_ctx = llava_cpp.clip_model_load(self._path.encode(), 0)
|
||
|
|
|
||
|
|
if clip_ctx is None:
|
||
|
|
raise ValueError(f"Failed to load clip model: {self._path}")
|
||
|
|
|
||
|
|
self._clip_ctx = clip_ctx
|
||
|
|
|
||
|
|
def clip_free():
|
||
|
|
llava_cpp.clip_free(self._clip_ctx)
|
||
|
|
print("Clip model freed")
|
||
|
|
|
||
|
|
self._exit_stack.callback(clip_free)
|
||
|
|
|
||
|
|
def embed_bytes(self, image_bytes: bytes):
|
||
|
|
embed = llava_cpp.llava_image_embed_make_with_bytes(
|
||
|
|
self._clip_ctx,
|
||
|
|
self._n_threads,
|
||
|
|
(ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)),
|
||
|
|
len(image_bytes),
|
||
|
|
)
|
||
|
|
return LlavaEmbedding(embed)
|
||
|
|
|