diff --git a/SQLTools.py b/SQLTools.py index 6fe2c19..f0fece2 100644 --- a/SQLTools.py +++ b/SQLTools.py @@ -4,6 +4,7 @@ import os import re import logging +from collections import OrderedDict import sublime from sublime_plugin import WindowCommand, EventListener, TextCommand @@ -30,10 +31,10 @@ CONNECTIONS_FILENAME_DEFAULT = None QUERIES_FILENAME = None QUERIES_FILENAME_DEFAULT = None -settings = None -queries = None -connections = None -history = None +settingsStore = None +queriesStore = None +connectionsStore = None +historyStore = None # create pluggin logger DEFAULT_LOG_LEVEL = logging.WARNING @@ -58,7 +59,7 @@ def startPlugin(): global SETTINGS_FILENAME, SETTINGS_FILENAME_DEFAULT global CONNECTIONS_FILENAME, CONNECTIONS_FILENAME_DEFAULT global QUERIES_FILENAME, QUERIES_FILENAME_DEFAULT - global settings, queries, connections, history + global settingsStore, queriesStore, connectionsStore, historyStore USER_FOLDER = getSublimeUserFolder() DEFAULT_FOLDER = os.path.dirname(__file__) @@ -71,71 +72,60 @@ def startPlugin(): QUERIES_FILENAME_DEFAULT = os.path.join(DEFAULT_FOLDER, SQLTOOLS_QUERIES_FILE) try: - settings = Settings(SETTINGS_FILENAME, default=SETTINGS_FILENAME_DEFAULT) + settingsStore = Settings(SETTINGS_FILENAME, default=SETTINGS_FILENAME_DEFAULT) except Exception as e: msg = '{0}: Failed to parse {1} file'.format(__package__, SQLTOOLS_SETTINGS_FILE) logging.exception(msg) Window().status_message(msg) try: - connections = Settings(CONNECTIONS_FILENAME, default=CONNECTIONS_FILENAME_DEFAULT) + connectionsStore = Settings(CONNECTIONS_FILENAME, default=CONNECTIONS_FILENAME_DEFAULT) except Exception as e: msg = '{0}: Failed to parse {1} file'.format(__package__, SQLTOOLS_CONNECTIONS_FILE) logging.exception(msg) Window().status_message(msg) - queries = Storage(QUERIES_FILENAME, default=QUERIES_FILENAME_DEFAULT) - history = History(settings.get('history_size', 100)) + queriesStore = Storage(QUERIES_FILENAME, default=QUERIES_FILENAME_DEFAULT) + historyStore = History(settingsStore.get('history_size', 100)) - if settings.get('debug', False): + if settingsStore.get('debug', False): plugin_logger.setLevel(logging.DEBUG) else: plugin_logger.setLevel(DEFAULT_LOG_LEVEL) - Connection.setTimeout(settings.get('thread_timeout', 15)) - Connection.setHistoryManager(history) + Connection.setTimeout(settingsStore.get('thread_timeout', 15)) + Connection.setHistoryManager(historyStore) logger.info('plugin (re)loaded') logger.info('version %s', __version__) -def getConnections(): - connectionsObj = {} +def readConnections(): + mergedConnections = {} # fixes #39 and #45 - if not connections: + if not connectionsStore: startPlugin() - options = connections.get('connections', {}) - allSettings = settings.all() - - for name, config in options.items(): - connectionsObj[name] = createConnection(name, config, settings=allSettings) - - # project settings + # global connections + globalConnectionsDict = connectionsStore.get('connections', {}) + # project-specific connections + projectConnectionsDict = {} projectData = Window().project_data() if projectData: - options = projectData.get('connections', {}) - for name, config in options.items(): - connectionsObj[name] = createConnection(name, config, settings=allSettings) + projectConnectionsDict = projectData.get('connections', {}) - return connectionsObj + # merge connections + mergedConnections = globalConnectionsDict.copy() + mergedConnections.update(projectConnectionsDict) + ordered = OrderedDict(sorted(mergedConnections.items())) -def createConnection(name, config, settings): - newConnection = None - # if DB cli binary could not be found in path a FileNotFoundError is thrown - try: - newConnection = Connection(name, config, settings=settings) - except FileNotFoundError as e: - # use only first line of the Exception in status message - Window().status_message(__package__ + ": " + str(e).splitlines()[0]) - raise e - return newConnection + return ordered -def loadDefaultConnection(): - default = connections.get('default', False) +def getDefaultConnectionName(): + default = connectionsStore.get('default', False) if not default: return return default @@ -176,19 +166,19 @@ def toNewTab(content, name="", suffix="SQLTools Saved Query"): def insertContent(content): view = View() # getting the settings local to this view/tab - settings = view.settings() + viewSettings = view.settings() # saving the original settings for "auto_indent", or True if none set - autoIndent = settings.get('auto_indent', True) + autoIndent = viewSettings.get('auto_indent', True) # turn off automatic indenting otherwise the tabbing of the original # string is not respected after a newline is encountered - settings.set('auto_indent', False) + viewSettings.set('auto_indent', False) view.run_command('insert', {'characters': content}) # restore "auto_indent" setting - settings.set('auto_indent', autoIndent) + viewSettings.set('auto_indent', autoIndent) def getOutputPlace(syntax=None, name="SQLTools Result"): - showResultOnWindow = settings.get('show_result_on_window', False) + showResultOnWindow = settingsStore.get('show_result_on_window', False) if not showResultOnWindow: resultContainer = Window().find_output_panel(name) if resultContainer is None: @@ -209,7 +199,7 @@ def getOutputPlace(syntax=None, name="SQLTools Result"): resultContainer.settings().set("word_wrap", "false") def onInitialOutputCallback(): - if settings.get('clear_output', False): + if settingsStore.get('clear_output', False): resultContainer.set_read_only(False) resultContainer.run_command('select_all') resultContainer.run_command('left_delete') @@ -233,7 +223,7 @@ def onInitialOutputCallback(): # if case this is an output pannel, show it Window().run_command("show_panel", {"panel": "output." + name}) - if settings.get('focus_on_result', False): + if settingsStore.get('focus_on_result', False): Window().focus_view(resultContainer) return resultContainer, onInitialOutputCallback @@ -263,12 +253,12 @@ def getSelectionRegions(): # 'file', 'view' = use text of current view # 'paragraph' = paragraph(s) (text between newlines) # 'line' = current line(s) - expandTo = settings.get('expand_to', 'file') + expandTo = settingsStore.get('expand_to', 'file') if not expandTo: expandTo = 'file' # keep compatibility with previous settings - expandToParagraph = settings.get('expand_to_paragraph') + expandToParagraph = settingsStore.get('expand_to_paragraph') if expandToParagraph is True: expandTo = 'paragraph' @@ -306,7 +296,7 @@ def getCurrentSyntax(): class ST(EventListener): - connectionList = None + connectionDict = None conn = None tables = [] columns = [] @@ -315,115 +305,186 @@ class ST(EventListener): @staticmethod def bootstrap(): - ST.connectionList = getConnections() - ST.checkDefaultConnection() + ST.connectionDict = readConnections() + ST.setDefaultConnection() @staticmethod - def checkDefaultConnection(): - default = loadDefaultConnection() + def setDefaultConnection(): + default = getDefaultConnectionName() if not default: return - + if default not in ST.connectionDict: + logger.error('connection "%s" set as default, but it does not exists', default) + return logger.info('default connection is set to "%s"', default) + ST.setConnection(default) + + @staticmethod + def setConnection(connectionName, callback=None): + if not connectionName: + return + + if connectionName not in ST.connectionDict: + return + + settings = settingsStore.all() + config = ST.connectionDict.get(connectionName) + + promptKeys = [key for key, value in config.items() if value is None] + promptDict = {} + logger.info('[setConnection] prompt keys {}'.format(promptKeys)) + + def mergeConfig(config, promptedKeys=None): + merged = config.copy() + if promptedKeys: + merged.update(promptedKeys) + return merged + + def createConnection(connectionName, config, settings, callback=None): + # if DB cli binary could not be found in path a FileNotFoundError is thrown + try: + ST.conn = Connection(connectionName, config, settings=settings) + except FileNotFoundError as e: + # use only first line of the Exception in status message + Window().status_message(__package__ + ": " + str(e).splitlines()[0]) + raise e + ST.loadConnectionData(callback) + + if not promptKeys: + createConnection(connectionName, config, settings, callback) + return - try: - ST.conn = ST.connectionList[default] - except KeyError as e: - logger.error('connection "%s" set as default, but it does not exists', default) - else: - ST.loadConnectionData() + + def setMissingKey(key, value): + nonlocal promptDict + if value is None: + return + promptDict[key] = value + if promptKeys: + promptNext() + else: + merged = mergeConfig(config, promptDict) + createConnection(connectionName, merged, settings, callback) + + def promptNext(): + nonlocal promptKeys + if not promptKeys: + merged = mergeConfig(config, promptDict) + createConnection(connectionName, merged, settings, callback) + key = promptKeys.pop(); + Window().show_input_panel( + 'Connection ' + key, + '', + lambda userInput: setMissingKey(key, userInput), + None, + None) + + promptNext() @staticmethod - def loadConnectionData(tablesCallback=None, columnsCallback=None, functionsCallback=None): + def loadConnectionData(callback=None): # clear the list of identifiers (in case connection is changed) ST.tables = [] ST.columns = [] ST.functions = [] ST.completion = None - callbacksRun = 0 + objectsLoaded = 0 if not ST.conn: return - def tbCallback(tables): - ST.tables = tables - - nonlocal callbacksRun - callbacksRun += 1 - if callbacksRun == 3: - ST.completion = Completion(ST.tables, ST.columns, ST.functions, settings=settings) + def afterAllDataHasLoaded(): + ST.completion = Completion(ST.tables, ST.columns, ST.functions, settings=settingsStore) + logger.info('completions loaded') + if (callback): + callback() - if tablesCallback: - tablesCallback() + def tablesCallback(tables): + ST.tables = tables + nonlocal objectsLoaded + objectsLoaded += 1 + logger.info('loaded tables : "{0}"'.format(tables)) + if objectsLoaded == 3: + afterAllDataHasLoaded() - def colCallback(columns): + def columnsCallback(columns): ST.columns = columns + nonlocal objectsLoaded + objectsLoaded += 1 + logger.info('loaded columns : "{0}"'.format(columns)) + if objectsLoaded == 3: + afterAllDataHasLoaded() - nonlocal callbacksRun - callbacksRun += 1 - if callbacksRun == 3: - ST.completion = Completion(ST.tables, ST.columns, ST.functions, settings=settings) - - if columnsCallback: - columnsCallback() - - def funcCallback(functions): + def functionsCallback(functions): ST.functions = functions + nonlocal objectsLoaded + objectsLoaded += 1 + logger.info('loaded functions: "{0}"'.format(functions)) + if objectsLoaded == 3: + logger.info('all objects loaded') + afterAllDataHasLoaded() - nonlocal callbacksRun - callbacksRun += 1 - if callbacksRun == 3: - ST.completion = Completion(ST.tables, ST.columns, ST.functions, settings=settings) - - if functionsCallback: - functionsCallback() - - ST.conn.getTables(tbCallback) - ST.conn.getColumns(colCallback) - ST.conn.getFunctions(funcCallback) - - @staticmethod - def setConnection(index, tablesCallback=None, columnsCallback=None, functionsCallback=None): - if index < 0 or index > (len(ST.connectionList) - 1): - return - - connListNames = list(ST.connectionList.keys()) - connListNames.sort() - ST.conn = ST.connectionList.get(connListNames[index]) - ST.loadConnectionData(tablesCallback, columnsCallback, functionsCallback) - logger.info('Connection "{0}" selected'.format(ST.conn)) + ST.conn.getTables(tablesCallback) + ST.conn.getColumns(columnsCallback) + ST.conn.getFunctions(functionsCallback) @staticmethod - def selectConnection(tablesCallback=None, columnsCallback=None, functionsCallback=None): - ST.connectionList = getConnections() - if len(ST.connectionList) == 0: + def selectConnectionQuickPanel(callback=None): + ST.connectionDict = readConnections() + if len(ST.connectionDict) == 0: sublime.message_dialog('You need to setup your connections first.') return - menu = [] - for name, conn in ST.connectionList.items(): - menu.append([name, conn.info()]) - menu.sort() - Window().show_quick_panel(menu, lambda index: ST.setConnection(index, tablesCallback, columnsCallback, functionsCallback)) + def connectionMenuList(connDictionary): + menuItemsList = [] + template = '{dbtype}://{user}{host}{port}{db}' + for name, config in ST.connectionDict.items(): + dbtype = config.get('type', '') + user = '{}@'.format(config.get('username', '')) if 'username' in config else '' + # user = config.get('username', '') + host=config.get('host', '') + port = ':{}'.format(config.get('port', '')) if 'port' in config else '' + db = '/{}'.format(config.get('database', '')) if 'database' in config else '' + connectionInfo = template.format( + dbtype=dbtype, + user=user, + host=host, + port=port, + db=db) + menuItemsList.append([name, connectionInfo]) + menuItemsList.sort() + return menuItemsList + + def onConnectionSelected(index, callback): + menuItemsList = connectionMenuList(ST.connectionDict) + if index < 0 or index >= len(menuItemsList): + return + connectionName = menuItemsList[index][0] + ST.setConnection(connectionName, callback) + logger.info('Connection "{0}" selected'.format(connectionName)) + + menu = connectionMenuList(ST.connectionDict) + # show pannel with callback above + Window().show_quick_panel(menu, lambda index: onConnectionSelected(index, callback)) @staticmethod - def selectTable(callback): + def showTablesQuickPanel(callback): if len(ST.tables) == 0: sublime.message_dialog('Your database has no tables.') return - ST.show_quick_panel_with_selection(ST.tables, callback) + ST.showQuickPanelWithSelection(ST.tables, callback) @staticmethod - def selectFunction(callback): + def showFunctionsQuickPanel(callback): if len(ST.functions) == 0: sublime.message_dialog('Your database has no functions.') return - ST.show_quick_panel_with_selection(ST.functions, callback) + ST.showQuickPanelWithSelection(ST.functions, callback) @staticmethod - def show_quick_panel_with_selection(arrayOfValues, callback): + def showQuickPanelWithSelection(arrayOfValues, callback): w = Window(); view = w.active_view() selection = view.sel()[0] @@ -525,17 +586,17 @@ def run(): class StSelectConnection(WindowCommand): @staticmethod def run(): - ST.selectConnection() + ST.selectConnectionQuickPanel() class StShowRecords(WindowCommand): @staticmethod def run(): if not ST.conn: - ST.selectConnection(tablesCallback=lambda: Window().run_command('st_show_records')) + ST.selectConnectionQuickPanel(callback=lambda: Window().run_command('st_show_records')) return - def cb(index): + def onTableSelected(index): if index < 0: return None Window().status_message(MESSAGE_RUNNING_CMD) @@ -545,7 +606,7 @@ def cb(index): tableName, createOutput(prependText=prependText)) - ST.selectTable(cb) + ST.showTablesQuickPanel(callback=onTableSelected) class StDescTable(WindowCommand): @@ -554,16 +615,16 @@ def run(): currentSyntax = getCurrentSyntax() if not ST.conn: - ST.selectConnection(tablesCallback=lambda: Window().run_command('st_desc_table')) + ST.selectConnectionQuickPanel(callback=lambda: Window().run_command('st_desc_table')) return - def cb(index): + def onTableSelected(index): if index < 0: return None Window().status_message(MESSAGE_RUNNING_CMD) return ST.conn.getTableDescription(ST.tables[index], createOutput(syntax=currentSyntax)) - ST.selectTable(cb) + ST.showTablesQuickPanel(callback=onTableSelected) class StDescFunction(WindowCommand): @@ -572,10 +633,10 @@ def run(): currentSyntax = getCurrentSyntax() if not ST.conn: - ST.selectConnection(functionsCallback=lambda: Window().run_command('st_desc_function')) + ST.selectConnectionQuickPanel(callback=lambda: Window().run_command('st_desc_function')) return - def cb(index): + def onFunctionSelected(index): if index < 0: return None Window().status_message(MESSAGE_RUNNING_CMD) @@ -584,7 +645,7 @@ def cb(index): # get everything until first occurrence of "(", e.g. get "function_name" # from "function_name(int)" - ST.selectFunction(cb) + ST.showFunctionsQuickPanel(callback=onFunctionSelected) class StRefreshConnectionData(WindowCommand): @@ -599,7 +660,7 @@ class StExplainPlan(WindowCommand): @staticmethod def run(): if not ST.conn: - ST.selectConnection(tablesCallback=lambda: Window().run_command('st_explain_plan')) + ST.selectConnectionQuickPanel(callback=lambda: Window().run_command('st_explain_plan')) return Window().status_message(MESSAGE_RUNNING_CMD) @@ -610,7 +671,7 @@ class StExecute(WindowCommand): @staticmethod def run(): if not ST.conn: - ST.selectConnection(tablesCallback=lambda: Window().run_command('st_execute')) + ST.selectConnectionQuickPanel(callback=lambda: Window().run_command('st_execute')) return Window().status_message(MESSAGE_RUNNING_CMD) @@ -621,7 +682,7 @@ class StExecuteAll(WindowCommand): @staticmethod def run(): if not ST.conn: - ST.selectConnection(tablesCallback=lambda: Window().run_command('st_execute_all')) + ST.selectConnectionQuickPanel(callback=lambda: Window().run_command('st_execute_all')) return Window().status_message(MESSAGE_RUNNING_CMD) @@ -639,7 +700,7 @@ def run(edit): for region in selectionRegions: textToFormat = View().substr(region) - View().replace(edit, region, Utils.formatSql(textToFormat, settings.get('format', {}))) + View().replace(edit, region, Utils.formatSql(textToFormat, settingsStore.get('format', {}))) class StFormatAll(TextCommand): @@ -647,7 +708,7 @@ class StFormatAll(TextCommand): def run(edit): region = sublime.Region(0, View().size()) textToFormat = View().substr(region) - View().replace(edit, region, Utils.formatSql(textToFormat, settings.get('format', {}))) + View().replace(edit, region, Utils.formatSql(textToFormat, settingsStore.get('format', {}))) class StVersion(WindowCommand): @@ -660,19 +721,19 @@ class StHistory(WindowCommand): @staticmethod def run(): if not ST.conn: - ST.selectConnection(functionsCallback=lambda: Window().run_command('st_history')) + ST.selectConnectionQuickPanel(callback=lambda: Window().run_command('st_history')) return - if len(history.all()) == 0: + if len(historyStore.all()) == 0: sublime.message_dialog('History is empty.') return def cb(index): if index < 0: return None - return ST.conn.execute(history.get(index), createOutput()) + return ST.conn.execute(historyStore.get(index), createOutput()) - Window().show_quick_panel(history.all(), cb) + Window().show_quick_panel(historyStore.all(), cb) class StSaveQuery(WindowCommand): @@ -681,7 +742,7 @@ def run(): query = getSelectionText() def cb(alias): - queries.add(alias, query) + queriesStore.add(alias, query) Window().show_input_panel('Query alias', '', cb, None, None) @@ -689,11 +750,11 @@ class StListQueries(WindowCommand): @staticmethod def run(mode="run"): if mode == "run" and not ST.conn: - ST.selectConnection(functionsCallback=lambda: Window().run_command('st_list_queries', + ST.selectConnectionQuickPanel(callback=lambda: Window().run_command('st_list_queries', {'mode': mode})) return - queriesList = queries.all() + queriesList = queriesStore.all() if len(queriesList) == 0: sublime.message_dialog('No saved queries.') return @@ -727,10 +788,10 @@ class StRemoveSavedQuery(WindowCommand): @staticmethod def run(): if not ST.conn: - ST.selectConnection(functionsCallback=lambda: Window().run_command('st_remove_saved_query')) + ST.selectConnectionQuickPanel(callback=lambda: Window().run_command('st_remove_saved_query')) return - queriesList = queries.all() + queriesList = queriesStore.all() if len(queriesList) == 0: sublime.message_dialog('No saved queries.') return @@ -744,7 +805,7 @@ def cb(index): if index < 0: return None - return queries.delete(options[index][0]) + return queriesStore.delete(options[index][0]) try: Window().show_quick_panel(options, cb) except Exception: diff --git a/SQLToolsAPI/Command.py b/SQLToolsAPI/Command.py index 1f2d26b..4f016be 100644 --- a/SQLToolsAPI/Command.py +++ b/SQLToolsAPI/Command.py @@ -177,9 +177,14 @@ def stop(self): self.process = None logger.info("command execution exceeded timeout (%s s), process killed", self.timeout) - self.callback("Command execution time exceeded 'thread_timeout' ({0} s).\nProcess killed!\n\n" - .format(self.timeout)) + self.callback(("Command execution time exceeded 'thread_timeout' ({0} s).\n" + "Process killed!\n\n" + ).format(self.timeout)) except Exception: + logger.info("command execution exceeded timeout (%s s), process could not be killed", self.timeout) + self.callback(("Command execution time exceeded 'thread_timeout' ({0} s).\n" + "Process could not be killed!\n\n" + ).format(self.timeout)) pass @staticmethod diff --git a/SQLToolsAPI/Connection.py b/SQLToolsAPI/Connection.py index 727a446..8acfa71 100644 --- a/SQLToolsAPI/Connection.py +++ b/SQLToolsAPI/Connection.py @@ -49,19 +49,19 @@ def __init__(self, name, options, settings=None, commandClass='ThreadCommand'): self.Command = getattr(C, commandClass) self.name = name - self.options = options + self.options = {k: v for k, v in options.items() if v is not None} if settings is None: settings = {} self.settings = settings - self.type = options.get('type', None) - self.host = options.get('host', None) - self.port = options.get('port', None) - self.database = options.get('database', None) - self.username = options.get('username', None) - self.password = options.get('password', None) - self.encoding = options.get('encoding', 'utf-8') + self.type = self.options.get('type', None) + self.host = self.options.get('host', None) + self.port = self.options.get('port', None) + self.database = self.options.get('database', None) + self.username = self.options.get('username', None) + self.password = self.options.get('password', None) + self.encoding = self.options.get('encoding', 'utf-8') self.encoding = self.encoding or 'utf-8' # defaults to utf-8 if not _encoding_exists(self.encoding): self.encoding = 'utf-8' @@ -70,7 +70,7 @@ def __init__(self, name, options, settings=None, commandClass='ThreadCommand'): self.show_query = settings.get('show_query', False) self.rowsLimit = settings.get('show_records', {}).get('limit', 50) self.useStreams = settings.get('use_streams', False) - self.cli = settings.get('cli')[options['type']] + self.cli = settings.get('cli')[self.options['type']] cli_path = shutil.which(self.cli) if cli_path is None: @@ -87,6 +87,8 @@ def info(self): def runInternalNamedQueryCommand(self, queryName, callback): query = self.getNamedQuery(queryName) if not query: + emptyList = [] + callback(emptyList) return queryToRun = self.buildNamedQuery(queryName, query) diff --git a/SQLToolsConnections.sublime-settings b/SQLToolsConnections.sublime-settings index e648047..db8e1c4 100644 --- a/SQLToolsConnections.sublime-settings +++ b/SQLToolsConnections.sublime-settings @@ -2,6 +2,7 @@ "connections": { /* "Generic Template": { // Connection name, used in menu (Display name) + // connection properties set to "null" will prompt for value when connecting "type" : "pgsql", // DB type: (mysql, pgsql, oracle, vertica, sqlite, firebird, sqsh) "host" : "HOSTNAME", // DB host to connect to "port" : PORT, // DB port