diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go index 925a1233..6594b938 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go @@ -187,114 +187,145 @@ func (d *SyncServerDatabase) IncrementalSync( userID string, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int, -) (res *types.Response, returnErr error) { - returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { - // Work out which rooms to return in the response. This is done by getting not only the currently - // joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. - // This works out what the 'state' key should be for each room as well as which membership block - // to put the room into. - deltas, err := d.getStateDeltas(ctx, txn, fromPos, toPos, userID) +) (*types.Response, error) { + txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) + if err != nil { + return nil, err + } + var succeeded bool + defer common.EndTransaction(txn, &succeeded) + + // Work out which rooms to return in the response. This is done by getting not only the currently + // joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. + // This works out what the 'state' key should be for each room as well as which membership block + // to put the room into. + deltas, err := d.getStateDeltas(ctx, txn, fromPos, toPos, userID) + if err != nil { + return nil, err + } + + res := types.NewResponse(toPos) + for _, delta := range deltas { + endPos := toPos + if delta.membershipPos > 0 && delta.membership == "leave" { + // make sure we don't leak recent events after the leave event. + // TODO: History visibility makes this somewhat complex to handle correctly. For example: + // TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join). + // TODO: This will fail on join -> leave -> sensitive msg -> join -> leave + // in a single /sync request + // This is all "okay" assuming history_visibility == "shared" which it is by default. + endPos = delta.membershipPos + } + var recentStreamEvents []streamEvent + recentStreamEvents, err = d.events.selectRecentEvents( + ctx, txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom, + ) if err != nil { - return err + return nil, err } + recentEvents := streamEventsToEvents(recentStreamEvents) + delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back - res = types.NewResponse(toPos) - for _, delta := range deltas { - endPos := toPos - if delta.membershipPos > 0 && delta.membership == "leave" { - // make sure we don't leak recent events after the leave event. - // TODO: History visibility makes this somewhat complex to handle correctly. For example: - // TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join). - // TODO: This will fail on join -> leave -> sensitive msg -> join -> leave - // in a single /sync request - // This is all "okay" assuming history_visibility == "shared" which it is by default. - endPos = delta.membershipPos - } - recentStreamEvents, err := d.events.selectRecentEvents( - ctx, txn, delta.roomID, fromPos, endPos, numRecentEventsPerRoom, - ) - if err != nil { - return err - } - recentEvents := streamEventsToEvents(recentStreamEvents) - delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back - - switch delta.membership { - case "join": - jr := types.NewJoinResponse() - jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true - jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Join[delta.roomID] = *jr - case "leave": - fallthrough // transitions to leave are the same as ban - case "ban": - // TODO: recentEvents may contain events that this user is not allowed to see because they are - // no longer in the room. - lr := types.NewLeaveResponse() - lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true - lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Leave[delta.roomID] = *lr - } + switch delta.membership { + case "join": + jr := types.NewJoinResponse() + jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) + jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true + jr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) + res.Rooms.Join[delta.roomID] = *jr + case "leave": + fallthrough // transitions to leave are the same as ban + case "ban": + // TODO: recentEvents may contain events that this user is not allowed to see because they are + // no longer in the room. + lr := types.NewLeaveResponse() + lr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) + lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true + lr.State.Events = gomatrixserverlib.ToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) + res.Rooms.Leave[delta.roomID] = *lr } + } - // TODO: This should be done in getStateDeltas - return d.addInvitesToResponse(ctx, txn, userID, res) - }) - return + // TODO: This should be done in getStateDeltas + if err = d.addInvitesToResponse(ctx, txn, userID, res); err != nil { + return nil, err + } + + succeeded = true + return res, nil } // CompleteSync a complete /sync API response for the given user. func (d *SyncServerDatabase) CompleteSync( ctx context.Context, userID string, numRecentEventsPerRoom int, -) (res *types.Response, returnErr error) { +) (*types.Response, error) { // This needs to be all done in a transaction as we need to do multiple SELECTs, and we need to have // a consistent view of the database throughout. This includes extracting the sync stream position. // This does have the unfortunate side-effect that all the matrixy logic resides in this function, // but it's better to not hide the fact that this is being done in a transaction. - returnErr = common.WithTransaction(d.db, func(txn *sql.Tx) error { - // Get the current stream position which we will base the sync response on. - id, err := d.events.selectMaxID(ctx, txn) + txn, err := d.db.BeginTx(ctx, &txReadOnlySnapshot) + if err != nil { + return nil, err + } + var succeeded bool + defer common.EndTransaction(txn, &succeeded) + + // Get the current stream position which we will base the sync response on. + id, err := d.events.selectMaxID(ctx, txn) + if err != nil { + return nil, err + } + pos := types.StreamPosition(id) + + // Extract room state and recent events for all rooms the user is joined to. + roomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") + if err != nil { + return nil, err + } + + // Build up a /sync response. Add joined rooms. + res := types.NewResponse(pos) + for _, roomID := range roomIDs { + var stateEvents []gomatrixserverlib.Event + stateEvents, err = d.roomstate.selectCurrentState(ctx, txn, roomID) if err != nil { - return err + return nil, err } - pos := types.StreamPosition(id) - - // Extract room state and recent events for all rooms the user is joined to. - roomIDs, err := d.roomstate.selectRoomIDsWithMembership(ctx, txn, userID, "join") + // TODO: When filters are added, we may need to call this multiple times to get enough events. + // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 + var recentStreamEvents []streamEvent + recentStreamEvents, err = d.events.selectRecentEvents( + ctx, txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom, + ) if err != nil { - return err + return nil, err } + recentEvents := streamEventsToEvents(recentStreamEvents) - // Build up a /sync response. Add joined rooms. - res = types.NewResponse(pos) - for _, roomID := range roomIDs { - stateEvents, err := d.roomstate.selectCurrentState(ctx, txn, roomID) - if err != nil { - return err - } - // TODO: When filters are added, we may need to call this multiple times to get enough events. - // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 - recentStreamEvents, err := d.events.selectRecentEvents( - ctx, txn, roomID, types.StreamPosition(0), pos, numRecentEventsPerRoom, - ) - if err != nil { - return err - } - recentEvents := streamEventsToEvents(recentStreamEvents) + stateEvents = removeDuplicates(stateEvents, recentEvents) + jr := types.NewJoinResponse() + jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) + jr.Timeline.Limited = true + jr.State.Events = gomatrixserverlib.ToClientEvents(stateEvents, gomatrixserverlib.FormatSync) + res.Rooms.Join[roomID] = *jr + } - stateEvents = removeDuplicates(stateEvents, recentEvents) - jr := types.NewJoinResponse() - jr.Timeline.Events = gomatrixserverlib.ToClientEvents(recentEvents, gomatrixserverlib.FormatSync) - jr.Timeline.Limited = true - jr.State.Events = gomatrixserverlib.ToClientEvents(stateEvents, gomatrixserverlib.FormatSync) - res.Rooms.Join[roomID] = *jr - } + if err = d.addInvitesToResponse(ctx, txn, userID, res); err != nil { + return nil, err + } - return d.addInvitesToResponse(ctx, txn, userID, res) - }) - return + succeeded = true + return res, err +} + +var txReadOnlySnapshot = sql.TxOptions{ + // Set the isolation level so that we see a snapshot of the database. + // In PostgreSQL repeatable read transactions will see a snapshot taken + // at the first query, and since the transaction is read-only it can't + // run into any serialisation errors. + // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, } // GetAccountDataInRange returns all account data for a given user inserted or diff --git a/vendor/manifest b/vendor/manifest index 70cc5198..cac40142 100644 --- a/vendor/manifest +++ b/vendor/manifest @@ -98,7 +98,7 @@ { "importpath": "github.com/lib/pq", "repository": "https://github.com/lib/pq", - "revision": "a6657b2386e9b8be76484c08711b02c7cf867ead", + "revision": "23da1db4f16d9658a86ae9b717c245fc078f10f1", "branch": "master" }, { diff --git a/vendor/src/github.com/lib/pq/README.md b/vendor/src/github.com/lib/pq/README.md index 5eb9e144..7670fc87 100644 --- a/vendor/src/github.com/lib/pq/README.md +++ b/vendor/src/github.com/lib/pq/README.md @@ -1,6 +1,6 @@ # pq - A pure Go postgres driver for Go's database/sql package -[![Build Status](https://travis-ci.org/lib/pq.png?branch=master)](https://travis-ci.org/lib/pq) +[![Build Status](https://travis-ci.org/lib/pq.svg?branch=master)](https://travis-ci.org/lib/pq) ## Install diff --git a/vendor/src/github.com/lib/pq/array.go b/vendor/src/github.com/lib/pq/array.go index e7b2145d..e4933e22 100644 --- a/vendor/src/github.com/lib/pq/array.go +++ b/vendor/src/github.com/lib/pq/array.go @@ -13,7 +13,7 @@ import ( var typeByteSlice = reflect.TypeOf([]byte{}) var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() -var typeSqlScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() +var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() // Array returns the optimal driver.Valuer and sql.Scanner for an array or // slice of any dimension. @@ -278,7 +278,7 @@ func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]b // TODO calculate the assign function for other types // TODO repeat this section on the element type of arrays or slices (multidimensional) { - if reflect.PtrTo(rt).Implements(typeSqlScanner) { + if reflect.PtrTo(rt).Implements(typeSQLScanner) { // dest is always addressable because it is an element of a slice. assign = func(src []byte, dest reflect.Value) (err error) { ss := dest.Addr().Interface().(sql.Scanner) @@ -587,7 +587,7 @@ func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { } } - var del string = "," + var del = "," var err error var iv interface{} = rv.Interface() diff --git a/vendor/src/github.com/lib/pq/conn.go b/vendor/src/github.com/lib/pq/conn.go index 3c8f77cb..338a0bc1 100644 --- a/vendor/src/github.com/lib/pq/conn.go +++ b/vendor/src/github.com/lib/pq/conn.go @@ -27,22 +27,22 @@ var ( ErrNotSupported = errors.New("pq: Unsupported command") ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") - ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.") - ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.") + ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less") + ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") errUnexpectedReady = errors.New("unexpected ReadyForQuery") errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") - errNoLastInsertId = errors.New("no LastInsertId available after the empty statement") + errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) -type drv struct{} +type Driver struct{} -func (d *drv) Open(name string) (driver.Conn, error) { +func (d *Driver) Open(name string) (driver.Conn, error) { return Open(name) } func init() { - sql.Register("postgres", &drv{}) + sql.Register("postgres", &Driver{}) } type parameterStatus struct { @@ -98,7 +98,7 @@ type conn struct { namei int scratch [512]byte txnStatus transactionStatus - txnClosed chan<- struct{} + txnFinish func() // Save connection arguments to use during CancelRequest. dialer Dialer @@ -131,9 +131,9 @@ type conn struct { } // Handle driver-side settings in parsed connection string. -func (c *conn) handleDriverSettings(o values) (err error) { +func (cn *conn) handleDriverSettings(o values) (err error) { boolSetting := func(key string, val *bool) error { - if value := o.Get(key); value != "" { + if value, ok := o[key]; ok { if value == "yes" { *val = true } else if value == "no" { @@ -145,21 +145,20 @@ func (c *conn) handleDriverSettings(o values) (err error) { return nil } - err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult) + err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) if err != nil { return err } - err = boolSetting("binary_parameters", &c.binaryParameters) + err = boolSetting("binary_parameters", &cn.binaryParameters) if err != nil { return err } return nil } -func (c *conn) handlePgpass(o values) { +func (cn *conn) handlePgpass(o values) { // if a password was supplied, do not process .pgpass - _, ok := o["password"] - if ok { + if _, ok := o["password"]; ok { return } filename := os.Getenv("PGPASSFILE") @@ -187,11 +186,11 @@ func (c *conn) handlePgpass(o values) { } defer file.Close() scanner := bufio.NewScanner(io.Reader(file)) - hostname := o.Get("host") + hostname := o["host"] ntw, _ := network(o) - port := o.Get("port") - db := o.Get("dbname") - username := o.Get("user") + port := o["port"] + db := o["dbname"] + username := o["user"] // From: https://github.com/tg/pgpass/blob/master/reader.go getFields := func(s string) []string { fs := make([]string, 0, 5) @@ -230,10 +229,10 @@ func (c *conn) handlePgpass(o values) { } } -func (c *conn) writeBuf(b byte) *writeBuf { - c.scratch[0] = b +func (cn *conn) writeBuf(b byte) *writeBuf { + cn.scratch[0] = b return &writeBuf{ - buf: c.scratch[:5], + buf: cn.scratch[:5], pos: 1, } } @@ -256,13 +255,13 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // * Very low precedence defaults applied in every situation // * Environment variables // * Explicitly passed connection information - o.Set("host", "localhost") - o.Set("port", "5432") + o["host"] = "localhost" + o["port"] = "5432" // N.B.: Extra float digits should be set to 3, but that breaks // Postgres 8.4 and older, where the max is 2. - o.Set("extra_float_digits", "2") + o["extra_float_digits"] = "2" for k, v := range parseEnviron(os.Environ()) { - o.Set(k, v) + o[k] = v } if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { @@ -277,9 +276,9 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { } // Use the "fallback" application name if necessary - if fallback := o.Get("fallback_application_name"); fallback != "" { - if !o.Isset("application_name") { - o.Set("application_name", fallback) + if fallback, ok := o["fallback_application_name"]; ok { + if _, ok := o["application_name"]; !ok { + o["application_name"] = fallback } } @@ -290,30 +289,29 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // parsing its value is not worth it. Instead, we always explicitly send // client_encoding as a separate run-time parameter, which should override // anything set in options. - if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) { + if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { return nil, errors.New("client_encoding must be absent or 'UTF8'") } - o.Set("client_encoding", "UTF8") + o["client_encoding"] = "UTF8" // DateStyle needs a similar treatment. - if datestyle := o.Get("datestyle"); datestyle != "" { + if datestyle, ok := o["datestyle"]; ok { if datestyle != "ISO, MDY" { panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle)) } } else { - o.Set("datestyle", "ISO, MDY") + o["datestyle"] = "ISO, MDY" } // If a user is not provided by any other means, the last // resort is to use the current operating system provided user // name. - if o.Get("user") == "" { + if _, ok := o["user"]; !ok { u, err := userCurrent() if err != nil { return nil, err - } else { - o.Set("user", u) } + o["user"] = u } cn := &conn{ @@ -335,7 +333,7 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { cn.startup(o) // reset the deadline, in case one was set (see dial) - if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { err = cn.c.SetDeadline(time.Time{}) } return cn, err @@ -349,7 +347,7 @@ func dial(d Dialer, o values) (net.Conn, error) { } // Zero or not specified means wait indefinitely. - if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { seconds, err := strconv.ParseInt(timeout, 10, 0) if err != nil { return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) @@ -371,31 +369,18 @@ func dial(d Dialer, o values) (net.Conn, error) { } func network(o values) (string, string) { - host := o.Get("host") + host := o["host"] if strings.HasPrefix(host, "/") { - sockPath := path.Join(host, ".s.PGSQL."+o.Get("port")) + sockPath := path.Join(host, ".s.PGSQL."+o["port"]) return "unix", sockPath } - return "tcp", net.JoinHostPort(host, o.Get("port")) + return "tcp", net.JoinHostPort(host, o["port"]) } type values map[string]string -func (vs values) Set(k, v string) { - vs[k] = v -} - -func (vs values) Get(k string) (v string) { - return vs[k] -} - -func (vs values) Isset(k string) bool { - _, ok := vs[k] - return ok -} - // scanner implements a tokenizer for libpq-style option strings. type scanner struct { s []rune @@ -466,7 +451,7 @@ func parseOpts(name string, o values) error { // Skip any whitespace after the = if r, ok = s.SkipSpaces(); !ok { // If we reach the end here, the last value is just an empty string as per libpq. - o.Set(string(keyRunes), "") + o[string(keyRunes)] = "" break } @@ -501,7 +486,7 @@ func parseOpts(name string, o values) error { } } - o.Set(string(keyRunes), string(valRunes)) + o[string(keyRunes)] = string(valRunes) } return nil @@ -520,13 +505,17 @@ func (cn *conn) checkIsInTransaction(intxn bool) { } func (cn *conn) Begin() (_ driver.Tx, err error) { + return cn.begin("") +} + +func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if cn.bad { return nil, driver.ErrBadConn } defer cn.errRecover(&err) cn.checkIsInTransaction(false) - _, commandTag, err := cn.simpleExec("BEGIN") + _, commandTag, err := cn.simpleExec("BEGIN" + mode) if err != nil { return nil, err } @@ -542,9 +531,8 @@ func (cn *conn) Begin() (_ driver.Tx, err error) { } func (cn *conn) closeTxn() { - if cn.txnClosed != nil { - close(cn.txnClosed) - cn.txnClosed = nil + if finish := cn.txnFinish; finish != nil { + finish() } } @@ -665,6 +653,12 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { cn: cn, } } + // Set the result and tag to the last command complete if there wasn't a + // query already run. Although queries usually return from here and cede + // control to Next, a query with zero results does not. + if t == 'C' && res.colNames == nil { + res.result, res.tag = cn.parseComplete(r.string()) + } res.done = true case 'Z': cn.processReadyForQuery(r) @@ -703,7 +697,7 @@ var emptyRows noRows var _ driver.Result = noRows{} func (noRows) LastInsertId() (int64, error) { - return 0, errNoLastInsertId + return 0, errNoLastInsertID } func (noRows) RowsAffected() (int64, error) { @@ -712,7 +706,7 @@ func (noRows) RowsAffected() (int64, error) { // Decides which column formats to use for a prepared statement. The input is // an array of type oids, one element per result column. -func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) { +func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) { if len(colTyps) == 0 { return nil, colFmtDataAllText } @@ -724,8 +718,8 @@ func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, c allBinary := true allText := true - for i, o := range colTyps { - switch o { + for i, t := range colTyps { + switch t.OID { // This is the list of types to use binary mode for when receiving them // through a prepared statement. If a type appears in this list, it // must also be implemented in binaryDecode in encode.go. @@ -845,16 +839,15 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() cn.postExecuteWorkaround() return rows, nil - } else { - st := cn.prepareTo(query, "") - st.exec(args) - return &rows{ - cn: cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, - }, nil } + st := cn.prepareTo(query, "") + st.exec(args) + return &rows{ + cn: cn, + colNames: st.colNames, + colTyps: st.colTyps, + colFmts: st.colFmts, + }, nil } // Implement the optional "Execer" interface for one-shot queries @@ -881,17 +874,16 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err cn.postExecuteWorkaround() res, _, err = cn.readExecuteResponse("Execute") return res, err - } else { - // Use the unnamed statement to defer planning until bind - // time, or else value-based selectivity estimates cannot be - // used. - st := cn.prepareTo(query, "") - r, err := st.Exec(args) - if err != nil { - panic(err) - } - return r, err } + // Use the unnamed statement to defer planning until bind + // time, or else value-based selectivity estimates cannot be + // used. + st := cn.prepareTo(query, "") + r, err := st.Exec(args) + if err != nil { + panic(err) + } + return r, err } func (cn *conn) send(m *writeBuf) { @@ -901,16 +893,9 @@ func (cn *conn) send(m *writeBuf) { } } -func (cn *conn) sendStartupPacket(m *writeBuf) { - // sanity check - if m.buf[0] != 0 { - panic("oops") - } - +func (cn *conn) sendStartupPacket(m *writeBuf) error { _, err := cn.c.Write((m.wrap())[1:]) - if err != nil { - panic(err) - } + return err } // Send a message of type typ to the server on the other end of cn. The @@ -1032,7 +1017,9 @@ func (cn *conn) ssl(o values) { w := cn.writeBuf(0) w.int32(80877103) - cn.sendStartupPacket(w) + if err := cn.sendStartupPacket(w); err != nil { + panic(err) + } b := cn.scratch[:1] _, err := io.ReadFull(cn.c, b) @@ -1093,7 +1080,9 @@ func (cn *conn) startup(o values) { w.string(v) } w.string("") - cn.sendStartupPacket(w) + if err := cn.sendStartupPacket(w); err != nil { + panic(err) + } for { t, r := cn.recv() @@ -1119,7 +1108,7 @@ func (cn *conn) auth(r *readBuf, o values) { // OK case 3: w := cn.writeBuf('p') - w.string(o.Get("password")) + w.string(o["password"]) cn.send(w) t, r := cn.recv() @@ -1133,7 +1122,7 @@ func (cn *conn) auth(r *readBuf, o values) { case 5: s := string(r.next(4)) w := cn.writeBuf('p') - w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s)) + w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) cn.send(w) t, r := cn.recv() @@ -1155,10 +1144,10 @@ const formatText format = 0 const formatBinary format = 1 // One result-column format code with the value 1 (i.e. all binary). -var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1} +var colFmtDataAllBinary = []byte{0, 1, 0, 1} // No result-column format codes (i.e. all text). -var colFmtDataAllText []byte = []byte{0, 0} +var colFmtDataAllText = []byte{0, 0} type stmt struct { cn *conn @@ -1166,7 +1155,7 @@ type stmt struct { colNames []string colFmts []format colFmtData []byte - colTyps []oid.Oid + colTyps []fieldDesc paramTyps []oid.Oid closed bool } @@ -1327,17 +1316,19 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { type rows struct { cn *conn - closed chan<- struct{} + finish func() colNames []string - colTyps []oid.Oid + colTyps []fieldDesc colFmts []format done bool rb readBuf + result driver.Result + tag string } func (rs *rows) Close() error { - if rs.closed != nil { - defer close(rs.closed) + if finish := rs.finish; finish != nil { + defer finish() } // no need to look at cn.bad as Next() will for { @@ -1345,7 +1336,12 @@ func (rs *rows) Close() error { switch err { case nil: case io.EOF: - return nil + // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row + // description, used with HasNextResultSet). We need to fetch messages until + // we hit a 'Z', which is done by waiting for done to be set. + if rs.done { + return nil + } default: return err } @@ -1356,6 +1352,17 @@ func (rs *rows) Columns() []string { return rs.colNames } +func (rs *rows) Result() driver.Result { + if rs.result == nil { + return emptyRows + } + return rs.result +} + +func (rs *rows) Tag() string { + return rs.tag +} + func (rs *rows) Next(dest []driver.Value) (err error) { if rs.done { return io.EOF @@ -1373,6 +1380,9 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'E': err = parseError(&rs.rb) case 'C', 'I': + if t == 'C' { + rs.result, rs.tag = conn.parseComplete(rs.rb.string()) + } continue case 'Z': conn.processReadyForQuery(&rs.rb) @@ -1396,7 +1406,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) { dest[i] = nil continue } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i]) + dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) } return case 'T': @@ -1502,7 +1512,7 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { cn.send(b) } -func (c *conn) processParameterStatus(r *readBuf) { +func (cn *conn) processParameterStatus(r *readBuf) { var err error param := r.string() @@ -1513,13 +1523,13 @@ func (c *conn) processParameterStatus(r *readBuf) { var minor int _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) if err == nil { - c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor + cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor } case "TimeZone": - c.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) + cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) if err != nil { - c.parameterStatus.currentLocation = nil + cn.parameterStatus.currentLocation = nil } default: @@ -1527,8 +1537,8 @@ func (c *conn) processParameterStatus(r *readBuf) { } } -func (c *conn) processReadyForQuery(r *readBuf) { - c.txnStatus = transactionStatus(r.byte()) +func (cn *conn) processReadyForQuery(r *readBuf) { + cn.txnStatus = transactionStatus(r.byte()) } func (cn *conn) readReadyForQuery() { @@ -1543,9 +1553,9 @@ func (cn *conn) readReadyForQuery() { } } -func (c *conn) processBackendKeyData(r *readBuf) { - c.processID = r.int32() - c.secretKey = r.int32() +func (cn *conn) processBackendKeyData(r *readBuf) { + cn.processID = r.int32() + cn.secretKey = r.int32() } func (cn *conn) readParseResponse() { @@ -1563,7 +1573,7 @@ func (cn *conn) readParseResponse() { } } -func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) { +func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) { for { t, r := cn.recv1() switch t { @@ -1589,7 +1599,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ } } -func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) { +func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) { t, r := cn.recv1() switch t { case 'T': @@ -1685,31 +1695,33 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co } } -func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) { +func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { n := r.int16() colNames = make([]string, n) - colTyps = make([]oid.Oid, n) + colTyps = make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) - colTyps[i] = r.oid() - r.next(6) + colTyps[i].OID = r.oid() + colTyps[i].Len = r.int16() + colTyps[i].Mod = r.int32() // format code not known when describing a statement; always 0 r.next(2) } return } -func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) { +func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) { n := r.int16() colNames = make([]string, n) colFmts = make([]format, n) - colTyps = make([]oid.Oid, n) + colTyps = make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) - colTyps[i] = r.oid() - r.next(6) + colTyps[i].OID = r.oid() + colTyps[i].Len = r.int16() + colTyps[i].Mod = r.int32() colFmts[i] = format(r.int16()) } return diff --git a/vendor/src/github.com/lib/pq/conn_go18.go b/vendor/src/github.com/lib/pq/conn_go18.go index 0aca1d00..ab97a104 100644 --- a/vendor/src/github.com/lib/pq/conn_go18.go +++ b/vendor/src/github.com/lib/pq/conn_go18.go @@ -4,8 +4,11 @@ package pq import ( "context" + "database/sql" "database/sql/driver" - "errors" + "fmt" + "io" + "io/ioutil" ) // Implement the "QueryerContext" interface @@ -14,15 +17,15 @@ func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.Na for i, nv := range args { list[i] = nv.Value } - var closed chan<- struct{} - if ctx.Done() != nil { - closed = watchCancel(ctx, cn.cancel) - } + finish := cn.watchCancel(ctx) r, err := cn.query(query, list) if err != nil { + if finish != nil { + finish() + } return nil, err } - r.closed = closed + r.finish = finish return r, nil } @@ -33,9 +36,8 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam list[i] = nv.Value } - if ctx.Done() != nil { - closed := watchCancel(ctx, cn.cancel) - defer close(closed) + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() } return cn.Exec(query, list) @@ -43,50 +45,84 @@ func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.Nam // Implement the "ConnBeginTx" interface func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - if opts.Isolation != 0 { - return nil, errors.New("isolation levels not supported") + var mode string + + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + // Don't touch mode: use the server's default + case sql.LevelReadUncommitted: + mode = " ISOLATION LEVEL READ UNCOMMITTED" + case sql.LevelReadCommitted: + mode = " ISOLATION LEVEL READ COMMITTED" + case sql.LevelRepeatableRead: + mode = " ISOLATION LEVEL REPEATABLE READ" + case sql.LevelSerializable: + mode = " ISOLATION LEVEL SERIALIZABLE" + default: + return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) } + if opts.ReadOnly { - return nil, errors.New("read-only transactions not supported") + mode += " READ ONLY" + } else { + mode += " READ WRITE" } - tx, err := cn.Begin() + + tx, err := cn.begin(mode) if err != nil { return nil, err } - if ctx.Done() != nil { - cn.txnClosed = watchCancel(ctx, cn.cancel) - } + cn.txnFinish = cn.watchCancel(ctx) return tx, nil } -func watchCancel(ctx context.Context, cancel func()) chan<- struct{} { - closed := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - cancel() - case <-closed: +func (cn *conn) watchCancel(ctx context.Context) func() { + if done := ctx.Done(); done != nil { + finished := make(chan struct{}) + go func() { + select { + case <-done: + _ = cn.cancel() + finished <- struct{}{} + case <-finished: + } + }() + return func() { + select { + case <-finished: + case finished <- struct{}{}: + } } - }() - return closed -} - -func (cn *conn) cancel() { - var err error - can := &conn{} - can.c, err = dial(cn.dialer, cn.opts) - if err != nil { - return } - can.ssl(cn.opts) - - defer can.errRecover(&err) - - w := can.writeBuf(0) - w.int32(80877102) // cancel request code - w.int32(cn.processID) - w.int32(cn.secretKey) - - can.sendStartupPacket(w) - _ = can.c.Close() + return nil +} + +func (cn *conn) cancel() error { + c, err := dial(cn.dialer, cn.opts) + if err != nil { + return err + } + defer c.Close() + + { + can := conn{ + c: c, + } + can.ssl(cn.opts) + + w := can.writeBuf(0) + w.int32(80877102) // cancel request code + w.int32(cn.processID) + w.int32(cn.secretKey) + + if err := can.sendStartupPacket(w); err != nil { + return err + } + } + + // Read until EOF to ensure that the server received the cancel. + { + _, err := io.Copy(ioutil.Discard, c) + return err + } } diff --git a/vendor/src/github.com/lib/pq/conn_test.go b/vendor/src/github.com/lib/pq/conn_test.go index 183e6dcd..d3c01f34 100644 --- a/vendor/src/github.com/lib/pq/conn_test.go +++ b/vendor/src/github.com/lib/pq/conn_test.go @@ -136,7 +136,7 @@ func TestOpenURL(t *testing.T) { testURL("postgresql://") } -const pgpass_file = "/tmp/pqgotest_pgpass" +const pgpassFile = "/tmp/pqgotest_pgpass" func TestPgpass(t *testing.T) { if os.Getenv("TRAVIS") != "true" { @@ -160,11 +160,11 @@ func TestPgpass(t *testing.T) { rows, err := txn.Query("SELECT USER") if err != nil { txn.Rollback() - rows.Close() if expected != "fail" { t.Fatalf(reason, err) } } else { + rows.Close() if expected != "ok" { t.Fatalf(reason, err) } @@ -172,10 +172,10 @@ func TestPgpass(t *testing.T) { txn.Rollback() } testAssert("", "ok", "missing .pgpass, unexpected error %#v") - os.Setenv("PGPASSFILE", pgpass_file) + os.Setenv("PGPASSFILE", pgpassFile) testAssert("host=/tmp", "fail", ", unexpected error %#v") - os.Remove(pgpass_file) - pgpass, err := os.OpenFile(pgpass_file, os.O_RDWR|os.O_CREATE, 0644) + os.Remove(pgpassFile) + pgpass, err := os.OpenFile(pgpassFile, os.O_RDWR|os.O_CREATE, 0644) if err != nil { t.Fatalf("Unexpected error writing pgpass file %#v", err) } @@ -191,7 +191,7 @@ localhost:*:*:*:pass_C pgpass.Close() assertPassword := func(extra values, expected string) { - o := &values{ + o := values{ "host": "localhost", "sslmode": "disable", "connect_timeout": "20", @@ -203,17 +203,17 @@ localhost:*:*:*:pass_C "datestyle": "ISO, MDY", } for k, v := range extra { - (*o)[k] = v + o[k] = v } - (&conn{}).handlePgpass(*o) - if o.Get("password") != expected { - t.Fatalf("For %v expected %s got %s", extra, expected, o.Get("password")) + (&conn{}).handlePgpass(o) + if pw := o["password"]; pw != expected { + t.Fatalf("For %v expected %s got %s", extra, expected, pw) } } // wrong permissions for the pgpass file means it should be ignored assertPassword(values{"host": "example.com", "user": "foo"}, "") // fix the permissions and check if it has taken effect - os.Chmod(pgpass_file, 0600) + os.Chmod(pgpassFile, 0600) assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "pass_A") assertPassword(values{"host": "example.com", "user": "foo"}, "pass_fallback") assertPassword(values{"host": "example.com", "dbname": "some_db", "user": "some_user"}, "pass_B") @@ -221,7 +221,7 @@ localhost:*:*:*:pass_C assertPassword(values{"host": "", "user": "some_user"}, "pass_C") assertPassword(values{"host": "/tmp", "user": "some_user"}, "pass_C") // cleanup - os.Remove(pgpass_file) + os.Remove(pgpassFile) os.Setenv("PGPASSFILE", "") } @@ -393,8 +393,8 @@ func TestEmptyQuery(t *testing.T) { if _, err := res.RowsAffected(); err != errNoRowsAffected { t.Fatalf("expected %s, got %v", errNoRowsAffected, err) } - if _, err := res.LastInsertId(); err != errNoLastInsertId { - t.Fatalf("expected %s, got %v", errNoLastInsertId, err) + if _, err := res.LastInsertId(); err != errNoLastInsertID { + t.Fatalf("expected %s, got %v", errNoLastInsertID, err) } rows, err := db.Query("") if err != nil { @@ -425,8 +425,8 @@ func TestEmptyQuery(t *testing.T) { if _, err := res.RowsAffected(); err != errNoRowsAffected { t.Fatalf("expected %s, got %v", errNoRowsAffected, err) } - if _, err := res.LastInsertId(); err != errNoLastInsertId { - t.Fatalf("expected %s, got %v", errNoLastInsertId, err) + if _, err := res.LastInsertId(); err != errNoLastInsertID { + t.Fatalf("expected %s, got %v", errNoLastInsertID, err) } rows, err = stmt.Query() if err != nil { @@ -686,17 +686,28 @@ func TestCloseBadConn(t *testing.T) { if err := cn.Close(); err != nil { t.Fatal(err) } + + // During the Go 1.9 cycle, https://github.com/golang/go/commit/3792db5 + // changed this error from + // + // net.errClosing = errors.New("use of closed network connection") + // + // to + // + // internal/poll.ErrClosing = errors.New("use of closed file or network connection") + const errClosing = "use of closed" + // Verify write after closing fails. if _, err := nc.Write(nil); err == nil { t.Fatal("expected error") - } else if !strings.Contains(err.Error(), "use of closed network connection") { - t.Fatalf("expected use of closed network connection error, got %s", err) + } else if !strings.Contains(err.Error(), errClosing) { + t.Fatalf("expected %s error, got %s", errClosing, err) } // Verify second close fails. if err := cn.Close(); err == nil { t.Fatal("expected error") - } else if !strings.Contains(err.Error(), "use of closed network connection") { - t.Fatalf("expected use of closed network connection error, got %s", err) + } else if !strings.Contains(err.Error(), errClosing) { + t.Fatalf("expected %s error, got %s", errClosing, err) } } @@ -1042,16 +1053,16 @@ func TestIssue282(t *testing.T) { db := openTestConn(t) defer db.Close() - var search_path string + var searchPath string err := db.QueryRow(` SET LOCAL search_path TO pg_catalog; SET LOCAL search_path TO pg_catalog; - SHOW search_path`).Scan(&search_path) + SHOW search_path`).Scan(&searchPath) if err != nil { t.Fatal(err) } - if search_path != "pg_catalog" { - t.Fatalf("unexpected search_path %s", search_path) + if searchPath != "pg_catalog" { + t.Fatalf("unexpected search_path %s", searchPath) } } @@ -1493,3 +1504,111 @@ func TestQuoteIdentifier(t *testing.T) { } } } + +func TestRowsResultTag(t *testing.T) { + type ResultTag interface { + Result() driver.Result + Tag() string + } + + tests := []struct { + query string + tag string + ra int64 + }{ + { + query: "CREATE TEMP TABLE temp (a int)", + tag: "CREATE TABLE", + }, + { + query: "INSERT INTO temp VALUES (1), (2)", + tag: "INSERT", + ra: 2, + }, + { + query: "SELECT 1", + }, + // A SELECT anywhere should take precedent. + { + query: "SELECT 1; INSERT INTO temp VALUES (1), (2)", + }, + { + query: "INSERT INTO temp VALUES (1), (2); SELECT 1", + }, + // Multiple statements that don't return rows should return the last tag. + { + query: "CREATE TEMP TABLE t (a int); DROP TABLE t", + tag: "DROP TABLE", + }, + // Ensure a rows-returning query in any position among various tags-returing + // statements will prefer the rows. + { + query: "SELECT 1; CREATE TEMP TABLE t (a int); DROP TABLE t", + }, + { + query: "CREATE TEMP TABLE t (a int); SELECT 1; DROP TABLE t", + }, + { + query: "CREATE TEMP TABLE t (a int); DROP TABLE t; SELECT 1", + }, + // Verify that an no-results query doesn't set the tag. + { + query: "CREATE TEMP TABLE t (a int); SELECT 1 WHERE FALSE; DROP TABLE t;", + }, + } + + // If this is the only test run, this will correct the connection string. + openTestConn(t).Close() + + conn, err := Open("") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + q := conn.(driver.Queryer) + + for _, test := range tests { + if rows, err := q.Query(test.query, nil); err != nil { + t.Fatalf("%s: %s", test.query, err) + } else { + r := rows.(ResultTag) + if tag := r.Tag(); tag != test.tag { + t.Fatalf("%s: unexpected tag %q", test.query, tag) + } + res := r.Result() + if ra, _ := res.RowsAffected(); ra != test.ra { + t.Fatalf("%s: unexpected rows affected: %d", test.query, ra) + } + rows.Close() + } + } +} + +// TestQuickClose tests that closing a query early allows a subsequent query to work. +func TestQuickClose(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + rows, err := tx.Query("SELECT 1; SELECT 2;") + if err != nil { + t.Fatal(err) + } + if err := rows.Close(); err != nil { + t.Fatal(err) + } + + var id int + if err := tx.QueryRow("SELECT 3").Scan(&id); err != nil { + t.Fatal(err) + } + if id != 3 { + t.Fatalf("unexpected %d", id) + } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } +} diff --git a/vendor/src/github.com/lib/pq/encode_test.go b/vendor/src/github.com/lib/pq/encode_test.go index b1531ec2..3a0f7286 100644 --- a/vendor/src/github.com/lib/pq/encode_test.go +++ b/vendor/src/github.com/lib/pq/encode_test.go @@ -370,17 +370,17 @@ func TestInfinityTimestamp(t *testing.T) { t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String()) } - y_1500 := time.Date(-1500, time.January, 1, 0, 0, 0, 0, time.UTC) + ym1500 := time.Date(-1500, time.January, 1, 0, 0, 0, 0, time.UTC) y11500 := time.Date(11500, time.January, 1, 0, 0, 0, 0, time.UTC) var s string - err = db.QueryRow("SELECT $1::timestamp::text", y_1500).Scan(&s) + err = db.QueryRow("SELECT $1::timestamp::text", ym1500).Scan(&s) if err != nil { t.Errorf("Encoding -infinity, expected no error, got %q", err) } if s != "-infinity" { t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s) } - err = db.QueryRow("SELECT $1::timestamptz::text", y_1500).Scan(&s) + err = db.QueryRow("SELECT $1::timestamptz::text", ym1500).Scan(&s) if err != nil { t.Errorf("Encoding -infinity, expected no error, got %q", err) } diff --git a/vendor/src/github.com/lib/pq/go18_test.go b/vendor/src/github.com/lib/pq/go18_test.go index 15546d86..4bf6391e 100644 --- a/vendor/src/github.com/lib/pq/go18_test.go +++ b/vendor/src/github.com/lib/pq/go18_test.go @@ -5,6 +5,8 @@ package pq import ( "context" "database/sql" + "runtime" + "strings" "testing" "time" ) @@ -72,6 +74,8 @@ func TestMultipleSimpleQuery(t *testing.T) { } } +const contextRaceIterations = 100 + func TestContextCancelExec(t *testing.T) { db := openTestConn(t) defer db.Close() @@ -79,10 +83,7 @@ func TestContextCancelExec(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) // Delay execution for just a bit until db.ExecContext has begun. - go func() { - time.Sleep(time.Millisecond * 10) - cancel() - }() + defer time.AfterFunc(time.Millisecond*10, cancel).Stop() // Not canceled until after the exec has started. if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil { @@ -97,6 +98,20 @@ func TestContextCancelExec(t *testing.T) { } else if err.Error() != "context canceled" { t.Fatalf("unexpected error: %s", err) } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if _, err := db.ExecContext(ctx, "select 1"); err != nil { + t.Fatal(err) + } + }() + + if _, err := db.Exec("select 1"); err != nil { + t.Fatal(err) + } + } } func TestContextCancelQuery(t *testing.T) { @@ -106,10 +121,7 @@ func TestContextCancelQuery(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) // Delay execution for just a bit until db.QueryContext has begun. - go func() { - time.Sleep(time.Millisecond * 10) - cancel() - }() + defer time.AfterFunc(time.Millisecond*10, cancel).Stop() // Not canceled until after the exec has started. if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil { @@ -124,6 +136,55 @@ func TestContextCancelQuery(t *testing.T) { } else if err.Error() != "context canceled" { t.Fatalf("unexpected error: %s", err) } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + rows, err := db.QueryContext(ctx, "select 1") + cancel() + if err != nil { + t.Fatal(err) + } else if err := rows.Close(); err != nil { + t.Fatal(err) + } + }() + + if rows, err := db.Query("select 1"); err != nil { + t.Fatal(err) + } else if err := rows.Close(); err != nil { + t.Fatal(err) + } + } +} + +// TestIssue617 tests that a failed query in QueryContext doesn't lead to a +// goroutine leak. +func TestIssue617(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + const N = 10 + + numGoroutineStart := runtime.NumGoroutine() + for i := 0; i < N; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := db.QueryContext(ctx, `SELECT * FROM DOESNOTEXIST`) + pqErr, _ := err.(*Error) + // Expecting "pq: relation \"doesnotexist\" does not exist" error. + if err == nil || pqErr == nil || pqErr.Code != "42P01" { + t.Fatalf("expected undefined table error, got %v", err) + } + }() + } + numGoroutineFinish := runtime.NumGoroutine() + + // We use N/2 and not N because the GC and other actors may increase or + // decrease the number of goroutines. + if numGoroutineFinish-numGoroutineStart >= N/2 { + t.Errorf("goroutine leak detected, was %d, now %d", numGoroutineStart, numGoroutineFinish) + } } func TestContextCancelBegin(t *testing.T) { @@ -137,10 +198,7 @@ func TestContextCancelBegin(t *testing.T) { } // Delay execution for just a bit until tx.Exec has begun. - go func() { - time.Sleep(time.Millisecond * 10) - cancel() - }() + defer time.AfterFunc(time.Millisecond*10, cancel).Stop() // Not canceled until after the exec has started. if _, err := tx.Exec("select pg_sleep(1)"); err == nil { @@ -162,4 +220,100 @@ func TestContextCancelBegin(t *testing.T) { } else if err.Error() != "context canceled" { t.Fatalf("unexpected error: %s", err) } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + tx, err := db.BeginTx(ctx, nil) + cancel() + if err != nil { + t.Fatal(err) + } else if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { + t.Fatal(err) + } + }() + + if tx, err := db.Begin(); err != nil { + t.Fatal(err) + } else if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + } +} + +func TestTxOptions(t *testing.T) { + db := openTestConn(t) + defer db.Close() + ctx := context.Background() + + tests := []struct { + level sql.IsolationLevel + isolation string + }{ + { + level: sql.LevelDefault, + isolation: "", + }, + { + level: sql.LevelReadUncommitted, + isolation: "read uncommitted", + }, + { + level: sql.LevelReadCommitted, + isolation: "read committed", + }, + { + level: sql.LevelRepeatableRead, + isolation: "repeatable read", + }, + { + level: sql.LevelSerializable, + isolation: "serializable", + }, + } + + for _, test := range tests { + for _, ro := range []bool{true, false} { + tx, err := db.BeginTx(ctx, &sql.TxOptions{ + Isolation: test.level, + ReadOnly: ro, + }) + if err != nil { + t.Fatal(err) + } + + var isolation string + err = tx.QueryRow("select current_setting('transaction_isolation')").Scan(&isolation) + if err != nil { + t.Fatal(err) + } + + if test.isolation != "" && isolation != test.isolation { + t.Errorf("wrong isolation level: %s != %s", isolation, test.isolation) + } + + var isRO string + err = tx.QueryRow("select current_setting('transaction_read_only')").Scan(&isRO) + if err != nil { + t.Fatal(err) + } + + if ro != (isRO == "on") { + t.Errorf("read/[write,only] not set: %t != %s for level %s", + ro, isRO, test.isolation) + } + + tx.Rollback() + } + } + + _, err := db.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelLinearizable, + }) + if err == nil { + t.Fatal("expected LevelLinearizable to fail") + } + if !strings.Contains(err.Error(), "isolation level not supported") { + t.Errorf("Expected error to mention isolation level, got %q", err) + } } diff --git a/vendor/src/github.com/lib/pq/listen_example/doc.go b/vendor/src/github.com/lib/pq/listen_example/doc.go index 5bc99f5c..80f0a9b9 100644 --- a/vendor/src/github.com/lib/pq/listen_example/doc.go +++ b/vendor/src/github.com/lib/pq/listen_example/doc.go @@ -51,21 +51,15 @@ mechanism to avoid polling the database while waiting for more work to arrive. } func waitForNotification(l *pq.Listener) { - for { - select { - case <-l.Notify: - fmt.Println("received notification, new work available") - return - case <-time.After(90 * time.Second): - go func() { - l.Ping() - }() - // Check if there's more work available, just in case it takes - // a while for the Listener to notice connection loss and - // reconnect. - fmt.Println("received no work for 90 seconds, checking for new work") - return - } + select { + case <-l.Notify: + fmt.Println("received notification, new work available") + case <-time.After(90 * time.Second): + go l.Ping() + // Check if there's more work available, just in case it takes + // a while for the Listener to notice connection loss and + // reconnect. + fmt.Println("received no work for 90 seconds, checking for new work") } } diff --git a/vendor/src/github.com/lib/pq/notify_test.go b/vendor/src/github.com/lib/pq/notify_test.go index fe8941a4..82a77e1e 100644 --- a/vendor/src/github.com/lib/pq/notify_test.go +++ b/vendor/src/github.com/lib/pq/notify_test.go @@ -7,7 +7,6 @@ import ( "os" "runtime" "sync" - "sync/atomic" "testing" "time" ) @@ -235,15 +234,10 @@ func TestConnExecDeadlock(t *testing.T) { // calls Close on the net.Conn; equivalent to a network failure l.Close() - var done int32 = 0 - go func() { - time.Sleep(10 * time.Second) - if atomic.LoadInt32(&done) != 1 { - panic("timed out") - } - }() + defer time.AfterFunc(10*time.Second, func() { + panic("timed out") + }).Stop() wg.Wait() - atomic.StoreInt32(&done, 1) } // Test for ListenerConn being closed while a slow query is executing @@ -271,15 +265,11 @@ func TestListenerConnCloseWhileQueryIsExecuting(t *testing.T) { if err != nil { t.Fatal(err) } - var done int32 = 0 - go func() { - time.Sleep(10 * time.Second) - if atomic.LoadInt32(&done) != 1 { - panic("timed out") - } - }() + + defer time.AfterFunc(10*time.Second, func() { + panic("timed out") + }).Stop() wg.Wait() - atomic.StoreInt32(&done, 1) } func TestNotifyExtra(t *testing.T) { diff --git a/vendor/src/github.com/lib/pq/oid/gen.go b/vendor/src/github.com/lib/pq/oid/gen.go index cd4aea80..7c634cdc 100644 --- a/vendor/src/github.com/lib/pq/oid/gen.go +++ b/vendor/src/github.com/lib/pq/oid/gen.go @@ -10,10 +10,22 @@ import ( "log" "os" "os/exec" + "strings" _ "github.com/lib/pq" ) +// OID represent a postgres Object Identifier Type. +type OID struct { + ID int + Type string +} + +// Name returns an upper case version of the oid type. +func (o OID) Name() string { + return strings.ToUpper(o.Type) +} + func main() { datname := os.Getenv("PGDATABASE") sslmode := os.Getenv("PGSSLMODE") @@ -30,6 +42,25 @@ func main() { if err != nil { log.Fatal(err) } + rows, err := db.Query(` + SELECT typname, oid + FROM pg_type WHERE oid < 10000 + ORDER BY oid; + `) + if err != nil { + log.Fatal(err) + } + oids := make([]*OID, 0) + for rows.Next() { + var oid OID + if err = rows.Scan(&oid.Type, &oid.ID); err != nil { + log.Fatal(err) + } + oids = append(oids, &oid) + } + if err = rows.Err(); err != nil { + log.Fatal(err) + } cmd := exec.Command("gofmt") cmd.Stderr = os.Stderr w, err := cmd.StdinPipe() @@ -45,30 +76,18 @@ func main() { if err != nil { log.Fatal(err) } - fmt.Fprintln(w, "// generated by 'go run gen.go'; do not edit") + fmt.Fprintln(w, "// Code generated by gen.go. DO NOT EDIT.") fmt.Fprintln(w, "\npackage oid") fmt.Fprintln(w, "const (") - rows, err := db.Query(` - SELECT typname, oid - FROM pg_type WHERE oid < 10000 - ORDER BY oid; - `) - if err != nil { - log.Fatal(err) - } - var name string - var oid int - for rows.Next() { - err = rows.Scan(&name, &oid) - if err != nil { - log.Fatal(err) - } - fmt.Fprintf(w, "T_%s Oid = %d\n", name, oid) - } - if err = rows.Err(); err != nil { - log.Fatal(err) + for _, oid := range oids { + fmt.Fprintf(w, "T_%s Oid = %d\n", oid.Type, oid.ID) } fmt.Fprintln(w, ")") + fmt.Fprintln(w, "var TypeName = map[Oid]string{") + for _, oid := range oids { + fmt.Fprintf(w, "T_%s: \"%s\",\n", oid.Type, oid.Name()) + } + fmt.Fprintln(w, "}") w.Close() cmd.Wait() } diff --git a/vendor/src/github.com/lib/pq/oid/types.go b/vendor/src/github.com/lib/pq/oid/types.go index 03df05a6..ecc84c2c 100644 --- a/vendor/src/github.com/lib/pq/oid/types.go +++ b/vendor/src/github.com/lib/pq/oid/types.go @@ -1,4 +1,4 @@ -// generated by 'go run gen.go'; do not edit +// Code generated by gen.go. DO NOT EDIT. package oid @@ -18,6 +18,7 @@ const ( T_xid Oid = 28 T_cid Oid = 29 T_oidvector Oid = 30 + T_pg_ddl_command Oid = 32 T_pg_type Oid = 71 T_pg_attribute Oid = 75 T_pg_proc Oid = 81 @@ -28,6 +29,7 @@ const ( T_pg_node_tree Oid = 194 T__json Oid = 199 T_smgr Oid = 210 + T_index_am_handler Oid = 325 T_point Oid = 600 T_lseg Oid = 601 T_path Oid = 602 @@ -133,6 +135,9 @@ const ( T__uuid Oid = 2951 T_txid_snapshot Oid = 2970 T_fdw_handler Oid = 3115 + T_pg_lsn Oid = 3220 + T__pg_lsn Oid = 3221 + T_tsm_handler Oid = 3310 T_anyenum Oid = 3500 T_tsvector Oid = 3614 T_tsquery Oid = 3615 @@ -144,6 +149,8 @@ const ( T__regconfig Oid = 3735 T_regdictionary Oid = 3769 T__regdictionary Oid = 3770 + T_jsonb Oid = 3802 + T__jsonb Oid = 3807 T_anyrange Oid = 3831 T_event_trigger Oid = 3838 T_int4range Oid = 3904 @@ -158,4 +165,179 @@ const ( T__daterange Oid = 3913 T_int8range Oid = 3926 T__int8range Oid = 3927 + T_pg_shseclabel Oid = 4066 + T_regnamespace Oid = 4089 + T__regnamespace Oid = 4090 + T_regrole Oid = 4096 + T__regrole Oid = 4097 ) + +var TypeName = map[Oid]string{ + T_bool: "BOOL", + T_bytea: "BYTEA", + T_char: "CHAR", + T_name: "NAME", + T_int8: "INT8", + T_int2: "INT2", + T_int2vector: "INT2VECTOR", + T_int4: "INT4", + T_regproc: "REGPROC", + T_text: "TEXT", + T_oid: "OID", + T_tid: "TID", + T_xid: "XID", + T_cid: "CID", + T_oidvector: "OIDVECTOR", + T_pg_ddl_command: "PG_DDL_COMMAND", + T_pg_type: "PG_TYPE", + T_pg_attribute: "PG_ATTRIBUTE", + T_pg_proc: "PG_PROC", + T_pg_class: "PG_CLASS", + T_json: "JSON", + T_xml: "XML", + T__xml: "_XML", + T_pg_node_tree: "PG_NODE_TREE", + T__json: "_JSON", + T_smgr: "SMGR", + T_index_am_handler: "INDEX_AM_HANDLER", + T_point: "POINT", + T_lseg: "LSEG", + T_path: "PATH", + T_box: "BOX", + T_polygon: "POLYGON", + T_line: "LINE", + T__line: "_LINE", + T_cidr: "CIDR", + T__cidr: "_CIDR", + T_float4: "FLOAT4", + T_float8: "FLOAT8", + T_abstime: "ABSTIME", + T_reltime: "RELTIME", + T_tinterval: "TINTERVAL", + T_unknown: "UNKNOWN", + T_circle: "CIRCLE", + T__circle: "_CIRCLE", + T_money: "MONEY", + T__money: "_MONEY", + T_macaddr: "MACADDR", + T_inet: "INET", + T__bool: "_BOOL", + T__bytea: "_BYTEA", + T__char: "_CHAR", + T__name: "_NAME", + T__int2: "_INT2", + T__int2vector: "_INT2VECTOR", + T__int4: "_INT4", + T__regproc: "_REGPROC", + T__text: "_TEXT", + T__tid: "_TID", + T__xid: "_XID", + T__cid: "_CID", + T__oidvector: "_OIDVECTOR", + T__bpchar: "_BPCHAR", + T__varchar: "_VARCHAR", + T__int8: "_INT8", + T__point: "_POINT", + T__lseg: "_LSEG", + T__path: "_PATH", + T__box: "_BOX", + T__float4: "_FLOAT4", + T__float8: "_FLOAT8", + T__abstime: "_ABSTIME", + T__reltime: "_RELTIME", + T__tinterval: "_TINTERVAL", + T__polygon: "_POLYGON", + T__oid: "_OID", + T_aclitem: "ACLITEM", + T__aclitem: "_ACLITEM", + T__macaddr: "_MACADDR", + T__inet: "_INET", + T_bpchar: "BPCHAR", + T_varchar: "VARCHAR", + T_date: "DATE", + T_time: "TIME", + T_timestamp: "TIMESTAMP", + T__timestamp: "_TIMESTAMP", + T__date: "_DATE", + T__time: "_TIME", + T_timestamptz: "TIMESTAMPTZ", + T__timestamptz: "_TIMESTAMPTZ", + T_interval: "INTERVAL", + T__interval: "_INTERVAL", + T__numeric: "_NUMERIC", + T_pg_database: "PG_DATABASE", + T__cstring: "_CSTRING", + T_timetz: "TIMETZ", + T__timetz: "_TIMETZ", + T_bit: "BIT", + T__bit: "_BIT", + T_varbit: "VARBIT", + T__varbit: "_VARBIT", + T_numeric: "NUMERIC", + T_refcursor: "REFCURSOR", + T__refcursor: "_REFCURSOR", + T_regprocedure: "REGPROCEDURE", + T_regoper: "REGOPER", + T_regoperator: "REGOPERATOR", + T_regclass: "REGCLASS", + T_regtype: "REGTYPE", + T__regprocedure: "_REGPROCEDURE", + T__regoper: "_REGOPER", + T__regoperator: "_REGOPERATOR", + T__regclass: "_REGCLASS", + T__regtype: "_REGTYPE", + T_record: "RECORD", + T_cstring: "CSTRING", + T_any: "ANY", + T_anyarray: "ANYARRAY", + T_void: "VOID", + T_trigger: "TRIGGER", + T_language_handler: "LANGUAGE_HANDLER", + T_internal: "INTERNAL", + T_opaque: "OPAQUE", + T_anyelement: "ANYELEMENT", + T__record: "_RECORD", + T_anynonarray: "ANYNONARRAY", + T_pg_authid: "PG_AUTHID", + T_pg_auth_members: "PG_AUTH_MEMBERS", + T__txid_snapshot: "_TXID_SNAPSHOT", + T_uuid: "UUID", + T__uuid: "_UUID", + T_txid_snapshot: "TXID_SNAPSHOT", + T_fdw_handler: "FDW_HANDLER", + T_pg_lsn: "PG_LSN", + T__pg_lsn: "_PG_LSN", + T_tsm_handler: "TSM_HANDLER", + T_anyenum: "ANYENUM", + T_tsvector: "TSVECTOR", + T_tsquery: "TSQUERY", + T_gtsvector: "GTSVECTOR", + T__tsvector: "_TSVECTOR", + T__gtsvector: "_GTSVECTOR", + T__tsquery: "_TSQUERY", + T_regconfig: "REGCONFIG", + T__regconfig: "_REGCONFIG", + T_regdictionary: "REGDICTIONARY", + T__regdictionary: "_REGDICTIONARY", + T_jsonb: "JSONB", + T__jsonb: "_JSONB", + T_anyrange: "ANYRANGE", + T_event_trigger: "EVENT_TRIGGER", + T_int4range: "INT4RANGE", + T__int4range: "_INT4RANGE", + T_numrange: "NUMRANGE", + T__numrange: "_NUMRANGE", + T_tsrange: "TSRANGE", + T__tsrange: "_TSRANGE", + T_tstzrange: "TSTZRANGE", + T__tstzrange: "_TSTZRANGE", + T_daterange: "DATERANGE", + T__daterange: "_DATERANGE", + T_int8range: "INT8RANGE", + T__int8range: "_INT8RANGE", + T_pg_shseclabel: "PG_SHSECLABEL", + T_regnamespace: "REGNAMESPACE", + T__regnamespace: "_REGNAMESPACE", + T_regrole: "REGROLE", + T__regrole: "_REGROLE", +} diff --git a/vendor/src/github.com/lib/pq/rows.go b/vendor/src/github.com/lib/pq/rows.go new file mode 100644 index 00000000..c6aa5b9a --- /dev/null +++ b/vendor/src/github.com/lib/pq/rows.go @@ -0,0 +1,93 @@ +package pq + +import ( + "math" + "reflect" + "time" + + "github.com/lib/pq/oid" +) + +const headerSize = 4 + +type fieldDesc struct { + // The object ID of the data type. + OID oid.Oid + // The data type size (see pg_type.typlen). + // Note that negative values denote variable-width types. + Len int + // The type modifier (see pg_attribute.atttypmod). + // The meaning of the modifier is type-specific. + Mod int +} + +func (fd fieldDesc) Type() reflect.Type { + switch fd.OID { + case oid.T_int8: + return reflect.TypeOf(int64(0)) + case oid.T_int4: + return reflect.TypeOf(int32(0)) + case oid.T_int2: + return reflect.TypeOf(int16(0)) + case oid.T_varchar, oid.T_text: + return reflect.TypeOf("") + case oid.T_bool: + return reflect.TypeOf(false) + case oid.T_date, oid.T_time, oid.T_timetz, oid.T_timestamp, oid.T_timestamptz: + return reflect.TypeOf(time.Time{}) + case oid.T_bytea: + return reflect.TypeOf([]byte(nil)) + default: + return reflect.TypeOf(new(interface{})).Elem() + } +} + +func (fd fieldDesc) Name() string { + return oid.TypeName[fd.OID] +} + +func (fd fieldDesc) Length() (length int64, ok bool) { + switch fd.OID { + case oid.T_text, oid.T_bytea: + return math.MaxInt64, true + case oid.T_varchar, oid.T_bpchar: + return int64(fd.Mod - headerSize), true + default: + return 0, false + } +} + +func (fd fieldDesc) PrecisionScale() (precision, scale int64, ok bool) { + switch fd.OID { + case oid.T_numeric, oid.T__numeric: + mod := fd.Mod - headerSize + precision = int64((mod >> 16) & 0xffff) + scale = int64(mod & 0xffff) + return precision, scale, true + default: + return 0, 0, false + } +} + +// ColumnTypeScanType returns the value type that can be used to scan types into. +func (rs *rows) ColumnTypeScanType(index int) reflect.Type { + return rs.colTyps[index].Type() +} + +// ColumnTypeDatabaseTypeName return the database system type name. +func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { + return rs.colTyps[index].Name() +} + +// ColumnTypeLength returns the length of the column type if the column is a +// variable length type. If the column is not a variable length type ok +// should return false. +func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) { + return rs.colTyps[index].Length() +} + +// ColumnTypePrecisionScale should return the precision and scale for decimal +// types. If not applicable, ok should be false. +func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + return rs.colTyps[index].PrecisionScale() +} diff --git a/vendor/src/github.com/lib/pq/rows_test.go b/vendor/src/github.com/lib/pq/rows_test.go new file mode 100644 index 00000000..3033bc01 --- /dev/null +++ b/vendor/src/github.com/lib/pq/rows_test.go @@ -0,0 +1,220 @@ +// +build go1.8 + +package pq + +import ( + "math" + "reflect" + "testing" + + "github.com/lib/pq/oid" +) + +func TestDataTypeName(t *testing.T) { + tts := []struct { + typ oid.Oid + name string + }{ + {oid.T_int8, "INT8"}, + {oid.T_int4, "INT4"}, + {oid.T_int2, "INT2"}, + {oid.T_varchar, "VARCHAR"}, + {oid.T_text, "TEXT"}, + {oid.T_bool, "BOOL"}, + {oid.T_numeric, "NUMERIC"}, + {oid.T_date, "DATE"}, + {oid.T_time, "TIME"}, + {oid.T_timetz, "TIMETZ"}, + {oid.T_timestamp, "TIMESTAMP"}, + {oid.T_timestamptz, "TIMESTAMPTZ"}, + {oid.T_bytea, "BYTEA"}, + } + + for i, tt := range tts { + dt := fieldDesc{OID: tt.typ} + if name := dt.Name(); name != tt.name { + t.Errorf("(%d) got: %s want: %s", i, name, tt.name) + } + } +} + +func TestDataType(t *testing.T) { + tts := []struct { + typ oid.Oid + kind reflect.Kind + }{ + {oid.T_int8, reflect.Int64}, + {oid.T_int4, reflect.Int32}, + {oid.T_int2, reflect.Int16}, + {oid.T_varchar, reflect.String}, + {oid.T_text, reflect.String}, + {oid.T_bool, reflect.Bool}, + {oid.T_date, reflect.Struct}, + {oid.T_time, reflect.Struct}, + {oid.T_timetz, reflect.Struct}, + {oid.T_timestamp, reflect.Struct}, + {oid.T_timestamptz, reflect.Struct}, + {oid.T_bytea, reflect.Slice}, + } + + for i, tt := range tts { + dt := fieldDesc{OID: tt.typ} + if kind := dt.Type().Kind(); kind != tt.kind { + t.Errorf("(%d) got: %s want: %s", i, kind, tt.kind) + } + } +} + +func TestDataTypeLength(t *testing.T) { + tts := []struct { + typ oid.Oid + len int + mod int + length int64 + ok bool + }{ + {oid.T_int4, 0, -1, 0, false}, + {oid.T_varchar, 65535, 9, 5, true}, + {oid.T_text, 65535, -1, math.MaxInt64, true}, + {oid.T_bytea, 65535, -1, math.MaxInt64, true}, + } + + for i, tt := range tts { + dt := fieldDesc{OID: tt.typ, Len: tt.len, Mod: tt.mod} + if l, k := dt.Length(); k != tt.ok || l != tt.length { + t.Errorf("(%d) got: %d, %t want: %d, %t", i, l, k, tt.length, tt.ok) + } + } +} + +func TestDataTypePrecisionScale(t *testing.T) { + tts := []struct { + typ oid.Oid + mod int + precision, scale int64 + ok bool + }{ + {oid.T_int4, -1, 0, 0, false}, + {oid.T_numeric, 589830, 9, 2, true}, + {oid.T_text, -1, 0, 0, false}, + } + + for i, tt := range tts { + dt := fieldDesc{OID: tt.typ, Mod: tt.mod} + p, s, k := dt.PrecisionScale() + if k != tt.ok { + t.Errorf("(%d) got: %t want: %t", i, k, tt.ok) + } + if p != tt.precision { + t.Errorf("(%d) wrong precision got: %d want: %d", i, p, tt.precision) + } + if s != tt.scale { + t.Errorf("(%d) wrong scale got: %d want: %d", i, s, tt.scale) + } + } +} + +func TestRowsColumnTypes(t *testing.T) { + columnTypesTests := []struct { + Name string + TypeName string + Length struct { + Len int64 + OK bool + } + DecimalSize struct { + Precision int64 + Scale int64 + OK bool + } + ScanType reflect.Type + }{ + { + Name: "a", + TypeName: "INT4", + Length: struct { + Len int64 + OK bool + }{ + Len: 0, + OK: false, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(int32(0)), + }, { + Name: "bar", + TypeName: "TEXT", + Length: struct { + Len int64 + OK bool + }{ + Len: math.MaxInt64, + OK: true, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(""), + }, + } + + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec") + if err != nil { + t.Fatal(err) + } + + columns, err := rows.ColumnTypes() + if err != nil { + t.Fatal(err) + } + if len(columns) != 3 { + t.Errorf("expected 3 columns found %d", len(columns)) + } + + for i, tt := range columnTypesTests { + c := columns[i] + if c.Name() != tt.Name { + t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) + } + if c.DatabaseTypeName() != tt.TypeName { + t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) + } + l, ok := c.Length() + if l != tt.Length.Len { + t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) + } + if ok != tt.Length.OK { + t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) + } + p, s, ok := c.DecimalSize() + if p != tt.DecimalSize.Precision { + t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) + } + if s != tt.DecimalSize.Scale { + t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) + } + if ok != tt.DecimalSize.OK { + t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) + } + if c.ScanType() != tt.ScanType { + t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) + } + } +} diff --git a/vendor/src/github.com/lib/pq/ssl.go b/vendor/src/github.com/lib/pq/ssl.go index b282ebd9..7deb3043 100644 --- a/vendor/src/github.com/lib/pq/ssl.go +++ b/vendor/src/github.com/lib/pq/ssl.go @@ -15,7 +15,7 @@ import ( func ssl(o values) func(net.Conn) net.Conn { verifyCaOnly := false tlsConf := tls.Config{} - switch mode := o.Get("sslmode"); mode { + switch mode := o["sslmode"]; mode { // "require" is the default. case "", "require": // We must skip TLS's own verification since it requires full @@ -23,15 +23,19 @@ func ssl(o values) func(net.Conn) net.Conn { tlsConf.InsecureSkipVerify = true // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: - // Note: For backwards compatibility with earlier versions of PostgreSQL, if a - // root CA file exists, the behavior of sslmode=require will be the same as - // that of verify-ca, meaning the server certificate is validated against the - // CA. Relying on this behavior is discouraged, and applications that need - // certificate validation should always use verify-ca or verify-full. - if _, err := os.Stat(o.Get("sslrootcert")); err == nil { - verifyCaOnly = true - } else { - o.Set("sslrootcert", "") + // + // Note: For backwards compatibility with earlier versions of + // PostgreSQL, if a root CA file exists, the behavior of + // sslmode=require will be the same as that of verify-ca, meaning the + // server certificate is validated against the CA. Relying on this + // behavior is discouraged, and applications that need certificate + // validation should always use verify-ca or verify-full. + if sslrootcert, ok := o["sslrootcert"]; ok { + if _, err := os.Stat(sslrootcert); err == nil { + verifyCaOnly = true + } else { + delete(o, "sslrootcert") + } } case "verify-ca": // We must skip TLS's own verification since it requires full @@ -39,7 +43,7 @@ func ssl(o values) func(net.Conn) net.Conn { tlsConf.InsecureSkipVerify = true verifyCaOnly = true case "verify-full": - tlsConf.ServerName = o.Get("host") + tlsConf.ServerName = o["host"] case "disable": return nil default: @@ -64,37 +68,42 @@ func ssl(o values) func(net.Conn) net.Conn { // in the user's home directory. The configured files must exist and have // the correct permissions. func sslClientCertificates(tlsConf *tls.Config, o values) { - sslkey := o.Get("sslkey") - sslcert := o.Get("sslcert") + // user.Current() might fail when cross-compiling. We have to ignore the + // error and continue without home directory defaults, since we wouldn't + // know from where to load them. + user, _ := user.Current() - var cinfo, kinfo os.FileInfo - var err error - - if sslcert != "" && sslkey != "" { - // Check that both files exist. Note that we don't do any more extensive - // checks than this (such as checking that the paths aren't directories); - // LoadX509KeyPair() will take care of the rest. - cinfo, err = os.Stat(sslcert) - if err != nil { - panic(err) - } - - kinfo, err = os.Stat(sslkey) - if err != nil { - panic(err) - } - } else { - // Automatically find certificates from ~/.postgresql - sslcert, sslkey, cinfo, kinfo = sslHomeCertificates() - - if cinfo == nil || kinfo == nil { - // No certificates to load - return - } + // In libpq, the client certificate is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037 + sslcert := o["sslcert"] + if len(sslcert) == 0 && user != nil { + sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + } + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1045 + if len(sslcert) == 0 { + return + } + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1050:L1054 + if _, err := os.Stat(sslcert); os.IsNotExist(err) { + return + } else if err != nil { + panic(err) } - // The files must also have the correct permissions - sslCertificatePermissions(cinfo, kinfo) + // In libpq, the ssl key is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1123-L1222 + sslkey := o["sslkey"] + if len(sslkey) == 0 && user != nil { + sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + } + + if len(sslkey) > 0 { + if err := sslKeyPermissions(sslkey); err != nil { + panic(err) + } + } cert, err := tls.LoadX509KeyPair(sslcert, sslkey) if err != nil { @@ -105,7 +114,10 @@ func sslClientCertificates(tlsConf *tls.Config, o values) { // sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting. func sslCertificateAuthority(tlsConf *tls.Config, o values) { - if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" { + // In libpq, the root certificate is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L950-L951 + if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { tlsConf.RootCAs = x509.NewCertPool() cert, err := ioutil.ReadFile(sslrootcert) @@ -113,41 +125,12 @@ func sslCertificateAuthority(tlsConf *tls.Config, o values) { panic(err) } - ok := tlsConf.RootCAs.AppendCertsFromPEM(cert) - if !ok { + if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { errorf("couldn't parse pem in sslrootcert") } } } -// sslHomeCertificates returns the path and stats of certificates in the current -// user's home directory. -func sslHomeCertificates() (cert, key string, cinfo, kinfo os.FileInfo) { - user, err := user.Current() - - if err != nil { - // user.Current() might fail when cross-compiling. We have to ignore the - // error and continue without client certificates, since we wouldn't know - // from where to load them. - return - } - - cert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") - key = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") - - cinfo, err = os.Stat(cert) - if err != nil { - cinfo = nil - } - - kinfo, err = os.Stat(key) - if err != nil { - kinfo = nil - } - - return -} - // sslVerifyCertificateAuthority carries out a TLS handshake to the server and // verifies the presented certificate against the CA, i.e. the one specified in // sslrootcert or the system CA if sslrootcert was not specified. diff --git a/vendor/src/github.com/lib/pq/ssl_permissions.go b/vendor/src/github.com/lib/pq/ssl_permissions.go index 33076a8d..3b7c3a2a 100644 --- a/vendor/src/github.com/lib/pq/ssl_permissions.go +++ b/vendor/src/github.com/lib/pq/ssl_permissions.go @@ -4,13 +4,17 @@ package pq import "os" -// sslCertificatePermissions checks the permissions on user-supplied certificate -// files. The key file should have very little access. +// sslKeyPermissions checks the permissions on user-supplied ssl key files. +// The key file should have very little access. // // libpq does not check key file permissions on Windows. -func sslCertificatePermissions(cert, key os.FileInfo) { - kmode := key.Mode() - if kmode != kmode&0600 { - panic(ErrSSLKeyHasWorldPermissions) +func sslKeyPermissions(sslkey string) error { + info, err := os.Stat(sslkey) + if err != nil { + return err } + if info.Mode().Perm()&0077 != 0 { + return ErrSSLKeyHasWorldPermissions + } + return nil } diff --git a/vendor/src/github.com/lib/pq/ssl_test.go b/vendor/src/github.com/lib/pq/ssl_test.go index f70a5fd5..3eafbfd2 100644 --- a/vendor/src/github.com/lib/pq/ssl_test.go +++ b/vendor/src/github.com/lib/pq/ssl_test.go @@ -6,7 +6,6 @@ import ( _ "crypto/sha256" "crypto/x509" "database/sql" - "fmt" "os" "path/filepath" "testing" @@ -42,10 +41,13 @@ func openSSLConn(t *testing.T, conninfo string) (*sql.DB, error) { } func checkSSLSetup(t *testing.T, conninfo string) { - db, err := openSSLConn(t, conninfo) - if err == nil { - db.Close() - t.Fatalf("expected error with conninfo=%q", conninfo) + _, err := openSSLConn(t, conninfo) + if pge, ok := err.(*Error); ok { + if pge.Code.Name() != "invalid_authorization_specification" { + t.Fatalf("unexpected error code '%s'", pge.Code.Name()) + } + } else { + t.Fatalf("expected %T, got %v", (*Error)(nil), err) } } @@ -150,120 +152,128 @@ func TestSSLVerifyCA(t *testing.T) { checkSSLSetup(t, "sslmode=disable user=pqgossltest") // Not OK according to the system CA - _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest") - if err == nil { - t.Fatal("expected error") + { + _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest") + if _, ok := err.(x509.UnknownAuthorityError); !ok { + t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err) + } } - _, ok := err.(x509.UnknownAuthorityError) - if !ok { - t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err) + + // Still not OK according to the system CA; empty sslrootcert is treated as unspecified. + { + _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest sslrootcert=''") + if _, ok := err.(x509.UnknownAuthorityError); !ok { + t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err) + } } rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") rootCert := "sslrootcert=" + rootCertPath + " " // No match on Common Name, but that's OK - _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest") - if err != nil { + if _, err := openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest"); err != nil { t.Fatal(err) } // Everything OK - _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest") - if err != nil { + if _, err := openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest"); err != nil { t.Fatal(err) } } -func getCertConninfo(t *testing.T, source string) string { - var sslkey string - var sslcert string - - certpath := os.Getenv("PQSSLCERTTEST_PATH") - - switch source { - case "missingkey": - sslkey = "/tmp/filedoesnotexist" - sslcert = filepath.Join(certpath, "postgresql.crt") - case "missingcert": - sslkey = filepath.Join(certpath, "postgresql.key") - sslcert = "/tmp/filedoesnotexist" - case "certtwice": - sslkey = filepath.Join(certpath, "postgresql.crt") - sslcert = filepath.Join(certpath, "postgresql.crt") - case "valid": - sslkey = filepath.Join(certpath, "postgresql.key") - sslcert = filepath.Join(certpath, "postgresql.crt") - default: - t.Fatalf("invalid source %q", source) - } - return fmt.Sprintf("sslmode=require user=pqgosslcert sslkey=%s sslcert=%s", sslkey, sslcert) -} - // Authenticate over SSL using client certificates func TestSSLClientCertificates(t *testing.T) { maybeSkipSSLTests(t) // Environment sanity check: should fail without SSL checkSSLSetup(t, "sslmode=disable user=pqgossltest") - // Should also fail without a valid certificate - db, err := openSSLConn(t, "sslmode=require user=pqgosslcert") - if err == nil { - db.Close() - t.Fatal("expected error") + const baseinfo = "sslmode=require user=pqgosslcert" + + // Certificate not specified, should fail + { + _, err := openSSLConn(t, baseinfo) + if pge, ok := err.(*Error); ok { + if pge.Code.Name() != "invalid_authorization_specification" { + t.Fatalf("unexpected error code '%s'", pge.Code.Name()) + } + } else { + t.Fatalf("expected %T, got %v", (*Error)(nil), err) + } } - pge, ok := err.(*Error) + + // Empty certificate specified, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert=''") + if pge, ok := err.(*Error); ok { + if pge.Code.Name() != "invalid_authorization_specification" { + t.Fatalf("unexpected error code '%s'", pge.Code.Name()) + } + } else { + t.Fatalf("expected %T, got %v", (*Error)(nil), err) + } + } + + // Non-existent certificate specified, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert=/tmp/filedoesnotexist") + if pge, ok := err.(*Error); ok { + if pge.Code.Name() != "invalid_authorization_specification" { + t.Fatalf("unexpected error code '%s'", pge.Code.Name()) + } + } else { + t.Fatalf("expected %T, got %v", (*Error)(nil), err) + } + } + + certpath, ok := os.LookupEnv("PQSSLCERTTEST_PATH") if !ok { - t.Fatal("expected pq.Error") + t.Fatalf("PQSSLCERTTEST_PATH not present in environment") } - if pge.Code.Name() != "invalid_authorization_specification" { - t.Fatalf("unexpected error code %q", pge.Code.Name()) + + sslcert := filepath.Join(certpath, "postgresql.crt") + + // Cert present, key not specified, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert) + if _, ok := err.(*os.PathError); !ok { + t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) + } } + // Cert present, empty key specified, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=''") + if _, ok := err.(*os.PathError); !ok { + t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) + } + } + + // Cert present, non-existent key, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=/tmp/filedoesnotexist") + if _, ok := err.(*os.PathError); !ok { + t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) + } + } + + // Key has wrong permissions (passing the cert as the key), should fail + if _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslcert); err != ErrSSLKeyHasWorldPermissions { + t.Fatalf("expected %s, got %#+v", ErrSSLKeyHasWorldPermissions, err) + } + + sslkey := filepath.Join(certpath, "postgresql.key") + // Should work - db, err = openSSLConn(t, getCertConninfo(t, "valid")) - if err != nil { + if db, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslkey); err != nil { t.Fatal(err) - } - rows, err := db.Query("SELECT 1") - if err != nil { - t.Fatal(err) - } - rows.Close() -} - -// Test errors with ssl certificates -func TestSSLClientCertificatesMissingFiles(t *testing.T) { - maybeSkipSSLTests(t) - // Environment sanity check: should fail without SSL - checkSSLSetup(t, "sslmode=disable user=pqgossltest") - - // Key missing, should fail - _, err := openSSLConn(t, getCertConninfo(t, "missingkey")) - if err == nil { - t.Fatal("expected error") - } - // should be a PathError - _, ok := err.(*os.PathError) - if !ok { - t.Fatalf("expected PathError, got %#+v", err) - } - - // Cert missing, should fail - _, err = openSSLConn(t, getCertConninfo(t, "missingcert")) - if err == nil { - t.Fatal("expected error") - } - // should be a PathError - _, ok = err.(*os.PathError) - if !ok { - t.Fatalf("expected PathError, got %#+v", err) - } - - // Key has wrong permissions, should fail - _, err = openSSLConn(t, getCertConninfo(t, "certtwice")) - if err == nil { - t.Fatal("expected error") - } - if err != ErrSSLKeyHasWorldPermissions { - t.Fatalf("expected ErrSSLKeyHasWorldPermissions, got %#+v", err) + } else { + rows, err := db.Query("SELECT 1") + if err != nil { + t.Fatal(err) + } + if err := rows.Close(); err != nil { + t.Fatal(err) + } + if err := db.Close(); err != nil { + t.Fatal(err) + } } } diff --git a/vendor/src/github.com/lib/pq/ssl_windows.go b/vendor/src/github.com/lib/pq/ssl_windows.go index 529daed2..5d2c763c 100644 --- a/vendor/src/github.com/lib/pq/ssl_windows.go +++ b/vendor/src/github.com/lib/pq/ssl_windows.go @@ -2,8 +2,8 @@ package pq -import "os" - -// sslCertificatePermissions checks the permissions on user-supplied certificate -// files. In libpq, this is a no-op on Windows. -func sslCertificatePermissions(cert, key os.FileInfo) {} +// sslKeyPermissions checks the permissions on user-supplied ssl key files. +// The key file should have very little access. +// +// libpq does not check key file permissions on Windows. +func sslKeyPermissions(string) error { return nil } diff --git a/vendor/src/github.com/lib/pq/uuid_test.go b/vendor/src/github.com/lib/pq/uuid_test.go index 9df4a79b..8ecee2fd 100644 --- a/vendor/src/github.com/lib/pq/uuid_test.go +++ b/vendor/src/github.com/lib/pq/uuid_test.go @@ -33,7 +33,7 @@ func TestDecodeUUIDBackend(t *testing.T) { db := openTestConn(t) defer db.Close() - var s string = "a0ecc91d-a13f-4fe4-9fce-7e09777cc70a" + var s = "a0ecc91d-a13f-4fe4-9fce-7e09777cc70a" var scanned interface{} err := db.QueryRow(`SELECT $1::uuid`, s).Scan(&scanned)