diff --git a/llama_cpp/_llava.py b/llama_cpp/_llava.py new file mode 100644 index 000000000..4bdb7dea4 --- /dev/null +++ b/llama_cpp/_llava.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import os +import ctypes +import typing +import contextlib + +import numpy as np + +import llama_cpp +import llama_cpp.llava_cpp as llava_cpp + + +class LlavaEmbedding: + def __init__(self, embedding: ctypes._Pointer[llava_cpp.llava_image_embed]): + self._embedding = embedding + self._exit_stack = contextlib.ExitStack() + + def llava_image_embed_free(): + llava_cpp.llava_image_embed_free(self._embedding) + + self._exit_stack.callback(llava_image_embed_free) + + @property + def n_image_pos(self) -> int: + return self._embedding.contents.n_image_pos + + def embed( + self, llama_ctx: llama_cpp.llama_context_p, n_tokens: int, n_batch: int + ) -> int: + n_past = ctypes.c_int(n_tokens) + n_past_p = ctypes.pointer(n_past) + llava_cpp.llava_eval_image_embed( + llama_ctx, + self._embedding, + n_batch, + n_past_p, + ) + return n_past.value + + def numpy_view(self, shape: typing.Tuple[int, int]) -> np.ndarray: + return np.ctypeslib.as_array( + self._embedding.contents.embed, shape=shape + ) + + +class LlavaModel: + def __init__(self, path: str, n_threads: int = 1): + self._path = path + self._n_threads = n_threads + self._exit_stack = contextlib.ExitStack() + + if not os.path.exists(self._path): + raise ValueError(f"Clip model path does not exist: {self._path}") + + clip_ctx = llava_cpp.clip_model_load(self._path.encode(), 0) + + if clip_ctx is None: + raise ValueError(f"Failed to load clip model: {self._path}") + + self._clip_ctx = clip_ctx + + def clip_free(): + llava_cpp.clip_free(self._clip_ctx) + print("Clip model freed") + + self._exit_stack.callback(clip_free) + + def embed_bytes(self, image_bytes: bytes): + embed = llava_cpp.llava_image_embed_make_with_bytes( + self._clip_ctx, + self._n_threads, + (ctypes.c_uint8 * len(image_bytes)).from_buffer(bytearray(image_bytes)), + len(image_bytes), + ) + return LlavaEmbedding(embed) + diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index dfb0af65e..88105aee6 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -28,9 +28,11 @@ import numpy as np import numpy.typing as npt +import llama_cpp import llama_cpp.llama as llama import llama_cpp.llama_types as llama_types import llama_cpp.llama_grammar as llama_grammar +import llama_cpp._internals as internals from ._logger import logger from ._utils import suppress_stdout_stderr, Singleton @@ -3350,6 +3352,204 @@ class MiniCPMv26ChatHandler(Llava15ChatHandler): ) +class PaliGemmaChatHandler(Llava15ChatHandler): + def __call__( + self, + *, + llama: llama.Llama, + messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, + temperature: float = 0.2, + top_p: float = 0.95, + top_k: int = 40, + min_p: float = 0.05, + typical_p: float = 1.0, + stream: bool = False, + stop: Optional[Union[str, List[str]]] = [], + seed: Optional[int] = None, + response_format: Optional[ + llama_types.ChatCompletionRequestResponseFormat + ] = None, + max_tokens: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repeat_penalty: float = 1.1, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + logits_processor: Optional[llama.LogitsProcessorList] = None, + grammar: Optional[llama.LlamaGrammar] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + **kwargs, # type: ignore + ) -> Union[ + llama_types.CreateChatCompletionResponse, + Iterator[llama_types.CreateChatCompletionStreamResponse], + ]: + assert self.clip_ctx is not None + + if len(messages) != 1: + raise ValueError("PaligemmaChatHandler only supports single-turn conversations.") + + image_urls = self.get_image_urls(messages) + + if len(image_urls) > 1: + raise ValueError("PaligemmaChatHandler only supports single image per turn.") + + text = "answer en " + message = messages[0] + if isinstance(message["content"], str): + text = message["content"] + elif isinstance(message["content"], list): + for content in message["content"]: + if content["type"] == "text": + text += content["text"] + text += "\n" + + if self.verbose: + print(text, file=sys.stderr) + + + + tokens = llama.tokenize(text.encode("utf-8"), special=True) + embedding_dim = llama_cpp.llama_n_embd(llama.model) + tokens_np = np.array(tokens).astype(np.int32) + token_embedding = np.empty((len(tokens), embedding_dim), dtype=np.single) + llama_cpp.llama_token_inp_embd( + llama.ctx, + tokens_np.ctypes.data_as(ctypes.POINTER(ctypes.c_int32)), + len(tokens), + token_embedding.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + ) + + if len(image_urls) > 0: + image_embedding = self._embed_image_bytes(self.load_image(image_urls[0])) + n_image_pos = image_embedding.contents.n_image_pos + embeds = np.concatenate([np.ctypeslib.as_array(image_embedding.contents.embed, shape=(n_image_pos, embedding_dim)), token_embedding], axis=0) + n_tokens = n_image_pos + len(tokens) + llama.input_ids[: n_tokens] = ( + llama.tokenize(b"", add_bos=False, special=True) * image_embedding.contents.n_image_pos + tokens + ) + else: + n_tokens = len(tokens) + llama.input_ids[: n_tokens] = tokens + embeds = token_embedding + + + n_batch = 512 + batch = internals.LlamaBatch(n_tokens=n_batch, embd=embedding_dim, n_seq_max=1) + + batch.batch.n_tokens = n_tokens + + np.ctypeslib.as_array(batch.batch.embd, shape=(n_batch, embedding_dim))[ + :n_tokens, : + ] = embeds + np.ctypeslib.as_array(batch.batch.pos, shape=(n_batch,))[:n_tokens] = np.arange(n_tokens) + np.ctypeslib.as_array(batch.batch.n_seq_id, shape=(n_batch,))[:] = 1 + np.ctypeslib.as_array(batch.batch.logits, shape=(n_batch,))[:] = False + np.ctypeslib.as_array(batch.batch.logits, shape=(n_batch,))[n_tokens - 1] = True + + for i in range(n_tokens): + batch.batch.seq_id[i][0] = 0 + + # Evaluate prompt + llama.reset() + llama._ctx.kv_cache_clear() + llama_cpp.llama_set_causal_attn(llama._ctx.ctx, False) + llama._ctx.decode(batch) + llama.n_tokens += n_tokens + llama_cpp.llama_set_causal_attn(llama._ctx.ctx, True) + + # Get prompt tokens to avoid a cache miss + prompt = llama.input_ids[: llama.n_tokens].tolist() + + if response_format is not None and response_format["type"] == "json_object": + grammar = _grammar_for_response_format(response_format) + + # Convert legacy functions to tools + if functions is not None: + tools = [ + { + "type": "function", + "function": function, + } + for function in functions + ] + + # Convert legacy function_call to tool_choice + if function_call is not None: + if isinstance(function_call, str) and ( + function_call == "none" or function_call == "auto" + ): + tool_choice = function_call + if isinstance(function_call, dict) and "name" in function_call: + tool_choice = { + "type": "function", + "function": { + "name": function_call["name"], + }, + } + + tool = None + if ( + tool_choice is not None + and isinstance(tool_choice, dict) + and tools is not None + ): + name = tool_choice["function"]["name"] + tool = next((t for t in tools if t["function"]["name"] == name), None) + if tool is None: + raise ValueError(f"Tool choice '{name}' not found in tools.") + schema = tool["function"]["parameters"] + try: + # create grammar from json schema + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(schema), verbose=llama.verbose + ) + except Exception as e: + if llama.verbose: + print(str(e), file=sys.stderr) + grammar = llama_grammar.LlamaGrammar.from_string( + llama_grammar.JSON_GBNF, verbose=llama.verbose + ) + + completion_or_chunks = llama.create_completion( + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + typical_p=typical_p, + logprobs=top_logprobs if logprobs else None, + stream=stream, + stop=stop, + seed=seed, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=grammar, + logit_bias=logit_bias, + ) + if tool is not None: + tool_name = tool["function"]["name"] + return _convert_completion_to_chat_function( + tool_name, completion_or_chunks, stream + ) + return _convert_completion_to_chat(completion_or_chunks, stream=stream) + @register_chat_completion_handler("chatml-function-calling") def chatml_function_calling( llama: llama.Llama, diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 97c969136..9933869d7 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -2836,6 +2836,47 @@ def llama_detokenize( ... +# // @details Get the input embeddings for a sequence of tokens +# // @param tokens The tokens to embed +# // @param n_tokens The number of tokens +# // @param embeddings The embeddings pointer must be large enough to hold the resulting embeddings. +# // @param n_embd The number of embeddings per token +# // @return Returns a negative number on failure +# LLAMA_API int32_t llama_token_inp_embd( +# struct llama_context * ctx, +# llama_token * tokens, +# int32_t n_tokens, +# float * embeddings); +@ctypes_function( + "llama_token_inp_embd", + [ + llama_context_p_ctypes, + llama_token_p, + ctypes.c_int32, + ctypes.POINTER(ctypes.c_float), + ], + ctypes.c_int32, +) +def llama_token_inp_embd( + ctx: llama_context_p, + tokens: CtypesArray[llama_token], + n_tokens: Union[ctypes.c_int32, int], + embeddings: CtypesArray[ctypes.c_float], + /, +) -> int: + """Get the input embeddings for a sequence of tokens + + Args: + ctx: The model context. + tokens: The tokens to embed. + n_tokens: The number of tokens. + embeddings: The embeddings pointer must be large enough to hold the resulting embeddings. + + Returns: + Returns a negative number on failure""" + ... + + # // # // Chat templates # // diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index c6716f919..3b1184d82 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -171,6 +171,20 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama: chat_handler = llama_cpp.llama_chat_format.MiniCPMv26ChatHandler( clip_model_path=settings.clip_model_path, verbose=settings.verbose ) + elif settings.chat_format == "paligemma": + assert settings.clip_model_path is not None, "clip model not found" + if settings.hf_model_repo_id is not None: + chat_handler = ( + llama_cpp.llama_chat_format.PaliGemmaChatHandler.from_pretrained( + repo_id=settings.hf_model_repo_id, + filename=settings.clip_model_path, + verbose=settings.verbose, + ) + ) + else: + chat_handler = llama_cpp.llama_chat_format.PaliGemmaChatHandler( + clip_model_path=settings.clip_model_path, verbose=settings.verbose + ) elif settings.chat_format == "hf-autotokenizer": assert ( settings.hf_pretrained_model_name_or_path is not None diff --git a/vendor/llama.cpp b/vendor/llama.cpp index c919d5db3..c702e5593 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit c919d5db39c8a7fcb64737f008e4b105ee0acd20 +Subproject commit c702e5593086c42b3fb52ad68e04e37ffe29f61f