Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,12 @@ def reset(self):
if self.batch is not None:
self.batch.n_tokens = 0

def set_batch(self, batch: Sequence[int], n_past: llama_cpp.llama_pos, logits_all: bool):
def set_batch(self,
batch: Sequence[int],
n_past: llama_cpp.llama_pos,
logits_all: bool,
logits_last: bool = True
):
if len(batch) > self.n_tokens_capacity:
raise IndexError(f"Input batch size {len(batch)} exceeds capacity {self.n_tokens_capacity}")

Expand All @@ -684,7 +689,7 @@ def set_batch(self, batch: Sequence[int], n_past: llama_cpp.llama_pos, logits_al
self.batch.seq_id[i][0] = 0
self.batch.n_seq_id[i] = 1
self.batch.logits[i] = logits_all
self.batch.logits[n_tokens - 1] = True
self.batch.logits[n_tokens - 1] = logits_last

def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
n_tokens = len(batch)
Expand Down
30 changes: 30 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
from ._logger import set_verbose
from ._utils import suppress_stdout_stderr

from .mtmd_cpp import mtmd_context_params_default, mtmd_init_from_file
from .mtmd import MultiModalContext


class Llama:
"""High-level Python wrapper for a llama.cpp model."""
Expand Down Expand Up @@ -130,6 +133,10 @@ def __init__(
# Misc
spm_infill: bool = False,
verbose: bool = True,
mmproj_path: str = None,
mmproj_use_gpu: Optional[bool] = None,
image_min_tokens: int = -1,
image_max_tokens: int = -1,
# Extra Params
**kwargs, # type: ignore
):
Expand Down Expand Up @@ -426,6 +433,29 @@ def __init__(
)
)

if mmproj_path != None:
mparams = mtmd_context_params_default();
mparams.use_gpu = mmproj_use_gpu if mmproj_use_gpu != None else n_gpu_layers == -1
mparams.print_timings = verbose
mparams.n_threads = self.n_threads
mparams.flash_attn_type = self.context_params.flash_attn_type
mparams.warmup = True
if image_min_tokens > 0:
mparams.image_min_tokens = image_min_tokens
if image_max_tokens > 0:
mparams.image_max_tokens = image_max_tokens

with suppress_stdout_stderr(disable=verbose):
mctx = mtmd_init_from_file(mmproj_path.encode("utf-8"), self._model.model, mparams)
if mctx is None:
raise RuntimeError(f"failed to load multimodal projection '{mmproj_path}'")

self.mtmd_context = self._stack.enter_context(
contextlib.closing(
MultiModalContext(mctx)
)
)

# Check for Encoder-Decoder architecture
self._has_encoder = self._model.has_encoder()
self._has_decoder = self._model.has_decoder()
Expand Down
71 changes: 69 additions & 2 deletions llama_cpp/llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
LLAMA_POOLING_TYPE_LAST,
LLAMA_POOLING_TYPE_RANK, # Specifically for Reranking models
)
from .mtmd import MediaChunk, mtmd_tokenize, mtmd_prefill
from ._utils import suppress_stdout_stderr

# Normalization modes for embedding vectors
# See: https://github.com/ggml-org/llama.cpp/tree/master/examples/embedding#--embd-normalize-integer
Expand Down Expand Up @@ -168,7 +170,7 @@ def embed(
if self.verbose:
llama_cpp.llama_perf_context_reset(ctx)
self._batch.reset()
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), True)
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)

# Initialize State Variables
results: List[Any] = []
Expand Down Expand Up @@ -219,7 +221,7 @@ def _decode_batch():
results.append(data)

self._batch.reset()
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), True)
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)
batch_seq_lens = []

# Main Streaming Loop
Expand Down Expand Up @@ -427,3 +429,68 @@ def create_embedding(
print(f"Warning: Failed to calculate similarity matrix: {e}")

return response


def embed_multimodal(
self,
prompt: str,
files: List[bytes | str] = [],

normalize: int = NORM_MODE_EUCLIDEAN,
return_count: bool = False,
) -> Union[List[float], List[List[float]], Tuple[Any, int]]:

ctx = self._ctx.ctx
mctx = self.mtmd_context.ctx

# Determine if it is in Rerank mode
try:
pooling_type = self.pooling_type()
except AttributeError:
pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED
is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK)
is_none = (pooling_type == LLAMA_POOLING_TYPE_NONE) # Token-level embedding

out_dim = self.n_embd()

if self.verbose:
type_str = "TOKEN (None)" if is_none else ("RANK (Score)" if is_rank else "SEQ (Vector)")
print(f"LlamaEmbedding Debug: Mode={type_str} | Pooling={pooling_type} | Dim={out_dim}")

# Reset Context and Batch
if self.verbose:
llama_cpp.llama_perf_context_reset(ctx)
self._batch.reset()
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)

# Initialize State Variables
result: Any = None


with suppress_stdout_stderr(disable=self.verbose):
tokens: MultimodalTokenList = mtmd_tokenize(mctx, prompt, files)

n_tokens = len(tokens)

if n_tokens == 0:
result = []
else:
n_past = mtmd_prefill(self._ctx, mctx, self._batch, tokens)

# Extract Embeddings
ptr = llama_cpp.llama_get_embeddings_ith(ctx, self._batch.n_tokens() - 1)
data = ptr[:out_dim]
data = self._normalize_vector(data, normalize)

result = data

self._batch.reset()
llama_cpp.llama_memory_clear(llama_cpp.llama_get_memory(ctx), False)

if self.verbose:
llama_cpp.llama_perf_context_print(ctx)

if return_count:
return result, n_tokens

return result
202 changes: 202 additions & 0 deletions llama_cpp/mtmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import llama_cpp
from llama_cpp import LLAMA_TOKEN_NULL

import llama_cpp.mtmd_cpp as mtmd
from .mtmd_cpp import mtmd_input_chunk_type, mtmd_free
from ._internals import LlamaContext, LlamaBatch

import ctypes
from typing import Union, List

class TextChunk:
def __init__(self, tokens: List[int]):
self.tokens = tokens
self.n_tokens = len(tokens)

class MediaChunk:
def __init__(self, chunk_ptr: ctypes.c_void_p):
self.chunk_ptr = mtmd.mtmd_input_chunk_copy(chunk_ptr)
self.n_tokens = mtmd.mtmd_input_chunk_get_n_tokens(self.chunk_ptr)

def __del__(self):
if self.chunk_ptr:
mtmd.mtmd_input_chunk_free(self.chunk_ptr)

class MultimodalTokenList:
def __init__(self):
self.chunks: List[Union[TextChunk, MediaChunk]] = []
self.total_tokens = 0

def add(self, chunk_ptr: mtmd.mtmd_input_chunk_p_ctypes):
chunk_type = mtmd.mtmd_input_chunk_get_type(chunk_ptr)

if chunk_type in [
mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_IMAGE,
mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_AUDIO
]:
m_chunk = MediaChunk(chunk_ptr)
self.chunks.append(m_chunk)
self.total_tokens += m_chunk.n_tokens

elif chunk_type == mtmd_input_chunk_type.MTMD_INPUT_CHUNK_TYPE_TEXT:
n_tokens_ref = ctypes.c_size_t()
text_tokens_ptr = mtmd.mtmd_input_chunk_get_tokens_text(chunk_ptr, ctypes.byref(n_tokens_ref))
tokens = [text_tokens_ptr[j] for j in range(n_tokens_ref.value)]
self.add_text(tokens)

else:
raise ValueError(f"Invalid chunk type {chunk_type}")

def add_text(self, tokens: List[int]):
if not tokens: return
# combine text nodes
if self.chunks and isinstance(self.chunks[-1], TextChunk):
self.chunks[-1].tokens.extend(tokens)
self.chunks[-1].n_tokens += len(tokens)
else:
self.chunks.append(TextChunk(tokens))
self.total_tokens += len(tokens)

def __len__(self):
return self.total_tokens


class MultiModalContext:
def __init__(
self,
ctx
):
self.ctx = ctx

def close(self):
if self.ctx is None:
return
mtmd_free(self.ctx)
self.ctx = None

def __del__(self):
self.close()


# Simple FNV-1a hash implementation to match fnv_hash in C++
def fnv_hash(data: bytes) -> str:
h = 0x811c9dc5
for b in data:
h = (h ^ b) * 0x01000193
h &= 0xffffffff
return f"{h:08x}"

def mtmd_tokenize(
mctx: mtmd.mtmd_context_p,
prompt: str,
files_data: list[bytes | str]) -> MultimodalTokenList:

bitmaps = []
do_hash = False

for data in files_data:

bmp = None
if isinstance(data, str):
bmp = mtmd.mtmd_helper_bitmap_init_from_file(mctx, data.encode("utf-8"))
elif isinstance(data, bytes):
buf = (ctypes.c_ubyte * len(data)).from_buffer_copy(data)
bmp = mtmd.mtmd_helper_bitmap_init_from_buf(mctx, buf, len(buf))
elif isinstance(data, bytearray):
buf = (ctypes.c_ubyte * len(data)).from_buffer(data)
bmp = mtmd.mtmd_helper_bitmap_init_from_buf(mctx, buf, len(buf))

if bmp is None:
raise RuntimeError("Failed to load image or audio file")

if do_hash:
data_ptr = mtmd.mtmd_bitmap_get_data(bmp)
data_size = mtmd.mtmd_bitmap_get_n_bytes(bmp)

raw_node_data = ctypes.string_at(data_ptr, data_size)
h = fnv_hash(raw_node_data)
mtmd.mtmd_bitmap_set_id(bmp, h.encode('utf-8'))

bitmaps.append(bmp)

inp_txt = mtmd.mtmd_input_text(
text=prompt.encode('utf-8'),
add_special=True,
parse_special=True
)

chunks_ptr = mtmd.mtmd_input_chunks_init()

n_bitmaps = len(bitmaps)
if n_bitmaps > 0:
BitmapPtr = mtmd.mtmd_bitmap_p_ctypes * n_bitmaps
bitmaps_ptr = BitmapPtr(*bitmaps)
else:
bitmaps_ptr = None

res = mtmd.mtmd_tokenize(
mctx,
chunks_ptr,
ctypes.pointer(inp_txt),
bitmaps_ptr,
n_bitmaps
)

# TODO Hash based cache
for data in bitmaps:
mtmd.mtmd_bitmap_free(bmp)

if res != 0:
mtmd.mtmd_input_chunks_free(chunks_ptr)
raise RuntimeError(f"Tokenization failed with code {res}")

st = MultimodalTokenList()

n_chunks = mtmd.mtmd_input_chunks_size(chunks_ptr)
for i in range(n_chunks):
chunk_ptr = mtmd.mtmd_input_chunks_get(chunks_ptr, i)
st.add(chunk_ptr)

mtmd.mtmd_input_chunks_free(chunks_ptr)

return st

def mtmd_prefill(
ctx: LlamaContext,
mctx: mtmd.mtmd_context_p,
batch: LlamaBatch,
mtmd_tokens: MultimodalTokenList
) -> int:
n_past = 0
n_batch = ctx.n_batch()
total_chunks = len(mtmd_tokens.chunks)

for i, chunk in enumerate(mtmd_tokens.chunks):
is_last_chunk = (i == total_chunks - 1)

if isinstance(chunk, TextChunk):
batch.set_batch(
chunk.tokens,
n_past,
logits_all=False,
logits_last=is_last_chunk
)
ctx.decode(batch)

n_past += chunk.n_tokens
else:
new_n_past = llama_cpp.llama_pos(0)
result = mtmd.mtmd_helper_eval_chunk_single(
mctx,
ctx.ctx,
chunk.chunk_ptr,
llama_cpp.llama_pos(n_past),
llama_cpp.llama_seq_id(0),
n_batch,
False, # logits_last
ctypes.byref(new_n_past)
)
if result != 0:
raise RuntimeError(f"MTMD eval error: {result}")

n_past = new_n_past.value
Loading