From f9d3fae5c6624b220336fb66fbed870185c6cf88 Mon Sep 17 00:00:00 2001 From: gofrendi Date: Sat, 30 Nov 2024 11:59:00 +0700 Subject: [PATCH] Add codemod function --- src/zrb/util/codemod/add_parent_to_class.py | 38 +++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 src/zrb/util/codemod/add_parent_to_class.py diff --git a/src/zrb/util/codemod/add_parent_to_class.py b/src/zrb/util/codemod/add_parent_to_class.py new file mode 100644 index 00000000..94aee6a2 --- /dev/null +++ b/src/zrb/util/codemod/add_parent_to_class.py @@ -0,0 +1,38 @@ +import libcst as cst + + +class ParentClassAdder(cst.CSTTransformer): + def __init__(self, class_name: str, parent_class_name: str): + self.class_name = class_name + self.parent_class_name = parent_class_name + self.class_found = False + + def leave_ClassDef( + self, original_node: cst.ClassDef, updated_node: cst.ClassDef + ) -> cst.ClassDef: + # Check if this is the target class + if original_node.name.value == self.class_name: + self.class_found = True + # Add the parent class to the existing bases + new_bases = ( + *updated_node.bases, + cst.Arg(value=cst.Name(self.parent_class_name)), + ) + return updated_node.with_changes(bases=new_bases) + return updated_node + + +def add_parent_to_class( + original_code: str, class_name: str, parent_class_name: str +) -> str: + # Parse the original code into a module + module = cst.parse_module(original_code) + # Initialize transformer with the class name and parent class name + transformer = ParentClassAdder(class_name, parent_class_name) + # Apply the transformation + modified_module = module.visit(transformer) + # Check if the class was found + if not transformer.class_found: + raise ValueError(f"Class {class_name} not found in the provided code.") + # Return the modified code + return modified_module.code