Skip to content

Commit 3f3208d

Browse files
authored
feat: Support for OpenAI assistants API (promptfoo#283)
1 parent d0f3d6c commit 3f3208d

File tree

6 files changed

+240
-2
lines changed

6 files changed

+240
-2
lines changed

package-lock.json

+41
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"js-yaml": "^4.1.0",
8787
"node-fetch": "^2.6.7",
8888
"nunjucks": "^3.2.4",
89+
"openai": "^4.19.0",
8990
"opener": "^1.5.2",
9091
"replicate": "^0.12.3",
9192
"rouge": "^1.0.3",

src/cache.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ export async function fetchWithCache(
5050
options: RequestInit = {},
5151
timeout: number,
5252
format: 'json' | 'text' = 'json',
53+
bust: boolean = false,
5354
): Promise<{ data: any; cached: boolean }> {
54-
if (!enabled) {
55+
if (!enabled || bust) {
5556
const resp = await fetchWithRetries(url, options, timeout);
5657
return {
5758
cached: false,

src/providers.ts

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import path from 'path';
22

33
import {
4+
OpenAiAssistantProvider,
45
OpenAiCompletionProvider,
56
OpenAiChatCompletionProvider,
67
OpenAiEmbeddingProvider,
@@ -127,6 +128,8 @@ export async function loadApiProvider(
127128
return new OpenAiChatCompletionProvider(modelType, providerOptions);
128129
} else if (OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.includes(modelType)) {
129130
return new OpenAiCompletionProvider(modelType, providerOptions);
131+
} else if (modelType === 'assistant') {
132+
return new OpenAiAssistantProvider(modelName, providerOptions);
130133
} else {
131134
throw new Error(
132135
`Unknown OpenAI model type: ${modelType}. Use one of the following providers: openai:chat:<model name>, openai:completion:<model name>`,
@@ -237,6 +240,7 @@ export async function loadApiProvider(
237240
export default {
238241
OpenAiCompletionProvider,
239242
OpenAiChatCompletionProvider,
243+
OpenAiAssistantProvider,
240244
AnthropicCompletionProvider,
241245
ReplicateProvider,
242246
LocalAiCompletionProvider,

src/providers/openai.ts

+182
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import OpenAI from 'openai';
2+
13
import logger from '../logger';
24
import { fetchWithCache } from '../cache';
35
import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
@@ -32,6 +34,17 @@ interface OpenAiCompletionOptions {
3234
organization?: string;
3335
}
3436

37+
function failApiCall(err: any) {
38+
if (err instanceof OpenAI.APIError) {
39+
return {
40+
error: `API error: ${err.type} ${err.message}`,
41+
};
42+
}
43+
return {
44+
error: `API error: ${String(err)}`,
45+
};
46+
}
47+
3548
class OpenAiGenericProvider implements ApiProvider {
3649
modelName: string;
3750

@@ -342,6 +355,175 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
342355
}
343356
}
344357

358+
interface AssistantMessagesResponseDataContent {
359+
type: string;
360+
text?: {
361+
value: string;
362+
};
363+
}
364+
365+
interface AssistantMessagesResponseData {
366+
data: {
367+
role: string;
368+
content?: AssistantMessagesResponseDataContent[];
369+
}[];
370+
}
371+
372+
interface OpenAiAssistantOptions {
373+
modelName?: string;
374+
instructions?: string;
375+
tools?: OpenAI.Beta.Threads.ThreadCreateAndRunParams['tools'];
376+
metadata?: object[];
377+
}
378+
379+
function toTitleCase(str: string) {
380+
return str.replace(/\w\S*/g, (txt) => txt.charAt(0).toUpperCase() + txt.substr(1).toLowerCase());
381+
}
382+
383+
export class OpenAiAssistantProvider extends OpenAiGenericProvider {
384+
assistantId: string;
385+
assistantConfig: OpenAiAssistantOptions;
386+
387+
constructor(
388+
assistantId: string,
389+
options: { config?: OpenAiAssistantOptions; id?: string; env?: EnvOverrides } = {},
390+
) {
391+
super(assistantId, {});
392+
this.assistantConfig = options.config || {};
393+
this.assistantId = assistantId;
394+
}
395+
396+
async callApi(prompt: string): Promise<ProviderResponse> {
397+
if (!this.getApiKey()) {
398+
throw new Error(
399+
'OpenAI API key is not set. Set the OPENAI_API_KEY environment variable or add `apiKey` to the provider config.',
400+
);
401+
}
402+
403+
const openai = new OpenAI({
404+
maxRetries: 3,
405+
timeout: REQUEST_TIMEOUT_MS,
406+
});
407+
408+
const messages = parseChatPrompt(prompt, [
409+
{ role: 'user', content: prompt },
410+
]) as OpenAI.Beta.Threads.ThreadCreateParams.Message[];
411+
const body: OpenAI.Beta.Threads.ThreadCreateAndRunParams = {
412+
assistant_id: this.assistantId,
413+
model: this.assistantConfig.modelName || undefined,
414+
instructions: this.assistantConfig.instructions || undefined,
415+
tools: this.assistantConfig.tools || undefined,
416+
metadata: this.assistantConfig.metadata || undefined,
417+
thread: {
418+
messages,
419+
},
420+
};
421+
422+
logger.debug(`Calling OpenAI API, creating thread run: ${JSON.stringify(body)}`);
423+
let run;
424+
try {
425+
run = await openai.beta.threads.createAndRun(body);
426+
} catch (err) {
427+
return failApiCall(err);
428+
}
429+
430+
logger.debug(`\tOpenAI thread run API response: ${JSON.stringify(run)}`);
431+
432+
while (run.status === 'in_progress' || run.status === 'queued') {
433+
await new Promise((resolve) => setTimeout(resolve, 1000));
434+
435+
logger.debug(`Calling OpenAI API, getting thread run ${run.id} status`);
436+
try {
437+
run = await openai.beta.threads.runs.retrieve(run.thread_id, run.id);
438+
} catch (err) {
439+
return failApiCall(err);
440+
}
441+
logger.debug(`\tOpenAI thread run API response: ${JSON.stringify(run)}`);
442+
}
443+
444+
if (run.status !== 'completed') {
445+
if (run.last_error) {
446+
return {
447+
error: `Thread run failed: ${run.last_error.message}`,
448+
};
449+
}
450+
return {
451+
error: `Thread run failed: ${run.status}`,
452+
};
453+
}
454+
455+
// Get run steps
456+
logger.debug(`Calling OpenAI API, getting thread run steps for ${run.thread_id}`);
457+
let steps;
458+
try {
459+
steps = await openai.beta.threads.runs.steps.list(run.thread_id, run.id, {
460+
order: 'asc',
461+
});
462+
} catch (err) {
463+
return failApiCall(err);
464+
}
465+
logger.debug(`\tOpenAI thread run steps API response: ${JSON.stringify(steps)}`);
466+
467+
const outputBlocks = [];
468+
for (const step of steps.data) {
469+
if (step.step_details.type === 'message_creation') {
470+
logger.debug(`Calling OpenAI API, getting message ${step.id}`);
471+
let message;
472+
try {
473+
message = await openai.beta.threads.messages.retrieve(
474+
run.thread_id,
475+
step.step_details.message_creation.message_id,
476+
);
477+
} catch (err) {
478+
return failApiCall(err);
479+
}
480+
logger.debug(`\tOpenAI thread run step message API response: ${JSON.stringify(message)}`);
481+
482+
const content = message.content
483+
.map((content) =>
484+
content.type === 'text' ? content.text.value : `<${content.type} output>`,
485+
)
486+
.join('\n');
487+
outputBlocks.push(`[${toTitleCase(message.role)}] ${content}`);
488+
} else if (step.step_details.type === 'tool_calls') {
489+
for (const toolCall of step.step_details.tool_calls) {
490+
if (toolCall.type === 'function') {
491+
outputBlocks.push(
492+
`[Call function ${toolCall.function.name} with arguments ${toolCall.function.arguments}]`,
493+
);
494+
outputBlocks.push(`[Function output: ${toolCall.function.output}]`);
495+
} else if (toolCall.type === 'retrieval') {
496+
outputBlocks.push(`[Ran retrieval]`);
497+
} else if (toolCall.type === 'code_interpreter') {
498+
const output = toolCall.code_interpreter.outputs
499+
.map((output) => (output.type === 'logs' ? output.logs : `<${output.type} output>`))
500+
.join('\n');
501+
outputBlocks.push(`[Code interpreter input]`);
502+
outputBlocks.push(toolCall.code_interpreter.input);
503+
outputBlocks.push(`[Code interpreter output]`);
504+
outputBlocks.push(output);
505+
} else {
506+
outputBlocks.push(`[Unknown tool call type: ${(toolCall as any).type}]`);
507+
}
508+
}
509+
} else {
510+
outputBlocks.push(`[Unknown step type: ${(step.step_details as any).type}]`);
511+
}
512+
}
513+
514+
return {
515+
output: outputBlocks.join('\n\n').trim(),
516+
/*
517+
tokenUsage: {
518+
total: data.usage.total_tokens,
519+
prompt: data.usage.prompt_tokens,
520+
completion: data.usage.completion_tokens,
521+
},
522+
*/
523+
};
524+
}
525+
}
526+
345527
export const DefaultEmbeddingProvider = new OpenAiEmbeddingProvider('text-embedding-ada-002');
346528
export const DefaultGradingProvider = new OpenAiChatCompletionProvider('gpt-4-0613');
347529
export const DefaultSuggestionsProvider = new OpenAiChatCompletionProvider('gpt-4-0613');

test/providers.test.ts

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import fetch from 'node-fetch';
22

3-
import { OpenAiCompletionProvider, OpenAiChatCompletionProvider } from '../src/providers/openai';
3+
import {
4+
OpenAiAssistantProvider,
5+
OpenAiCompletionProvider,
6+
OpenAiChatCompletionProvider,
7+
} from '../src/providers/openai';
48
import { AnthropicCompletionProvider } from '../src/providers/anthropic';
59
import { LlamaProvider } from '../src/providers/llama';
610

@@ -374,6 +378,11 @@ describe('providers', () => {
374378
expect(provider).toBeInstanceOf(OpenAiCompletionProvider);
375379
});
376380

381+
test('loadApiProvider with openai:assistant', async () => {
382+
const provider = await loadApiProvider('openai:assistant:foobar');
383+
expect(provider).toBeInstanceOf(OpenAiAssistantProvider);
384+
});
385+
377386
test('loadApiProvider with openai:chat:modelName', async () => {
378387
const provider = await loadApiProvider('openai:chat:gpt-3.5-turbo');
379388
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);

0 commit comments

Comments
 (0)