@@ -30,6 +30,7 @@ package sqlite3
30
30
#endif
31
31
#include <stdlib.h>
32
32
#include <string.h>
33
+ #include <ctype.h>
33
34
34
35
#ifdef __CYGWIN__
35
36
# include <errno.h>
@@ -90,6 +91,16 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change
90
91
return rv;
91
92
}
92
93
94
+ static const char *
95
+ _trim_leading_spaces(const char *str) {
96
+ if (str) {
97
+ while (isspace(*str)) {
98
+ str++;
99
+ }
100
+ }
101
+ return str;
102
+ }
103
+
93
104
#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
94
105
extern int _sqlite3_step_blocking(sqlite3_stmt *stmt);
95
106
extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes);
@@ -110,7 +121,11 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
110
121
static int
111
122
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
112
123
{
113
- return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
124
+ int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
125
+ if (pzTail) {
126
+ *pzTail = _trim_leading_spaces(*pzTail);
127
+ }
128
+ return rv;
114
129
}
115
130
116
131
#else
@@ -133,7 +148,11 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
133
148
static int
134
149
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
135
150
{
136
- return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
151
+ int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
152
+ if (pzTail) {
153
+ *pzTail = _trim_leading_spaces(*pzTail);
154
+ }
155
+ return rv;
137
156
}
138
157
#endif
139
158
@@ -858,25 +877,33 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
858
877
}
859
878
860
879
func (c * SQLiteConn ) exec (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
880
+ pquery := C .CString (query )
881
+ op := pquery // original pointer
882
+ defer C .free (unsafe .Pointer (op ))
883
+
884
+ var stmtArgs []driver.NamedValue
885
+ var tail * C.char
886
+ s := new (SQLiteStmt ) // escapes to the heap so reuse it
861
887
start := 0
862
888
for {
863
- s , err := c .prepare (ctx , query )
864
- if err != nil {
865
- return nil , err
889
+ * s = SQLiteStmt {c : c } // reset
890
+ rv := C ._sqlite3_prepare_v2_internal (c .db , pquery , C .int (- 1 ), & s .s , & tail )
891
+ if rv != C .SQLITE_OK {
892
+ return nil , c .lastError ()
866
893
}
894
+
867
895
var res driver.Result
868
- if s .(* SQLiteStmt ).s != nil {
869
- stmtArgs := make ([]driver.NamedValue , 0 , len (args ))
896
+ if s .s != nil {
870
897
na := s .NumInput ()
871
898
if len (args )- start < na {
872
- s .Close ()
899
+ s .finalize ()
873
900
return nil , fmt .Errorf ("not enough args to execute query: want %d got %d" , na , len (args ))
874
901
}
875
902
// consume the number of arguments used in the current
876
903
// statement and append all named arguments not
877
904
// contained therein
878
905
if len (args [start :start + na ]) > 0 {
879
- stmtArgs = append (stmtArgs , args [start :start + na ]... )
906
+ stmtArgs = append (stmtArgs [: 0 ] , args [start :start + na ]... )
880
907
for i := range args {
881
908
if (i < start || i >= na ) && args [i ].Name != "" {
882
909
stmtArgs = append (stmtArgs , args [i ])
@@ -886,23 +913,23 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.Named
886
913
stmtArgs [i ].Ordinal = i + 1
887
914
}
888
915
}
889
- res , err = s .(* SQLiteStmt ).exec (ctx , stmtArgs )
916
+ var err error
917
+ res , err = s .exec (ctx , stmtArgs )
890
918
if err != nil && err != driver .ErrSkip {
891
- s .Close ()
919
+ s .finalize ()
892
920
return nil , err
893
921
}
894
922
start += na
895
923
}
896
- tail := s .(* SQLiteStmt ).t
897
- s .Close ()
898
- if tail == "" {
924
+ s .finalize ()
925
+ if tail == nil || * tail == '\000' {
899
926
if res == nil {
900
927
// https://github.com/mattn/go-sqlite3/issues/963
901
928
res = & SQLiteResult {0 , 0 }
902
929
}
903
930
return res , nil
904
931
}
905
- query = tail
932
+ pquery = tail
906
933
}
907
934
}
908
935
@@ -919,22 +946,29 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
919
946
}
920
947
921
948
func (c * SQLiteConn ) query (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
949
+ pquery := C .CString (query )
950
+ op := pquery // original pointer
951
+ defer C .free (unsafe .Pointer (op ))
952
+
953
+ var stmtArgs []driver.NamedValue
954
+ var tail * C.char
955
+ s := new (SQLiteStmt ) // escapes to the heap so reuse it
922
956
start := 0
923
957
for {
924
- stmtArgs := make ([]driver. NamedValue , 0 , len ( args ))
925
- s , err := c . prepare ( ctx , query )
926
- if err != nil {
927
- return nil , err
958
+ * s = SQLiteStmt { c : c , cls : true } // reset
959
+ rv := C . _sqlite3_prepare_v2_internal ( c . db , pquery , C . int ( - 1 ), & s . s , & tail )
960
+ if rv != C . SQLITE_OK {
961
+ return nil , c . lastError ()
928
962
}
929
- s .( * SQLiteStmt ). cls = true
963
+
930
964
na := s .NumInput ()
931
965
if len (args )- start < na {
932
966
return nil , fmt .Errorf ("not enough args to execute query: want %d got %d" , na , len (args )- start )
933
967
}
934
968
// consume the number of arguments used in the current
935
969
// statement and append all named arguments not contained
936
970
// therein
937
- stmtArgs = append (stmtArgs , args [start :start + na ]... )
971
+ stmtArgs = append (stmtArgs [: 0 ] , args [start :start + na ]... )
938
972
for i := range args {
939
973
if (i < start || i >= na ) && args [i ].Name != "" {
940
974
stmtArgs = append (stmtArgs , args [i ])
@@ -943,19 +977,18 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
943
977
for i := range stmtArgs {
944
978
stmtArgs [i ].Ordinal = i + 1
945
979
}
946
- rows , err := s .( * SQLiteStmt ). query (ctx , stmtArgs )
980
+ rows , err := s .query (ctx , stmtArgs )
947
981
if err != nil && err != driver .ErrSkip {
948
- s .Close ()
982
+ s .finalize ()
949
983
return rows , err
950
984
}
951
985
start += na
952
- tail := s .(* SQLiteStmt ).t
953
- if tail == "" {
986
+ if tail == nil || * tail == '\000' {
954
987
return rows , nil
955
988
}
956
989
rows .Close ()
957
- s .Close ()
958
- query = tail
990
+ s .finalize ()
991
+ pquery = tail
959
992
}
960
993
}
961
994
@@ -1818,8 +1851,11 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
1818
1851
return nil , c .lastError ()
1819
1852
}
1820
1853
var t string
1821
- if tail != nil && * tail != '\000' {
1822
- t = strings .TrimSpace (C .GoString (tail ))
1854
+ if tail != nil && * tail != 0 {
1855
+ n := int (uintptr (unsafe .Pointer (tail ))) - int (uintptr (unsafe .Pointer (pquery )))
1856
+ if 0 <= n && n < len (query ) {
1857
+ t = strings .TrimSpace (query [n :])
1858
+ }
1823
1859
}
1824
1860
ss := & SQLiteStmt {c : c , s : s , t : t }
1825
1861
runtime .SetFinalizer (ss , (* SQLiteStmt ).Close )
@@ -1913,6 +1949,13 @@ func (s *SQLiteStmt) Close() error {
1913
1949
return nil
1914
1950
}
1915
1951
1952
+ func (s * SQLiteStmt ) finalize () {
1953
+ if s .s != nil {
1954
+ C .sqlite3_finalize (s .s )
1955
+ s .s = nil
1956
+ }
1957
+ }
1958
+
1916
1959
// NumInput return a number of parameters.
1917
1960
func (s * SQLiteStmt ) NumInput () int {
1918
1961
return int (C .sqlite3_bind_parameter_count (s .s ))
0 commit comments