Skip to content

Commit 56c61da

Browse files
author
Judd
committed
add model id & demo for watt-tool
1 parent 7fd47f1 commit 56c61da

File tree

6 files changed

+163
-4
lines changed

6 files changed

+163
-4
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pure C++ implementation based on [@ggerganov](https://github.com/ggerganov)'s [g
1313

1414
**What's New:**
1515

16-
* 2025-02-19: MoE CPU offloading
16+
* 2025-02-19: MoE CPU offloading, tool calling with Watt-tool
1717
* 2025-02-17: [ggml updated](https://github.com/ggml-org/llama.cpp/tree/0f2bbe656473177538956d22b6842bcaa0449fab) again
1818
* 2025-02-10: [GPU acceleration](./docs/gpu.md) 🔥
1919
* 2025-01-25: MiniCPM Embedding & ReRanker

docs/models.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
- v3: [Instruct-1B](https://huggingface.co/tiiuae/Falcon3-1B-Instruct), [Instruct-3B](https://huggingface.co/tiiuae/Falcon3-3B-Instruct), [Instruct-7B](https://huggingface.co/tiiuae/Falcon3-7B-Instruct), [Instruct-10B](https://huggingface.co/tiiuae/Falcon3-10B-Instruct)
3636
* [x] DeepSeek-R1-Distill-LlaMA: [8B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B), [70B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B) (`-a DeepSeek-R1-Distill-LlaMA`)
3737
* [x] DeepHermes-3: [Llama-3-8B-Preview](https://huggingface.co/NousResearch/DeepHermes-3-Llama-3-8B-Preview) (Remember to user `-s ...`)
38+
* [x] Watt-tool: [8B](https://huggingface.co/watt-ai/watt-tool-8B), [70B](https://huggingface.co/watt-ai/watt-tool-70B)
3839

3940
For other models that using `LlamaForCausalLM` architecture, for example, [aiXcoder-7B](https://huggingface.co/aiXcoder/aixcoder-7b-base), try `-a Yi`.
4041

@@ -94,7 +95,7 @@
9495

9596
* [x] Codestral: [22B-v0.1](https://huggingface.co/mistralai/Codestral-22B-v0.1)
9697
* [x] Mistral-Nemo: [Nemo-Instruct-2407](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407)
97-
98+
* [x] Small: [Instruct-24B](https://huggingface.co/mistralai/Mistral-Small-24B-Instruct-2501)
9899

99100
* Phi (`PhiForCausalLM`, `Phi3ForCausalLM`)
100101
* [x] [Phi-2](https://huggingface.co/microsoft/phi-2/tree/eb8bbd1d37d258ea74fb082c53346d33056a83d4)

docs/tool_calling.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Demos of tool calling for these models are provided:
1111
* [NuminaMath](../scripts/tool_numinamath.py)
1212
* [LlaMA3-Groq-Tool-Use](../scripts/tool_groq.py)
1313
* [LlaMA 3.1](../scripts/tool_llama3.1.py)
14+
* [Watt-Tool](../scripts/tool_watt.py)
1415

1516
## Precondition
1617

scripts/model_downloader.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ def get_model_url_on_modelscope(proj: str, fn: str, user: str = 'judd2024') -> s
88
with open(os.path.join(binding.PATH_SCRIPTS, 'models.json'), encoding='utf-8') as f:
99
all_models = json.load(f)
1010

11+
DEF_STORAGE_DIR = '../quantized'
12+
1113
def print_progress_bar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 60, fill = '█', printEnd = "\r", auto_nl = True):
1214
percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
1315
filledLength = int(length * iteration // total)
@@ -126,13 +128,13 @@ def enum_missing():
126128
for q in variants[s]['quantized'].keys():
127129
quantized = variants[s]['quantized'][q]
128130
all.add(quantized['url'].split('/')[1])
129-
f = glob.glob(os.path.join('../quantized', '*.bin'))
131+
f = glob.glob(os.path.join(DEF_STORAGE_DIR, '*.bin'))
130132
l = []
131133
for x in f:
132134
k = os.path.basename(x)
133135
if not k in all:
134136
l.append(k)
135-
print(sorted(l))
137+
print(f'not uploaded models: {sorted(l)}')
136138

137139
def check_default():
138140
for m in all_models.keys():
@@ -141,4 +143,15 @@ def check_default():
141143
print(f"{m} default missing")
142144

143145
if __name__ == '__main__':
146+
import sys
147+
args = sys.argv[1:]
148+
if len(args) > 0:
149+
if args[0] == 'check':
150+
enum_missing()
151+
check_default()
152+
exit(0)
153+
154+
if args[0].startswith(':'):
155+
print(preprocess_args(args, DEF_STORAGE_DIR))
156+
exit(0)
144157
show()

scripts/models.json

+16
Original file line numberDiff line numberDiff line change
@@ -2006,5 +2006,21 @@
20062006
}
20072007
}
20082008
}
2009+
},
2010+
"watt-tool": {
2011+
"brief": "watt-tool-8B is a fine-tuned language model based on LLaMa-3.1-8B-Instruct, optimized for tool usage and multi-turn dialogue.",
2012+
"default": "8b",
2013+
"license": "Apache License 2.0",
2014+
"variants": {
2015+
"8b": {
2016+
"default": "q8",
2017+
"quantized": {
2018+
"q8": {
2019+
"size": 8538752192,
2020+
"url": "chatllm_quantized_watt-tool/watt-tool-8b.bin"
2021+
}
2022+
}
2023+
}
2024+
}
20092025
}
20102026
}

scripts/tool_watt.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import json
2+
import re, sys
3+
from datetime import datetime
4+
5+
from binding import PATH_BINDS
6+
7+
import tool_definition
8+
from tool_definition import dispatch_tool
9+
10+
# https://huggingface.co/watt-ai/watt-tool-8B
11+
12+
FUNCTION_CALL_START = "["
13+
FUNCTION_CALL_CLOSE = "]"
14+
15+
def convert_tool_def(func: dict) -> dict:
16+
params = {}
17+
required = set()
18+
for p in func['parameters']:
19+
params[p['name']] = {
20+
"type": p['type'],
21+
"description": p['description'],
22+
}
23+
if p['required']: required.add(p['name'])
24+
25+
desc = {
26+
"name": func['name'],
27+
"description": func['description'],
28+
"arguments": {
29+
"type": "dict",
30+
"properties": params,
31+
"required": list(required)
32+
}
33+
}
34+
return desc
35+
36+
SYS_PROMPT_TEMPLATE = """You are an expert in composing functions. You are given a question and a set of possible functions. Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
37+
If none of the function can be used, point it out. If the given question lacks the parameters required by the function, also point it out.
38+
You should only return the function call in tools call sections.
39+
40+
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
41+
You SHOULD NOT include any other text in the response.
42+
Here is a list of functions in JSON format that you can invoke.\n{functions}\n
43+
"""
44+
45+
def build_system_prompt(functions: list[dict]):
46+
s = SYS_PROMPT_TEMPLATE.format(functions=[convert_tool_def(f) for f in functions])
47+
return s
48+
49+
import chatllm
50+
from chatllm import ChatLLM
51+
52+
def call_internal_tool(s: str) -> str:
53+
print(f"[Use Builtin Tool]{s}")
54+
return "Error: not implemented"
55+
56+
def call_functions(s: str) -> str:
57+
try:
58+
r = []
59+
for tool_name, code in parse_function_calls(s):
60+
print(f'[Use Tool] {tool_name}')
61+
observation = dispatch_tool(tool_name, code)
62+
r.append(observation.text)
63+
return '\n\n'.join(r)
64+
except Exception as e:
65+
print(f"error occurs: {e}")
66+
return "failed to call the function"
67+
68+
69+
def parse_function_calls(s: str) -> list[tuple[str, dict]] | None:
70+
try:
71+
matches = re.findall(r'(\w+)\((.*?)\)', s)
72+
if matches is None: return None
73+
74+
def parse_args(s: str) -> str:
75+
r = []
76+
for pair in s.split(', '):
77+
k, v = pair.split('=')
78+
r.append(f'"{k}": {v}')
79+
return f"{{{','.join(r)}}}"
80+
81+
return [(match[0], json.loads(parse_args(match[1]))) for match in matches]
82+
except:
83+
return None
84+
85+
class ToolChatLLM(ChatLLM):
86+
87+
chunk_acc = ''
88+
89+
def callback_print(self, s: str) -> None:
90+
91+
if self.chunk_acc == '':
92+
if FUNCTION_CALL_START.startswith(s):
93+
self.chunk_acc = s
94+
else:
95+
super().callback_print(s)
96+
97+
return
98+
99+
self.chunk_acc = (self.chunk_acc + s).strip()
100+
101+
if len(self.chunk_acc) < len(FUNCTION_CALL_START):
102+
return
103+
104+
if not self.chunk_acc.startswith(FUNCTION_CALL_START):
105+
super().callback_print(self.chunk_acc)
106+
self.chunk_acc = ''
107+
return
108+
109+
def callback_end(self) -> None:
110+
s = self.chunk_acc
111+
self.chunk_acc = ''
112+
super().callback_end()
113+
114+
s = s.strip()
115+
if len(s) < 1: return
116+
117+
if parse_function_calls(s) is not None:
118+
rsp = call_functions(s)
119+
self.tool_input(rsp)
120+
else:
121+
super().callback_print(s)
122+
123+
def call_tool(self, s: str) -> None:
124+
rsp = call_internal_tool(s.strip())
125+
self.tool_input(rsp)
126+
127+
if __name__ == '__main__':
128+
chatllm.demo_simple(sys.argv[1:] + ['-s', build_system_prompt(tool_definition._TOOL_DESCRIPTIONS)], ToolChatLLM, lib_path=PATH_BINDS)

0 commit comments

Comments
 (0)