Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BeforeConnect Hook #1875

Open
wants to merge 6 commits into
base: v10
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions base.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,30 @@ func (db *baseDB) initConn(ctx context.Context, cn *pool.Conn) error {
}
cn.Inited = true

if db.opt.TLSConfig != nil {
err := db.enableSSL(ctx, cn, db.opt.TLSConfig)
opt := db.opt.clone()

if opt.TLSConfig != nil {
err := db.enableSSL(ctx, cn, opt.TLSConfig)
if err != nil {
return err
}
}

err := db.startup(ctx, cn, db.opt.User, db.opt.Password, db.opt.Database, db.opt.ApplicationName)

if opt.BeforeConnect != nil {
if err := opt.BeforeConnect(ctx, opt); err != nil {
return err
}
}

err := db.startup(ctx, cn, opt.User, opt.Password, opt.Database, opt.ApplicationName)
if err != nil {
return err
}

if db.opt.OnConnect != nil {
if opt.OnConnect != nil {
p := pool.NewSingleConnPool(db.pool, cn)
return db.opt.OnConnect(ctx, newConn(ctx, db.withPool(p)))
return opt.OnConnect(ctx, newConn(ctx, db.withPool(p)))
}

return nil
Expand Down
24 changes: 24 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,30 @@ func TestDBConnectWithStartupNotice(t *testing.T) {
require.NoError(t, db.Ping(context.Background()), "must successfully ping database with long application name")
}

func TestBeforeConnect(t *testing.T) {
Copy link
Collaborator

@elliotcourant elliotcourant Apr 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add another test that tests the before connect with multiple go routines to make sure that a race condition is not triggered by accident. A race condition in a before connection hook for an ORM could cripple an application.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a good way to do this would be to set the options to have a max pool of 1. And then try to send multiple queries concurrently in a few go routines.

This should trigger a concurrent call of the BeforeConnect method; and given the -race flag should be a sufficient enough smoke test for most use cases.

pwd := "dynamic-passwords-from-xkcd"
opt := pgOptions()
opt.BeforeConnect = func(ctx context.Context, o *pg.Options) error {
o.Password = pwd
return nil
}

db := pg.Connect(opt)
defer db.Close()

var val int
_, err := db.QueryOne(pg.Scan(&val), "SELECT 1")
if err != nil {
t.Fatal(err)
}
if val != 1 {
t.Fatalf(`got %q, wanted 1`, val)
}
if pwd != opt.Password {
t.Fatalf(`got %s, wanted %s`, opt.Password, pwd)
}
}

func TestOnConnect(t *testing.T) {
opt := pgOptions()
opt.OnConnect = func(ctx context.Context, db *pg.Conn) error {
Expand Down
44 changes: 44 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,20 @@ import (
"runtime"
"strconv"
"strings"
"sync"
"time"

"github.com/go-pg/pg/v10/internal/pool"
)

type BeforeConnectOptions struct {
User string
Password string

// TLS config for secure connections.
TLSConfig *tls.Config
}

// Options contains database connection options.
type Options struct {
// Network type, either tcp or unix.
Expand All @@ -28,6 +37,11 @@ type Options struct {
// Network and Addr options.
Dialer func(ctx context.Context, network, addr string) (net.Conn, error)

// BeforeConnect is a hook which is called before a new connection is
// established. Useful for scenarios where dynamic passwords are used.
// Timeout & Retry values set in this hook are not ignored
BeforeConnect func(ctx context.Context, o *Options) error

// Hook that is called after new connection is established
// and user is authenticated.
OnConnect func(ctx context.Context, cn *Conn) error
Expand Down Expand Up @@ -90,6 +104,36 @@ type Options struct {
// but idle connections are still discarded by the client
// if IdleTimeout is set.
IdleCheckFrequency time.Duration

mux sync.Mutex
}

func (opt *Options) clone() *Options {
return &Options{
Network: opt.Network,
Addr: opt.Addr,
Dialer: opt.Dialer,
BeforeConnect: opt.BeforeConnect,
OnConnect: opt.OnConnect,
User: opt.User,
Password: opt.Password,
Database: opt.Database,
ApplicationName: opt.ApplicationName,
TLSConfig: opt.TLSConfig.Clone(),
DialTimeout: opt.DialTimeout,
ReadTimeout: opt.ReadTimeout,
WriteTimeout: opt.WriteTimeout,
MaxRetries: opt.MaxRetries,
RetryStatementTimeout: opt.RetryStatementTimeout,
MinRetryBackoff: opt.MinRetryBackoff,
MaxRetryBackoff: opt.MaxRetryBackoff,
PoolSize: opt.PoolSize,
MinIdleConns: opt.MinIdleConns,
MaxConnAge: opt.MaxConnAge,
PoolTimeout: opt.PoolTimeout,
IdleTimeout: opt.IdleTimeout,
IdleCheckFrequency: opt.IdleCheckFrequency,
Comment on lines +113 to +135
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be go fmted?

}
}

func (opt *Options) init() {
Expand Down