Skip to content

Commit

Permalink
Added new PostgreSQL auth method using env vars
Browse files Browse the repository at this point in the history
Two new dict options under cli_options[type] are added in settings:
 * env - always sets the vars to given value
 * env_optional - sets the env vars only if sucessfully formatted

Default settings file is updated to support this new functionality.
  • Loading branch information
tkopets committed May 30, 2017
1 parent 90e8086 commit 9c9c4cd
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 17 deletions.
3 changes: 3 additions & 0 deletions SQLTools.sublime-settings
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
"options": ["--no-password"],
"before": [],
"args": "-h {host} -p {port} -U {username} -d {database}",
"env_optional": {
"PGPASSWORD": "{password}"
},
"queries": {
"desc" : {
"query": "select '|' || quote_ident(table_schema)||'.'||quote_ident(table_name) ||'|' as tblname from information_schema.tables where table_schema = any(current_schemas(false)) and table_schema not in ('pg_catalog', 'information_schema') order by table_schema = current_schema() desc, table_schema, table_name",
Expand Down
22 changes: 14 additions & 8 deletions SQLToolsAPI/Command.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
class Command(object):
timeout = 15

def __init__(self, args, callback, query=None, encoding='utf-8',
def __init__(self, args, env, callback, query=None, encoding='utf-8',
options=None, timeout=15, silenceErrors=False, stream=False):
if options is None:
options = {}

self.stream = stream
self.args = args
self.env = env
self.callback = callback
self.query = query
self.encoding = encoding
Expand All @@ -43,13 +44,18 @@ def run(self):
if self.silenceErrors:
stderrHandle = subprocess.PIPE

# set the environment
modifiedEnvironment = os.environ.copy()
if (self.env):
modifiedEnvironment.update(self.env)

queryTimerStart = time.time()

self.process = subprocess.Popen(self.args,
stdout=subprocess.PIPE,
stderr=stderrHandle,
stdin=subprocess.PIPE,
env=os.environ.copy(),
env=modifiedEnvironment,
startupinfo=si)

if self.stream:
Expand Down Expand Up @@ -103,21 +109,21 @@ def _formatShowQuery(query, queryTimeStart, queryTimeEnd):
return resultString

@staticmethod
def createAndRun(args, query, callback, options=None, timeout=15, silenceErrors=False, stream=False):
def createAndRun(args, env, query, callback, options=None, timeout=15, silenceErrors=False, stream=False):
if options is None:
options = {}
command = Command(args, callback, query, options=options,
command = Command(args, env, callback, query, options=options,
timeout=timeout, silenceErrors=silenceErrors)
command.run()


class ThreadCommand(Command, Thread):
def __init__(self, args, callback, query=None, encoding='utf-8',
def __init__(self, args, env, callback, query=None, encoding='utf-8',
options=None, timeout=Command.timeout, silenceErrors=False, stream=False):
if options is None:
options = {}

Command.__init__(self, args, callback, query=query,
Command.__init__(self, args, env, callback, query=query,
encoding=encoding, options=options,
timeout=timeout, silenceErrors=silenceErrors,
stream=stream)
Expand All @@ -143,13 +149,13 @@ def stop(self):
pass

@staticmethod
def createAndRun(args, query, callback, options=None,
def createAndRun(args, env, query, callback, options=None,
timeout=Command.timeout, silenceErrors=False, stream=False):
# Don't allow empty dicts or lists as defaults in method signature,
# cfr http://nedbatchelder.com/blog/200806/pylint.html
if options is None:
options = {}
command = ThreadCommand(args, callback, query, options=options,
command = ThreadCommand(args, env, callback, query, options=options,
timeout=timeout, silenceErrors=silenceErrors, stream=stream)
command.start()
killTimeout = Timer(command.timeout, command.stop)
Expand Down
59 changes: 50 additions & 9 deletions SQLToolsAPI/Connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def getTables(self, callback):
def cb(result):
callback(U.getResultAsList(result))

self.Command.createAndRun(self.builArgs('desc'),
args = self.buildArgs('desc')
env = self.buildEnv()
self.Command.createAndRun(args, env,
query, cb, silenceErrors=True)

def getColumns(self, callback):
Expand All @@ -87,7 +89,9 @@ def cb(result):

try:
query = self.getOptionsForSgdbCli()['queries']['columns']['query']
self.Command.createAndRun(self.builArgs('columns'),
args = self.buildArgs('columns')
env = self.buildEnv()
self.Command.createAndRun(args, env,
query, cb, silenceErrors=True)
except Exception:
pass
Expand All @@ -99,26 +103,34 @@ def cb(result):

try:
query = self.getOptionsForSgdbCli()['queries']['functions']['query']
self.Command.createAndRun(self.builArgs('functions'),
args = self.buildArgs('functions')
env = self.buildEnv()
self.Command.createAndRun(args, env,
query, cb, silenceErrors=True)
except Exception:
pass

def getTableRecords(self, tableName, callback):
query = self.getOptionsForSgdbCli()['queries']['show records']['query'].format(tableName, self.rowsLimit)
queryToRun = '\n'.join(self.getOptionsForSgdbCli()['before'] + [query])
self.Command.createAndRun(self.builArgs('show records'), queryToRun, callback, timeout=self.timeout)
args = self.buildArgs('show records')
env = self.buildEnv()
self.Command.createAndRun(args, env, queryToRun, callback, timeout=self.timeout)

def getTableDescription(self, tableName, callback):
query = self.getOptionsForSgdbCli()['queries']['desc table']['query'] % tableName
queryToRun = '\n'.join(self.getOptionsForSgdbCli()['before'] + [query])
self.Command.createAndRun(self.builArgs('desc table'), queryToRun, callback)
args = self.buildArgs('desc table')
env = self.buildEnv()
self.Command.createAndRun(args, env, queryToRun, callback)

def getFunctionDescription(self, functionName, callback):
query = self.getOptionsForSgdbCli()['queries']['desc function'][
'query'] % functionName
queryToRun = '\n'.join(self.getOptionsForSgdbCli()['before'] + [query])
self.Command.createAndRun(self.builArgs('desc function'), queryToRun, callback)
args = self.buildArgs('desc function')
env = self.buildEnv()
self.Command.createAndRun(args, env, queryToRun, callback)

def explainPlan(self, queries, callback):
try:
Expand All @@ -132,7 +144,9 @@ def explainPlan(self, queries, callback):
for query in filter(None, sqlparse.split(rawQuery))
]
queryToRun = '\n'.join(self.getOptionsForSgdbCli()['before'] + stripped_queries)
self.Command.createAndRun(self.builArgs('explain plan'), queryToRun, callback, timeout=self.timeout)
args = self.buildArgs('explain plan')
env = self.buildEnv()
self.Command.createAndRun(args, env, queryToRun, callback, timeout=self.timeout)

def execute(self, queries, callback, stream=False):
queryToRun = ''
Expand Down Expand Up @@ -165,9 +179,11 @@ def execute(self, queries, callback, stream=False):
if self.history:
self.history.add(queryToRun)

self.Command.createAndRun(self.builArgs(), queryToRun, callback, options={'show_query': self.show_query}, timeout=self.timeout, stream=stream)
args = self.buildArgs()
env = self.buildEnv()
self.Command.createAndRun(args, env, queryToRun, callback, options={'show_query': self.show_query}, timeout=self.timeout, stream=stream)

def builArgs(self, queryName=None):
def buildArgs(self, queryName=None):
cliOptions = self.getOptionsForSgdbCli()
args = [self.cli]

Expand Down Expand Up @@ -206,6 +222,31 @@ def builArgs(self, queryName=None):
Log('Using cli args ' + ' '.join(args))
return args

def buildEnv(self):
cliOptions = self.getOptionsForSgdbCli()
env = dict()

# append **optional** environment variables dict (if any)
optionalEnv = cliOptions.get('env_optional')
if optionalEnv: # only if we have optional args
if isinstance(optionalEnv, dict):
for var, value in optionalEnv.items():
formattedValue = self.formatOptionalArgument(value, self.options)
if formattedValue:
env.update({var: formattedValue})

# append environment variables dict (if any)
staticEnv = cliOptions.get('env')
if staticEnv: # only if we have optional args
if isinstance(staticEnv, dict):
for var, value in staticEnv.items():
formattedValue = value.format(**self.options)
if formattedValue:
env.update({var: formattedValue})

Log('Environment for command: ' + str(env))
return env

def getOptionsForSgdbCli(self):
return self.settings.get('cli_options', {}).get(self.type)

Expand Down

0 comments on commit 9c9c4cd

Please sign in to comment.