@@ -1116,6 +1116,56 @@ def decode_batch(seq_sizes: List[int]):
1116
1116
else :
1117
1117
return output
1118
1118
1119
+ def _create_chunk (
1120
+ self ,
1121
+ completion_id : str ,
1122
+ created : int ,
1123
+ model_name : str ,
1124
+ text : str ,
1125
+ logprobs_or_none : Union [Optional [CompletionLogprobs ], None ],
1126
+ include_usage : bool ,
1127
+ index : int ,
1128
+ finish_reason : Union [str , None ],
1129
+ usage : Union [Dict [str , Any ], None ] = None ,
1130
+ ) -> CreateChatCompletionStreamResponse :
1131
+ """
1132
+ Create chunks for streaming API, depending on whether usage is requested or
1133
+ not they need (or don't need) an additional field
1134
+ """
1135
+
1136
+ if include_usage :
1137
+ token = {
1138
+ "id" : completion_id ,
1139
+ "object" : "text_completion" ,
1140
+ "created" : created ,
1141
+ "model" : model_name ,
1142
+ "choices" : [
1143
+ {
1144
+ "text" : text ,
1145
+ "index" : index ,
1146
+ "logprobs" : logprobs_or_none ,
1147
+ "finish_reason" : finish_reason ,
1148
+ },
1149
+ ],
1150
+ "usage" : usage ,
1151
+ }
1152
+ else :
1153
+ token = {
1154
+ "id" : completion_id ,
1155
+ "object" : "text_completion" ,
1156
+ "created" : created ,
1157
+ "model" : model_name ,
1158
+ "choices" : [
1159
+ {
1160
+ "text" : text ,
1161
+ "index" : index ,
1162
+ "logprobs" : logprobs_or_none ,
1163
+ "finish_reason" : finish_reason ,
1164
+ }
1165
+ ],
1166
+ }
1167
+ return token
1168
+
1119
1169
def _create_completion (
1120
1170
self ,
1121
1171
prompt : Union [str , List [int ]],
@@ -1133,6 +1183,7 @@ def _create_completion(
1133
1183
repeat_penalty : float = 1.0 ,
1134
1184
top_k : int = 40 ,
1135
1185
stream : bool = False ,
1186
+ stream_options : Optional [StreamOptions ] = None ,
1136
1187
seed : Optional [int ] = None ,
1137
1188
tfs_z : float = 1.0 ,
1138
1189
mirostat_mode : int = 0 ,
@@ -1363,6 +1414,11 @@ def logit_bias_processor(
1363
1414
break
1364
1415
1365
1416
if stream :
1417
+ if stream_options is not None and "include_usage" in stream_options :
1418
+ include_usage = True if stream_options ["include_usage" ] else False
1419
+ else :
1420
+ include_usage = False
1421
+
1366
1422
remaining_tokens = completion_tokens [returned_tokens :]
1367
1423
remaining_text = self .detokenize (
1368
1424
remaining_tokens ,
@@ -1442,24 +1498,23 @@ def logit_bias_processor(
1442
1498
"top_logprobs" : [top_logprob ],
1443
1499
}
1444
1500
returned_tokens += 1
1445
- yield {
1446
- "id" : completion_id ,
1447
- "object" : "text_completion" ,
1448
- "created" : created ,
1449
- "model" : model_name ,
1450
- "choices" : [
1451
- {
1452
- "text" : self .detokenize (
1453
- [token ],
1454
- prev_tokens = prompt_tokens
1455
- + completion_tokens [:returned_tokens ],
1456
- ).decode ("utf-8" , errors = "ignore" ),
1457
- "index" : 0 ,
1458
- "logprobs" : logprobs_or_none ,
1459
- "finish_reason" : None ,
1460
- }
1461
- ],
1462
- }
1501
+ text = (
1502
+ self .detokenize (
1503
+ [token ],
1504
+ prev_tokens = prompt_tokens
1505
+ + completion_tokens [:returned_tokens ],
1506
+ ).decode ("utf-8" , errors = "ignore" ),
1507
+ )
1508
+ yield self ._create_chunk (
1509
+ completion_id = completion_id ,
1510
+ created = created ,
1511
+ model_name = model_name ,
1512
+ text = text ,
1513
+ finish_reason = None ,
1514
+ index = 0 ,
1515
+ logprobs_or_none = logprobs_or_none ,
1516
+ include_usage = include_usage ,
1517
+ )
1463
1518
else :
1464
1519
while len (remaining_tokens ) > 0 :
1465
1520
decode_success = False
@@ -1488,20 +1543,16 @@ def logit_bias_processor(
1488
1543
remaining_tokens = remaining_tokens [i :]
1489
1544
returned_tokens += i
1490
1545
1491
- yield {
1492
- "id" : completion_id ,
1493
- "object" : "text_completion" ,
1494
- "created" : created ,
1495
- "model" : model_name ,
1496
- "choices" : [
1497
- {
1498
- "text" : ts ,
1499
- "index" : 0 ,
1500
- "logprobs" : None ,
1501
- "finish_reason" : None ,
1502
- }
1503
- ],
1504
- }
1546
+ yield self ._create_chunk (
1547
+ index = 0 ,
1548
+ finish_reason = None ,
1549
+ completion_id = completion_id ,
1550
+ created = created ,
1551
+ model_name = model_name ,
1552
+ text = ts ,
1553
+ logprobs_or_none = None ,
1554
+ include_usage = include_usage ,
1555
+ )
1505
1556
1506
1557
if len (completion_tokens ) >= max_tokens :
1507
1558
text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
@@ -1580,54 +1631,60 @@ def logit_bias_processor(
1580
1631
if token_end_position == end - 1 :
1581
1632
break
1582
1633
returned_tokens += 1
1583
- yield {
1584
- "id" : completion_id ,
1585
- "object" : "text_completion" ,
1586
- "created" : created ,
1587
- "model" : model_name ,
1588
- "choices" : [
1589
- {
1590
- "text" : last_text [
1591
- : len (last_text ) - (token_end_position - end )
1592
- ].decode ("utf-8" , errors = "ignore" ),
1593
- "index" : 0 ,
1594
- "logprobs" : logprobs_or_none ,
1595
- "finish_reason" : None ,
1596
- }
1597
- ],
1598
- }
1634
+ text = last_text [
1635
+ : len (last_text ) - (token_end_position - end )
1636
+ ].decode ("utf-8" , errors = "ignore" )
1637
+
1638
+ yield self ._create_chunk (
1639
+ completion_id = completion_id ,
1640
+ created = created ,
1641
+ model_name = model_name ,
1642
+ text = text ,
1643
+ logprobs_or_none = logprobs_or_none ,
1644
+ include_usage = include_usage ,
1645
+ index = 0 ,
1646
+ finish_reason = None ,
1647
+ )
1599
1648
break
1600
1649
returned_tokens += 1
1601
- yield {
1602
- "id" : completion_id ,
1603
- "object" : "text_completion" ,
1604
- "created" : created ,
1605
- "model" : model_name ,
1606
- "choices" : [
1607
- {
1608
- "text" : self .detokenize ([token ]).decode (
1609
- "utf-8" , errors = "ignore"
1610
- ),
1611
- "index" : 0 ,
1612
- "logprobs" : logprobs_or_none ,
1613
- "finish_reason" : None ,
1614
- }
1615
- ],
1616
- }
1617
- yield {
1618
- "id" : completion_id ,
1619
- "object" : "text_completion" ,
1620
- "created" : created ,
1621
- "model" : model_name ,
1622
- "choices" : [
1623
- {
1624
- "text" : "" ,
1625
- "index" : 0 ,
1626
- "logprobs" : None ,
1627
- "finish_reason" : finish_reason ,
1628
- }
1629
- ],
1630
- }
1650
+ text = self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
1651
+ yield self ._create_chunk (
1652
+ completion_id = completion_id ,
1653
+ created = created ,
1654
+ model_name = model_name ,
1655
+ text = text ,
1656
+ logprobs_or_none = logprobs_or_none ,
1657
+ include_usage = include_usage ,
1658
+ index = 0 ,
1659
+ finish_reason = None ,
1660
+ )
1661
+ yield self ._create_chunk (
1662
+ completion_id = completion_id ,
1663
+ created = created ,
1664
+ model_name = model_name ,
1665
+ text = "" ,
1666
+ index = 0 ,
1667
+ logprobs_or_none = None ,
1668
+ include_usage = include_usage ,
1669
+ usage = None ,
1670
+ finish_reason = finish_reason )
1671
+
1672
+ if include_usage :
1673
+ yield self ._create_chunk (
1674
+ completion_id = completion_id ,
1675
+ created = created ,
1676
+ model_name = model_name ,
1677
+ text = "" ,
1678
+ logprobs_or_none = None ,
1679
+ include_usage = include_usage ,
1680
+ index = 0 ,
1681
+ finish_reason = None ,
1682
+ usage = {
1683
+ "prompt_tokens" : len (prompt_tokens ),
1684
+ "completion_tokens" : returned_tokens ,
1685
+ "total_tokens" : len (prompt_tokens ) + returned_tokens ,
1686
+ },
1687
+ )
1631
1688
if self .cache :
1632
1689
if self .verbose :
1633
1690
print ("Llama._create_completion: cache save" , file = sys .stderr )
@@ -1736,6 +1793,7 @@ def logit_bias_processor(
1736
1793
},
1737
1794
}
1738
1795
1796
+
1739
1797
def create_completion (
1740
1798
self ,
1741
1799
prompt : Union [str , List [int ]],
@@ -1753,6 +1811,7 @@ def create_completion(
1753
1811
repeat_penalty : float = 1.0 ,
1754
1812
top_k : int = 40 ,
1755
1813
stream : bool = False ,
1814
+ stream_options : Optional [StreamOptions ] = None ,
1756
1815
seed : Optional [int ] = None ,
1757
1816
tfs_z : float = 1.0 ,
1758
1817
mirostat_mode : int = 0 ,
@@ -1816,6 +1875,7 @@ def create_completion(
1816
1875
repeat_penalty = repeat_penalty ,
1817
1876
top_k = top_k ,
1818
1877
stream = stream ,
1878
+ stream_options = stream_options ,
1819
1879
seed = seed ,
1820
1880
tfs_z = tfs_z ,
1821
1881
mirostat_mode = mirostat_mode ,
@@ -1850,6 +1910,7 @@ def __call__(
1850
1910
repeat_penalty : float = 1.0 ,
1851
1911
top_k : int = 40 ,
1852
1912
stream : bool = False ,
1913
+ stream_options : Optional [StreamOptions ] = None ,
1853
1914
seed : Optional [int ] = None ,
1854
1915
tfs_z : float = 1.0 ,
1855
1916
mirostat_mode : int = 0 ,
@@ -1913,6 +1974,7 @@ def __call__(
1913
1974
repeat_penalty = repeat_penalty ,
1914
1975
top_k = top_k ,
1915
1976
stream = stream ,
1977
+ stream_options = stream_options ,
1916
1978
seed = seed ,
1917
1979
tfs_z = tfs_z ,
1918
1980
mirostat_mode = mirostat_mode ,
@@ -1938,6 +2000,7 @@ def create_chat_completion(
1938
2000
min_p : float = 0.05 ,
1939
2001
typical_p : float = 1.0 ,
1940
2002
stream : bool = False ,
2003
+ stream_options : Optional [StreamOptions ] = False ,
1941
2004
stop : Optional [Union [str , List [str ]]] = [],
1942
2005
seed : Optional [int ] = None ,
1943
2006
response_format : Optional [ChatCompletionRequestResponseFormat ] = None ,
@@ -2011,6 +2074,7 @@ def create_chat_completion(
2011
2074
logprobs = logprobs ,
2012
2075
top_logprobs = top_logprobs ,
2013
2076
stream = stream ,
2077
+ stream_options = stream_options ,
2014
2078
stop = stop ,
2015
2079
seed = seed ,
2016
2080
response_format = response_format ,
0 commit comments