diff --git a/impl/Cargo.toml b/impl/Cargo.toml index 170a12f..75c5b5e 100644 --- a/impl/Cargo.toml +++ b/impl/Cargo.toml @@ -22,6 +22,7 @@ syn = { version = "2", default-features = false, features = [ proc-macro2 = "1" quote = "1" indices = "0.3" +regex = "1" [features] default = [] diff --git a/impl/src/resolve.rs b/impl/src/resolve.rs index a5c4787..652f9d6 100644 --- a/impl/src/resolve.rs +++ b/impl/src/resolve.rs @@ -1,10 +1,12 @@ use std::collections::HashMap; use crate::ast::{ - AstErrorDeclaration, AstErrorSet, AstErrorVariant, AstInlineErrorVariantField, Disabled, RefError + AstErrorDeclaration, AstErrorSet, AstErrorVariant, AstInlineErrorVariantField, Disabled, + RefError, }; use crate::expand::{ErrorEnum, ErrorVariant, Named, SourceStruct, SourceTuple, Struct}; +use quote::ToTokens; use syn::{Attribute, Ident, TypeParam}; /// Constructs [ErrorEnum]s from the ast, resolving any references to other sets. The returned result is @@ -21,7 +23,8 @@ pub(crate) fn resolve(error_set: AstErrorSet) -> syn::Result> { parts, } = declaration; - let mut error_enum_builder = ErrorEnumBuilder::new(error_name, attributes, generics, disabled); + let mut error_enum_builder = + ErrorEnumBuilder::new(error_name, attributes, generics, disabled); for part in parts.into_iter() { match part { @@ -133,32 +136,42 @@ fn resolve_builders_helper<'a>( let type_path = syn::TypePath { qself: None, path }; syn::Type::Path(type_path) } - // rename the generics inside the variants to the new declared name - to avoid collisions. - let mut rename = HashMap::::new(); + + // rename the generics inside the variant fields to the new declared name - for `...= X ..`, `T` in this case. + let mut generic_type_to_new_generic_type = HashMap::::new(); + let mut generic_type_to_new_generic_type_str = HashMap::::new(); + let mut generic_type_str_to_regex = HashMap::::new(); for (ref_part_generic, ref_error_enum_generic) in ref_part .generic_refs .iter() .zip(ref_error_enum_builder.generics.iter()) { - rename.insert( + let old = ref_error_enum_generic.ident.to_string(); + // e.g. For "X", matches "", but not "" or "X" + let generic_identification_pattern = format!( + r"(?P[^\w\d]){}(?P[^\w\d])", + regex::escape(&old) + ); + let re = regex::Regex::new(&generic_identification_pattern).unwrap(); + generic_type_str_to_regex.insert(old.clone(), re); + let new = ref_part_generic.to_string(); + generic_type_to_new_generic_type_str.insert(old, new); + generic_type_to_new_generic_type.insert( ident_to_type(ref_error_enum_generic.ident.clone()), ident_to_type(ref_part_generic.clone()), ); } - // let error_variants = Vec::new(); + for error_variant in ref_error_enum_builder.error_variants.iter() { let new_fields = if let Some(fields) = &error_variant.fields { let mut new_fields = Vec::new(); for field in fields.iter() { - if rename.contains_key(&field.r#type) { - let new_type = rename.get(&field.r#type).unwrap().clone(); - new_fields.push(AstInlineErrorVariantField { - name: field.name.clone(), - r#type: new_type.clone(), - }); - } else { - new_fields.push(field.clone()); - } + new_fields.push(replace_generics_in_fields( + field, + &generic_type_to_new_generic_type, + &generic_type_to_new_generic_type_str, + &generic_type_str_to_regex, + )); } Some(new_fields) } else { @@ -226,7 +239,12 @@ struct ErrorEnumBuilder { } impl ErrorEnumBuilder { - fn new(error_name: Ident, attributes: Vec, generics: Vec, disabled: Disabled) -> Self { + fn new( + error_name: Ident, + attributes: Vec, + generics: Vec, + disabled: Disabled, + ) -> Self { Self { attributes, error_name, @@ -331,3 +349,44 @@ fn reshape(this: AstErrorVariant) -> ErrorVariant { } } } + +//************************************************************************// + +fn replace_generics_in_fields( + field: &AstInlineErrorVariantField, + old_to_new: &HashMap, + old_to_new_str: &HashMap, + old_to_match_regex: &HashMap, +) -> AstInlineErrorVariantField { + if old_to_new.contains_key(&field.r#type) { + let new_type = old_to_new.get(&field.r#type).unwrap().clone(); + return AstInlineErrorVariantField { + name: field.name.clone(), + r#type: new_type.clone(), + }; + } + // return field.clone(); + let field_type_str = field.r#type.to_token_stream().to_string(); + for (original_type, new_type) in old_to_new_str { + let regex = &old_to_match_regex[original_type]; + let replaced = replace_part(&field_type_str, new_type, regex); + if field_type_str != replaced { + let new_type = syn::parse_str::(&replaced) + .expect("Failed to parse replaced type back into type"); + return AstInlineErrorVariantField { + name: field.name.clone(), + r#type: new_type.clone(), + }; + } + } + return field.clone(); +} + +/// Assumes regex is `"(?P[^\w\d]){}(?P[^\w\d])"` as declared earlier +fn replace_part(input: &str, replacement: &str, re: ®ex::Regex) -> String { + re.replace_all(input, |caps: ®ex::Captures| { + // Reconstruct the matched segment with the replacement + format!("{}{}{}", &caps["before"], replacement, &caps["after"]) + }) + .to_string() +} diff --git a/tests/mod.rs b/tests/mod.rs index 3bfdeea..f035984 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -860,6 +860,36 @@ pub mod from_for_generic_and_regular { assert!(matches!(x, X::B(_))); } } + +#[cfg(test)] +pub mod generics_nested { + use error_set::error_set; + + #[derive(Debug)] + pub struct Wrapper(T); + + impl std::fmt::Display for Wrapper { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Wrapper({})", self.0) + } + } + + error_set!{ + X = { + A { + a: Wrapper + } + }; + Z = X; + } + + #[test] + fn test() { + let _x = X::A { a: Wrapper(1) }; + let _z = Z::A { a: Wrapper(1) }; + } +} + #[cfg(test)] pub mod should_not_compile_tests {