Skip to content

Commit 46e826d

Browse files
authored
Merge pull request #407 from zombiezen/foreignkeys
Add _foreign_keys connection parameter
2 parents c935ccc + c6d43c4 commit 46e826d

File tree

2 files changed

+72
-3
lines changed

2 files changed

+72
-3
lines changed

sqlite3.go

+43-3
Original file line numberDiff line numberDiff line change
@@ -400,14 +400,18 @@ func (c *SQLiteConn) AutoCommit() bool {
400400
}
401401

402402
func (c *SQLiteConn) lastError() error {
403-
rv := C.sqlite3_errcode(c.db)
403+
return lastError(c.db)
404+
}
405+
406+
func lastError(db *C.sqlite3) error {
407+
rv := C.sqlite3_errcode(db)
404408
if rv == C.SQLITE_OK {
405409
return nil
406410
}
407411
return Error{
408412
Code: ErrNo(rv),
409-
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(c.db)),
410-
err: C.GoString(C.sqlite3_errmsg(c.db)),
413+
ExtendedCode: ErrNoExtended(C.sqlite3_extended_errcode(db)),
414+
err: C.GoString(C.sqlite3_errmsg(db)),
411415
}
412416
}
413417

@@ -537,6 +541,8 @@ func errorString(err Error) string {
537541
// _txlock=XXX
538542
// Specify locking behavior for transactions. XXX can be "immediate",
539543
// "deferred", "exclusive".
544+
// _foreign_keys=X
545+
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
540546
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
541547
if C.sqlite3_threadsafe() == 0 {
542548
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
@@ -545,6 +551,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
545551
var loc *time.Location
546552
txlock := "BEGIN"
547553
busyTimeout := 5000
554+
foreignKeys := -1
548555
pos := strings.IndexRune(dsn, '?')
549556
if pos >= 1 {
550557
params, err := url.ParseQuery(dsn[pos+1:])
@@ -587,6 +594,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
587594
}
588595
}
589596

597+
// _foreign_keys
598+
if val := params.Get("_foreign_keys"); val != "" {
599+
switch val {
600+
case "1":
601+
foreignKeys = 1
602+
case "0":
603+
foreignKeys = 0
604+
default:
605+
return nil, fmt.Errorf("Invalid _foreign_keys: %v", val)
606+
}
607+
}
608+
590609
if !strings.HasPrefix(dsn, "file:") {
591610
dsn = dsn[:pos]
592611
}
@@ -613,6 +632,27 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
613632
return nil, Error{Code: ErrNo(rv)}
614633
}
615634

635+
exec := func(s string) error {
636+
cs := C.CString(s)
637+
rv := C.sqlite3_exec(db, cs, nil, nil, nil)
638+
C.free(unsafe.Pointer(cs))
639+
if rv != C.SQLITE_OK {
640+
return lastError(db)
641+
}
642+
return nil
643+
}
644+
if foreignKeys == 0 {
645+
if err := exec("PRAGMA foreign_keys = OFF;"); err != nil {
646+
C.sqlite3_close_v2(db)
647+
return nil, err
648+
}
649+
} else if foreignKeys == 1 {
650+
if err := exec("PRAGMA foreign_keys = ON;"); err != nil {
651+
C.sqlite3_close_v2(db)
652+
return nil, err
653+
}
654+
}
655+
616656
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
617657

618658
if len(d.Extensions) > 0 {

sqlite3_test.go

+29
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,35 @@ func TestReadonly(t *testing.T) {
107107
}
108108
}
109109

110+
func TestForeignKeys(t *testing.T) {
111+
cases := map[string]bool{
112+
"?_foreign_keys=1": true,
113+
"?_foreign_keys=0": false,
114+
}
115+
for option, want := range cases {
116+
fname := TempFilename(t)
117+
uri := "file:" + fname + option
118+
db, err := sql.Open("sqlite3", uri)
119+
if err != nil {
120+
os.Remove(fname)
121+
t.Errorf("sql.Open(\"sqlite3\", %q): %v", uri, err)
122+
continue
123+
}
124+
var enabled bool
125+
err = db.QueryRow("PRAGMA foreign_keys;").Scan(&enabled)
126+
db.Close()
127+
os.Remove(fname)
128+
if err != nil {
129+
t.Errorf("query foreign_keys for %s: %v", uri, err)
130+
continue
131+
}
132+
if enabled != want {
133+
t.Errorf("\"PRAGMA foreign_keys;\" for %q = %t; want %t", uri, enabled, want)
134+
continue
135+
}
136+
}
137+
}
138+
110139
func TestClose(t *testing.T) {
111140
tempFilename := TempFilename(t)
112141
defer os.Remove(tempFilename)

0 commit comments

Comments
 (0)