diff --git a/rewrite/mypy.ini b/rewrite/mypy.ini index 00216c34..c805eaba 100644 --- a/rewrite/mypy.ini +++ b/rewrite/mypy.ini @@ -5,3 +5,6 @@ warn_unused_ignores = true warn_return_any = true warn_unreachable = true disallow_untyped_calls = false +disallow_untyped_defs = false +# to disable relative imports error +disable_error_code = misc diff --git a/rewrite/rewrite/java/remote/__init__.py b/rewrite/rewrite/java/remote/__init__.py index c22cfe26..86335464 100644 --- a/rewrite/rewrite/java/remote/__init__.py +++ b/rewrite/rewrite/java/remote/__init__.py @@ -1,5 +1,12 @@ __path__ = __import__('pkgutil').extend_path(__path__, __name__) +from typing import TypeVar + from .extensions import * from .receiver import * from .sender import * + +__all__ = [name for name in dir() if not name.startswith('_') and not isinstance(globals()[name], TypeVar)] + +from .register import register_codecs +register_codecs() diff --git a/rewrite/rewrite/java/remote/register.py b/rewrite/rewrite/java/remote/register.py new file mode 100644 index 00000000..7baeb6e0 --- /dev/null +++ b/rewrite/rewrite/java/remote/register.py @@ -0,0 +1,260 @@ +from typing import Optional, List, Type, Any + +from _cbor2 import break_marker +from cbor2 import CBOREncoder, CBORDecoder +from rewrite_remote import RemotingContext, SerializationContext, DeserializationContext + +from ..tree import JavaType + + +def register_codecs(): + # SenderContext.register(J, JavaSender) + # ReceiverContext.register(J, JavaReceiver) + + RemotingContext.register_value_deserializer( + 'org.openrewrite.java.tree.JavaType$Primitive', + deserialize_java_primitive + ) + RemotingContext.register_value_serializer( + JavaType.Primitive, + serialize_java_primitive + ) + RemotingContext.register_value_deserializer( + 'org.openrewrite.java.tree.JavaType$Class', + deserialize_java_class + ) + RemotingContext.register_value_serializer( + JavaType.Class, + serialize_java_class + ) + RemotingContext.register_value_deserializer( + 'org.openrewrite.java.tree.JavaType$Method', + deserialize_java_method + ) + RemotingContext.register_value_serializer( + JavaType.Method, + serialize_java_method + ) + RemotingContext.register_value_deserializer( + 'org.openrewrite.java.tree.JavaType$Variable', + deserialize_java_variable + ) + RemotingContext.register_value_serializer( + JavaType.Variable, + serialize_java_variable + ) + RemotingContext.register_value_deserializer( + 'org.openrewrite.java.tree.JavaType$Array', + deserialize_java_array + ) + RemotingContext.register_value_serializer( + JavaType.Array, + serialize_java_array + ) + RemotingContext.register_value_deserializer( + 'org.openrewrite.java.tree.JavaType$Parameterized', + deserialize_java_parameterized + ) + RemotingContext.register_value_serializer( + JavaType.Parameterized, + serialize_java_parameterized + ) + RemotingContext.register_value_deserializer( + 'org.openrewrite.java.tree.JavaType$GenericTypeVariable', + deserialize_java_generic_type_variable + ) + RemotingContext.register_value_serializer( + JavaType.GenericTypeVariable, + serialize_java_generic_type_variable + ) + RemotingContext.register_value_deserializer( + 'org.openrewrite.java.tree.JavaType$Unknown', + deserialize_java_unknown + ) + RemotingContext.register_value_serializer( + JavaType.Unknown, + serialize_java_unknown + ) + + +def deserialize_java_class(type_: str, decoder: CBORDecoder, context: DeserializationContext) -> JavaType.Class: + cls = JavaType.ShallowClass() if type_ == 'org.openrewrite.java.tree.JavaType$ShallowClass' else JavaType.Class() + while not (key := decoder.decode()) == break_marker: + if key == '@ref': + context.remoting_context.add_by_id(decoder.decode(), cls) + elif key == 'flagsBitMap': + setattr(cls, '_flags_bit_map', decoder.decode()) + elif key == 'fullyQualifiedName': + name = decoder.decode() + setattr(cls, '_fully_qualified_name', name) + elif key == 'kind': + setattr(cls, '_kind', context.deserialize(JavaType.FullyQualified.Kind, decoder)) + elif key == 'typeParameters': + setattr(cls, '_type_parameters', context.deserialize(List[JavaType], decoder)) + elif key == 'supertype': + setattr(cls, '_supertype', context.deserialize(JavaType.FullyQualified, decoder)) + elif key == 'owningClass': + setattr(cls, '_owning_class', context.deserialize(JavaType.FullyQualified, decoder)) + elif key == 'annotations': + setattr(cls, '_annotations', context.deserialize(List[JavaType.FullyQualified], decoder)) + elif key == 'interfaces': + setattr(cls, '_interfaces', context.deserialize(List[JavaType.FullyQualified], decoder)) + elif key == 'members': + setattr(cls, '_members', context.deserialize(List[JavaType.Variable], decoder)) + elif key == 'methods': + setattr(cls, '_methods', context.deserialize(List[JavaType.Method], decoder)) + return cls + + +def deserialize_java_method(_: str, decoder: CBORDecoder, context: DeserializationContext) -> JavaType.Method: + method = JavaType.Method() + while not (key := decoder.decode()) == break_marker: + if key == '@ref': + context.remoting_context.add_by_id(decoder.decode(), method) + elif key == 'flagsBitMap': + setattr(method, '_flags_bit_map', decoder.decode()) + elif key == 'declaringType': + setattr(method, '_declaring_type', context.deserialize(JavaType.FullyQualified, decoder)) + elif key == 'name': + setattr(method, '_name', decoder.decode()) + elif key == 'returnType': + setattr(method, '_return_type', context.deserialize(JavaType, decoder)) + elif key == 'parameterNames': + setattr(method, '_parameter_names', context.deserialize(List[str], decoder)) + elif key == 'parameterTypes': + setattr(method, '_parameter_types', context.deserialize(List[JavaType], decoder)) + elif key == 'thrownExceptions': + setattr(method, '_thrown_exceptions', context.deserialize(List[JavaType.FullyQualified], decoder)) + elif key == 'annotations': + setattr(method, '_annotations', context.deserialize(List[JavaType.FullyQualified], decoder)) + elif key == 'defaultValue': + setattr(method, '_default_value', context.deserialize(List[str], decoder)) + elif key == 'declaredFormalTypeNames': + setattr(method, '_declared_formal_type_names', context.deserialize(List[str], decoder)) + return method + + +def deserialize_java_variable(_: str, decoder: CBORDecoder, context: DeserializationContext) -> JavaType.Variable: + variable = JavaType.Variable() + while not (key := decoder.decode()) == break_marker: + if key == '@ref': + context.remoting_context.add_by_id(decoder.decode(), variable) + elif key == 'flagsBitMap': + setattr(variable, '_flags_bit_map', decoder.decode()) + elif key == 'name': + setattr(variable, '_name', decoder.decode()) + elif key == 'owner': + setattr(variable, '_owner', context.deserialize(JavaType, decoder)) + elif key == 'type': + setattr(variable, '_type', context.deserialize(JavaType, decoder)) + elif key == 'annotations': + setattr(variable, '_annotations', context.deserialize(List[JavaType.FullyQualified], decoder)) + return variable + + +def deserialize_java_array(_: str, decoder: CBORDecoder, context: DeserializationContext) -> JavaType.Array: + array = JavaType.Array() + while not (key := decoder.decode()) == break_marker: + if key == '@ref': + context.remoting_context.add_by_id(decoder.decode(), array) + elif key == 'elemType': + setattr(array, '_elem_type', context.deserialize(JavaType, decoder)) + elif key == 'annotations': + setattr(array, '_annotations', context.deserialize(List[JavaType.FullyQualified], decoder)) + return array + + +def deserialize_java_parameterized(_: str, decoder: CBORDecoder, + context: DeserializationContext) -> JavaType.Parameterized: + param = JavaType.Parameterized() + while not (key := decoder.decode()) == break_marker: + if key == '@ref': + context.remoting_context.add_by_id(decoder.decode(), param) + elif key == 'type': + setattr(param, '_type', context.deserialize(JavaType.FullyQualified, decoder)) + elif key == 'typeParameters': + setattr(param, '_type_parameters', context.deserialize(List[JavaType], decoder)) + return param + + +def deserialize_java_generic_type_variable(_: str, decoder: CBORDecoder, + context: DeserializationContext) -> JavaType.GenericTypeVariable: + type_variable = JavaType.GenericTypeVariable() + while not (key := decoder.decode()) == break_marker: + if key == '@ref': + context.remoting_context.add_by_id(decoder.decode(), type_variable) + elif key == 'name': + setattr(type_variable, '_name', decoder.decode()) + elif key == 'variance': + setattr(type_variable, '_variance', context.deserialize(JavaType.GenericTypeVariable.Variance, decoder)) + elif key == 'bounds': + setattr(type_variable, '_bounds', context.deserialize(List[JavaType], decoder)) + return type_variable + + +def deserialize_java_primitive(_: str, decoder: CBORDecoder, context: DeserializationContext) -> JavaType.Primitive: + kind = decoder.decode() + assert decoder.decode() == break_marker + return JavaType.Primitive(kind) + + +def deserialize_java_unknown(_: str, decoder: CBORDecoder, context: DeserializationContext) -> JavaType.Unknown: + unknown = JavaType.Unknown() + while not (key := decoder.decode()) == break_marker: + if key == '@ref': + context.remoting_context.add_by_id(decoder.decode(), unknown) + else: + decoder.decode() + return unknown + + +def serialize_java_primitive(value: JavaType.Primitive, type_name: Optional[str], encoder: CBOREncoder, + context: SerializationContext) -> None: + encoder.encode(['org.openrewrite.java.tree.JavaType$Primitive', value.value]) + + +def serialize_java_class(value: JavaType.Class, type_name: Optional[str], encoder: CBOREncoder, + context: SerializationContext) -> None: + pass + + +def serialize_java_method(value: JavaType.Method, type_name: Optional[str], encoder: CBOREncoder, + context: SerializationContext) -> None: + pass + + +def serialize_java_variable(value: JavaType.Variable, type_name: Optional[str], encoder: CBOREncoder, + context: SerializationContext) -> None: + pass + + +def serialize_java_array(value: JavaType.Array, type_name: Optional[str], encoder: CBOREncoder, + context: SerializationContext) -> None: + pass + + +def serialize_java_parameterized(value: JavaType.Parameterized, type_name: Optional[str], encoder: CBOREncoder, + context: SerializationContext) -> None: + pass + + +def serialize_java_generic_type_variable(value: JavaType.GenericTypeVariable, type_name: Optional[str], + encoder: CBOREncoder, context: SerializationContext) -> None: + pass + + +def serialize_java_union(value: JavaType.GenericTypeVariable, type_name: Optional[str], encoder: CBOREncoder, + context: SerializationContext) -> None: + pass + + +def serialize_java_unknown(value: JavaType.GenericTypeVariable, type_name: Optional[str], encoder: CBOREncoder, + context: SerializationContext) -> None: + id = context.remoting_context.try_get_id(value) + if id is not None: + encoder.encode(id) + return + encoder.encode({ + '@c': 'org.openrewrite.java.tree.JavaType$Unknown', + '@ref': context.remoting_context.add(value) + }) diff --git a/rewrite/rewrite/java/support_types.py b/rewrite/rewrite/java/support_types.py index accff282..e2197b1b 100644 --- a/rewrite/rewrite/java/support_types.py +++ b/rewrite/rewrite/java/support_types.py @@ -127,7 +127,6 @@ def format_first_prefix(cls, trees: List[J2], prefix: Space) -> List[J2]: return formatted_trees return trees - EMPTY: ClassVar[Space] SINGLE_SPACE: ClassVar[Space] @@ -323,23 +322,54 @@ class MethodCall(Expression): class JavaType(ABC): class FullyQualified: + class Kind(Enum): + Class = 0 + Enum = 1 + Interface = 2 + Annotation = 3 + Record = 4 + + class Unknown(FullyQualified): pass - class Class: + class Class(FullyQualified): pass - class Parameterized: + class ShallowClass(Class): pass - class Primitive: + class Parameterized(FullyQualified): pass + class GenericTypeVariable: + class Variance(Enum): + Invariant = 0 + Covariant = 1 + Contravariant = 2 + + class Primitive(Enum): + Boolean = 0 + Byte = 1 + Char = 2 + Double = 3 + Float = 4 + Int = 5 + Long = 6 + Short = 7 + Void = 8 + String = 9 + None_ = 10 + Null = 11 + class Method: pass class Variable: pass + class Array: + pass + T = TypeVar('T') J2 = TypeVar('J2', bound=J)