Skip to content

Commit

Permalink
[ngcodegen][objectmodel] add model generator
Browse files Browse the repository at this point in the history
  • Loading branch information
apalala committed Nov 29, 2023
1 parent 18462a5 commit e38a3dc
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 37 deletions.
89 changes: 53 additions & 36 deletions tatsu/ngcodegen/objectmodel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import inspect
import re
from collections import namedtuple

from .. import grammars, objectmodel
from ..exceptions import CodegenError
from ..mixins.indent import IndentPrintMixin
from ..util import safe_name, compress_seq
from ..walkers import NodeWalker
Expand Down Expand Up @@ -45,23 +42,22 @@ def __init__(self, context=None, types=None):
"""


TypeSpec = namedtuple('TypeSpec', ['class_name', 'base'])
BaseClassSpec = namedtuple('TypeSpec', ['class_name', 'base'])


def codegen(model: grammars.Model, parser_name: str = '', base_type: type = objectmodel.Node) -> str:
def modelgen(model: grammars.Model, parser_name: str = '', base_type: type = objectmodel.Node) -> str:
generator = PythonModelGenerator(parser_name=parser_name, base_type=base_type)
generator.walk(model)
return generator.printed_text()
return generator.generate_model(model)


class PythonModelGenerator(IndentPrintMixin, NodeWalker):
class PythonModelGenerator(IndentPrintMixin):

def __init__(self, parser_name: str = '', base_type: type = objectmodel.Node):
super().__init__()
self.base_type = base_type
self.parser_name = parser_name or None

def walk_Grammar(self, grammar: grammars.Grammar):
def generate_model(self, grammar: grammars.Grammar):
base_type_qual = self.base_type.__module__
base_type_import = f'from {self.base_type.__module__} import {self.base_type.__name__.split('.')[-1]}'

Expand All @@ -78,17 +74,55 @@ def walk_Grammar(self, grammar: grammars.Grammar):

rule_index = {rule.name: rule for rule in grammar.rules}
rule_specs = {
rule.name: self._type_specs(rule)
rule.name: self._base_class_specs(rule)
for rule in grammar.rules
}

model_classes = {s.class_name for spec in rule_specs.values() for s in spec}
base_classes = {s.base for spec in rule_specs.values() for s in spec}
base_classes -= model_classes
# raise Exception('HERE', base_classes, model_classes)

for base_name in base_classes:
if base_name in rule_specs:
self._gen_base_class(rule_specs[base_name])

for model_name, rule in rule_index.items():
if model_name in rule_index:
self._gen_rule_class(
rule,
rule_specs[model_name],
)

return self.printed_text()

def _gen_base_class(self, spec: BaseClassSpec):
self.print()
self.print()
if spec.base:
self.print(f'class {spec.class_name}({spec.base}):')
else:
self.print(f'class {spec.class_name}:')
with self.indent():
self.print('pass')

def _gen_rule_class(self, rule: grammars.Rule, specs: list[BaseClassSpec]):
if not specs:
return
spec = specs[0]
arguments = sorted({safe_name(d) for d, _ in compress_seq(rule.defines())})

self.print()
self.print()
self.print('@dataclass(eq=False)')
self.print(f'class {spec.class_name}({spec.base}):')
with self.indent():
if not arguments:
self.print('pass')
for arg in arguments:
self.print(f'{arg}: Any = None')

def walk_Rule(self, rule: grammars.Rule):
specs = self._type_specs(rule)
specs = self._base_class_specs(rule)
if not specs:
return

Expand All @@ -108,30 +142,13 @@ def walk_Rule(self, rule: grammars.Rule):
for arg in arguments:
self.print(f'{arg}: Any = None')

def _type_specs(self, rule: grammars.Rule) -> TypeSpec:
if not self._get_node_class_name(rule):
return []
def _base_class_specs(self, rule: grammars.Rule) -> BaseClassSpec:
if not rule.params:
return ()

spec = rule.params[0].split('::')
class_names = [safe_name(n) for n in spec] + [f'{self.parser_name}ModelBase']

typespec = []
for i, class_name in enumerate(class_names[:-1]):
base = class_names[i + 1]
typespec.append(TypeSpec(class_name, base))

return typespec

@staticmethod
def _get_node_class_name(rule: grammars.Rule):
if not rule.params:
return None

node_names = rule.params[0]
if not isinstance(node_names, str):
return None
if not re.match(r'\w+(?:::\w+)*', node_names):
return None
if not node_names[0].isupper():
return None
return node_names
return tuple(
BaseClassSpec(class_name, class_names[i + 1])
for i, class_name in enumerate(class_names[:-1])
)
2 changes: 1 addition & 1 deletion tatsu/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def main():
elif args.ng_parser:
result = ngpythoncg(model)
elif args.ng_model:
result = ngobjectmodel.codegen(model, parser_name=args.name)
result = ngobjectmodel.modelgen(model, parser_name=args.name)
else:
result = pythoncg(model)

Expand Down

0 comments on commit e38a3dc

Please sign in to comment.