|
| 1 | +import OpenAI from 'openai'; |
| 2 | + |
1 | 3 | import logger from '../logger';
|
2 | 4 | import { fetchWithCache } from '../cache';
|
3 | 5 | import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
|
@@ -32,6 +34,17 @@ interface OpenAiCompletionOptions {
|
32 | 34 | organization?: string;
|
33 | 35 | }
|
34 | 36 |
|
| 37 | +function failApiCall(err: any) { |
| 38 | + if (err instanceof OpenAI.APIError) { |
| 39 | + return { |
| 40 | + error: `API error: ${err.type} ${err.message}`, |
| 41 | + }; |
| 42 | + } |
| 43 | + return { |
| 44 | + error: `API error: ${String(err)}`, |
| 45 | + }; |
| 46 | +} |
| 47 | + |
35 | 48 | class OpenAiGenericProvider implements ApiProvider {
|
36 | 49 | modelName: string;
|
37 | 50 |
|
@@ -342,6 +355,175 @@ export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
|
342 | 355 | }
|
343 | 356 | }
|
344 | 357 |
|
| 358 | +interface AssistantMessagesResponseDataContent { |
| 359 | + type: string; |
| 360 | + text?: { |
| 361 | + value: string; |
| 362 | + }; |
| 363 | +} |
| 364 | + |
| 365 | +interface AssistantMessagesResponseData { |
| 366 | + data: { |
| 367 | + role: string; |
| 368 | + content?: AssistantMessagesResponseDataContent[]; |
| 369 | + }[]; |
| 370 | +} |
| 371 | + |
| 372 | +interface OpenAiAssistantOptions { |
| 373 | + modelName?: string; |
| 374 | + instructions?: string; |
| 375 | + tools?: OpenAI.Beta.Threads.ThreadCreateAndRunParams['tools']; |
| 376 | + metadata?: object[]; |
| 377 | +} |
| 378 | + |
| 379 | +function toTitleCase(str: string) { |
| 380 | + return str.replace(/\w\S*/g, (txt) => txt.charAt(0).toUpperCase() + txt.substr(1).toLowerCase()); |
| 381 | +} |
| 382 | + |
| 383 | +export class OpenAiAssistantProvider extends OpenAiGenericProvider { |
| 384 | + assistantId: string; |
| 385 | + assistantConfig: OpenAiAssistantOptions; |
| 386 | + |
| 387 | + constructor( |
| 388 | + assistantId: string, |
| 389 | + options: { config?: OpenAiAssistantOptions; id?: string; env?: EnvOverrides } = {}, |
| 390 | + ) { |
| 391 | + super(assistantId, {}); |
| 392 | + this.assistantConfig = options.config || {}; |
| 393 | + this.assistantId = assistantId; |
| 394 | + } |
| 395 | + |
| 396 | + async callApi(prompt: string): Promise<ProviderResponse> { |
| 397 | + if (!this.getApiKey()) { |
| 398 | + throw new Error( |
| 399 | + 'OpenAI API key is not set. Set the OPENAI_API_KEY environment variable or add `apiKey` to the provider config.', |
| 400 | + ); |
| 401 | + } |
| 402 | + |
| 403 | + const openai = new OpenAI({ |
| 404 | + maxRetries: 3, |
| 405 | + timeout: REQUEST_TIMEOUT_MS, |
| 406 | + }); |
| 407 | + |
| 408 | + const messages = parseChatPrompt(prompt, [ |
| 409 | + { role: 'user', content: prompt }, |
| 410 | + ]) as OpenAI.Beta.Threads.ThreadCreateParams.Message[]; |
| 411 | + const body: OpenAI.Beta.Threads.ThreadCreateAndRunParams = { |
| 412 | + assistant_id: this.assistantId, |
| 413 | + model: this.assistantConfig.modelName || undefined, |
| 414 | + instructions: this.assistantConfig.instructions || undefined, |
| 415 | + tools: this.assistantConfig.tools || undefined, |
| 416 | + metadata: this.assistantConfig.metadata || undefined, |
| 417 | + thread: { |
| 418 | + messages, |
| 419 | + }, |
| 420 | + }; |
| 421 | + |
| 422 | + logger.debug(`Calling OpenAI API, creating thread run: ${JSON.stringify(body)}`); |
| 423 | + let run; |
| 424 | + try { |
| 425 | + run = await openai.beta.threads.createAndRun(body); |
| 426 | + } catch (err) { |
| 427 | + return failApiCall(err); |
| 428 | + } |
| 429 | + |
| 430 | + logger.debug(`\tOpenAI thread run API response: ${JSON.stringify(run)}`); |
| 431 | + |
| 432 | + while (run.status === 'in_progress' || run.status === 'queued') { |
| 433 | + await new Promise((resolve) => setTimeout(resolve, 1000)); |
| 434 | + |
| 435 | + logger.debug(`Calling OpenAI API, getting thread run ${run.id} status`); |
| 436 | + try { |
| 437 | + run = await openai.beta.threads.runs.retrieve(run.thread_id, run.id); |
| 438 | + } catch (err) { |
| 439 | + return failApiCall(err); |
| 440 | + } |
| 441 | + logger.debug(`\tOpenAI thread run API response: ${JSON.stringify(run)}`); |
| 442 | + } |
| 443 | + |
| 444 | + if (run.status !== 'completed') { |
| 445 | + if (run.last_error) { |
| 446 | + return { |
| 447 | + error: `Thread run failed: ${run.last_error.message}`, |
| 448 | + }; |
| 449 | + } |
| 450 | + return { |
| 451 | + error: `Thread run failed: ${run.status}`, |
| 452 | + }; |
| 453 | + } |
| 454 | + |
| 455 | + // Get run steps |
| 456 | + logger.debug(`Calling OpenAI API, getting thread run steps for ${run.thread_id}`); |
| 457 | + let steps; |
| 458 | + try { |
| 459 | + steps = await openai.beta.threads.runs.steps.list(run.thread_id, run.id, { |
| 460 | + order: 'asc', |
| 461 | + }); |
| 462 | + } catch (err) { |
| 463 | + return failApiCall(err); |
| 464 | + } |
| 465 | + logger.debug(`\tOpenAI thread run steps API response: ${JSON.stringify(steps)}`); |
| 466 | + |
| 467 | + const outputBlocks = []; |
| 468 | + for (const step of steps.data) { |
| 469 | + if (step.step_details.type === 'message_creation') { |
| 470 | + logger.debug(`Calling OpenAI API, getting message ${step.id}`); |
| 471 | + let message; |
| 472 | + try { |
| 473 | + message = await openai.beta.threads.messages.retrieve( |
| 474 | + run.thread_id, |
| 475 | + step.step_details.message_creation.message_id, |
| 476 | + ); |
| 477 | + } catch (err) { |
| 478 | + return failApiCall(err); |
| 479 | + } |
| 480 | + logger.debug(`\tOpenAI thread run step message API response: ${JSON.stringify(message)}`); |
| 481 | + |
| 482 | + const content = message.content |
| 483 | + .map((content) => |
| 484 | + content.type === 'text' ? content.text.value : `<${content.type} output>`, |
| 485 | + ) |
| 486 | + .join('\n'); |
| 487 | + outputBlocks.push(`[${toTitleCase(message.role)}] ${content}`); |
| 488 | + } else if (step.step_details.type === 'tool_calls') { |
| 489 | + for (const toolCall of step.step_details.tool_calls) { |
| 490 | + if (toolCall.type === 'function') { |
| 491 | + outputBlocks.push( |
| 492 | + `[Call function ${toolCall.function.name} with arguments ${toolCall.function.arguments}]`, |
| 493 | + ); |
| 494 | + outputBlocks.push(`[Function output: ${toolCall.function.output}]`); |
| 495 | + } else if (toolCall.type === 'retrieval') { |
| 496 | + outputBlocks.push(`[Ran retrieval]`); |
| 497 | + } else if (toolCall.type === 'code_interpreter') { |
| 498 | + const output = toolCall.code_interpreter.outputs |
| 499 | + .map((output) => (output.type === 'logs' ? output.logs : `<${output.type} output>`)) |
| 500 | + .join('\n'); |
| 501 | + outputBlocks.push(`[Code interpreter input]`); |
| 502 | + outputBlocks.push(toolCall.code_interpreter.input); |
| 503 | + outputBlocks.push(`[Code interpreter output]`); |
| 504 | + outputBlocks.push(output); |
| 505 | + } else { |
| 506 | + outputBlocks.push(`[Unknown tool call type: ${(toolCall as any).type}]`); |
| 507 | + } |
| 508 | + } |
| 509 | + } else { |
| 510 | + outputBlocks.push(`[Unknown step type: ${(step.step_details as any).type}]`); |
| 511 | + } |
| 512 | + } |
| 513 | + |
| 514 | + return { |
| 515 | + output: outputBlocks.join('\n\n').trim(), |
| 516 | + /* |
| 517 | + tokenUsage: { |
| 518 | + total: data.usage.total_tokens, |
| 519 | + prompt: data.usage.prompt_tokens, |
| 520 | + completion: data.usage.completion_tokens, |
| 521 | + }, |
| 522 | + */ |
| 523 | + }; |
| 524 | + } |
| 525 | +} |
| 526 | + |
345 | 527 | export const DefaultEmbeddingProvider = new OpenAiEmbeddingProvider('text-embedding-ada-002');
|
346 | 528 | export const DefaultGradingProvider = new OpenAiChatCompletionProvider('gpt-4-0613');
|
347 | 529 | export const DefaultSuggestionsProvider = new OpenAiChatCompletionProvider('gpt-4-0613');
|
0 commit comments