diff --git a/base.go b/base.go index d1399746..88c9a7fe 100644 --- a/base.go +++ b/base.go @@ -102,21 +102,30 @@ func (db *baseDB) initConn(ctx context.Context, cn *pool.Conn) error { } cn.Inited = true - if db.opt.TLSConfig != nil { - err := db.enableSSL(ctx, cn, db.opt.TLSConfig) + opt := db.opt.clone() + + if opt.TLSConfig != nil { + err := db.enableSSL(ctx, cn, opt.TLSConfig) if err != nil { return err } } - err := db.startup(ctx, cn, db.opt.User, db.opt.Password, db.opt.Database, db.opt.ApplicationName) + + if opt.BeforeConnect != nil { + if err := opt.BeforeConnect(ctx, opt); err != nil { + return err + } + } + + err := db.startup(ctx, cn, opt.User, opt.Password, opt.Database, opt.ApplicationName) if err != nil { return err } - if db.opt.OnConnect != nil { + if opt.OnConnect != nil { p := pool.NewSingleConnPool(db.pool, cn) - return db.opt.OnConnect(ctx, newConn(ctx, db.withPool(p))) + return opt.OnConnect(ctx, newConn(ctx, db.withPool(p))) } return nil diff --git a/db_test.go b/db_test.go index 679340c4..93e72a3a 100644 --- a/db_test.go +++ b/db_test.go @@ -110,6 +110,30 @@ func TestDBConnectWithStartupNotice(t *testing.T) { require.NoError(t, db.Ping(context.Background()), "must successfully ping database with long application name") } +func TestBeforeConnect(t *testing.T) { + pwd := "dynamic-passwords-from-xkcd" + opt := pgOptions() + opt.BeforeConnect = func(ctx context.Context, o *pg.Options) error { + o.Password = pwd + return nil + } + + db := pg.Connect(opt) + defer db.Close() + + var val int + _, err := db.QueryOne(pg.Scan(&val), "SELECT 1") + if err != nil { + t.Fatal(err) + } + if val != 1 { + t.Fatalf(`got %q, wanted 1`, val) + } + if pwd != opt.Password { + t.Fatalf(`got %s, wanted %s`, opt.Password, pwd) + } +} + func TestOnConnect(t *testing.T) { opt := pgOptions() opt.OnConnect = func(ctx context.Context, db *pg.Conn) error { diff --git a/options.go b/options.go index efd634fd..34bcd020 100644 --- a/options.go +++ b/options.go @@ -11,11 +11,20 @@ import ( "runtime" "strconv" "strings" + "sync" "time" "github.com/go-pg/pg/v10/internal/pool" ) +type BeforeConnectOptions struct { + User string + Password string + + // TLS config for secure connections. + TLSConfig *tls.Config +} + // Options contains database connection options. type Options struct { // Network type, either tcp or unix. @@ -28,6 +37,11 @@ type Options struct { // Network and Addr options. Dialer func(ctx context.Context, network, addr string) (net.Conn, error) + // BeforeConnect is a hook which is called before a new connection is + // established. Useful for scenarios where dynamic passwords are used. + // Timeout & Retry values set in this hook are not ignored + BeforeConnect func(ctx context.Context, o *Options) error + // Hook that is called after new connection is established // and user is authenticated. OnConnect func(ctx context.Context, cn *Conn) error @@ -90,6 +104,36 @@ type Options struct { // but idle connections are still discarded by the client // if IdleTimeout is set. IdleCheckFrequency time.Duration + + mux sync.Mutex +} + +func (opt *Options) clone() *Options { + return &Options{ + Network: opt.Network, + Addr: opt.Addr, + Dialer: opt.Dialer, + BeforeConnect: opt.BeforeConnect, + OnConnect: opt.OnConnect, + User: opt.User, + Password: opt.Password, + Database: opt.Database, + ApplicationName: opt.ApplicationName, + TLSConfig: opt.TLSConfig.Clone(), + DialTimeout: opt.DialTimeout, + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + MaxRetries: opt.MaxRetries, + RetryStatementTimeout: opt.RetryStatementTimeout, + MinRetryBackoff: opt.MinRetryBackoff, + MaxRetryBackoff: opt.MaxRetryBackoff, + PoolSize: opt.PoolSize, + MinIdleConns: opt.MinIdleConns, + MaxConnAge: opt.MaxConnAge, + PoolTimeout: opt.PoolTimeout, + IdleTimeout: opt.IdleTimeout, + IdleCheckFrequency: opt.IdleCheckFrequency, + } } func (opt *Options) init() {