diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 26ed40833..abb655621 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -672,7 +672,12 @@ def reset(self): if self.batch is not None: self.batch.n_tokens = 0 - def set_batch(self, batch: Sequence[int], n_past: llama_cpp.llama_pos, logits_all: bool): + def set_batch(self, + batch: Sequence[int], + n_past: llama_cpp.llama_pos, + logits_all: bool, + logits_last: bool = True + ): if len(batch) > self.n_tokens_capacity: raise IndexError(f"Input batch size {len(batch)} exceeds capacity {self.n_tokens_capacity}") @@ -684,7 +689,7 @@ def set_batch(self, batch: Sequence[int], n_past: llama_cpp.llama_pos, logits_al self.batch.seq_id[i][0] = 0 self.batch.n_seq_id[i] = 1 self.batch.logits[i] = logits_all - self.batch.logits[n_tokens - 1] = True + self.batch.logits[n_tokens - 1] = logits_last def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool): n_tokens = len(batch) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index ba3d164f9..ecad18034 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -58,6 +58,9 @@ from ._logger import set_verbose from ._utils import suppress_stdout_stderr +from .mtmd_cpp import mtmd_context_params_default, mtmd_init_from_file +from .mtmd import MultiModalContext + class Llama: """High-level Python wrapper for a llama.cpp model.""" @@ -130,6 +133,10 @@ def __init__( # Misc spm_infill: bool = False, verbose: bool = True, + mmproj_path: str = None, + mmproj_use_gpu: Optional[bool] = None, + image_min_tokens: int = -1, + image_max_tokens: int = -1, # Extra Params **kwargs, # type: ignore ): @@ -426,6 +433,29 @@ def __init__( ) ) + if mmproj_path != None: + mparams = mtmd_context_params_default(); + mparams.use_gpu = mmproj_use_gpu if mmproj_use_gpu != None else n_gpu_layers == -1 + mparams.print_timings = verbose + mparams.n_threads = self.n_threads + mparams.flash_attn_type = self.context_params.flash_attn_type + mparams.warmup = True + if image_min_tokens > 0: + mparams.image_min_tokens = image_min_tokens + if image_max_tokens > 0: + mparams.image_max_tokens = image_max_tokens + + with suppress_stdout_stderr(disable=verbose): + mctx = mtmd_init_from_file(mmproj_path.encode("utf-8"), self._model.model, mparams) + if mctx is None: + raise RuntimeError(f"failed to load multimodal projection '{mmproj_path}'") + + self.mtmd_context = self._stack.enter_context( + contextlib.closing( + MultiModalContext(mctx) + ) + ) + # Check for Encoder-Decoder architecture self._has_encoder = self._model.has_encoder() self._has_decoder = self._model.has_decoder() diff --git a/llama_cpp/llama_embedding.py b/llama_cpp/llama_embedding.py index 5da97fa19..86faad32a 100644 --- a/llama_cpp/llama_embedding.py +++ b/llama_cpp/llama_embedding.py @@ -12,6 +12,8 @@ LLAMA_POOLING_TYPE_LAST, LLAMA_POOLING_TYPE_RANK, # Specifically for Reranking models ) +from .mtmd import MediaChunk, mtmd_tokenize, mtmd_prefill +from ._utils import suppress_stdout_stderr # Normalization modes for embedding vectors # See: https://github.com/ggml-org/llama.cpp/tree/master/examples/embedding#--embd-normalize-integer @@ -168,7 +170,7 @@ def embed( if self.verbose: llama_cpp.llama_perf_context_reset(ctx) self._batch.reset() - llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), True) + llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False) # Initialize State Variables results: List[Any] = [] @@ -219,7 +221,7 @@ def _decode_batch(): results.append(data) self._batch.reset() - llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), True) + llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False) batch_seq_lens = [] # Main Streaming Loop @@ -427,3 +429,68 @@ def create_embedding( print(f"Warning: Failed to calculate similarity matrix: {e}") return response + + + def embed_multimodal( + self, + prompt: str, + files: List[bytes | str] = [], + + normalize: int = NORM_MODE_EUCLIDEAN, + return_count: bool = False, + ) -> Union[List[float], List[List[float]], Tuple[Any, int]]: + + ctx = self._ctx.ctx + mctx = self.mtmd_context.ctx + + # Determine if it is in Rerank mode + try: + pooling_type = self.pooling_type() + except AttributeError: + pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED + is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK) + is_none = (pooling_type == LLAMA_POOLING_TYPE_NONE) # Token-level embedding + + out_dim = self.n_embd() + + if self.verbose: + type_str = "TOKEN (None)" if is_none else ("RANK (Score)" if is_rank else "SEQ (Vector)") + print(f"LlamaEmbedding Debug: Mode={type_str} | Pooling={pooling_type} | Dim={out_dim}") + + # Reset Context and Batch + if self.verbose: + llama_cpp.llama_perf_context_reset(ctx) + self._batch.reset() + llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False) + + # Initialize State Variables + result: Any = None + + + with suppress_stdout_stderr(disable=self.verbose): + tokens: MultimodalTokenList = mtmd_tokenize(mctx, prompt, files) + + n_tokens = len(tokens) + + if n_tokens == 0: + result = [] + else: + n_past = mtmd_prefill(self._ctx, mctx, self._batch, tokens) + + # Extract Embeddings + ptr = llama_cpp.llama_get_embeddings_ith(ctx, self._batch.n_tokens() - 1) + data = ptr[:out_dim] + data = self._normalize_vector(data, normalize) + + result = data + + self._batch.reset() + llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False) + + if self.verbose: + llama_cpp.llama_perf_context_print(ctx) + + if return_count: + return result, n_tokens + + return result diff --git a/llama_cpp/mtmd.py b/llama_cpp/mtmd.py new file mode 100644 index 000000000..d8bdf1ec3 --- /dev/null +++ b/llama_cpp/mtmd.py @@ -0,0 +1,202 @@ +import llama_cpp +from llama_cpp import LLAMA_TOKEN_NULL + +import llama_cpp.mtmd_cpp as mtmd +from .mtmd_cpp import mtmd_input_chunk_type, mtmd_free +from ._internals import LlamaContext, LlamaBatch + +import ctypes +from typing import Union, List + +class TextChunk: + def __init__(self, tokens: List[int]): + self.tokens = tokens + self.n_tokens = len(tokens) + +class MediaChunk: + def __init__(self, chunk_ptr: ctypes.c_void_p): + self.chunk_ptr = mtmd.mtmd_input_chunk_copy(chunk_ptr) + self.n_tokens = mtmd.mtmd_input_chunk_get_n_tokens(self.chunk_ptr) + + def __del__(self): + if self.chunk_ptr: + mtmd.mtmd_input_chunk_free(self.chunk_ptr) + +class MultimodalTokenList: + def __init__(self): + self.chunks: List[Union[TextChunk, MediaChunk]] = [] + self.total_tokens = 0 + + def add(self, chunk_ptr: mtmd.mtmd_input_chunk_p_ctypes): + chunk_type = mtmd.mtmd_input_chunk_get_type(chunk_ptr) + + if chunk_type in [ + mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_IMAGE, + mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_AUDIO + ]: + m_chunk = MediaChunk(chunk_ptr) + self.chunks.append(m_chunk) + self.total_tokens += m_chunk.n_tokens + + elif chunk_type == mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_TEXT: + n_tokens_ref = ctypes.c_size_t() + text_tokens_ptr = mtmd.mtmd_input_chunk_get_tokens_text(chunk_ptr, ctypes.byref(n_tokens_ref)) + tokens = [text_tokens_ptr[j] for j in range(n_tokens_ref.value)] + self.add_text(tokens) + + else: + raise ValueError(f"Invalid chunk type {chunk_type}") + + def add_text(self, tokens: List[int]): + if not tokens: return + # combine text nodes + if self.chunks and isinstance(self.chunks[-1], TextChunk): + self.chunks[-1].tokens.extend(tokens) + self.chunks[-1].n_tokens += len(tokens) + else: + self.chunks.append(TextChunk(tokens)) + self.total_tokens += len(tokens) + + def __len__(self): + return self.total_tokens + + +class MultiModalContext: + def __init__( + self, + ctx + ): + self.ctx = ctx + + def close(self): + if self.ctx is None: + return + mtmd_free(self.ctx) + self.ctx = None + + def __del__(self): + self.close() + + +# Simple FNV-1a hash implementation to match fnv_hash in C++ +def fnv_hash(data: bytes) -> str: + h = 0x811c9dc5 + for b in data: + h = (h ^ b) * 0x01000193 + h &= 0xffffffff + return f"{h:08x}" + +def mtmd_tokenize( + mctx: mtmd.mtmd_context_p, + prompt: str, + files_data: list[bytes | str]) -> MultimodalTokenList: + + bitmaps = [] + do_hash = False + + for data in files_data: + + bmp = None + if isinstance(data, str): + bmp = mtmd.mtmd_helper_bitmap_init_from_file(mctx, data.encode("utf-8")) + elif isinstance(data, bytes): + buf = (ctypes.c_ubyte * len(data)).from_buffer_copy(data) + bmp = mtmd.mtmd_helper_bitmap_init_from_buf(mctx, buf, len(buf)) + elif isinstance(data, bytearray): + buf = (ctypes.c_ubyte * len(data)).from_buffer(data) + bmp = mtmd.mtmd_helper_bitmap_init_from_buf(mctx, buf, len(buf)) + + if bmp is None: + raise RuntimeError("Failed to load image or audio file") + + if do_hash: + data_ptr = mtmd.mtmd_bitmap_get_data(bmp) + data_size = mtmd.mtmd_bitmap_get_n_bytes(bmp) + + raw_node_data = ctypes.string_at(data_ptr, data_size) + h = fnv_hash(raw_node_data) + mtmd.mtmd_bitmap_set_id(bmp, h.encode('utf-8')) + + bitmaps.append(bmp) + + inp_txt = mtmd.mtmd_input_text( + text=prompt.encode('utf-8'), + add_special=True, + parse_special=True + ) + + chunks_ptr = mtmd.mtmd_input_chunks_init() + + n_bitmaps = len(bitmaps) + if n_bitmaps > 0: + BitmapPtr = mtmd.mtmd_bitmap_p_ctypes * n_bitmaps + bitmaps_ptr = BitmapPtr(*bitmaps) + else: + bitmaps_ptr = None + + res = mtmd.mtmd_tokenize( + mctx, + chunks_ptr, + ctypes.pointer(inp_txt), + bitmaps_ptr, + n_bitmaps + ) + + # TODO Hash based cache + for data in bitmaps: + mtmd.mtmd_bitmap_free(bmp) + + if res != 0: + mtmd.mtmd_input_chunks_free(chunks_ptr) + raise RuntimeError(f"Tokenization failed with code {res}") + + st = MultimodalTokenList() + + n_chunks = mtmd.mtmd_input_chunks_size(chunks_ptr) + for i in range(n_chunks): + chunk_ptr = mtmd.mtmd_input_chunks_get(chunks_ptr, i) + st.add(chunk_ptr) + + mtmd.mtmd_input_chunks_free(chunks_ptr) + + return st + +def mtmd_prefill( + ctx: LlamaContext, + mctx: mtmd.mtmd_context_p, + batch: LlamaBatch, + mtmd_tokens: MultimodalTokenList +) -> int: + n_past = 0 + n_batch = ctx.n_batch() + total_chunks = len(mtmd_tokens.chunks) + + for i, chunk in enumerate(mtmd_tokens.chunks): + is_last_chunk = (i == total_chunks - 1) + + if isinstance(chunk, TextChunk): + batch.set_batch( + chunk.tokens, + n_past, + logits_all=False, + logits_last=is_last_chunk + ) + ctx.decode(batch) + + n_past += chunk.n_tokens + else: + new_n_past = llama_cpp.llama_pos(0) + result = mtmd.mtmd_helper_eval_chunk_single( + mctx, + ctx.ctx, + chunk.chunk_ptr, + llama_cpp.llama_pos(n_past), + llama_cpp.llama_seq_id(0), + n_batch, + False, # logits_last + ctypes.byref(new_n_past) + ) + if result != 0: + raise RuntimeError(f"MTMD eval error: {result}") + + n_past = new_n_past.value