@@ -680,6 +680,13 @@ def _init_sampler(
680
680
mirostat_mode : int = 0 ,
681
681
mirostat_eta : float = 0.1 ,
682
682
mirostat_tau : float = 5.0 ,
683
+ xtc_probability : float = 0.0 ,
684
+ xtc_threshold : float = 0.1 ,
685
+ dry_multiplier : float = 0.0 ,
686
+ dry_allowed_length : int = 2 ,
687
+ dry_base : float = 1.75 ,
688
+ dry_range : int = 0 ,
689
+ dry_seq_breakers : list [str ] = [],
683
690
penalize_nl : bool = True ,
684
691
logits_processor : Optional [LogitsProcessorList ] = None ,
685
692
grammar : Optional [LlamaGrammar ] = None ,
@@ -747,12 +754,14 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
747
754
else :
748
755
n_probs = 0
749
756
min_keep = max (1 , n_probs )
757
+ sampler .add_dry (self ._model , dry_multiplier , dry_base , dry_allowed_length , dry_range , dry_seq_breakers )
750
758
sampler .add_top_k (top_k )
751
759
sampler .add_typical (typical_p , min_keep )
752
760
sampler .add_top_p (top_p , min_keep )
753
761
sampler .add_min_p (min_p , min_keep )
754
762
sampler .add_temp (temp )
755
763
sampler .add_dist (self ._seed )
764
+ sampler .add_xtc (xtc_probability , xtc_threshold , min_keep , self ._seed )
756
765
return sampler
757
766
758
767
def sample (
@@ -769,6 +778,13 @@ def sample(
769
778
mirostat_mode : int = 0 ,
770
779
mirostat_eta : float = 0.1 ,
771
780
mirostat_tau : float = 5.0 ,
781
+ xtc_probability : float = 0.0 ,
782
+ xtc_threshold : float = 0.1 ,
783
+ dry_multiplier : float = 0.0 ,
784
+ dry_allowed_length : int = 2 ,
785
+ dry_base : float = 1.75 ,
786
+ dry_range : int = 0 ,
787
+ dry_seq_breakers : list [str ] = [],
772
788
penalize_nl : bool = True ,
773
789
logits_processor : Optional [LogitsProcessorList ] = None ,
774
790
grammar : Optional [LlamaGrammar ] = None ,
@@ -804,6 +820,13 @@ def sample(
804
820
mirostat_mode = mirostat_mode ,
805
821
mirostat_tau = mirostat_tau ,
806
822
mirostat_eta = mirostat_eta ,
823
+ xtc_probability = xtc_probability ,
824
+ xtc_threshold = xtc_threshold ,
825
+ dry_multiplier = dry_multiplier ,
826
+ dry_allowed_length = dry_allowed_length ,
827
+ dry_base = dry_base ,
828
+ dry_range = dry_range ,
829
+ dry_seq_breakers = dry_seq_breakers ,
807
830
penalize_nl = penalize_nl ,
808
831
logits_processor = logits_processor ,
809
832
grammar = grammar ,
@@ -833,6 +856,13 @@ def generate(
833
856
mirostat_mode : int = 0 ,
834
857
mirostat_tau : float = 5.0 ,
835
858
mirostat_eta : float = 0.1 ,
859
+ xtc_probability : float = 0.0 ,
860
+ xtc_threshold : float = 0.1 ,
861
+ dry_multiplier : float = 0.0 ,
862
+ dry_allowed_length : int = 2 ,
863
+ dry_base : float = 1.75 ,
864
+ dry_range : int = 0 ,
865
+ dry_seq_breakers : list [str ] = [],
836
866
penalize_nl : bool = True ,
837
867
logits_processor : Optional [LogitsProcessorList ] = None ,
838
868
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
@@ -872,6 +902,13 @@ def generate(
872
902
mirostat_mode = mirostat_mode ,
873
903
mirostat_tau = mirostat_tau ,
874
904
mirostat_eta = mirostat_eta ,
905
+ xtc_probability = xtc_probability ,
906
+ xtc_threshold = xtc_threshold ,
907
+ dry_multiplier = dry_multiplier ,
908
+ dry_allowed_length = dry_allowed_length ,
909
+ dry_base = dry_base ,
910
+ dry_range = dry_range ,
911
+ dry_seq_breakers = dry_seq_breakers ,
875
912
penalize_nl = penalize_nl ,
876
913
logits_processor = logits_processor ,
877
914
grammar = grammar ,
@@ -924,6 +961,13 @@ def generate(
924
961
mirostat_mode = mirostat_mode ,
925
962
mirostat_tau = mirostat_tau ,
926
963
mirostat_eta = mirostat_eta ,
964
+ xtc_probability = xtc_probability ,
965
+ xtc_threshold = xtc_threshold ,
966
+ dry_multiplier = dry_multiplier ,
967
+ dry_allowed_length = dry_allowed_length ,
968
+ dry_base = dry_base ,
969
+ dry_range = dry_range ,
970
+ dry_seq_breakers = dry_seq_breakers ,
927
971
logits_processor = logits_processor ,
928
972
grammar = grammar ,
929
973
penalize_nl = penalize_nl ,
@@ -1140,6 +1184,13 @@ def _create_completion(
1140
1184
mirostat_mode : int = 0 ,
1141
1185
mirostat_tau : float = 5.0 ,
1142
1186
mirostat_eta : float = 0.1 ,
1187
+ xtc_probability : float = 0.0 ,
1188
+ xtc_threshold : float = 0.1 ,
1189
+ dry_multiplier : float = 0.0 ,
1190
+ dry_allowed_length : int = 2 ,
1191
+ dry_base : float = 1.75 ,
1192
+ dry_range : int = 0 ,
1193
+ dry_seq_breakers : list [str ] = [],
1143
1194
model : Optional [str ] = None ,
1144
1195
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1145
1196
logits_processor : Optional [LogitsProcessorList ] = None ,
@@ -1328,6 +1379,13 @@ def logit_bias_processor(
1328
1379
mirostat_mode = mirostat_mode ,
1329
1380
mirostat_tau = mirostat_tau ,
1330
1381
mirostat_eta = mirostat_eta ,
1382
+ xtc_probability = xtc_probability ,
1383
+ xtc_threshold = xtc_threshold ,
1384
+ dry_multiplier = dry_multiplier ,
1385
+ dry_allowed_length = dry_allowed_length ,
1386
+ dry_base = dry_base ,
1387
+ dry_range = dry_range ,
1388
+ dry_seq_breakers = dry_seq_breakers ,
1331
1389
frequency_penalty = frequency_penalty ,
1332
1390
presence_penalty = presence_penalty ,
1333
1391
repeat_penalty = repeat_penalty ,
@@ -1760,6 +1818,13 @@ def create_completion(
1760
1818
mirostat_mode : int = 0 ,
1761
1819
mirostat_tau : float = 5.0 ,
1762
1820
mirostat_eta : float = 0.1 ,
1821
+ xtc_probability : float = 0.0 ,
1822
+ xtc_threshold : float = 0.1 ,
1823
+ dry_multiplier : float = 0.0 ,
1824
+ dry_allowed_length : int = 2 ,
1825
+ dry_base : float = 1.75 ,
1826
+ dry_range : int = 0 ,
1827
+ dry_seq_breakers : list [str ] = [],
1763
1828
model : Optional [str ] = None ,
1764
1829
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1765
1830
logits_processor : Optional [LogitsProcessorList ] = None ,
@@ -1823,6 +1888,13 @@ def create_completion(
1823
1888
mirostat_mode = mirostat_mode ,
1824
1889
mirostat_tau = mirostat_tau ,
1825
1890
mirostat_eta = mirostat_eta ,
1891
+ xtc_probability = xtc_probability ,
1892
+ xtc_threshold = xtc_threshold ,
1893
+ dry_multiplier = dry_multiplier ,
1894
+ dry_allowed_length = dry_allowed_length ,
1895
+ dry_base = dry_base ,
1896
+ dry_range = dry_range ,
1897
+ dry_seq_breakers = dry_seq_breakers ,
1826
1898
model = model ,
1827
1899
stopping_criteria = stopping_criteria ,
1828
1900
logits_processor = logits_processor ,
@@ -1857,6 +1929,13 @@ def __call__(
1857
1929
mirostat_mode : int = 0 ,
1858
1930
mirostat_tau : float = 5.0 ,
1859
1931
mirostat_eta : float = 0.1 ,
1932
+ xtc_probability : float = 0.0 ,
1933
+ xtc_threshold : float = 0.1 ,
1934
+ dry_multiplier : float = 0.0 ,
1935
+ dry_allowed_length : int = 2 ,
1936
+ dry_base : float = 1.75 ,
1937
+ dry_range : int = 0 ,
1938
+ dry_seq_breakers : list [str ] = [],
1860
1939
model : Optional [str ] = None ,
1861
1940
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
1862
1941
logits_processor : Optional [LogitsProcessorList ] = None ,
@@ -1920,6 +1999,13 @@ def __call__(
1920
1999
mirostat_mode = mirostat_mode ,
1921
2000
mirostat_tau = mirostat_tau ,
1922
2001
mirostat_eta = mirostat_eta ,
2002
+ xtc_probability = xtc_probability ,
2003
+ xtc_threshold = xtc_threshold ,
2004
+ dry_multiplier = dry_multiplier ,
2005
+ dry_allowed_length = dry_allowed_length ,
2006
+ dry_base = dry_base ,
2007
+ dry_range = dry_range ,
2008
+ dry_seq_breakers = dry_seq_breakers ,
1923
2009
model = model ,
1924
2010
stopping_criteria = stopping_criteria ,
1925
2011
logits_processor = logits_processor ,
@@ -1951,6 +2037,13 @@ def create_chat_completion(
1951
2037
mirostat_mode : int = 0 ,
1952
2038
mirostat_tau : float = 5.0 ,
1953
2039
mirostat_eta : float = 0.1 ,
2040
+ xtc_probability : float = 0.0 ,
2041
+ xtc_threshold : float = 0.1 ,
2042
+ dry_multiplier : float = 0.0 ,
2043
+ dry_allowed_length : int = 2 ,
2044
+ dry_base : float = 1.75 ,
2045
+ dry_range : int = 0 ,
2046
+ dry_seq_breakers : list [str ] = [],
1954
2047
model : Optional [str ] = None ,
1955
2048
logits_processor : Optional [LogitsProcessorList ] = None ,
1956
2049
grammar : Optional [LlamaGrammar ] = None ,
@@ -2024,6 +2117,13 @@ def create_chat_completion(
2024
2117
mirostat_mode = mirostat_mode ,
2025
2118
mirostat_tau = mirostat_tau ,
2026
2119
mirostat_eta = mirostat_eta ,
2120
+ xtc_probability = xtc_probability ,
2121
+ xtc_threshold = xtc_threshold ,
2122
+ dry_multiplier = dry_multiplier ,
2123
+ dry_allowed_length = dry_allowed_length ,
2124
+ dry_base = dry_base ,
2125
+ dry_range = dry_range ,
2126
+ dry_seq_breakers = dry_seq_breakers ,
2027
2127
model = model ,
2028
2128
logits_processor = logits_processor ,
2029
2129
grammar = grammar ,
0 commit comments