Skip to content

Commit 1df7ace

Browse files
authoredSep 29, 2023
Add support for HuggingFace Inference API (text generation and feature extraction) (promptfoo#205)
1 parent c7c0949 commit 1df7ace

File tree

5 files changed

+209
-4
lines changed

5 files changed

+209
-4
lines changed
 

‎.jest/setEnvVars.js

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ process.env.OPENAI_API_KEY = 'foo';
22
process.env.AZURE_OPENAI_API_KEY = 'foo';
33
process.env.AZURE_OPENAI_API_HOST = 'azure.openai.host';
44
process.env.ANTHROPIC_API_KEY = 'foo';
5+
process.env.HF_API_TOKEN = 'foo';

‎src/providers.ts

+18-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import {
2323
AzureOpenAiCompletionProvider,
2424
AzureOpenAiEmbeddingProvider,
2525
} from './providers/azureopenai';
26+
import { HuggingfaceFeatureExtractionProvider, HuggingfaceTextGenerationProvider } from './providers/huggingface';
2627

2728
import type {
2829
ApiProvider,
@@ -162,11 +163,26 @@ export async function loadApiProvider(
162163
`Unknown Anthropic model type: ${modelType}. Use one of the following providers: anthropic:completion:<model name>`,
163164
);
164165
}
166+
} else if (providerPath?.startsWith('huggingface:')) {
167+
const splits = providerPath.split(':');
168+
if (splits.length < 3) {
169+
throw new Error(
170+
`Invalid Huggingface provider path: ${providerPath}. Use one of the following providers: huggingface:feature-extraction:<model name>, huggingface:text-generation:<model name>`,
171+
);
172+
}
173+
const modelName = splits.slice(2).join(':');
174+
if (splits[1] === 'feature-extraction') {
175+
return new HuggingfaceFeatureExtractionProvider(modelName, providerOptions);
176+
} else if (splits[1] === 'text-generation') {
177+
return new HuggingfaceTextGenerationProvider(modelName, providerOptions);
178+
} else {
179+
throw new Error(
180+
`Invalid Huggingface provider path: ${providerPath}. Use one of the following providers: huggingface:feature-extraction:<model name>, huggingface:text-generation:<model name>`,
181+
);
182+
}
165183
} else if (providerPath?.startsWith('replicate:')) {
166-
// Load Replicate module
167184
const splits = providerPath.split(':');
168185
const modelName = splits.slice(1).join(':');
169-
170186
return new ReplicateProvider(modelName, providerOptions);
171187
}
172188

‎src/providers/huggingface.ts

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import fetch from 'node-fetch';
2+
import {fetchWithCache} from '../cache';
3+
4+
import type { ApiProvider, ProviderEmbeddingResponse, ProviderResponse } from '../types';
5+
import {REQUEST_TIMEOUT_MS} from './shared';
6+
7+
interface HuggingfaceTextGenerationOptions {
8+
top_k?: number;
9+
top_p?: number;
10+
temperature?: number;
11+
repetition_penalty?: number;
12+
max_new_tokens?: number;
13+
max_time?: number;
14+
return_full_text?: boolean;
15+
num_return_sequences?: number;
16+
do_sample?: boolean;
17+
use_cache?: boolean;
18+
wait_for_model?: boolean;
19+
}
20+
21+
export class HuggingfaceTextGenerationProvider implements ApiProvider {
22+
modelName: string;
23+
config: HuggingfaceTextGenerationOptions;
24+
25+
constructor(modelName: string, options: { id?: string, config?: HuggingfaceTextGenerationOptions } = {}) {
26+
const { id, config } = options;
27+
this.modelName = modelName;
28+
this.id = id ? () => id : this.id;
29+
this.config = config || {};
30+
}
31+
32+
id(): string {
33+
return `huggingface:text-generation:${this.modelName}`;
34+
}
35+
36+
toString(): string {
37+
return `[Huggingface Text Generation Provider ${this.modelName}]`;
38+
}
39+
40+
async callApi(prompt: string): Promise<ProviderResponse> {
41+
const params = {
42+
inputs: prompt,
43+
parameters: {
44+
return_full_text: this.config.return_full_text ?? false,
45+
...this.config
46+
},
47+
};
48+
49+
let response;
50+
try {
51+
response = await fetchWithCache(`https://api-inference.huggingface.co/models/${this.modelName}`, {
52+
method: 'POST',
53+
headers: {
54+
'Content-Type': 'application/json',
55+
...(process.env.HF_API_TOKEN ? { 'Authorization': `Bearer ${process.env.HF_API_TOKEN}` } : {}),
56+
},
57+
body: JSON.stringify(params),
58+
}, REQUEST_TIMEOUT_MS);
59+
60+
if (response.data.error) {
61+
return {
62+
error: `API call error: ${response.data.error}`,
63+
};
64+
}
65+
if (!response.data[0]) {
66+
return {
67+
error: `Malformed response data: ${response.data}`,
68+
};
69+
}
70+
71+
return {
72+
output: response.data[0]?.generated_text,
73+
};
74+
} catch(err) {
75+
return {
76+
error: `API call error: ${String(err)}. Output:\n${response?.data}`,
77+
};
78+
}
79+
}
80+
}
81+
82+
interface HuggingfaceFeatureExtractionOptions {
83+
use_cache?: boolean;
84+
wait_for_model?: boolean;
85+
}
86+
87+
export class HuggingfaceFeatureExtractionProvider implements ApiProvider {
88+
modelName: string;
89+
config: HuggingfaceFeatureExtractionOptions;
90+
91+
constructor(modelName: string, options: { id?: string, config?: HuggingfaceFeatureExtractionOptions } = {}) {
92+
const { id, config } = options;
93+
this.modelName = modelName;
94+
this.id = id ? () => id : this.id;
95+
this.config = config || {};
96+
}
97+
98+
id(): string {
99+
return `huggingface:feature-extraction:${this.modelName}`;
100+
}
101+
102+
toString(): string {
103+
return `[Huggingface Feature Extraction Provider ${this.modelName}]`;
104+
}
105+
106+
async callApi(): Promise<ProviderResponse> {
107+
throw new Error('Cannot use a feature extraction provider for text generation');
108+
}
109+
110+
async callEmbeddingApi(text: string): Promise<ProviderEmbeddingResponse> {
111+
const params = {
112+
inputs: text,
113+
};
114+
115+
let response;
116+
try {
117+
response = await fetchWithCache(`https://api-inference.huggingface.co/models/${this.modelName}`, {
118+
method: 'POST',
119+
headers: {
120+
'Content-Type': 'application/json',
121+
...(process.env.HF_API_TOKEN ? { 'Authorization': `Bearer ${process.env.HF_API_TOKEN}` } : {}),
122+
},
123+
body: JSON.stringify(params),
124+
}, REQUEST_TIMEOUT_MS);
125+
126+
if (response.data.error) {
127+
return {
128+
error: `API call error: ${response.data.error}`,
129+
};
130+
}
131+
if (!Array.isArray(response.data)) {
132+
return {
133+
error: `Malformed response data: ${response.data}`,
134+
};
135+
}
136+
137+
return {
138+
embedding: response.data,
139+
};
140+
} catch(err) {
141+
return {
142+
error: `API call error: ${String(err)}. Output:\n${response?.data}`,
143+
};
144+
}
145+
}
146+
}

‎src/providers/ollama.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ export class OllamaProvider implements ApiProvider {
136136
}
137137

138138
export class OllamaEmbeddingProvider extends OllamaProvider {
139-
async callEmbeddingApi(prompt: string): Promise<ProviderEmbeddingResponse> {
139+
async callEmbeddingApi(text: string): Promise<ProviderEmbeddingResponse> {
140140
const params = {
141141
model: this.modelName,
142-
prompt,
142+
prompt: text,
143143
};
144144

145145
logger.debug(`Calling Ollama API: ${JSON.stringify(params)}`);

‎test/providers.test.ts

+42
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import {
1212
} from '../src/providers/azureopenai';
1313
import { OllamaProvider } from '../src/providers/ollama';
1414
import { WebhookProvider } from '../src/providers/webhook';
15+
import { HuggingfaceTextGenerationProvider, HuggingfaceFeatureExtractionProvider } from '../src/providers/huggingface';
1516

1617
import type { ProviderOptionsMap, ProviderFunction } from '../src/types';
1718

@@ -225,6 +226,37 @@ describe('providers', () => {
225226
expect(result.output).toBe('Test output');
226227
});
227228

229+
test('HuggingfaceTextGenerationProvider callApi', async () => {
230+
const mockResponse = {
231+
json: jest.fn().mockResolvedValue([
232+
{ generated_text: 'Test output' },
233+
]),
234+
};
235+
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
236+
237+
const provider = new HuggingfaceTextGenerationProvider('gpt2');
238+
const result = await provider.callApi('Test prompt');
239+
240+
expect(fetch).toHaveBeenCalledTimes(1);
241+
expect(result.output).toBe('Test output');
242+
});
243+
244+
test('HuggingfaceFeatureExtractionProvider callEmbeddingApi', async () => {
245+
const mockResponse = {
246+
json: jest.fn().mockResolvedValue(
247+
[0.1, 0.2, 0.3, 0.4, 0.5],
248+
),
249+
};
250+
(fetch as unknown as jest.Mock).mockResolvedValue(mockResponse);
251+
252+
const provider = new HuggingfaceFeatureExtractionProvider('distilbert-base-uncased');
253+
const result = await provider.callEmbeddingApi('Test text');
254+
255+
expect(fetch).toHaveBeenCalledTimes(1);
256+
expect(result.embedding).toEqual([0.1, 0.2, 0.3, 0.4, 0.5]);
257+
});
258+
259+
228260
test('loadApiProvider with openai:chat', async () => {
229261
const provider = await loadApiProvider('openai:chat');
230262
expect(provider).toBeInstanceOf(OpenAiChatCompletionProvider);
@@ -281,6 +313,16 @@ describe('providers', () => {
281313
expect(provider).toBeInstanceOf(WebhookProvider);
282314
});
283315

316+
test('loadApiProvider with huggingface:text-generation', async () => {
317+
const provider = await loadApiProvider('huggingface:text-generation:foobar/baz');
318+
expect(provider).toBeInstanceOf(HuggingfaceTextGenerationProvider);
319+
});
320+
321+
test('loadApiProvider with huggingface:feature-extraction', async () => {
322+
const provider = await loadApiProvider('huggingface:feature-extraction:foobar/baz');
323+
expect(provider).toBeInstanceOf(HuggingfaceFeatureExtractionProvider);
324+
});
325+
284326
test('loadApiProvider with RawProviderConfig', async () => {
285327
const rawProviderConfig = {
286328
'openai:chat': {

0 commit comments

Comments
 (0)
Please sign in to comment.