Skip to content

Commit

Permalink
chore: Add tests and validate schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
kulikthebird committed Nov 26, 2024
1 parent 2a737c8 commit 777aca3
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 82 deletions.
25 changes: 9 additions & 16 deletions packages/cw-schema-codegen/playground/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class __InnerStruct(BaseModel):
a: 'SomeEnum'
Field5: __InnerStruct

root: Union[Field1, Field2, Field3, Field4, Field5]
root: Union[Field1, Field2, Field3, Field4, Field5,]


class UnitStructure(RootModel):
Expand All @@ -46,21 +46,14 @@ class NamedStructure(BaseModel):
### TESTS:
###

for (index, input) in enumerate(sys.stdin):
input = input.rstrip()
try:
if index < 5:
deserialized = SomeEnum.model_validate_json(input)
elif index == 5:
deserialized = UnitStructure.model_validate_json(input)
elif index == 6:
deserialized = TupleStructure.model_validate_json(input)
else:
deserialized = NamedStructure.model_validate_json(input)
except:
raise(Exception(f"This json can't be deserialized: {input}"))
serialized = deserialized.model_dump_json()
print(serialized)
print(SomeEnum.model_validate_json(sys.stdin.readline().rstrip()).model_dump_json())
print(SomeEnum.model_validate_json(sys.stdin.readline().rstrip()).model_dump_json())
print(SomeEnum.model_validate_json(sys.stdin.readline().rstrip()).model_dump_json())
print(SomeEnum.model_validate_json(sys.stdin.readline().rstrip()).model_dump_json())
print(SomeEnum.model_validate_json(sys.stdin.readline().rstrip()).model_dump_json())
print(UnitStructure.model_validate_json(sys.stdin.readline().rstrip()).model_dump_json())
print(TupleStructure.model_validate_json(sys.stdin.readline().rstrip()).model_dump_json())
print(NamedStructure.model_validate_json(sys.stdin.readline().rstrip()).model_dump_json())


# def handle_msg(json):
Expand Down
27 changes: 13 additions & 14 deletions packages/cw-schema-codegen/src/python/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use self::template::{
EnumTemplate, EnumVariantTemplate, FieldTemplate, StructTemplate, TypeTemplate,
};
use heck::ToPascalCase;
use std::{borrow::Cow, io};

pub mod template;
Expand All @@ -15,15 +14,15 @@ fn expand_node_name<'a>(
let items = &schema.definitions[items];
format!("{}[]", expand_node_name(schema, items)).into()
}
cw_schema::NodeType::Float => "number".into(),
cw_schema::NodeType::Double => "number".into(),
cw_schema::NodeType::Boolean => "boolean".into(),
cw_schema::NodeType::String => "string".into(),
cw_schema::NodeType::Integer { .. } => "string".into(),
cw_schema::NodeType::Binary => "Uint8Array".into(),
cw_schema::NodeType::Float => "float".into(),
cw_schema::NodeType::Double => "float".into(),
cw_schema::NodeType::Boolean => "bool".into(),
cw_schema::NodeType::String => "str".into(),
cw_schema::NodeType::Integer { .. } => "int".into(),
cw_schema::NodeType::Binary => "bytes".into(),
cw_schema::NodeType::Optional { inner } => {
let inner = &schema.definitions[inner];
format!("{} | null", expand_node_name(schema, inner)).into()
format!("typing.Optional[{}]", expand_node_name(schema, inner)).into()
}
cw_schema::NodeType::Struct(..) => node.name.as_ref().into(),
cw_schema::NodeType::Tuple { ref items } => {
Expand All @@ -37,13 +36,13 @@ fn expand_node_name<'a>(
}
cw_schema::NodeType::Enum { .. } => node.name.as_ref().into(),

cw_schema::NodeType::Decimal { .. } => "string".into(),
cw_schema::NodeType::Address => "string".into(),
cw_schema::NodeType::Decimal { .. } => "decimal.Decimal".into(),
cw_schema::NodeType::Address => "str".into(),
cw_schema::NodeType::Checksum => todo!(),
cw_schema::NodeType::HexBinary => todo!(),
cw_schema::NodeType::Timestamp => todo!(),
cw_schema::NodeType::Unit => Cow::Borrowed("void"),
_ => todo!()
cw_schema::NodeType::Unit => "None".into(),
_ => todo!(),
}
}

Expand Down Expand Up @@ -83,7 +82,7 @@ where
.map(|item| expand_node_name(schema, &schema.definitions[*item]))
.collect(),
),
_ => todo!()
_ => todo!(),
},
};

Expand Down Expand Up @@ -125,7 +124,7 @@ where
.collect(),
}
}
_ => todo!()
_ => todo!(),
},
})
.collect(),
Expand Down
27 changes: 15 additions & 12 deletions packages/cw-schema-codegen/templates/python/enum.tpl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This code is @generated by cw-schema-codegen. Do not modify this manually.

import typing
import decimal
from pydantic import BaseModel, RootModel

class {{ name }}(RootModel):
Expand All @@ -15,24 +16,26 @@ class {{ variant.name }}(RootModel):
"""{% for doc in variant.docs %}
{{ doc }}
{% endfor %}"""
root: None
root: typing.Literal['{{ variant.name }}']
{% when TypeTemplate::Tuple with (types) %}
class {{ variant.name }}(BaseModel):
"""{% for doc in variant.docs %}
{{ doc }}
{% endfor %}"""
{{ variant.name }}: typing.Tuple[{{ types|join(", ") }}]
{% when TypeTemplate::Named with { fields } %}
class __Inner:
"""{% for doc in variant.docs %}
{{ doc }}
{% endfor %}"""
{% for field in fields %}
{{ field.name }}: {{ field.ty }}
"""{% for doc in field.docs %}
# {{ doc }}
{% endfor %}"""
{% endfor %}
class {{ variant.name }}(BaseModel):
class __Inner(BaseModel):
"""{% for doc in variant.docs %}
{{ doc }}
{% endfor %}"""
{% for field in fields %}
{{ field.name }}: {{ field.ty }}
"""{% for doc in field.docs %}
{{ doc }}
{% endfor %}"""
{% endfor %}
{{ variant.name }}: __Inner
{% endmatch %}
{% endfor %}
{% endfor %}
root: typing.Union[ {% for variant in variants %} {{ variant.name }}, {% endfor %} ]
3 changes: 2 additions & 1 deletion packages/cw-schema-codegen/templates/python/struct.tpl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This code is @generated by cw-schema-codegen. Do not modify this manually.

import typing
import decimal
from pydantic import BaseModel, RootModel


Expand All @@ -10,7 +11,7 @@ class {{ name }}(RootModel):
'''{% for doc in docs %}
{{ doc }}
{% endfor %}'''
root: None
root: None
{% when TypeTemplate::Tuple with (types) %}
class {{ name }}(RootModel):
'''{% for doc in docs %}
Expand Down
175 changes: 151 additions & 24 deletions packages/cw-schema-codegen/tests/python_tpl.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,156 @@
use std::borrow::Cow;
use cw_schema::Schemaifier;
use serde::{Deserialize, Serialize};
use std::io::Write;

use askama::Template;
use cw_schema_codegen::python::template::{
EnumTemplate, EnumVariantTemplate, FieldTemplate, StructTemplate, TypeTemplate,
};
#[derive(Schemaifier, Serialize, Deserialize)]
pub enum SomeEnum {
Field1,
Field2(u32, u32),
Field3 { a: String, b: u32 },
// Field4(Box<SomeEnum>), // TODO tkulik: Do we want to support Box<T> ?
// Field5 { a: Box<SomeEnum> },
}

#[derive(Schemaifier, Serialize, Deserialize)]
pub struct UnitStructure;

#[derive(Schemaifier, Serialize, Deserialize)]
pub struct TupleStructure(u32, String, u128);

#[derive(Schemaifier, Serialize, Deserialize)]
pub struct NamedStructure {
a: String,
b: u8,
c: SomeEnum,
}

#[test]
fn simple_enum() {
let tpl = EnumTemplate {
name: Cow::Borrowed("Simple"),
docs: Cow::Borrowed(&[Cow::Borrowed("Simple enum")]),
variants: Cow::Borrowed(&[
EnumVariantTemplate {
name: Cow::Borrowed("One"),
docs: Cow::Borrowed(&[Cow::Borrowed("One variant")]),
ty: TypeTemplate::Unit,
},
EnumVariantTemplate {
name: Cow::Borrowed("Two"),
docs: Cow::Borrowed(&[Cow::Borrowed("Two variant")]),
ty: TypeTemplate::Unit,
},
]),
};

let rendered = tpl.render().unwrap();
insta::assert_snapshot!(rendered);
// generate the schemas for each of the above types
let schemas = [
cw_schema::schema_of::<SomeEnum>(),
cw_schema::schema_of::<UnitStructure>(),
cw_schema::schema_of::<TupleStructure>(),
cw_schema::schema_of::<NamedStructure>(),
];

// run the codegen to typescript
for schema in schemas {
let cw_schema::Schema::V1(schema) = schema else {
panic!();
};

let output = schema
.definitions
.iter()
.map(|node| {
let mut buf = Vec::new();
cw_schema_codegen::python::process_node(&mut buf, &schema, node).unwrap();
String::from_utf8(buf).unwrap()
})
.collect::<String>();

insta::assert_snapshot!(output);
}
}

macro_rules! validator {
($typ:ty) => {{
let a: Box<dyn FnOnce(&str) -> ()> = Box::new(|output| {
serde_json::from_str::<$typ>(output).unwrap();
});
a
}};
}

#[test]
fn assert_validity() {
let schemas = [
(
"SomeEnum",
cw_schema::schema_of::<SomeEnum>(),
serde_json::to_string(&SomeEnum::Field1).unwrap(),
validator!(SomeEnum),
),
(
"SomeEnum",
cw_schema::schema_of::<SomeEnum>(),
serde_json::to_string(&SomeEnum::Field2(10, 23)).unwrap(),
validator!(SomeEnum),
),
(
"SomeEnum",
cw_schema::schema_of::<SomeEnum>(),
serde_json::to_string(&SomeEnum::Field3 {
a: "sdf".to_string(),
b: 12,
})
.unwrap(),
validator!(SomeEnum),
),
(
"UnitStructure",
cw_schema::schema_of::<UnitStructure>(),
serde_json::to_string(&UnitStructure {}).unwrap(),
validator!(UnitStructure),
),
(
"TupleStructure",
cw_schema::schema_of::<TupleStructure>(),
serde_json::to_string(&TupleStructure(10, "aasdf".to_string(), 2)).unwrap(),
validator!(TupleStructure),
),
(
"NamedStructure",
cw_schema::schema_of::<NamedStructure>(),
serde_json::to_string(&NamedStructure {
a: "awer".to_string(),
b: 4,
c: SomeEnum::Field1,
})
.unwrap(),
validator!(NamedStructure),
),
];

for (type_name, schema, example, validator) in schemas {
let cw_schema::Schema::V1(schema) = schema else {
unreachable!();
};

let schema_output = schema
.definitions
.iter()
.map(|node| {
let mut buf = Vec::new();
cw_schema_codegen::python::process_node(&mut buf, &schema, node).unwrap();
String::from_utf8(buf).unwrap()
})
.collect::<String>();

let mut file = tempfile::NamedTempFile::with_suffix(".py").unwrap();
file.write_all(schema_output.as_bytes()).unwrap();
file.write(
format!(
"import sys; print({type_name}.model_validate_json('{example}').model_dump_json())"
)
.as_bytes(),
)
.unwrap();
file.flush().unwrap();

let output = std::process::Command::new("python")
.arg(file.path())
.output()
.unwrap();

assert!(
output.status.success(),
"stdout: {stdout}, stderr: {stderr}\n\n schema:\n {schema_output}",
stdout = String::from_utf8_lossy(&output.stdout),
stderr = String::from_utf8_lossy(&output.stderr),
);

validator(&String::from_utf8_lossy(&output.stdout))
}
}
Loading

0 comments on commit 777aca3

Please sign in to comment.