-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathchat_test.go
126 lines (120 loc) · 3.38 KB
/
chat_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package gollama
import (
"context"
"encoding/json"
"testing"
)
func TestGollama_Chat(t *testing.T) {
type args struct {
Prompt string
Options interface{}
}
type outs struct {
wantContent *ChatOuput
wantToolJson string
wantFormatJson string
}
tests := []struct {
name string
c *Gollama
args args
want *outs
wantErr bool
}{
{
name: "Vision",
c: New("llama3.2-vision"),
args: args{Prompt: "what is on the road?", Options: PromptImage{Filename: "./test/road.png"}},
want: &outs{wantContent: &ChatOuput{Content: "There is a llama on the road."}},
wantErr: false,
},
{
name: "Math",
c: New("llama3.2"),
args: args{Prompt: "what is 2 + 2? only answer in number"},
want: &outs{wantContent: &ChatOuput{Content: "4"}},
wantErr: false,
},
{
name: "JSON Output",
c: New("llama3.2"),
args: args{Prompt: "Tell me about Argentina. Response in JSON", Options: StructuredFormat{
Type: "object",
Properties: map[string]FormatProperty{
"capital": {
Type: "string",
},
"language": {
Type: "array",
Items: ItemProperty{
Type: "string",
},
}},
Required: []string{"capital", "language"},
}},
want: &outs{wantFormatJson: `{"capital":"Buenos Aires","language":["Spanish"]}`},
wantErr: false,
},
{
name: "Tool",
c: New("llama3.2"),
args: args{Prompt: "what is the weather in New York?", Options: Tool{
Type: "function",
Function: ToolFunction{
Name: "get_current_weather",
Description: "Get the current weather in a specific city",
Parameters: StructuredFormat{
Type: "object",
Properties: map[string]FormatProperty{
"city": {
Type: "string",
Description: "The name of the city",
},
},
Required: []string{"city"},
}},
},
},
want: &outs{wantContent: &ChatOuput{Content: ""}, wantToolJson: `[{"function":{"name":"get_current_weather","arguments":{"city":"New York"}}}]`},
wantErr: false,
},
{
name: "Invalid model",
c: New("invalid"),
args: args{Prompt: "hello"},
want: &outs{wantContent: &ChatOuput{Content: ""}},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.c.Verbose = true
got, err := tt.c.Chat(context.Background(), tt.args.Prompt, tt.args.Options)
if (err != nil) != tt.wantErr {
t.Errorf("Gollama.Chat() error = %v, wantErr %v", err, tt.wantErr)
return
}
// log.Fatalf("got: %+v", got)
if got != nil && tt.want != nil &&
tt.want.wantFormatJson != "" {
var data map[string]interface{}
json.Unmarshal([]byte(got.Content), &data)
jsonString, _ := json.Marshal(data)
if string(jsonString) != tt.want.wantFormatJson {
t.Errorf("Gollama.Chat() = %v, want %v", string(jsonString), tt.want.wantFormatJson)
}
}
if got != nil && tt.want != nil && tt.want.wantContent != nil &&
got.Content != tt.want.wantContent.Content {
t.Errorf("Gollama.Chat() = %v, want %v", got, tt.want)
}
if got != nil && tt.want != nil && tt.want.wantContent != nil &&
tt.want.wantToolJson != "" {
toolJson, _ := json.Marshal(got.ToolCalls)
if string(toolJson) != tt.want.wantToolJson {
t.Errorf("Gollama.Chat() tool calls = %v, want %v", string(toolJson), tt.want.wantToolJson)
}
}
})
}
}