Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReAct function default args #2021

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@
from dspy.signatures.signature import ensure_signature
from dspy.utils.callback import with_callbacks


class Tool:
def __init__(self, func: Callable, name: str = None, desc: str = None, args: dict[str, Any] = None):

def __init__(
self,
func: Callable,
name: str = None,
desc: str = None,
args: dict[str, Any] = None,
defaults: dict[str, Any] = None,
private_defaults: dict[str, Any] = None,
):
annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__
self.func = func
self.name = name or getattr(func, "__name__", type(func).__name__)
Expand All @@ -23,6 +31,8 @@ def __init__(self, func: Callable, name: str = None, desc: str = None, args: dic
for k, v in (args or get_type_hints(annotations_func)).items()
if k != "return"
}
self.defaults = defaults
self.private_defaults = private_defaults

@with_callbacks
def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -63,6 +73,10 @@ def __init__(self, signature, tools: list[Callable], max_iters=5):
args = tool.args if hasattr(tool, "args") else str({tool.input_variable: str})
desc = (f", whose description is <desc>{tool.desc}</desc>." if tool.desc else ".").replace("\n", " ")
desc += f" It takes arguments {args} in JSON format."
if tool.defaults:
desc += f" Default arguments are {tool.defaults}."
if tool.private_defaults:
desc += f" Assume the following function arguments will be provided at function execution time: {tool.private_defaults.keys()}. Therefore do not propose these arguments in the `next_tool_args`."
instr.append(f"({idx+1}) {tool.name}{desc}")

react_signature = (
Expand Down Expand Up @@ -91,13 +105,25 @@ def format(trajectory: dict[str, Any], last_iteration: bool):
for idx in range(self.max_iters):
pred = self.react(**input_args, trajectory=format(trajectory, last_iteration=(idx == self.max_iters - 1)))

# extract private defaults from the tool and supply them to the next tool call
# do not assign the private defaults to the next_tool_args as this will be captured in the trajectory logs, which is not what we want
private_defaults = (
self.tools[pred.next_tool_name].private_defaults
if pred.next_tool_name in self.tools
and self.tools[pred.next_tool_name].private_defaults
else {}
)

trajectory[f"thought_{idx}"] = pred.next_thought
trajectory[f"tool_name_{idx}"] = pred.next_tool_name
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args)
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](
**pred.next_tool_args, **private_defaults
)
except Exception as e:
# risk that the error log will capture the private defaults?
trajectory[f"observation_{idx}"] = f"Failed to execute: {e}"

if pred.next_tool_name == "finish":
Expand Down
Loading