Skip to content

Commit

Permalink
[ngcodegen][model] refactor and optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
apalala committed Dec 10, 2023
1 parent 224612c commit c48c19f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
22 changes: 11 additions & 11 deletions tatsu/ngcodegen/objectmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .. import grammars, objectmodel
from ..mixins.indent import IndentPrintMixin
from ..util import compress_seq, safe_name
from ..util.misc import topological_sort
from ..util.misc import topsort

HEADER = """\
#!/usr/bin/env python3
Expand Down Expand Up @@ -74,39 +74,39 @@ def generate_model(self, grammar: grammars.Grammar):
}
rule_specs = {name: specs for name, specs in rule_specs.items() if specs}

all_base_spec = {
specs_by_name = {
s.class_name: s.base
for specs in rule_specs.values()
for s in specs
}
base = self._model_base_class_name()
all_base_spec[base] = base_type_name
base = self._model_base_name()
specs_by_name[base] = base_type_name

all_model_names = list(reversed(all_base_spec.keys()))
all_specs = {
(s.class_name, s.base)
for specs in rule_specs.values()
for s in specs
}
model_names = topsort(reversed(specs_by_name), all_specs)

all_model_names = topological_sort(all_model_names, all_specs)
model_to_rule = {
rule_specs[name][0].class_name: rule
for name, rule in rule_index.items()
if name in rule_specs
}

for model_name in all_model_names:
if model_name in dir(builtins):
for model_name in model_names:
if model_name in vars(builtins):
continue
if rule := model_to_rule.get(model_name):
self._gen_rule_class(rule, rule_specs[rule.name])
else:
self._gen_base_class(model_name, all_base_spec.get(model_name))
self._gen_base_class(model_name, specs_by_name.get(model_name))

return self.printed_text()

def _model_base_class_name(self):
@staticmethod
def _model_base_name():
return 'ModelBase'

def _gen_base_class(self, class_name: str, base: str | None):
Expand Down Expand Up @@ -139,7 +139,7 @@ def _gen_rule_class(self, rule: grammars.Rule, specs: list[BaseClassSpec]):

def _base_class_specs(self, rule: grammars.Rule) -> list[BaseClassSpec]:
spec = rule.params[0].split('::') if rule.params else []
base = [self._model_base_class_name()]
base = [self._model_base_name()]
class_names = [safe_name(n) for n in spec] + base
return [
BaseClassSpec(class_name, class_names[i + 1])
Expand Down
21 changes: 10 additions & 11 deletions tatsu/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,29 +85,28 @@ def findfirst(pattern, string, pos=None, endpos=None, flags=0, default=_undefine
)


def topological_sort(nodes: Iterable[_T], order: Iterable[tuple[_T, _T]]) -> list[_T]:
def topsort(nodes: Iterable[_T], order: Iterable[tuple[_T, _T]]) -> list[_T]:
# https://en.wikipedia.org/wiki/Topological_sorting

order = set(order)
result: list[_T] = [] # Empty list that will contain the sorted elements

pending = [ # Set of all nodes with no incoming edge
n for n in nodes
if not any(x for (x, y) in order if y == n)
def with_incoming():
return {m for (_, m) in order}

pending = [ # Set of all nodes with no incoming edges
n for n in nodes if n not in with_incoming()
]
while pending:
n = pending.pop()
result.insert(0, n)

# nodes m with an edge from n to m
outgoing = {m for (x, m) in order if x == n}
# node m with an edge e from n to m
for m in outgoing:
order.remove((n, m))
if not any(x for x, y in order if y == m):
# m has no other incoming edges
pending.append(m)
order -= {(n, m) for m in outgoing}
pending.extend(outgoing - with_incoming())

if order:
raise ValueError('There are cycles in the graph')
raise ValueError('There are cycles in the topological order')

return result # a topologically sorted list

0 comments on commit c48c19f

Please sign in to comment.