Skip to content

Commit

Permalink
fix(): allow connection switch when different credentials are detected
Browse files Browse the repository at this point in the history
  • Loading branch information
MatteoGioioso committed Jul 4, 2021
1 parent 71ad266 commit 2950776
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 61 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ const handler = async(event, context) => {
| maxRetries | `Integer` | Maximum number of times to retry a connection before throwing an error. | `3` |
| processCountCacheEnabled | `Boolean` | Enable caching for get process count. | `False` |
| processCountFreqMs | `Integer` | The number of milliseconds to cache lookups of process count. | `6000` |
| allowCredentialsDiffing | `Boolean` | If you are using dynamic credentials, such as IAM, you can set this parameter to `true` and the client will be refreshed | `false` |


## Note
Expand Down
1 change: 1 addition & 0 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ declare interface Config {
port?: number;
host?: string;
connectionString?: string;
allowCredentialsDiffing?: boolean;
keepAlive?: boolean;
stream?: stream.Duplex;
statement_timeout?: false | number;
Expand Down
150 changes: 89 additions & 61 deletions src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
* @license MIT
*/

const { isValidStrategy, type, validateNum, isWithinRange } = require("./utils");
const { Client } = require("pg");
const {isValidStrategy, type, validateNum, isWithinRange} = require("./utils");
const {Client} = require("pg");

function ServerlessClient(config) {
this._client = null;
Expand Down Expand Up @@ -40,15 +40,15 @@ ServerlessClient.prototype._setMaxConnections = async (__self) => {

// This strategy arbitrarily (maxIdleConnections) terminates connections starting from the oldest one in idle.
// It is very aggressive and it can cause disruption if a connection was in idle for a short period of time
ServerlessClient.prototype._getIdleProcessesListOrderByDate = async function() {
ServerlessClient.prototype._getIdleProcessesListOrderByDate = async function () {
const query = `
SELECT pid,backend_start,state
FROM pg_stat_activity
WHERE datname=$1
AND state='idle'
AND usename=$2
ORDER BY state_change
LIMIT $3;`
SELECT pid, backend_start, state
FROM pg_stat_activity
WHERE datname = $1
AND state = 'idle'
AND usename = $2
ORDER BY state_change
LIMIT $3;`

const values = [
this._client.database,
Expand All @@ -60,7 +60,7 @@ ServerlessClient.prototype._getIdleProcessesListOrderByDate = async function() {
const result = await this._client.query(query, values);

return result.rows
} catch (e){
} catch (e) {
this._logger("Swallowed internal error", e.message)
// Swallow the error, if this produce an error there is no need to error the function
return []
Expand All @@ -70,21 +70,20 @@ ServerlessClient.prototype._getIdleProcessesListOrderByDate = async function() {
// This strategy select only the connections that have been in idle state for more
// than a minimum amount of seconds, it is very accurate as it only takes the process that have been in idle
// for more than a threshold time (minConnectionTimeoutSec)
ServerlessClient.prototype._getIdleProcessesListByMinimumTimeout = async function(){
ServerlessClient.prototype._getIdleProcessesListByMinimumTimeout = async function () {
const query = `
WITH processes AS(
SELECT
EXTRACT(EPOCH FROM (Now() - state_change)) AS idle_time,
pid
FROM pg_stat_activity
WHERE usename=$1
AND datname=$2
AND state='idle'
)
SELECT pid
FROM processes
WHERE idle_time > $3
LIMIT $4;`
WITH processes AS (
SELECT EXTRACT(EPOCH FROM (Now() - state_change)) AS idle_time,
pid
FROM pg_stat_activity
WHERE usename = $1
AND datname = $2
AND state = 'idle'
)
SELECT pid
FROM processes
WHERE idle_time > $3
LIMIT $4;`

const values = [
this._client.user,
Expand All @@ -104,7 +103,7 @@ ServerlessClient.prototype._getIdleProcessesListByMinimumTimeout = async functio
}
}

ServerlessClient.prototype._getProcessesCount = async function() {
ServerlessClient.prototype._getProcessesCount = async function () {
function isCacheExpiredOrDisabled(__self) {
// If cache is disabled
if (!__self._processCount.cacheEnabled) {
Expand Down Expand Up @@ -142,19 +141,20 @@ ServerlessClient.prototype._getProcessesCount = async function() {
return this._processCount.cache.count
};

ServerlessClient.prototype._killProcesses = async function(processesList) {
ServerlessClient.prototype._killProcesses = async function (processesList) {
const pids = processesList.map(proc => proc.pid);

const query = `
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE pid = ANY ($1) AND state='idle';`
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE pid = ANY ($1)
AND state = 'idle';`

const values = [pids]

try {
return await this._client.query(query, values)
} catch (e){
} catch (e) {
this._logger("Swallowed internal error", e.message)
// Swallow the error, if this produce an error there is no need to error the function

Expand All @@ -164,7 +164,7 @@ ServerlessClient.prototype._killProcesses = async function(processesList) {
}
};

ServerlessClient.prototype._getStrategy = function(){
ServerlessClient.prototype._getStrategy = function () {
switch (this._strategy.name) {
case "minimum_idle_time":
return this._getIdleProcessesListByMinimumTimeout.bind(this)
Expand All @@ -175,14 +175,14 @@ ServerlessClient.prototype._getStrategy = function(){
}
}

ServerlessClient.prototype._decorrelatedJitter = function(delay){
ServerlessClient.prototype._decorrelatedJitter = function (delay) {
const cap = this._backoff.capMs;
const base = this._backoff.baseMs;
const randRange = (min,max) => Math.floor(Math.random() * (max - min + 1)) + min;
const randRange = (min, max) => Math.floor(Math.random() * (max - min + 1)) + min;
return Math.min(cap, randRange(base, delay * 3));
}

ServerlessClient.prototype.clean = async function() {
ServerlessClient.prototype.clean = async function () {
const processCount = await this._getProcessesCount();
this._logger("Current process count: ", processCount);

Expand All @@ -197,12 +197,29 @@ ServerlessClient.prototype.clean = async function() {
}
};

ServerlessClient.prototype._init = async function(){
if(this._client !== null){
ServerlessClient.prototype._diffCredentials = function (config) {
const keys = ['password', 'host', 'port', 'user', 'database']
for (const key of keys) {
if (this._config[key] !== config[key]) {
this._multipleCredentials.areCredentialsDifferent = true
break;
}
}
}

ServerlessClient.prototype._init = async function () {
if (this._client !== null && !this._multipleCredentials.areCredentialsDifferent) {
return
}

if (this._client !== null && this._multipleCredentials.areCredentialsDifferent) {
// For the time being we close the connection if new credentials are detected to avoid leaking.
// In the future we could use Pool in this case to avoid recreating a client each time
this._client.end()
}

this._client = new Client(this._config)
this._multipleCredentials.areCredentialsDifferent = false

// pg throws an error if we terminate the connection, therefore we need to swallow these errors
// and throw the rest
Expand All @@ -223,14 +240,14 @@ ServerlessClient.prototype._init = async function(){
await this._client.connect();
this._logger("Connected...")

if (this._maxConns.manualMaxConnections){
if (this._maxConns.manualMaxConnections) {
await this._setMaxConnections(this)
}

this._logger("Max connections: ", this._maxConns.cache.total)
}

ServerlessClient.prototype._validateConfig = function(config){
ServerlessClient.prototype._validateConfig = function (config) {
const {
manualMaxConnections,
maxConnsFreqMs,
Expand All @@ -249,58 +266,64 @@ ServerlessClient.prototype._validateConfig = function(config){
if (
manualMaxConnections &&
type(manualMaxConnections) !== "Boolean"
){
) {
throw new Error("manualMaxConnections must be of type Boolean")
}

if (debug && type(debug) !== "Boolean"){
if (debug && type(debug) !== "Boolean") {
throw new Error("debug must be of type Boolean")
}

if (validateNum(maxConnsFreqMs)){
if (validateNum(maxConnsFreqMs)) {
throw new Error("maxConnsFreqMs must be of type Number")
}

if (validateNum(maxConnections)){
if (validateNum(maxConnections)) {
throw new Error("maxConnections must be of type Number")
}

if (strategy && !isValidStrategy(strategy)){
if (strategy && !isValidStrategy(strategy)) {
throw new Error("the provided strategy is invalid")
}

if (validateNum(maxIdleConnectionsToKill)){
if (validateNum(maxIdleConnectionsToKill)) {
throw new Error("maxIdleConnectionsToKill must be of type Number or null")
}

if (validateNum(minConnectionIdleTimeSec)){
if (validateNum(minConnectionIdleTimeSec)) {
throw new Error("minConnectionIdleTimeSec must be of type Number")
}

if (validateNum(connUtilization) || !isWithinRange(connUtilization, 0, 1)){
if (validateNum(connUtilization) || !isWithinRange(connUtilization, 0, 1)) {
throw new Error("connUtilization must be of type Number")
}

if (validateNum(capMs)){
if (validateNum(capMs)) {
throw new Error("capMs must be of type Number")
}

if (validateNum(baseMs)){
if (validateNum(baseMs)) {
throw new Error("baseMs must be of type Number")
}

if (validateNum(delayMs)){
if (validateNum(delayMs)) {
throw new Error("delayMs must be of type Number")
}

if (validateNum(maxRetries)){
if (validateNum(maxRetries)) {
throw new Error("maxRetries must be of type Number")
}
}

ServerlessClient.prototype.setConfig = function (config) {
const prevConfig = this._config;
this._validateConfig(config)
this._config = { ...this._config, ...config };
this._config = {...this._config, ...config};

this._multipleCredentials = {
allowCredentialsDiffing: this._config.allowCredentialsDiffing || false,
areCredentialsDifferent: false
};

this._maxConns = {
// Cache expiration for getting the max connections value in milliseconds
Expand Down Expand Up @@ -353,16 +376,21 @@ ServerlessClient.prototype.setConfig = function (config) {
retries: 0,
queryRetries: 0
}

// Prevent diffing also if client is null
if (this._multipleCredentials.allowCredentialsDiffing && this._client !== null) {
this._diffCredentials(prevConfig, config)
}
}

ServerlessClient.prototype._logger = function(...args) {
if (this._debug){
ServerlessClient.prototype._logger = function (...args) {
if (this._debug) {
const pid = this._client && this._client.processID || 'offline'
console.log('serverless-pg | pid: ', pid, ' | ', ...args)
console.log('serverless-pg | pid: ', pid, ' | ', ...args)
}
}

ServerlessClient.prototype.connect = async function() {
ServerlessClient.prototype.connect = async function () {
try {
await this._init();
} catch (e) {
Expand Down Expand Up @@ -391,7 +419,7 @@ ServerlessClient.prototype.connect = async function() {
}
};

ServerlessClient.prototype.query = async function(...args){
ServerlessClient.prototype.query = async function (...args) {
try {
this._logger("Start query...")
// We fulfill the promise to catch the error
Expand All @@ -401,7 +429,7 @@ ServerlessClient.prototype.query = async function(...args){
e.message === "Client has encountered a connection error and is not queryable" ||
e.message === "terminating connection due to administrator command" ||
e.message === "Connection terminated unexpectedly"
){
) {
// If a client has been terminated by serverless-postgres and try to query again
// we re-initialize it and retry
this._client = null
Expand All @@ -424,15 +452,15 @@ ServerlessClient.prototype.query = async function(...args){
}
}

ServerlessClient.prototype.end = async function(){
ServerlessClient.prototype.end = async function () {
this._backoff.retries = 0
this._backoff.queryRetries = 0
await this._client.end()
this._client = null
}

ServerlessClient.prototype.on = function(...args){
ServerlessClient.prototype.on = function (...args) {
this._client.on(...args)
}

module.exports = { ServerlessClient };
module.exports = {ServerlessClient};

0 comments on commit 2950776

Please sign in to comment.