Skip to content

enable tgi and hf endpoint #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ sbatch scripts/run_short_slurm.sh # 8k-64k
# for the API models, note that API results may vary due to the randomness in the API calls
bash scripts/run_api.sh
```
### Run on Intel Gaudi
### Run on Intel Gaudi Accelerators
If you want to enable the evaluation on vLLM with Intel Gaudi, you can use the following commands:
```bash
## Build vllm docker image
Expand Down
4 changes: 3 additions & 1 deletion arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ def parse_arguments():
parser.add_argument("--model_name_or_path", type=str, default=None)
parser.add_argument("--use_vllm", action="store_true", help="whether to use vllm engine")
parser.add_argument("--use_sglang", action="store_true", help="whether to use sglang engine")
parser.add_argument("--use_tgi_or_vllm_serving", action="store_true", help="whether to use tgi or vllm serving engine")
parser.add_argument("--use_vllm_serving", action="store_true", help="whether to use vllm serving engine")
parser.add_argument("--use_tgi_serving", action="store_true", help="whether to use tgi serving engine")
parser.add_argument("--endpoint_url", type=str,default="http://localhost:8080/v1/", help="endpoint url for tgi or vllm serving engine")
parser.add_argument("--api_key", type=str, default="EMPTY", help="api key for model endpoint")

# data settings
parser.add_argument("--datasets", type=str, default=None, help="comma separated list of dataset names")
Expand Down
4 changes: 2 additions & 2 deletions configs/rag_vllm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ datasets: kilt_nq,kilt_triviaqa,kilt_hotpotqa,kilt_popqa_3
generation_max_length: 20,20,20,20
test_files: data/kilt/nq-dev-multikilt_1000_k1000_dep6.jsonl,data/kilt/triviaqa-dev-multikilt_1000_k1000_dep6.jsonl,data/kilt/hotpotqa-dev-multikilt_1000_k1000_dep3.jsonl,data/kilt/popqa_test_1000_k1000_dep6.jsonl
demo_files: data/kilt/nq-train-multikilt_1000_k3_dep6.jsonl,data/kilt/triviaqa-train-multikilt_1000_k3_dep6.jsonl,data/kilt/hotpotqa-train-multikilt_1000_k3_dep3.jsonl,data/kilt/popqa_test_1000_k3_dep6.jsonl
use_chat_template: false
use_chat_template: true
max_test_samples: 100
shots: 2
stop_new_line: true
model_name_or_path: meta-llama/Llama-3.3-70B-Instruct
output_dir: output/vllm-gaudi/Llama-3.3-70B-Instruct
use_tgi_or_vllm_serving: true
use_vllm_serving: true
13 changes: 13 additions & 0 deletions configs/recall_demo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
input_max_length: 8192
datasets: ruler_niah_mk_2
generation_max_length: 50
test_files: data/ruler/niah_multikey_2/validation_8192.jsonl
demo_files: ''
use_chat_template: true
max_test_samples: 5
shots: 2
top_p: 0.95 # need to be >0 and <1
stop_new_line: false
model_name_or_path: tgi:meta-llama/Llama-3.2-1B-Instruct
output_dir: output/tgi/meta-llama/Llama-3.2-1B-Instruct
use_tgi_serving: true
4 changes: 2 additions & 2 deletions configs/recall_vllm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ datasets: ruler_niah_mk_2,ruler_niah_mk_3,ruler_niah_mv,json_kv
generation_max_length: 50,100,50,100
test_files: data/ruler/niah_multikey_2/validation_131072.jsonl,data/ruler/niah_multikey_3/validation_131072.jsonl,data/ruler/niah_multivalue/validation_131072.jsonl,data/json_kv/test_k1800_dep6.jsonl
demo_files: ',,,'
use_chat_template: false
use_chat_template: true
max_test_samples: 100
shots: 2
stop_new_line: false
model_name_or_path: meta-llama/Llama-3.3-70B-Instruct
output_dir: output/vllm-gaudi/Llama-3.3-70B-Instruct
use_tgi_or_vllm_serving: true
use_vllm_serving: true
17 changes: 13 additions & 4 deletions model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,18 @@ def __init__(

self.model = OpenAI(
base_url=endpoint_url,
api_key="EMPTY_KEY"
api_key=kwargs["api_key"],
)
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if "tgi" in model_name:
# remove the tgi: prefix
model_name = model_name[model_name.index(":")+1:]
print(f"** Model: {model_name}")
self.model_name = "tgi"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
else:
print(f"** Model: {model_name}")
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.seed = seed
self.API_MAX_LENGTH = float('inf')

Expand Down Expand Up @@ -1261,10 +1269,11 @@ def load_LLM(args):
elif args.use_vllm:
model_cls = VLLMModel
kwargs['seed'] = args.seed
elif args.use_tgi_or_vllm_serving:
elif args.use_tgi_serving or args.use_vllm_serving:
model_cls = TgiVllmModel
kwargs['seed'] = args.seed
kwargs["endpoint_url"] = args.endpoint_url
kwargs["api_key"] = args.api_key
elif args.use_sglang:
model_cls = SGLangModel
kwargs['seed'] = args.seed
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ accelerate
sentencepiece
pytrec_eval
rouge_score
openai
5 changes: 5 additions & 0 deletions scripts/run_eval_hf_endpoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

LLM_ENDPOINT="https://${hf_inference_point_url}/v1" # fill in your endpoint url
API_KEY=$HF_TOKEN

python eval.py --config configs/recall_demo.yaml --endpoint_url $LLM_ENDPOINT --api_key $API_KEY
5 changes: 5 additions & 0 deletions scripts/run_eval_tgi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export host_ip=$(hostname -I | awk '{print $1}')
export LLM_ENDPOINT_PORT=8085 # change this to the port you want to use
export LLM_ENDPOINT="http://${host_ip}:${LLM_ENDPOINT_PORT}/v1"

python eval.py --config configs/recall_demo.yaml --endpoint_url $LLM_ENDPOINT