diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 3065514a1..6ecbc4774 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -7,7 +7,6 @@ jobs: matrix: python-version: - '2.7' - - '3.5' - '3.6' - '3.7' - '3.8' @@ -24,9 +23,9 @@ jobs: fail-fast: false name: Python ${{ matrix.python-version }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Setup python - uses: MatteoH2O1999/setup-python@v2 + uses: MatteoH2O1999/setup-python@v4 with: python-version: ${{ matrix.python-version }} cache: pip diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c7868a2d8..e1ea9b6af 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: - --experimental - --ignore=W503,E501,E722,E402,E721 - repo: https://github.com/pre-commit/pre-commit-hooks.git - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-added-large-files - id: fix-byte-order-marker diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 12b79fd46..53f4cdbf7 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -155,7 +155,7 @@ After you've tested your changes locally, you'll want to add more permanent test 1. Preparation: 1. Run `make check-reqs` and update dependencies as necessary 2. Run `sudo make format` - 3. Make sure `make test`, `make test-py2`, and `make test-easter-eggs` are passing + 3. Make sure `make test`, `make test-pyright`, and `make test-easter-eggs` are passing 4. Ensure that `coconut --watch` can successfully compile files when they're modified 5. Check changes in [`compiled-cocotest`](https://github.com/evhub/compiled-cocotest), [`pyprover`](https://github.com/evhub/pyprover), and [`coconut-prelude`](https://github.com/evhub/coconut-prelude) 6. Check [Codebeat](https://codebeat.co/a/evhub/projects) and [LGTM](https://lgtm.com/dashboard) for `coconut` and `compiled-cocotest` diff --git a/DOCS.md b/DOCS.md index 69cd43466..a341d3a4a 100644 --- a/DOCS.md +++ b/DOCS.md @@ -92,6 +92,7 @@ The full list of optional dependencies is: - `kernel`: lightweight subset of `jupyter` that only includes the dependencies that are strictly necessary for Coconut's [Jupyter kernel](#kernel). - `watch`: enables use of the `--watch` flag. - `mypy`: enables use of the `--mypy` flag. +- `pyright`: enables use of the `--pyright` flag. - `xonsh`: enables use of Coconut's [`xonsh` support](#xonsh-support). - `numpy`: installs everything necessary for making use of Coconut's [`numpy` integration](#numpy-integration). - `jupyterlab`: installs everything necessary to use [JupyterLab](https://github.com/jupyterlab/jupyterlab) with Coconut. @@ -121,11 +122,11 @@ depth: 1 ``` coconut [-h] [--and source [dest ...]] [-v] [-t version] [-i] [-p] [-a] [-l] - [--no-line-numbers] [-k] [-w] [-r] [-n] [-d] [-q] [-s] [--no-tco] - [--no-wrap-types] [-c code] [--incremental] [-j processes] [-f] [--minify] - [--jupyter ...] [--mypy ...] [--argv ...] [--tutorial] [--docs] [--style name] - [--vi-mode] [--recursion-limit limit] [--stack-size kbs] [--site-install] - [--site-uninstall] [--verbose] [--trace] [--profile] + [--no-line-numbers] [-k] [-w] [-r] [-n] [-d] [-q] [-s] [--no-tco] [--no-wrap-types] + [-c code] [-j processes] [-f] [--minify] [--jupyter ...] [--mypy ...] [--pyright] + [--argv ...] [--tutorial] [--docs] [--style name] [--vi-mode] + [--recursion-limit limit] [--stack-size kbs] [--fail-fast] [--no-cache] + [--site-install] [--site-uninstall] [--verbose] [--trace] [--profile] [source] [dest] ``` @@ -184,6 +185,7 @@ dest destination directory for compiled files (defaults to Jupyter) --mypy ... run MyPy on compiled Python (remaining args passed to MyPy) (implies --package --line-numbers) +--pyright run Pyright on compiled Python (implies --package) --argv ..., --args ... set sys.argv to source plus remaining args for use in the Coconut script being run @@ -280,6 +282,8 @@ To make Coconut built-ins universal across Python versions, Coconut makes availa - `py_xrange` - `py_repr` - `py_breakpoint` +- `py_min` +- `py_max` _Note: Coconut's `repr` can be somewhat tricky, as it will attempt to remove the `u` before reprs of unicode strings on Python 2, but will not always be able to do so if the unicode string is nested._ @@ -334,6 +338,7 @@ If the `--strict` (`-s` for short) flag is enabled, Coconut will perform additio The style issues which will cause `--strict` to throw an error are: - mixing of tabs and spaces +- use of `"hello" "world"` implicit string concatenation (use explicit `+` instead) - use of `from __future__` imports (Coconut does these automatically) - inheriting from `object` in classes (Coconut does this automatically) - semicolons at end of lines @@ -449,6 +454,10 @@ You can also run `mypy`—or any other static type checker—directly on the com To distribute your code with checkable type annotations, you'll need to include `coconut` as a dependency (though a `--no-deps` install should be fine), as installing it is necessary to make the requisite stub files available. You'll also probably want to include a [`py.typed`](https://peps.python.org/pep-0561/) file. +##### Pyright Integration + +Though not as well-supported as MyPy, Coconut also has built-in [Pyright](https://github.com/microsoft/pyright) support. Simply pass `--pyright` to automatically run Pyright on all compiled code. To adjust Pyright options, rather than pass them at the command-line, add your settings to the file `~/.coconut_pyrightconfig.json` (automatically generated the first time `coconut --pyright` is run). + ##### Syntax To explicitly annotate your code with types to be checked, Coconut supports (on all Python versions): @@ -464,7 +473,7 @@ Sometimes, MyPy will not know how to handle certain Coconut constructs, such as ##### Interpreter -Coconut even supports `--mypy` in the interpreter, which will intelligently scan each new line of code, in the context of previous lines, for newly-introduced MyPy errors. For example: +Coconut even supports `--mypy` (though not `--pyright`) in the interpreter, which will intelligently scan each new line of code, in the context of previous lines, for newly-introduced MyPy errors. For example: ```coconut_pycon >>> a: str = count()[0] :14: error: Incompatible types in assignment (expression has type "int", variable has type "str") @@ -541,9 +550,9 @@ a `b` c, left (captures lambda) all custom operators ?? left (short-circuits) ..>, <.., ..*>, <*.., n/a (captures lambda) - ..**>, <**.. + ..**>, <**.., etc. |>, <|, |*>, <*|, left (captures lambda) - |**>, <**| + |**>, <**|, etc. ==, !=, <, >, <=, >=, in, not in, @@ -1318,11 +1327,10 @@ data Empty() from Tree data Leaf(n) from Tree data Node(l, r) from Tree -def depth(Tree()) = 0 - -addpattern def depth(Tree(n)) = 1 - -addpattern def depth(Tree(l, r)) = 1 + max([depth(l), depth(r)]) +case def depth: + case(Tree()) = 0 + case(Tree(n)) = 1 + case(Tree(l, r)) = 1 + max(depth(l), depth(r)) Empty() |> depth |> print Leaf(5) |> depth |> print @@ -1338,26 +1346,26 @@ def duplicate_first([x] + xs as l) = ``` _Showcases head-tail splitting, one of the most common uses of pattern-matching, where a `+ ` (or `:: ` for any iterable) at the end of a list or tuple literal can be used to match the rest of the sequence._ -``` -def sieve([head] :: tail) = - [head] :: sieve(n for n in tail if n % head) - -addpattern def sieve((||)) = [] +```coconut +case def sieve: + case([head] :: tail) = + [head] :: sieve(n for n in tail if n % head) + case((||)) = [] ``` _Showcases how to match against iterators, namely that the empty iterator case (`(||)`) must come last, otherwise that case will exhaust the whole iterator before any other pattern has a chance to match against it._ -``` +```coconut def odd_primes(p=3) = (p,) :: filter(=> _ % p != 0, odd_primes(p + 2)) def primes() = (2,) :: odd_primes() -def twin_primes(_ :: [p, (.-2) -> p] :: ps) = - [(p, p+2)] :: twin_primes([p + 2] :: ps) - -addpattern def twin_primes() = # type: ignore - twin_primes(primes()) +case def twin_primes: + case(_ :: [p, (.-2) -> p] :: ps) = + [(p, p+2)] :: twin_primes([p + 2] :: ps) + case() = + twin_primes(primes()) twin_primes()$[:5] |> list |> print ``` @@ -1386,7 +1394,7 @@ match : ``` where `` is any `match` pattern, `` is the item to match against, `` is an optional additional check, and `` is simply code that is executed if the header above it succeeds. Note the absence of an `in` in the `match` statements: that's because the `` in `case ` is taking its place. If no `else` is present and no match succeeds, then the `case` statement is simply skipped over as with [`match` statements](#match) (though unlike [destructuring assignments](#destructuring-assignment)). -Additionally, `cases` can be used as the top-level keyword instead of `match`, and in such a `case` block `match` is allowed for each case rather than `case`. _Deprecated: Coconut also supports `case` instead of `cases` as the top-level keyword for backwards-compatibility purposes._ +_Deprecated: Additionally, `cases` or `case` can be used as the top-level keyword instead of `match`, and in such a block `match` is used for each case rather than `case`._ ##### Examples @@ -1520,15 +1528,14 @@ data Empty() data Leaf(n) data Node(l, r) -def size(Empty()) = 0 - -addpattern def size(Leaf(n)) = 1 - -addpattern def size(Node(l, r)) = size(l) + size(r) +case def size: + case(Empty()) = 0 + case(Leaf(n)) = 1 + case(Node(l, r)) = size(l) + size(r) size(Node(Empty(), Leaf(10))) == 1 ``` -_Showcases the algebraic nature of `data` types when combined with pattern-matching._ +_Showcases the use of pattern-matching to deconstruct `data` types._ ```coconut data vector(*pts): @@ -2219,7 +2226,7 @@ quad = 5 * x**2 + 3 * x + 1 When passing in long variable names as keyword arguments of the same name, Coconut supports the syntax ``` -f(...=long_variable_name) +f(long_variable_name=) ``` as a shorthand for ``` @@ -2228,6 +2235,8 @@ f(long_variable_name=long_variable_name) Such syntax is also supported in [partial application](#partial-application) and [anonymous `namedtuple`s](#anonymous-namedtuples). +_Deprecated: Coconut also supports `f(...=long_variable_name)` as an alternative shorthand syntax._ + ##### Example **Coconut:** @@ -2235,8 +2244,8 @@ Such syntax is also supported in [partial application](#partial-application) and really_long_variable_name_1 = get_1() really_long_variable_name_2 = get_2() main_func( - ...=really_long_variable_name_1, - ...=really_long_variable_name_2, + really_long_variable_name_1=, + really_long_variable_name_2=, ) ``` @@ -2521,6 +2530,58 @@ range(5) |> last_two |> print _Can't be done without a long series of checks at the top of the function. See the compiled code for the Python syntax._ +### `case` Functions + +For easily defining a pattern-matching function with many different cases, Coconut provides the `case def` syntax based on Coconut's [`case`](#case) syntax. The basic syntax is +``` +case def : + case(, , ... [if ]): + + case(, , ... [if ]): + + ... +``` +where the patterns in each `case` are checked in sequence until a match is found and the body under that match is executed, or a [`MatchError`](#matcherror) is raised. Each `case(...)` statement is effectively treated as a separate pattern-matching function signature that is checked independently, as if they had each been defined separately and then combined with [`addpattern`](#addpattern). + +Any individual body can also be defined with [assignment function syntax](#assignment-functions) such that +``` +case def : + case(, , ... [if ]) = +``` +is equivalent to +``` +case def : + case(, , ... [if ]): return +``` + +`case` function definition can also be combined with `async` functions, [`copyclosure` functions](#copyclosure-functions), and [`yield` functions](#explicit-generators). The various keywords in front of the `def` can be put in any order. + +`case def` also allows for easily providing type annotations for pattern-matching functions. To add type annotations, inside the body of the `case def`, instead of just `case(...)` statements, include some `type(...)` statements as well, which will compile into [`typing.overload`](https://docs.python.org/3/library/typing.html#overload) declarations. The syntax is +``` +case def []: + type(: , : , ...) -> + type(: , : , ...) -> + ... +``` +which can be interspersed with the `case(...)` statements. + +##### Example + +**Coconut:** +```coconut +case def my_min[T]: + type(x: T, y: T) -> T + case(x, y if x <= y) = x + case(x, y) = y + + type(xs: T[]) -> T + case([x]) = x + case([x] + xs) = my_min(x, my_min(xs)) +``` + +**Python:** +_Can't be done without a long series of checks for each pattern-matching. See the compiled code for the Python syntax._ + ### `addpattern` Functions Coconut provides the `addpattern def` syntax as a shortcut for the full @@ -2531,10 +2592,12 @@ match def func(...): ``` syntax using the [`addpattern`](#addpattern) decorator. -Additionally, `addpattern def` will act just like a normal [`match def`](#pattern-matching-functions) if the function has not previously been defined, allowing for `addpattern def` to be used for each case rather than requiring `match def` for the first case and `addpattern def` for future cases. - If you want to put a decorator on an `addpattern def` function, make sure to put it on the _last_ pattern function. +For complex multi-pattern functions, it is generally recommended to use [`case def`](#case-functions) over `addpattern def` in most situations. + +_Deprecated: `addpattern def` will act just like a normal [`match def`](#pattern-matching-functions) if the function has not previously been defined. This will show a [`CoconutWarning`](#coconutwarning) and is not recommended._ + ##### Example **Coconut:** @@ -2952,7 +3015,7 @@ depth: 1 Takes one argument that is a [pattern-matching function](#pattern-matching-functions), and returns a decorator that adds the patterns in the existing function to the new function being decorated, where the existing patterns are checked first, then the new. `addpattern` also supports a shortcut syntax where the new patterns can be passed in directly. Roughly equivalent to: -``` +```coconut_python def _pattern_adder(base_func, add_func): def add_pattern_func(*args, **kwargs): try: @@ -2990,7 +3053,7 @@ print_type() # appears to work print_type(1) # TypeError: print_type() takes 0 positional arguments but 1 was given ``` -This can be fixed by using either the `match` or `addpattern` keyword. For example: +This can be fixed by using either the `match` keyword. For example: ```coconut match def print_type(): print("Received no arguments.") @@ -3341,6 +3404,10 @@ Additionally, if you are using [view patterns](#match), you might need to raise In some cases where there are multiple Coconut packages installed at the same time, there may be multiple `MatchError`s defined in different packages. Coconut can perform some magic under the hood to make sure that all these `MatchError`s will seamlessly interoperate, but only if all such packages are compiled in [`--package` mode rather than `--standalone` mode](#compilation-modes). +### `CoconutWarning` + +`CoconutWarning` is the [`Warning`](https://docs.python.org/3/library/exceptions.html#Warning) subclass used for all runtime Coconut warnings; see [`warnings`](https://docs.python.org/3/library/warnings.html). + ### Generic Built-In Functions @@ -4594,7 +4661,7 @@ else: #### `reveal_type` and `reveal_locals` -When using MyPy, `reveal_type()` will cause MyPy to print the type of `` and `reveal_locals()` will cause MyPy to print the types of the current `locals()`. At runtime, `reveal_type(x)` is always the identity function and `reveal_locals()` always returns `None`. See [the MyPy documentation](https://mypy.readthedocs.io/en/stable/common_issues.html#reveal-type) for more information. +When using static type analysis tools integrated with Coconut such as [MyPy](#mypy-integration), `reveal_type()` will cause MyPy to print the type of `` and `reveal_locals()` will cause MyPy to print the types of the current `locals()`. At runtime, `reveal_type(x)` is always the identity function and `reveal_locals()` always returns `None`. See [the MyPy documentation](https://mypy.readthedocs.io/en/stable/common_issues.html#reveal-type) for more information. ##### Example diff --git a/HELP.md b/HELP.md index 8c78644af..9b87056f4 100644 --- a/HELP.md +++ b/HELP.md @@ -379,7 +379,7 @@ def factorial(n): ``` By making use of the [Coconut `addpattern` syntax](./DOCS.md#addpattern), we can take that from three indentation levels down to one. Take a look: -``` +```coconut def factorial(0) = 1 addpattern def factorial(int() as n if n > 0) = diff --git a/Makefile b/Makefile index eb2094c8f..e96ed3eef 100644 --- a/Makefile +++ b/Makefile @@ -161,8 +161,16 @@ test-mypy-tests: clean-no-tests python ./coconut/tests/dest/runner.py python ./coconut/tests/dest/extras.py +# same as test-mypy but uses pyright instead +.PHONY: test-pyright +test-pyright: export COCONUT_USE_COLOR=TRUE +test-pyright: clean + python ./coconut/tests --strict --keep-lines --force --target sys --no-cache --pyright + python ./coconut/tests/dest/runner.py + python ./coconut/tests/dest/extras.py + # same as test-univ but includes verbose output for better debugging -# regex for getting non-timing lines: ^(?!\s*(Time|Packrat|Loaded|Saving|Adaptive|Errorless|Grammar|Failed|Incremental|Pruned)\s)[^\n]*\n* +# regex for getting non-timing lines: ^(?!'|\s*(Time|Packrat|Loaded|Saving|Adaptive|Errorless|Grammar|Failed|Incremental|Pruned|Compiled)\s)[^\n]*\n* .PHONY: test-verbose test-verbose: export COCONUT_USE_COLOR=TRUE test-verbose: clean @@ -170,6 +178,14 @@ test-verbose: clean python ./coconut/tests/dest/runner.py python ./coconut/tests/dest/extras.py +# same as test-verbose but reuses the incremental cache +.PHONY: test-verbose-cache +test-verbose-cache: export COCONUT_USE_COLOR=TRUE +test-verbose-cache: clean-no-tests + python ./coconut/tests --strict --keep-lines --force --verbose + python ./coconut/tests/dest/runner.py + python ./coconut/tests/dest/extras.py + # same as test-verbose but doesn't use the incremental cache .PHONY: test-verbose-no-cache test-verbose-no-cache: export COCONUT_USE_COLOR=TRUE @@ -359,7 +375,7 @@ check-reqs: .PHONY: profile profile: export COCONUT_USE_COLOR=TRUE profile: - coconut ./coconut/tests/src/cocotest/agnostic/util.coco ./coconut/tests/dest/cocotest --force --jobs 0 --profile --verbose --stack-size 4096 --recursion-limit 4096 2>&1 | tee ./profile.log + coconut ./coconut/tests/src/cocotest/agnostic/util.coco ./coconut/tests/dest/cocotest --force --verbose --profile --stack-size 4096 --recursion-limit 4096 2>&1 | tee ./profile.log .PHONY: open-speedscope open-speedscope: diff --git a/__coconut__/__init__.pyi b/__coconut__/__init__.pyi index 09313eb57..fedb0bb90 100644 --- a/__coconut__/__init__.pyi +++ b/__coconut__/__init__.pyi @@ -175,6 +175,8 @@ py_reversed = reversed py_enumerate = enumerate py_repr = repr py_breakpoint = breakpoint +py_min = min +py_max = max # all py_ functions, but not py_ types, go here chr = _builtins.chr @@ -189,6 +191,8 @@ zip = _builtins.zip filter = _builtins.filter reversed = _builtins.reversed enumerate = _builtins.enumerate +min = _builtins.min +max = _builtins.max _coconut_py_str = py_str @@ -235,6 +239,11 @@ def scan( _coconut_scan = scan +class CoconutWarning(Warning): + pass +_coconut_CoconutWarning = CoconutWarning + + class MatchError(Exception): """Pattern-matching error. Has attributes .pattern, .value, and .message.""" pattern: _t.Optional[_t.Text] @@ -275,30 +284,30 @@ def call( _y: _U, _z: _V, ) -> _W: ... -# @_t.overload -# def call( -# _func: _t.Callable[_t.Concatenate[_T, _P], _U], -# _x: _T, -# *args: _t.Any, -# **kwargs: _t.Any, -# ) -> _U: ... -# @_t.overload -# def call( -# _func: _t.Callable[_t.Concatenate[_T, _U, _P], _V], -# _x: _T, -# _y: _U, -# *args: _t.Any, -# **kwargs: _t.Any, -# ) -> _V: ... -# @_t.overload -# def call( -# _func: _t.Callable[_t.Concatenate[_T, _U, _V, _P], _W], -# _x: _T, -# _y: _U, -# _z: _V, -# *args: _t.Any, -# **kwargs: _t.Any, -# ) -> _W: ... +@_t.overload +def call( + _func: _t.Callable[_t.Concatenate[_T, _P], _U], + _x: _T, + *args: _t.Any, + **kwargs: _t.Any, +) -> _U: ... +@_t.overload +def call( + _func: _t.Callable[_t.Concatenate[_T, _U, _P], _V], + _x: _T, + _y: _U, + *args: _t.Any, + **kwargs: _t.Any, +) -> _V: ... +@_t.overload +def call( + _func: _t.Callable[_t.Concatenate[_T, _U, _V, _P], _W], + _x: _T, + _y: _U, + _z: _V, + *args: _t.Any, + **kwargs: _t.Any, +) -> _W: ... @_t.overload def call( _func: _t.Callable[..., _T], @@ -312,6 +321,7 @@ def call( """ ... +# call = _coconut.operator.call _coconut_tail_call = call of = _deprecated("use call instead")(call) diff --git a/coconut/__coconut__.pyi b/coconut/__coconut__.pyi index 520b56973..92a5a9dce 100644 --- a/coconut/__coconut__.pyi +++ b/coconut/__coconut__.pyi @@ -1,2 +1,2 @@ from __coconut__ import * -from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_arr_concat_op, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter, _coconut_if_op +from __coconut__ import _coconut_tail_call, _coconut_tco, _coconut_call_set_names, _coconut_handle_cls_kwargs, _coconut_handle_cls_stargs, _namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_Expected, _coconut_MatchError, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_arr_concat_op, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter, _coconut_if_op, _coconut_CoconutWarning diff --git a/coconut/_pyparsing.py b/coconut/_pyparsing.py index 6d08487a6..f3101d42e 100644 --- a/coconut/_pyparsing.py +++ b/coconut/_pyparsing.py @@ -20,7 +20,6 @@ from coconut.root import * # NOQA import os -import re import sys import traceback from warnings import warn @@ -39,7 +38,6 @@ min_versions, max_versions, pure_python_env_var, - enable_pyparsing_warnings, use_left_recursion_if_available, get_bool_env_var, use_computation_graph_env_var, @@ -50,6 +48,7 @@ num_displayed_timing_items, use_cache_file, use_line_by_line_parser, + incremental_use_hybrid, ) from coconut.util import get_clock_time # NOQA from coconut.util import ( @@ -147,6 +146,7 @@ # ----------------------------------------------------------------------------------------------------------------------- if MODERN_PYPARSING: + ParserElement.leaveWhitespace = ParserElement.leave_whitespace SUPPORTS_PACKRAT_CONTEXT = False elif CPYPARSING: @@ -243,8 +243,9 @@ def enableIncremental(*args, **kwargs): use_computation_graph_env_var, default=( not MODERN_PYPARSING # not yet supported - # commented out to minimize memory footprint when running tests: - # and not PYPY # experimentally determined + # technically PYPY is faster without the computation graph, but + # it breaks some features and balloons the memory footprint + # and not PYPY ), ) @@ -265,7 +266,7 @@ def enableIncremental(*args, **kwargs): maybe_make_safe = getattr(_pyparsing, "maybe_make_safe", None) -if enable_pyparsing_warnings: +if DEVELOP: if MODERN_PYPARSING: _pyparsing.enable_all_warnings() else: @@ -276,7 +277,11 @@ def enableIncremental(*args, **kwargs): if MODERN_PYPARSING and use_left_recursion_if_available: ParserElement.enable_left_recursion() elif SUPPORTS_INCREMENTAL and use_incremental_if_available: - ParserElement.enableIncremental(default_incremental_cache_size, still_reset_cache=not never_clear_incremental_cache) + ParserElement.enableIncremental( + default_incremental_cache_size, + still_reset_cache=not never_clear_incremental_cache, + hybrid_mode=incremental_use_hybrid, + ) elif use_packrat_parser: ParserElement.enablePackrat(packrat_cache_size) @@ -290,22 +295,6 @@ def enableIncremental(*args, **kwargs): all_parse_elements = None -# ----------------------------------------------------------------------------------------------------------------------- -# MISSING OBJECTS: -# ----------------------------------------------------------------------------------------------------------------------- - -python_quoted_string = getattr(_pyparsing, "python_quoted_string", None) -if python_quoted_string is None: - python_quoted_string = _pyparsing.Combine( - # multiline strings must come first - (_pyparsing.Regex(r'"""(?:[^"\\]|""(?!")|"(?!"")|\\.)*', flags=re.MULTILINE) + '"""').setName("multiline double quoted string") - | (_pyparsing.Regex(r"'''(?:[^'\\]|''(?!')|'(?!'')|\\.)*", flags=re.MULTILINE) + "'''").setName("multiline single quoted string") - | (_pyparsing.Regex(r'"(?:[^"\n\r\\]|(?:\\")|(?:\\(?:[^x]|x[0-9a-fA-F]+)))*') + '"').setName("double quoted string") - | (_pyparsing.Regex(r"'(?:[^'\n\r\\]|(?:\\')|(?:\\(?:[^x]|x[0-9a-fA-F]+)))*") + "'").setName("single quoted string") - ).setName("Python quoted string") - _pyparsing.python_quoted_string = python_quoted_string - - # ----------------------------------------------------------------------------------------------------------------------- # FAST REPRS: # ----------------------------------------------------------------------------------------------------------------------- @@ -559,5 +548,5 @@ def start_profiling(): def print_profiling_results(): """Print all profiling results.""" - print_timing_info() print_poorly_ordered_MatchFirsts() + print_timing_info() diff --git a/coconut/api.pyi b/coconut/api.pyi index 850b2eb89..f80fb0538 100644 --- a/coconut/api.pyi +++ b/coconut/api.pyi @@ -27,7 +27,9 @@ from coconut.command.command import Command class CoconutException(Exception): """Coconut Exception.""" - ... + + def syntax_err(self) -> SyntaxError: + ... #----------------------------------------------------------------------------------------------------------------------- # COMMAND: diff --git a/coconut/command/cli.py b/coconut/command/cli.py index 5ea28e199..c542cab6f 100644 --- a/coconut/command/cli.py +++ b/coconut/command/cli.py @@ -216,6 +216,12 @@ help="run MyPy on compiled Python (remaining args passed to MyPy) (implies --package --line-numbers)", ) +arguments.add_argument( + "--pyright", + action="store_true", + help="run Pyright on compiled Python (implies --package)", +) + arguments.add_argument( "--argv", "--args", type=str, diff --git a/coconut/command/command.py b/coconut/command/command.py index fc6fe2d3e..4d4debab8 100644 --- a/coconut/command/command.py +++ b/coconut/command/command.py @@ -73,6 +73,7 @@ coconut_cache_dir, coconut_sys_kwargs, interpreter_uses_incremental, + pyright_config_file, ) from coconut.util import ( univ_open, @@ -83,8 +84,6 @@ first_import_time, ) from coconut.command.util import ( - writefile, - readfile, showpath, rem_encoding, Runner, @@ -104,6 +103,7 @@ run_with_stack_size, proc_run_args, get_python_lib, + update_pyright_config, ) from coconut.compiler.util import ( should_indent, @@ -128,6 +128,7 @@ class Command(object): display = False # corresponds to --display flag jobs = 0 # corresponds to --jobs flag mypy_args = None # corresponds to --mypy flag + pyright = False # corresponds to --pyright flag argv_args = None # corresponds to --argv flag stack_size = 0 # corresponds to --stack-size flag use_cache = USE_CACHE # corresponds to --no-cache flag @@ -252,6 +253,8 @@ def execute_args(self, args, interact=True, original_args=None): logger.log("Directly passed args:", original_args) logger.log("Parsed args:", args) + type_checking_arg = "--mypy" if args.mypy else "--pyright" if args.pyright else None + # validate args and show warnings if args.stack_size and args.stack_size % 4 != 0: logger.warn("--stack-size should generally be a multiple of 4, not {stack_size} (to support 4 KB pages)".format(stack_size=args.stack_size)) @@ -259,8 +262,8 @@ def execute_args(self, args, interact=True, original_args=None): logger.warn("using --mypy running with --no-line-numbers is not recommended; mypy error messages won't include Coconut line numbers") if args.interact and args.run: logger.warn("extraneous --run argument passed; --interact implies --run") - if args.package and self.mypy: - logger.warn("extraneous --package argument passed; --mypy implies --package") + if args.package and type_checking_arg: + logger.warn("extraneous --package argument passed; --{type_checking_arg} implies --package".format(type_checking_arg=type_checking_arg)) # validate args and raise errors if args.line_numbers and args.no_line_numbers: @@ -269,10 +272,10 @@ def execute_args(self, args, interact=True, original_args=None): raise CoconutException("cannot --site-install and --site-uninstall simultaneously") if args.standalone and args.package: raise CoconutException("cannot compile as both --package and --standalone") - if args.standalone and self.mypy: - raise CoconutException("cannot compile as both --package (implied by --mypy) and --standalone") - if args.no_write and self.mypy: - raise CoconutException("cannot compile with --no-write when using --mypy") + if args.standalone and type_checking_arg: + raise CoconutException("cannot compile as both --package (implied by --{type_checking_arg}) and --standalone".format(type_checking_arg=type_checking_arg)) + if args.no_write and type_checking_arg: + raise CoconutException("cannot compile with --no-write when using --{type_checking_arg}".format(type_checking_arg=type_checking_arg)) for and_args in getattr(args, "and") or []: if len(and_args) > 2: raise CoconutException( @@ -328,25 +331,24 @@ def execute_args(self, args, interact=True, original_args=None): no_tco=args.no_tco, no_wrap=args.no_wrap_types, ) - self.comp.warm_up( - streamline=( - not self.using_jobs - and (args.watch or args.profile) - ), - enable_incremental_mode=( - not self.using_jobs - and args.watch - ), - set_debug_names=( - args.verbose - or args.trace - or args.profile - ), - ) + if not self.using_jobs: + self.comp.warm_up( + streamline=( + args.watch + or args.profile + ), + set_debug_names=( + args.verbose + or args.trace + or args.profile + ), + ) - # process mypy args and print timing info (must come after compiler setup) + # process mypy + pyright args and print timing info (must come after compiler setup) if args.mypy is not None: self.set_mypy_args(args.mypy) + if args.pyright: + self.enable_pyright() logger.log_compiler_stats(self.comp) # do compilation, keeping track of compiled filepaths @@ -378,8 +380,8 @@ def execute_args(self, args, interact=True, original_args=None): for kwargs in all_compile_path_kwargs: filepaths += self.compile_path(**kwargs) - # run mypy on compiled files - self.run_mypy(filepaths) + # run type checking on compiled files + self.run_type_checking(filepaths) # do extra compilation if there is any if extra_compile_path_kwargs: @@ -459,7 +461,7 @@ def process_source_dest(self, source, dest, args): processed_dest = dest # determine package mode - if args.package or self.mypy: + if args.package or self.type_checking: package = True elif args.standalone: package = False @@ -506,7 +508,7 @@ def process_source_dest(self, source, dest, args): ] return main_compilation_tasks, extra_compilation_tasks - def compile_path(self, source, dest=True, package=True, handling_exceptions_kwargs={}, **kwargs): + def compile_path(self, source, dest=True, package=True, **kwargs): """Compile a path and return paths to compiled files.""" if not isinstance(dest, bool): dest = fixpath(dest) @@ -514,11 +516,11 @@ def compile_path(self, source, dest=True, package=True, handling_exceptions_kwar destpath = self.compile_file(source, dest, package, **kwargs) return [destpath] if destpath is not None else [] elif os.path.isdir(source): - return self.compile_folder(source, dest, package, handling_exceptions_kwargs=handling_exceptions_kwargs, **kwargs) + return self.compile_folder(source, dest, package, **kwargs) else: raise CoconutException("could not find source path", source) - def compile_folder(self, directory, write=True, package=True, handling_exceptions_kwargs={}, **kwargs): + def compile_folder(self, directory, write=True, package=True, **kwargs): """Compile a directory and return paths to compiled files.""" if not isinstance(write, bool) and os.path.isfile(write): raise CoconutException("destination path cannot point to a file when compiling a directory") @@ -530,7 +532,7 @@ def compile_folder(self, directory, write=True, package=True, handling_exception writedir = os.path.join(write, os.path.relpath(dirpath, directory)) for filename in filenames: if os.path.splitext(filename)[1] in code_exts: - with self.handling_exceptions(**handling_exceptions_kwargs): + with self.handling_exceptions(**kwargs.get("handling_exceptions_kwargs", {})): destpath = self.compile_file(os.path.join(dirpath, filename), writedir, package, **kwargs) if destpath is not None: filepaths.append(destpath) @@ -576,10 +578,10 @@ def compile_file(self, filepath, write=True, package=False, force=False, **kwarg self.compile(filepath, destpath, package, force=force, **kwargs) return destpath - def compile(self, codepath, destpath=None, package=False, run=False, force=False, show_unchanged=True): + def compile(self, codepath, destpath=None, package=False, run=False, force=False, show_unchanged=True, handling_exceptions_kwargs={}, callback=None): """Compile a source Coconut file to a destination Python file.""" with univ_open(codepath, "r") as opened: - code = readfile(opened) + code = opened.read() package_level = -1 if destpath is not None: @@ -599,16 +601,18 @@ def compile(self, codepath, destpath=None, package=False, run=False, force=False logger.print(foundhash) if run: self.execute_file(destpath, argv_source_path=codepath) + if callback is not None: + callback(destpath) else: logger.show_tabulated("Compiling", showpath(codepath), "...") - def callback(compiled): + def inner_callback(compiled): if destpath is None: logger.show_tabulated("Compiled", showpath(codepath), "without writing to file.") else: with univ_open(destpath, "w") as opened: - writefile(opened, compiled) + opened.write(compiled) logger.show_tabulated("Compiled to", showpath(destpath), ".") if self.display: logger.print(compiled) @@ -617,15 +621,17 @@ def callback(compiled): self.execute(compiled, path=codepath, allow_show=False) else: self.execute_file(destpath, argv_source_path=codepath) + if callback is not None: + callback(destpath) parse_kwargs = dict( codepath=codepath, use_cache=self.use_cache, ) if package is True: - self.submit_comp_job(codepath, callback, "parse_package", code, package_level=package_level, **parse_kwargs) + self.submit_comp_job(codepath, inner_callback, handling_exceptions_kwargs, "parse_package", code, package_level=package_level, **parse_kwargs) elif package is False: - self.submit_comp_job(codepath, callback, "parse_file", code, **parse_kwargs) + self.submit_comp_job(codepath, inner_callback, handling_exceptions_kwargs, "parse_file", code, **parse_kwargs) else: raise CoconutInternalException("invalid value for package", package) @@ -656,7 +662,7 @@ def create_package(self, dirpath, retries_left=create_package_retries): filepath = os.path.join(dirpath, "__coconut__.py") try: with univ_open(filepath, "w") as opened: - writefile(opened, self.comp.getheader("__coconut__")) + opened.write(self.comp.getheader("__coconut__")) except OSError: logger.log_exc() if retries_left <= 0: @@ -667,10 +673,10 @@ def create_package(self, dirpath, retries_left=create_package_retries): time.sleep(random.random() / 10) self.create_package(dirpath, retries_left - 1) - def submit_comp_job(self, path, callback, method, *args, **kwargs): + def submit_comp_job(self, path, callback, handling_exceptions_kwargs, method, *args, **kwargs): """Submits a job on self.comp to be run in parallel.""" if self.executor is None: - with self.handling_exceptions(): + with self.handling_exceptions(**handling_exceptions_kwargs): callback(getattr(self.comp, method)(*args, **kwargs)) else: path = showpath(path) @@ -680,7 +686,7 @@ def submit_comp_job(self, path, callback, method, *args, **kwargs): def callback_wrapper(completed_future): """Ensures that all errors are always caught, since errors raised in a callback won't be propagated.""" with logger.in_path(path): # handle errors in the path context - with self.handling_exceptions(): + with self.handling_exceptions(**handling_exceptions_kwargs): result = completed_future.result() callback(result) future.add_done_callback(callback_wrapper) @@ -693,19 +699,19 @@ def register_exit_code(self, code=1, errmsg=None, err=None): errmsg = format_error(err) else: errmsg = err.__class__.__name__ - if errmsg is not None: - if self.errmsg is None: - self.errmsg = errmsg - elif errmsg not in self.errmsg: - if logger.verbose: - self.errmsg += "\nAnd error: " + errmsg - else: - self.errmsg += "; " + errmsg - if code is not None: + if code: + if errmsg is not None: + if self.errmsg is None: + self.errmsg = errmsg + elif errmsg not in self.errmsg: + if logger.verbose: + self.errmsg += "\nAnd error: " + errmsg + else: + self.errmsg += "; " + errmsg self.exit_code = code or self.exit_code @contextmanager - def handling_exceptions(self, exit_on_error=None, on_keyboard_interrupt=None): + def handling_exceptions(self, exit_on_error=None, error_callback=None): """Perform proper exception handling.""" if exit_on_error is None: exit_on_error = self.fail_fast @@ -715,26 +721,29 @@ def handling_exceptions(self, exit_on_error=None, on_keyboard_interrupt=None): yield else: yield - except SystemExit as err: - self.register_exit_code(err.code) # make sure we don't catch GeneratorExit below except GeneratorExit: raise + except SystemExit as err: + self.register_exit_code(err.code) + if error_callback is not None: + error_callback(err) except BaseException as err: if isinstance(err, CoconutException): logger.print_exc() - elif isinstance(err, KeyboardInterrupt): - if on_keyboard_interrupt is not None: - on_keyboard_interrupt() - else: + elif not isinstance(err, KeyboardInterrupt): logger.print_exc() logger.printerr(report_this_text) self.register_exit_code(err=err) + if error_callback is not None: + error_callback(err) if exit_on_error: self.exit_on_error() def set_jobs(self, jobs, profile=False): """Set --jobs.""" + if profile and jobs is None: + jobs = 0 if jobs in (None, "sys"): self.jobs = jobs else: @@ -788,7 +797,7 @@ def has_hash_of(self, destpath, code, package_level): """Determine if a file has the hash of the code.""" if destpath is not None and os.path.isfile(destpath): with univ_open(destpath, "r") as opened: - compiled = readfile(opened) + compiled = opened.read() hashash = gethash(compiled) if hashash is not None: newhash = self.comp.genhash(code, package_level) @@ -876,19 +885,19 @@ def execute(self, compiled=None, path=None, use_eval=False, allow_show=True): logger.print(compiled) if path is None: # header is not included - if not self.mypy: + if not self.type_checking: no_str_code = self.comp.remove_strs(compiled) if no_str_code is not None: result = mypy_builtin_regex.search(no_str_code) if result: - logger.warn("found mypy-only built-in " + repr(result.group(0)) + "; pass --mypy to use mypy-only built-ins at the interpreter") + logger.warn("found type-checking-only built-in " + repr(result.group(0)) + "; pass --mypy to use such built-ins at the interpreter") else: # header is included compiled = rem_encoding(compiled) self.runner.run(compiled, use_eval=use_eval, path=path, all_errors_exit=path is not None) - self.run_mypy(code=self.runner.was_run_code()) + self.run_type_checking(code=self.runner.was_run_code()) def execute_file(self, destpath, **kwargs): """Execute compiled file.""" @@ -908,25 +917,31 @@ def check_runner(self, set_sys_vars=True, argv_source_path=""): # set up runner if self.runner is None: - self.runner = Runner(self.comp, exit=self.exit_runner, store=self.mypy) + self.runner = Runner(self.comp, exit=self.exit_runner, store=self.type_checking) # pass runner to prompt self.prompt.set_runner(self.runner) @property - def mypy(self): - """Whether using MyPy or not.""" - return self.mypy_args is not None + def type_checking(self): + """Whether using a static type-checker or not.""" + return self.mypy_args is not None or self.pyright + + @property + def type_checking_version(self): + """What version of Python to type check against.""" + return ver_tuple_to_str(get_target_info_smart(self.comp.target, mode="highest")) def set_mypy_args(self, mypy_args=None): """Set MyPy arguments.""" if mypy_args is None: self.mypy_args = None - elif mypy_install_arg in mypy_args: + stub_dir = set_mypy_path() + + if mypy_install_arg in mypy_args: if mypy_args != [mypy_install_arg]: raise CoconutException("'--mypy install' cannot be used alongside other --mypy arguments") - stub_dir = set_mypy_path() logger.show_sig("Successfully installed MyPy stubs into " + repr(stub_dir)) self.mypy_args = None @@ -936,7 +951,7 @@ def set_mypy_args(self, mypy_args=None): if not any(arg.startswith("--python-version") for arg in self.mypy_args): self.mypy_args += [ "--python-version", - ver_tuple_to_str(get_target_info_smart(self.comp.target, mode="highest")), + self.type_checking_version, ] if not any(arg.startswith("--python-executable") for arg in self.mypy_args): @@ -956,10 +971,15 @@ def set_mypy_args(self, mypy_args=None): logger.log("MyPy args:", self.mypy_args) self.mypy_errs = [] - def run_mypy(self, paths=(), code=None): - """Run MyPy with arguments.""" - if self.mypy: - set_mypy_path() + def enable_pyright(self): + """Enable the use of Pyright for type-checking.""" + update_pyright_config() + self.pyright = True + + def run_type_checking(self, paths=(), code=None): + """Run type-checking on the given paths / code.""" + if self.mypy_args is not None: + set_mypy_path(ensure_stubs=False) from coconut.command.mypy import mypy_run args = list(paths) + self.mypy_args if code is not None: # interpreter @@ -983,6 +1003,19 @@ def run_mypy(self, paths=(), code=None): if code is not None: # interpreter logger.printerr(line) self.mypy_errs.append(line) + if self.pyright: + if code is not None: + logger.warn("--pyright only works on files, not code snippets or at the interpreter (use --mypy instead)") + if paths: + try: + from pyright import main + except ImportError: + raise CoconutException( + "coconut --pyright requires Pyright", + extra="run '{python} -m pip install coconut[pyright]' to fix".format(python=sys.executable), + ) + args = ["--project", pyright_config_file, "--pythonversion", self.type_checking_version] + list(paths) + self.register_exit_code(main(args), errmsg="Pyright error") def run_silent_cmd(self, *args): """Same as run_cmd$(show_output=logger.verbose).""" @@ -1135,29 +1168,36 @@ def watch(self, all_compile_path_kwargs): interrupted = [False] # in list to allow modification - def interrupt(): - interrupted[0] = True - - def recompile(path, **kwargs): + def recompile(path, callback, **kwargs): + def error_callback(err): + if isinstance(err, KeyboardInterrupt): + interrupted[0] = True + callback() path = fixpath(path) src = kwargs.pop("source") dest = kwargs.pop("dest") if os.path.isfile(path) and os.path.splitext(path)[1] in code_exts: - with self.handling_exceptions(on_keyboard_interrupt=interrupt): + with self.handling_exceptions(error_callback=error_callback): if dest is True or dest is None: writedir = dest else: # correct the compilation path based on the relative position of path to src dirpath = os.path.dirname(path) writedir = os.path.join(dest, os.path.relpath(dirpath, src)) - filepaths = self.compile_path( + + def inner_callback(path): + self.run_type_checking([path]) + callback() + self.compile_path( path, writedir, show_unchanged=False, - handling_exceptions_kwargs=dict(on_keyboard_interrupt=interrupt), + handling_exceptions_kwargs=dict(error_callback=error_callback), + callback=inner_callback, **kwargs # no comma for py2 ) - self.run_mypy(filepaths) + else: + callback() observer = Observer() watchers = [] @@ -1171,10 +1211,8 @@ def recompile(path, **kwargs): try: while not interrupted[0]: time.sleep(watch_interval) - for wcher in watchers: - wcher.keep_watching() except KeyboardInterrupt: - interrupt() + interrupted[0] = True finally: if interrupted[0]: logger.show_sig("Got KeyboardInterrupt; stopping watcher.") diff --git a/coconut/command/mypy.py b/coconut/command/mypy.py index 57366b490..bcc4f636d 100644 --- a/coconut/command/mypy.py +++ b/coconut/command/mypy.py @@ -34,7 +34,7 @@ from mypy.api import run except ImportError: raise CoconutException( - "--mypy flag requires MyPy library", + "coconut --mypy requires MyPy", extra="run '{python} -m pip install coconut[mypy]' to fix".format(python=sys.executable), ) diff --git a/coconut/command/util.py b/coconut/command/util.py index c4e0b1e7d..fe26947d8 100644 --- a/coconut/command/util.py +++ b/coconut/command/util.py @@ -24,6 +24,7 @@ import subprocess import shutil import threading +import json from select import select from contextlib import contextmanager from functools import partial @@ -50,6 +51,7 @@ get_encoding, get_clock_time, assert_remove_prefix, + univ_open, ) from coconut.constants import ( WINDOWS, @@ -88,6 +90,9 @@ high_proc_prio, call_timeout, use_fancy_call_output, + extra_pyright_args, + pyright_config_file, + tabideal, ) if PY26: @@ -148,17 +153,23 @@ # ----------------------------------------------------------------------------------------------------------------------- -def writefile(openedfile, newcontents): - """Set the contents of a file.""" +def writefile(openedfile, newcontents, in_json=False, **kwargs): + """Set the entire contents of a file regardless of current position.""" openedfile.seek(0) openedfile.truncate() - openedfile.write(newcontents) + if in_json: + json.dump(newcontents, openedfile, **kwargs) + else: + openedfile.write(newcontents, **kwargs) -def readfile(openedfile): - """Read the contents of a file.""" +def readfile(openedfile, in_json=False, **kwargs): + """Read the entire contents of a file regardless of current position.""" openedfile.seek(0) - return str(openedfile.read()) + if in_json: + return json.load(openedfile, **kwargs) + else: + return str(openedfile.read(**kwargs)) def open_website(url): @@ -450,8 +461,8 @@ def symlink(link_to, link_from): shutil.copytree(link_to, link_from) -def install_mypy_stubs(): - """Properly symlink mypy stub files.""" +def install_stubs(): + """Properly symlink stub files for type-checking purposes.""" # unlink stub_dirs so we know rm_dir_or_link won't clear them for stub_name in stub_dir_names: unlink(os.path.join(base_stub_dir, stub_name)) @@ -477,10 +488,12 @@ def set_env_var(name, value): os.environ[py_str(name)] = py_str(value) -def set_mypy_path(): +def set_mypy_path(ensure_stubs=True): """Put Coconut stubs in MYPYPATH.""" + if ensure_stubs: + install_stubs() # mypy complains about the path if we don't use / over \ - install_dir = install_mypy_stubs().replace(os.sep, "/") + install_dir = installed_stub_dir.replace(os.sep, "/") original = os.getenv(mypy_path_env_var) if original is None: new_mypy_path = install_dir @@ -494,6 +507,28 @@ def set_mypy_path(): return install_dir +def update_pyright_config(python_version=None): + """Save an updated pyrightconfig.json.""" + stubs_dir = install_stubs() + update_existing = os.path.exists(pyright_config_file) + with univ_open(pyright_config_file, "r+" if update_existing else "w") as config_file: + if update_existing: + try: + config = readfile(config_file, in_json=True) + except ValueError: + raise CoconutException("invalid JSON syntax in " + repr(pyright_config_file)) + else: + config = extra_pyright_args.copy() + if "extraPaths" not in config: + config["extraPaths"] = [] + if stubs_dir not in config["extraPaths"]: + config["extraPaths"].append(stubs_dir) + if python_version is not None: + config["pythonVersion"] = python_version + writefile(config_file, config, in_json=True, indent=tabideal) + return pyright_config_file + + def is_empty_pipe(pipe, default=None): """Determine if the given pipe file object is empty.""" if pipe.closed: diff --git a/coconut/command/watch.py b/coconut/command/watch.py index c7046c397..f70a15b51 100644 --- a/coconut/command/watch.py +++ b/coconut/command/watch.py @@ -21,6 +21,9 @@ import sys +from functools import partial + +from coconut.terminal import logger from coconut.exceptions import CoconutException try: @@ -46,15 +49,28 @@ def __init__(self, recompile, *args, **kwargs): self.recompile = recompile self.args = args self.kwargs = kwargs - self.keep_watching() - - def keep_watching(self): - """Allows recompiling previously-compiled files.""" self.saw = set() + self.saw_twice = set() def on_modified(self, event): """Handle a file modified event.""" - path = event.src_path - if path not in self.saw: + self.handle(event.src_path) + + def handle(self, path): + """Handle a potential recompilation event for the given path.""" + if path in self.saw: + logger.log("Skipping watch event for: " + repr(path) + "\n\t(currently compiling: " + repr(self.saw) + ")") + self.saw_twice.add(path) + else: + logger.log("Handling watch event for: " + repr(path) + "\n\t(currently compiling: " + repr(self.saw) + ")") self.saw.add(path) - self.recompile(path, *self.args, **self.kwargs) + self.saw_twice.discard(path) + self.recompile(path, callback=partial(self.callback, path), *self.args, **self.kwargs) + + def callback(self, path): + """Callback for after recompiling the given path.""" + self.saw.discard(path) + if path in self.saw_twice: + logger.log("Submitting deferred watch event for: " + repr(path) + "\n\t(currently deferred: " + repr(self.saw_twice) + ")") + self.saw_twice.discard(path) + self.handle(path) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index b567759dc..458d6c283 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -93,6 +93,7 @@ import_existing, use_adaptive_any_of, reverse_any_of, + tempsep, ) from coconut.util import ( pickleable_obj, @@ -104,6 +105,7 @@ get_clock_time, get_name, assert_remove_prefix, + assert_remove_suffix, dictset, noop_ctx, ) @@ -123,7 +125,7 @@ complain, internal_assert, ) -from coconut.compiler.matching import Matcher +from coconut.compiler.matching import Matcher, match_funcdef_setup_code from coconut.compiler.grammar import ( Grammar, lazy_list_handle, @@ -134,6 +136,7 @@ itemgetter_handle, partial_op_item_handle, partial_arr_concat_handle, + split_args_list, ) from coconut.compiler.util import ( ExceptionNode, @@ -184,6 +187,7 @@ manage, sub_all, ComputationNode, + StartOfStrGrammar, ) from coconut.compiler.header import ( minify_header, @@ -310,75 +314,6 @@ def special_starred_import_handle(imp_all=False): return out -def split_args_list(tokens, loc): - """Splits function definition arguments.""" - pos_only_args = [] - req_args = [] - default_args = [] - star_arg = None - kwd_only_args = [] - dubstar_arg = None - pos = 0 - for arg in tokens: - # only the first two components matter; if there's a third it's a typedef - arg = arg[:2] - - if len(arg) == 1: - if arg[0] == "*": - # star sep (pos = 2) - if pos >= 2: - raise CoconutDeferredSyntaxError("star separator at invalid position in function definition", loc) - pos = 2 - elif arg[0] == "/": - # slash sep (pos = 0) - if pos > 0: - raise CoconutDeferredSyntaxError("slash separator at invalid position in function definition", loc) - if pos_only_args: - raise CoconutDeferredSyntaxError("only one slash separator allowed in function definition", loc) - if not req_args: - raise CoconutDeferredSyntaxError("slash separator must come after arguments to mark as positional-only", loc) - pos_only_args = req_args - req_args = [] - else: - # pos arg (pos = 0) - if pos == 0: - req_args.append(arg[0]) - # kwd only arg (pos = 2) - elif pos == 2: - kwd_only_args.append((arg[0], None)) - else: - raise CoconutDeferredSyntaxError("non-default arguments must come first or after star argument/separator", loc) - - else: - internal_assert(arg[1] is not None, "invalid arg[1] in split_args_list", arg) - - if arg[0] == "*": - # star arg (pos = 2) - if pos >= 2: - raise CoconutDeferredSyntaxError("star argument at invalid position in function definition", loc) - pos = 2 - star_arg = arg[1] - elif arg[0] == "**": - # dub star arg (pos = 3) - if pos == 3: - raise CoconutDeferredSyntaxError("double star argument at invalid position in function definition", loc) - pos = 3 - dubstar_arg = arg[1] - else: - # def arg (pos = 1) - if pos <= 1: - pos = 1 - default_args.append((arg[0], arg[1])) - # kwd only arg (pos = 2) - elif pos <= 2: - pos = 2 - kwd_only_args.append((arg[0], arg[1])) - else: - raise CoconutDeferredSyntaxError("invalid default argument in function definition", loc) - - return pos_only_args, req_args, default_args, star_arg, kwd_only_args, dubstar_arg - - def reconstitute_paramdef(pos_only_args, req_args, default_args, star_arg, kwd_only_args, dubstar_arg): """Convert the results of split_args_list back into a parameter defintion string.""" args_list = [] @@ -606,6 +541,7 @@ def reset(self, keep_state=False, filename=None): self.add_code_before_replacements = {} self.add_code_before_ignore_names = {} self.remaining_original = None + self.shown_warnings = set() @contextmanager def inner_environment(self, ln=None): @@ -623,6 +559,7 @@ def inner_environment(self, ln=None): kept_lines, self.kept_lines = self.kept_lines, [] num_lines, self.num_lines = self.num_lines, 0 remaining_original, self.remaining_original = self.remaining_original, None + shown_warnings, self.shown_warnings = self.shown_warnings, set() try: with ComputationNode.using_overrides(): yield @@ -638,6 +575,7 @@ def inner_environment(self, ln=None): self.kept_lines = kept_lines self.num_lines = num_lines self.remaining_original = remaining_original + self.shown_warnings = shown_warnings @contextmanager def disable_checks(self): @@ -842,6 +780,7 @@ def bind(cls): cls.testlist_star_namedexpr <<= attach(cls.testlist_star_namedexpr_tokens, cls.method("testlist_star_expr_handle")) cls.ellipsis <<= attach(cls.ellipsis_tokens, cls.method("ellipsis_handle")) cls.f_string <<= attach(cls.f_string_tokens, cls.method("f_string_handle")) + cls.funcname_typeparams <<= attach(cls.funcname_typeparams_tokens, cls.method("funcname_typeparams_handle")) # standard handlers of the form name <<= attach(name_ref, method("name_handle")) cls.term <<= attach(cls.term_ref, cls.method("term_handle")) @@ -855,6 +794,7 @@ def bind(cls): cls.full_match <<= attach(cls.full_match_ref, cls.method("full_match_handle")) cls.name_match_funcdef <<= attach(cls.name_match_funcdef_ref, cls.method("name_match_funcdef_handle")) cls.op_match_funcdef <<= attach(cls.op_match_funcdef_ref, cls.method("op_match_funcdef_handle")) + cls.base_case_funcdef <<= attach(cls.base_case_funcdef_ref, cls.method("base_case_funcdef_handle")) cls.yield_from <<= attach(cls.yield_from_ref, cls.method("yield_from_handle")) cls.typedef <<= attach(cls.typedef_ref, cls.method("typedef_handle")) cls.typedef_default <<= attach(cls.typedef_default_ref, cls.method("typedef_handle")) @@ -873,7 +813,6 @@ def bind(cls): cls.base_match_for_stmt <<= attach(cls.base_match_for_stmt_ref, cls.method("base_match_for_stmt_handle")) cls.async_with_for_stmt <<= attach(cls.async_with_for_stmt_ref, cls.method("async_with_for_stmt_handle")) cls.unsafe_typedef_tuple <<= attach(cls.unsafe_typedef_tuple_ref, cls.method("unsafe_typedef_tuple_handle")) - cls.funcname_typeparams <<= attach(cls.funcname_typeparams_ref, cls.method("funcname_typeparams_handle")) cls.impl_call <<= attach(cls.impl_call_ref, cls.method("impl_call_handle")) cls.protocol_intersect_expr <<= attach(cls.protocol_intersect_expr_ref, cls.method("protocol_intersect_expr_handle")) @@ -1003,11 +942,14 @@ def strict_err(self, *args, **kwargs): if self.strict: raise self.make_err(CoconutStyleError, *args, **kwargs) - def syntax_warning(self, *args, **kwargs): + def syntax_warning(self, message, original, loc, **kwargs): """Show a CoconutSyntaxWarning. Usage: self.syntax_warning(message, original, loc) """ - logger.warn_err(self.make_err(CoconutSyntaxWarning, *args, **kwargs)) + key = (message, loc) + if key not in self.shown_warnings: + logger.warn_err(self.make_err(CoconutSyntaxWarning, message, original, loc, **kwargs)) + self.shown_warnings.add(key) def strict_err_or_warn(self, *args, **kwargs): """Raises an error if in strict mode, otherwise raises a warning. Usage: @@ -1364,7 +1306,7 @@ def streamline(self, grammars, inputstring=None, force=False, inner=False): input_len = 0 if inputstring is None else len(inputstring) if force or (streamline_grammar_for_len is not None and input_len > streamline_grammar_for_len): start_time = get_clock_time() - prep_grammar(grammar, streamline=True) + prep_grammar(grammar, for_scan=False, streamline=True) logger.log_lambda( lambda: "Streamlined {grammar} in {time} seconds{info}.".format( grammar=get_name(grammar), @@ -1561,7 +1503,7 @@ def str_proc(self, inputstring, **kwargs): hold["exprs"][-1] += c elif hold["paren_level"] > 0: raise self.make_err(CoconutSyntaxError, "imbalanced parentheses in format string expression", inputstring, i, reformat=False) - elif match_in(self.end_f_str_expr, remaining_text): + elif does_parse(self.end_f_str_expr, remaining_text): hold["in_expr"] = False hold["str_parts"].append(c) else: @@ -1570,10 +1512,7 @@ def str_proc(self, inputstring, **kwargs): # if we might be at the end of the string elif hold["stop"] is not None: - if c == "\\": - self.str_hold_contents(hold, append=hold["stop"] + c) - hold["stop"] = None - elif c == hold["start"][0]: + if c == hold["start"][0]: hold["stop"] += c elif len(hold["stop"]) > len(hold["start"]): raise self.make_err(CoconutSyntaxError, "invalid number of closing " + repr(hold["start"][0]) + "s", inputstring, i, reformat=False) @@ -1581,8 +1520,9 @@ def str_proc(self, inputstring, **kwargs): done = True rerun = True else: - self.str_hold_contents(hold, append=hold["stop"] + c) + self.str_hold_contents(hold, append=hold["stop"]) hold["stop"] = None + rerun = True # if we might be at the start of an f string expr elif hold.get("saw_brace", False): @@ -1597,15 +1537,16 @@ def str_proc(self, inputstring, **kwargs): hold["exprs"].append("") rerun = True + elif is_f and c == "{": + hold["saw_brace"] = True + self.str_hold_contents(hold, append=c) + # backslashes should escape quotes, but nothing else elif count_end(self.str_hold_contents(hold), "\\") % 2 == 1: self.str_hold_contents(hold, append=c) elif c == hold["start"]: done = True elif c == hold["start"][0]: hold["stop"] = c - elif is_f and c == "{": - hold["saw_brace"] = True - self.str_hold_contents(hold, append=c) else: self.str_hold_contents(hold, append=c) @@ -2187,11 +2128,11 @@ def tre_return_handle(loc, tokens): type_ignore=self.type_ignore_comment(), ) self.tre_func_name <<= base_keyword(func_name).suppress() - return attach( - self.tre_return, + return StartOfStrGrammar(attach( + self.tre_return_base, tre_return_handle, greedy=True, - ) + )) def detect_is_gen(self, raw_lines): """Determine if the given function code is for a generator.""" @@ -2364,9 +2305,10 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method, def_stmt = raw_lines.pop(0) out = [] - # detect addpattern/copyclosure functions + # detect keyword functions addpattern = False copyclosure = False + typed_case_def = False done = False while not done: if def_stmt.startswith("addpattern "): @@ -2375,6 +2317,11 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method, elif def_stmt.startswith("copyclosure "): def_stmt = assert_remove_prefix(def_stmt, "copyclosure ") copyclosure = True + elif def_stmt.startswith("case "): + def_stmt = assert_remove_prefix(def_stmt, "case ") + case_def_ref, def_stmt = def_stmt.split(unwrapper, 1) + type_param_code, all_type_defs = self.get_ref("case_def", case_def_ref) + typed_case_def = True elif def_stmt.startswith("def"): done = True else: @@ -2450,6 +2397,7 @@ def proc_funcdef(self, original, loc, decorators, funcdef, is_async, in_method, try: {addpattern_decorator} = _coconut_addpattern({func_name}) {type_ignore} except _coconut.NameError: + _coconut.warnings.warn("Deprecated use of 'addpattern def {func_name}' with no pre-existing '{func_name}' function (use 'match def {func_name}' for the first definition or switch to 'case def' syntax)", _coconut_CoconutWarning) {addpattern_decorator} = lambda f: f """, add_newline=True, @@ -2580,7 +2528,7 @@ def {mock_var}({mock_paramdef}): # assemble tre'd function comment, rest = split_leading_comments(func_code) - indent, base, dedent = split_leading_trailing_indent(rest, 1) + indent, base, dedent = split_leading_trailing_indent(rest, max_indents=1) base, base_dedent = split_trailing_indent(base) docstring, base = self.split_docstring(base) @@ -2614,6 +2562,29 @@ def {mock_var}({mock_paramdef}): if is_match_func: decorators += "@_coconut_mark_as_match\n" # binds most tightly + # handle typed case def functions (must happen before decorators are cleared out) + type_code = None + if typed_case_def: + internal_assert(len(all_type_defs) not in (0, 2), "invalid typed case def all_type_defs", all_type_defs) + if undotted_name is not None: + all_type_defs = [ + "def " + def_name + assert_remove_prefix(type_def, "def " + func_name) + for type_def in all_type_defs + ] + type_def_lines = [] + for i, type_def in enumerate(all_type_defs): + type_def_lines.append( + ("@_coconut.typing.overload\n" if i < len(all_type_defs) - 1 else "") + + decorators + + self.deferred_code_proc(type_def) + ) + if undotted_name is not None: + type_def_lines.append("{func_name} = {def_name}".format( + func_name=func_name, + def_name=def_name, + )) + type_code = self.deferred_code_proc(type_param_code) + "\n".join(type_def_lines) + # handle dotted function definition if undotted_name is not None: out.append( @@ -2645,7 +2616,7 @@ def {mock_var}({mock_paramdef}): out += [decorators, def_stmt, func_code] decorators = "" - # handle copyclosure functions + # handle copyclosure functions and type_code if copyclosure: vars_var = self.get_temp_var("func_vars", loc) func_from_vars = vars_var + '["' + def_name + '"]' @@ -2658,24 +2629,39 @@ def {mock_var}({mock_paramdef}): handle_indentation( ''' if _coconut.typing.TYPE_CHECKING: - {code} + {type_code} {vars_var} = {{"{def_name}": {def_name}}} else: {vars_var} = _coconut.globals().copy() {vars_var}.update(_coconut.locals()) _coconut_exec({code_str}, {vars_var}) {func_name} = {func_from_vars} - ''', + ''', add_newline=True, ).format( func_name=func_name, def_name=def_name, vars_var=vars_var, - code=code, + type_code=code if type_code is None else type_code, code_str=self.wrap_str_of(self.reformat_post_deferred_code_proc(code)), func_from_vars=func_from_vars, ), ] + elif type_code: + out = [ + handle_indentation( + ''' +if _coconut.typing.TYPE_CHECKING: + {type_code} +else: + {code} + ''', + add_newline=True, + ).format( + type_code=type_code, + code="".join(out), + ), + ] internal_assert(not decorators, "unhandled decorators", decorators) return "".join(out) @@ -2731,25 +2717,21 @@ def deferred_code_proc(self, inputstring, add_code_at_start=False, ignore_names= func_id = int(assert_remove_prefix(line, funcwrapper)) original, loc, decorators, funcdef, is_async, in_method, is_stmt_lambda = self.get_ref("func", func_id) - # process inner code + # process inner code (we use tempsep to tell what was newly added before the funcdef) decorators = self.deferred_code_proc(decorators, add_code_at_start=True, ignore_names=ignore_names, **kwargs) - funcdef = self.deferred_code_proc(funcdef, ignore_names=ignore_names, **kwargs) - - # handle any non-function code that was added before the funcdef - pre_def_lines = [] - post_def_lines = [] - funcdef_lines = list(literal_lines(funcdef, True)) - for i, line in enumerate(funcdef_lines): - if self.def_regex.match(line): - pre_def_lines = funcdef_lines[:i] - post_def_lines = funcdef_lines[i:] - break - internal_assert(post_def_lines, "no def statement found in funcdef", funcdef) - - out.append(bef_ind) - out.extend(pre_def_lines) - out.append(self.proc_funcdef(original, loc, decorators, "".join(post_def_lines), is_async, in_method, is_stmt_lambda)) - out.append(aft_ind) + raw_funcdef = self.deferred_code_proc(tempsep + funcdef, ignore_names=ignore_names, **kwargs) + + pre_funcdef, post_funcdef = raw_funcdef.split(tempsep) + func_indent, func_code, func_dedent = split_leading_trailing_indent(post_funcdef, symmetric=True) + + out += [ + bef_ind, + pre_funcdef, + func_indent, + self.proc_funcdef(original, loc, decorators, func_code, is_async, in_method, is_stmt_lambda), + func_dedent, + aft_ind, + ] # look for add_code_before regexes else: @@ -2804,7 +2786,7 @@ def polish(self, inputstring, final_endline=True, **kwargs): # HANDLERS: # ----------------------------------------------------------------------------------------------------------------------- - def split_function_call(self, tokens, loc): + def split_function_call(self, original, loc, tokens): """Split into positional arguments and keyword arguments.""" pos_args = [] star_args = [] @@ -2827,7 +2809,10 @@ def split_function_call(self, tokens, loc): star_args.append(argstr) elif arg[0] == "**": dubstar_args.append(argstr) + elif arg[1] == "=": + kwd_args.append(arg[0] + "=" + arg[0]) elif arg[0] == "...": + self.strict_err_or_warn("'...={name}' shorthand is deprecated, use '{name}=' shorthand instead".format(name=arg[1]), original, loc) kwd_args.append(arg[1] + "=" + arg[1]) else: kwd_args.append(argstr) @@ -2843,9 +2828,9 @@ def split_function_call(self, tokens, loc): return pos_args, star_args, kwd_args, dubstar_args - def function_call_handle(self, loc, tokens): + def function_call_handle(self, original, loc, tokens): """Enforce properly ordered function parameters.""" - return "(" + join_args(*self.split_function_call(tokens, loc)) + ")" + return "(" + join_args(*self.split_function_call(original, loc, tokens)) + ")" def pipe_item_split(self, original, loc, tokens): """Process a pipe item, which could be a partial, an attribute access, a method call, or an expression. @@ -2866,7 +2851,7 @@ def pipe_item_split(self, original, loc, tokens): return "expr", tokens elif "partial" in tokens: func, args = tokens - pos_args, star_args, kwd_args, dubstar_args = self.split_function_call(args, loc) + pos_args, star_args, kwd_args, dubstar_args = self.split_function_call(original, loc, args) return "partial", (func, join_args(pos_args, star_args), join_args(kwd_args, dubstar_args)) elif "attrgetter" in tokens: name, args = attrgetter_atom_split(tokens) @@ -3012,17 +2997,17 @@ def pipe_handle(self, original, loc, tokens, **kwargs): raise CoconutDeferredSyntaxError("cannot star pipe into operator partial", loc) op, arg = split_item return "({op})({x}, {arg})".format(op=op, x=subexpr, arg=arg) + elif name == "await": + internal_assert(not split_item, "invalid split await pipe item tokens", split_item) + if stars: + raise CoconutDeferredSyntaxError("cannot star pipe into await", loc) + return self.await_expr_handle(original, loc, [subexpr]) elif name == "right arr concat partial": if stars: raise CoconutDeferredSyntaxError("cannot star pipe into array concatenation operator partial", loc) op, arg = split_item internal_assert(op.lstrip(";") == "", "invalid arr concat op", op) return "_coconut_arr_concat_op({dim}, {x}, {arg})".format(dim=len(op), x=subexpr, arg=arg) - elif name == "await": - internal_assert(not split_item, "invalid split await pipe item tokens", split_item) - if stars: - raise CoconutDeferredSyntaxError("cannot star pipe into await", loc) - return self.await_expr_handle(original, loc, [subexpr]) elif name == "namedexpr": if stars: raise CoconutDeferredSyntaxError("cannot star pipe into named expression partial", loc) @@ -3086,7 +3071,7 @@ def item_handle(self, original, loc, tokens): elif trailer[0] == "$[": out = "_coconut_iter_getitem(" + out + ", " + trailer[1] + ")" elif trailer[0] == "$(?": - pos_args, star_args, base_kwd_args, dubstar_args = self.split_function_call(trailer[1], loc) + pos_args, star_args, base_kwd_args, dubstar_args = self.split_function_call(original, loc, trailer[1]) has_question_mark = False needs_complex_partial = False @@ -3243,12 +3228,21 @@ def classdef_handle(self, original, loc, tokens): """Process class definitions.""" decorators, name, paramdefs, classlist_toks, body = tokens - out = "".join(paramdefs) + decorators + "class " + name + out = "" + + # paramdefs are type params on >= 3.12 and type var assignments on < 3.12 + if paramdefs: + if self.target_info >= (3, 12): + name += "[" + ", ".join(paramdefs) + "]" + else: + out += "".join(paramdefs) + + out += decorators + "class " + name # handle classlist base_classes = [] if classlist_toks: - pos_args, star_args, kwd_args, dubstar_args = self.split_function_call(classlist_toks, loc) + pos_args, star_args, kwd_args, dubstar_args = self.split_function_call(original, loc, classlist_toks) # check for just inheriting from object if ( @@ -3273,7 +3267,7 @@ def classdef_handle(self, original, loc, tokens): base_classes.append(join_args(pos_args, star_args, kwd_args, dubstar_args)) - if paramdefs: + if paramdefs and self.target_info < (3, 12): base_classes.append(self.get_generic_for_typevars()) if not classlist_toks and not self.target.startswith("3"): @@ -3307,16 +3301,7 @@ def match_datadef_handle(self, original, loc, tokens): check_var = self.get_temp_var("match_check", loc) matcher = self.get_matcher(original, loc, check_var, name_list=[]) - - pos_only_args, req_args, default_args, star_arg, kwd_only_args, dubstar_arg = split_args_list(matches, loc) - matcher.match_function( - pos_only_match_args=pos_only_args, - match_args=req_args + default_args, - star_arg=star_arg, - kwd_only_match_args=kwd_only_args, - dubstar_arg=dubstar_arg, - ) - + matcher.match_function_toks(matches) if cond is not None: matcher.add_guard(cond) @@ -3515,8 +3500,14 @@ def assemble_data(self, decorators, name, namedtuple_call, inherit, extra_stmts, definition of Expected in header.py_template. """ # create class - out = [ - "".join(paramdefs), + out = [] + if paramdefs: + # paramdefs are type params on >= 3.12 and type var assignments on < 3.12 + if self.target_info >= (3, 12): + name += "[" + ", ".join(paramdefs) + "]" + else: + out += ["".join(paramdefs)] + out += [ decorators, "class ", name, @@ -3525,7 +3516,7 @@ def assemble_data(self, decorators, name, namedtuple_call, inherit, extra_stmts, ] if inherit is not None: out += [", ", inherit] - if paramdefs: + if paramdefs and self.target_info < (3, 12): out += [", ", self.get_generic_for_typevars()] if not self.target.startswith("3"): out.append(", _coconut.object") @@ -3547,7 +3538,7 @@ def __rmul__(self, other): return _coconut.NotImplemented def __eq__(self, other): return self.__class__ is other.__class__ and _coconut.tuple.__eq__(self, other) def __hash__(self): - return _coconut.tuple.__hash__(self) ^ hash(self.__class__) + return _coconut.tuple.__hash__(self) ^ _coconut.hash(self.__class__) """, add_newline=True, ).format( @@ -3585,7 +3576,7 @@ def __hash__(self): return "".join(out) - def anon_namedtuple_handle(self, tokens): + def anon_namedtuple_handle(self, original, loc, tokens): """Handle anonymous named tuples.""" names = [] types = {} @@ -3598,7 +3589,10 @@ def anon_namedtuple_handle(self, tokens): types[i] = typedef else: raise CoconutInternalException("invalid anonymous named item", tok) - if name == "...": + if item == "=": + item = name + elif name == "...": + self.strict_err_or_warn("'...={item}' shorthand is deprecated, use '{item}=' shorthand instead".format(item=item), original, loc) name = item names.append(name) items.append(item) @@ -3630,7 +3624,7 @@ def single_import(self, loc, path, imp_as, type_ignore=False): fake_mods = imp_as.split(".") for i in range(1, len(fake_mods)): mod_name = ".".join(fake_mods[:i]) - out.extend(( + out += [ "try:", openindent + mod_name, closeindent + "except:", @@ -3638,7 +3632,7 @@ def single_import(self, loc, path, imp_as, type_ignore=False): closeindent + "else:", openindent + "if not _coconut.isinstance(" + mod_name + ", _coconut.types.ModuleType):", openindent + mod_name + ' = _coconut.types.ModuleType(_coconut_py_str("' + mod_name + '"))' + closeindent * 2, - )) + ] out.append(".".join(fake_mods) + " = " + import_as_var) else: out.append(import_stmt(imp_from, imp, imp_as)) @@ -3848,16 +3842,7 @@ def name_match_funcdef_handle(self, original, loc, tokens): check_var = self.get_temp_var("match_check", loc) matcher = self.get_matcher(original, loc, check_var) - - pos_only_args, req_args, default_args, star_arg, kwd_only_args, dubstar_arg = split_args_list(matches, loc) - matcher.match_function( - pos_only_match_args=pos_only_args, - match_args=req_args + default_args, - star_arg=star_arg, - kwd_only_match_args=kwd_only_args, - dubstar_arg=dubstar_arg, - ) - + matcher.match_function_toks(matches) if cond is not None: matcher.add_guard(cond) @@ -3888,6 +3873,110 @@ def op_match_funcdef_handle(self, original, loc, tokens): name_tokens.append(cond) return self.name_match_funcdef_handle(original, loc, name_tokens) + def base_case_funcdef_handle(self, original, loc, tokens): + """Process case def function definitions.""" + if len(tokens) == 2: + name_toks, cases = tokens + docstring = None + elif len(tokens) == 3: + name_toks, docstring, cases = tokens + else: + raise CoconutInternalException("invalid case function definition tokens", tokens) + + type_param_code = "" + if len(name_toks) == 1: + name, = name_toks + else: + name, paramdefs = name_toks + # paramdefs are type params on >= 3.12 and type var assignments on < 3.12 + if self.target_info >= (3, 12): + name += "[" + ", ".join(paramdefs) + "]" + else: + type_param_code = "".join(paramdefs) + + check_var = self.get_temp_var("match_check", loc) + + all_case_code = [] + all_type_defs = [] + for case_toks in cases: + if "match" in case_toks: + if len(case_toks) == 2: + matches, body = case_toks + cond = None + else: + matches, cond, body = case_toks + matcher = self.get_matcher(original, loc, check_var) + matcher.match_function_toks(matches, include_setup=False) + if cond is not None: + matcher.add_guard(cond) + all_case_code.append(handle_indentation(""" +if not {check_var}: + {match_to_kwargs_var} = {match_to_kwargs_var}_store.copy() + {match_out} + if {check_var}: + {body} + """).format( + check_var=check_var, + match_to_kwargs_var=match_to_kwargs_var, + match_out=matcher.out(), + body=body, + )) + elif "type" in case_toks: + typed_params, typed_ret = case_toks + all_type_defs.append(handle_indentation(""" +def {name}{typed_params}{typed_ret} + {docstring} + return {ellipsis} + """).format( + name=name, + typed_params=typed_params, + typed_ret=typed_ret, + docstring=docstring if docstring is not None else "", + ellipsis=self.any_type_ellipsis(), + )) + else: + raise CoconutInternalException("invalid case_funcdef case_toks", case_toks) + + if not all_case_code: + raise CoconutDeferredSyntaxError("case def with no case patterns", loc) + if type_param_code and not all_type_defs: + raise CoconutDeferredSyntaxError("type parameters in case def but no type cases", loc) + + if len(all_type_defs) > 1: + all_type_defs.append(handle_indentation(""" +def {name}(*_coconut_args, **_coconut_kwargs): + {docstring} + return {ellipsis} + """).format( + name=name, + docstring=docstring if docstring is not None else "", + ellipsis=self.any_type_ellipsis(), + )) + + func_code = handle_indentation(""" +def {name}({match_func_paramdef}): + {docstring} + {check_var} = False + {setup_code} + {match_to_kwargs_var}_store = {match_to_kwargs_var} + {all_case_code} + {error} + """).format( + name=name, + match_func_paramdef=match_func_paramdef, + docstring=docstring if docstring is not None else "", + check_var=check_var, + setup_code=match_funcdef_setup_code(), + match_to_kwargs_var=match_to_kwargs_var, + all_case_code="\n".join(all_case_code), + error=self.pattern_error(original, loc, match_to_args_var, check_var, function_match_error_var), + ) + + if not (type_param_code or all_type_defs): + return func_code + + return "case " + self.add_ref("case_def", (type_param_code, all_type_defs)) + unwrapper + func_code + def set_literal_handle(self, tokens): """Converts set literals to the right form for the target Python.""" internal_assert(len(tokens) == 1 and len(tokens[0]) == 1, "invalid set literal tokens", tokens) @@ -4137,8 +4226,7 @@ def typed_assign_stmt_handle(self, tokens): ).format( name=name, value=( - value if value is not None - else "_coconut.typing.cast(_coconut.typing.Any, {ellipsis})".format(ellipsis=self.ellipsis_handle()) + value if value is not None else self.any_type_ellipsis() ), comment=self.wrap_type_comment(typedef), annotation=self.wrap_typedef(typedef, for_py_typedef=False, duplicate=True), @@ -4165,6 +4253,10 @@ def ellipsis_handle(self, tokens=None): ellipsis_handle.ignore_arguments = True + def any_type_ellipsis(self): + """Get an ellipsis cast to Any type.""" + return "_coconut.typing.cast(_coconut.typing.Any, {ellipsis})".format(ellipsis=self.ellipsis_handle()) + def match_case_tokens(self, match_var, check_var, original, tokens, top): """Build code for matching the given case.""" if len(tokens) == 3: @@ -4191,9 +4283,20 @@ def cases_stmt_handle(self, original, loc, tokens): else: raise CoconutInternalException("invalid case tokens", tokens) - self.internal_assert(block_kwd in ("cases", "case", "match"), original, loc, "invalid case statement keyword", block_kwd) if block_kwd == "case": - self.strict_err_or_warn("deprecated case keyword at top level in case ...: match ...: block (use Python 3.10 match ...: case ...: syntax instead)", original, loc) + self.strict_err_or_warn( + "deprecated case keyword at top level in case ...: match ...: block (use Python 3.10 match ...: case ...: syntax instead)", + original, + loc, + ) + elif block_kwd == "cases": + self.syntax_warning( + "deprecated cases keyword at top level in cases ...: match ...: block (use Python 3.10 match ...: case ...: syntax instead)", + original, + loc, + ) + else: + self.internal_assert(block_kwd == "match", original, loc, "invalid case statement keyword", block_kwd) check_var = self.get_temp_var("case_match_check", loc) match_var = self.get_temp_var("case_match_to", loc) @@ -4233,11 +4336,11 @@ def f_string_handle(self, original, loc, tokens): # handle Python 3.8 f string = specifier for i, expr in enumerate(exprs): - if expr.endswith("="): + expr_rstrip = expr.rstrip() + if expr_rstrip.endswith("="): before = string_parts[i] - internal_assert(before[-1] == "{", "invalid format string split", (string_parts, exprs)) - string_parts[i] = before[:-1] + expr + "{" - exprs[i] = expr[:-1] + string_parts[i] = assert_remove_suffix(before, "{") + expr + "{" + exprs[i] = assert_remove_suffix(expr_rstrip, "=") # compile Coconut expressions compiled_exprs = [] @@ -4465,15 +4568,21 @@ def async_with_for_stmt_handle(self, original, loc, tokens): loop=loop ) - def string_atom_handle(self, tokens): + def string_atom_handle(self, original, loc, tokens, allow_silent_concat=False): """Handle concatenation of string literals.""" internal_assert(len(tokens) >= 1, "invalid string literal tokens", tokens) - if any(s.endswith(")") for s in tokens): # has .format() calls - return "(" + " + ".join(tokens) + ")" - elif any(s.startswith(("f", "rf")) for s in tokens): # has f-strings - return " ".join(tokens) + if len(tokens) == 1: + return tokens[0] else: - return self.eval_now(" ".join(tokens)) + if not allow_silent_concat: + self.strict_err_or_warn("found Python-style implicit string concatenation (use explicit '+' instead)", original, loc) + if any(s.endswith(")") for s in tokens): # has .format() calls + # parens are necessary for string_atom_handle + return "(" + " + ".join(tokens) + ")" + elif any(s.startswith(("f", "rf")) for s in tokens): # has f-strings + return " ".join(tokens) + else: + return self.eval_now(" ".join(tokens)) string_atom_handle.ignore_one_token = True @@ -4493,6 +4602,8 @@ def term_handle(self, tokens): out += [op, term] return " ".join(out) + term_handle.ignore_one_token = True + def impl_call_handle(self, loc, tokens): """Process implicit function application or coefficient syntax.""" internal_assert(len(tokens) >= 2, "invalid implicit call / coefficient tokens", tokens) @@ -4572,15 +4683,21 @@ def funcname_typeparams_handle(self, tokens): return name else: name, paramdefs = tokens - return self.add_code_before_marker_with_replacement(name, "".join(paramdefs), add_spaces=False) + # paramdefs are type params on >= 3.12 and type var assignments on < 3.12 + if self.target_info >= (3, 12): + return name + "[" + ", ".join(paramdefs) + "]" + else: + return self.add_code_before_marker_with_replacement(name, "".join(paramdefs), add_spaces=False) funcname_typeparams_handle.ignore_one_token = True def type_param_handle(self, original, loc, tokens): """Compile a type param into an assignment.""" args = "" + raw_bound = None bound_op = None bound_op_type = "" + stars = "" if "TypeVar" in tokens: TypeVarFunc = "TypeVar" bound_op_type = "bound" @@ -4588,18 +4705,24 @@ def type_param_handle(self, original, loc, tokens): name_loc, name = tokens else: name_loc, name, bound_op, bound = tokens + # raw_bound is for >=3.12, so it is for_py_typedef, but args is for <3.12, so it isn't + raw_bound = self.wrap_typedef(bound, for_py_typedef=True) args = ", bound=" + self.wrap_typedef(bound, for_py_typedef=False) elif "TypeVar constraint" in tokens: TypeVarFunc = "TypeVar" bound_op_type = "constraint" name_loc, name, bound_op, constraints = tokens + # for_py_typedef is different in the two cases here as above + raw_bound = ", ".join(self.wrap_typedef(c, for_py_typedef=True) for c in constraints) args = ", " + ", ".join(self.wrap_typedef(c, for_py_typedef=False) for c in constraints) elif "TypeVarTuple" in tokens: TypeVarFunc = "TypeVarTuple" name_loc, name = tokens + stars = "*" elif "ParamSpec" in tokens: TypeVarFunc = "ParamSpec" name_loc, name = tokens + stars = "**" else: raise CoconutInternalException("invalid type_param tokens", tokens) @@ -4620,8 +4743,14 @@ def type_param_handle(self, original, loc, tokens): loc, ) + # on >= 3.12, return a type param + if self.target_info >= (3, 12): + return stars + name + (": " + raw_bound if raw_bound is not None else "") + + # on < 3.12, return a type variable assignment + kwargs = "" - # uncomment these lines whenever mypy adds support for infer_variance in TypeVar + # TODO: uncomment these lines whenever mypy adds support for infer_variance in TypeVar # (and remove the warning about it in the DOCS) # if TypeVarFunc == "TypeVar": # kwargs += ", infer_variance=True" @@ -4652,6 +4781,7 @@ def type_param_handle(self, original, loc, tokens): def get_generic_for_typevars(self): """Get the Generic instances for the current typevars.""" + internal_assert(self.target_info < (3, 12), "get_generic_for_typevars should only be used on targets < 3.12") typevar_info = self.current_parsing_context("typevars") internal_assert(typevar_info is not None, "get_generic_for_typevars called with no typevars") generics = [] @@ -4685,16 +4815,18 @@ def type_alias_stmt_handle(self, tokens): paramdefs = () else: name, paramdefs, typedef = tokens - out = "".join(paramdefs) + + # paramdefs are type params on >= 3.12 and type var assignments on < 3.12 if self.target_info >= (3, 12): - out += "type " + name + " = " + self.wrap_typedef(typedef, for_py_typedef=True) + if paramdefs: + name += "[" + ", ".join(paramdefs) + "]" + return "type " + name + " = " + self.wrap_typedef(typedef, for_py_typedef=True) else: - out += self.typed_assign_stmt_handle([ + return "".join(paramdefs) + self.typed_assign_stmt_handle([ name, "_coconut.typing.TypeAlias", self.wrap_typedef(typedef, for_py_typedef=False), ]) - return out def where_item_handle(self, tokens): """Manage where items.""" @@ -5077,7 +5209,7 @@ def warm_up(self, streamline=False, enable_incremental_mode=False, set_debug_nam self.streamline(self.file_parser, force=streamline) self.streamline(self.eval_parser, force=streamline) if enable_incremental_mode: - enable_incremental_parsing() + enable_incremental_parsing(reason="explicit warm_up call") # end: ENDPOINTS diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index 967930699..7e5bd8e6e 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -38,9 +38,7 @@ Literal, OneOrMore, Optional, - ParserElement, StringEnd, - StringStart, Word, ZeroOrMore, hexnums, @@ -48,7 +46,6 @@ originalTextFor, nestedExpr, FollowedBy, - python_quoted_string, restOfLine, ) @@ -119,6 +116,7 @@ using_fast_grammar_methods, disambiguate_literal, any_of, + StartOfStrGrammar, ) @@ -183,6 +181,75 @@ def pipe_info(op): return direction, stars, none_aware +def split_args_list(tokens, loc): + """Splits function definition arguments.""" + pos_only_args = [] + req_args = [] + default_args = [] + star_arg = None + kwd_only_args = [] + dubstar_arg = None + pos = 0 + for arg in tokens: + # only the first two components matter; if there's a third it's a typedef + arg = arg[:2] + + if len(arg) == 1: + if arg[0] == "*": + # star sep (pos = 2) + if pos >= 2: + raise CoconutDeferredSyntaxError("star separator at invalid position in function definition", loc) + pos = 2 + elif arg[0] == "/": + # slash sep (pos = 0) + if pos > 0: + raise CoconutDeferredSyntaxError("slash separator at invalid position in function definition", loc) + if pos_only_args: + raise CoconutDeferredSyntaxError("only one slash separator allowed in function definition", loc) + if not req_args: + raise CoconutDeferredSyntaxError("slash separator must come after arguments to mark as positional-only", loc) + pos_only_args = req_args + req_args = [] + else: + # pos arg (pos = 0) + if pos == 0: + req_args.append(arg[0]) + # kwd only arg (pos = 2) + elif pos == 2: + kwd_only_args.append((arg[0], None)) + else: + raise CoconutDeferredSyntaxError("non-default arguments must come first or after star argument/separator", loc) + + else: + internal_assert(arg[1] is not None, "invalid arg[1] in split_args_list", arg) + + if arg[0] == "*": + # star arg (pos = 2) + if pos >= 2: + raise CoconutDeferredSyntaxError("star argument at invalid position in function definition", loc) + pos = 2 + star_arg = arg[1] + elif arg[0] == "**": + # dub star arg (pos = 3) + if pos == 3: + raise CoconutDeferredSyntaxError("double star argument at invalid position in function definition", loc) + pos = 3 + dubstar_arg = arg[1] + else: + # def arg (pos = 1) + if pos <= 1: + pos = 1 + default_args.append((arg[0], arg[1])) + # kwd only arg (pos = 2) + elif pos <= 2: + pos = 2 + kwd_only_args.append((arg[0], arg[1])) + else: + raise CoconutDeferredSyntaxError("invalid default argument in function definition", loc) + + return pos_only_args, req_args, default_args, star_arg, kwd_only_args, dubstar_arg + + # end: HELPERS # ----------------------------------------------------------------------------------------------------------------------- # HANDLERS: @@ -523,15 +590,6 @@ def join_match_funcdef(tokens): ) -def kwd_err_msg_handle(tokens): - """Handle keyword parse error messages.""" - kwd, = tokens - if kwd == "def": - return "invalid function definition" - else: - return 'invalid use of the keyword "' + kwd + '"' - - def alt_ternary_handle(tokens): """Handle if ... then ... else ternary operator.""" cond, if_true, if_false = tokens @@ -864,7 +922,6 @@ class Grammar(object): # rparen handles simple stmts ending parenthesized stmt lambdas end_simple_stmt_item = FollowedBy(newline | semicolon | rparen) - start_marker = StringStart() moduledoc_marker = condense(ZeroOrMore(lineitem) - Optional(moduledoc_item)) end_marker = StringEnd() indent = Literal(openindent) @@ -1180,11 +1237,12 @@ class Grammar(object): call_item = ( unsafe_name + default - # ellipsis must come before namedexpr_test - | ellipsis_tokens + equals.suppress() + refname - | namedexpr_test | star + test | dubstar + test + | refname + equals # new long name ellision syntax + | ellipsis_tokens + equals.suppress() + refname # old long name ellision syntax + # must come at end + | namedexpr_test ) function_call_tokens = lparen.suppress() + ( # everything here must end with rparen @@ -1234,7 +1292,7 @@ class Grammar(object): maybe_typedef = Optional(colon.suppress() + typedef_test) anon_namedtuple_ref = tokenlist( Group( - unsafe_name + maybe_typedef + equals.suppress() + test + unsafe_name + maybe_typedef + (equals.suppress() + test | equals) | ellipsis_tokens + maybe_typedef + equals.suppress() + refname ), comma, @@ -1512,14 +1570,12 @@ class Grammar(object): # arith_expr = exprlist(term, addop) # shift_expr = exprlist(arith_expr, shift) # and_expr = exprlist(shift_expr, amp) - and_expr = exprlist( - term, - any_of( - addop, - shift, - amp, - ), + term_op = any_of( + addop, + shift, + amp, ) + and_expr = exprlist(term, term_op) protocol_intersect_expr = Forward() protocol_intersect_expr_ref = tokenlist(and_expr, amp_colon, allow_trailing=False) @@ -2182,11 +2238,12 @@ class Grammar(object): with_stmt = Forward() funcname_typeparams = Forward() - funcname_typeparams_ref = dotted_setname + Optional(type_params) + funcname_typeparams_tokens = dotted_setname + Optional(type_params) name_funcdef = condense(funcname_typeparams + parameters) op_tfpdef = unsafe_typedef_default | condense(setname + Optional(default)) op_funcdef_arg = setname | condense(lparen.suppress() + op_tfpdef + rparen.suppress()) op_funcdef_name = unsafe_backtick.suppress() + funcname_typeparams + unsafe_backtick.suppress() + op_funcdef_name_tokens = unsafe_backtick.suppress() + funcname_typeparams_tokens + unsafe_backtick.suppress() op_funcdef = attach( Group(Optional(op_funcdef_arg)) + op_funcdef_name @@ -2274,44 +2331,108 @@ class Grammar(object): base_match_funcdef + end_func_equals - ( - attach(implicit_return_stmt, make_suite_handle) - | ( + ( newline.suppress() - indent.suppress() - Optional(docstring) - attach(math_funcdef_body, make_suite_handle) - dedent.suppress() ) + | attach(implicit_return_stmt, make_suite_handle) ), join_match_funcdef, ) ) - async_stmt = Forward() - async_with_for_stmt = Forward() - async_with_for_stmt_ref = ( - labeled_group( - (keyword("async") + keyword("with") + keyword("for")).suppress() - + assignlist + keyword("in").suppress() - - test - - suite_with_else_tokens, - "normal", + base_case_funcdef = Forward() + base_case_funcdef_ref = ( + keyword("def").suppress() + + Group( + funcname_typeparams_tokens + | op_funcdef_name_tokens ) - | labeled_group( - (any_len_perm( - keyword("match"), - required=(keyword("async"), keyword("with")), - ) + keyword("for")).suppress() - + many_match + keyword("in").suppress() - - test - - suite_with_else_tokens, - "match", + + colon.suppress() + - newline.suppress() + - indent.suppress() + - Optional(docstring) + - Group(OneOrMore( + labeled_group( + keyword("case").suppress() + + lparen.suppress() + + match_args_list + + match_guard + + rparen.suppress() + + ( + colon.suppress() + + ( + newline.suppress() + + indent.suppress() + + attach(condense(OneOrMore(stmt)), make_suite_handle) + + dedent.suppress() + | attach(simple_stmt, make_suite_handle) + ) + | equals.suppress() + + ( + ( + newline.suppress() + + indent.suppress() + + attach(math_funcdef_body, make_suite_handle) + + dedent.suppress() + ) + | attach(implicit_return_stmt, make_suite_handle) + ) + ), + "match", + ) + | labeled_group( + keyword("type").suppress() + + parameters + + return_typedef + + newline.suppress(), + "type", + ) + )) + - dedent.suppress() + ) + case_funcdef = keyword("case").suppress() + base_case_funcdef + + keyword_normal_funcdef = Group( + any_len_perm_at_least_one( + keyword("yield"), + keyword("copyclosure"), ) + ) + (funcdef | math_funcdef) + keyword_match_funcdef = Group( + any_len_perm_at_least_one( + keyword("yield"), + keyword("copyclosure"), + keyword("match").suppress(), + # addpattern is detected later + keyword("addpattern"), + ) + ) + (def_match_funcdef | math_match_funcdef) + keyword_case_funcdef = Group( + any_len_perm_at_least_one( + keyword("yield"), + keyword("copyclosure"), + required=(keyword("case").suppress(),), + ) + ) + base_case_funcdef + keyword_funcdef = Forward() + keyword_funcdef_ref = ( + keyword_normal_funcdef + | keyword_match_funcdef + | keyword_case_funcdef ) - async_stmt_ref = addspace( - keyword("async") + (with_stmt | any_for_stmt) # handles async [match] for - | keyword("match").suppress() + keyword("async") + base_match_for_stmt # handles match async for - | async_with_for_stmt + + normal_funcdef_stmt = ( + # match funcdefs must come after normal + funcdef + | math_funcdef + | match_funcdef + | math_match_funcdef + | case_funcdef + | keyword_funcdef ) async_funcdef = keyword("async").suppress() + (funcdef | math_funcdef) @@ -2321,7 +2442,15 @@ class Grammar(object): # addpattern is detected later keyword("addpattern"), required=(keyword("async").suppress(),), - ) + (def_match_funcdef | math_match_funcdef), + ) + (def_match_funcdef | math_match_funcdef) + ) + async_case_funcdef = addspace( + any_len_perm( + required=( + keyword("case").suppress(), + keyword("async").suppress(), + ), + ) + base_case_funcdef ) async_keyword_normal_funcdef = Group( @@ -2341,41 +2470,56 @@ class Grammar(object): required=(keyword("async").suppress(),), ) ) + (def_match_funcdef | math_match_funcdef) + async_keyword_case_funcdef = Group( + any_len_perm_at_least_one( + keyword("yield"), + keyword("copyclosure"), + required=( + keyword("async").suppress(), + keyword("case").suppress(), + ), + ) + ) + base_case_funcdef async_keyword_funcdef = Forward() - async_keyword_funcdef_ref = async_keyword_normal_funcdef | async_keyword_match_funcdef + async_keyword_funcdef_ref = ( + async_keyword_normal_funcdef + | async_keyword_match_funcdef + | async_keyword_case_funcdef + ) async_funcdef_stmt = ( # match funcdefs must come after normal async_funcdef | async_match_funcdef + | async_case_funcdef | async_keyword_funcdef ) - keyword_normal_funcdef = Group( - any_len_perm_at_least_one( - keyword("yield"), - keyword("copyclosure"), + async_stmt = Forward() + async_with_for_stmt = Forward() + async_with_for_stmt_ref = ( + labeled_group( + (keyword("async") + keyword("with") + keyword("for")).suppress() + + assignlist + keyword("in").suppress() + - test + - suite_with_else_tokens, + "normal", ) - ) + (funcdef | math_funcdef) - keyword_match_funcdef = Group( - any_len_perm_at_least_one( - keyword("yield"), - keyword("copyclosure"), - keyword("match").suppress(), - # addpattern is detected later - keyword("addpattern"), + | labeled_group( + (any_len_perm( + keyword("match"), + required=(keyword("async"), keyword("with")), + ) + keyword("for")).suppress() + + many_match + keyword("in").suppress() + - test + - suite_with_else_tokens, + "match", ) - ) + (def_match_funcdef | math_match_funcdef) - keyword_funcdef = Forward() - keyword_funcdef_ref = keyword_normal_funcdef | keyword_match_funcdef - - normal_funcdef_stmt = ( - # match funcdefs must come after normal - funcdef - | math_funcdef - | match_funcdef - | math_match_funcdef - | keyword_funcdef + ) + async_stmt_ref = addspace( + keyword("async") + (with_stmt | any_for_stmt) # handles async [match] for + | keyword("match").suppress() + keyword("async") + base_match_for_stmt # handles match async for + | async_with_for_stmt ) datadef = Forward() @@ -2522,19 +2666,19 @@ class Grammar(object): line = newline | stmt file_input = condense(moduledoc_marker - ZeroOrMore(line)) - raw_file_parser = start_marker - file_input - end_marker + raw_file_parser = StartOfStrGrammar(file_input - end_marker) line_by_line_file_parser = ( - start_marker - moduledoc_marker - stores_loc_item, - start_marker - line - stores_loc_item, + StartOfStrGrammar(moduledoc_marker - stores_loc_item), + StartOfStrGrammar(line - stores_loc_item), ) file_parser = line_by_line_file_parser if USE_LINE_BY_LINE else raw_file_parser single_input = condense(Optional(line) - ZeroOrMore(newline)) eval_input = condense(testlist - ZeroOrMore(newline)) - single_parser = start_marker - single_input - end_marker - eval_parser = start_marker - eval_input - end_marker - some_eval_parser = start_marker + eval_input + single_parser = StartOfStrGrammar(single_input - end_marker) + eval_parser = StartOfStrGrammar(eval_input - end_marker) + some_eval_parser = StartOfStrGrammar(eval_input) parens = originalTextFor(nestedExpr("(", ")", ignoreExpr=None)) brackets = originalTextFor(nestedExpr("[", "]", ignoreExpr=None)) @@ -2552,15 +2696,16 @@ class Grammar(object): ) ) unsafe_xonsh_parser, _impl_call_ref = disable_inside( - single_parser, + single_input - end_marker, unsafe_impl_call_ref, ) impl_call_ref <<= _impl_call_ref - xonsh_parser, _anything_stmt, _xonsh_command = disable_outside( + _xonsh_parser, _anything_stmt, _xonsh_command = disable_outside( unsafe_xonsh_parser, unsafe_anything_stmt, unsafe_xonsh_command, ) + xonsh_parser = StartOfStrGrammar(_xonsh_parser) anything_stmt <<= _anything_stmt xonsh_command <<= _xonsh_command @@ -2574,7 +2719,8 @@ class Grammar(object): whitespace_regex = compile_regex(r"\s") - def_regex = compile_regex(r"\b((async|addpattern|copyclosure)\s+)*def\b") + def_regex = compile_regex(r"((async|addpattern|copyclosure)\s+)*def\b") + yield_regex = compile_regex(r"\byield(?!\s+_coconut\.asyncio\.From)\b") yield_from_regex = compile_regex(r"\byield\s+from\b") @@ -2583,7 +2729,7 @@ class Grammar(object): noqa_regex = compile_regex(r"\b[Nn][Oo][Qq][Aa]\b") - just_non_none_atom = start_marker + ~keyword("None") + known_atom + end_marker + just_non_none_atom = StartOfStrGrammar(~keyword("None") + known_atom + end_marker) original_function_call_tokens = ( lparen.suppress() + rparen.suppress() @@ -2593,9 +2739,8 @@ class Grammar(object): ) tre_func_name = Forward() - tre_return = ( - start_marker - + keyword("return").suppress() + tre_return_base = ( + keyword("return").suppress() + maybeparens( lparen, tre_func_name + original_function_call_tokens, @@ -2603,9 +2748,8 @@ class Grammar(object): ) + end_marker ) - tco_return = attach( - start_marker - + keyword("return").suppress() + tco_return = StartOfStrGrammar(attach( + keyword("return").suppress() + maybeparens( lparen, disallow_keywords(untcoable_funcs, with_suffix="(") @@ -2630,7 +2774,7 @@ class Grammar(object): tco_return_handle, # this is the root in what it's used for, so might as well evaluate greedily greedy=True, - ) + )) rest_of_lambda = Forward() lambdas = keyword("lambda") - rest_of_lambda - colon @@ -2670,9 +2814,8 @@ class Grammar(object): )) ) - split_func = ( - start_marker - - keyword("def").suppress() + split_func = StartOfStrGrammar( + keyword("def").suppress() - unsafe_dotted_name - Optional(brackets).suppress() - lparen.suppress() @@ -2686,13 +2829,13 @@ class Grammar(object): | ~indent + ~dedent + any_char + keyword("for") + unsafe_name + keyword("in") ) - just_a_string = start_marker + string_atom + end_marker + just_a_string = StartOfStrGrammar(string_atom + end_marker) end_of_line = end_marker | Literal("\n") | pound unsafe_equals = Literal("=") - parse_err_msg = start_marker + ( + parse_err_msg = StartOfStrGrammar( # should be in order of most likely to actually be the source of the error first fixto( ZeroOrMore(~questionmark + ~Literal("\n") + any_char) @@ -2704,29 +2847,38 @@ class Grammar(object): "misplaced '?' (naked '?' is only supported inside partial application arguments)", ) | fixto(Optional(keyword("if") + skip_to_in_line(unsafe_equals)) + equals, "misplaced assignment (maybe should be '==')") - | attach(any_keyword_in(keyword_vars + reserved_vars), kwd_err_msg_handle) + | fixto(keyword("def"), "invalid function definition") | fixto(end_of_line, "misplaced newline (maybe missing ':')") ) start_f_str_regex = compile_regex(r"\br?fr?$") start_f_str_regex_len = 4 - end_f_str_expr = combine(start_marker + (rbrace | colon | bang)) + end_f_str_expr = StartOfStrGrammar(combine(rbrace | colon | bang).leaveWhitespace()) + + python_quoted_string = regex_item( + # multiline strings must come first + r'"""(?:[^"\\]|\n|""(?!")|"(?!"")|\\.)*"""' + r"|'''(?:[^'\\]|\n|''(?!')|'(?!'')|\\.)*'''" + r'|"(?:[^"\n\r\\]|(?:\\")|(?:\\(?:[^x]|x[0-9a-fA-F]+)))*"' + r"|'(?:[^'\n\r\\]|(?:\\')|(?:\\(?:[^x]|x[0-9a-fA-F]+)))*'" + ) - string_start = start_marker + python_quoted_string + string_start = StartOfStrGrammar(python_quoted_string) - no_unquoted_newlines = start_marker + ZeroOrMore(python_quoted_string | ~Literal("\n") + any_char) + end_marker + no_unquoted_newlines = StartOfStrGrammar( + ZeroOrMore(python_quoted_string | ~Literal("\n") + any_char) + + end_marker + ) - operator_stmt = ( - start_marker - + keyword("operator").suppress() + operator_stmt = StartOfStrGrammar( + keyword("operator").suppress() + restOfLine ) unsafe_import_from_name = condense(ZeroOrMore(unsafe_dot) + unsafe_dotted_name | OneOrMore(unsafe_dot)) - from_import_operator = ( - start_marker - + keyword("from").suppress() + from_import_operator = StartOfStrGrammar( + keyword("from").suppress() + unsafe_import_from_name + keyword("import").suppress() + keyword("operator").suppress() @@ -2754,7 +2906,7 @@ def add_to_grammar_init_time(cls): def set_grammar_names(): """Set names of grammar elements to their variable names.""" for varname, val in vars(Grammar).items(): - if isinstance(val, ParserElement): + if hasattr(val, "setName"): val.setName(varname) diff --git a/coconut/compiler/header.py b/coconut/compiler/header.py index 39b2d2664..989e0c6a8 100644 --- a/coconut/compiler/header.py +++ b/coconut/compiler/header.py @@ -641,7 +641,7 @@ def __anext__(self): # (extra_format_dict is to keep indentation levels matching) extra_format_dict = dict( # when anything is added to this list it must also be added to *both* __coconut__ stub files - underscore_imports="{tco_comma}{call_set_names_comma}{handle_cls_args_comma}_namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_arr_concat_op, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter, _coconut_if_op".format(**format_dict), + underscore_imports="{tco_comma}{call_set_names_comma}{handle_cls_args_comma}_namedtuple_of, _coconut, _coconut_Expected, _coconut_MatchError, _coconut_SupportsAdd, _coconut_SupportsMinus, _coconut_SupportsMul, _coconut_SupportsPow, _coconut_SupportsTruediv, _coconut_SupportsFloordiv, _coconut_SupportsMod, _coconut_SupportsAnd, _coconut_SupportsXor, _coconut_SupportsOr, _coconut_SupportsLshift, _coconut_SupportsRshift, _coconut_SupportsMatmul, _coconut_SupportsInv, _coconut_iter_getitem, _coconut_base_compose, _coconut_forward_compose, _coconut_back_compose, _coconut_forward_star_compose, _coconut_back_star_compose, _coconut_forward_dubstar_compose, _coconut_back_dubstar_compose, _coconut_pipe, _coconut_star_pipe, _coconut_dubstar_pipe, _coconut_back_pipe, _coconut_back_star_pipe, _coconut_back_dubstar_pipe, _coconut_none_pipe, _coconut_none_star_pipe, _coconut_none_dubstar_pipe, _coconut_bool_and, _coconut_bool_or, _coconut_none_coalesce, _coconut_minus, _coconut_map, _coconut_partial, _coconut_complex_partial, _coconut_get_function_match_error, _coconut_base_pattern_func, _coconut_addpattern, _coconut_sentinel, _coconut_assert, _coconut_raise, _coconut_mark_as_match, _coconut_reiterable, _coconut_self_match_types, _coconut_dict_merge, _coconut_exec, _coconut_comma_op, _coconut_arr_concat_op, _coconut_mk_anon_namedtuple, _coconut_matmul, _coconut_py_str, _coconut_flatten, _coconut_multiset, _coconut_back_none_pipe, _coconut_back_none_star_pipe, _coconut_back_none_dubstar_pipe, _coconut_forward_none_compose, _coconut_back_none_compose, _coconut_forward_none_star_compose, _coconut_back_none_star_compose, _coconut_forward_none_dubstar_compose, _coconut_back_none_dubstar_compose, _coconut_call_or_coefficient, _coconut_in, _coconut_not_in, _coconut_attritemgetter, _coconut_if_op, _coconut_CoconutWarning".format(**format_dict), import_typing=pycondition( (3, 5), if_ge=''' @@ -865,6 +865,21 @@ def async_map(*args, **kwargs): "{_coconut_}zip".format(**format_dict): "zip", }, ), + def_call=pycondition( + (3, 11), + if_ge=r''' +call = _coconut.operator.call + ''', + if_lt=r''' +def call(_coconut_f{comma_slash}, *args, **kwargs): + """Function application operator function. + + Equivalent to: + def call(f, /, *args, **kwargs) = f(*args, **kwargs). + """ + return _coconut_f(*args, **kwargs) + '''.format(**format_dict), + ), ) format_dict.update(extra_format_dict) @@ -924,8 +939,8 @@ def getheader(which, use_hash, target, no_tco, strict, no_wrap): header += "from __future__ import print_function, absolute_import, unicode_literals, division\n" # including generator_stop here is fine, even though to universalize generator returns # we raise StopIteration errors, since we only do so when target_info < (3, 3) - elif target_info >= (3, 13): - # 3.13 supports lazy annotations, so we should just use that instead of from __future__ import annotations + elif target_info >= (3, 14): + # 3.14 supports lazy annotations, so we should just use that instead of from __future__ import annotations header += "from __future__ import generator_stop\n" elif target_info >= (3, 7): if no_wrap: @@ -953,24 +968,24 @@ def getheader(which, use_hash, target, no_tco, strict, no_wrap): _coconut_cached__coconut__ = _coconut_sys.modules.get({__coconut__}) _coconut_file_dir = {coconut_file_dir} _coconut_pop_path = False -if _coconut_cached__coconut__ is None or getattr(_coconut_cached__coconut__, "_coconut_header_info", None) != _coconut_header_info and _coconut_os.path.dirname(_coconut_cached__coconut__.__file__ or "") != _coconut_file_dir: +if _coconut_cached__coconut__ is None or getattr(_coconut_cached__coconut__, "_coconut_header_info", None) != _coconut_header_info and _coconut_os.path.dirname(_coconut_cached__coconut__.__file__ or "") != _coconut_file_dir: # type: ignore if _coconut_cached__coconut__ is not None: _coconut_sys.modules[{_coconut_cached__coconut__}] = _coconut_cached__coconut__ del _coconut_sys.modules[{__coconut__}] _coconut_sys.path.insert(0, _coconut_file_dir) _coconut_pop_path = True _coconut_module_name = _coconut_os.path.splitext(_coconut_os.path.basename(_coconut_file_dir))[0] - if _coconut_module_name and _coconut_module_name[0].isalpha() and all(c.isalpha() or c.isdigit() for c in _coconut_module_name) and "__init__.py" in _coconut_os.listdir(_coconut_file_dir): - _coconut_full_module_name = str(_coconut_module_name + ".__coconut__") + if _coconut_module_name and _coconut_module_name[0].isalpha() and all(c.isalpha() or c.isdigit() for c in _coconut_module_name) and "__init__.py" in _coconut_os.listdir(_coconut_file_dir): # type: ignore + _coconut_full_module_name = str(_coconut_module_name + ".__coconut__") # type: ignore import __coconut__ as _coconut__coconut__ _coconut__coconut__.__name__ = _coconut_full_module_name - for _coconut_v in vars(_coconut__coconut__).values(): - if getattr(_coconut_v, "__module__", None) == {__coconut__}: + for _coconut_v in vars(_coconut__coconut__).values(): # type: ignore + if getattr(_coconut_v, "__module__", None) == {__coconut__}: # type: ignore try: _coconut_v.__module__ = _coconut_full_module_name except AttributeError: - _coconut_v_type = type(_coconut_v) - if getattr(_coconut_v_type, "__module__", None) == {__coconut__}: + _coconut_v_type = type(_coconut_v) # type: ignore + if getattr(_coconut_v_type, "__module__", None) == {__coconut__}: # type: ignore _coconut_v_type.__module__ = _coconut_full_module_name _coconut_sys.modules[_coconut_full_module_name] = _coconut__coconut__ from __coconut__ import * diff --git a/coconut/compiler/matching.py b/coconut/compiler/matching.py index 99e5457f5..9690dc9d9 100644 --- a/coconut/compiler/matching.py +++ b/coconut/compiler/matching.py @@ -46,6 +46,8 @@ match_to_args_var, match_to_kwargs_var, ) +from coconut.util import noop_ctx +from coconut.compiler.grammar import split_args_list from coconut.compiler.util import ( paren_join, handle_indentation, @@ -93,6 +95,24 @@ def get_match_names(match): return names +def match_funcdef_setup_code( + first_arg=match_first_arg_var, + args=match_to_args_var, +): + """Get initial code to set up a match funcdef.""" + # pop the FunctionMatchError from context + # and fix args to include first_arg, which we have to do to make super work + return handle_indentation(""" +{function_match_error_var} = _coconut_get_function_match_error() +if {first_arg} is not _coconut_sentinel: + {args} = ({first_arg},) + {args} + """).format( + function_match_error_var=function_match_error_var, + first_arg=first_arg, + args=args, + ) + + # ----------------------------------------------------------------------------------------------------------------------- # MATCHER: # ----------------------------------------------------------------------------------------------------------------------- @@ -393,6 +413,18 @@ def check_len_in(self, min_len, max_len, item): else: self.add_check(str(min_len) + " <= _coconut.len(" + item + ") <= " + str(max_len)) + def match_function_toks(self, match_arg_toks, include_setup=True): + """Match pattern-matching function tokens.""" + pos_only_args, req_args, default_args, star_arg, kwd_only_args, dubstar_arg = split_args_list(match_arg_toks, self.loc) + self.match_function( + pos_only_match_args=pos_only_args, + match_args=req_args + default_args, + star_arg=star_arg, + kwd_only_match_args=kwd_only_args, + dubstar_arg=dubstar_arg, + include_setup=include_setup, + ) + def match_function( self, first_arg=match_first_arg_var, @@ -403,24 +435,13 @@ def match_function( star_arg=None, kwd_only_match_args=(), dubstar_arg=None, + include_setup=True, ): """Matches a pattern-matching function.""" - # before everything, pop the FunctionMatchError from context - self.add_def(function_match_error_var + " = _coconut_get_function_match_error()") - # and fix args to include first_arg, which we have to do to make super work - self.add_def( - handle_indentation( - """ -if {first_arg} is not _coconut_sentinel: - {args} = ({first_arg},) + {args} - """, - ).format( - first_arg=first_arg, - args=args, - ) - ) + if include_setup: + self.add_def(match_funcdef_setup_code(first_arg, args)) - with self.down_a_level(): + with self.down_a_level() if include_setup else noop_ctx(): self.match_in_args_kwargs(pos_only_match_args, match_args, args, kwargs, allow_star_args=star_arg is not None) @@ -618,6 +639,22 @@ def proc_sequence_match(self, tokens, iter_match=False): elif "elem" in group: group_type = "elem_matches" group_contents = group + # must check for f_string before string, since a mixture will be tagged as both + elif "f_string" in group: + group_type = "f_string" + # f strings are always unicode + if seq_type is None: + seq_type = '"' + elif seq_type != '"': + raise CoconutDeferredSyntaxError("string literals and byte literals cannot be mixed in string patterns", self.loc) + for str_literal in group: + if str_literal.startswith("b"): + raise CoconutDeferredSyntaxError("string literals and byte literals cannot be mixed in string patterns", self.loc) + if len(group) == 1: + str_item = group[0] + else: + str_item = self.comp.string_atom_handle(self.original, self.loc, group, allow_silent_concat=True) + group_contents = (str_item, "_coconut.len(" + str_item + ")") elif "string" in group: group_type = "string" for str_literal in group: @@ -634,16 +671,6 @@ def proc_sequence_match(self, tokens, iter_match=False): else: str_item = self.comp.eval_now(" ".join(group)) group_contents = (str_item, len(self.comp.literal_eval(str_item))) - elif "f_string" in group: - group_type = "f_string" - # f strings are always unicode - if seq_type is None: - seq_type = '"' - elif seq_type != '"': - raise CoconutDeferredSyntaxError("string literals and byte literals cannot be mixed in string patterns", self.loc) - internal_assert(len(group) == 1, "invalid f string sequence match group", group) - str_item = group[0] - group_contents = (str_item, "_coconut.len(" + str_item + ")") else: raise CoconutInternalException("invalid sequence match group", group) seq_groups.append((group_type, group_contents)) @@ -661,12 +688,12 @@ def handle_sequence(self, seq_type, seq_groups, item, iter_match=False): bounded = False elif gtype == "elem_matches": min_len_int += len(gcontents) - elif gtype == "string": - str_item, str_len = gcontents - min_len_int += str_len elif gtype == "f_string": str_item, str_len = gcontents min_len_strs.append(str_len) + elif gtype == "string": + str_item, str_len = gcontents + min_len_int += str_len else: raise CoconutInternalException("invalid sequence match group type", gtype) min_len = add_int_and_strs(min_len_int, min_len_strs) @@ -690,17 +717,17 @@ def handle_sequence(self, seq_type, seq_groups, item, iter_match=False): self.add_check("_coconut.len(" + head_var + ") == " + str(len(matches))) self.match_all_in(matches, head_var) start_ind_int += len(matches) + elif seq_groups[0][0] == "f_string": + internal_assert(not iter_match, "cannot be both f string and iter match") + _, (str_item, str_len) = seq_groups.pop(0) + self.add_check(item + ".startswith(" + str_item + ")") + start_ind_strs.append(str_len) elif seq_groups[0][0] == "string": internal_assert(not iter_match, "cannot be both string and iter match") _, (str_item, str_len) = seq_groups.pop(0) if str_len > 0: self.add_check(item + ".startswith(" + str_item + ")") start_ind_int += str_len - elif seq_groups[0][0] == "f_string": - internal_assert(not iter_match, "cannot be both f string and iter match") - _, (str_item, str_len) = seq_groups.pop(0) - self.add_check(item + ".startswith(" + str_item + ")") - start_ind_strs.append(str_len) if not seq_groups: return start_ind = add_int_and_strs(start_ind_int, start_ind_strs) @@ -714,17 +741,17 @@ def handle_sequence(self, seq_type, seq_groups, item, iter_match=False): for i, match in enumerate(matches): self.match(match, item + "[-" + str(len(matches) - i) + "]") last_ind_int -= len(matches) + elif seq_groups[-1][0] == "f_string": + internal_assert(not iter_match, "cannot be both f string and iter match") + _, (str_item, str_len) = seq_groups.pop() + self.add_check(item + ".endswith(" + str_item + ")") + last_ind_strs.append("-" + str_len) elif seq_groups[-1][0] == "string": internal_assert(not iter_match, "cannot be both string and iter match") _, (str_item, str_len) = seq_groups.pop() if str_len > 0: self.add_check(item + ".endswith(" + str_item + ")") last_ind_int -= str_len - elif seq_groups[-1][0] == "f_string": - internal_assert(not iter_match, "cannot be both f string and iter match") - _, (str_item, str_len) = seq_groups.pop() - self.add_check(item + ".endswith(" + str_item + ")") - last_ind_strs.append("-" + str_len) if not seq_groups: return last_ind = add_int_and_strs(last_ind_int, last_ind_strs) diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index a332de645..c5cfb8f26 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -61,7 +61,7 @@ class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE} reiterables = abc.Sequence, abc.Mapping, abc.Set fmappables = list, tuple, dict, set, frozenset, bytes, bytearray abc.Sequence.register(collections.deque) - Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} + Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, {lstatic}min{rstatic}, {lstatic}max{rstatic}, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} @_coconut.functools.wraps(_coconut.functools.partial) def _coconut_partial(_coconut_func, *args, **kwargs): partial_func = _coconut.functools.partial(_coconut_func, *args, **kwargs) @@ -136,6 +136,10 @@ def _coconut_xarray_to_numpy(obj): return obj.to_dataframe().to_numpy() else: return obj.to_numpy() +class CoconutWarning(Warning{comma_object}): + """Exception class used for all Coconut warnings.""" + __slots__ = () +_coconut_CoconutWarning = CoconutWarning class MatchError(_coconut_baseclass, Exception): """Pattern-matching error. Has attributes .pattern, .value, and .message."""{COMMENT.no_slots_to_allow_setattr_below} max_val_repr_len = 500 @@ -767,10 +771,7 @@ Additionally supports Cartesian products of numpy arrays.""" if iterables: it_modules = [_coconut_get_base_module(it) for it in iterables] if _coconut.all(mod in _coconut.numpy_modules for mod in it_modules): - if _coconut.any(mod in _coconut.xarray_modules for mod in it_modules): - iterables = tuple((_coconut_xarray_to_numpy(it) if mod in _coconut.xarray_modules else it) for it, mod in _coconut.zip(iterables, it_modules)) - if _coconut.any(mod in _coconut.pandas_modules for mod in it_modules): - iterables = tuple((it.to_numpy() if mod in _coconut.pandas_modules else it) for it, mod in _coconut.zip(iterables, it_modules)) + iterables = tuple((it.to_numpy() if mod in _coconut.pandas_modules else _coconut_xarray_to_numpy(it) if mod in _coconut.xarray_modules else it) for it, mod in _coconut.zip(iterables, it_modules)) if _coconut.any(mod in _coconut.jax_numpy_modules for mod in it_modules): from jax import numpy else: @@ -843,7 +844,7 @@ class map(_coconut_baseclass, _coconut.map): def __len__(self): if not _coconut.all(_coconut.isinstance(it, _coconut.abc.Sized) for it in self.iters): return _coconut.NotImplemented - return _coconut.min(_coconut.len(it) for it in self.iters) + return _coconut.min((_coconut.len(it) for it in self.iters), default=0) def __repr__(self): return "%s(%r, %s)" % (self.__class__.__name__, self.func, ", ".join((_coconut.repr(it) for it in self.iters))) def __reduce__(self): @@ -985,7 +986,7 @@ class zip(_coconut_baseclass, _coconut.zip): def __len__(self): if not _coconut.all(_coconut.isinstance(it, _coconut.abc.Sized) for it in self.iters): return _coconut.NotImplemented - return _coconut.min(_coconut.len(it) for it in self.iters) + return _coconut.min((_coconut.len(it) for it in self.iters), default=0) def __repr__(self): return "zip(%s%s)" % (", ".join((_coconut.repr(it) for it in self.iters)), ", strict=True" if self.strict else "") def __reduce__(self): @@ -1036,7 +1037,7 @@ class zip_longest(zip): def __len__(self): if not _coconut.all(_coconut.isinstance(it, _coconut.abc.Sized) for it in self.iters): return _coconut.NotImplemented - return _coconut.max(_coconut.len(it) for it in self.iters) + return _coconut.max((_coconut.len(it) for it in self.iters), default=0) def __repr__(self): return "zip_longest(%s, fillvalue=%s)" % (", ".join((_coconut.repr(it) for it in self.iters)), _coconut.repr(self.fillvalue)) def __reduce__(self): @@ -1100,12 +1101,7 @@ class multi_enumerate(_coconut_has_iter): through inner iterables and produces a tuple index representing the index in each inner iterable. Supports indexing. - For numpy arrays, effectively equivalent to: - it = np.nditer(iterable, flags=["multi_index", "refs_ok"]) - for x in it: - yield it.multi_index, x - - Also supports len for numpy arrays. + For numpy arrays, uses np.nditer under the hood and supports len. """ __slots__ = () def __repr__(self): @@ -1439,7 +1435,7 @@ def addpattern(base_func, *add_funcs, **kwargs): """ allow_any_func = kwargs.pop("allow_any_func", False) if not allow_any_func and not _coconut.getattr(base_func, "_coconut_is_match", False): - _coconut.warnings.warn("Possible misuse of addpattern with non-pattern-matching function " + _coconut.repr(base_func) + " (pass allow_any_func=True to dismiss)", stacklevel=2) + _coconut.warnings.warn("Possible misuse of addpattern with non-pattern-matching function " + _coconut.repr(base_func) + " (pass allow_any_func=True to dismiss)", _coconut_CoconutWarning, 2) if kwargs: raise _coconut.TypeError("addpattern() got unexpected keyword arguments " + _coconut.repr(kwargs)) if add_funcs: @@ -1709,13 +1705,7 @@ def ident(x, **kwargs): if side_effect is not None: side_effect(x) return x -def call(_coconut_f{comma_slash}, *args, **kwargs): - """Function application operator function. - - Equivalent to: - def call(f, /, *args, **kwargs) = f(*args, **kwargs). - """ - return _coconut_f(*args, **kwargs) +{def_call} def safe_call(_coconut_f{comma_slash}, *args, **kwargs): """safe_call is a version of call that catches any Exceptions and returns an Expected containing either the result or the error. @@ -1962,10 +1952,10 @@ def all_equal(iterable, to=_coconut_sentinel): """ iterable_module = _coconut_get_base_module(iterable) if iterable_module in _coconut.numpy_modules: - if iterable_module in _coconut.xarray_modules: - iterable = _coconut_xarray_to_numpy(iterable) - elif iterable_module in _coconut.pandas_modules: + if iterable_module in _coconut.pandas_modules: iterable = iterable.to_numpy() + elif iterable_module in _coconut.xarray_modules: + iterable = _coconut_xarray_to_numpy(iterable) return not _coconut.len(iterable) or (iterable == (iterable[0] if to is _coconut_sentinel else to)).all() first_item = to for item in iterable: diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index ffbbf6151..bd434b363 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -72,6 +72,7 @@ ParserElement, MatchFirst, And, + StringStart, _trim_arity, _ParseResultsWithOffset, all_parse_elements, @@ -133,6 +134,9 @@ require_cache_clear_frac, reverse_any_of, all_keywords, + always_keep_parse_name_prefix, + keep_if_unchanged_parse_name_prefix, + incremental_use_hybrid, ) from coconut.exceptions import ( CoconutException, @@ -147,7 +151,7 @@ indexable_evaluated_tokens_types = (ParseResults, list, tuple) -def evaluate_all_tokens(all_tokens, **kwargs): +def evaluate_all_tokens(all_tokens, expand_inner=True, **kwargs): """Recursively evaluate all the tokens in all_tokens.""" all_evaluated_toks = [] for toks in all_tokens: @@ -156,10 +160,34 @@ def evaluate_all_tokens(all_tokens, **kwargs): # short-circuit the computation and return them, since they imply this parse contains invalid syntax if isinstance(evaluated_toks, ExceptionNode): return None, evaluated_toks - all_evaluated_toks.append(evaluated_toks) + elif expand_inner and isinstance(evaluated_toks, MergeNode): + all_evaluated_toks = ParseResults(all_evaluated_toks) + all_evaluated_toks += evaluated_toks # use += to avoid an unnecessary copy + else: + all_evaluated_toks.append(evaluated_toks) return all_evaluated_toks, None +def make_modified_tokens(old_tokens, new_toklist=None, new_tokdict=None, cls=ParseResults): + """Construct a modified ParseResults object from the given ParseResults object.""" + if new_toklist is None: + if DEVELOP: # avoid the overhead of the call if not develop + internal_assert(new_tokdict is None, "if new_toklist is None, new_tokdict must be None", new_tokdict) + new_toklist = old_tokens._ParseResults__toklist + new_tokdict = old_tokens._ParseResults__tokdict + # we have to pass name=None here and then set __name after otherwise + # the constructor might generate a new tokdict item we don't want; + # this also ensures that asList and modal don't matter, since they + # only do anything when you name is not None, so we don't pass them + new_tokens = cls(new_toklist) + new_tokens._ParseResults__name = old_tokens._ParseResults__name + new_tokens._ParseResults__parent = old_tokens._ParseResults__parent + new_tokens._ParseResults__accumNames.update(old_tokens._ParseResults__accumNames) + if new_tokdict is not None: + new_tokens._ParseResults__tokdict.update(new_tokdict) + return new_tokens + + def evaluate_tokens(tokens, **kwargs): """Evaluate the given tokens in the computation graph. Very performance sensitive.""" @@ -175,7 +203,7 @@ def evaluate_tokens(tokens, **kwargs): if isinstance(tokens, ParseResults): # evaluate the list portion of the ParseResults - old_toklist, old_name, asList, modal = tokens.__getnewargs__() + old_toklist = tokens._ParseResults__toklist new_toklist = None for eval_old_toklist, eval_new_toklist in evaluated_toklists: if old_toklist == eval_old_toklist: @@ -188,26 +216,25 @@ def evaluate_tokens(tokens, **kwargs): # overwrite evaluated toklists rather than appending, since this # should be all the information we need for evaluating the dictionary evaluated_toklists = ((old_toklist, new_toklist),) - # we have to pass name=None here and then set __name after otherwise - # the constructor might generate a new tokdict item we don't want - new_tokens = ParseResults(new_toklist, None, asList, modal) - new_tokens._ParseResults__name = old_name - new_tokens._ParseResults__accumNames.update(tokens._ParseResults__accumNames) # evaluate the dictionary portion of the ParseResults new_tokdict = {} for name, occurrences in tokens._ParseResults__tokdict.items(): new_occurrences = [] for value, position in occurrences: - new_value = evaluate_tokens(value, is_final=is_final, evaluated_toklists=evaluated_toklists) - if isinstance(new_value, ExceptionNode): - return new_value + if value is None: # fake value created by build_new_toks_for + new_value = None + else: + new_value = evaluate_tokens(value, is_final=is_final, evaluated_toklists=evaluated_toklists) + if isinstance(new_value, ExceptionNode): + return new_value new_occurrences.append(_ParseResultsWithOffset(new_value, position)) new_tokdict[name] = new_occurrences - new_tokens._ParseResults__tokdict.update(new_tokdict) + + new_tokens = make_modified_tokens(tokens, new_toklist, new_tokdict) if DEVELOP: # avoid the overhead of the call if not develop - internal_assert(set(tokens._ParseResults__tokdict.keys()) == set(new_tokens._ParseResults__tokdict.keys()), "evaluate_tokens on ParseResults failed to maintain tokdict keys", (tokens, "->", new_tokens)) + internal_assert(set(tokens._ParseResults__tokdict.keys()) <= set(new_tokens._ParseResults__tokdict.keys()), "evaluate_tokens on ParseResults failed to maintain tokdict keys", (tokens, "->", new_tokens)) return new_tokens @@ -238,14 +265,22 @@ def evaluate_tokens(tokens, **kwargs): result = tokens.evaluate() if is_final and isinstance(result, ExceptionNode): raise result.exception - return result + elif isinstance(result, ParseResults): + return make_modified_tokens(result, cls=MergeNode) + elif isinstance(result, list): + if len(result) == 1: + return result[0] + else: + return MergeNode(result) + else: + return result elif isinstance(tokens, list): result, exc_node = evaluate_all_tokens(tokens, is_final=is_final, evaluated_toklists=evaluated_toklists) return result if exc_node is None else exc_node elif isinstance(tokens, tuple): - result, exc_node = evaluate_all_tokens(tokens, is_final=is_final, evaluated_toklists=evaluated_toklists) + result, exc_node = evaluate_all_tokens(tokens, expand_inner=False, is_final=is_final, evaluated_toklists=evaluated_toklists) return tuple(result) if exc_node is None else exc_node elif isinstance(tokens, ExceptionNode): @@ -260,6 +295,31 @@ def evaluate_tokens(tokens, **kwargs): raise CoconutInternalException("invalid computation graph tokens", tokens) +class MergeNode(ParseResults): + """A special type of ParseResults object that should be merged into outer tokens.""" + __slots__ = () + + +def build_new_toks_for(tokens, new_toklist, unchanged=False): + """Build new tokens from tokens to return just new_toklist.""" + if USE_COMPUTATION_GRAPH and not isinstance(new_toklist, ExceptionNode): + keep_names = [ + n for n in tokens._ParseResults__tokdict + if n.startswith(always_keep_parse_name_prefix) or unchanged and n.startswith(keep_if_unchanged_parse_name_prefix) + ] + if tokens._ParseResults__name is not None and ( + tokens._ParseResults__name.startswith(always_keep_parse_name_prefix) + or unchanged and tokens._ParseResults__name.startswith(keep_if_unchanged_parse_name_prefix) + ): + keep_names.append(tokens._ParseResults__name) + if keep_names: + new_tokens = make_modified_tokens(tokens, new_toklist) + for name in keep_names: + new_tokens[name] = None + return new_tokens + return new_toklist + + class ComputationNode(object): """A single node in the computation graph.""" __slots__ = ("action", "original", "loc", "tokens") @@ -284,10 +344,9 @@ def __new__(cls, action, original, loc, tokens, ignore_no_tokens=False, ignore_o If ignore_no_tokens, then don't call the action if there are no tokens. If ignore_one_token, then don't call the action if there is only one token. If greedy, then never defer the action until later.""" - if ignore_no_tokens and len(tokens) == 0: - return [] - elif ignore_one_token and len(tokens) == 1: - return tokens[0] # could be a ComputationNode, so we can't have an __init__ + if ignore_no_tokens and len(tokens) == 0 or ignore_one_token and len(tokens) == 1: + # could be a ComputationNode, so we can't have an __init__ + return build_new_toks_for(tokens, tokens, unchanged=True) else: self = super(ComputationNode, cls).__new__(cls) if trim_arity: @@ -321,7 +380,7 @@ def evaluate(self): if isinstance(evaluated_toks, ExceptionNode): return evaluated_toks # short-circuit if we got an ExceptionNode try: - return self.action( + result = self.action( self.original, self.loc, evaluated_toks, @@ -336,6 +395,14 @@ def evaluate(self): embed(depth=2) else: raise error + out = build_new_toks_for(evaluated_toks, result) + if logger.tracing: # avoid the overhead if not tracing + dropped_keys = set(self.tokens._ParseResults__tokdict.keys()) + if isinstance(out, ParseResults): + dropped_keys -= set(out._ParseResults__tokdict.keys()) + if dropped_keys: + logger.log_tag(self.name, "DROP " + repr(dropped_keys), wrap=False) + return out def __repr__(self): """Get a representation of the entire computation graph below this node.""" @@ -492,7 +559,10 @@ def force_reset_packrat_cache(): """Forcibly reset the packrat cache and all packrat stats.""" if ParserElement._incrementalEnabled: ParserElement._incrementalEnabled = False - ParserElement.enableIncremental(incremental_mode_cache_size if in_incremental_mode() else default_incremental_cache_size, still_reset_cache=False) + ParserElement.enableIncremental( + incremental_mode_cache_size if in_incremental_mode() else default_incremental_cache_size, + **ParserElement.getIncrementalInfo() # no comma for py2 + ) else: ParserElement._packratEnabled = False ParserElement.enablePackrat(packrat_cache_size) @@ -524,6 +594,7 @@ def parsing_context(inner_parse=None): yield finally: ParserElement._incrementalWithResets = incrementalWithResets + dehybridize_cache() elif ( current_cache_matters and will_clear_cache @@ -541,12 +612,50 @@ def parsing_context(inner_parse=None): if logger.verbose: ParserElement.packrat_cache_stats[0] += old_cache_stats[0] ParserElement.packrat_cache_stats[1] += old_cache_stats[1] + elif not will_clear_cache: + try: + yield + finally: + dehybridize_cache() else: yield -def prep_grammar(grammar, streamline=False): +class StartOfStrGrammar(object): + """A container object that denotes grammars that should always be parsed at the start of the string.""" + __slots__ = ("grammar",) + start_marker = StringStart() + + def __init__(self, grammar): + self.grammar = grammar + + def with_start_marker(self): + """Get the grammar with the start marker.""" + internal_assert(not CPYPARSING, "StartOfStrGrammar.with_start_marker() should only be necessary without cPyparsing") + return self.start_marker + self.grammar + + def apply(self, grammar_transformer): + """Apply a function to transform the grammar.""" + self.grammar = grammar_transformer(self.grammar) + + @property + def name(self): + return get_name(self.grammar) + + def setName(self, *args, **kwargs): + """Equivalent to .grammar.setName.""" + return self.grammar.setName(*args, **kwargs) + + +def prep_grammar(grammar, for_scan, streamline=False, add_unpack=False): """Prepare a grammar item to be used as the root of a parse.""" + if isinstance(grammar, StartOfStrGrammar): + if for_scan: + grammar = grammar.with_start_marker() + else: + grammar = grammar.grammar + if add_unpack: + grammar = add_action(grammar, unpack) grammar = trace(grammar) if streamline: grammar.streamlined = False @@ -559,7 +668,7 @@ def prep_grammar(grammar, streamline=False): def parse(grammar, text, inner=None, eval_parse_tree=True): """Parse text using grammar.""" with parsing_context(inner): - result = prep_grammar(grammar).parseString(text) + result = prep_grammar(grammar, for_scan=False).parseString(text) if eval_parse_tree: result = unpack(result) return result @@ -580,8 +689,12 @@ def does_parse(grammar, text, inner=None): def all_matches(grammar, text, inner=None, eval_parse_tree=True): """Find all matches for grammar in text.""" + kwargs = {} + if CPYPARSING and isinstance(grammar, StartOfStrGrammar): + grammar = grammar.grammar + kwargs["maxStartLoc"] = 0 with parsing_context(inner): - for tokens, start, stop in prep_grammar(grammar).scanString(text): + for tokens, start, stop in prep_grammar(grammar, for_scan=True).scanString(text, **kwargs): if eval_parse_tree: tokens = unpack(tokens) yield tokens, start, stop @@ -603,8 +716,12 @@ def match_in(grammar, text, inner=None): def transform(grammar, text, inner=None): """Transform text by replacing matches to grammar.""" + kwargs = {} + if CPYPARSING and isinstance(grammar, StartOfStrGrammar): + grammar = grammar.grammar + kwargs["maxStartLoc"] = 0 with parsing_context(inner): - result = prep_grammar(add_action(grammar, unpack)).transformString(text) + result = prep_grammar(grammar, add_unpack=True, for_scan=True).transformString(text, **kwargs) if result == text: result = None return result @@ -692,6 +809,22 @@ def get_target_info_smart(target, mode="lowest"): # PARSING INTROSPECTION: # ----------------------------------------------------------------------------------------------------------------------- +# incremental lookup indices +_lookup_elem = 0 +_lookup_orig = 1 +_lookup_loc = 2 +# _lookup_bools = 3 +# _lookup_context = 4 +assert _lookup_elem == 0, "lookup must start with elem" + +# incremental value indices +_value_exc_loc_or_ret = 0 +# _value_furthest_loc = 1 +_value_useful = -1 +assert _value_exc_loc_or_ret == 0, "value must start with exc loc / ret" +assert _value_useful == -1, "value must end with usefullness obj" + + def maybe_copy_elem(item, name): """Copy the given grammar element if it's referenced somewhere else.""" item_ref_count = sys.getrefcount(item) if CPYTHON and not on_new_python else float("inf") @@ -824,7 +957,7 @@ def execute_clear_strat(clear_cache): if clear_cache == "useless": keys_to_del = [] for lookup, value in cache.items(): - if not value[-1][0]: + if not value[_value_useful][0]: keys_to_del.append(lookup) for del_key in keys_to_del: del cache[del_key] @@ -837,6 +970,24 @@ def execute_clear_strat(clear_cache): return orig_cache_len +def dehybridize_cache(): + """Dehybridize any hybrid entries in the incremental parsing cache.""" + if ( + CPYPARSING + # if we're not in incremental mode, we just throw away the cache + # after every parse, so no need to dehybridize it + and in_incremental_mode() + and ParserElement.getIncrementalInfo()["hybrid_mode"] + ): + cache = get_pyparsing_cache() + new_entries = {} + for lookup, value in cache.items(): + cached_item = value[0] + if cached_item is not True and not isinstance(cached_item, int): + new_entries[lookup] = (True,) + value[1:] + cache.update(new_entries) + + def clear_packrat_cache(force=False): """Clear the packrat cache if applicable. Very performance-sensitive for incremental parsing mode.""" @@ -845,6 +996,8 @@ def clear_packrat_cache(force=False): if DEVELOP: start_time = get_clock_time() orig_cache_len = execute_clear_strat(clear_cache) + # always dehybridize after cache clear so we're dehybridizing the fewest items + dehybridize_cache() if DEVELOP and orig_cache_len is not None: logger.log("Pruned packrat cache from {orig_len} items to {new_len} items using {strat!r} strategy ({time} secs).".format( orig_len=orig_cache_len, @@ -859,10 +1012,10 @@ def get_cache_items_for(original, only_useful=False, exclude_stale=True): """Get items from the pyparsing cache filtered to only be from parsing original.""" cache = get_pyparsing_cache() for lookup, value in cache.items(): - got_orig = lookup[1] + got_orig = lookup[_lookup_orig] internal_assert(lambda: isinstance(got_orig, (bytes, str)), "failed to look up original in pyparsing cache item", (lookup, value)) if ParserElement._incrementalEnabled: - (is_useful,) = value[-1] + (is_useful,) = value[_value_useful] if only_useful and not is_useful: continue if exclude_stale and is_useful >= 2: @@ -876,13 +1029,13 @@ def get_highest_parse_loc(original): Note that there's no point in filtering for successes/failures, since we always see both at the same locations.""" highest_loc = 0 for lookup, _ in get_cache_items_for(original): - loc = lookup[2] + loc = lookup[_lookup_loc] if loc > highest_loc: highest_loc = loc return highest_loc -def enable_incremental_parsing(): +def enable_incremental_parsing(reason="explicit enable_incremental_parsing call"): """Enable incremental parsing mode where prefix/suffix parses are reused.""" if not SUPPORTS_INCREMENTAL: return False @@ -890,10 +1043,15 @@ def enable_incremental_parsing(): return True ParserElement._incrementalEnabled = False try: - ParserElement.enableIncremental(incremental_mode_cache_size, still_reset_cache=False, cache_successes=incremental_mode_cache_successes) + ParserElement.enableIncremental( + incremental_mode_cache_size, + still_reset_cache=False, + cache_successes=incremental_mode_cache_successes, + hybrid_mode=incremental_mode_cache_successes and incremental_use_hybrid, + ) except ImportError as err: raise CoconutException(str(err)) - logger.log("Incremental parsing mode enabled.") + logger.log("Incremental parsing mode enabled due to {reason}.".format(reason=reason)) return True @@ -919,7 +1077,7 @@ def pickle_cache(original, cache_path, include_incremental=True, protocol=pickle break if len(pickleable_cache_items) >= incremental_cache_limit: break - loc = lookup[2] + loc = lookup[_lookup_loc] # only include cache items that aren't at the start or end, since those # are the only ones that parseIncremental will reuse if 0 < loc < len(original) - 1: @@ -929,6 +1087,7 @@ def pickle_cache(original, cache_path, include_incremental=True, protocol=pickle if validation_dict is not None: validation_dict[identifier] = elem.__class__.__name__ pickleable_lookup = (identifier,) + lookup[1:] + internal_assert(value[_value_exc_loc_or_ret] is True or isinstance(value[_value_exc_loc_or_ret], int), "cache must be dehybridized before pickling", value[_value_exc_loc_or_ret]) pickleable_cache_items.append((pickleable_lookup, value)) all_adaptive_stats = {} @@ -1017,6 +1176,7 @@ def unpickle_cache(cache_path): if maybe_elem is not None: if validation_dict is not None: internal_assert(maybe_elem.__class__.__name__ == validation_dict[identifier], "incremental cache pickle-unpickle inconsistency", (maybe_elem, validation_dict[identifier])) + internal_assert(value[_value_exc_loc_or_ret] is True or isinstance(value[_value_exc_loc_or_ret], int), "attempting to unpickle hybrid cache item", value[_value_exc_loc_or_ret]) lookup = (maybe_elem,) + pickleable_lookup[1:] usefullness = value[-1][0] internal_assert(usefullness, "loaded useless cache item", (lookup, value)) @@ -1039,7 +1199,7 @@ def load_cache_for(inputstring, codepath): incremental_enabled = True incremental_info = "using incremental parsing mode since it was already enabled" elif len(inputstring) < disable_incremental_for_len: - incremental_enabled = enable_incremental_parsing() + incremental_enabled = enable_incremental_parsing(reason="input length") if incremental_enabled: incremental_info = "incremental parsing mode enabled due to len == {input_len} < {max_len}".format( input_len=len(inputstring), @@ -1077,7 +1237,7 @@ def load_cache_for(inputstring, codepath): incremental_info=incremental_info, )) if incremental_enabled: - logger.warn("Populating initial parsing cache (compilation may take longer than usual)...") + logger.warn("Populating initial parsing cache (initial compilation may take a while; pass --no-cache to disable)...") else: cache_path = None logger.log("Declined to load cache for {filename!r} ({incremental_info}).".format( @@ -1281,6 +1441,23 @@ def labeled_group(item, label): return Group(item(label)) +def fake_labeled_group(item, label): + """Apply a label to an item in a group and then destroy the group. + Only useful with special labels that stick around.""" + + def fake_labeled_group_handle(tokens): + internal_assert(label in tokens, "failed to label with " + repr(label) + " for tokens", tokens) + [item], = tokens + return item + return attach(labeled_group(item, label), fake_labeled_group_handle) + + +def add_labels(tokens): + """Parse action to gather all the attached labels.""" + item, = tokens + return (item, tokens._ParseResults__tokdict.keys()) + + def invalid_syntax(item, msg, **kwargs): """Mark a grammar item as an invalid item that raises a syntax err with msg.""" if isinstance(item, str): @@ -1356,30 +1533,44 @@ def maybeparens(lparen, item, rparen, prefer_parens=False): return item | lparen.suppress() + item + rparen.suppress() -def interleaved_tokenlist(required_item, other_item, sep, allow_trailing=False, at_least_two=False): +def interleaved_tokenlist(required_item, other_item, sep, allow_trailing=False, at_least_two=False, multi_group=True): """Create a grammar to match interleaved required_items and other_items, where required_item must show up at least once.""" sep = sep.suppress() + + def one_or_more_group(item): + return Group(OneOrMore(item)) if multi_group else OneOrMore(Group(item)) + if at_least_two: out = ( # required sep other (sep other)* Group(required_item) - + Group(OneOrMore(sep + other_item)) + + one_or_more_group(sep + other_item) # other (sep other)* sep required (sep required)* - | Group(other_item + ZeroOrMore(sep + other_item)) - + Group(OneOrMore(sep + required_item)) + | ( + Group(other_item + ZeroOrMore(sep + other_item)) + if multi_group else + Group(other_item) + ZeroOrMore(Group(sep + other_item)) + ) + one_or_more_group(sep + required_item) # required sep required (sep required)* - | Group(required_item + OneOrMore(sep + required_item)) + | ( + Group(required_item + OneOrMore(sep + required_item)) + if multi_group else + Group(required_item) + OneOrMore(Group(sep + required_item)) + ) ) else: out = ( - Optional(Group(OneOrMore(other_item + sep))) - + Group(required_item + ZeroOrMore(sep + required_item)) - + Optional(Group(OneOrMore(sep + other_item))) + Optional(one_or_more_group(other_item + sep)) + + ( + Group(required_item + ZeroOrMore(sep + required_item)) + if multi_group else + Group(required_item) + ZeroOrMore(Group(sep + required_item)) + ) + Optional(one_or_more_group(sep + other_item)) ) out += ZeroOrMore( - Group(OneOrMore(sep + required_item)) - | Group(OneOrMore(sep + other_item)), + one_or_more_group(sep + required_item) + | one_or_more_group(sep + other_item) ) if allow_trailing: out += Optional(sep) @@ -1804,10 +1995,12 @@ def split_trailing_indent(inputstr, max_indents=None, handle_comments=True): return inputstr, "".join(reversed(indents_from_end)) -def split_leading_trailing_indent(line, max_indents=None): +def split_leading_trailing_indent(line, symmetric=False, **kwargs): """Split leading and trailing indent.""" - leading_indent, line = split_leading_indent(line, max_indents) - line, trailing_indent = split_trailing_indent(line, max_indents) + leading_indent, line = split_leading_indent(line, **kwargs) + if symmetric: + kwargs["max_indents"] = leading_indent.count(openindent) + line, trailing_indent = split_trailing_indent(line, **kwargs) return leading_indent, line, trailing_indent diff --git a/coconut/constants.py b/coconut/constants.py index c9d7d095a..4b10de3b2 100644 --- a/coconut/constants.py +++ b/coconut/constants.py @@ -93,7 +93,7 @@ def get_path_env_var(env_var, default): PY38 and not WINDOWS and not PYPY - # disabled until MyPy supports PEP 695 + # TODO: disabled until MyPy supports PEP 695 and not PY312 ) XONSH = ( @@ -111,7 +111,6 @@ def get_path_env_var(env_var, default): # set this to False only ever temporarily for ease of debugging use_fast_pyparsing_reprs = get_bool_env_var("COCONUT_FAST_PYPARSING_REPRS", True) -enable_pyparsing_warnings = DEVELOP warn_on_multiline_regex = False default_whitespace_chars = " \t\f" # the only non-newline whitespace Python allows @@ -138,25 +137,29 @@ def get_path_env_var(env_var, default): use_cache_file = True -disable_incremental_for_len = 46080 - adaptive_any_of_env_var = "COCONUT_ADAPTIVE_ANY_OF" use_adaptive_any_of = get_bool_env_var(adaptive_any_of_env_var, True) +use_line_by_line_parser = False + +# 0 for always disabled; float("inf") for always enabled +# (this determines when compiler.util.enable_incremental_parsing() is used) +disable_incremental_for_len = 20480 + # note that _parseIncremental produces much smaller caches use_incremental_if_available = False -use_line_by_line_parser = False - # these only apply to use_incremental_if_available, not compiler.util.enable_incremental_parsing() default_incremental_cache_size = None repeatedly_clear_incremental_cache = True never_clear_incremental_cache = False +# also applies to compiler.util.enable_incremental_parsing() if incremental_mode_cache_successes is True +incremental_use_hybrid = True # this is what gets used in compiler.util.enable_incremental_parsing() incremental_mode_cache_size = None incremental_cache_limit = 2097152 # clear cache when it gets this large -incremental_mode_cache_successes = False +incremental_mode_cache_successes = False # if False, also disables hybrid mode require_cache_clear_frac = 0.3125 # require that at least this much of the cache must be cleared on each cache clear use_left_recursion_if_available = False @@ -179,22 +182,22 @@ def get_path_env_var(env_var, default): sys.setrecursionlimit(default_recursion_limit) # modules that numpy-like arrays can live in -xarray_modules = ( - "xarray", +jax_numpy_modules = ( + "jaxlib", ) pandas_modules = ( "pandas", ) -jax_numpy_modules = ( - "jaxlib", +xarray_modules = ( + "xarray", ) numpy_modules = ( "numpy", "torch", ) + ( - xarray_modules + jax_numpy_modules + pandas_modules - + jax_numpy_modules + + xarray_modules ) legal_indent_chars = " \t" # the only Python-legal indent chars @@ -219,6 +222,7 @@ def get_path_env_var(env_var, default): (3, 11), (3, 12), (3, 13), + (3, 14), ) # must be in ascending order and kept up-to-date with https://devguide.python.org/versions @@ -230,6 +234,7 @@ def get_path_env_var(env_var, default): ("311", dt.datetime(2027, 11, 1)), ("312", dt.datetime(2028, 11, 1)), ("313", dt.datetime(2029, 11, 1)), + ("314", dt.datetime(2030, 11, 1)), ) # must match supported vers above and must be replicated in DOCS @@ -248,6 +253,7 @@ def get_path_env_var(env_var, default): "311", "312", "313", + "314", ) pseudo_targets = { "universal": "", @@ -287,10 +293,11 @@ def get_path_env_var(env_var, default): openindent = "\u204b" # reverse pilcrow closeindent = "\xb6" # pilcrow strwrapper = "\u25b6" # black right-pointing triangle -errwrapper = "\u24d8" # circled letter i early_passthrough_wrapper = "\u2038" # caret lnwrapper = "\u2021" # double dagger unwrapper = "\u23f9" # stop square +errwrapper = "\u24d8" # circled letter i +tempsep = "\u22ee" # vertical ellipsis funcwrapper = "def:" # must be tuples for .startswith / .endswith purposes @@ -308,12 +315,17 @@ def get_path_env_var(env_var, default): # together should include all the constants defined above delimiter_symbols = tuple(open_chars + close_chars + str_chars) + ( strwrapper, - errwrapper, early_passthrough_wrapper, unwrapper, + "`", + ":", + ",", + ";", ) + indchars + comment_chars reserved_compiler_symbols = delimiter_symbols + ( reserved_prefix, + errwrapper, + tempsep, funcwrapper, ) @@ -606,6 +618,9 @@ def get_path_env_var(env_var, default): "BaseException", "BaseExceptionGroup", "GeneratorExit", "KeyboardInterrupt", "SystemExit", "Exception", "ArithmeticError", "FloatingPointError", "OverflowError", "ZeroDivisionError", "AssertionError", "AttributeError", "BufferError", "EOFError", "ExceptionGroup", "BaseExceptionGroup", "ImportError", "ModuleNotFoundError", "LookupError", "IndexError", "KeyError", "MemoryError", "NameError", "UnboundLocalError", "OSError", "BlockingIOError", "ChildProcessError", "ConnectionError", "BrokenPipeError", "ConnectionAbortedError", "ConnectionRefusedError", "ConnectionResetError", "FileExistsError", "FileNotFoundError", "InterruptedError", "IsADirectoryError", "NotADirectoryError", "PermissionError", "ProcessLookupError", "TimeoutError", "ReferenceError", "RuntimeError", "NotImplementedError", "RecursionError", "StopAsyncIteration", "StopIteration", "SyntaxError", "IndentationError", "TabError", "SystemError", "TypeError", "ValueError", "UnicodeError", "UnicodeDecodeError", "UnicodeEncodeError", "UnicodeTranslateError", "Warning", "BytesWarning", "DeprecationWarning", "EncodingWarning", "FutureWarning", "ImportWarning", "PendingDeprecationWarning", "ResourceWarning", "RuntimeWarning", "SyntaxWarning", "UnicodeWarning", "UserWarning", ) +always_keep_parse_name_prefix = "HAS_" +keep_if_unchanged_parse_name_prefix = "IS_" + # ----------------------------------------------------------------------------------------------------------------------- # COMMAND CONSTANTS: # ----------------------------------------------------------------------------------------------------------------------- @@ -658,6 +673,8 @@ def get_path_env_var(env_var, default): ) installed_stub_dir = os.path.join(coconut_home, ".coconut_stubs") +pyright_config_file = os.path.join(coconut_home, ".coconut_pyrightconfig.json") + watch_interval = .1 # seconds info_tabulation = 18 # offset for tabulated info messages @@ -710,6 +727,10 @@ def get_path_env_var(env_var, default): ": note: ", ) +extra_pyright_args = { + "reportPossiblyUnboundVariable": False, +} + oserror_retcode = 127 kilobyte = 1024 @@ -819,6 +840,8 @@ def get_path_env_var(env_var, default): "py_xrange", "py_repr", "py_breakpoint", + "py_min", + "py_max", "_namedtuple_of", "reveal_type", "reveal_locals", @@ -831,6 +854,7 @@ def get_path_env_var(env_var, default): coconut_exceptions = ( "MatchError", + "CoconutWarning", ) highlight_builtins = coconut_specific_builtins + interp_only_builtins + python_builtins @@ -970,6 +994,11 @@ def get_path_env_var(env_var, default): "types-backports", ("typing", "py<35"), ), + "pyright": ( + "pyright", + "types-backports", + ("typing", "py<35"), + ), "watch": ( "watchdog", ), @@ -1008,7 +1037,7 @@ def get_path_env_var(env_var, default): # min versions are inclusive unpinned_min_versions = { - "cPyparsing": (2, 4, 7, 2, 3, 2), + "cPyparsing": (2, 4, 7, 2, 4, 0), ("pre-commit", "py3"): (3,), ("psutil", "py>=27"): (5,), "jupyter": (1, 0), @@ -1017,30 +1046,32 @@ def get_path_env_var(env_var, default): ("argparse", "py<27"): (1, 4), "pexpect": (4,), ("trollius", "py<3;cpy"): (2, 2), - "requests": (2, 31), + "requests": (2, 32), ("numpy", "py39"): (1, 26), ("xarray", "py39"): (2024,), ("dataclasses", "py==36"): (0, 8), ("aenum", "py<34"): (3, 1, 15), "pydata-sphinx-theme": (0, 15), - "myst-parser": (2,), + "myst-parser": (3,), "sphinx": (7,), - "mypy[python2]": (1, 8), + "mypy[python2]": (1, 10), + "pyright": (1, 1), ("jupyter-console", "py37"): (6, 6), ("typing", "py<35"): (3, 10), - ("typing_extensions", "py>=38"): (4, 9), + ("typing_extensions", "py>=38"): (4, 12), ("ipykernel", "py38"): (6,), ("jedi", "py39"): (0, 19), - ("pygments", "py>=39"): (2, 17), - ("xonsh", "py39"): (0, 15), - ("pytest", "py38"): (8,), + ("pygments", "py>=39"): (2, 18), + ("xonsh", "py39"): (0, 16), ("async_generator", "py35"): (1, 10), ("exceptiongroup", "py37;py<311"): (1,), - ("ipython", "py>=310"): (8, 22), + ("ipython", "py>=310"): (8, 25), "py-spy": (0, 3), } pinned_min_versions = { + # don't upgrade this; it breaks xonsh + ("pytest", "py38"): (8, 0), # don't upgrade these; they break on Python 3.9 ("numpy", "py34;py<39"): (1, 18), ("ipython", "py==39"): (8, 18), @@ -1107,6 +1138,7 @@ def get_path_env_var(env_var, default): ("jedi", "py<39"): _, ("pywinpty", "py<3;windows"): _, ("ipython", "py3;py<37"): _, + ("pytest", "py38"): _, } classifiers = ( diff --git a/coconut/exceptions.py b/coconut/exceptions.py index 89843a428..016eeb7d2 100644 --- a/coconut/exceptions.py +++ b/coconut/exceptions.py @@ -201,7 +201,7 @@ def message(self, message, source, point, ln, extra=None, endpoint=None, filenam message_parts += ["|"] else: message_parts += ["/", "~" * (len(lines[0]) - point_ind - 1)] - message_parts += ["~" * (max_line_len - len(lines[0])), "\n"] + message_parts += ["~" * (max_line_len - len(lines[0]) + 1), "\n"] # add code, highlighting all of it together code_parts = [] diff --git a/coconut/requirements.py b/coconut/requirements.py index 05be7b6d4..c6db84597 100644 --- a/coconut/requirements.py +++ b/coconut/requirements.py @@ -223,6 +223,7 @@ def everything_in(req_dict): "kernel": get_reqs("kernel"), "watch": get_reqs("watch"), "mypy": get_reqs("mypy"), + "pyright": get_reqs("pyright"), "xonsh": get_reqs("xonsh"), "numpy": get_reqs("numpy"), } diff --git a/coconut/root.py b/coconut/root.py index 44fe2b5c8..9c5e80b58 100644 --- a/coconut/root.py +++ b/coconut/root.py @@ -23,7 +23,7 @@ # VERSION: # ----------------------------------------------------------------------------------------------------------------------- -VERSION = "3.1.0" +VERSION = "3.1.1" VERSION_NAME = None # False for release, int >= 1 for develop DEVELOP = False @@ -61,16 +61,16 @@ def _get_target_info(target): # if a new assignment is added below, a new builtins import should be added alongside it _base_py3_header = r'''from builtins import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr -py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr -_coconut_py_str, _coconut_py_super, _coconut_py_dict = str, super, dict +py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr, py_min, py_max = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr, min, max +_coconut_py_str, _coconut_py_super, _coconut_py_dict, _coconut_py_min, _coconut_py_max = str, super, dict, min, max from functools import wraps as _coconut_wraps exec("_coconut_exec = exec") ''' # if a new assignment is added below, a new builtins import should be added alongside it _base_py2_header = r'''from __builtin__ import chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr, long -py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr -_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr, _coconut_py_dict, _coconut_py_bytes = raw_input, xrange, int, long, print, str, super, unicode, repr, dict, bytes +py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_raw_input, py_xrange, py_repr, py_min, py_max = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, raw_input, xrange, repr, min, max +_coconut_py_raw_input, _coconut_py_xrange, _coconut_py_int, _coconut_py_long, _coconut_py_print, _coconut_py_str, _coconut_py_super, _coconut_py_unicode, _coconut_py_repr, _coconut_py_dict, _coconut_py_bytes, _coconut_py_min, _coconut_py_max = raw_input, xrange, int, long, print, str, super, unicode, repr, dict, bytes, min, max from functools import wraps as _coconut_wraps from collections import Sequence as _coconut_Sequence from future_builtins import * @@ -278,26 +278,26 @@ def __call__(self, obj): _coconut_operator.methodcaller = _coconut_methodcaller ''' -_non_py37_extras = r'''def _coconut_default_breakpointhook(*args, **kwargs): - hookname = _coconut.os.getenv("PYTHONBREAKPOINT") - if hookname != "0": - if not hookname: - hookname = "pdb.set_trace" - modname, dot, funcname = hookname.rpartition(".") - if not dot: - modname = "builtins" if _coconut_sys.version_info >= (3,) else "__builtin__" - if _coconut_sys.version_info >= (2, 7): - import importlib - module = importlib.import_module(modname) +_below_py34_extras = '''def min(*args, **kwargs): + if len(args) == 1 and "default" in kwargs: + obj = tuple(args[0]) + default = kwargs.pop("default") + if len(obj): + return _coconut_py_min(obj, **kwargs) else: - import imp - module = imp.load_module(modname, *imp.find_module(modname)) - hook = _coconut.getattr(module, funcname) - return hook(*args, **kwargs) -if not hasattr(_coconut_sys, "__breakpointhook__"): - _coconut_sys.__breakpointhook__ = _coconut_default_breakpointhook -def breakpoint(*args, **kwargs): - return _coconut.getattr(_coconut_sys, "breakpointhook", _coconut_default_breakpointhook)(*args, **kwargs) + return default + else: + return _coconut_py_min(*args, **kwargs) +def max(*args, **kwargs): + if len(args) == 1 and "default" in kwargs: + obj = tuple(args[0]) + default = kwargs.pop("default") + if len(obj): + return _coconut_py_max(obj, **kwargs) + else: + return default + else: + return _coconut_py_max(*args, **kwargs) ''' _finish_dict_def = ''' @@ -321,6 +321,26 @@ def __subclasscheck__(cls, subcls): ''' _below_py37_extras = '''from collections import OrderedDict as _coconut_OrderedDict +def _coconut_default_breakpointhook(*args, **kwargs): + hookname = _coconut.os.getenv("PYTHONBREAKPOINT") + if hookname != "0": + if not hookname: + hookname = "pdb.set_trace" + modname, dot, funcname = hookname.rpartition(".") + if not dot: + modname = "builtins" if _coconut_sys.version_info >= (3,) else "__builtin__" + if _coconut_sys.version_info >= (2, 7): + import importlib + module = importlib.import_module(modname) + else: + import imp + module = imp.load_module(modname, *imp.find_module(modname)) + hook = _coconut.getattr(module, funcname) + return hook(*args, **kwargs) +if not hasattr(_coconut_sys, "__breakpointhook__"): + _coconut_sys.__breakpointhook__ = _coconut_default_breakpointhook +def breakpoint(*args, **kwargs): + return _coconut.getattr(_coconut_sys, "breakpointhook", _coconut_default_breakpointhook)(*args, **kwargs) class _coconut_dict_base(_coconut_OrderedDict): __slots__ = () __doc__ = getattr(_coconut_OrderedDict, "__doc__", "") @@ -385,15 +405,17 @@ def _get_root_header(version="universal"): header += r'''py_breakpoint = breakpoint ''' elif version == "3": - header += r'''if _coconut_sys.version_info < (3, 7): -''' + _indent(_non_py37_extras) + r'''else: + header += r'''if _coconut_sys.version_info >= (3, 7): py_breakpoint = breakpoint ''' - else: - assert version.startswith("2"), version - header += _non_py37_extras - if version == "2": - header += _py26_extras + elif version == "2": + header += _py26_extras + + if version.startswith("2"): + header += _below_py34_extras + elif version_info < (3, 4): + header += r'''if _coconut_sys.version_info < (3, 4): +''' + _indent(_below_py34_extras) if version == "3": header += r'''if _coconut_sys.version_info < (3, 7): diff --git a/coconut/terminal.py b/coconut/terminal.py index 3fe3cad9d..3ff7d432a 100644 --- a/coconut/terminal.py +++ b/coconut/terminal.py @@ -186,11 +186,15 @@ def logging(self): def should_use_color(file=None): """Determine if colors should be used for the given file object.""" use_color = get_bool_env_var(use_color_env_var, default=None) + if use_color is None: + use_color = get_bool_env_var("PYTHON_COLORS", default=None) if use_color is not None: return use_color - if get_bool_env_var("CLICOLOR_FORCE") or get_bool_env_var("FORCE_COLOR"): + if get_bool_env_var("NO_COLOR"): + return False + if get_bool_env_var("FORCE_COLOR") or get_bool_env_var("CLICOLOR_FORCE"): return True - return file is not None and not isatty(file) + return file is not None and isatty(file) # ----------------------------------------------------------------------------------------------------------------------- @@ -495,7 +499,7 @@ def print_trace(self, *args): trace = " ".join(str(arg) for arg in args) self.printlog(_indent(trace, self.trace_ind)) - def log_tag(self, tag, block, multiline=False, force=False): + def log_tag(self, tag, block, multiline=False, wrap=True, force=False): """Logs a tagged message if tracing.""" if self.tracing or force: assert not (not DEVELOP and force), tag @@ -505,7 +509,7 @@ def log_tag(self, tag, block, multiline=False, force=False): if multiline: self.print_trace(tagstr + "\n" + displayable(block)) else: - self.print_trace(tagstr, ascii(block)) + self.print_trace(tagstr, ascii(block) if wrap else block) def log_trace(self, expr, original, loc, item=None, extra=None): """Formats and displays a trace if tracing.""" diff --git a/coconut/tests/__main__.py b/coconut/tests/__main__.py index 649ac82ed..ea35cb527 100644 --- a/coconut/tests/__main__.py +++ b/coconut/tests/__main__.py @@ -44,12 +44,13 @@ def main(args=None): # compile everything print("Compiling Coconut test suite with args %r and agnostic_target=%r." % (args, agnostic_target)) + type_checking = "--mypy" in args or "--pyright" in args comp_all( args, agnostic_target=agnostic_target, - expect_retcode=0 if "--mypy" not in args else None, + expect_retcode=0 if not type_checking else None, check_errors="--verbose" not in args, - ignore_output=WINDOWS and "--mypy" not in args, + ignore_output=WINDOWS and not type_checking, ) diff --git a/coconut/tests/main_test.py b/coconut/tests/main_test.py index 155f8e17b..07b3a04c2 100644 --- a/coconut/tests/main_test.py +++ b/coconut/tests/main_test.py @@ -109,7 +109,11 @@ else None ) -jupyter_timeout = 120 + +def pexpect(p, out): + """p.expect(out) with timeout""" + p.expect(out, timeout=120) + tests_dir = os.path.dirname(os.path.relpath(__file__)) src = os.path.join(tests_dir, "src") @@ -150,6 +154,8 @@ ignore_error_lines_with = ( # ignore SyntaxWarnings containing assert_raises or raise "raise", + # ignore Pyright errors + " - error: ", ) mypy_snip = "a: str = count()[0]" @@ -174,8 +180,12 @@ "DeprecationWarning: The distutils package is deprecated", "from distutils.version import LooseVersion", ": SyntaxWarning: 'int' object is not ", - " assert_raises(", + ": CoconutWarning: Deprecated use of ", + " assert_raises(", + " assert ", "Populating initial parsing cache", + "_coconut.warnings.warn(", + ": SyntaxWarning: invalid escape sequence", ) kernel_installation_msg = ( @@ -339,7 +349,7 @@ def call( continue # combine mypy error lines - if any(infix in line for infix in mypy_err_infixes): + if any(infix in line for infix in mypy_err_infixes) and i < len(raw_lines) - 1: # always add the next line, since it might be a continuation of the error message line += "\n" + raw_lines[i + 1] i += 1 @@ -680,14 +690,17 @@ def run( """Compiles and runs tests.""" assert use_run_arg + run_directory < 2 + if manage_cache and "--no-cache" not in args: + args = ["--no-cache"] + args + if agnostic_target is None: agnostic_args = args else: agnostic_args = ["--target", str(agnostic_target)] + args - with (using_caches() if manage_cache else noop_ctx()): + with using_caches() if manage_cache else noop_ctx(): with using_dest(): - with (using_dest(additional_dest) if "--and" in args else noop_ctx()): + with using_dest(additional_dest) if "--and" in args else noop_ctx(): spec_kwargs = kwargs.copy() spec_kwargs["always_sys"] = always_sys @@ -802,7 +815,7 @@ def comp_prelude(args=[], **kwargs): def run_prelude(**kwargs): """Runs coconut-prelude.""" call(["make", "base-install"], cwd=prelude) - call(["pytest", "--strict-markers", "-s", os.path.join(prelude, "prelude")], assert_output="passed", **kwargs) + call(["pytest", "--strict-markers", "-s", os.path.join(prelude, "prelude")], assert_output=" passed in ", assert_output_only_at_end=False, **kwargs) def comp_bbopt(args=[], **kwargs): @@ -919,37 +932,37 @@ def test_import_runnable(self): if not WINDOWS and XONSH: def test_xontrib(self): p = spawn_cmd("xonsh") - p.expect("$") + pexpect(p, "$") p.sendline("xontrib load coconut") - p.expect("$") + pexpect(p, "$") p.sendline("!(ls -la) |> bool") - p.expect("True") + pexpect(p, "True") p.sendline("'1; 2' |> print") - p.expect("1; 2") + pexpect(p, "1; 2") p.sendline('$ENV_VAR = "ABC"') - p.expect("$") + pexpect(p, "$") p.sendline('echo f"{$ENV_VAR}"; echo f"{$ENV_VAR}"') - p.expect("ABC") - p.expect("ABC") + pexpect(p, "ABC") + pexpect(p, "ABC") p.sendline('len("""1\n3\n5""")\n') - p.expect("5") + pexpect(p, "5") if not PYPY or PY39: if PY36: p.sendline("echo 123;; 123") - p.expect("123;; 123") + pexpect(p, "123;; 123") p.sendline("echo abc; echo abc") - p.expect("abc") - p.expect("abc") + pexpect(p, "abc") + pexpect(p, "abc") p.sendline("echo abc; print(1 |> (.+1))") - p.expect("abc") - p.expect("2") + pexpect(p, "abc") + pexpect(p, "2") p.sendline('execx("10 |> print")') - p.expect("subprocess mode") + pexpect(p, ["subprocess mode", "IndexError"]) p.sendline("xontrib unload coconut") - p.expect("$") + pexpect(p, "$") if (not PYPY or PY39) and PY36: p.sendline("1 |> print") - p.expect("subprocess mode") + pexpect(p, ["subprocess mode", "IndexError"]) p.sendeof() if p.isalive(): p.terminate() @@ -974,12 +987,12 @@ def test_kernel_installation(self): if not WINDOWS and not PYPY: def test_jupyter_console(self): p = spawn_cmd("coconut --jupyter console") - p.expect("In", timeout=jupyter_timeout) + pexpect(p, "In") p.sendline("%load_ext coconut") - p.expect("In", timeout=jupyter_timeout) + pexpect(p, "In") p.sendline("`exit`") if sys.version_info[:2] != (3, 6): - p.expect("Shutting down kernel|shutting down", timeout=jupyter_timeout) + pexpect(p, "Shutting down kernel|shutting down") if p.isalive(): p.terminate() @@ -1092,11 +1105,11 @@ def test_bbopt(self): if not PYPY and PY38 and not PY310: install_bbopt() - def test_pyprover(self): - with using_paths(pyprover): - comp_pyprover() - if PY38: - run_pyprover() + # def test_pyprover(self): + # with using_paths(pyprover): + # comp_pyprover() + # if PY38: + # run_pyprover() def test_pyston(self): with using_paths(pyston): diff --git a/coconut/tests/src/cocotest/agnostic/primary_1.coco b/coconut/tests/src/cocotest/agnostic/primary_1.coco index b8e9a44d5..bfe7888cf 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_1.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_1.coco @@ -56,10 +56,10 @@ def primary_test_1() -> bool: assert x == 5 x == 6 assert x == 5 - assert r"hello, world" == "hello, world" == "hello," " " "world" + assert r"hello, world" == "hello, world" == "hello," + " " + "world" assert "\n " == """ """ - assert "\\" "\"" == "\\\"" + assert "\\" + "\"" == "\\\"" assert """ """ == "\n\n" @@ -812,9 +812,9 @@ def primary_test_1() -> bool: else: assert False x = 1 - assert f"{x}" f"{x}" == "11" - assert f"{x}" "{x}" == "1{x}" - assert "{x}" f"{x}" == "{x}1" + assert f"{x}" + f"{x}" == "11" + assert f"{x}" + "{x}" == "1{x}" + assert "{x}" + f"{x}" == "{x}1" assert (if False then 1 else 2) == 2 == (if False then 1 else if True then 2 else 3) class metaA(type): def __instancecheck__(cls, inst): @@ -1054,12 +1054,10 @@ def primary_test_1() -> bool: assert False init :: (3,) = (|1, 2, 3|) assert init == (1, 2) - assert "a\"z""a"'"'"z" == 'a"za"z' - assert b"ab" b"cd" == b"abcd" == rb"ab" br"cd" + assert "a\"z"+"a"+'"'+"z" == 'a"za"z' + assert b"ab" + b"cd" == b"abcd" == rb"ab" + br"cd" "a" + "c" = "ac" b"a" + b"c" = b"ac" - "a" "c" = "ac" - b"a" b"c" = b"ac" (1, *xs, 4) = (|1, 2, 3, 4|) assert xs == [2, 3] assert xs `isinstance` list @@ -1146,9 +1144,9 @@ def primary_test_1() -> bool: key = "abc" f"{key}: " + value = "abc: xyz" assert value == "xyz" - f"{key}" ": " + value = "abc: 123" + f"{key}" + ": " + value = "abc: 123" assert value == "123" - "{" f"{key}" ": " + value + "}" = "{abc: aaa}" + "{" + f"{key}" + ": " + value + "}" = "{abc: aaa}" assert value == "aaa" try: 2 @ 3 # type: ignore diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index ee8ca556b..e95fa4c61 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -313,10 +313,10 @@ def primary_test_2() -> bool: f.is_f = True # type: ignore assert (f ..*> (+)).is_f # type: ignore really_long_var = 10 - assert (...=really_long_var) == (10,) - assert (...=really_long_var, abc="abc") == (10, "abc") - assert (abc="abc", ...=really_long_var) == ("abc", 10) - assert (...=really_long_var).really_long_var == 10 # type: ignore + assert (really_long_var=) == (10,) + assert (really_long_var=, abc="abc") == (10, "abc") + assert (abc="abc", really_long_var=) == ("abc", 10) + assert (really_long_var=).really_long_var == 10 # type: ignore n = [0] assert n[0] == 0 assert_raises(-> m{{1:2,2:3}}, TypeError) @@ -455,6 +455,17 @@ def primary_test_2() -> bool: match def maybe_dup(x, y=x) = (x, y) assert maybe_dup(1) == (1, 1) == maybe_dup(x=1) assert maybe_dup(1, 2) == (1, 2) == maybe_dup(x=1, y=2) + assert min((), default=10) == 10 == max((), default=10) + assert py_min(3, 4) == 3 == py_max(2, 3) + assert len(zip()) == 0 == len(zip_longest()) # type: ignore + assert CoconutWarning `issubclass` Warning + x = y = 2 + assert f"{x + y = }" == "x + y = 4" + assert f""" +"{x}" +""" == '\n"2"\n' + assert f"\{1}" == "\\1" + assert f''' '{1}' ''' == " '1' " with process_map.multiple_sequential_calls(): # type: ignore assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore diff --git a/coconut/tests/src/cocotest/agnostic/suite.coco b/coconut/tests/src/cocotest/agnostic/suite.coco index 45d96810a..ae651af80 100644 --- a/coconut/tests/src/cocotest/agnostic/suite.coco +++ b/coconut/tests/src/cocotest/agnostic/suite.coco @@ -10,6 +10,7 @@ from .util import operator CONST from .util import operator “ from .util import operator ” from .util import operator ! +from .util import operator <> operator lol operator ++ @@ -116,6 +117,8 @@ def suite_test() -> bool: test_factorial(factorial2) test_factorial(factorial4) test_factorial(factorial5) + test_factorial(factorial6) + test_factorial(factorial7) test_factorial(fact, test_none=False) test_factorial(fact_, test_none=False) test_factorial(factorial, test_none=False) @@ -150,9 +153,11 @@ def suite_test() -> bool: assert -4 == neg_square_u(2) ≠ 4 ∩ 0 ≤ neg_square_u(0) ≤ 0 assert is_null(null1()) assert is_null(null2()) - assert empty() |> depth_1 == 0 == empty() |> depth_2 - assert leaf(5) |> depth_1 == 1 == leaf(5) |> depth_2 - assert node(leaf(2), node(empty(), leaf(3))) |> depth_1 == 3 == node(leaf(2), node(empty(), leaf(3))) |> depth_2 + for depth in (depth_1, depth_2, depth_3): + assert empty() |> depth == 0 # type: ignore + assert leaf(5) |> depth == 1 # type: ignore + assert node(leaf(2), node(empty(), leaf(3))) |> depth == 3 # type: ignore + assert size(node(empty(), leaf(10))) == 1 == size_(node(empty(), leaf(10))) assert maybes(5, square, plus1) == 26 assert maybes(None, square, plus1) is None assert square <| 2 == 4 @@ -435,6 +440,7 @@ def suite_test() -> bool: assert partition([1, 2, 3], 2) |> map$(tuple) |> list == [(1,), (3, 2)] == partition_([1, 2, 3], 2) |> map$(tuple) |> list assert myreduce((+), (1, 2, 3)) == 6 assert recurse_n_times(10000) + assert recurse_n_times_(10000) assert fake_recurse_n_times(10000) a = clsA() assert ((not)..a.true)() is False @@ -533,7 +539,7 @@ def suite_test() -> bool: assert False tv = typed_vector() assert repr(tv) == "typed_vector(x=0, y=0)" - for obj in (factorial, iadd, collatz, recurse_n_times): + for obj in (factorial, iadd, collatz, recurse_n_times, recurse_n_times_): assert obj.__doc__ == "this is a docstring", obj assert list_type((|1,2|)) == "at least 2" assert list_type((|1|)) == "at least 1" @@ -628,8 +634,10 @@ def suite_test() -> bool: assert dt.lam() == dt assert dt.comp() == (dt,) assert dt.N()$[:2] |> list == [(dt, 0), (dt, 1)] == dt.N_()$[:2] |> list - assert map(Ad().ef, range(5)) |> list == range(1, 6) |> list - assert Ad().ef 1 == 2 + assert map(HasDefs().a_def, range(5)) |> list == range(1, 6) |> list + assert HasDefs().a_def 1 == 2 + assert HasDefs().case_def 1 == 0 == HasDefs().case_def_ 1 + assert HasDefs.__annotations__.keys() |> set == {"a_def"}, HasDefs.__annotations__ assert store.plus1 store.one == store.two assert ret_locals()["my_loc"] == 1 assert ret_globals()["my_glob"] == 1 @@ -739,7 +747,8 @@ def suite_test() -> bool: (-> 3) -> _ `isinstance` int = "a" # type: ignore assert empty_it() |> list == [] == empty_it_of_int(1) |> list assert just_it(1) |> list == [1] - assert just_it_of_int(1) |> list == [1] == just_it_of_int_(1) |> list + assert just_it_of_int1(1) |> list == [1] == just_it_of_int2(1) |> list + assert just_it_of_int3(1) |> list == [1] == just_it_of_int4(1) |> list assert must_be_int(4) == 4 == must_be_int_(4) assert typed_plus(1, 2) == 3 (class inh_A() `isinstance` clsA) `isinstance` object = inh_A() @@ -864,7 +873,7 @@ forward 2""") == 900 assert split1_comma(",") == ("", "") assert split1_comma("abcd") == ("abcd", "") assert primes()$[:5] |> tuple == (2, 3, 5, 7, 11) - assert twin_primes()$[:5] |> list == [(3, 5), (5, 7), (11, 13), (17, 19), (29, 31)] + assert twin_primes()$[:5] |> list == [(3, 5), (5, 7), (11, 13), (17, 19), (29, 31)] == twin_primes_()$[:5] |> list assert stored_default(2) == [2, 1] == stored_default_cls()(2) assert stored_default(2) == [2, 1, 2, 1] == stored_default_cls()(2) if sys.version_info >= (3,): # naive namespace classes don't work on py2 @@ -1036,7 +1045,7 @@ forward 2""") == 900 assert (+) `on` (.*2) <*| (3, 5) == 16 assert test_super_B().method({'somekey': 'string', 'someotherkey': 42}) assert outer_func_normal() |> map$(call) |> list == [4] * 5 - for outer_func in (outer_func_1, outer_func_2, outer_func_3, outer_func_4, outer_func_5): + for outer_func in (outer_func_1, outer_func_2, outer_func_3, outer_func_4, outer_func_5, outer_func_6): assert outer_func() |> map$(call) |> list == range(5) |> list assert get_glob() == 0 assert wrong_get_set_glob(10) == 0 @@ -1046,8 +1055,8 @@ forward 2""") == 900 assert InitAndIter(range(3)) |> fmap$((.+1), fallback_to_init=True) == InitAndIter(range(1, 4)) assert_raises(-> InitAndIter(range(3)) |> fmap$(.+1), TypeError) really_long_var = 10 - assert ret_args_kwargs(...=really_long_var) == ((), {"really_long_var": 10}) == ret_args_kwargs$(...=really_long_var)() - assert ret_args_kwargs(123, ...=really_long_var, abc="abc") == ((123,), {"really_long_var": 10, "abc": "abc"}) == ret_args_kwargs$(123, ...=really_long_var, abc="abc")() + assert ret_args_kwargs(really_long_var=) == ((), {"really_long_var": 10}) == ret_args_kwargs$(really_long_var=)() + assert ret_args_kwargs(123, really_long_var=, abc="abc") == ((123,), {"really_long_var": 10, "abc": "abc"}) == ret_args_kwargs$(123, really_long_var=, abc="abc")() assert "Coconut version of typing" in typing.__doc__ numlist: NumList = [1, 2.3, 5] assert hasloc([[1, 2]]).loc[0][1] == 2 == hasloc([[1, 2]]) |> .loc[0][1] @@ -1079,6 +1088,10 @@ forward 2""") == 900 assert first_false_and_last_true([3, 2, 1, 0, "11", "1", ""]) == (0, "1") assert ret_args_kwargs ↤** dict(a=1) == ((), dict(a=1)) assert ret_args_kwargs ↤**? None is None + assert [1, 2, 3] |> reduce_with_init$(+) == 6 == (1, 2, 3) |> iter |> reduce_with_init$((+), init=0) + assert min(1, 2) == 1 == my_min(1, 2) + assert min([1, 2]) == 1 == my_min([1, 2]) + assert 3 <> 4 with process_map.multiple_sequential_calls(): # type: ignore assert process_map(tuple <.. (|>)$(to_sort), qsorts) |> list == [to_sort |> sorted |> tuple] * len(qsorts) diff --git a/coconut/tests/src/cocotest/agnostic/util.coco b/coconut/tests/src/cocotest/agnostic/util.coco index f58003eec..56dbe52c5 100644 --- a/coconut/tests/src/cocotest/agnostic/util.coco +++ b/coconut/tests/src/cocotest/agnostic/util.coco @@ -2,6 +2,7 @@ import sys import random import pickle +import typing import operator # NOQA from contextlib import contextmanager from functools import wraps @@ -10,6 +11,8 @@ from collections import defaultdict, deque __doc__ = "docstring" # Helpers: +___ = typing.cast(typing.Any, ...) + def rand_list(n): '''Generate a random list of length n.''' return [random.randrange(10) for x in range(0, n)] @@ -237,13 +240,18 @@ operator ” (“) = (”) = (,) ..> map$(str) ..> "".join operator ! -addpattern def (int(x))! = 0 if x else 1 # type: ignore +match def (int(x))! = 0 if x else 1 # type: ignore addpattern def (float(x))! = 0.0 if x else 1.0 # type: ignore addpattern def x! if x = False # type: ignore addpattern def x! = True # type: ignore +operator <> +case def <>: + case(x, y if x < y) = True + case(x, y if x > y) = True + case(x, y) = False + # Type aliases: -import typing if sys.version_info >= (3, 5) or TYPE_CHECKING: type list_or_tuple = list | tuple @@ -395,6 +403,11 @@ def recurse_n_times(n) = return True recurse_n_times(n-1) +case def recurse_n_times_: + """this is a docstring""" + case(0) = True + case(n) = recurse_n_times_(n-1) + def is_even(n) = if not n: return True @@ -632,12 +645,26 @@ def factorial5(value): else: return None raise TypeError() +case def factorial6[Num: (int, float)]: + """Factorial function""" + type(n: Num, acc: Num = ___) -> Num + case (0, acc=1): + return acc + case (int(n), acc=1 if n > 0): + return factorial6(n - 1, acc * n) + case (int(n), acc=... if n < 0): + return None +case def factorial7[Num <: int | float]: + type(n: Num, acc: Num = ___) -> Num + case(0, acc=1) = acc + case(int(n), acc=1 if n > 0) = factorial7(n - 1, acc * n) + case(int(n), acc=... if n < 0) = None match def fact(n) = fact(n, 1) match addpattern def fact(0, acc) = acc # type: ignore addpattern match def fact(n, acc) = fact(n-1, acc*n) # type: ignore -addpattern def factorial(0, acc=1) = acc +match def factorial(0, acc=1) = acc addpattern def factorial(int() as n, acc=1 if n > 0) = # type: ignore """this is a docstring""" factorial(n-1, acc*n) @@ -766,6 +793,20 @@ def depth_2(t): match tree(l=l, r=r) in t: return 1 + max([depth_2(l), depth_2(r)]) +case def depth_3: + case(tree()) = 0 + case(tree(n=n)) = 1 + case(tree(l=l, r=r)) = 1 + max(depth_3(l), depth_3(r)) + +def size(empty()) = 0 +addpattern def size(leaf(n)) = 1 # type: ignore +addpattern def size(node(l, r)) = size(l) + size(r) # type: ignore + +case def size_: + case(empty()) = 0 + case(leaf(n)) = 1 + case(node(l, r)) = size_(l) + size_(r) + class Tree data Node(*children) from Tree data Leaf(elem) from Tree @@ -913,12 +954,12 @@ class MySubExc(MyExc): class test_super_A: @classmethod - addpattern def method(cls, {'somekey': str()}) = True + match def method(cls, {'somekey': str()}) = True class test_super_B(test_super_A): @classmethod - addpattern def method(cls, {'someotherkey': int(), **rest}) = + match def method(cls, {'someotherkey': int(), **rest}) = super().method(rest) @@ -1397,12 +1438,24 @@ class descriptor_test: [(self, i)] :: self.N_(i=i+1) -# Function named Ad.ef -class Ad: - ef: typing.Callable +# Annotation checking +class HasDefs: + a_def: typing.Callable + + @staticmethod + case def case_def: + type(_: int) -> int + case(0) = 1 + case(1) = 0 + +def HasDefs.a_def(self, 0) = 1 # type: ignore +addpattern def HasDefs.a_def(self, x) = x + 1 # type: ignore -def Ad.ef(self, 0) = 1 # type: ignore -addpattern def Ad.ef(self, x) = x + 1 # type: ignore +@staticmethod # type: ignore +case def HasDefs.case_def_: # type: ignore + type(_: int) -> int + case(0) = 1 + case(1) = 0 # Storage class @@ -1511,12 +1564,20 @@ yield def just_it(x): yield x yield def empty_it_of_int(int() as x): pass -yield match def just_it_of_int(int() as x): +yield match def just_it_of_int1(int() as x): yield x -match yield def just_it_of_int_(int() as x): +match yield def just_it_of_int2(int() as x): yield x +yield case def just_it_of_int3: + case(int() as x): + yield x + +case yield def just_it_of_int4: + case(int() as x): + yield x + yield def num_it() -> int$[]: yield 5 @@ -1686,7 +1747,7 @@ sum_evens = ( ) -# n-ary reduction +# reduction def binary_reduce(binop, it) = ( it |> reiterable @@ -1703,6 +1764,8 @@ def nary_reduce(n, op, it) = ( binary_reduce_ = nary_reduce$(2) +match def reduce_with_init(f, xs, init=type(xs[0])()) = reduce(f, xs, init) + # last/end import operator @@ -1897,6 +1960,12 @@ def twin_primes(_ :: [p, (.-2) -> p] :: ps) = addpattern def twin_primes() = # type: ignore twin_primes(primes()) +case def twin_primes_: + case(_ :: [p, (.-2) -> p] :: ps) = + [(p, p+2)] :: twin_primes_([p + 2] :: ps) + case() = + twin_primes_(primes()) + # class matching class HasElems: @@ -2044,3 +2113,24 @@ def outer_func_5() -> (() -> int)[]: copyclosure def inner_func() -> int = x funcs.append(inner_func) return funcs + +def outer_func_6(): + funcs = [] + for x in range(5): + copyclosure case def inner_func: + case(y) = y + case() = x + funcs.append(inner_func) + return funcs + + +# case def + +case def my_min[T]: + type(xs: T[]) -> T + case([x]) = x + case([x] + xs) = my_min(x, my_min(xs)) + + type(x: T, y: T) -> T + case(x, y if x <= y) = x + case(x, y) = y diff --git a/coconut/tests/src/cocotest/non_strict/non_strict_test.coco b/coconut/tests/src/cocotest/non_strict/non_strict_test.coco index a21b8a155..5338ea7ed 100644 --- a/coconut/tests/src/cocotest/non_strict/non_strict_test.coco +++ b/coconut/tests/src/cocotest/non_strict/non_strict_test.coco @@ -93,6 +93,28 @@ def non_strict_test() -> bool: @recursive_iterator def fib() = (1, 1) :: map((+), fib(), fib()$[1:]) assert fib()$[:5] |> list == [1, 1, 2, 3, 5] + addpattern def args_or_kwargs(*args) = args + addpattern def args_or_kwargs(**kwargs) = kwargs # type: ignore + assert args_or_kwargs(1, 2) == (1, 2) + very_long_name = 10 + assert args_or_kwargs(short_name=5, very_long_name=) == {"short_name": 5, "very_long_name": 10} + assert "hello," " " "world" == "hello, world" + assert "\\" "\"" == "\\\"" + x = 1 + assert f"{x}" f"{x}" == "11" + assert f"{x}" "{x}" == "1{x}" + assert "{x}" f"{x}" == "{x}1" + assert "a\"z""a"'"'"z" == 'a"za"z' + assert b"ab" b"cd" == b"abcd" == rb"ab" br"cd" + "a" "c" = "ac" + b"a" b"c" = b"ac" + key = "abc" + f"{key}" ": " + value = "abc: 123" + assert value == "123" + "{" f"{key}" ": " + value + "}" = "{abc: aaa}" + assert value == "aaa" + assert """ """\ + == " " return True if __name__ == "__main__": diff --git a/coconut/tests/src/cocotest/target_sys/target_sys_test.coco b/coconut/tests/src/cocotest/target_sys/target_sys_test.coco index c65bc4125..6600aaa50 100644 --- a/coconut/tests/src/cocotest/target_sys/target_sys_test.coco +++ b/coconut/tests/src/cocotest/target_sys/target_sys_test.coco @@ -48,12 +48,16 @@ def asyncio_test() -> bool: async def async_map_0(args): return thread_map(args[0], *args[1:]) - async def async_map_1(args) = thread_map(args[0], *args[1:]) - async def async_map_2([func] + iters) = thread_map(func, *iters) - async match def async_map_3([func] + iters) = thread_map(func, *iters) - match async def async_map_4([func] + iters) = thread_map(func, *iters) + async def async_map_1(args) = map(args[0], *args[1:]) + async def async_map_2([func] + iters) = map(func, *iters) + async match def async_map_3([func] + iters) = map(func, *iters) + match async def async_map_4([func] + iters) = map(func, *iters) + async case def async_map_5: + case([func] + iters) = map(func, *iters) + case async def async_map_6: + case([func] + iters) = map(func, *iters) async def async_map_test() = - for async_map_ in (async_map_0, async_map_1, async_map_2, async_map_3, async_map_4): + for async_map_ in (async_map_0, async_map_1, async_map_2, async_map_3, async_map_4, async_map_5, async_map_6): assert (await ((pow$(2), range(5)) |> async_map_)) |> tuple == (1, 2, 4, 8, 16) True diff --git a/coconut/tests/src/extras.coco b/coconut/tests/src/extras.coco index 0bb22fbde..13c69496f 100644 --- a/coconut/tests/src/extras.coco +++ b/coconut/tests/src/extras.coco @@ -1,4 +1,5 @@ import os +import sys from collections.abc import Sequence os.environ["COCONUT_USE_COLOR"] = "False" @@ -14,6 +15,7 @@ from coconut.constants import ( PYPY, ) # type: ignore from coconut._pyparsing import USE_COMPUTATION_GRAPH # type: ignore +from coconut.terminal import logger from coconut.exceptions import ( CoconutSyntaxError, CoconutStyleError, @@ -32,7 +34,7 @@ from coconut.convenience import ( ) -def assert_raises(c, Exc, not_Exc=None, err_has=None): +def assert_raises(c, Exc, not_Exc=None, err_has=None) -> None: """Test whether callable c raises an exception of type Exc.""" if not_Exc is None and Exc is CoconutSyntaxError: not_Exc = CoconutParseError @@ -183,6 +185,15 @@ parsing failed for format string expression: 1+ (line 2) ^ """.strip()) + assert_raises(-> parse(""" +case def f[T]: + case(x) = x +""".strip()), CoconutSyntaxError) + assert_raises(-> parse(""" +case def f[T]: + type(x: T) -> T +""".strip()), CoconutSyntaxError) + assert_raises(-> parse("(|*?>)"), CoconutSyntaxError, err_has="'|?*>'") assert_raises(-> parse("(|**?>)"), CoconutSyntaxError, err_has="'|?**>'") assert_raises(-> parse("( parse("range(1,10) |> reduce$(*, initializer = 1000) |> print"), CoconutParseError, err_has=( "\n \\~~^", "\n \\~~~~~~~~~~~~~~~~~~~~~~~^", + ) + ( + ("\n \\~~~~~~~~~~~~^",) + if PYPY else () )) assert_raises(-> parse("a := b"), CoconutParseError, err_has=( "\n ^", @@ -263,7 +277,6 @@ def f() = assert_raises(-> parse('''f"""{ }"""'''), CoconutSyntaxError, err_has="parsing failed for format string expression") - assert_raises(-> parse("return = 1"), CoconutParseError, err_has='invalid use of the keyword "return"') assert_raises(-> parse("if a = b: pass"), CoconutParseError, err_has="misplaced assignment") assert_raises(-> parse("while a == b"), CoconutParseError, err_has="misplaced newline") @@ -313,7 +326,6 @@ def g(x) = x assert parse("def f(x):\n ${var}", "xonsh") == "def f(x):\n ${var}\n" assert "data ABC" not in parse("data ABC:\n ${var}", "xonsh") - assert parse('"abc" "xyz"', "lenient") == "'abcxyz'" assert "builder" not in parse("def x -> x", "lenient") assert parse("def x -> x", "lenient").count("def") == 1 assert "builder" in parse("x -> def y -> (x, y)", "lenient") @@ -325,7 +337,8 @@ def g(x) = x return True -def test_convenience() -> bool: +def test_api() -> bool: + assert not logger.enable_colors(sys.stdout) if IPY: import coconut.highlighter # noqa # type: ignore @@ -368,6 +381,9 @@ line 6''') assert_raises(-> parse("match def kwd_only_x_is_int_def_0(*, x is int = 0) = x"), CoconutStyleError, err_has=( "\n ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~|", "\n ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~/", + ) + ( + ("\n ^",) + if PYPY else () )) try: parse(""" @@ -398,25 +414,28 @@ import abc except CoconutStyleError as err: assert str(err) == """found unused import 'abc' (add '# NOQA' to suppress) (remove --strict to downgrade to a warning) (line 1) import abc""" - assert_raises(-> parse("""class A(object): - 1 - 2 - 3 - 4 - 5 - 6 - 7 - 8 - 9 - 10 - 11 - 12 - 13 - 14 - 15"""), CoconutStyleError, err_has="\n ...\n") + assert_raises(-> parse(""" +class A(object): + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + 11 + 12 + 13 + 14 + 15 + """.strip()), CoconutStyleError, **(dict(err_has="\n ...\n") if not PYPY else {})) + assert_raises(-> parse('["abc", "def" "ghi"]'), CoconutStyleError, err_has="implicit string concatenation") setup(line_numbers=False, strict=True, target="sys") - assert_raises(-> parse("await f x"), CoconutParseError, err_has='invalid use of the keyword "await"') + assert_raises(-> parse("await f x"), CoconutParseError) setup(line_numbers=False, target="2.7") assert parse("from io import BytesIO", mode="lenient") == "from io import BytesIO" @@ -467,15 +486,16 @@ async def async_map_test() = assert parse("a[x, *y]") setup(line_numbers=False, target="3.12") - assert parse("type Num = int | float").strip().endswith(""" -# Compiled Coconut: ----------------------------------------------------------- - -type Num = int | float""".strip()) - assert parse("type L[T] = list[T]").strip().endswith(""" -# Compiled Coconut: ----------------------------------------------------------- - -_coconut_typevar_T_0 = _coconut.typing.TypeVar("_coconut_typevar_T_0") -type L = list[_coconut_typevar_T_0]""".strip()) + assert parse("type Num = int | float", "lenient").strip() == """ +type Num = int | float +""".strip() + assert parse("type L[T] = list[T]", "lenient").strip() == """ +type L[T] = list[T] +""".strip() + assert parse("def f[T](x) = x", "lenient") == """ +def f[T](x): + return x +""".strip() setup(line_numbers=False, minify=True) assert parse("123 # derp", "lenient") == "123# derp" @@ -509,7 +529,7 @@ class F: def test_kernel() -> bool: # hide imports so as to not enable incremental parsing until we want to - if PY35: + if PY35 or TYPE_CHECKING: import asyncio from coconut.icoconut import CoconutKernel # type: ignore from jupyter_client.session import Session @@ -531,7 +551,7 @@ def test_kernel() -> bool: k = CoconutKernel() fake_session = FakeSession() assert k.shell is not None - k.shell.displayhook.session = fake_session + k.shell.displayhook.session = fake_session # type: ignore exec_result = k.do_execute("derp = pow$(?, 2)", False, True, {"two": "(+)(1, 1)"}, True) |> unwrap_future$(loop) assert exec_result["status"] == "ok", exec_result @@ -734,7 +754,7 @@ def test_extras() -> bool: print(".") # newline bc we print stuff after this assert test_setup_none() is True # ... print(".") # ditto - assert test_convenience() is True # .... + assert test_api() is True # .... # everything after here uses incremental parsing, so it must come last print(".", end="") assert test_incremental() is True # ..... diff --git a/coconut/util.py b/coconut/util.py index f9f4905d0..51b8abc3c 100644 --- a/coconut/util.py +++ b/coconut/util.py @@ -291,6 +291,19 @@ def assert_remove_prefix(inputstr, prefix, allow_no_prefix=False): remove_prefix = partial(assert_remove_prefix, allow_no_prefix=True) +def assert_remove_suffix(inputstr, suffix, allow_no_suffix=False): + """Remove prefix asserting that inputstr starts with it.""" + assert suffix, suffix + if not allow_no_suffix: + assert inputstr.endswith(suffix), inputstr + elif not inputstr.endswith(suffix): + return inputstr + return inputstr[:-len(suffix)] + + +remove_suffix = partial(assert_remove_suffix, allow_no_suffix=True) + + def ensure_dir(dirpath, logger=None): """Ensure that a directory exists.""" if not os.path.exists(dirpath):