Skip to content

Commit

Permalink
Refactor message content serialization and validation: Introduce a de…
Browse files Browse the repository at this point in the history
…fault method to dump message content
  • Loading branch information
MHHukiewitz committed Jan 31, 2024
1 parent e872112 commit f7cc2fb
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 22 deletions.
53 changes: 35 additions & 18 deletions aleph_message/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseModel, Field, validator
from typing_extensions import TypeAlias

from ..utils import dump_content
from .abstract import BaseContent
from .base import Chain, HashType, MessageType
from .execution.base import MachineType, Payment, PaymentType # noqa
Expand Down Expand Up @@ -105,7 +106,7 @@ class ForgetContent(BaseContent):
"""Content of a FORGET message"""

hashes: List[ItemHash]
aggregates: List[ItemHash] = Field(default_factory=list)
aggregates: Optional[List[ItemHash]] = None
reason: Optional[str] = None

def __hash__(self):
Expand Down Expand Up @@ -179,6 +180,36 @@ def check_item_content(cls, v: Optional[str], values) -> Optional[str]:
)
return v

@validator("content")
def check_content(cls, v, values):
item_type = values["item_type"]
if item_type == ItemType.inline:
try:
item_content = json.loads(values["item_content"])
except JSONDecodeError:
raise ValueError(
"Field 'item_content' does not appear to be valid JSON"
)
json_dump = json.loads(v.json())
for key, value in json_dump.items():
if value != item_content[key]:
if isinstance(value, list):
for item in value:
if item not in item_content[key]:
raise ValueError(
f"Field 'content.{key}' does not match 'item_content.{key}': {item} != {item_content[key]}"
)
if isinstance(value, dict):
for item in value.items():
if item not in item_content[key].items():
raise ValueError(
f"Field 'content.{key}' does not match 'item_content.{key}': {value} != {item_content[key]}"
)
raise ValueError(
f"Field 'content.{key}' does not match 'item_content.{key}': {value} != {item_content[key]} or type mismatch ({type(value)} != {type(item_content[key])})"
)
return v

@validator("item_hash")
def check_item_hash(cls, v: ItemHash, values) -> ItemHash:
item_type = values["item_type"]
Expand Down Expand Up @@ -255,20 +286,6 @@ class ProgramMessage(BaseMessage):
type: Literal[MessageType.program]
content: ProgramContent

@validator("content")
def check_content(cls, v, values):
item_type = values["item_type"]
if item_type == ItemType.inline:
item_content = json.loads(values["item_content"])
if v.dict(exclude_none=True) != item_content:
# Print differences
vdict = v.dict(exclude_none=True)
for key, value in item_content.items():
if vdict[key] != value:
print(f"{key}: {vdict[key]} != {value}")
raise ValueError("Content and item_content differ")
return v


class InstanceMessage(BaseMessage):
type: Literal[MessageType.instance]
Expand Down Expand Up @@ -315,12 +332,12 @@ def parse_message(message_dict: Dict) -> AlephMessage:


def add_item_content_and_hash(message_dict: Dict, inplace: bool = False):
# TODO: I really don't like this function. There is no validation of the
# message_dict, if it is indeed a real message, and can lead to unexpected results.
if not inplace:
message_dict = copy(message_dict)

message_dict["item_content"] = json.dumps(
message_dict["content"], separators=(",", ":")
)
message_dict["item_content"] = dump_content(message_dict["content"])
message_dict["item_hash"] = sha256(
message_dict["item_content"].encode()
).hexdigest()
Expand Down
5 changes: 5 additions & 0 deletions aleph_message/models/abstract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pydantic import BaseModel

from aleph_message.utils import dump_content


def hashable(obj):
"""Convert `obj` into a hashable object."""
Expand All @@ -23,3 +25,6 @@ class BaseContent(BaseModel):

address: str
time: float

def json(self, *args, **kwargs):
return dump_content(self)
7 changes: 6 additions & 1 deletion aleph_message/models/execution/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from typing import Literal, Optional, Union

from pydantic import ConstrainedInt
from pydantic import ConstrainedInt, Extra

from ...utils import Gigabytes, gigabyte_to_mebibyte
from ..abstract import HashableModel
Expand All @@ -18,6 +18,11 @@ class AbstractVolume(HashableModel, ABC):
@abstractmethod
def is_read_only(self): ...

class Config:
# This is the only type where we really need to forbid extra fields.
# Otherwise the pydantic_encoder will take the first allowed type instead of the correct one.
extra = Extra.forbid


class ImmutableVolume(AbstractVolume):
ref: ItemHash
Expand Down
4 changes: 2 additions & 2 deletions aleph_message/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,12 @@ def test_create_new_message():
"chain": "ETH",
"sender": "0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9",
"type": "POST",
"time": "1625652287.017",
"time": 1625652287.017,
"item_type": "inline",
"content": {
"address": "0x101d8D16372dBf5f1614adaE95Ee5CCE61998Fc9",
"type": "test-message",
"time": "1625652287.017",
"time": 1625652287.017,
"content": {
"hello": "world",
},
Expand Down
41 changes: 40 additions & 1 deletion aleph_message/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations

import json
import math
from typing import NewType
from datetime import date, datetime, time
from typing import Any, Dict, NewType, Union

from pydantic import BaseModel
from pydantic.json import pydantic_encoder

Megabytes = NewType("Megabytes", int)
Mebibytes = NewType("Mebibytes", int)
Expand All @@ -15,3 +20,37 @@ def gigabyte_to_mebibyte(n: Gigabytes) -> Mebibytes:
mebibyte = 2**20
gigabyte = 10**9
return Mebibytes(math.ceil(n * gigabyte / mebibyte))


def extended_json_encoder(obj: Any) -> Any:
"""
Extended JSON encoder for dumping objects that contain pydantic models and datetime objects.
"""
if isinstance(obj, datetime):
return obj.timestamp()
elif isinstance(obj, date):
return obj.toordinal()
elif isinstance(obj, time):
return obj.hour * 3600 + obj.minute * 60 + obj.second + obj.microsecond / 1e6
else:
return pydantic_encoder(obj)


def dump_content(obj: Union[Dict, BaseModel]) -> str:
"""Dump message content as JSON string."""
if isinstance(obj, dict):
# without None values
obj = obj.copy()
for key in list(obj.keys()):
if obj[key] is None:
del obj[key]
return json.dumps(obj, separators=(",", ":"), default=extended_json_encoder)

if isinstance(obj, BaseModel):
return json.dumps(
obj.dict(exclude_none=True),
separators=(",", ":"),
default=extended_json_encoder,
)

raise TypeError(f"Invalid type: `{type(obj)}`")

0 comments on commit f7cc2fb

Please sign in to comment.