diff options
Diffstat (limited to 'libgo/go/database/sql/sql.go')
-rw-r--r-- | libgo/go/database/sql/sql.go | 305 |
1 files changed, 241 insertions, 64 deletions
diff --git a/libgo/go/database/sql/sql.go b/libgo/go/database/sql/sql.go index c609fe4cc43..9f4fa14534d 100644 --- a/libgo/go/database/sql/sql.go +++ b/libgo/go/database/sql/sql.go @@ -285,7 +285,7 @@ type Scanner interface { // Example usage: // // var outArg string -// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", Out{Dest: &outArg})) +// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", sql.Out{Dest: &outArg})) type Out struct { _Named_Fields_Required struct{} @@ -317,8 +317,7 @@ var ErrNoRows = errors.New("sql: no rows in result set") // connection is returned to DB's idle connection pool. The pool size // can be controlled with SetMaxIdleConns. type DB struct { - driver driver.Driver - dsn string + connector driver.Connector // numClosed is an atomic counter which represents a total number of // closed connections. Stmt.openStmt checks it before cleaning closed // connections in Stmt.css. @@ -335,6 +334,7 @@ type DB struct { // It is closed during db.Close(). The close tells the connectionOpener // goroutine to exit. openerCh chan struct{} + resetterCh chan *driverConn closed bool dep map[finalCloser]depSet lastPut map[*driverConn]string // stacktrace of last conn's put; debug only @@ -342,6 +342,8 @@ type DB struct { maxOpen int // <= 0 means unlimited maxLifetime time.Duration // maximum amount of time a connection may be reused cleanerCh chan struct{} + + stop func() // stop cancels the connection opener and the session resetter. } // connReuseStrategy determines how (*DB).conn returns database connections. @@ -369,6 +371,7 @@ type driverConn struct { closed bool finalClosed bool // ci.Close has been called openStmt map[*driverStmt]bool + lastErr error // lastError captures the result of the session resetter. // guarded by db.mu inUse bool @@ -377,7 +380,7 @@ type driverConn struct { } func (dc *driverConn) releaseConn(err error) { - dc.db.putConn(dc, err) + dc.db.putConn(dc, err, true) } func (dc *driverConn) removeOpenStmt(ds *driverStmt) { @@ -418,6 +421,19 @@ func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, que return ds, nil } +// resetSession resets the connection session and sets the lastErr +// that is checked before returning the connection to another query. +// +// resetSession assumes that the embedded mutex is locked when the connection +// was returned to the pool. This unlocks the mutex. +func (dc *driverConn) resetSession(ctx context.Context) { + defer dc.Unlock() // In case of panic. + if dc.closed { // Check if the database has been closed. + return + } + dc.lastErr = dc.ci.(driver.SessionResetter).ResetSession(ctx) +} + // the dc.db's Mutex is held. func (dc *driverConn) closeDBLocked() func() error { dc.Lock() @@ -575,6 +591,52 @@ func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error { // to block until the connectionOpener can satisfy the backlog of requests. var connectionRequestQueueSize = 1000000 +type dsnConnector struct { + dsn string + driver driver.Driver +} + +func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) { + return t.driver.Open(t.dsn) +} + +func (t dsnConnector) Driver() driver.Driver { + return t.driver +} + +// OpenDB opens a database using a Connector, allowing drivers to +// bypass a string based data source name. +// +// Most users will open a database via a driver-specific connection +// helper function that returns a *DB. No database drivers are included +// in the Go standard library. See https://golang.org/s/sqldrivers for +// a list of third-party drivers. +// +// OpenDB may just validate its arguments without creating a connection +// to the database. To verify that the data source name is valid, call +// Ping. +// +// The returned DB is safe for concurrent use by multiple goroutines +// and maintains its own pool of idle connections. Thus, the OpenDB +// function should be called just once. It is rarely necessary to +// close a DB. +func OpenDB(c driver.Connector) *DB { + ctx, cancel := context.WithCancel(context.Background()) + db := &DB{ + connector: c, + openerCh: make(chan struct{}, connectionRequestQueueSize), + resetterCh: make(chan *driverConn, 50), + lastPut: make(map[*driverConn]string), + connRequests: make(map[uint64]chan connRequest), + stop: cancel, + } + + go db.connectionOpener(ctx) + go db.connectionResetter(ctx) + + return db +} + // Open opens a database specified by its database driver name and a // driver-specific data source name, usually consisting of at least a // database name and connection information. @@ -599,15 +661,16 @@ func Open(driverName, dataSourceName string) (*DB, error) { if !ok { return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) } - db := &DB{ - driver: driveri, - dsn: dataSourceName, - openerCh: make(chan struct{}, connectionRequestQueueSize), - lastPut: make(map[*driverConn]string), - connRequests: make(map[uint64]chan connRequest), + + if driverCtx, ok := driveri.(driver.DriverContext); ok { + connector, err := driverCtx.OpenConnector(dataSourceName) + if err != nil { + return nil, err + } + return OpenDB(connector), nil } - go db.connectionOpener() - return db, nil + + return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil } func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error { @@ -659,7 +722,6 @@ func (db *DB) Close() error { db.mu.Unlock() return nil } - close(db.openerCh) if db.cleanerCh != nil { close(db.cleanerCh) } @@ -680,6 +742,7 @@ func (db *DB) Close() error { err = err1 } } + db.stop() return err } @@ -867,18 +930,40 @@ func (db *DB) maybeOpenNewConnections() { } // Runs in a separate goroutine, opens new connections when requested. -func (db *DB) connectionOpener() { - for range db.openerCh { - db.openNewConnection() +func (db *DB) connectionOpener(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-db.openerCh: + db.openNewConnection(ctx) + } + } +} + +// connectionResetter runs in a separate goroutine to reset connections async +// to exported API. +func (db *DB) connectionResetter(ctx context.Context) { + for { + select { + case <-ctx.Done(): + close(db.resetterCh) + for dc := range db.resetterCh { + dc.Unlock() + } + return + case dc := <-db.resetterCh: + dc.resetSession(ctx) + } } } // Open one new connection -func (db *DB) openNewConnection() { +func (db *DB) openNewConnection(ctx context.Context) { // maybeOpenNewConnctions has already executed db.numOpen++ before it sent // on db.openerCh. This function must execute db.numOpen-- if the // connection fails or is closed before returning. - ci, err := db.driver.Open(db.dsn) + ci, err := db.connector.Connect(ctx) db.mu.Lock() defer db.mu.Unlock() if db.closed { @@ -953,6 +1038,14 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn conn.Close() return nil, driver.ErrBadConn } + // Lock around reading lastErr to ensure the session resetter finished. + conn.Lock() + err := conn.lastErr + conn.Unlock() + if err == driver.ErrBadConn { + conn.Close() + return nil, driver.ErrBadConn + } return conn, nil } @@ -978,7 +1071,7 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn default: case ret, ok := <-req: if ok { - db.putConn(ret.conn, ret.err) + db.putConn(ret.conn, ret.err, false) } } return nil, ctx.Err() @@ -990,13 +1083,24 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn ret.conn.Close() return nil, driver.ErrBadConn } + if ret.conn == nil { + return nil, ret.err + } + // Lock around reading lastErr to ensure the session resetter finished. + ret.conn.Lock() + err := ret.conn.lastErr + ret.conn.Unlock() + if err == driver.ErrBadConn { + ret.conn.Close() + return nil, driver.ErrBadConn + } return ret.conn, ret.err } } db.numOpen++ // optimistically db.mu.Unlock() - ci, err := db.driver.Open(db.dsn) + ci, err := db.connector.Connect(ctx) if err != nil { db.mu.Lock() db.numOpen-- // correct for earlier optimism @@ -1045,7 +1149,7 @@ const debugGetPut = false // putConn adds a connection to the db's free pool. // err is optionally the last error that occurred on this connection. -func (db *DB) putConn(dc *driverConn, err error) { +func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { db.mu.Lock() if !dc.inUse { if debugGetPut { @@ -1076,11 +1180,40 @@ func (db *DB) putConn(dc *driverConn, err error) { if putConnHook != nil { putConnHook(db, dc) } + if db.closed { + // Connections do not need to be reset if they will be closed. + // Prevents writing to resetterCh after the DB has closed. + resetSession = false + } + if resetSession { + if _, resetSession = dc.ci.(driver.SessionResetter); resetSession { + // Lock the driverConn here so it isn't released until + // the connection is reset. + // The lock must be taken before the connection is put into + // the pool to prevent it from being taken out before it is reset. + dc.Lock() + } + } added := db.putConnDBLocked(dc, nil) db.mu.Unlock() if !added { + if resetSession { + dc.Unlock() + } dc.Close() + return + } + if !resetSession { + return + } + select { + default: + // If the resetterCh is blocking then mark the connection + // as bad and continue on. + dc.lastErr = driver.ErrBadConn + dc.Unlock() + case db.resetterCh <- dc: } } @@ -1242,15 +1375,20 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q defer func() { release(err) }() - if execer, ok := dc.ci.(driver.Execer); ok { - var dargs []driver.NamedValue - dargs, err = driverArgs(dc.ci, nil, args) - if err != nil { - return nil, err - } + execerCtx, ok := dc.ci.(driver.ExecerContext) + var execer driver.Execer + if !ok { + execer, ok = dc.ci.(driver.Execer) + } + if ok { + var nvdargs []driver.NamedValue var resi driver.Result withLock(dc, func() { - resi, err = ctxDriverExec(ctx, execer, query, dargs) + nvdargs, err = driverArgsConnLocked(dc.ci, nil, args) + if err != nil { + return + } + resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs) }) if err != driver.ErrSkip { if err != nil { @@ -1309,15 +1447,21 @@ func (db *DB) query(ctx context.Context, query string, args []interface{}, strat // The ctx context is from a query method and the txctx context is from an // optional transaction context. func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { - if queryer, ok := dc.ci.(driver.Queryer); ok { - dargs, err := driverArgs(dc.ci, nil, args) - if err != nil { - releaseConn(err) - return nil, err - } + queryerCtx, ok := dc.ci.(driver.QueryerContext) + var queryer driver.Queryer + if !ok { + queryer, ok = dc.ci.(driver.Queryer) + } + if ok { + var nvdargs []driver.NamedValue var rowsi driver.Rows + var err error withLock(dc, func() { - rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs) + nvdargs, err = driverArgsConnLocked(dc.ci, nil, args) + if err != nil { + return + } + rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs) }) if err != driver.ErrSkip { if err != nil { @@ -1454,11 +1598,11 @@ func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), // Driver returns the database's underlying driver. func (db *DB) Driver() driver.Driver { - return db.driver + return db.connector.Driver() } // ErrConnDone is returned by any operation that is performed on a connection -// that has already been committed or rolled back. +// that has already been returned to the connection pool. var ErrConnDone = errors.New("database/sql: connection is already closed") // Conn returns a single connection by either opening a new connection @@ -1493,9 +1637,9 @@ func (db *DB) Conn(ctx context.Context) (*Conn, error) { type releaseConn func(error) -// Conn represents a single database session rather a pool of database -// sessions. Prefer running queries from DB unless there is a specific -// need for a continuous single database session. +// Conn represents a single database connection rather than a pool of database +// connections. Prefer running queries from DB unless there is a specific +// need for a continuous single database connection. // // A Conn must call Close to return the connection to the database pool // and may do so concurrently with a running query. @@ -1769,14 +1913,20 @@ func (tx *Tx) closePrepared() { // Commit commits the transaction. func (tx *Tx) Commit() error { - if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { - return ErrTxDone - } + // Check context first to avoid transaction leak. + // If put it behind tx.done CompareAndSwap statement, we cant't ensure + // the consistency between tx.done and the real COMMIT operation. select { default: case <-tx.ctx.Done(): + if atomic.LoadInt32(&tx.done) == 1 { + return ErrTxDone + } return tx.ctx.Err() } + if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) { + return ErrTxDone + } var err error withLock(tx.dc, func() { err = tx.txi.Commit() @@ -1859,6 +2009,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // ... // res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203) // +// The provided context is used for the preparation of the statement, not for the +// execution of the statement. +// // The returned statement operates within the transaction and will be closed // when the transaction has been committed or rolled back. func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { @@ -1902,11 +2055,14 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { stmt.mu.Unlock() if si == nil { - cs, err := stmt.prepareOnConnLocked(ctx, dc) + withLock(dc, func() { + var ds *driverStmt + ds, err = stmt.prepareOnConnLocked(ctx, dc) + si = ds.si + }) if err != nil { return &Stmt{stickyErr: err} } - si = cs.si } parentStmt = stmt } @@ -2098,13 +2254,20 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { } func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) { - dargs, err := driverArgs(ci, ds, args) + ds.Lock() + defer ds.Unlock() + + dargs, err := driverArgsConnLocked(ci, ds, args) if err != nil { return nil, err } - ds.Lock() - defer ds.Unlock() + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + if want := ds.si.NumInput(); want >= 0 && want != len(dargs) { + return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(dargs)) + } resi, err := ctxDriverStmtExec(ctx, ds.si, dargs) if err != nil { @@ -2269,25 +2432,20 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { } func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) { - var want int - withLock(ds, func() { - want = ds.si.NumInput() - }) - - // -1 means the driver doesn't know how to count the number of - // placeholders, so we won't sanity check input here and instead let the - // driver deal with errors. - if want != -1 && len(args) != want { - return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) - } + ds.Lock() + defer ds.Unlock() - dargs, err := driverArgs(ci, ds, args) + dargs, err := driverArgsConnLocked(ci, ds, args) if err != nil { return nil, err } - ds.Lock() - defer ds.Unlock() + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + if want := ds.si.NumInput(); want >= 0 && want != len(dargs) { + return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(dargs)) + } rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs) if err != nil { @@ -2451,9 +2609,16 @@ func (rs *Rows) nextLocked() (doClose, ok bool) { if rs.closed { return false, false } + + // Lock the driver connection before calling the driver interface + // rowsi to prevent a Tx from rolling back the connection at the same time. + rs.dc.Lock() + defer rs.dc.Unlock() + if rs.lastcols == nil { rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) } + rs.lasterr = rs.rowsi.Next(rs.lastcols) if rs.lasterr != nil { // Close the connection if there is a driver error. @@ -2503,6 +2668,12 @@ func (rs *Rows) NextResultSet() bool { doClose = true return false } + + // Lock the driver connection before calling the driver interface + // rowsi to prevent a Tx from rolling back the connection at the same time. + rs.dc.Lock() + defer rs.dc.Unlock() + rs.lasterr = nextResultSet.NextResultSet() if rs.lasterr != nil { doClose = true @@ -2534,6 +2705,9 @@ func (rs *Rows) Columns() ([]string, error) { if rs.rowsi == nil { return nil, errors.New("sql: no Rows available") } + rs.dc.Lock() + defer rs.dc.Unlock() + return rs.rowsi.Columns(), nil } @@ -2548,7 +2722,10 @@ func (rs *Rows) ColumnTypes() ([]*ColumnType, error) { if rs.rowsi == nil { return nil, errors.New("sql: no Rows available") } - return rowsColumnInfoSetup(rs.rowsi), nil + rs.dc.Lock() + defer rs.dc.Unlock() + + return rowsColumnInfoSetupConnLocked(rs.rowsi), nil } // ColumnType contains the name and type of a column. @@ -2609,7 +2786,7 @@ func (ci *ColumnType) DatabaseTypeName() string { return ci.databaseType } -func rowsColumnInfoSetup(rowsi driver.Rows) []*ColumnType { +func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType { names := rowsi.Columns() list := make([]*ColumnType, len(names)) |