summaryrefslogtreecommitdiff
path: root/libgo/go/database/sql/sql.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/database/sql/sql.go')
-rw-r--r--libgo/go/database/sql/sql.go305
1 files changed, 241 insertions, 64 deletions
diff --git a/libgo/go/database/sql/sql.go b/libgo/go/database/sql/sql.go
index c609fe4cc43..9f4fa14534d 100644
--- a/libgo/go/database/sql/sql.go
+++ b/libgo/go/database/sql/sql.go
@@ -285,7 +285,7 @@ type Scanner interface {
// Example usage:
//
// var outArg string
-// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", Out{Dest: &outArg}))
+// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", sql.Out{Dest: &outArg}))
type Out struct {
_Named_Fields_Required struct{}
@@ -317,8 +317,7 @@ var ErrNoRows = errors.New("sql: no rows in result set")
// connection is returned to DB's idle connection pool. The pool size
// can be controlled with SetMaxIdleConns.
type DB struct {
- driver driver.Driver
- dsn string
+ connector driver.Connector
// numClosed is an atomic counter which represents a total number of
// closed connections. Stmt.openStmt checks it before cleaning closed
// connections in Stmt.css.
@@ -335,6 +334,7 @@ type DB struct {
// It is closed during db.Close(). The close tells the connectionOpener
// goroutine to exit.
openerCh chan struct{}
+ resetterCh chan *driverConn
closed bool
dep map[finalCloser]depSet
lastPut map[*driverConn]string // stacktrace of last conn's put; debug only
@@ -342,6 +342,8 @@ type DB struct {
maxOpen int // <= 0 means unlimited
maxLifetime time.Duration // maximum amount of time a connection may be reused
cleanerCh chan struct{}
+
+ stop func() // stop cancels the connection opener and the session resetter.
}
// connReuseStrategy determines how (*DB).conn returns database connections.
@@ -369,6 +371,7 @@ type driverConn struct {
closed bool
finalClosed bool // ci.Close has been called
openStmt map[*driverStmt]bool
+ lastErr error // lastError captures the result of the session resetter.
// guarded by db.mu
inUse bool
@@ -377,7 +380,7 @@ type driverConn struct {
}
func (dc *driverConn) releaseConn(err error) {
- dc.db.putConn(dc, err)
+ dc.db.putConn(dc, err, true)
}
func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
@@ -418,6 +421,19 @@ func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, que
return ds, nil
}
+// resetSession resets the connection session and sets the lastErr
+// that is checked before returning the connection to another query.
+//
+// resetSession assumes that the embedded mutex is locked when the connection
+// was returned to the pool. This unlocks the mutex.
+func (dc *driverConn) resetSession(ctx context.Context) {
+ defer dc.Unlock() // In case of panic.
+ if dc.closed { // Check if the database has been closed.
+ return
+ }
+ dc.lastErr = dc.ci.(driver.SessionResetter).ResetSession(ctx)
+}
+
// the dc.db's Mutex is held.
func (dc *driverConn) closeDBLocked() func() error {
dc.Lock()
@@ -575,6 +591,52 @@ func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error {
// to block until the connectionOpener can satisfy the backlog of requests.
var connectionRequestQueueSize = 1000000
+type dsnConnector struct {
+ dsn string
+ driver driver.Driver
+}
+
+func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
+ return t.driver.Open(t.dsn)
+}
+
+func (t dsnConnector) Driver() driver.Driver {
+ return t.driver
+}
+
+// OpenDB opens a database using a Connector, allowing drivers to
+// bypass a string based data source name.
+//
+// Most users will open a database via a driver-specific connection
+// helper function that returns a *DB. No database drivers are included
+// in the Go standard library. See https://golang.org/s/sqldrivers for
+// a list of third-party drivers.
+//
+// OpenDB may just validate its arguments without creating a connection
+// to the database. To verify that the data source name is valid, call
+// Ping.
+//
+// The returned DB is safe for concurrent use by multiple goroutines
+// and maintains its own pool of idle connections. Thus, the OpenDB
+// function should be called just once. It is rarely necessary to
+// close a DB.
+func OpenDB(c driver.Connector) *DB {
+ ctx, cancel := context.WithCancel(context.Background())
+ db := &DB{
+ connector: c,
+ openerCh: make(chan struct{}, connectionRequestQueueSize),
+ resetterCh: make(chan *driverConn, 50),
+ lastPut: make(map[*driverConn]string),
+ connRequests: make(map[uint64]chan connRequest),
+ stop: cancel,
+ }
+
+ go db.connectionOpener(ctx)
+ go db.connectionResetter(ctx)
+
+ return db
+}
+
// Open opens a database specified by its database driver name and a
// driver-specific data source name, usually consisting of at least a
// database name and connection information.
@@ -599,15 +661,16 @@ func Open(driverName, dataSourceName string) (*DB, error) {
if !ok {
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
- db := &DB{
- driver: driveri,
- dsn: dataSourceName,
- openerCh: make(chan struct{}, connectionRequestQueueSize),
- lastPut: make(map[*driverConn]string),
- connRequests: make(map[uint64]chan connRequest),
+
+ if driverCtx, ok := driveri.(driver.DriverContext); ok {
+ connector, err := driverCtx.OpenConnector(dataSourceName)
+ if err != nil {
+ return nil, err
+ }
+ return OpenDB(connector), nil
}
- go db.connectionOpener()
- return db, nil
+
+ return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
}
func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error {
@@ -659,7 +722,6 @@ func (db *DB) Close() error {
db.mu.Unlock()
return nil
}
- close(db.openerCh)
if db.cleanerCh != nil {
close(db.cleanerCh)
}
@@ -680,6 +742,7 @@ func (db *DB) Close() error {
err = err1
}
}
+ db.stop()
return err
}
@@ -867,18 +930,40 @@ func (db *DB) maybeOpenNewConnections() {
}
// Runs in a separate goroutine, opens new connections when requested.
-func (db *DB) connectionOpener() {
- for range db.openerCh {
- db.openNewConnection()
+func (db *DB) connectionOpener(ctx context.Context) {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-db.openerCh:
+ db.openNewConnection(ctx)
+ }
+ }
+}
+
+// connectionResetter runs in a separate goroutine to reset connections async
+// to exported API.
+func (db *DB) connectionResetter(ctx context.Context) {
+ for {
+ select {
+ case <-ctx.Done():
+ close(db.resetterCh)
+ for dc := range db.resetterCh {
+ dc.Unlock()
+ }
+ return
+ case dc := <-db.resetterCh:
+ dc.resetSession(ctx)
+ }
}
}
// Open one new connection
-func (db *DB) openNewConnection() {
+func (db *DB) openNewConnection(ctx context.Context) {
// maybeOpenNewConnctions has already executed db.numOpen++ before it sent
// on db.openerCh. This function must execute db.numOpen-- if the
// connection fails or is closed before returning.
- ci, err := db.driver.Open(db.dsn)
+ ci, err := db.connector.Connect(ctx)
db.mu.Lock()
defer db.mu.Unlock()
if db.closed {
@@ -953,6 +1038,14 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
conn.Close()
return nil, driver.ErrBadConn
}
+ // Lock around reading lastErr to ensure the session resetter finished.
+ conn.Lock()
+ err := conn.lastErr
+ conn.Unlock()
+ if err == driver.ErrBadConn {
+ conn.Close()
+ return nil, driver.ErrBadConn
+ }
return conn, nil
}
@@ -978,7 +1071,7 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
default:
case ret, ok := <-req:
if ok {
- db.putConn(ret.conn, ret.err)
+ db.putConn(ret.conn, ret.err, false)
}
}
return nil, ctx.Err()
@@ -990,13 +1083,24 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
ret.conn.Close()
return nil, driver.ErrBadConn
}
+ if ret.conn == nil {
+ return nil, ret.err
+ }
+ // Lock around reading lastErr to ensure the session resetter finished.
+ ret.conn.Lock()
+ err := ret.conn.lastErr
+ ret.conn.Unlock()
+ if err == driver.ErrBadConn {
+ ret.conn.Close()
+ return nil, driver.ErrBadConn
+ }
return ret.conn, ret.err
}
}
db.numOpen++ // optimistically
db.mu.Unlock()
- ci, err := db.driver.Open(db.dsn)
+ ci, err := db.connector.Connect(ctx)
if err != nil {
db.mu.Lock()
db.numOpen-- // correct for earlier optimism
@@ -1045,7 +1149,7 @@ const debugGetPut = false
// putConn adds a connection to the db's free pool.
// err is optionally the last error that occurred on this connection.
-func (db *DB) putConn(dc *driverConn, err error) {
+func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
db.mu.Lock()
if !dc.inUse {
if debugGetPut {
@@ -1076,11 +1180,40 @@ func (db *DB) putConn(dc *driverConn, err error) {
if putConnHook != nil {
putConnHook(db, dc)
}
+ if db.closed {
+ // Connections do not need to be reset if they will be closed.
+ // Prevents writing to resetterCh after the DB has closed.
+ resetSession = false
+ }
+ if resetSession {
+ if _, resetSession = dc.ci.(driver.SessionResetter); resetSession {
+ // Lock the driverConn here so it isn't released until
+ // the connection is reset.
+ // The lock must be taken before the connection is put into
+ // the pool to prevent it from being taken out before it is reset.
+ dc.Lock()
+ }
+ }
added := db.putConnDBLocked(dc, nil)
db.mu.Unlock()
if !added {
+ if resetSession {
+ dc.Unlock()
+ }
dc.Close()
+ return
+ }
+ if !resetSession {
+ return
+ }
+ select {
+ default:
+ // If the resetterCh is blocking then mark the connection
+ // as bad and continue on.
+ dc.lastErr = driver.ErrBadConn
+ dc.Unlock()
+ case db.resetterCh <- dc:
}
}
@@ -1242,15 +1375,20 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
defer func() {
release(err)
}()
- if execer, ok := dc.ci.(driver.Execer); ok {
- var dargs []driver.NamedValue
- dargs, err = driverArgs(dc.ci, nil, args)
- if err != nil {
- return nil, err
- }
+ execerCtx, ok := dc.ci.(driver.ExecerContext)
+ var execer driver.Execer
+ if !ok {
+ execer, ok = dc.ci.(driver.Execer)
+ }
+ if ok {
+ var nvdargs []driver.NamedValue
var resi driver.Result
withLock(dc, func() {
- resi, err = ctxDriverExec(ctx, execer, query, dargs)
+ nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
+ if err != nil {
+ return
+ }
+ resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
})
if err != driver.ErrSkip {
if err != nil {
@@ -1309,15 +1447,21 @@ func (db *DB) query(ctx context.Context, query string, args []interface{}, strat
// The ctx context is from a query method and the txctx context is from an
// optional transaction context.
func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
- if queryer, ok := dc.ci.(driver.Queryer); ok {
- dargs, err := driverArgs(dc.ci, nil, args)
- if err != nil {
- releaseConn(err)
- return nil, err
- }
+ queryerCtx, ok := dc.ci.(driver.QueryerContext)
+ var queryer driver.Queryer
+ if !ok {
+ queryer, ok = dc.ci.(driver.Queryer)
+ }
+ if ok {
+ var nvdargs []driver.NamedValue
var rowsi driver.Rows
+ var err error
withLock(dc, func() {
- rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs)
+ nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
+ if err != nil {
+ return
+ }
+ rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
})
if err != driver.ErrSkip {
if err != nil {
@@ -1454,11 +1598,11 @@ func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error),
// Driver returns the database's underlying driver.
func (db *DB) Driver() driver.Driver {
- return db.driver
+ return db.connector.Driver()
}
// ErrConnDone is returned by any operation that is performed on a connection
-// that has already been committed or rolled back.
+// that has already been returned to the connection pool.
var ErrConnDone = errors.New("database/sql: connection is already closed")
// Conn returns a single connection by either opening a new connection
@@ -1493,9 +1637,9 @@ func (db *DB) Conn(ctx context.Context) (*Conn, error) {
type releaseConn func(error)
-// Conn represents a single database session rather a pool of database
-// sessions. Prefer running queries from DB unless there is a specific
-// need for a continuous single database session.
+// Conn represents a single database connection rather than a pool of database
+// connections. Prefer running queries from DB unless there is a specific
+// need for a continuous single database connection.
//
// A Conn must call Close to return the connection to the database pool
// and may do so concurrently with a running query.
@@ -1769,14 +1913,20 @@ func (tx *Tx) closePrepared() {
// Commit commits the transaction.
func (tx *Tx) Commit() error {
- if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
- return ErrTxDone
- }
+ // Check context first to avoid transaction leak.
+ // If put it behind tx.done CompareAndSwap statement, we cant't ensure
+ // the consistency between tx.done and the real COMMIT operation.
select {
default:
case <-tx.ctx.Done():
+ if atomic.LoadInt32(&tx.done) == 1 {
+ return ErrTxDone
+ }
return tx.ctx.Err()
}
+ if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
+ return ErrTxDone
+ }
var err error
withLock(tx.dc, func() {
err = tx.txi.Commit()
@@ -1859,6 +2009,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
// ...
// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
//
+// The provided context is used for the preparation of the statement, not for the
+// execution of the statement.
+//
// The returned statement operates within the transaction and will be closed
// when the transaction has been committed or rolled back.
func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
@@ -1902,11 +2055,14 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
stmt.mu.Unlock()
if si == nil {
- cs, err := stmt.prepareOnConnLocked(ctx, dc)
+ withLock(dc, func() {
+ var ds *driverStmt
+ ds, err = stmt.prepareOnConnLocked(ctx, dc)
+ si = ds.si
+ })
if err != nil {
return &Stmt{stickyErr: err}
}
- si = cs.si
}
parentStmt = stmt
}
@@ -2098,13 +2254,20 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
}
func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) {
- dargs, err := driverArgs(ci, ds, args)
+ ds.Lock()
+ defer ds.Unlock()
+
+ dargs, err := driverArgsConnLocked(ci, ds, args)
if err != nil {
return nil, err
}
- ds.Lock()
- defer ds.Unlock()
+ // -1 means the driver doesn't know how to count the number of
+ // placeholders, so we won't sanity check input here and instead let the
+ // driver deal with errors.
+ if want := ds.si.NumInput(); want >= 0 && want != len(dargs) {
+ return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(dargs))
+ }
resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
if err != nil {
@@ -2269,25 +2432,20 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
}
func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
- var want int
- withLock(ds, func() {
- want = ds.si.NumInput()
- })
-
- // -1 means the driver doesn't know how to count the number of
- // placeholders, so we won't sanity check input here and instead let the
- // driver deal with errors.
- if want != -1 && len(args) != want {
- return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
- }
+ ds.Lock()
+ defer ds.Unlock()
- dargs, err := driverArgs(ci, ds, args)
+ dargs, err := driverArgsConnLocked(ci, ds, args)
if err != nil {
return nil, err
}
- ds.Lock()
- defer ds.Unlock()
+ // -1 means the driver doesn't know how to count the number of
+ // placeholders, so we won't sanity check input here and instead let the
+ // driver deal with errors.
+ if want := ds.si.NumInput(); want >= 0 && want != len(dargs) {
+ return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(dargs))
+ }
rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
if err != nil {
@@ -2451,9 +2609,16 @@ func (rs *Rows) nextLocked() (doClose, ok bool) {
if rs.closed {
return false, false
}
+
+ // Lock the driver connection before calling the driver interface
+ // rowsi to prevent a Tx from rolling back the connection at the same time.
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
if rs.lastcols == nil {
rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
}
+
rs.lasterr = rs.rowsi.Next(rs.lastcols)
if rs.lasterr != nil {
// Close the connection if there is a driver error.
@@ -2503,6 +2668,12 @@ func (rs *Rows) NextResultSet() bool {
doClose = true
return false
}
+
+ // Lock the driver connection before calling the driver interface
+ // rowsi to prevent a Tx from rolling back the connection at the same time.
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
rs.lasterr = nextResultSet.NextResultSet()
if rs.lasterr != nil {
doClose = true
@@ -2534,6 +2705,9 @@ func (rs *Rows) Columns() ([]string, error) {
if rs.rowsi == nil {
return nil, errors.New("sql: no Rows available")
}
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
return rs.rowsi.Columns(), nil
}
@@ -2548,7 +2722,10 @@ func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
if rs.rowsi == nil {
return nil, errors.New("sql: no Rows available")
}
- return rowsColumnInfoSetup(rs.rowsi), nil
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
+ return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
}
// ColumnType contains the name and type of a column.
@@ -2609,7 +2786,7 @@ func (ci *ColumnType) DatabaseTypeName() string {
return ci.databaseType
}
-func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType {
+func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
names := rowsi.Columns()
list := make([]*ColumnType, len(names))