Skip to content

Commit 122ddb1

Browse files
committed
Move argument converters to callback.go, and optimize return value handling.
A call now doesn't have to do any reflection, it just blindly invokes a bunch of argument and return value handlers to execute the translation, and the safety of the translation is determined at registration time.
1 parent cf8fa0a commit 122ddb1

File tree

4 files changed

+367
-154
lines changed

4 files changed

+367
-154
lines changed

callback.go

+199-1
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,214 @@
55

66
package sqlite3
77

8+
// You can't export a Go function to C and have definitions in the C
9+
// preamble in the same file, so we have to have callbackTrampoline in
10+
// its own file. Because we need a separate file anyway, the support
11+
// code for SQLite custom functions is in here.
12+
813
/*
914
#include <sqlite3-binding.h>
15+
16+
void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
17+
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
1018
*/
1119
import "C"
1220

13-
import "unsafe"
21+
import (
22+
"errors"
23+
"fmt"
24+
"reflect"
25+
"unsafe"
26+
)
1427

1528
//export callbackTrampoline
1629
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
1730
args := (*[1 << 30]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
1831
fi := (*functionInfo)(unsafe.Pointer(C.sqlite3_user_data(ctx)))
1932
fi.Call(ctx, args)
2033
}
34+
35+
// This is only here so that tests can refer to it.
36+
type callbackArgRaw C.sqlite3_value
37+
38+
type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
39+
40+
type callbackArgCast struct {
41+
f callbackArgConverter
42+
typ reflect.Type
43+
}
44+
45+
func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
46+
val, err := c.f(v)
47+
if err != nil {
48+
return reflect.Value{}, err
49+
}
50+
if !val.Type().ConvertibleTo(c.typ) {
51+
return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
52+
}
53+
return val.Convert(c.typ), nil
54+
}
55+
56+
func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
57+
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
58+
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
59+
}
60+
return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
61+
}
62+
63+
func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
64+
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
65+
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
66+
}
67+
i := int64(C.sqlite3_value_int64(v))
68+
val := false
69+
if i != 0 {
70+
val = true
71+
}
72+
return reflect.ValueOf(val), nil
73+
}
74+
75+
func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
76+
if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
77+
return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
78+
}
79+
return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
80+
}
81+
82+
func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
83+
switch C.sqlite3_value_type(v) {
84+
case C.SQLITE_BLOB:
85+
l := C.sqlite3_value_bytes(v)
86+
p := C.sqlite3_value_blob(v)
87+
return reflect.ValueOf(C.GoBytes(p, l)), nil
88+
case C.SQLITE_TEXT:
89+
l := C.sqlite3_value_bytes(v)
90+
c := unsafe.Pointer(C.sqlite3_value_text(v))
91+
return reflect.ValueOf(C.GoBytes(c, l)), nil
92+
default:
93+
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
94+
}
95+
}
96+
97+
func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
98+
switch C.sqlite3_value_type(v) {
99+
case C.SQLITE_BLOB:
100+
l := C.sqlite3_value_bytes(v)
101+
p := (*C.char)(C.sqlite3_value_blob(v))
102+
return reflect.ValueOf(C.GoStringN(p, l)), nil
103+
case C.SQLITE_TEXT:
104+
c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
105+
return reflect.ValueOf(C.GoString(c)), nil
106+
default:
107+
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
108+
}
109+
}
110+
111+
func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
112+
switch typ.Kind() {
113+
case reflect.Slice:
114+
if typ.Elem().Kind() != reflect.Uint8 {
115+
return nil, errors.New("the only supported slice type is []byte")
116+
}
117+
return callbackArgBytes, nil
118+
case reflect.String:
119+
return callbackArgString, nil
120+
case reflect.Bool:
121+
return callbackArgBool, nil
122+
case reflect.Int64:
123+
return callbackArgInt64, nil
124+
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
125+
c := callbackArgCast{callbackArgInt64, typ}
126+
return c.Run, nil
127+
case reflect.Float64:
128+
return callbackArgFloat64, nil
129+
case reflect.Float32:
130+
c := callbackArgCast{callbackArgFloat64, typ}
131+
return c.Run, nil
132+
default:
133+
return nil, fmt.Errorf("don't know how to convert to %s", typ)
134+
}
135+
}
136+
137+
type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
138+
139+
func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
140+
switch v.Type().Kind() {
141+
case reflect.Int64:
142+
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
143+
v = v.Convert(reflect.TypeOf(int64(0)))
144+
case reflect.Bool:
145+
b := v.Interface().(bool)
146+
if b {
147+
v = reflect.ValueOf(int64(1))
148+
} else {
149+
v = reflect.ValueOf(int64(0))
150+
}
151+
default:
152+
return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
153+
}
154+
155+
C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
156+
return nil
157+
}
158+
159+
func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
160+
switch v.Type().Kind() {
161+
case reflect.Float64:
162+
case reflect.Float32:
163+
v = v.Convert(reflect.TypeOf(float64(0)))
164+
default:
165+
return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
166+
}
167+
168+
C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
169+
return nil
170+
}
171+
172+
func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
173+
if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
174+
return fmt.Errorf("cannot convert %s to BLOB", v.Type())
175+
}
176+
i := v.Interface()
177+
if i == nil || len(i.([]byte)) == 0 {
178+
C.sqlite3_result_null(ctx)
179+
} else {
180+
bs := i.([]byte)
181+
C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
182+
}
183+
return nil
184+
}
185+
186+
func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
187+
if v.Type().Kind() != reflect.String {
188+
return fmt.Errorf("cannot convert %s to TEXT", v.Type())
189+
}
190+
C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
191+
return nil
192+
}
193+
194+
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
195+
switch typ.Kind() {
196+
case reflect.Slice:
197+
if typ.Elem().Kind() != reflect.Uint8 {
198+
return nil, errors.New("the only supported slice type is []byte")
199+
}
200+
return callbackRetBlob, nil
201+
case reflect.String:
202+
return callbackRetText, nil
203+
case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
204+
return callbackRetInteger, nil
205+
case reflect.Float32, reflect.Float64:
206+
return callbackRetFloat, nil
207+
default:
208+
return nil, fmt.Errorf("don't know how to convert to %s", typ)
209+
}
210+
}
211+
212+
// Test support code. Tests are not allowed to import "C", so we can't
213+
// declare any functions that use C.sqlite3_value.
214+
func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
215+
return func(*C.sqlite3_value) (reflect.Value, error) {
216+
return v, err
217+
}
218+
}

callback_test.go

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package sqlite3
2+
3+
import (
4+
"errors"
5+
"math"
6+
"reflect"
7+
"testing"
8+
)
9+
10+
func TestCallbackArgCast(t *testing.T) {
11+
intConv := callbackSyntheticForTests(reflect.ValueOf(int64(math.MaxInt64)), nil)
12+
floatConv := callbackSyntheticForTests(reflect.ValueOf(float64(math.MaxFloat64)), nil)
13+
errConv := callbackSyntheticForTests(reflect.Value{}, errors.New("test"))
14+
15+
tests := []struct {
16+
f callbackArgConverter
17+
o reflect.Value
18+
}{
19+
{intConv, reflect.ValueOf(int8(-1))},
20+
{intConv, reflect.ValueOf(int16(-1))},
21+
{intConv, reflect.ValueOf(int32(-1))},
22+
{intConv, reflect.ValueOf(uint8(math.MaxUint8))},
23+
{intConv, reflect.ValueOf(uint16(math.MaxUint16))},
24+
{intConv, reflect.ValueOf(uint32(math.MaxUint32))},
25+
// Special case, int64->uint64 is only 1<<63 - 1, not 1<<64 - 1
26+
{intConv, reflect.ValueOf(uint64(math.MaxInt64))},
27+
{floatConv, reflect.ValueOf(float32(math.Inf(1)))},
28+
}
29+
30+
for _, test := range tests {
31+
conv := callbackArgCast{test.f, test.o.Type()}
32+
val, err := conv.Run(nil)
33+
if err != nil {
34+
t.Errorf("Couldn't convert to %s: %s", test.o.Type(), err)
35+
} else if !reflect.DeepEqual(val.Interface(), test.o.Interface()) {
36+
t.Errorf("Unexpected result from converting to %s: got %v, want %v", test.o.Type(), val.Interface(), test.o.Interface())
37+
}
38+
}
39+
40+
conv := callbackArgCast{errConv, reflect.TypeOf(int8(0))}
41+
_, err := conv.Run(nil)
42+
if err == nil {
43+
t.Errorf("Expected error during callbackArgCast, but got none")
44+
}
45+
}
46+
47+
func TestCallbackConverters(t *testing.T) {
48+
tests := []struct {
49+
v interface{}
50+
err bool
51+
}{
52+
// Unfortunately, we can't tell which converter was returned,
53+
// but we can at least check which types can be converted.
54+
{[]byte{0}, false},
55+
{"text", false},
56+
{true, false},
57+
{int8(0), false},
58+
{int16(0), false},
59+
{int32(0), false},
60+
{int64(0), false},
61+
{uint8(0), false},
62+
{uint16(0), false},
63+
{uint32(0), false},
64+
{uint64(0), false},
65+
{int(0), false},
66+
{uint(0), false},
67+
{float64(0), false},
68+
{float32(0), false},
69+
70+
{func() {}, true},
71+
{complex64(complex(0, 0)), true},
72+
{complex128(complex(0, 0)), true},
73+
{struct{}{}, true},
74+
{map[string]string{}, true},
75+
{[]string{}, true},
76+
{(*int8)(nil), true},
77+
{make(chan int), true},
78+
}
79+
80+
for _, test := range tests {
81+
_, err := callbackArg(reflect.TypeOf(test.v))
82+
if test.err && err == nil {
83+
t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v))
84+
} else if !test.err && err != nil {
85+
t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err)
86+
}
87+
}
88+
89+
for _, test := range tests {
90+
_, err := callbackRet(reflect.TypeOf(test.v))
91+
if test.err && err == nil {
92+
t.Errorf("Expected an error when converting %s, got no error", reflect.TypeOf(test.v))
93+
} else if !test.err && err != nil {
94+
t.Errorf("Expected converter when converting %s, got error: %s", reflect.TypeOf(test.v), err)
95+
}
96+
}
97+
}

0 commit comments

Comments
 (0)