From 570c9187be72fb61c1331ac76056468f7481fa37 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Thu, 23 May 2024 23:33:48 -0700 Subject: [PATCH] Improve watching --- coconut/command/command.py | 63 +++++++++++++++++++++++++------------- coconut/command/watch.py | 6 +--- coconut/exceptions.py | 2 +- 3 files changed, 43 insertions(+), 28 deletions(-) diff --git a/coconut/command/command.py b/coconut/command/command.py index fc6fe2d3e..9f0a51f0e 100644 --- a/coconut/command/command.py +++ b/coconut/command/command.py @@ -576,7 +576,7 @@ def compile_file(self, filepath, write=True, package=False, force=False, **kwarg self.compile(filepath, destpath, package, force=force, **kwargs) return destpath - def compile(self, codepath, destpath=None, package=False, run=False, force=False, show_unchanged=True): + def compile(self, codepath, destpath=None, package=False, run=False, force=False, show_unchanged=True, callback=None): """Compile a source Coconut file to a destination Python file.""" with univ_open(codepath, "r") as opened: code = readfile(opened) @@ -603,29 +603,39 @@ def compile(self, codepath, destpath=None, package=False, run=False, force=False else: logger.show_tabulated("Compiling", showpath(codepath), "...") - def callback(compiled): - if destpath is None: - logger.show_tabulated("Compiled", showpath(codepath), "without writing to file.") - else: - with univ_open(destpath, "w") as opened: - writefile(opened, compiled) - logger.show_tabulated("Compiled to", showpath(destpath), ".") - if self.display: - logger.print(compiled) - if run: + def inner_callback(compiled): + try: if destpath is None: - self.execute(compiled, path=codepath, allow_show=False) + logger.show_tabulated("Compiled", showpath(codepath), "without writing to file.") else: - self.execute_file(destpath, argv_source_path=codepath) + with univ_open(destpath, "w") as opened: + writefile(opened, compiled) + logger.show_tabulated("Compiled to", showpath(destpath), ".") + if self.display: + logger.print(compiled) + if run: + if destpath is None: + self.execute(compiled, path=codepath, allow_show=False) + else: + self.execute_file(destpath, argv_source_path=codepath) + except BaseException as err: + if callback is not None: + callback(False, err) + raise + else: + if callback is not None: + callback(True, destpath) parse_kwargs = dict( codepath=codepath, use_cache=self.use_cache, ) + if callback is not None: + parse_kwargs["error_callback"] = lambda err: callback(False, err) if package is True: - self.submit_comp_job(codepath, callback, "parse_package", code, package_level=package_level, **parse_kwargs) + self.submit_comp_job(codepath, inner_callback, "parse_package", code, package_level=package_level, **parse_kwargs) elif package is False: - self.submit_comp_job(codepath, callback, "parse_file", code, **parse_kwargs) + self.submit_comp_job(codepath, inner_callback, "parse_file", code, **parse_kwargs) else: raise CoconutInternalException("invalid value for package", package) @@ -669,6 +679,7 @@ def create_package(self, dirpath, retries_left=create_package_retries): def submit_comp_job(self, path, callback, method, *args, **kwargs): """Submits a job on self.comp to be run in parallel.""" + error_callback = kwargs.pop("error_callback", None) if self.executor is None: with self.handling_exceptions(): callback(getattr(self.comp, method)(*args, **kwargs)) @@ -681,8 +692,14 @@ def callback_wrapper(completed_future): """Ensures that all errors are always caught, since errors raised in a callback won't be propagated.""" with logger.in_path(path): # handle errors in the path context with self.handling_exceptions(): - result = completed_future.result() - callback(result) + try: + result = completed_future.result() + except BaseException as err: + if error_callback is not None: + error_callback(err) + raise + else: + callback(result) future.add_done_callback(callback_wrapper) def register_exit_code(self, code=1, errmsg=None, err=None): @@ -1138,7 +1155,11 @@ def watch(self, all_compile_path_kwargs): def interrupt(): interrupted[0] = True - def recompile(path, **kwargs): + def recompile(path, callback, **kwargs): + def inner_callback(ok, path): + if ok: + self.run_mypy(path) + callback() path = fixpath(path) src = kwargs.pop("source") dest = kwargs.pop("dest") @@ -1150,14 +1171,14 @@ def recompile(path, **kwargs): # correct the compilation path based on the relative position of path to src dirpath = os.path.dirname(path) writedir = os.path.join(dest, os.path.relpath(dirpath, src)) - filepaths = self.compile_path( + self.compile_path( path, writedir, show_unchanged=False, handling_exceptions_kwargs=dict(on_keyboard_interrupt=interrupt), + callback=inner_callback, **kwargs # no comma for py2 ) - self.run_mypy(filepaths) observer = Observer() watchers = [] @@ -1171,8 +1192,6 @@ def recompile(path, **kwargs): try: while not interrupted[0]: time.sleep(watch_interval) - for wcher in watchers: - wcher.keep_watching() except KeyboardInterrupt: interrupt() finally: diff --git a/coconut/command/watch.py b/coconut/command/watch.py index c7046c397..281900a6b 100644 --- a/coconut/command/watch.py +++ b/coconut/command/watch.py @@ -46,10 +46,6 @@ def __init__(self, recompile, *args, **kwargs): self.recompile = recompile self.args = args self.kwargs = kwargs - self.keep_watching() - - def keep_watching(self): - """Allows recompiling previously-compiled files.""" self.saw = set() def on_modified(self, event): @@ -57,4 +53,4 @@ def on_modified(self, event): path = event.src_path if path not in self.saw: self.saw.add(path) - self.recompile(path, *self.args, **self.kwargs) + self.recompile(path, callback=lambda: self.saw.remove(path), *self.args, **self.kwargs) diff --git a/coconut/exceptions.py b/coconut/exceptions.py index 89843a428..016eeb7d2 100644 --- a/coconut/exceptions.py +++ b/coconut/exceptions.py @@ -201,7 +201,7 @@ def message(self, message, source, point, ln, extra=None, endpoint=None, filenam message_parts += ["|"] else: message_parts += ["/", "~" * (len(lines[0]) - point_ind - 1)] - message_parts += ["~" * (max_line_len - len(lines[0])), "\n"] + message_parts += ["~" * (max_line_len - len(lines[0]) + 1), "\n"] # add code, highlighting all of it together code_parts = []