upgrade version of lib/pq to v1.1.0 (#6640)

Adds SCRAM-SHA-256 authentication
This commit is contained in:
techknowlogick 2019-04-15 16:14:31 -04:00 committed by GitHub
parent 83d6e5e3f8
commit 3fb038c53a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 527 additions and 162 deletions

259
vendor/github.com/lib/pq/conn.go generated vendored
View file

@ -2,7 +2,9 @@ package pq
import (
"bufio"
"context"
"crypto/md5"
"crypto/sha256"
"database/sql"
"database/sql/driver"
"encoding/binary"
@ -20,6 +22,7 @@ import (
"unicode"
"github.com/lib/pq/oid"
"github.com/lib/pq/scram"
)
// Common error types
@ -89,13 +92,24 @@ type Dialer interface {
DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
}
type defaultDialer struct{}
func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
return net.Dial(ntw, addr)
type DialerContext interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout(ntw, addr, timeout)
type defaultDialer struct {
d net.Dialer
}
func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
return d.d.Dial(network, address)
}
func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return d.DialContext(ctx, network, address)
}
func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.d.DialContext(ctx, network, address)
}
type conn struct {
@ -244,90 +258,35 @@ func (cn *conn) writeBuf(b byte) *writeBuf {
}
}
// Open opens a new connection to the database. name is a connection string.
// Open opens a new connection to the database. dsn is a connection string.
// Most users should only use it through database/sql package from the standard
// library.
func Open(name string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, name)
func Open(dsn string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, dsn)
}
// DialOpen opens a new connection to the database using a dialer.
func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
c, err := NewConnector(dsn)
if err != nil {
return nil, err
}
c.dialer = d
return c.open(context.Background())
}
func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
// Handle any panics during connection initialization. Note that we
// specifically do *not* want to use errRecover(), as that would turn any
// connection errors into ErrBadConns, hiding the real error message from
// the user.
defer errRecoverNoErrBadConn(&err)
o := make(values)
o := c.opts
// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o["host"] = "localhost"
o["port"] = "5432"
// N.B.: Extra float digits should be set to 3, but that breaks
// Postgres 8.4 and older, where the max is 2.
o["extra_float_digits"] = "2"
for k, v := range parseEnviron(os.Environ()) {
o[k] = v
}
if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
name, err = ParseURL(name)
if err != nil {
return nil, err
}
}
if err := parseOpts(name, o); err != nil {
return nil, err
}
// Use the "fallback" application name if necessary
if fallback, ok := o["fallback_application_name"]; ok {
if _, ok := o["application_name"]; !ok {
o["application_name"] = fallback
}
}
// We can't work with any client_encoding other than UTF-8 currently.
// However, we have historically allowed the user to set it to UTF-8
// explicitly, and there's no reason to break such programs, so allow that.
// Note that the "options" setting could also set client_encoding, but
// parsing its value is not worth it. Instead, we always explicitly send
// client_encoding as a separate run-time parameter, which should override
// anything set in options.
if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
return nil, errors.New("client_encoding must be absent or 'UTF8'")
}
o["client_encoding"] = "UTF8"
// DateStyle needs a similar treatment.
if datestyle, ok := o["datestyle"]; ok {
if datestyle != "ISO, MDY" {
panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
"ISO, MDY", datestyle))
}
} else {
o["datestyle"] = "ISO, MDY"
}
// If a user is not provided by any other means, the last
// resort is to use the current operating system provided user
// name.
if _, ok := o["user"]; !ok {
u, err := userCurrent()
if err != nil {
return nil, err
}
o["user"] = u
}
cn := &conn{
cn = &conn{
opts: o,
dialer: d,
dialer: c.dialer,
}
err = cn.handleDriverSettings(o)
if err != nil {
@ -335,7 +294,7 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
}
cn.handlePgpass(o)
cn.c, err = dial(d, o)
cn.c, err = dial(ctx, c.dialer, o)
if err != nil {
return nil, err
}
@ -364,10 +323,10 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
return cn, err
}
func dial(d Dialer, o values) (net.Conn, error) {
ntw, addr := network(o)
func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
network, address := network(o)
// SSL is not necessary or supported over UNIX domain sockets
if ntw == "unix" {
if network == "unix" {
o["sslmode"] = "disable"
}
@ -378,19 +337,30 @@ func dial(d Dialer, o values) (net.Conn, error) {
return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
}
duration := time.Duration(seconds) * time.Second
// connect_timeout should apply to the entire connection establishment
// procedure, so we both use a timeout for the TCP connection
// establishment and set a deadline for doing the initial handshake.
// The deadline is then reset after startup() is done.
deadline := time.Now().Add(duration)
conn, err := d.DialTimeout(ntw, addr, duration)
var conn net.Conn
if dctx, ok := d.(DialerContext); ok {
ctx, cancel := context.WithTimeout(ctx, duration)
defer cancel()
conn, err = dctx.DialContext(ctx, network, address)
} else {
conn, err = d.DialTimeout(network, address, duration)
}
if err != nil {
return nil, err
}
err = conn.SetDeadline(deadline)
return conn, err
}
return d.Dial(ntw, addr)
if dctx, ok := d.(DialerContext); ok {
return dctx.DialContext(ctx, network, address)
}
return d.Dial(network, address)
}
func network(o values) (string, string) {
@ -704,7 +674,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
// res might be non-nil here if we received a previous
// CommandComplete, but that's fine; just overwrite it
res = &rows{cn: cn}
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
res.rowsHeader = parsePortalRowDescribe(r)
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
@ -861,17 +831,15 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
cn.readParseResponse()
cn.readBindResponse()
rows := &rows{cn: cn}
rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
rows.rowsHeader = cn.readPortalDescribeResponse()
cn.postExecuteWorkaround()
return rows, nil
}
st := cn.prepareTo(query, "")
st.exec(args)
return &rows{
cn: cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
cn: cn,
rowsHeader: st.rowsHeader,
}, nil
}
@ -992,7 +960,6 @@ func (cn *conn) recv() (t byte, r *readBuf) {
if err != nil {
panic(err)
}
switch t {
case 'E':
panic(parseError(r))
@ -1163,6 +1130,55 @@ func (cn *conn) auth(r *readBuf, o values) {
if r.int32() != 0 {
errorf("unexpected authentication response: %q", t)
}
case 10:
sc := scram.NewClient(sha256.New, o["user"], o["password"])
sc.Step(nil)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
scOut := sc.Out()
w := cn.writeBuf('p')
w.string("SCRAM-SHA-256")
w.int32(len(scOut))
w.bytes(scOut)
cn.send(w)
t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 11 {
errorf("unexpected authentication response: %q", t)
}
nextStep := r.next(len(*r))
sc.Step(nextStep)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
scOut = sc.Out()
w = cn.writeBuf('p')
w.bytes(scOut)
cn.send(w)
t, r = cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 12 {
errorf("unexpected authentication response: %q", t)
}
nextStep = r.next(len(*r))
sc.Step(nextStep)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
default:
errorf("unknown authentication response: %d", code)
}
@ -1180,12 +1196,10 @@ var colFmtDataAllBinary = []byte{0, 1, 0, 1}
var colFmtDataAllText = []byte{0, 0}
type stmt struct {
cn *conn
name string
colNames []string
colFmts []format
cn *conn
name string
rowsHeader
colFmtData []byte
colTyps []fieldDesc
paramTyps []oid.Oid
closed bool
}
@ -1231,10 +1245,8 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
st.exec(v)
return &rows{
cn: st.cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
cn: st.cn,
rowsHeader: st.rowsHeader,
}, nil
}
@ -1344,16 +1356,22 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
return driver.RowsAffected(n), commandTag
}
type rows struct {
cn *conn
finish func()
type rowsHeader struct {
colNames []string
colTyps []fieldDesc
colFmts []format
done bool
rb readBuf
result driver.Result
tag string
}
type rows struct {
cn *conn
finish func()
rowsHeader
done bool
rb readBuf
result driver.Result
tag string
next *rowsHeader
}
func (rs *rows) Close() error {
@ -1440,7 +1458,8 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
}
return
case 'T':
rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb)
next := parsePortalRowDescribe(&rs.rb)
rs.next = &next
return io.EOF
default:
errorf("unexpected message after execute: %q", t)
@ -1449,10 +1468,16 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
}
func (rs *rows) HasNextResultSet() bool {
return !rs.done
hasNext := rs.next != nil && !rs.done
return hasNext
}
func (rs *rows) NextResultSet() error {
if rs.next == nil {
return io.EOF
}
rs.rowsHeader = *rs.next
rs.next = nil
return nil
}
@ -1630,13 +1655,13 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [
}
}
func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) {
func (cn *conn) readPortalDescribeResponse() rowsHeader {
t, r := cn.recv1()
switch t {
case 'T':
return parsePortalRowDescribe(r)
case 'n':
return nil, nil, nil
return rowsHeader{}
case 'E':
err := parseError(r)
cn.readReadyForQuery()
@ -1742,11 +1767,11 @@ func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDe
return
}
func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) {
func parsePortalRowDescribe(r *readBuf) rowsHeader {
n := r.int16()
colNames = make([]string, n)
colFmts = make([]format, n)
colTyps = make([]fieldDesc, n)
colNames := make([]string, n)
colFmts := make([]format, n)
colTyps := make([]fieldDesc, n)
for i := range colNames {
colNames[i] = r.string()
r.next(6)
@ -1755,7 +1780,11 @@ func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, co
colTyps[i].Mod = r.int32()
colFmts[i] = format(r.int16())
}
return
return rowsHeader{
colNames: colNames,
colFmts: colFmts,
colTyps: colTyps,
}
}
// parseEnviron tries to mimic some of libpq's environment handling