-
Notifications
You must be signed in to change notification settings - Fork 2.4k
/
Copy pathcreate_openai_tools_agent.int.test.ts
129 lines (120 loc) · 3.85 KB
/
create_openai_tools_agent.int.test.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import { test, expect } from "@jest/globals";
import { ChatOpenAI } from "@langchain/openai";
import type { ChatPromptTemplate } from "@langchain/core/prompts";
import { RunnableLambda } from "@langchain/core/runnables";
import { LangChainTracer } from "@langchain/core/tracers/tracer_langchain";
import { AsyncLocalStorageProviderSingleton } from "@langchain/core/singletons";
import { tool } from "@langchain/core/tools";
import { z } from "zod";
import { AsyncLocalStorage } from "async_hooks";
import { TavilySearchResults } from "../../util/testing/tools/tavily_search.js";
import { pull } from "../../hub/index.js";
import { AgentExecutor, createOpenAIToolsAgent } from "../index.js";
const tools = [new TavilySearchResults({ maxResults: 1 })];
test("createOpenAIToolsAgent works", async () => {
const prompt = await pull<ChatPromptTemplate>("hwchase17/openai-tools-agent");
const llm = new ChatOpenAI({
modelName: "gpt-3.5-turbo-1106",
temperature: 0,
});
const agent = await createOpenAIToolsAgent({
llm,
tools,
prompt,
});
const agentExecutor = new AgentExecutor({
agent,
tools,
});
const input = "what is LangChain?";
const result = await agentExecutor.invoke({
input,
});
// console.log(result);
expect(result.input).toBe(input);
expect(typeof result.output).toBe("string");
// Length greater than 10 because any less than that would warrant
// an investigation into why such a short generation was returned.
expect(result.output.length).toBeGreaterThan(10);
});
test("createOpenAIToolsAgent handles errors", async () => {
const errorTools = [
tool(
async () => {
const error = new Error("Error getting search results");
throw error;
},
{
name: "search-results",
schema: z.object({
query: z.string(),
}),
description: "Searches the web",
}
),
];
const prompt = await pull<ChatPromptTemplate>("hwchase17/openai-tools-agent");
const llm = new ChatOpenAI({
modelName: "gpt-3.5-turbo-1106",
temperature: 0,
});
const agent = await createOpenAIToolsAgent({
llm,
tools: errorTools,
prompt,
});
const agentExecutor = new AgentExecutor({
agent,
tools: errorTools,
handleToolRuntimeErrors: (e) => {
throw e;
},
});
const input = "what is LangChain?";
await expect(agentExecutor.invoke({ input })).rejects.toThrowError(
"Error getting search results"
);
});
test.skip("createOpenAIToolsAgent tracing works when it is nested in a lambda", async () => {
AsyncLocalStorageProviderSingleton.initializeGlobalInstance(
new AsyncLocalStorage()
);
const prompt = await pull<ChatPromptTemplate>("hwchase17/openai-tools-agent");
const llm = new ChatOpenAI({
modelName: "gpt-3.5-turbo-1106",
temperature: 0,
});
const agent = await createOpenAIToolsAgent({
llm,
tools,
prompt,
});
const agentExecutor = new AgentExecutor({
agent,
tools,
});
const outer = RunnableLambda.from(async (input) => {
const noop = RunnableLambda.from(() => "hi").withConfig({
runName: "nested_testing",
});
// @eslint-disable-next-line/@typescript-eslint/ban-ts-comment
// @ts-expect-error unused var
const noopRes = await noop.invoke({ nested: "nested" });
// console.log(noopRes);
const res = await agentExecutor.invoke({
input,
});
return res;
});
const input = "what is LangChain?";
const result = await outer.invoke(input, {
tags: ["test"],
callbacks: [new LangChainTracer({ projectName: "langchainjs-tracing-2" })],
});
// console.log(result);
expect(result.input).toBe(input);
expect(typeof result.output).toBe("string");
// Length greater than 10 because any less than that would warrant
// an investigation into why such a short generation was returned.
expect(result.output.length).toBeGreaterThan(10);
});