Skip to content

Commit

Permalink
add for checker for cmd ids orders and union types
Browse files Browse the repository at this point in the history
  • Loading branch information
wekesa360 committed Nov 27, 2023
1 parent aaf01d1 commit 4d9526d
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 0 deletions.
51 changes: 51 additions & 0 deletions projects/jdwp/tools/check_command_ids_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from argparse import ArgumentParser
import os
from projects.jdwp.defs.schema import CommandSet


def check_command_ids(command_set: CommandSet) -> None:
"""Verifies if command IDs in a command set are in ascending order."""
sorted_command_ids = [command.id for command in command_set.commands]
if sorted_command_ids != sorted(sorted_command_ids):
raise ValueError(
f"Command IDs in {command_set.name} are NOT in ascending order."
)


def handle_file_error(file_path: str, e: Exception) -> None:
"""Handles errors related to loading a command set file."""
print(f"Error loading file: {file_path} ({e})")
exit(1)


def main() -> None:
parser = ArgumentParser(description="Check command IDs in command sets")
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
parser.add_argument(
"command_sets_dir", type=str, help="Directory containing command set files"
)
args = parser.parse_args()

command_sets_dir = os.path.abspath(args.command_sets_dir)
command_set_files = [f for f in os.listdir(command_sets_dir) if f.endswith(".py")]

for file in command_set_files:
try:
module_name = os.path.splitext(file)[0]
module_name_title = module_name.replace("_", " ").title().replace(" ", "")
module = __import__(
f"projects.jdwp.defs.command_sets.{module_name}", fromlist=[module_name]
)

if args.verbose:
print(f"Checking {module_name_title}...")

if any(isinstance(obj, CommandSet) for obj in vars(module).values()):
command_set = getattr(module, module_name_title)
check_command_ids(command_set)
except Exception as e:
handle_file_error(file, e)


if __name__ == "__main__":
main()
107 changes: 107 additions & 0 deletions projects/jdwp/tools/check_union_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import argparse
import os
from projects.jdwp.defs.constants import ModifierKind
from projects.jdwp.defs.command_sets.event_request import (
CountModifier,
ConditionalModifier,
ThreadOnlyModifier,
ClassOnlyModifier,
ClassMatchModifier,
ClassExcludeModifier,
StepModifier,
LocationOnlyModifier,
ExceptionOnlyModifier,
FieldOnlyModifier,
InstanceOnlyModifier,
SourceNameMatchModifier,
)
from projects.jdwp.defs.schema import CommandSet, TaggedUnion, UnionTag, Array, Struct


def expected_case_value_for_modifier_kind(modifier_kind: ModifierKind) -> Struct:
if modifier_kind == ModifierKind.COUNT:
return CountModifier
elif modifier_kind == ModifierKind.CONDITIONAL:
return ConditionalModifier
elif modifier_kind == ModifierKind.THREAD_ONLY:
return ThreadOnlyModifier
elif modifier_kind == ModifierKind.CLASS_ONLY:
return ClassOnlyModifier
elif modifier_kind == ModifierKind.CLASS_MATCH:
return ClassMatchModifier
elif modifier_kind == ModifierKind.CLASS_EXCLUDE:
return ClassExcludeModifier
elif modifier_kind == ModifierKind.STEP:
return StepModifier
elif modifier_kind == ModifierKind.LOCATION_ONLY:
return LocationOnlyModifier
elif modifier_kind == ModifierKind.EXCEPTION_ONLY:
return ExceptionOnlyModifier
elif modifier_kind == ModifierKind.FIELD_ONLY:
return FieldOnlyModifier
elif modifier_kind == ModifierKind.INSTANCE_ONLY:
return InstanceOnlyModifier
elif modifier_kind == ModifierKind.SOURCE_NAME_MATCH:
return SourceNameMatchModifier
else:
raise ValueError(f"Unknown modifierKind: {modifier_kind}")


def check_tagged_union_mappings(command_set: CommandSet) -> None:
for command in command_set.commands:
if command.out is None:
continue

for field in command.out.fields:
mod_kind_value = None

if isinstance(field.type, Array):
for element in field.type.element_type.fields:
if isinstance(element.type, UnionTag):
mod_kind_value = element.type.value
if isinstance(element.type, TaggedUnion):
union_tag: UnionTag = element.type.tag
union_cases: dict = element.type.cases
if union_tag.type.value != mod_kind_value:
print(
f"Error in command '{command.name}': Invalid union tag for field '{field.name}'."
)
exit(1)
for case_value, case_type in union_cases.items():
expected_case_value = expected_case_value_for_modifier_kind(
case_value
)
if expected_case_value != case_type:
print(
f"Error in command '{command.name}': Invalid case type for field '{field.name}'."
)
exit(1)


def main() -> None:
parser = argparse.ArgumentParser(
description="Check tagged union mappings in command sets"
)
parser.add_argument(
"command_sets_dir", type=str, help="Directory containing command set files"
)
args = parser.parse_args()

command_sets_dir: str = args.command_sets_dir
for file in os.listdir(command_sets_dir):
if file.endswith(".py"):
module_name: str = os.path.splitext(file)[0]
module_name_title: str = (
module_name.replace("_", " ").title().replace(" ", "")
)
module = __import__(
f"projects.jdwp.defs.command_sets.{module_name}", fromlist=[module_name]
)

command_set: CommandSet = getattr(module, module_name_title)
if isinstance(command_set, CommandSet):
check_tagged_union_mappings(command_set)


if __name__ == "__main__":
main()

0 comments on commit 4d9526d

Please sign in to comment.