@@ -1116,6 +1116,56 @@ def decode_batch(seq_sizes: List[int]):
1116
1116
else :
1117
1117
return output
1118
1118
1119
+ def _create_chunk (
1120
+ self ,
1121
+ completion_id : str ,
1122
+ created : int ,
1123
+ model_name : str ,
1124
+ text : str ,
1125
+ logprobs_or_none : Union [Optional [CompletionLogprobs ], None ],
1126
+ include_usage : bool ,
1127
+ index : int ,
1128
+ finish_reason : Union [str , None ],
1129
+ usage : Union [Dict [str , Any ], None ] = None ,
1130
+ ) -> CreateChatCompletionStreamResponse :
1131
+ """
1132
+ Create chunks for streaming API, depending on whether usage is requested or
1133
+ not they need (or don't need) an additional field
1134
+ """
1135
+
1136
+ if include_usage :
1137
+ token = {
1138
+ "id" : completion_id ,
1139
+ "object" : "text_completion" ,
1140
+ "created" : created ,
1141
+ "model" : model_name ,
1142
+ "choices" : [
1143
+ {
1144
+ "text" : text ,
1145
+ "index" : index ,
1146
+ "logprobs" : logprobs_or_none ,
1147
+ "finish_reason" : finish_reason ,
1148
+ },
1149
+ ],
1150
+ "usage" : usage ,
1151
+ }
1152
+ else :
1153
+ token = {
1154
+ "id" : completion_id ,
1155
+ "object" : "text_completion" ,
1156
+ "created" : created ,
1157
+ "model" : model_name ,
1158
+ "choices" : [
1159
+ {
1160
+ "text" : text ,
1161
+ "index" : index ,
1162
+ "logprobs" : logprobs_or_none ,
1163
+ "finish_reason" : finish_reason ,
1164
+ }
1165
+ ],
1166
+ }
1167
+ return token
1168
+
1119
1169
def _create_completion (
1120
1170
self ,
1121
1171
prompt : Union [str , List [int ]],
@@ -1133,6 +1183,7 @@ def _create_completion(
1133
1183
repeat_penalty : float = 1.0 ,
1134
1184
top_k : int = 40 ,
1135
1185
stream : bool = False ,
1186
+ stream_options : Optional [StreamOptions ] = None ,
1136
1187
seed : Optional [int ] = None ,
1137
1188
tfs_z : float = 1.0 ,
1138
1189
mirostat_mode : int = 0 ,
@@ -1363,6 +1414,10 @@ def logit_bias_processor(
1363
1414
break
1364
1415
1365
1416
if stream :
1417
+ if stream_options is None or stream_options .include_usage == None :
1418
+ include_usage = False
1419
+ else :
1420
+ include_usage = stream_options .include_usage
1366
1421
remaining_tokens = completion_tokens [returned_tokens :]
1367
1422
remaining_text = self .detokenize (
1368
1423
remaining_tokens ,
@@ -1442,24 +1497,23 @@ def logit_bias_processor(
1442
1497
"top_logprobs" : [top_logprob ],
1443
1498
}
1444
1499
returned_tokens += 1
1445
- yield {
1446
- "id" : completion_id ,
1447
- "object" : "text_completion" ,
1448
- "created" : created ,
1449
- "model" : model_name ,
1450
- "choices" : [
1451
- {
1452
- "text" : self .detokenize (
1453
- [token ],
1454
- prev_tokens = prompt_tokens
1455
- + completion_tokens [:returned_tokens ],
1456
- ).decode ("utf-8" , errors = "ignore" ),
1457
- "index" : 0 ,
1458
- "logprobs" : logprobs_or_none ,
1459
- "finish_reason" : None ,
1460
- }
1461
- ],
1462
- }
1500
+ text = (
1501
+ self .detokenize (
1502
+ [token ],
1503
+ prev_tokens = prompt_tokens
1504
+ + completion_tokens [:returned_tokens ],
1505
+ ).decode ("utf-8" , errors = "ignore" ),
1506
+ )
1507
+ yield self ._create_chunk (
1508
+ completion_id = completion_id ,
1509
+ created = created ,
1510
+ model_name = model_name ,
1511
+ text = text ,
1512
+ finish_reason = None ,
1513
+ index = 0 ,
1514
+ logprobs_or_none = logprobs_or_none ,
1515
+ include_usage = include_usage ,
1516
+ )
1463
1517
else :
1464
1518
while len (remaining_tokens ) > 0 :
1465
1519
decode_success = False
@@ -1488,20 +1542,16 @@ def logit_bias_processor(
1488
1542
remaining_tokens = remaining_tokens [i :]
1489
1543
returned_tokens += i
1490
1544
1491
- yield {
1492
- "id" : completion_id ,
1493
- "object" : "text_completion" ,
1494
- "created" : created ,
1495
- "model" : model_name ,
1496
- "choices" : [
1497
- {
1498
- "text" : ts ,
1499
- "index" : 0 ,
1500
- "logprobs" : None ,
1501
- "finish_reason" : None ,
1502
- }
1503
- ],
1504
- }
1545
+ yield self ._create_chunk (
1546
+ index = 0 ,
1547
+ finish_reason = None ,
1548
+ completion_id = completion_id ,
1549
+ created = created ,
1550
+ model_name = model_name ,
1551
+ text = ts ,
1552
+ logprobs_or_none = None ,
1553
+ include_usage = include_usage ,
1554
+ )
1505
1555
1506
1556
if len (completion_tokens ) >= max_tokens :
1507
1557
text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
@@ -1580,54 +1630,60 @@ def logit_bias_processor(
1580
1630
if token_end_position == end - 1 :
1581
1631
break
1582
1632
returned_tokens += 1
1583
- yield {
1584
- "id" : completion_id ,
1585
- "object" : "text_completion" ,
1586
- "created" : created ,
1587
- "model" : model_name ,
1588
- "choices" : [
1589
- {
1590
- "text" : last_text [
1591
- : len (last_text ) - (token_end_position - end )
1592
- ].decode ("utf-8" , errors = "ignore" ),
1593
- "index" : 0 ,
1594
- "logprobs" : logprobs_or_none ,
1595
- "finish_reason" : None ,
1596
- }
1597
- ],
1598
- }
1633
+ text = last_text [
1634
+ : len (last_text ) - (token_end_position - end )
1635
+ ].decode ("utf-8" , errors = "ignore" )
1636
+
1637
+ yield self ._create_chunk (
1638
+ completion_id = completion_id ,
1639
+ created = created ,
1640
+ model_name = model_name ,
1641
+ text = text ,
1642
+ logprobs_or_none = logprobs_or_none ,
1643
+ include_usage = include_usage ,
1644
+ index = 0 ,
1645
+ finish_reason = None ,
1646
+ )
1599
1647
break
1600
1648
returned_tokens += 1
1601
- yield {
1602
- "id" : completion_id ,
1603
- "object" : "text_completion" ,
1604
- "created" : created ,
1605
- "model" : model_name ,
1606
- "choices" : [
1607
- {
1608
- "text" : self .detokenize ([token ]).decode (
1609
- "utf-8" , errors = "ignore"
1610
- ),
1611
- "index" : 0 ,
1612
- "logprobs" : logprobs_or_none ,
1613
- "finish_reason" : None ,
1614
- }
1615
- ],
1616
- }
1617
- yield {
1618
- "id" : completion_id ,
1619
- "object" : "text_completion" ,
1620
- "created" : created ,
1621
- "model" : model_name ,
1622
- "choices" : [
1623
- {
1624
- "text" : "" ,
1625
- "index" : 0 ,
1626
- "logprobs" : None ,
1627
- "finish_reason" : finish_reason ,
1628
- }
1629
- ],
1630
- }
1649
+ text = self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
1650
+ yield self ._create_chunk (
1651
+ completion_id = completion_id ,
1652
+ created = created ,
1653
+ model_name = model_name ,
1654
+ text = text ,
1655
+ logprobs_or_none = logprobs_or_none ,
1656
+ include_usage = include_usage ,
1657
+ index = 0 ,
1658
+ finish_reason = None ,
1659
+ )
1660
+ yield self ._create_chunk (
1661
+ completion_id = completion_id ,
1662
+ created = created ,
1663
+ model_name = model_name ,
1664
+ text = "" ,
1665
+ index = 0 ,
1666
+ logprobs_or_none = None ,
1667
+ include_usage = include_usage ,
1668
+ usage = None ,
1669
+ finish_reason = finish_reason )
1670
+
1671
+ if include_usage :
1672
+ yield self ._create_chunk (
1673
+ completion_id = completion_id ,
1674
+ created = created ,
1675
+ model_name = model_name ,
1676
+ text = "" ,
1677
+ logprobs_or_none = None ,
1678
+ include_usage = include_usage ,
1679
+ index = 0 ,
1680
+ finish_reason = None ,
1681
+ usage = {
1682
+ "prompt_tokens" : len (prompt_tokens ),
1683
+ "completion_tokens" : returned_tokens ,
1684
+ "total_tokens" : len (prompt_tokens ) + returned_tokens ,
1685
+ },
1686
+ )
1631
1687
if self .cache :
1632
1688
if self .verbose :
1633
1689
print ("Llama._create_completion: cache save" , file = sys .stderr )
@@ -1736,6 +1792,7 @@ def logit_bias_processor(
1736
1792
},
1737
1793
}
1738
1794
1795
+
1739
1796
def create_completion (
1740
1797
self ,
1741
1798
prompt : Union [str , List [int ]],
@@ -1753,6 +1810,7 @@ def create_completion(
1753
1810
repeat_penalty : float = 1.0 ,
1754
1811
top_k : int = 40 ,
1755
1812
stream : bool = False ,
1813
+ stream_options : Optional [StreamOptions ] = None ,
1756
1814
seed : Optional [int ] = None ,
1757
1815
tfs_z : float = 1.0 ,
1758
1816
mirostat_mode : int = 0 ,
@@ -1816,6 +1874,7 @@ def create_completion(
1816
1874
repeat_penalty = repeat_penalty ,
1817
1875
top_k = top_k ,
1818
1876
stream = stream ,
1877
+ stream_options = stream_options ,
1819
1878
seed = seed ,
1820
1879
tfs_z = tfs_z ,
1821
1880
mirostat_mode = mirostat_mode ,
@@ -1850,6 +1909,7 @@ def __call__(
1850
1909
repeat_penalty : float = 1.0 ,
1851
1910
top_k : int = 40 ,
1852
1911
stream : bool = False ,
1912
+ stream_options : Optional [StreamOptions ] = None ,
1853
1913
seed : Optional [int ] = None ,
1854
1914
tfs_z : float = 1.0 ,
1855
1915
mirostat_mode : int = 0 ,
@@ -1913,6 +1973,7 @@ def __call__(
1913
1973
repeat_penalty = repeat_penalty ,
1914
1974
top_k = top_k ,
1915
1975
stream = stream ,
1976
+ stream_options = stream_options ,
1916
1977
seed = seed ,
1917
1978
tfs_z = tfs_z ,
1918
1979
mirostat_mode = mirostat_mode ,
@@ -1938,6 +1999,7 @@ def create_chat_completion(
1938
1999
min_p : float = 0.05 ,
1939
2000
typical_p : float = 1.0 ,
1940
2001
stream : bool = False ,
2002
+ stream_options : Optional [StreamOptions ] = False ,
1941
2003
stop : Optional [Union [str , List [str ]]] = [],
1942
2004
seed : Optional [int ] = None ,
1943
2005
response_format : Optional [ChatCompletionRequestResponseFormat ] = None ,
@@ -2011,6 +2073,7 @@ def create_chat_completion(
2011
2073
logprobs = logprobs ,
2012
2074
top_logprobs = top_logprobs ,
2013
2075
stream = stream ,
2076
+ stream_options = stream_options ,
2014
2077
stop = stop ,
2015
2078
seed = seed ,
2016
2079
response_format = response_format ,
0 commit comments