diff --git a/packages/pg/lib/client.js b/packages/pg/lib/client.js index c6aa3dabe..09aa8054b 100644 --- a/packages/pg/lib/client.js +++ b/packages/pg/lib/client.js @@ -51,20 +51,14 @@ class Client extends EventEmitter { keepAlive: c.keepAlive || false, keepAliveInitialDelayMillis: c.keepAliveInitialDelayMillis || 0, encoding: this.connectionParameters.client_encoding || 'utf8', + Promise: this._Promise, }) this.queryQueue = [] this.binary = c.binary || defaults.binary this.processID = null this.secretKey = null + // TODO: remove in next major release? this.ssl = this.connectionParameters.ssl || false - // As with Password, make SSL->Key (the private key) non-enumerable. - // It won't show up in stack traces - // or if the client is console.logged - if (this.ssl && this.ssl.key) { - Object.defineProperty(this.ssl, 'key', { - enumerable: false, - }) - } this._connectionTimeoutMillis = c.connectionTimeoutMillis || 0 } @@ -115,14 +109,6 @@ class Client extends EventEmitter { // once connection is established send startup message con.on('connect', function () { - if (self.ssl) { - con.requestSsl() - } else { - con.startup(self.getStartupConf()) - } - }) - - con.on('sslconnect', function () { con.startup(self.getStartupConf()) }) diff --git a/packages/pg/lib/connection.js b/packages/pg/lib/connection.js index af4b8f13b..b07fc11e1 100644 --- a/packages/pg/lib/connection.js +++ b/packages/pg/lib/connection.js @@ -1,6 +1,5 @@ 'use strict' -var net = require('net') var EventEmitter = require('events').EventEmitter const { parse, serialize } = require('pg-protocol') @@ -16,6 +15,15 @@ class Connection extends EventEmitter { super() config = config || {} + // As with Password, make SSL->Key (the private key) non-enumerable. + // It won't show up in stack traces + // or if the client is console.logged + if (config.ssl && config.ssl.key) { + Object.defineProperty(config.ssl, 'key', { + enumerable: false, + }) + } + this.stream = config.stream || getStream(config.ssl) if (typeof this.stream === 'function') { this.stream = this.stream(config) @@ -34,11 +42,40 @@ class Connection extends EventEmitter { self._emitMessage = true } }) + + this._config = config + this._backendData = null + this._remote = null + } + + cancelWithClone() { + const config = this._config + const Promise = config.Promise || global.Promise + + return new Promise((resolve, reject) => { + const { processID, secretKey } = this._backendData + let { host, port, notIP } = this._remote + if (host && notIP && config.ssl && this.stream.remoteAddress) { + if (config.ssl === true) { + config.ssl = {} + } + config.ssl.servername = host + host = this.stream.remoteAddress + } + + const con = new Connection(config) + con + .on('connect', () => con.cancel(processID, secretKey)) + .on('error', reject) + .on('end', resolve) + .connect(port, host) + }) } connect(port, host) { var self = this + this._remote = { host, port } this._connecting = true this.stream.setNoDelay(true) this.stream.connect(port, host) @@ -47,7 +84,11 @@ class Connection extends EventEmitter { if (self._keepAlive) { self.stream.setKeepAlive(true, self._keepAliveInitialDelayMillis) } - self.emit('connect') + if (self.ssl) { + self.requestSsl() + } else { + self.emit('connect') + } }) const reportStreamError = function (error) { @@ -95,6 +136,7 @@ class Connection extends EventEmitter { var net = require('net') if (net.isIP && net.isIP(host) === 0) { options.servername = host + self._remote.notIP = true } try { self.stream = getSecureStream(options) @@ -104,7 +146,7 @@ class Connection extends EventEmitter { self.attachListeners(self.stream) self.stream.on('error', reportStreamError) - self.emit('sslconnect') + self.emit('connect') }) } @@ -115,6 +157,9 @@ class Connection extends EventEmitter { this.emit('message', msg) } this.emit(eventName, msg) + if (msg.name === 'backendKeyData') { + this._backendData = msg + } }) } diff --git a/packages/pg/lib/query.js b/packages/pg/lib/query.js index fac4d86e3..5e6a4bb38 100644 --- a/packages/pg/lib/query.js +++ b/packages/pg/lib/query.js @@ -5,6 +5,31 @@ const { EventEmitter } = require('events') const Result = require('./result') const utils = require('./utils') +function setupCancellation(cancelSignal, connection) { + let cancellation = null + + function cancelRequest() { + cancellation = connection.cancelWithClone().catch(() => { + // We could still have a cancel request in flight targeting this connection. + // Better safe than sorry? + connection.stream.destroy() + }) + } + + cancelSignal.addEventListener('abort', cancelRequest, { once: true }) + + return { + cleanup() { + if (cancellation) { + // Must wait out connection.cancelWithClone + return cancellation + } + cancelSignal.removeEventListener('abort', cancelRequest) + return Promise.resolve() + }, + } +} + class Query extends EventEmitter { constructor(config, values, callback) { super() @@ -29,6 +54,8 @@ class Query extends EventEmitter { // potential for multiple results this._results = this._result this._canceledDueToError = false + + this._cancelSignal = config.signal } requiresPreparation() { @@ -114,34 +141,53 @@ class Query extends EventEmitter { } } + _handleQueryComplete(fn) { + if (!this._cancellation) { + fn() + return + } + this._cancellation + .cleanup() + .then(fn) + .finally(() => { + this._cancellation = null + }) + } + handleError(err, connection) { // need to sync after error during a prepared statement if (this._canceledDueToError) { err = this._canceledDueToError this._canceledDueToError = false } - // if callback supplied do not emit error event as uncaught error - // events will bubble up to node process - if (this.callback) { - return this.callback(err) - } - this.emit('error', err) + + this._handleQueryComplete(() => { + // if callback supplied do not emit error event as uncaught error + // events will bubble up to node process + if (this.callback) { + return this.callback(err) + } + this.emit('error', err) + }) } handleReadyForQuery(con) { if (this._canceledDueToError) { return this.handleError(this._canceledDueToError, con) } - if (this.callback) { - try { - this.callback(null, this._results) - } catch (err) { - process.nextTick(() => { - throw err - }) + + this._handleQueryComplete(() => { + if (this.callback) { + try { + this.callback(null, this._results) + } catch (err) { + process.nextTick(() => { + throw err + }) + } } - } - this.emit('end', this._results) + this.emit('end', this._results) + }) } submit(connection) { @@ -155,6 +201,12 @@ class Query extends EventEmitter { if (this.values && !Array.isArray(this.values)) { return new Error('Query values must be an array') } + if (this._cancelSignal) { + if (this._cancelSignal.aborted) { + return this._cancelSignal.reason || Object.assign(new Error(), { name: 'AbortError' }) + } + this._cancellation = setupCancellation(this._cancelSignal, connection) + } if (this.requiresPreparation()) { this.prepare(connection) } else { diff --git a/packages/pg/test/integration/cancel/cancel-query-with-abort-signal-tests.js b/packages/pg/test/integration/cancel/cancel-query-with-abort-signal-tests.js new file mode 100644 index 000000000..69d93ec7c --- /dev/null +++ b/packages/pg/test/integration/cancel/cancel-query-with-abort-signal-tests.js @@ -0,0 +1,114 @@ +var helper = require('./../test-helper') + +var pg = helper.pg +const Client = pg.Client +const DatabaseError = pg.DatabaseError + +if (!global.AbortController) { + // Skip these tests if AbortController is not available + return +} + +const suite = new helper.Suite('query cancellation with abort signal') + +suite.test('query with signal succeeds if not aborted', function (done) { + const client = new Client() + const { signal } = new AbortController() + + client.connect( + assert.success(() => { + client.query( + new pg.Query({ text: 'select pg_sleep(0.1)', signal }), + assert.success((result) => { + assert.equal(result.rows[0].pg_sleep, '') + client.end(done) + }) + ) + }) + ) +}) + +if (helper.config.native) { + // Skip these tests if native bindings are enabled + return +} + +suite.test('query with signal is not submitted if the signal is already aborted', function (done) { + const client = new Client() + const signal = AbortSignal.abort() + + let counter = 0 + + client.query( + new pg.Query({ text: 'INVALID SQL...' }), + assert.calls((err) => { + assert(err instanceof DatabaseError) + counter++ + }) + ) + + client.query( + new pg.Query({ text: 'begin' }), + assert.success(() => { + counter++ + }) + ) + + client.query( + new pg.Query({ text: 'INVALID SQL...', signal }), + assert.calls((err) => { + assert.equal(err.name, 'AbortError') + counter++ + }) + ) + + client.query( + new pg.Query({ text: 'select 1' }), + assert.success(() => { + counter++ + assert.equal(counter, 4) + client.end(done) + }) + ) + + client.connect(assert.success(() => {})) +}) + +suite.test('query can be canceled with abort signal', function (done) { + const client = new Client() + const ac = new AbortController() + const { signal } = ac + + client.query( + new pg.Query({ text: 'SELECT pg_sleep(0.5)', signal }), + assert.calls((err) => { + assert(err instanceof DatabaseError) + assert.equal(err.code, '57014') + client.end(done) + }) + ) + + client.connect( + assert.success(() => { + setTimeout(() => { + ac.abort() + }, 50) + }) + ) +}) + +suite.test('long abort signal timeout does not keep the query / connection going', function (done) { + const client = new Client() + const ac = new AbortController() + setTimeout(() => ac.abort(), 10000).unref() + + client.query( + new pg.Query({ text: 'SELECT pg_sleep(0.1)', signal: ac.signal }), + assert.success((result) => { + assert.equal(result.rows[0].pg_sleep, '') + client.end(done) + }) + ) + + client.connect(assert.success(() => {})) +}) diff --git a/packages/pg/test/unit/connection/error-tests.js b/packages/pg/test/unit/connection/error-tests.js index 091c13e2c..95953c431 100644 --- a/packages/pg/test/unit/connection/error-tests.js +++ b/packages/pg/test/unit/connection/error-tests.js @@ -42,7 +42,7 @@ var SSLNegotiationPacketTests = [ testName: 'connection does not emit ECONNRESET errors during disconnect also when using SSL', errorMessage: null, response: 'S', - responseType: 'sslconnect', + responseType: 'connect', }, { testName: 'connection emits an error when SSL is not supported',