From 9c9c4cde4869a00025cfc9faca7c40fe8f1c23fc Mon Sep 17 00:00:00 2001 From: Taras Kopets Date: Tue, 30 May 2017 13:57:48 +0300 Subject: [PATCH] Added new PostgreSQL auth method using env vars Two new dict options under cli_options[type] are added in settings: * env - always sets the vars to given value * env_optional - sets the env vars only if sucessfully formatted Default settings file is updated to support this new functionality. --- SQLTools.sublime-settings | 3 ++ SQLToolsAPI/Command.py | 22 +++++++++------ SQLToolsAPI/Connection.py | 59 +++++++++++++++++++++++++++++++++------ 3 files changed, 67 insertions(+), 17 deletions(-) diff --git a/SQLTools.sublime-settings b/SQLTools.sublime-settings index 0ca39b8..9b6a489 100644 --- a/SQLTools.sublime-settings +++ b/SQLTools.sublime-settings @@ -61,6 +61,9 @@ "options": ["--no-password"], "before": [], "args": "-h {host} -p {port} -U {username} -d {database}", + "env_optional": { + "PGPASSWORD": "{password}" + }, "queries": { "desc" : { "query": "select '|' || quote_ident(table_schema)||'.'||quote_ident(table_name) ||'|' as tblname from information_schema.tables where table_schema = any(current_schemas(false)) and table_schema not in ('pg_catalog', 'information_schema') order by table_schema = current_schema() desc, table_schema, table_name", diff --git a/SQLToolsAPI/Command.py b/SQLToolsAPI/Command.py index a5b6751..e5ef990 100644 --- a/SQLToolsAPI/Command.py +++ b/SQLToolsAPI/Command.py @@ -10,13 +10,14 @@ class Command(object): timeout = 15 - def __init__(self, args, callback, query=None, encoding='utf-8', + def __init__(self, args, env, callback, query=None, encoding='utf-8', options=None, timeout=15, silenceErrors=False, stream=False): if options is None: options = {} self.stream = stream self.args = args + self.env = env self.callback = callback self.query = query self.encoding = encoding @@ -43,13 +44,18 @@ def run(self): if self.silenceErrors: stderrHandle = subprocess.PIPE + # set the environment + modifiedEnvironment = os.environ.copy() + if (self.env): + modifiedEnvironment.update(self.env) + queryTimerStart = time.time() self.process = subprocess.Popen(self.args, stdout=subprocess.PIPE, stderr=stderrHandle, stdin=subprocess.PIPE, - env=os.environ.copy(), + env=modifiedEnvironment, startupinfo=si) if self.stream: @@ -103,21 +109,21 @@ def _formatShowQuery(query, queryTimeStart, queryTimeEnd): return resultString @staticmethod - def createAndRun(args, query, callback, options=None, timeout=15, silenceErrors=False, stream=False): + def createAndRun(args, env, query, callback, options=None, timeout=15, silenceErrors=False, stream=False): if options is None: options = {} - command = Command(args, callback, query, options=options, + command = Command(args, env, callback, query, options=options, timeout=timeout, silenceErrors=silenceErrors) command.run() class ThreadCommand(Command, Thread): - def __init__(self, args, callback, query=None, encoding='utf-8', + def __init__(self, args, env, callback, query=None, encoding='utf-8', options=None, timeout=Command.timeout, silenceErrors=False, stream=False): if options is None: options = {} - Command.__init__(self, args, callback, query=query, + Command.__init__(self, args, env, callback, query=query, encoding=encoding, options=options, timeout=timeout, silenceErrors=silenceErrors, stream=stream) @@ -143,13 +149,13 @@ def stop(self): pass @staticmethod - def createAndRun(args, query, callback, options=None, + def createAndRun(args, env, query, callback, options=None, timeout=Command.timeout, silenceErrors=False, stream=False): # Don't allow empty dicts or lists as defaults in method signature, # cfr http://nedbatchelder.com/blog/200806/pylint.html if options is None: options = {} - command = ThreadCommand(args, callback, query, options=options, + command = ThreadCommand(args, env, callback, query, options=options, timeout=timeout, silenceErrors=silenceErrors, stream=stream) command.start() killTimeout = Timer(command.timeout, command.stop) diff --git a/SQLToolsAPI/Connection.py b/SQLToolsAPI/Connection.py index f2ba854..7971ec0 100644 --- a/SQLToolsAPI/Connection.py +++ b/SQLToolsAPI/Connection.py @@ -77,7 +77,9 @@ def getTables(self, callback): def cb(result): callback(U.getResultAsList(result)) - self.Command.createAndRun(self.builArgs('desc'), + args = self.buildArgs('desc') + env = self.buildEnv() + self.Command.createAndRun(args, env, query, cb, silenceErrors=True) def getColumns(self, callback): @@ -87,7 +89,9 @@ def cb(result): try: query = self.getOptionsForSgdbCli()['queries']['columns']['query'] - self.Command.createAndRun(self.builArgs('columns'), + args = self.buildArgs('columns') + env = self.buildEnv() + self.Command.createAndRun(args, env, query, cb, silenceErrors=True) except Exception: pass @@ -99,7 +103,9 @@ def cb(result): try: query = self.getOptionsForSgdbCli()['queries']['functions']['query'] - self.Command.createAndRun(self.builArgs('functions'), + args = self.buildArgs('functions') + env = self.buildEnv() + self.Command.createAndRun(args, env, query, cb, silenceErrors=True) except Exception: pass @@ -107,18 +113,24 @@ def cb(result): def getTableRecords(self, tableName, callback): query = self.getOptionsForSgdbCli()['queries']['show records']['query'].format(tableName, self.rowsLimit) queryToRun = '\n'.join(self.getOptionsForSgdbCli()['before'] + [query]) - self.Command.createAndRun(self.builArgs('show records'), queryToRun, callback, timeout=self.timeout) + args = self.buildArgs('show records') + env = self.buildEnv() + self.Command.createAndRun(args, env, queryToRun, callback, timeout=self.timeout) def getTableDescription(self, tableName, callback): query = self.getOptionsForSgdbCli()['queries']['desc table']['query'] % tableName queryToRun = '\n'.join(self.getOptionsForSgdbCli()['before'] + [query]) - self.Command.createAndRun(self.builArgs('desc table'), queryToRun, callback) + args = self.buildArgs('desc table') + env = self.buildEnv() + self.Command.createAndRun(args, env, queryToRun, callback) def getFunctionDescription(self, functionName, callback): query = self.getOptionsForSgdbCli()['queries']['desc function'][ 'query'] % functionName queryToRun = '\n'.join(self.getOptionsForSgdbCli()['before'] + [query]) - self.Command.createAndRun(self.builArgs('desc function'), queryToRun, callback) + args = self.buildArgs('desc function') + env = self.buildEnv() + self.Command.createAndRun(args, env, queryToRun, callback) def explainPlan(self, queries, callback): try: @@ -132,7 +144,9 @@ def explainPlan(self, queries, callback): for query in filter(None, sqlparse.split(rawQuery)) ] queryToRun = '\n'.join(self.getOptionsForSgdbCli()['before'] + stripped_queries) - self.Command.createAndRun(self.builArgs('explain plan'), queryToRun, callback, timeout=self.timeout) + args = self.buildArgs('explain plan') + env = self.buildEnv() + self.Command.createAndRun(args, env, queryToRun, callback, timeout=self.timeout) def execute(self, queries, callback, stream=False): queryToRun = '' @@ -165,9 +179,11 @@ def execute(self, queries, callback, stream=False): if self.history: self.history.add(queryToRun) - self.Command.createAndRun(self.builArgs(), queryToRun, callback, options={'show_query': self.show_query}, timeout=self.timeout, stream=stream) + args = self.buildArgs() + env = self.buildEnv() + self.Command.createAndRun(args, env, queryToRun, callback, options={'show_query': self.show_query}, timeout=self.timeout, stream=stream) - def builArgs(self, queryName=None): + def buildArgs(self, queryName=None): cliOptions = self.getOptionsForSgdbCli() args = [self.cli] @@ -206,6 +222,31 @@ def builArgs(self, queryName=None): Log('Using cli args ' + ' '.join(args)) return args + def buildEnv(self): + cliOptions = self.getOptionsForSgdbCli() + env = dict() + + # append **optional** environment variables dict (if any) + optionalEnv = cliOptions.get('env_optional') + if optionalEnv: # only if we have optional args + if isinstance(optionalEnv, dict): + for var, value in optionalEnv.items(): + formattedValue = self.formatOptionalArgument(value, self.options) + if formattedValue: + env.update({var: formattedValue}) + + # append environment variables dict (if any) + staticEnv = cliOptions.get('env') + if staticEnv: # only if we have optional args + if isinstance(staticEnv, dict): + for var, value in staticEnv.items(): + formattedValue = value.format(**self.options) + if formattedValue: + env.update({var: formattedValue}) + + Log('Environment for command: ' + str(env)) + return env + def getOptionsForSgdbCli(self): return self.settings.get('cli_options', {}).get(self.type)