Skip to content

Commit

Permalink
feat(rust): Adds the ability to specify custom derive attributes (#678)
Browse files Browse the repository at this point in the history
I also did a little bit of cleanup of clippy warnings while I was in
these files.

Closes #554

Signed-off-by: Taylor Thomas <[email protected]>
  • Loading branch information
thomastaylor312 authored Sep 27, 2023
1 parent 84dd60c commit 8c2abf4
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 28 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 42 additions & 12 deletions crates/rust-lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use heck::*;
use std::collections::BTreeSet;
use std::fmt;
use std::str::FromStr;
use wit_bindgen_core::abi::{Bitcast, LiftLower, WasmType};
Expand Down Expand Up @@ -91,6 +92,9 @@ pub trait RustGenerator<'a> {

fn is_exported_resource(&self, ty: TypeId) -> bool;

/// Return any additional derive attributes to add to the generated types
fn additional_derives(&self) -> &[String];

fn mark_resource_owned(&mut self, resource: TypeId);

fn push_str(&mut self, s: &str);
Expand Down Expand Up @@ -159,7 +163,7 @@ pub trait RustGenerator<'a> {
param_mode: TypeMode,
sig: &FnSig,
) -> Vec<String> {
let params = self.print_docs_and_params(func, param_mode, &sig);
let params = self.print_docs_and_params(func, param_mode, sig);
if let FunctionKind::Constructor(_) = &func.kind {
self.push_str(" -> Self")
} else {
Expand Down Expand Up @@ -198,7 +202,7 @@ pub trait RustGenerator<'a> {
} else {
&func.name
};
self.push_str(&to_rust_ident(&func_name));
self.push_str(&to_rust_ident(func_name));
if let Some(generics) = &sig.generics {
self.push_str(generics);
}
Expand Down Expand Up @@ -563,7 +567,7 @@ pub trait RustGenerator<'a> {
if self.uses_two_names(&info) {
result.push((self.param_name(ty), TypeMode::AllBorrowed("'a")));
}
return result;
result
}

fn print_typedef_record(
Expand All @@ -574,6 +578,9 @@ pub trait RustGenerator<'a> {
derive_component: bool,
) {
let info = self.info(id);
// We use a BTree set to make sure we don't have any duplicates and we have a stable order
let additional_derives: BTreeSet<String> =
self.additional_derives().iter().cloned().collect();
for (name, mode) in self.modes_of(id) {
let lt = self.lifetime_for(&info, mode);
self.rustdoc(docs);
Expand All @@ -586,12 +593,17 @@ pub trait RustGenerator<'a> {
self.push_str("#[derive(wasmtime::component::Lower)]\n");
self.push_str("#[component(record)]\n");
}

let mut derives = additional_derives.clone();
if info.is_copy() {
self.push_str("#[repr(C)]\n");
self.push_str("#[derive(Copy, Clone)]\n");
derives.extend(["Copy", "Clone"].into_iter().map(|s| s.to_string()));
} else if info.is_clone() {
self.push_str("#[derive(Clone)]\n");
derives.insert("Clone".to_string());
}
if !derives.is_empty() {
self.push_str("#[derive(");
self.push_str(&derives.into_iter().collect::<Vec<_>>().join(", "));
self.push_str(")]\n")
}
self.push_str(&format!("pub struct {}", name));
self.print_generics(lt);
Expand Down Expand Up @@ -707,7 +719,9 @@ pub trait RustGenerator<'a> {
Self: Sized,
{
let info = self.info(id);

// We use a BTree set to make sure we don't have any duplicates and have a stable order
let additional_derives: BTreeSet<String> =
self.additional_derives().iter().cloned().collect();
for (name, mode) in self.modes_of(id) {
self.rustdoc(docs);
let lt = self.lifetime_for(&info, mode);
Expand All @@ -719,10 +733,16 @@ pub trait RustGenerator<'a> {
self.push_str("#[derive(wasmtime::component::Lower)]\n");
self.push_str(&format!("#[component({})]\n", derive_component));
}
let mut derives = additional_derives.clone();
if info.is_copy() {
self.push_str("#[derive(Copy, Clone)]\n");
derives.extend(["Copy", "Clone"].into_iter().map(|s| s.to_string()));
} else if info.is_clone() {
self.push_str("#[derive(Clone)]\n");
derives.insert("Clone".to_string());
}
if !derives.is_empty() {
self.push_str("#[derive(");
self.push_str(&derives.into_iter().collect::<Vec<_>>().join(", "));
self.push_str(")]\n")
}
self.push_str(&format!("pub enum {name}"));
self.print_generics(lt);
Expand Down Expand Up @@ -873,7 +893,17 @@ pub trait RustGenerator<'a> {
}
self.push_str("#[repr(");
self.int_repr(enum_.tag());
self.push_str(")]\n#[derive(Clone, Copy, PartialEq, Eq)]\n");
self.push_str(")]\n");
// We use a BTree set to make sure we don't have any duplicates and a stable order
let mut derives: BTreeSet<String> = self.additional_derives().iter().cloned().collect();
derives.extend(
["Clone", "Copy", "PartialEq", "Eq"]
.into_iter()
.map(|s| s.to_string()),
);
self.push_str("#[derive(");
self.push_str(&derives.into_iter().collect::<Vec<_>>().join(", "));
self.push_str(")]\n");
self.push_str(&format!("pub enum {name} {{\n"));
for case in enum_.cases.iter() {
self.rustdoc(&case.docs);
Expand Down Expand Up @@ -1167,10 +1197,10 @@ pub trait RustFunctionGenerator {
for (field, val) in ty.fields.iter().zip(operands) {
result.push_str(&to_rust_ident(&field.name));
result.push_str(": ");
result.push_str(&val);
result.push_str(val);
result.push_str(",\n");
}
result.push_str("}");
result.push('}');
results.push(result);
}

Expand Down
3 changes: 2 additions & 1 deletion crates/rust-macro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ test = false

[dependencies]
proc-macro2 = "1.0"
syn = "2.0"
syn = { version = "2.0", features = ["printing"] }
quote = "1"
wit-bindgen-core = { workspace = true }
wit-bindgen-rust = { workspace = true }
wit-bindgen-rust-lib = { workspace = true }
Expand Down
23 changes: 20 additions & 3 deletions crates/rust-macro/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use proc_macro2::{Span, TokenStream};
use quote::ToTokens;
use std::path::{Path, PathBuf};
use syn::parse::{Error, Parse, ParseStream, Result};
use syn::punctuated::Punctuated;
Expand Down Expand Up @@ -73,6 +74,12 @@ impl Parse for Config {
opts.stubs = true;
}
Opt::ExportPrefix(prefix) => opts.export_prefix = Some(prefix.value()),
Opt::AdditionalDerives(paths) => {
opts.additional_derive_attributes = paths
.into_iter()
.map(|p| p.into_token_stream().to_string())
.collect()
}
}
}
} else {
Expand Down Expand Up @@ -101,7 +108,7 @@ fn parse_source(source: &Option<Source>) -> anyhow::Result<(Resolve, PackageId,
let root = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap());
let mut parse = |path: &Path| -> anyhow::Result<_> {
if path.is_dir() {
let (pkg, sources) = resolve.push_dir(&path)?;
let (pkg, sources) = resolve.push_dir(path)?;
files = sources;
Ok(pkg)
} else {
Expand All @@ -112,9 +119,9 @@ fn parse_source(source: &Option<Source>) -> anyhow::Result<(Resolve, PackageId,
};
let pkg = match source {
Some(Source::Inline(s)) => {
resolve.push(UnresolvedPackage::parse("macro-input".as_ref(), &s)?)?
resolve.push(UnresolvedPackage::parse("macro-input".as_ref(), s)?)?
}
Some(Source::Path(s)) => parse(&root.join(&s))?,
Some(Source::Path(s)) => parse(&root.join(s))?,
None => parse(&root.join("wit"))?,
};

Expand Down Expand Up @@ -159,6 +166,7 @@ mod kw {
syn::custom_keyword!(exports);
syn::custom_keyword!(stubs);
syn::custom_keyword!(export_prefix);
syn::custom_keyword!(additional_derives);
}

#[derive(Clone)]
Expand Down Expand Up @@ -216,6 +224,8 @@ enum Opt {
Exports(Vec<Export>),
Stubs,
ExportPrefix(syn::LitStr),
// Parse as paths so we can take the concrete types/macro names rather than raw strings
AdditionalDerives(Vec<syn::Path>),
}

impl Parse for Opt {
Expand Down Expand Up @@ -306,6 +316,13 @@ impl Parse for Opt {
input.parse::<kw::export_prefix>()?;
input.parse::<Token![:]>()?;
Ok(Opt::ExportPrefix(input.parse()?))
} else if l.peek(kw::additional_derives) {
input.parse::<kw::additional_derives>()?;
input.parse::<Token![:]>()?;
let contents;
syn::bracketed!(contents in input);
let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
Ok(Opt::AdditionalDerives(list.iter().cloned().collect()))
} else {
Err(l.error())
}
Expand Down
3 changes: 3 additions & 0 deletions crates/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ clap = { workspace = true, optional = true }
[dev-dependencies]
wit-bindgen = { path = '../guest-rust' }
test-helpers = { path = '../test-helpers' }
# For use with the custom attributes test
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
39 changes: 27 additions & 12 deletions crates/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ pub struct Opts {
/// This defaults to `wit_bindgen::bitflags`.
#[cfg_attr(feature = "clap", arg(long))]
pub bitflags_path: Option<String>,

/// Additional derive attributes to add to generated types. If using in a CLI, this flag can be
/// specified multiple times to add multiple attributes.
///
/// These derive attributes will be added to any generated structs or enums
#[cfg_attr(feature = "clap", arg(long = "additional_derive_attribute", short = 'd', default_values_t = Vec::<String>::new()))]
pub additional_derive_attributes: Vec<String>,
}

impl Opts {
Expand Down Expand Up @@ -543,9 +550,11 @@ impl InterfaceGenerator<'_> {
.entry(export_key)
.or_insert((trait_name, local_impl_name, Vec::new()));
let prev = mem::take(&mut self.src);
let mut sig = FnSig::default();
sig.use_item_name = true;
sig.private = true;
let mut sig = FnSig {
use_item_name: true,
private: true,
..Default::default()
};
if let FunctionKind::Method(_) = &func.kind {
sig.self_arg = Some("&self".into());
sig.self_is_first_param = true;
Expand Down Expand Up @@ -915,9 +924,11 @@ impl InterfaceGenerator<'_> {
if self.gen.skip.contains(&func.name) {
continue;
}
let mut sig = FnSig::default();
sig.use_item_name = true;
sig.private = true;
let mut sig = FnSig {
use_item_name: true,
private: true,
..Default::default()
};
if let FunctionKind::Method(_) = &func.kind {
sig.self_arg = Some("&self".into());
sig.self_is_first_param = true;
Expand All @@ -939,6 +950,10 @@ impl<'a> RustGenerator<'a> for InterfaceGenerator<'a> {
self.gen.opts.ownership
}

fn additional_derives(&self) -> &[String] {
&self.gen.opts.additional_derive_attributes
}

fn path_to_interface(&self, interface: InterfaceId) -> Option<String> {
let mut path = String::new();
if let Identifier::Interface(cur, name) = self.identifier {
Expand All @@ -958,7 +973,7 @@ impl<'a> RustGenerator<'a> for InterfaceGenerator<'a> {
}
}
let name = &self.gen.interface_names[&interface];
path.push_str(&name);
path.push_str(name);
Some(path)
}

Expand Down Expand Up @@ -1340,7 +1355,7 @@ impl<'a, 'b> FunctionBindgen<'a, 'b> {
sig.push_str(wasm_type(*param));
sig.push_str(", ");
}
sig.push_str(")");
sig.push(')');
assert!(results.len() < 2);
for result in results.iter() {
sig.push_str(" -> ");
Expand Down Expand Up @@ -1398,7 +1413,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
}

fn finish_block(&mut self, operands: &mut Vec<String>) {
if self.cleanup.len() > 0 {
if !self.cleanup.is_empty() {
self.needs_cleanup_list = true;
self.push_str("cleanup_list.extend_from_slice(&[");
for (ptr, layout) in mem::take(&mut self.cleanup) {
Expand Down Expand Up @@ -1677,7 +1692,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
self.push_str(&format!(" => {{\n{block}\n}}\n"));
}
}
if results.len() == 0 {
if results.is_empty() {
self.push_str("}\n");
} else {
self.push_str("};\n");
Expand Down Expand Up @@ -1901,7 +1916,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
self.push_str(&format!(
"if ptr.is_null()\n{{\nalloc::handle_alloc_error({layout});\n}}\nptr\n}}",
));
self.push_str(&format!("else {{\n::core::ptr::null_mut()\n}};\n",));
self.push_str("else {{\n::core::ptr::null_mut()\n}};\n");
self.push_str(&format!("for (i, e) in {vec}.into_iter().enumerate() {{\n",));
self.push_str(&format!(
"let base = {result} as i32 + (i as i32) * {size};\n",
Expand Down Expand Up @@ -1964,7 +1979,7 @@ impl Bindgen for FunctionBindgen<'_, '_> {
);

// ... then call the function with all our operands
if sig.results.len() > 0 {
if !sig.results.is_empty() {
self.push_str("let ret = ");
results.push("ret".to_string());
}
Expand Down
46 changes: 46 additions & 0 deletions crates/rust/tests/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,49 @@ mod package_with_versions {
}
}
}

mod custom_derives {
use std::collections::{hash_map::RandomState, HashSet};

wit_bindgen::generate!({
inline: "
package my:inline
interface blah {
record foo {
field1: string,
field2: list<u32>
}
bar: func(cool: foo)
}
world baz {
export blah
}
",
exports: {
"my:inline/blah": Component
},
// Clone is included by default almost everywhere, so include it here to make sure it
// doesn't conflict
additional_derives: [serde::Serialize, serde::Deserialize, Hash, Clone, PartialEq, Eq],
});

use exports::my::inline::blah::Foo;

struct Component;
impl exports::my::inline::blah::Guest for Component {
fn bar(cool: Foo) {
// Check that built in derives that I've added actually work by seeing that this hashes
let _blah: HashSet<Foo, RandomState> = HashSet::from_iter([Foo {
field1: "hello".to_string(),
field2: vec![1, 2, 3],
}]);

// Check that the attributes from an external crate actually work. If they don't work,
// compilation will fail here
let _ = serde_json::to_string(&cool);
}
}
}

0 comments on commit 8c2abf4

Please sign in to comment.