Skip to content

Commit aa397ee

Browse files
committed
Adding stream_options and include_usage to server. Extracting Updating token generation into it's own function to avoid replicated statements
1 parent 7c4aead commit aa397ee

File tree

5 files changed

+482
-346
lines changed

5 files changed

+482
-346
lines changed

llama_cpp/llama.py

+142-78
Original file line numberDiff line numberDiff line change
@@ -1116,6 +1116,56 @@ def decode_batch(seq_sizes: List[int]):
11161116
else:
11171117
return output
11181118

1119+
def _create_chunk(
1120+
self,
1121+
completion_id: str,
1122+
created: int,
1123+
model_name: str,
1124+
text: str,
1125+
logprobs_or_none: Union[Optional[CompletionLogprobs], None],
1126+
include_usage: bool,
1127+
index: int,
1128+
finish_reason: Union[str, None],
1129+
usage: Union[Dict[str, Any], None] = None,
1130+
) -> CreateChatCompletionStreamResponse:
1131+
"""
1132+
Create chunks for streaming API, depending on whether usage is requested or
1133+
not they need (or don't need) an additional field
1134+
"""
1135+
1136+
if include_usage:
1137+
token = {
1138+
"id": completion_id,
1139+
"object": "text_completion",
1140+
"created": created,
1141+
"model": model_name,
1142+
"choices": [
1143+
{
1144+
"text": text,
1145+
"index": index,
1146+
"logprobs": logprobs_or_none,
1147+
"finish_reason": finish_reason,
1148+
},
1149+
],
1150+
"usage": usage,
1151+
}
1152+
else:
1153+
token = {
1154+
"id": completion_id,
1155+
"object": "text_completion",
1156+
"created": created,
1157+
"model": model_name,
1158+
"choices": [
1159+
{
1160+
"text": text,
1161+
"index": index,
1162+
"logprobs": logprobs_or_none,
1163+
"finish_reason": finish_reason,
1164+
}
1165+
],
1166+
}
1167+
return token
1168+
11191169
def _create_completion(
11201170
self,
11211171
prompt: Union[str, List[int]],
@@ -1133,6 +1183,7 @@ def _create_completion(
11331183
repeat_penalty: float = 1.0,
11341184
top_k: int = 40,
11351185
stream: bool = False,
1186+
stream_options: Optional[StreamOptions] = None,
11361187
seed: Optional[int] = None,
11371188
tfs_z: float = 1.0,
11381189
mirostat_mode: int = 0,
@@ -1363,6 +1414,11 @@ def logit_bias_processor(
13631414
break
13641415

13651416
if stream:
1417+
if stream_options is not None and "include_usage" in stream_options:
1418+
include_usage = True if stream_options["include_usage"] else False
1419+
else:
1420+
include_usage = False
1421+
13661422
remaining_tokens = completion_tokens[returned_tokens:]
13671423
remaining_text = self.detokenize(
13681424
remaining_tokens,
@@ -1442,24 +1498,23 @@ def logit_bias_processor(
14421498
"top_logprobs": [top_logprob],
14431499
}
14441500
returned_tokens += 1
1445-
yield {
1446-
"id": completion_id,
1447-
"object": "text_completion",
1448-
"created": created,
1449-
"model": model_name,
1450-
"choices": [
1451-
{
1452-
"text": self.detokenize(
1453-
[token],
1454-
prev_tokens=prompt_tokens
1455-
+ completion_tokens[:returned_tokens],
1456-
).decode("utf-8", errors="ignore"),
1457-
"index": 0,
1458-
"logprobs": logprobs_or_none,
1459-
"finish_reason": None,
1460-
}
1461-
],
1462-
}
1501+
text = (
1502+
self.detokenize(
1503+
[token],
1504+
prev_tokens=prompt_tokens
1505+
+ completion_tokens[:returned_tokens],
1506+
).decode("utf-8", errors="ignore"),
1507+
)
1508+
yield self._create_chunk(
1509+
completion_id=completion_id,
1510+
created=created,
1511+
model_name=model_name,
1512+
text=text,
1513+
finish_reason=None,
1514+
index=0,
1515+
logprobs_or_none=logprobs_or_none,
1516+
include_usage=include_usage,
1517+
)
14631518
else:
14641519
while len(remaining_tokens) > 0:
14651520
decode_success = False
@@ -1488,20 +1543,16 @@ def logit_bias_processor(
14881543
remaining_tokens = remaining_tokens[i:]
14891544
returned_tokens += i
14901545

1491-
yield {
1492-
"id": completion_id,
1493-
"object": "text_completion",
1494-
"created": created,
1495-
"model": model_name,
1496-
"choices": [
1497-
{
1498-
"text": ts,
1499-
"index": 0,
1500-
"logprobs": None,
1501-
"finish_reason": None,
1502-
}
1503-
],
1504-
}
1546+
yield self._create_chunk(
1547+
index=0,
1548+
finish_reason=None,
1549+
completion_id=completion_id,
1550+
created=created,
1551+
model_name=model_name,
1552+
text=ts,
1553+
logprobs_or_none=None,
1554+
include_usage=include_usage,
1555+
)
15051556

15061557
if len(completion_tokens) >= max_tokens:
15071558
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
@@ -1580,54 +1631,60 @@ def logit_bias_processor(
15801631
if token_end_position == end - 1:
15811632
break
15821633
returned_tokens += 1
1583-
yield {
1584-
"id": completion_id,
1585-
"object": "text_completion",
1586-
"created": created,
1587-
"model": model_name,
1588-
"choices": [
1589-
{
1590-
"text": last_text[
1591-
: len(last_text) - (token_end_position - end)
1592-
].decode("utf-8", errors="ignore"),
1593-
"index": 0,
1594-
"logprobs": logprobs_or_none,
1595-
"finish_reason": None,
1596-
}
1597-
],
1598-
}
1634+
text = last_text[
1635+
: len(last_text) - (token_end_position - end)
1636+
].decode("utf-8", errors="ignore")
1637+
1638+
yield self._create_chunk(
1639+
completion_id=completion_id,
1640+
created=created,
1641+
model_name=model_name,
1642+
text=text,
1643+
logprobs_or_none=logprobs_or_none,
1644+
include_usage=include_usage,
1645+
index=0,
1646+
finish_reason=None,
1647+
)
15991648
break
16001649
returned_tokens += 1
1601-
yield {
1602-
"id": completion_id,
1603-
"object": "text_completion",
1604-
"created": created,
1605-
"model": model_name,
1606-
"choices": [
1607-
{
1608-
"text": self.detokenize([token]).decode(
1609-
"utf-8", errors="ignore"
1610-
),
1611-
"index": 0,
1612-
"logprobs": logprobs_or_none,
1613-
"finish_reason": None,
1614-
}
1615-
],
1616-
}
1617-
yield {
1618-
"id": completion_id,
1619-
"object": "text_completion",
1620-
"created": created,
1621-
"model": model_name,
1622-
"choices": [
1623-
{
1624-
"text": "",
1625-
"index": 0,
1626-
"logprobs": None,
1627-
"finish_reason": finish_reason,
1628-
}
1629-
],
1630-
}
1650+
text = self.detokenize([token]).decode("utf-8", errors="ignore")
1651+
yield self._create_chunk(
1652+
completion_id=completion_id,
1653+
created=created,
1654+
model_name=model_name,
1655+
text=text,
1656+
logprobs_or_none=logprobs_or_none,
1657+
include_usage=include_usage,
1658+
index=0,
1659+
finish_reason=None,
1660+
)
1661+
yield self._create_chunk(
1662+
completion_id= completion_id,
1663+
created= created,
1664+
model_name=model_name,
1665+
text="",
1666+
index=0,
1667+
logprobs_or_none= None,
1668+
include_usage=include_usage,
1669+
usage=None,
1670+
finish_reason=finish_reason)
1671+
1672+
if include_usage:
1673+
yield self._create_chunk(
1674+
completion_id=completion_id,
1675+
created=created,
1676+
model_name=model_name,
1677+
text="",
1678+
logprobs_or_none=None,
1679+
include_usage=include_usage,
1680+
index=0,
1681+
finish_reason=None,
1682+
usage={
1683+
"prompt_tokens": len(prompt_tokens),
1684+
"completion_tokens": returned_tokens,
1685+
"total_tokens": len(prompt_tokens) + returned_tokens,
1686+
},
1687+
)
16311688
if self.cache:
16321689
if self.verbose:
16331690
print("Llama._create_completion: cache save", file=sys.stderr)
@@ -1736,6 +1793,7 @@ def logit_bias_processor(
17361793
},
17371794
}
17381795

1796+
17391797
def create_completion(
17401798
self,
17411799
prompt: Union[str, List[int]],
@@ -1753,6 +1811,7 @@ def create_completion(
17531811
repeat_penalty: float = 1.0,
17541812
top_k: int = 40,
17551813
stream: bool = False,
1814+
stream_options: Optional[StreamOptions] = None,
17561815
seed: Optional[int] = None,
17571816
tfs_z: float = 1.0,
17581817
mirostat_mode: int = 0,
@@ -1816,6 +1875,7 @@ def create_completion(
18161875
repeat_penalty=repeat_penalty,
18171876
top_k=top_k,
18181877
stream=stream,
1878+
stream_options=stream_options,
18191879
seed=seed,
18201880
tfs_z=tfs_z,
18211881
mirostat_mode=mirostat_mode,
@@ -1850,6 +1910,7 @@ def __call__(
18501910
repeat_penalty: float = 1.0,
18511911
top_k: int = 40,
18521912
stream: bool = False,
1913+
stream_options: Optional[StreamOptions] = None,
18531914
seed: Optional[int] = None,
18541915
tfs_z: float = 1.0,
18551916
mirostat_mode: int = 0,
@@ -1913,6 +1974,7 @@ def __call__(
19131974
repeat_penalty=repeat_penalty,
19141975
top_k=top_k,
19151976
stream=stream,
1977+
stream_options=stream_options,
19161978
seed=seed,
19171979
tfs_z=tfs_z,
19181980
mirostat_mode=mirostat_mode,
@@ -1938,6 +2000,7 @@ def create_chat_completion(
19382000
min_p: float = 0.05,
19392001
typical_p: float = 1.0,
19402002
stream: bool = False,
2003+
stream_options: Optional[StreamOptions] = False,
19412004
stop: Optional[Union[str, List[str]]] = [],
19422005
seed: Optional[int] = None,
19432006
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
@@ -2011,6 +2074,7 @@ def create_chat_completion(
20112074
logprobs=logprobs,
20122075
top_logprobs=top_logprobs,
20132076
stream=stream,
2077+
stream_options=stream_options,
20142078
stop=stop,
20152079
seed=seed,
20162080
response_format=response_format,

0 commit comments

Comments
 (0)