Skip to content

Commit

Permalink
feat: Support nesed generics
Browse files Browse the repository at this point in the history
  • Loading branch information
mcmah309 committed Dec 23, 2024
1 parent 1d6e577 commit 2ffec06
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 16 deletions.
1 change: 1 addition & 0 deletions impl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ syn = { version = "2", default-features = false, features = [
proc-macro2 = "1"
quote = "1"
indices = "0.3"
regex = "1"

[features]
default = []
Expand Down
91 changes: 75 additions & 16 deletions impl/src/resolve.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::collections::HashMap;

use crate::ast::{
AstErrorDeclaration, AstErrorSet, AstErrorVariant, AstInlineErrorVariantField, Disabled, RefError
AstErrorDeclaration, AstErrorSet, AstErrorVariant, AstInlineErrorVariantField, Disabled,
RefError,
};
use crate::expand::{ErrorEnum, ErrorVariant, Named, SourceStruct, SourceTuple, Struct};

use quote::ToTokens;
use syn::{Attribute, Ident, TypeParam};

/// Constructs [ErrorEnum]s from the ast, resolving any references to other sets. The returned result is
Expand All @@ -21,7 +23,8 @@ pub(crate) fn resolve(error_set: AstErrorSet) -> syn::Result<Vec<ErrorEnum>> {
parts,
} = declaration;

let mut error_enum_builder = ErrorEnumBuilder::new(error_name, attributes, generics, disabled);
let mut error_enum_builder =
ErrorEnumBuilder::new(error_name, attributes, generics, disabled);

for part in parts.into_iter() {
match part {
Expand Down Expand Up @@ -133,32 +136,42 @@ fn resolve_builders_helper<'a>(
let type_path = syn::TypePath { qself: None, path };
syn::Type::Path(type_path)
}
// rename the generics inside the variants to the new declared name - to avoid collisions.
let mut rename = HashMap::<syn::Type, syn::Type>::new();

// rename the generics inside the variant fields to the new declared name - for `...= X<T> ..`, `T` in this case.
let mut generic_type_to_new_generic_type = HashMap::<syn::Type, syn::Type>::new();
let mut generic_type_to_new_generic_type_str = HashMap::<String, String>::new();
let mut generic_type_str_to_regex = HashMap::<String, regex::Regex>::new();
for (ref_part_generic, ref_error_enum_generic) in ref_part
.generic_refs
.iter()
.zip(ref_error_enum_builder.generics.iter())
{
rename.insert(
let old = ref_error_enum_generic.ident.to_string();
// e.g. For "X", matches "<X>", but not "<X" or "X>" or "X"
let generic_identification_pattern = format!(
r"(?P<before>[^\w\d]){}(?P<after>[^\w\d])",
regex::escape(&old)
);
let re = regex::Regex::new(&generic_identification_pattern).unwrap();
generic_type_str_to_regex.insert(old.clone(), re);
let new = ref_part_generic.to_string();
generic_type_to_new_generic_type_str.insert(old, new);
generic_type_to_new_generic_type.insert(
ident_to_type(ref_error_enum_generic.ident.clone()),
ident_to_type(ref_part_generic.clone()),
);
}
// let error_variants = Vec::new();

for error_variant in ref_error_enum_builder.error_variants.iter() {
let new_fields = if let Some(fields) = &error_variant.fields {
let mut new_fields = Vec::new();
for field in fields.iter() {
if rename.contains_key(&field.r#type) {
let new_type = rename.get(&field.r#type).unwrap().clone();
new_fields.push(AstInlineErrorVariantField {
name: field.name.clone(),
r#type: new_type.clone(),
});
} else {
new_fields.push(field.clone());
}
new_fields.push(replace_generics_in_fields(
field,
&generic_type_to_new_generic_type,
&generic_type_to_new_generic_type_str,
&generic_type_str_to_regex,
));
}
Some(new_fields)
} else {
Expand Down Expand Up @@ -226,7 +239,12 @@ struct ErrorEnumBuilder {
}

impl ErrorEnumBuilder {
fn new(error_name: Ident, attributes: Vec<Attribute>, generics: Vec<TypeParam>, disabled: Disabled) -> Self {
fn new(
error_name: Ident,
attributes: Vec<Attribute>,
generics: Vec<TypeParam>,
disabled: Disabled,
) -> Self {
Self {
attributes,
error_name,
Expand Down Expand Up @@ -331,3 +349,44 @@ fn reshape(this: AstErrorVariant) -> ErrorVariant {
}
}
}

//************************************************************************//

fn replace_generics_in_fields(
field: &AstInlineErrorVariantField,
old_to_new: &HashMap<syn::Type, syn::Type>,
old_to_new_str: &HashMap<String, String>,
old_to_match_regex: &HashMap<String, regex::Regex>,
) -> AstInlineErrorVariantField {
if old_to_new.contains_key(&field.r#type) {
let new_type = old_to_new.get(&field.r#type).unwrap().clone();
return AstInlineErrorVariantField {
name: field.name.clone(),
r#type: new_type.clone(),
};
}
// return field.clone();
let field_type_str = field.r#type.to_token_stream().to_string();
for (original_type, new_type) in old_to_new_str {
let regex = &old_to_match_regex[original_type];
let replaced = replace_part(&field_type_str, new_type, regex);
if field_type_str != replaced {
let new_type = syn::parse_str::<syn::Type>(&replaced)
.expect("Failed to parse replaced type back into type");
return AstInlineErrorVariantField {
name: field.name.clone(),
r#type: new_type.clone(),
};
}
}
return field.clone();
}

/// Assumes regex is `"(?P<before>[^\w\d]){}(?P<after>[^\w\d])"` as declared earlier
fn replace_part(input: &str, replacement: &str, re: &regex::Regex) -> String {
re.replace_all(input, |caps: &regex::Captures| {
// Reconstruct the matched segment with the replacement
format!("{}{}{}", &caps["before"], replacement, &caps["after"])
})
.to_string()
}
30 changes: 30 additions & 0 deletions tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,36 @@ pub mod from_for_generic_and_regular {
assert!(matches!(x, X::B(_)));
}
}

#[cfg(test)]
pub mod generics_nested {
use error_set::error_set;

#[derive(Debug)]
pub struct Wrapper<T: core::fmt::Debug + core::fmt::Display>(T);

impl<T: core::fmt::Debug + core::fmt::Display> std::fmt::Display for Wrapper<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Wrapper({})", self.0)
}
}

error_set!{
X<H: core::fmt::Debug + core::fmt::Display> = {
A {
a: Wrapper<H>
}
};
Z<T: core::fmt::Debug + core::fmt::Display> = X<T>;
}

#[test]
fn test() {
let _x = X::A { a: Wrapper(1) };
let _z = Z::A { a: Wrapper(1) };
}
}

#[cfg(test)]
pub mod should_not_compile_tests {

Expand Down

0 comments on commit 2ffec06

Please sign in to comment.