Skip to content

Commit fa20e91

Browse files
llmixerzpin
llmixer
authored andcommitted
Added DRY and XTC samplers
1 parent 710e19a commit fa20e91

File tree

3 files changed

+148
-0
lines changed

3 files changed

+148
-0
lines changed

llama_cpp/_internals.py

+16
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,22 @@ def add_mirostat_v2(self, seed: int, tau: float, eta: float):
806806
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
807807
self._add_sampler(sampler)
808808

809+
def add_xtc(self, probability: float, threshold: float, min_keep: int, seed: int):
810+
sampler = llama_cpp.llama_sampler_init_xtc(probability, threshold, min_keep, seed)
811+
self._add_sampler(sampler)
812+
813+
def add_dry(self, model: LlamaModel, multiplier: float, base: float,
814+
allowed_length: int, penalty_last_n: int, seq_breakers: list[str] = []):
815+
816+
# Convert Python strings to bytes
817+
seq_breakers_bytes = [s.encode('utf-8') for s in seq_breakers]
818+
# Create array of char*
819+
arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes)
820+
sampler = llama_cpp.llama_sampler_init_dry(model.model, multiplier, base,
821+
allowed_length, penalty_last_n,
822+
arr, len(seq_breakers))
823+
self._add_sampler(sampler)
824+
809825
def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
810826
sampler = llama_cpp.llama_sampler_init_grammar(
811827
model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")

llama_cpp/llama.py

+100
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,13 @@ def _init_sampler(
680680
mirostat_mode: int = 0,
681681
mirostat_eta: float = 0.1,
682682
mirostat_tau: float = 5.0,
683+
xtc_probability: float = 0.0,
684+
xtc_threshold: float = 0.1,
685+
dry_multiplier: float = 0.0,
686+
dry_allowed_length: int = 2,
687+
dry_base: float = 1.75,
688+
dry_range: int = 0,
689+
dry_seq_breakers: list[str] = [],
683690
penalize_nl: bool = True,
684691
logits_processor: Optional[LogitsProcessorList] = None,
685692
grammar: Optional[LlamaGrammar] = None,
@@ -747,12 +754,14 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
747754
else:
748755
n_probs = 0
749756
min_keep = max(1, n_probs)
757+
sampler.add_dry(self._model, dry_multiplier, dry_base, dry_allowed_length, dry_range, dry_seq_breakers)
750758
sampler.add_top_k(top_k)
751759
sampler.add_typical(typical_p, min_keep)
752760
sampler.add_top_p(top_p, min_keep)
753761
sampler.add_min_p(min_p, min_keep)
754762
sampler.add_temp(temp)
755763
sampler.add_dist(self._seed)
764+
sampler.add_xtc(xtc_probability, xtc_threshold, min_keep, self._seed)
756765
return sampler
757766

758767
def sample(
@@ -769,6 +778,13 @@ def sample(
769778
mirostat_mode: int = 0,
770779
mirostat_eta: float = 0.1,
771780
mirostat_tau: float = 5.0,
781+
xtc_probability: float = 0.0,
782+
xtc_threshold: float = 0.1,
783+
dry_multiplier: float = 0.0,
784+
dry_allowed_length: int = 2,
785+
dry_base: float = 1.75,
786+
dry_range: int = 0,
787+
dry_seq_breakers: list[str] = [],
772788
penalize_nl: bool = True,
773789
logits_processor: Optional[LogitsProcessorList] = None,
774790
grammar: Optional[LlamaGrammar] = None,
@@ -804,6 +820,13 @@ def sample(
804820
mirostat_mode=mirostat_mode,
805821
mirostat_tau=mirostat_tau,
806822
mirostat_eta=mirostat_eta,
823+
xtc_probability=xtc_probability,
824+
xtc_threshold=xtc_threshold,
825+
dry_multiplier=dry_multiplier,
826+
dry_allowed_length=dry_allowed_length,
827+
dry_base=dry_base,
828+
dry_range=dry_range,
829+
dry_seq_breakers=dry_seq_breakers,
807830
penalize_nl=penalize_nl,
808831
logits_processor=logits_processor,
809832
grammar=grammar,
@@ -833,6 +856,13 @@ def generate(
833856
mirostat_mode: int = 0,
834857
mirostat_tau: float = 5.0,
835858
mirostat_eta: float = 0.1,
859+
xtc_probability: float = 0.0,
860+
xtc_threshold: float = 0.1,
861+
dry_multiplier: float = 0.0,
862+
dry_allowed_length: int = 2,
863+
dry_base: float = 1.75,
864+
dry_range: int = 0,
865+
dry_seq_breakers: list[str] = [],
836866
penalize_nl: bool = True,
837867
logits_processor: Optional[LogitsProcessorList] = None,
838868
stopping_criteria: Optional[StoppingCriteriaList] = None,
@@ -872,6 +902,13 @@ def generate(
872902
mirostat_mode=mirostat_mode,
873903
mirostat_tau=mirostat_tau,
874904
mirostat_eta=mirostat_eta,
905+
xtc_probability=xtc_probability,
906+
xtc_threshold=xtc_threshold,
907+
dry_multiplier=dry_multiplier,
908+
dry_allowed_length=dry_allowed_length,
909+
dry_base=dry_base,
910+
dry_range=dry_range,
911+
dry_seq_breakers=dry_seq_breakers,
875912
penalize_nl=penalize_nl,
876913
logits_processor=logits_processor,
877914
grammar=grammar,
@@ -924,6 +961,13 @@ def generate(
924961
mirostat_mode=mirostat_mode,
925962
mirostat_tau=mirostat_tau,
926963
mirostat_eta=mirostat_eta,
964+
xtc_probability=xtc_probability,
965+
xtc_threshold=xtc_threshold,
966+
dry_multiplier=dry_multiplier,
967+
dry_allowed_length=dry_allowed_length,
968+
dry_base=dry_base,
969+
dry_range=dry_range,
970+
dry_seq_breakers=dry_seq_breakers,
927971
logits_processor=logits_processor,
928972
grammar=grammar,
929973
penalize_nl=penalize_nl,
@@ -1140,6 +1184,13 @@ def _create_completion(
11401184
mirostat_mode: int = 0,
11411185
mirostat_tau: float = 5.0,
11421186
mirostat_eta: float = 0.1,
1187+
xtc_probability: float = 0.0,
1188+
xtc_threshold: float = 0.1,
1189+
dry_multiplier: float = 0.0,
1190+
dry_allowed_length: int = 2,
1191+
dry_base: float = 1.75,
1192+
dry_range: int = 0,
1193+
dry_seq_breakers: list[str] = [],
11431194
model: Optional[str] = None,
11441195
stopping_criteria: Optional[StoppingCriteriaList] = None,
11451196
logits_processor: Optional[LogitsProcessorList] = None,
@@ -1328,6 +1379,13 @@ def logit_bias_processor(
13281379
mirostat_mode=mirostat_mode,
13291380
mirostat_tau=mirostat_tau,
13301381
mirostat_eta=mirostat_eta,
1382+
xtc_probability=xtc_probability,
1383+
xtc_threshold=xtc_threshold,
1384+
dry_multiplier=dry_multiplier,
1385+
dry_allowed_length=dry_allowed_length,
1386+
dry_base=dry_base,
1387+
dry_range=dry_range,
1388+
dry_seq_breakers=dry_seq_breakers,
13311389
frequency_penalty=frequency_penalty,
13321390
presence_penalty=presence_penalty,
13331391
repeat_penalty=repeat_penalty,
@@ -1760,6 +1818,13 @@ def create_completion(
17601818
mirostat_mode: int = 0,
17611819
mirostat_tau: float = 5.0,
17621820
mirostat_eta: float = 0.1,
1821+
xtc_probability: float = 0.0,
1822+
xtc_threshold: float = 0.1,
1823+
dry_multiplier: float = 0.0,
1824+
dry_allowed_length: int = 2,
1825+
dry_base: float = 1.75,
1826+
dry_range: int = 0,
1827+
dry_seq_breakers: list[str] = [],
17631828
model: Optional[str] = None,
17641829
stopping_criteria: Optional[StoppingCriteriaList] = None,
17651830
logits_processor: Optional[LogitsProcessorList] = None,
@@ -1823,6 +1888,13 @@ def create_completion(
18231888
mirostat_mode=mirostat_mode,
18241889
mirostat_tau=mirostat_tau,
18251890
mirostat_eta=mirostat_eta,
1891+
xtc_probability=xtc_probability,
1892+
xtc_threshold=xtc_threshold,
1893+
dry_multiplier=dry_multiplier,
1894+
dry_allowed_length=dry_allowed_length,
1895+
dry_base=dry_base,
1896+
dry_range=dry_range,
1897+
dry_seq_breakers=dry_seq_breakers,
18261898
model=model,
18271899
stopping_criteria=stopping_criteria,
18281900
logits_processor=logits_processor,
@@ -1857,6 +1929,13 @@ def __call__(
18571929
mirostat_mode: int = 0,
18581930
mirostat_tau: float = 5.0,
18591931
mirostat_eta: float = 0.1,
1932+
xtc_probability: float = 0.0,
1933+
xtc_threshold: float = 0.1,
1934+
dry_multiplier: float = 0.0,
1935+
dry_allowed_length: int = 2,
1936+
dry_base: float = 1.75,
1937+
dry_range: int = 0,
1938+
dry_seq_breakers: list[str] = [],
18601939
model: Optional[str] = None,
18611940
stopping_criteria: Optional[StoppingCriteriaList] = None,
18621941
logits_processor: Optional[LogitsProcessorList] = None,
@@ -1920,6 +1999,13 @@ def __call__(
19201999
mirostat_mode=mirostat_mode,
19212000
mirostat_tau=mirostat_tau,
19222001
mirostat_eta=mirostat_eta,
2002+
xtc_probability=xtc_probability,
2003+
xtc_threshold=xtc_threshold,
2004+
dry_multiplier=dry_multiplier,
2005+
dry_allowed_length=dry_allowed_length,
2006+
dry_base=dry_base,
2007+
dry_range=dry_range,
2008+
dry_seq_breakers=dry_seq_breakers,
19232009
model=model,
19242010
stopping_criteria=stopping_criteria,
19252011
logits_processor=logits_processor,
@@ -1951,6 +2037,13 @@ def create_chat_completion(
19512037
mirostat_mode: int = 0,
19522038
mirostat_tau: float = 5.0,
19532039
mirostat_eta: float = 0.1,
2040+
xtc_probability: float = 0.0,
2041+
xtc_threshold: float = 0.1,
2042+
dry_multiplier: float = 0.0,
2043+
dry_allowed_length: int = 2,
2044+
dry_base: float = 1.75,
2045+
dry_range: int = 0,
2046+
dry_seq_breakers: list[str] = [],
19542047
model: Optional[str] = None,
19552048
logits_processor: Optional[LogitsProcessorList] = None,
19562049
grammar: Optional[LlamaGrammar] = None,
@@ -2024,6 +2117,13 @@ def create_chat_completion(
20242117
mirostat_mode=mirostat_mode,
20252118
mirostat_tau=mirostat_tau,
20262119
mirostat_eta=mirostat_eta,
2120+
xtc_probability=xtc_probability,
2121+
xtc_threshold=xtc_threshold,
2122+
dry_multiplier=dry_multiplier,
2123+
dry_allowed_length=dry_allowed_length,
2124+
dry_base=dry_base,
2125+
dry_range=dry_range,
2126+
dry_seq_breakers=dry_seq_breakers,
20272127
model=model,
20282128
logits_processor=logits_processor,
20292129
grammar=grammar,

llama_cpp/llama_cpp.py

+32
Original file line numberDiff line numberDiff line change
@@ -3626,6 +3626,38 @@ def llama_sampler_init_xtc(
36263626
) -> llama_sampler_p:
36273627
...
36283628

3629+
# LLAMA_API struct llama_sampler * llama_sampler_init_dry(
3630+
# const struct llama_model * model,
3631+
# float dry_multiplier,
3632+
# float dry_base,
3633+
# int32_t dry_allowed_length,
3634+
# int32_t dry_penalty_last_n,
3635+
# const char ** seq_breakers,
3636+
# size_t num_breakers);
3637+
@ctypes_function(
3638+
"llama_sampler_init_dry",
3639+
[
3640+
llama_model_p_ctypes,
3641+
ctypes.c_float,
3642+
ctypes.c_float,
3643+
ctypes.c_int32,
3644+
ctypes.c_int32,
3645+
ctypes.POINTER(ctypes.c_char_p),
3646+
ctypes.c_size_t
3647+
],
3648+
llama_sampler_p_ctypes,
3649+
)
3650+
def llama_sampler_init_dry(
3651+
model: llama_model_p,
3652+
dry_multiplier: float,
3653+
dry_base: float,
3654+
dry_allowed_length: int,
3655+
dry_penalty_last_n: int,
3656+
seq_breakers: list[str],
3657+
num_breakers: int,
3658+
) -> llama_sampler_p:
3659+
...
3660+
36293661

36303662
# /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
36313663
# /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.

0 commit comments

Comments
 (0)