Skip to content

Commit 95b249b

Browse files
committed
sqlite3: handle trailing comments and multiple SQL statements in Queries
This commit fixes *SQLiteConn.Query to properly handle trailing comments after a SQL query statement. Previously, trailing comments could lead to an infinite loop. It also changes Query to error if the provided SQL statement contains multiple queries ("SELECT 1; SELECT 2;") - previously only the last query was executed ("SELECT 1; SELECT 2;" would yield only 2). This may be a breaking change as previously: Query consumed all of its args - despite only using the last query (Query now only uses the args required to satisfy the first query and errors if there is a mismatch); Query used only the last query and there may be code using this library that depends on this behavior. Personally, I believe the behavior introduced by this commit is correct and any code relying on the prior undocumented behavior incorrect, but it could still be a break.
1 parent c88c58a commit 95b249b

File tree

2 files changed

+168
-106
lines changed

2 files changed

+168
-106
lines changed

sqlite3.go

+36-56
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ package sqlite3
3030
#endif
3131
#include <stdlib.h>
3232
#include <string.h>
33-
#include <ctype.h>
3433
3534
#ifdef __CYGWIN__
3635
# include <errno.h>
@@ -91,16 +90,6 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change
9190
return rv;
9291
}
9392
94-
static const char *
95-
_trim_leading_spaces(const char *str) {
96-
if (str) {
97-
while (isspace(*str)) {
98-
str++;
99-
}
100-
}
101-
return str;
102-
}
103-
10493
#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
10594
extern int _sqlite3_step_blocking(sqlite3_stmt *stmt);
10695
extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes);
@@ -121,11 +110,7 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
121110
static int
122111
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
123112
{
124-
int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
125-
if (pzTail) {
126-
*pzTail = _trim_leading_spaces(*pzTail);
127-
}
128-
return rv;
113+
return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
129114
}
130115
131116
#else
@@ -148,12 +133,9 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
148133
static int
149134
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
150135
{
151-
int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
152-
if (pzTail) {
153-
*pzTail = _trim_leading_spaces(*pzTail);
154-
}
155-
return rv;
136+
return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
156137
}
138+
157139
#endif
158140
159141
void _sqlite3_result_text(sqlite3_context* ctx, const char* s) {
@@ -950,46 +932,44 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
950932
op := pquery // original pointer
951933
defer C.free(unsafe.Pointer(op))
952934

953-
var stmtArgs []driver.NamedValue
954935
var tail *C.char
955-
s := new(SQLiteStmt) // escapes to the heap so reuse it
956-
start := 0
957-
for {
958-
*s = SQLiteStmt{c: c, cls: true} // reset
959-
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
960-
if rv != C.SQLITE_OK {
961-
return nil, c.lastError()
936+
s := &SQLiteStmt{c: c, cls: true}
937+
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &s.s, &tail)
938+
if rv != C.SQLITE_OK {
939+
return nil, c.lastError()
940+
}
941+
if s.s == nil {
942+
return &SQLiteRows{s: &SQLiteStmt{closed: true}, closed: true}, nil
943+
}
944+
na := s.NumInput()
945+
if n := len(args); n != na {
946+
s.finalize()
947+
if n < na {
948+
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args))
962949
}
950+
return nil, fmt.Errorf("too many args to execute query: want %d got %d", na, len(args))
951+
}
952+
rows, err := s.query(ctx, args)
953+
if err != nil && err != driver.ErrSkip {
954+
s.finalize()
955+
return rows, err
956+
}
963957

964-
na := s.NumInput()
965-
if len(args)-start < na {
966-
return nil, fmt.Errorf("not enough args to execute query: want %d got %d", na, len(args)-start)
967-
}
968-
// consume the number of arguments used in the current
969-
// statement and append all named arguments not contained
970-
// therein
971-
stmtArgs = append(stmtArgs[:0], args[start:start+na]...)
972-
for i := range args {
973-
if (i < start || i >= na) && args[i].Name != "" {
974-
stmtArgs = append(stmtArgs, args[i])
975-
}
976-
}
977-
for i := range stmtArgs {
978-
stmtArgs[i].Ordinal = i + 1
979-
}
980-
rows, err := s.query(ctx, stmtArgs)
981-
if err != nil && err != driver.ErrSkip {
982-
s.finalize()
983-
return rows, err
958+
// Consume the rest of the query
959+
for pquery = tail; pquery != nil && *pquery != 0; pquery = tail {
960+
var stmt *C.sqlite3_stmt
961+
rv := C._sqlite3_prepare_v2_internal(c.db, pquery, C.int(-1), &stmt, &tail)
962+
if rv != C.SQLITE_OK {
963+
rows.Close()
964+
return nil, c.lastError()
984965
}
985-
start += na
986-
if tail == nil || *tail == '\000' {
987-
return rows, nil
966+
if stmt != nil {
967+
rows.Close()
968+
return nil, errors.New("cannot insert multiple commands into a prepared statement") // WARN
988969
}
989-
rows.Close()
990-
s.finalize()
991-
pquery = tail
992970
}
971+
972+
return rows, nil
993973
}
994974

995975
// Begin transaction.
@@ -2043,7 +2023,7 @@ func (s *SQLiteStmt) Query(args []driver.Value) (driver.Rows, error) {
20432023
return s.query(context.Background(), list)
20442024
}
20452025

2046-
func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
2026+
func (s *SQLiteStmt) query(ctx context.Context, args []driver.NamedValue) (*SQLiteRows, error) {
20472027
if err := s.bind(args); err != nil {
20482028
return nil, err
20492029
}

sqlite3_test.go

+132-50
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"math/rand"
1919
"net/url"
2020
"os"
21+
"path/filepath"
2122
"reflect"
2223
"regexp"
2324
"runtime"
@@ -1080,67 +1081,158 @@ func TestExecer(t *testing.T) {
10801081
defer db.Close()
10811082

10821083
_, err = db.Exec(`
1083-
create table foo (id integer); -- one comment
1084-
insert into foo(id) values(?);
1085-
insert into foo(id) values(?);
1086-
insert into foo(id) values(?); -- another comment
1084+
CREATE TABLE foo (id INTEGER); -- one comment
1085+
INSERT INTO foo(id) VALUES(?);
1086+
INSERT INTO foo(id) VALUES(?);
1087+
INSERT INTO foo(id) VALUES(?); -- another comment
10871088
`, 1, 2, 3)
10881089
if err != nil {
10891090
t.Error("Failed to call db.Exec:", err)
10901091
}
10911092
}
10921093

1093-
func TestQueryer(t *testing.T) {
1094-
tempFilename := TempFilename(t)
1095-
defer os.Remove(tempFilename)
1096-
db, err := sql.Open("sqlite3", tempFilename)
1094+
func testQuery(t *testing.T, seed bool, test func(t *testing.T, db *sql.DB)) {
1095+
db, err := sql.Open("sqlite3", filepath.Join(t.TempDir(), "test.sqlite3"))
10971096
if err != nil {
10981097
t.Fatal("Failed to open database:", err)
10991098
}
11001099
defer db.Close()
11011100

1102-
_, err = db.Exec(`
1103-
create table foo (id integer);
1104-
`)
1105-
if err != nil {
1106-
t.Error("Failed to call db.Query:", err)
1101+
if seed {
1102+
if _, err := db.Exec(`create table foo (id integer);`); err != nil {
1103+
t.Fatal(err)
1104+
}
1105+
_, err := db.Exec(`
1106+
INSERT INTO foo(id) VALUES(?);
1107+
INSERT INTO foo(id) VALUES(?);
1108+
INSERT INTO foo(id) VALUES(?);
1109+
`, 3, 2, 1)
1110+
if err != nil {
1111+
t.Fatal(err)
1112+
}
11071113
}
11081114

1109-
_, err = db.Exec(`
1110-
insert into foo(id) values(?);
1111-
insert into foo(id) values(?);
1112-
insert into foo(id) values(?);
1113-
`, 3, 2, 1)
1114-
if err != nil {
1115-
t.Error("Failed to call db.Exec:", err)
1116-
}
1117-
rows, err := db.Query(`
1118-
select id from foo order by id;
1119-
`)
1120-
if err != nil {
1121-
t.Error("Failed to call db.Query:", err)
1122-
}
1123-
defer rows.Close()
1124-
n := 0
1125-
for rows.Next() {
1126-
var id int
1127-
err = rows.Scan(&id)
1115+
// Capture panic so tests can continue
1116+
defer func() {
1117+
if e := recover(); e != nil {
1118+
buf := make([]byte, 32*1024)
1119+
n := runtime.Stack(buf, false)
1120+
t.Fatalf("\npanic: %v\n\n%s\n", e, buf[:n])
1121+
}
1122+
}()
1123+
test(t, db)
1124+
}
1125+
1126+
func testQueryValues(t *testing.T, query string, args ...interface{}) []interface{} {
1127+
var values []interface{}
1128+
testQuery(t, true, func(t *testing.T, db *sql.DB) {
1129+
rows, err := db.Query(query, args...)
11281130
if err != nil {
1129-
t.Error("Failed to db.Query:", err)
1131+
t.Fatal(err)
11301132
}
1131-
if id != n+1 {
1132-
t.Error("Failed to db.Query: not matched results")
1133+
if rows == nil {
1134+
t.Fatal("nil rows")
11331135
}
1134-
n = n + 1
1136+
for i := 0; rows.Next(); i++ {
1137+
if i > 1_000 {
1138+
t.Fatal("To many iterations of rows.Next():", i)
1139+
}
1140+
var v interface{}
1141+
if err := rows.Scan(&v); err != nil {
1142+
t.Fatal(err)
1143+
}
1144+
values = append(values, v)
1145+
}
1146+
if err := rows.Err(); err != nil {
1147+
t.Fatal(err)
1148+
}
1149+
if err := rows.Close(); err != nil {
1150+
t.Fatal(err)
1151+
}
1152+
})
1153+
return values
1154+
}
1155+
1156+
func TestQuery(t *testing.T) {
1157+
queries := []struct {
1158+
query string
1159+
args []interface{}
1160+
}{
1161+
{"SELECT id FROM foo ORDER BY id;", nil},
1162+
{"SELECT id FROM foo WHERE id != ? ORDER BY id;", []interface{}{4}},
1163+
{"SELECT id FROM foo WHERE id IN (?, ?, ?) ORDER BY id;", []interface{}{1, 2, 3}},
1164+
1165+
// Comments
1166+
{"SELECT id FROM foo ORDER BY id; -- comment", nil},
1167+
{"SELECT id FROM foo ORDER BY id;\n -- comment\n", nil},
1168+
{
1169+
`-- FOO
1170+
SELECT id FROM foo ORDER BY id; -- BAR
1171+
/* BAZ */`,
1172+
nil,
1173+
},
11351174
}
1136-
if err := rows.Err(); err != nil {
1137-
t.Errorf("Post-scan failed: %v\n", err)
1175+
want := []interface{}{
1176+
int64(1),
1177+
int64(2),
1178+
int64(3),
1179+
}
1180+
for _, q := range queries {
1181+
t.Run("", func(t *testing.T) {
1182+
got := testQueryValues(t, q.query, q.args...)
1183+
if !reflect.DeepEqual(got, want) {
1184+
t.Fatalf("Query(%q, %v) = %v; want: %v", q.query, q.args, got, want)
1185+
}
1186+
})
11381187
}
1139-
if n != 3 {
1140-
t.Errorf("Expected 3 rows but retrieved %v", n)
1188+
}
1189+
1190+
func TestQueryNoSQL(t *testing.T) {
1191+
got := testQueryValues(t, "")
1192+
if got != nil {
1193+
t.Fatalf("Query(%q, %v) = %v; want: %v", "", nil, got, nil)
11411194
}
11421195
}
11431196

1197+
func testQueryError(t *testing.T, query string, args ...interface{}) {
1198+
testQuery(t, true, func(t *testing.T, db *sql.DB) {
1199+
rows, err := db.Query(query, args...)
1200+
if err == nil {
1201+
t.Error("Expected an error got:", err)
1202+
}
1203+
if rows != nil {
1204+
t.Error("Returned rows should be nil on error!")
1205+
// Attempt to iterate over rows to make sure they don't panic.
1206+
for i := 0; rows.Next(); i++ {
1207+
if i > 1_000 {
1208+
t.Fatal("To many iterations of rows.Next():", i)
1209+
}
1210+
}
1211+
if err := rows.Err(); err != nil {
1212+
t.Error(err)
1213+
}
1214+
rows.Close()
1215+
}
1216+
})
1217+
}
1218+
1219+
func TestQueryNotEnoughArgs(t *testing.T) {
1220+
testQueryError(t, "SELECT FROM foo WHERE id = ? OR id = ?);", 1)
1221+
}
1222+
1223+
func TestQueryTooManyArgs(t *testing.T) {
1224+
// TODO: test error message / kind
1225+
testQueryError(t, "SELECT FROM foo WHERE id = ?);", 1, 2)
1226+
}
1227+
1228+
func TestQueryMultipleStatements(t *testing.T) {
1229+
testQueryError(t, "SELECT 1; SELECT 2;")
1230+
}
1231+
1232+
func TestQueryInvalidTable(t *testing.T) {
1233+
testQueryError(t, "SELECT COUNT(*) FROM does_not_exist;")
1234+
}
1235+
11441236
func TestStress(t *testing.T) {
11451237
tempFilename := TempFilename(t)
11461238
defer os.Remove(tempFilename)
@@ -2112,7 +2204,6 @@ var benchmarks = []testing.InternalBenchmark{
21122204
{Name: "BenchmarkRows", F: benchmarkRows},
21132205
{Name: "BenchmarkStmtRows", F: benchmarkStmtRows},
21142206
{Name: "BenchmarkExecStep", F: benchmarkExecStep},
2115-
{Name: "BenchmarkQueryStep", F: benchmarkQueryStep},
21162207
}
21172208

21182209
func (db *TestDB) mustExec(sql string, args ...any) sql.Result {
@@ -2580,12 +2671,3 @@ func benchmarkExecStep(b *testing.B) {
25802671
}
25812672
}
25822673
}
2583-
2584-
func benchmarkQueryStep(b *testing.B) {
2585-
var i int
2586-
for n := 0; n < b.N; n++ {
2587-
if err := db.QueryRow(largeSelectStmt).Scan(&i); err != nil {
2588-
b.Fatal(err)
2589-
}
2590-
}
2591-
}

0 commit comments

Comments
 (0)