Skip to content

Commit

Permalink
fix: dont requires generic trait to declare Clone + PartialEq
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard committed Jun 14, 2022
1 parent e09adf5 commit 53678ef
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 28 deletions.
54 changes: 54 additions & 0 deletions mock-it_codegen/src/generics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use proc_macro2::Ident;
use quote::quote;
use syn::{parse2, Generics, WhereClause, WherePredicate};

pub struct TraitGenerics {
generics: Generics,
}

impl TraitGenerics {
pub fn new(generics: Generics) -> Self {
Self { generics }
}

pub fn types(&self) -> Vec<Ident> {
self.generics
.type_params()
.into_iter()
.map(|tp| tp.ident.clone())
.collect()
}

pub fn add_predicates(&mut self, predicate: WherePredicate) {
let where_clause: WhereClause = match &self.generics.where_clause {
Some(val) => parse2(quote! {
#val
#predicate,
})
.unwrap(),
None => parse2(quote! {
where
#predicate,
})
.unwrap(),
};
self.generics.where_clause = Some(where_clause);
}
}

impl Into<Generics> for TraitGenerics {
fn into(self) -> Generics {
self.generics
}
}

pub fn add_generics(generics: &Generics) -> Generics {
let mut trait_generics = TraitGenerics::new(generics.clone());
trait_generics
.types()
.into_iter()
.map(|ty| parse2(quote! { #ty: Clone + PartialEq }).unwrap())
.for_each(|predicate| trait_generics.add_predicates(predicate));

trait_generics.into()
}
28 changes: 14 additions & 14 deletions mock-it_codegen/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
extern crate proc_macro;

mod generics;
mod mock_fn;
mod trait_method;

use generics::add_generics;
use mock_fn::{mock_fns, MockFn};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse_macro_input, Ident, Item};
use trait_method::{get_trait_method_types, TraitMethodType};
use trait_method::get_trait_method_types;

/// Generate a mock struct from a trait. The mock struct will be named after the
/// trait, with "Mock" appended.
Expand All @@ -25,9 +27,6 @@ pub fn mock_it(
_ => panic!("Only traits can be mocked with the mock_it macro"),
};

let generics = item_trait.generics.clone();
let generics_where = item_trait.generics.where_clause.clone();

let trait_method_types = get_trait_method_types(&item_trait);
let mock_fns = mock_fns(trait_method_types.clone());
let helper_functions: Vec<TokenStream> = mock_fns
Expand All @@ -43,16 +42,19 @@ pub fn mock_it(
let fields = create_fields(&mock_fns);
let field_init = create_field_init(&mock_fns);
let trait_impls = create_trait_impls(&mock_fns);
let clone_impl = create_clone_impl(&trait_method_types);
let clone_impl = create_clone_impl(&mock_fns);

let generics = add_generics(&item_trait.generics);
let (generics_impl, generics_ty, generics_where) = generics.split_for_impl();

let output = quote! {
#item_trait

pub struct #mock_ident #generics #generics_where {
pub struct #mock_ident #generics_ty #generics_where {
#(#fields),*
}

impl #generics #mock_ident #generics #generics_where {
impl #generics_impl #mock_ident #generics_ty #generics_where {
pub fn new() -> Self {
#mock_ident {
#(#field_init),*
Expand All @@ -63,15 +65,15 @@ pub fn mock_it(

}

impl #generics std::clone::Clone for #mock_ident #generics #generics_where {
impl #generics_impl std::clone::Clone for #mock_ident #generics_ty #generics_where {
fn clone(&self) -> Self {
#mock_ident {
#(#clone_impl),*
}
}
}

impl #generics #trait_ident #generics for #mock_ident #generics #generics_where {
impl #generics_impl #trait_ident #generics_ty for #mock_ident #generics_ty #generics_where {
#(#trait_impls)*
}
};
Expand Down Expand Up @@ -110,11 +112,9 @@ fn create_field_init(mock_fns: &Vec<MockFn>) -> Vec<TokenStream> {
}

/// Create the clone implementation
fn create_clone_impl(
trait_method_types: &Vec<TraitMethodType>,
) -> impl Iterator<Item = TokenStream> + '_ {
trait_method_types.iter().map(|method_type| {
let ident = &method_type.signature.ident;
fn create_clone_impl(mock_fns: &Vec<MockFn>) -> impl Iterator<Item = TokenStream> + '_ {
mock_fns.iter().map(|mock_fn| {
let ident = &mock_fn.signature().ident;
quote! {
#ident: self.#ident.clone()
}
Expand Down
15 changes: 3 additions & 12 deletions tests/codegen_generic_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ mod simple {
use mock_it::{any, eq, mock_it};

#[mock_it]
trait ATrait<T>
where
T: Clone + PartialEq,
{
trait ATrait<T> {
fn a_fn(&self, arg1: T);
}

Expand Down Expand Up @@ -45,10 +42,7 @@ mod two_methods {
use mock_it::{any, eq, mock_it};

#[mock_it]
trait ATrait<T>
where
T: Clone + PartialEq,
{
trait ATrait<T> {
fn a_fn(&self, arg1: T);
fn another_fn(&self, arg1: &str);
}
Expand Down Expand Up @@ -85,10 +79,7 @@ mod with_lifetime {
use mock_it::{any, eq, mock_it};

#[mock_it]
trait ATrait<'a, T>
where
T: Clone + PartialEq,
{
trait ATrait<'a, T> {
fn a_fn(&self, arg1: &'a T) -> &'a str;
}

Expand Down
4 changes: 2 additions & 2 deletions tests/codegen_reference_sized.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use mock_it::{any, eq, mock_it};
use mock_it::{any, eq};

#[mock_it]
#[cfg_attr(test, mock_it::mock_it)]
pub trait ATrait {
fn a_fn(&self, sized1: &usize, sized2: &String) -> String;
}
Expand Down

0 comments on commit 53678ef

Please sign in to comment.