-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathconfigs.py
164 lines (128 loc) · 4.09 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from typing import Optional, Dict, Union, List
from enum import Enum
from dataclasses import dataclass, asdict
import json
@dataclass
class Config:
def __getitem__(self, key: str) -> int:
return getattr(self, key)
class Usecase(Enum):
"""
Specification for the use case
- DUMMY: the dummy use case, for basic testing
- PREFIX_REUSE: the prefix reuse use case
- RAG: the RAG use case
- MULTI: the multi-turn conversation use case
- VARY: the variable length use case
"""
DUMMY = 1
PREFIX_REUSE = 2
RAG = 3
MULTI = 4
VARY = 5
@dataclass
class WorkloadConfig(Config):
"""
- QPS: query per second
- Duration: number of seconds
- Context length: number of tokens in the requests (approximate number)
- Query length
"""
# Number of queries per second
qps: int
# Total duration of the workload in seconds
duration: float
# Number of tokens in the context (approximate number)
context_length: int
# Number of tokens in the suffix question
query_length: int
# Offset of the timestamps
offset: float
def desc(self) -> str:
return json.dumps(asdict(self))
@dataclass
class LMCacheConfig(Config):
# Path to the lmcache configuration
config_path: str
remote_device: Optional[str] = None
def cmdargs(self) -> str:
return " " if self.config_path is not None else ""
# return f"--lmcache-config-file {self.config_path}" if self.config_path is not None else ""
@dataclass
class VLLMConfig(Config):
# which Model is used
model: str
# vLLM engine's port
port: int
# Memory limit for the vLLM engine
gpu_memory_utilization: float
# Tensor parallelism
tensor_parallel_size: Optional[int]
def cmdargs(self) -> str:
args = []
for key, value in self.__dict__.items():
if value is None:
continue
if key=="model":
args.append(f"{value}")
continue
modified_key = key.replace("_", "-")
args.append(f"--{modified_key} {value}")
return " ".join(args)
@dataclass
class VLLMOptionalConfig(Config):
"""
Optional cmdline configuration for the vLLM engine
"""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __setitem__(self, key: str, value: int):
setattr(self, key, value)
def __str__(self):
return "VLLMOptionalConfig({})".format(
", ".join([f"{k}={v}" for k, v in self.__dict__.items()]))
def __repr__(self):
return self.__str__()
def cmdargs(self) -> str:
args = []
for key, value in self.__dict__.items():
if value is None:
continue
modified_key = key.replace("_", "-")
args.append(f"--{modified_key} {value}")
return " ".join(args)
class EngineType(Enum):
"""
What kind of engine will be bootstraped
"""
LOCAL = 1
DOCKER = 2
@dataclass
class BootstrapConfig:
# What kind of engine will be bootstraped
# TODO: this should be in the test case specification
engine_type: EngineType
# Required VLLM configurations
vllm_config: VLLMConfig
# Optional VLLM configurations
vllm_optional_config: VLLMOptionalConfig
# LMCache configurations
lmcache_config: LMCacheConfig
# Extra environment variables
envs: Dict[str, str]
# TODO: configuration loader
#if __name__ == "__main__":
# # Example usage
# workload_config = WorkloadConfig(qps=100, duration=10, context_length=100, query_length=10)
# lmcache_config = LMCacheConfig(config_path="configs/lmcache_config.yaml")
# vllm_config = VLLMConfig(port=8000, model="gpt2", gpu_memory_utilization=0.5, tensor_parallelism=1)
# vllm_optional_config = VLLMOptionalConfig(**{"key1": 1, "key2": 2})
# vllm_optional_config["key3"] = 3
#
# print(workload_config)
# print(lmcache_config)
# print(vllm_config)
# print(vllm_optional_config)
# print(vllm_config.cmdargs())
# print(vllm_optional_config.cmdargs())