Skip to content

Commit

Permalink
Start implementing JavaType serde logic
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwannheden committed Dec 31, 2024
1 parent e745d45 commit 4483ec5
Show file tree
Hide file tree
Showing 4 changed files with 304 additions and 4 deletions.
3 changes: 3 additions & 0 deletions rewrite/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ warn_unused_ignores = true
warn_return_any = true
warn_unreachable = true
disallow_untyped_calls = false
disallow_untyped_defs = false
# to disable relative imports error
disable_error_code = misc
7 changes: 7 additions & 0 deletions rewrite/rewrite/java/remote/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__)

from typing import TypeVar

from .extensions import *
from .receiver import *
from .sender import *

__all__ = [name for name in dir() if not name.startswith('_') and not isinstance(globals()[name], TypeVar)]

from .register import register_codecs
register_codecs()
260 changes: 260 additions & 0 deletions rewrite/rewrite/java/remote/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
from typing import Optional, List, Type, Any

from _cbor2 import break_marker
from cbor2 import CBOREncoder, CBORDecoder
from rewrite_remote import RemotingContext, SerializationContext, DeserializationContext

from ..tree import JavaType


def register_codecs():
# SenderContext.register(J, JavaSender)
# ReceiverContext.register(J, JavaReceiver)

RemotingContext.register_value_deserializer(
'org.openrewrite.java.tree.JavaType$Primitive',
deserialize_java_primitive
)
RemotingContext.register_value_serializer(
JavaType.Primitive,
serialize_java_primitive
)
RemotingContext.register_value_deserializer(
'org.openrewrite.java.tree.JavaType$Class',
deserialize_java_class
)
RemotingContext.register_value_serializer(
JavaType.Class,
serialize_java_class
)
RemotingContext.register_value_deserializer(
'org.openrewrite.java.tree.JavaType$Method',
deserialize_java_method
)
RemotingContext.register_value_serializer(
JavaType.Method,
serialize_java_method
)
RemotingContext.register_value_deserializer(
'org.openrewrite.java.tree.JavaType$Variable',
deserialize_java_variable
)
RemotingContext.register_value_serializer(
JavaType.Variable,
serialize_java_variable
)
RemotingContext.register_value_deserializer(
'org.openrewrite.java.tree.JavaType$Array',
deserialize_java_array
)
RemotingContext.register_value_serializer(
JavaType.Array,
serialize_java_array
)
RemotingContext.register_value_deserializer(
'org.openrewrite.java.tree.JavaType$Parameterized',
deserialize_java_parameterized
)
RemotingContext.register_value_serializer(
JavaType.Parameterized,
serialize_java_parameterized
)
RemotingContext.register_value_deserializer(
'org.openrewrite.java.tree.JavaType$GenericTypeVariable',
deserialize_java_generic_type_variable
)
RemotingContext.register_value_serializer(
JavaType.GenericTypeVariable,
serialize_java_generic_type_variable
)
RemotingContext.register_value_deserializer(
'org.openrewrite.java.tree.JavaType$Unknown',
deserialize_java_unknown
)
RemotingContext.register_value_serializer(
JavaType.Unknown,
serialize_java_unknown
)


def deserialize_java_class(type_: str, decoder: CBORDecoder, context: DeserializationContext) -> JavaType.Class:
cls = JavaType.ShallowClass() if type_ == 'org.openrewrite.java.tree.JavaType$ShallowClass' else JavaType.Class()
while not (key := decoder.decode()) == break_marker:
if key == '@ref':
context.remoting_context.add_by_id(decoder.decode(), cls)
elif key == 'flagsBitMap':
setattr(cls, '_flags_bit_map', decoder.decode())
elif key == 'fullyQualifiedName':
name = decoder.decode()
setattr(cls, '_fully_qualified_name', name)
elif key == 'kind':
setattr(cls, '_kind', context.deserialize(JavaType.FullyQualified.Kind, decoder))
elif key == 'typeParameters':
setattr(cls, '_type_parameters', context.deserialize(List[JavaType], decoder))
elif key == 'supertype':
setattr(cls, '_supertype', context.deserialize(JavaType.FullyQualified, decoder))
elif key == 'owningClass':
setattr(cls, '_owning_class', context.deserialize(JavaType.FullyQualified, decoder))
elif key == 'annotations':
setattr(cls, '_annotations', context.deserialize(List[JavaType.FullyQualified], decoder))
elif key == 'interfaces':
setattr(cls, '_interfaces', context.deserialize(List[JavaType.FullyQualified], decoder))
elif key == 'members':
setattr(cls, '_members', context.deserialize(List[JavaType.Variable], decoder))
elif key == 'methods':
setattr(cls, '_methods', context.deserialize(List[JavaType.Method], decoder))
return cls


def deserialize_java_method(_: str, decoder: CBORDecoder, context: DeserializationContext) -> JavaType.Method:
method = JavaType.Method()
while not (key := decoder.decode()) == break_marker:
if key == '@ref':
context.remoting_context.add_by_id(decoder.decode(), method)
elif key == 'flagsBitMap':
setattr(method, '_flags_bit_map', decoder.decode())
elif key == 'declaringType':
setattr(method, '_declaring_type', context.deserialize(JavaType.FullyQualified, decoder))
elif key == 'name':
setattr(method, '_name', decoder.decode())
elif key == 'returnType':
setattr(method, '_return_type', context.deserialize(JavaType, decoder))
elif key == 'parameterNames':
setattr(method, '_parameter_names', context.deserialize(List[str], decoder))
elif key == 'parameterTypes':
setattr(method, '_parameter_types', context.deserialize(List[JavaType], decoder))
elif key == 'thrownExceptions':
setattr(method, '_thrown_exceptions', context.deserialize(List[JavaType.FullyQualified], decoder))
elif key == 'annotations':
setattr(method, '_annotations', context.deserialize(List[JavaType.FullyQualified], decoder))
elif key == 'defaultValue':
setattr(method, '_default_value', context.deserialize(List[str], decoder))
elif key == 'declaredFormalTypeNames':
setattr(method, '_declared_formal_type_names', context.deserialize(List[str], decoder))
return method


def deserialize_java_variable(_: str, decoder: CBORDecoder, context: DeserializationContext) -> JavaType.Variable:
variable = JavaType.Variable()
while not (key := decoder.decode()) == break_marker:
if key == '@ref':
context.remoting_context.add_by_id(decoder.decode(), variable)
elif key == 'flagsBitMap':
setattr(variable, '_flags_bit_map', decoder.decode())
elif key == 'name':
setattr(variable, '_name', decoder.decode())
elif key == 'owner':
setattr(variable, '_owner', context.deserialize(JavaType, decoder))
elif key == 'type':
setattr(variable, '_type', context.deserialize(JavaType, decoder))
elif key == 'annotations':
setattr(variable, '_annotations', context.deserialize(List[JavaType.FullyQualified], decoder))
return variable


def deserialize_java_array(_: str, decoder: CBORDecoder, context: DeserializationContext) -> JavaType.Array:
array = JavaType.Array()
while not (key := decoder.decode()) == break_marker:
if key == '@ref':
context.remoting_context.add_by_id(decoder.decode(), array)
elif key == 'elemType':
setattr(array, '_elem_type', context.deserialize(JavaType, decoder))
elif key == 'annotations':
setattr(array, '_annotations', context.deserialize(List[JavaType.FullyQualified], decoder))
return array


def deserialize_java_parameterized(_: str, decoder: CBORDecoder,
context: DeserializationContext) -> JavaType.Parameterized:
param = JavaType.Parameterized()
while not (key := decoder.decode()) == break_marker:
if key == '@ref':
context.remoting_context.add_by_id(decoder.decode(), param)
elif key == 'type':
setattr(param, '_type', context.deserialize(JavaType.FullyQualified, decoder))
elif key == 'typeParameters':
setattr(param, '_type_parameters', context.deserialize(List[JavaType], decoder))
return param


def deserialize_java_generic_type_variable(_: str, decoder: CBORDecoder,
context: DeserializationContext) -> JavaType.GenericTypeVariable:
type_variable = JavaType.GenericTypeVariable()
while not (key := decoder.decode()) == break_marker:
if key == '@ref':
context.remoting_context.add_by_id(decoder.decode(), type_variable)
elif key == 'name':
setattr(type_variable, '_name', decoder.decode())
elif key == 'variance':
setattr(type_variable, '_variance', context.deserialize(JavaType.GenericTypeVariable.Variance, decoder))
elif key == 'bounds':
setattr(type_variable, '_bounds', context.deserialize(List[JavaType], decoder))
return type_variable


def deserialize_java_primitive(_: str, decoder: CBORDecoder, context: DeserializationContext) -> JavaType.Primitive:
kind = decoder.decode()
assert decoder.decode() == break_marker
return JavaType.Primitive(kind)


def deserialize_java_unknown(_: str, decoder: CBORDecoder, context: DeserializationContext) -> JavaType.Unknown:
unknown = JavaType.Unknown()
while not (key := decoder.decode()) == break_marker:
if key == '@ref':
context.remoting_context.add_by_id(decoder.decode(), unknown)
else:
decoder.decode()
return unknown


def serialize_java_primitive(value: JavaType.Primitive, type_name: Optional[str], encoder: CBOREncoder,
context: SerializationContext) -> None:
encoder.encode(['org.openrewrite.java.tree.JavaType$Primitive', value.value])


def serialize_java_class(value: JavaType.Class, type_name: Optional[str], encoder: CBOREncoder,
context: SerializationContext) -> None:
pass


def serialize_java_method(value: JavaType.Method, type_name: Optional[str], encoder: CBOREncoder,
context: SerializationContext) -> None:
pass


def serialize_java_variable(value: JavaType.Variable, type_name: Optional[str], encoder: CBOREncoder,
context: SerializationContext) -> None:
pass


def serialize_java_array(value: JavaType.Array, type_name: Optional[str], encoder: CBOREncoder,
context: SerializationContext) -> None:
pass


def serialize_java_parameterized(value: JavaType.Parameterized, type_name: Optional[str], encoder: CBOREncoder,
context: SerializationContext) -> None:
pass


def serialize_java_generic_type_variable(value: JavaType.GenericTypeVariable, type_name: Optional[str],
encoder: CBOREncoder, context: SerializationContext) -> None:
pass


def serialize_java_union(value: JavaType.GenericTypeVariable, type_name: Optional[str], encoder: CBOREncoder,
context: SerializationContext) -> None:
pass


def serialize_java_unknown(value: JavaType.GenericTypeVariable, type_name: Optional[str], encoder: CBOREncoder,
context: SerializationContext) -> None:
id = context.remoting_context.try_get_id(value)
if id is not None:
encoder.encode(id)
return
encoder.encode({
'@c': 'org.openrewrite.java.tree.JavaType$Unknown',
'@ref': context.remoting_context.add(value)
})
38 changes: 34 additions & 4 deletions rewrite/rewrite/java/support_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def format_first_prefix(cls, trees: List[J2], prefix: Space) -> List[J2]:
return formatted_trees
return trees


EMPTY: ClassVar[Space]
SINGLE_SPACE: ClassVar[Space]

Expand Down Expand Up @@ -323,23 +322,54 @@ class MethodCall(Expression):

class JavaType(ABC):
class FullyQualified:
class Kind(Enum):
Class = 0
Enum = 1
Interface = 2
Annotation = 3
Record = 4

class Unknown(FullyQualified):
pass

class Class:
class Class(FullyQualified):
pass

class Parameterized:
class ShallowClass(Class):
pass

class Primitive:
class Parameterized(FullyQualified):
pass

class GenericTypeVariable:
class Variance(Enum):
Invariant = 0
Covariant = 1
Contravariant = 2

class Primitive(Enum):
Boolean = 0
Byte = 1
Char = 2
Double = 3
Float = 4
Int = 5
Long = 6
Short = 7
Void = 8
String = 9
None_ = 10
Null = 11

class Method:
pass

class Variable:
pass

class Array:
pass


T = TypeVar('T')
J2 = TypeVar('J2', bound=J)
Expand Down

0 comments on commit 4483ec5

Please sign in to comment.