Skip to content

Commit dcca56a

Browse files
committed
Adding streaming options to llama
1 parent 7c4aead commit dcca56a

File tree

5 files changed

+481
-346
lines changed

5 files changed

+481
-346
lines changed

llama_cpp/llama.py

+141-78
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,56 @@ def decode_batch(seq_sizes: List[int]):
11161116
else:
11171117
return output
11181118

1119+
def _create_chunk(
1120+
self,
1121+
completion_id: str,
1122+
created: int,
1123+
model_name: str,
1124+
text: str,
1125+
logprobs_or_none: Union[Optional[CompletionLogprobs], None],
1126+
include_usage: bool,
1127+
index: int,
1128+
finish_reason: Union[str, None],
1129+
usage: Union[Dict[str, Any], None] = None,
1130+
) -> CreateChatCompletionStreamResponse:
1131+
"""
1132+
Create chunks for streaming API, depending on whether usage is requested or
1133+
not they need (or don't need) an additional field
1134+
"""
1135+
1136+
if include_usage:
1137+
token = {
1138+
"id": completion_id,
1139+
"object": "text_completion",
1140+
"created": created,
1141+
"model": model_name,
1142+
"choices": [
1143+
{
1144+
"text": text,
1145+
"index": index,
1146+
"logprobs": logprobs_or_none,
1147+
"finish_reason": finish_reason,
1148+
},
1149+
],
1150+
"usage": usage,
1151+
}
1152+
else:
1153+
token = {
1154+
"id": completion_id,
1155+
"object": "text_completion",
1156+
"created": created,
1157+
"model": model_name,
1158+
"choices": [
1159+
{
1160+
"text": text,
1161+
"index": index,
1162+
"logprobs": logprobs_or_none,
1163+
"finish_reason": finish_reason,
1164+
}
1165+
],
1166+
}
1167+
return token
1168+
11191169
def _create_completion(
11201170
self,
11211171
prompt: Union[str, List[int]],
@@ -1133,6 +1183,7 @@ def _create_completion(
11331183
repeat_penalty: float = 1.0,
11341184
top_k: int = 40,
11351185
stream: bool = False,
1186+
stream_options: Optional[StreamOptions] = None,
11361187
seed: Optional[int] = None,
11371188
tfs_z: float = 1.0,
11381189
mirostat_mode: int = 0,
@@ -1363,6 +1414,10 @@ def logit_bias_processor(
13631414
break
13641415

13651416
if stream:
1417+
if stream_options is None or stream_options.include_usage == None:
1418+
include_usage = False
1419+
else:
1420+
include_usage = stream_options.include_usage
13661421
remaining_tokens = completion_tokens[returned_tokens:]
13671422
remaining_text = self.detokenize(
13681423
remaining_tokens,
@@ -1442,24 +1497,23 @@ def logit_bias_processor(
14421497
"top_logprobs": [top_logprob],
14431498
}
14441499
returned_tokens += 1
1445-
yield {
1446-
"id": completion_id,
1447-
"object": "text_completion",
1448-
"created": created,
1449-
"model": model_name,
1450-
"choices": [
1451-
{
1452-
"text": self.detokenize(
1453-
[token],
1454-
prev_tokens=prompt_tokens
1455-
+ completion_tokens[:returned_tokens],
1456-
).decode("utf-8", errors="ignore"),
1457-
"index": 0,
1458-
"logprobs": logprobs_or_none,
1459-
"finish_reason": None,
1460-
}
1461-
],
1462-
}
1500+
text = (
1501+
self.detokenize(
1502+
[token],
1503+
prev_tokens=prompt_tokens
1504+
+ completion_tokens[:returned_tokens],
1505+
).decode("utf-8", errors="ignore"),
1506+
)
1507+
yield self._create_chunk(
1508+
completion_id=completion_id,
1509+
created=created,
1510+
model_name=model_name,
1511+
text=text,
1512+
finish_reason=None,
1513+
index=0,
1514+
logprobs_or_none=logprobs_or_none,
1515+
include_usage=include_usage,
1516+
)
14631517
else:
14641518
while len(remaining_tokens) > 0:
14651519
decode_success = False
@@ -1488,20 +1542,16 @@ def logit_bias_processor(
14881542
remaining_tokens = remaining_tokens[i:]
14891543
returned_tokens += i
14901544

1491-
yield {
1492-
"id": completion_id,
1493-
"object": "text_completion",
1494-
"created": created,
1495-
"model": model_name,
1496-
"choices": [
1497-
{
1498-
"text": ts,
1499-
"index": 0,
1500-
"logprobs": None,
1501-
"finish_reason": None,
1502-
}
1503-
],
1504-
}
1545+
yield self._create_chunk(
1546+
index=0,
1547+
finish_reason=None,
1548+
completion_id=completion_id,
1549+
created=created,
1550+
model_name=model_name,
1551+
text=ts,
1552+
logprobs_or_none=None,
1553+
include_usage=include_usage,
1554+
)
15051555

15061556
if len(completion_tokens) >= max_tokens:
15071557
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
@@ -1580,54 +1630,60 @@ def logit_bias_processor(
15801630
if token_end_position == end - 1:
15811631
break
15821632
returned_tokens += 1
1583-
yield {
1584-
"id": completion_id,
1585-
"object": "text_completion",
1586-
"created": created,
1587-
"model": model_name,
1588-
"choices": [
1589-
{
1590-
"text": last_text[
1591-
: len(last_text) - (token_end_position - end)
1592-
].decode("utf-8", errors="ignore"),
1593-
"index": 0,
1594-
"logprobs": logprobs_or_none,
1595-
"finish_reason": None,
1596-
}
1597-
],
1598-
}
1633+
text = last_text[
1634+
: len(last_text) - (token_end_position - end)
1635+
].decode("utf-8", errors="ignore")
1636+
1637+
yield self._create_chunk(
1638+
completion_id=completion_id,
1639+
created=created,
1640+
model_name=model_name,
1641+
text=text,
1642+
logprobs_or_none=logprobs_or_none,
1643+
include_usage=include_usage,
1644+
index=0,
1645+
finish_reason=None,
1646+
)
15991647
break
16001648
returned_tokens += 1
1601-
yield {
1602-
"id": completion_id,
1603-
"object": "text_completion",
1604-
"created": created,
1605-
"model": model_name,
1606-
"choices": [
1607-
{
1608-
"text": self.detokenize([token]).decode(
1609-
"utf-8", errors="ignore"
1610-
),
1611-
"index": 0,
1612-
"logprobs": logprobs_or_none,
1613-
"finish_reason": None,
1614-
}
1615-
],
1616-
}
1617-
yield {
1618-
"id": completion_id,
1619-
"object": "text_completion",
1620-
"created": created,
1621-
"model": model_name,
1622-
"choices": [
1623-
{
1624-
"text": "",
1625-
"index": 0,
1626-
"logprobs": None,
1627-
"finish_reason": finish_reason,
1628-
}
1629-
],
1630-
}
1649+
text = self.detokenize([token]).decode("utf-8", errors="ignore")
1650+
yield self._create_chunk(
1651+
completion_id=completion_id,
1652+
created=created,
1653+
model_name=model_name,
1654+
text=text,
1655+
logprobs_or_none=logprobs_or_none,
1656+
include_usage=include_usage,
1657+
index=0,
1658+
finish_reason=None,
1659+
)
1660+
yield self._create_chunk(
1661+
completion_id= completion_id,
1662+
created= created,
1663+
model_name=model_name,
1664+
text="",
1665+
index=0,
1666+
logprobs_or_none= None,
1667+
include_usage=include_usage,
1668+
usage=None,
1669+
finish_reason=finish_reason)
1670+
1671+
if include_usage:
1672+
yield self._create_chunk(
1673+
completion_id=completion_id,
1674+
created=created,
1675+
model_name=model_name,
1676+
text="",
1677+
logprobs_or_none=None,
1678+
include_usage=include_usage,
1679+
index=0,
1680+
finish_reason=None,
1681+
usage={
1682+
"prompt_tokens": len(prompt_tokens),
1683+
"completion_tokens": returned_tokens,
1684+
"total_tokens": len(prompt_tokens) + returned_tokens,
1685+
},
1686+
)
16311687
if self.cache:
16321688
if self.verbose:
16331689
print("Llama._create_completion: cache save", file=sys.stderr)
@@ -1736,6 +1792,7 @@ def logit_bias_processor(
17361792
},
17371793
}
17381794

1795+
17391796
def create_completion(
17401797
self,
17411798
prompt: Union[str, List[int]],
@@ -1753,6 +1810,7 @@ def create_completion(
17531810
repeat_penalty: float = 1.0,
17541811
top_k: int = 40,
17551812
stream: bool = False,
1813+
stream_options: Optional[StreamOptions] = None,
17561814
seed: Optional[int] = None,
17571815
tfs_z: float = 1.0,
17581816
mirostat_mode: int = 0,
@@ -1816,6 +1874,7 @@ def create_completion(
18161874
repeat_penalty=repeat_penalty,
18171875
top_k=top_k,
18181876
stream=stream,
1877+
stream_options=stream_options,
18191878
seed=seed,
18201879
tfs_z=tfs_z,
18211880
mirostat_mode=mirostat_mode,
@@ -1850,6 +1909,7 @@ def __call__(
18501909
repeat_penalty: float = 1.0,
18511910
top_k: int = 40,
18521911
stream: bool = False,
1912+
stream_options: Optional[StreamOptions] = None,
18531913
seed: Optional[int] = None,
18541914
tfs_z: float = 1.0,
18551915
mirostat_mode: int = 0,
@@ -1913,6 +1973,7 @@ def __call__(
19131973
repeat_penalty=repeat_penalty,
19141974
top_k=top_k,
19151975
stream=stream,
1976+
stream_options=stream_options,
19161977
seed=seed,
19171978
tfs_z=tfs_z,
19181979
mirostat_mode=mirostat_mode,
@@ -1938,6 +1999,7 @@ def create_chat_completion(
19381999
min_p: float = 0.05,
19392000
typical_p: float = 1.0,
19402001
stream: bool = False,
2002+
stream_options: Optional[StreamOptions] = False,
19412003
stop: Optional[Union[str, List[str]]] = [],
19422004
seed: Optional[int] = None,
19432005
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
@@ -2011,6 +2073,7 @@ def create_chat_completion(
20112073
logprobs=logprobs,
20122074
top_logprobs=top_logprobs,
20132075
stream=stream,
2076+
stream_options=stream_options,
20142077
stop=stop,
20152078
seed=seed,
20162079
response_format=response_format,

0 commit comments

Comments
 (0)