diff --git a/objc2-encode/src/encoding.rs b/objc2-encode/src/encoding.rs index 70746d3d7..4bda9112a 100644 --- a/objc2-encode/src/encoding.rs +++ b/objc2-encode/src/encoding.rs @@ -489,11 +489,6 @@ mod tests { !"{SomeStruct=}"; } - fn struct_unicode() { - Encoding::Struct("☃", &[Encoding::Char]); - "{☃=c}"; - } - fn pointer_struct() { Encoding::Pointer(&Encoding::Struct("SomeStruct", &[Encoding::Char, Encoding::Int])); !Encoding::Pointer(&Encoding::Struct("SomeStruct", &[Encoding::Int, Encoding::Char])); @@ -590,5 +585,22 @@ mod tests { ); "{abc=^[8B](def=@?)^^b255?}"; } + + fn identifier() { + Encoding::Struct("_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789", &[]); + "{_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789=}"; + } + } + + #[test] + #[should_panic = "Struct name was not a valid identifier"] + fn struct_unicode() { + let _ = Encoding::Struct("☃", &[Encoding::Char]).to_string(); + } + + #[test] + #[should_panic = "Union name was not a valid identifier"] + fn union_invalid_identifier() { + let _ = Encoding::Union("a-b", &[Encoding::Char]).equivalent_to_str("(☃=c)"); } } diff --git a/objc2-encode/src/helper.rs b/objc2-encode/src/helper.rs index e67c84c8e..bb64a6b2e 100644 --- a/objc2-encode/src/helper.rs +++ b/objc2-encode/src/helper.rs @@ -198,8 +198,32 @@ impl<'a> Helper<'a> { Pointer(t) => Self::Indirection(IndirectionKind::Pointer, t), Atomic(t) => Self::Indirection(IndirectionKind::Atomic, t), Array(len, item) => Self::Array(len, item), - Struct(name, fields) => Self::Container(ContainerKind::Struct, name, fields), - Union(name, members) => Self::Container(ContainerKind::Union, name, members), + Struct(name, fields) => { + if !verify_name(name) { + panic!("Struct name was not a valid identifier"); + } + Self::Container(ContainerKind::Struct, name, fields) + } + Union(name, members) => { + if !verify_name(name) { + panic!("Union name was not a valid identifier"); + } + Self::Container(ContainerKind::Union, name, members) + } } } } + +/// Check whether the name is a valid identifier +const fn verify_name(name: &str) -> bool { + let bytes = name.as_bytes(); + let mut i = 0; + while i < bytes.len() { + let byte = bytes[i]; + if !(byte.is_ascii_alphanumeric() || byte == b'_') { + return false; + } + i += 1; + } + true +}