Skip to content

Commit 2987bc6

Browse files
author
zpin
committed
Fixes for llamap.cpp changes
1 parent 8789afa commit 2987bc6

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

llama_cpp/_internals.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -810,14 +810,14 @@ def add_xtc(self, probability: float, threshold: float, min_keep: int, seed: int
810810
sampler = llama_cpp.llama_sampler_init_xtc(probability, threshold, min_keep, seed)
811811
self._add_sampler(sampler)
812812

813-
def add_dry(self, model: LlamaModel, multiplier: float, base: float,
813+
def add_dry(self, model: LlamaModel, ctx: LlamaContext, multiplier: float, base: float,
814814
allowed_length: int, penalty_last_n: int, seq_breakers: list[str] = []):
815815

816816
# Convert Python strings to bytes
817817
seq_breakers_bytes = [s.encode('utf-8') for s in seq_breakers]
818818
# Create array of char*
819819
arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes)
820-
sampler = llama_cpp.llama_sampler_init_dry(model.model, multiplier, base,
820+
sampler = llama_cpp.llama_sampler_init_dry(model.vocab, ctx.n_ctx(), multiplier, base,
821821
allowed_length, penalty_last_n,
822822
arr, len(seq_breakers))
823823
self._add_sampler(sampler)

llama_cpp/llama.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
754754
else:
755755
n_probs = 0
756756
min_keep = max(1, n_probs)
757-
sampler.add_dry(self._model, dry_multiplier, dry_base, dry_allowed_length, dry_range, dry_seq_breakers)
757+
sampler.add_dry(self._model, self._ctx, dry_multiplier, dry_base, dry_allowed_length, dry_range, dry_seq_breakers)
758758
sampler.add_top_k(top_k)
759759
sampler.add_typical(typical_p, min_keep)
760760
sampler.add_top_p(top_p, min_keep)

llama_cpp/llama_cpp.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -3647,7 +3647,8 @@ def llama_sampler_init_xtc(
36473647
...
36483648

36493649
# LLAMA_API struct llama_sampler * llama_sampler_init_dry(
3650-
# const struct llama_model * model,
3650+
# const struct llama_vocab * vocab,
3651+
# int32_t context_size,
36513652
# float dry_multiplier,
36523653
# float dry_base,
36533654
# int32_t dry_allowed_length,
@@ -3657,7 +3658,8 @@ def llama_sampler_init_xtc(
36573658
@ctypes_function(
36583659
"llama_sampler_init_dry",
36593660
[
3660-
llama_model_p_ctypes,
3661+
llama_vocab_p_ctypes,
3662+
ctypes.c_int32,
36613663
ctypes.c_float,
36623664
ctypes.c_float,
36633665
ctypes.c_int32,
@@ -3668,7 +3670,8 @@ def llama_sampler_init_xtc(
36683670
llama_sampler_p_ctypes,
36693671
)
36703672
def llama_sampler_init_dry(
3671-
model: llama_model_p,
3673+
vocab: llama_vocab_p,
3674+
context_size: int,
36723675
dry_multiplier: float,
36733676
dry_base: float,
36743677
dry_allowed_length: int,

0 commit comments

Comments
 (0)