Skip to content

Commit

Permalink
replace DefaultRange with Encode trait
Browse files Browse the repository at this point in the history
  • Loading branch information
micahrj committed Jan 25, 2024
1 parent a4020a5 commit 6b55ff7
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 69 deletions.
55 changes: 33 additions & 22 deletions coupler-derive/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,27 @@ pub fn parse_param(field: &Field) -> Result<Option<ParamAttr>, Error> {
}))
}

pub fn gen_range(field: &Field, param: &ParamAttr) -> TokenStream {
pub fn gen_encode(field: &Field, param: &ParamAttr, value: impl ToTokens) -> TokenStream {
let ty = &field.ty;
if let Some(range) = &param.range {
range.to_token_stream()
quote! { ::coupler::params::Range::<#ty>::encode(&(#range), #value) }
} else {
let ty = &field.ty;
quote! { <#ty as ::coupler::params::DefaultRange>::default_range() }
quote! { <#ty as ::coupler::params::Encode>::encode(#value) }
}
}

pub fn gen_decode(field: &Field, param: &ParamAttr, value: impl ToTokens) -> TokenStream {
let ty = &field.ty;
if let Some(range) = &param.range {
quote! { ::coupler::params::Range::<#ty>::decode(&(#range), #value) }
} else {
quote! { <#ty as ::coupler::params::Encode>::decode(#value) }
}
}

struct ParamField<'a> {
field: &'a Field,
param: ParamAttr,
range: TokenStream,
}

fn parse_fields(input: &DeriveInput) -> Result<Vec<ParamField>, Error> {
Expand Down Expand Up @@ -171,12 +179,7 @@ fn parse_fields(input: &DeriveInput) -> Result<Vec<ParamField>, Error> {

for field in &fields.named {
if let Some(param) = parse_param(field)? {
let range = gen_range(field, &param);
param_fields.push(ParamField {
field,
param,
range,
});
param_fields.push(ParamField { field, param });
}
}

Expand All @@ -193,15 +196,23 @@ pub fn expand_params(input: &DeriveInput) -> Result<TokenStream, Error> {
let ident = field.field.ident.as_ref().unwrap();
let ty = &field.field.ty;
let id = &field.param.id;
let range = &field.range;

let name = if let Some(name) = &field.param.name {
name.clone()
} else {
LitStr::new(&ident.to_string(), ident.span())
};

let encode = quote! { ::coupler::params::Range::<#ty>::encode(&(#range), __value) };
let default = gen_encode(&field.field, &field.param, quote! { __default.#ident });

let steps = if let Some(range) = &field.param.range {
let ty = &field.field.ty;
quote! { ::coupler::params::Range::<#ty>::steps(&(#range)) }
} else {
quote! { <#ty as ::coupler::params::Encode>::steps() }
};

let encode = gen_encode(&field.field, &field.param, quote! { __value });
let parse = if let Some(parse) = &field.param.parse {
quote! {
match (#parse)(__str) {
Expand All @@ -218,7 +229,7 @@ pub fn expand_params(input: &DeriveInput) -> Result<TokenStream, Error> {
}
};

let decode = quote! { ::coupler::params::Range::<#ty>::decode(&(#range), __value) };
let decode = gen_decode(&field.field, &field.param, quote! { __value });
let display = if let Some(display) = &field.param.display {
quote! { (#display)(#decode, __formatter) }
} else if let Some(format) = &field.param.format {
Expand All @@ -231,8 +242,8 @@ pub fn expand_params(input: &DeriveInput) -> Result<TokenStream, Error> {
::coupler::params::ParamInfo {
id: #id,
name: ::std::string::ToString::to_string(#name),
default: ::coupler::params::Range::<#ty>::encode(&(#range), __default.#ident),
steps: ::coupler::params::Range::<#ty>::steps(&(#range)),
default: #default,
steps: #steps,
parse: ::std::boxed::Box::new(|__str| #parse),
display: ::std::boxed::Box::new(|__value, __formatter| #display),
}
Expand All @@ -241,26 +252,26 @@ pub fn expand_params(input: &DeriveInput) -> Result<TokenStream, Error> {

let set_cases = fields.iter().map(|field| {
let ident = &field.field.ident;
let ty = &field.field.ty;
let id = &field.param.id;
let range = &field.range;

let decode = gen_decode(&field.field, &field.param, quote! { __value });

quote! {
#id => {
self.#ident = ::coupler::params::Range::<#ty>::decode(&(#range), __value);
self.#ident = #decode;
}
}
});

let get_cases = fields.iter().map(|field| {
let ident = &field.field.ident;
let ty = &field.field.ty;
let id = &field.param.id;
let range = &field.range;

let encode = gen_encode(&field.field, &field.param, quote! { self.#ident });

quote! {
#id => {
::coupler::params::Range::<#ty>::encode(&(#range), self.#ident)
#encode
}
}
});
Expand Down
13 changes: 5 additions & 8 deletions coupler-derive/src/smooth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use proc_macro2::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Error, Expr, Field, Fields, Type};

use super::params::{gen_range, parse_param, ParamAttr};
use super::params::{gen_decode, parse_param, ParamAttr};

struct SmoothAttr {
style: Type,
Expand Down Expand Up @@ -194,25 +194,22 @@ pub fn expand_smooth(input: &DeriveInput) -> Result<TokenStream, Error> {

let set_cases = fields.iter().filter_map(|field| {
let param = field.param.as_ref()?;
let range = gen_range(&field.field, param);

let ident = &field.field.ident;
let ty = &field.field.ty;
let id = &param.id;

let decode = gen_decode(&field.field, param, quote! { __value });

if field.smooth.is_some() {
Some(quote! {
#id => {
::coupler::params::smooth::Smoother::set(
&mut self.#ident,
::coupler::params::Range::<#ty>::decode(&(#range), __value),
);
::coupler::params::smooth::Smoother::set(&mut self.#ident, #decode);
}
})
} else {
Some(quote! {
#id => {
self.#ident = ::coupler::params::Range::<#ty>::decode(&(#range), __value);
self.#ident = #decode;
}
})
}
Expand Down
2 changes: 1 addition & 1 deletion src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub use coupler_derive::Params;
mod range;
pub mod smooth;

pub use range::{DefaultRange, Enum, EnumRange, Range};
pub use range::{Encode, Enum, Range};

pub type ParamId = u32;
pub type ParamValue = f64;
Expand Down
67 changes: 29 additions & 38 deletions src/params/range.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::marker::PhantomData;

use super::ParamValue;

pub trait Range<T> {
Expand All @@ -8,10 +6,10 @@ pub trait Range<T> {
fn decode(&self, value: ParamValue) -> T;
}

pub trait DefaultRange: Sized {
type Range: Range<Self>;

fn default_range() -> Self::Range;
pub trait Encode {
fn steps() -> Option<u32>;
fn encode(self) -> ParamValue;
fn decode(value: ParamValue) -> Self;
}

macro_rules! float_range {
Expand Down Expand Up @@ -50,12 +48,17 @@ macro_rules! float_range {
}
}

impl DefaultRange for $float {
type Range = std::ops::Range<$float>;
impl Encode for $float {
fn steps() -> Option<u32> {
(0.0..1.0).steps()
}

#[inline]
fn default_range() -> Self::Range {
0.0..1.0
fn encode(self) -> ParamValue {
(0.0..1.0).encode(self)
}

fn decode(value: ParamValue) -> Self {
(0.0..1.0).decode(value)
}
}
};
Expand Down Expand Up @@ -104,12 +107,17 @@ macro_rules! int_range {
}
}

impl DefaultRange for $int {
type Range = std::ops::Range<$int>;
impl Encode for $int {
fn steps() -> Option<u32> {
(0..2).steps()
}

#[inline]
fn default_range() -> Self::Range {
0..1
fn encode(self) -> ParamValue {
(0..2).encode(self)
}

fn decode(value: ParamValue) -> Self {
(0..2).decode(value)
}
}
};
Expand All @@ -131,39 +139,22 @@ pub trait Enum {
fn from_index(index: u32) -> Self;
}

pub struct EnumRange<E>(PhantomData<E>);

impl<E> EnumRange<E> {
pub fn new() -> EnumRange<E> {
EnumRange(PhantomData)
}
}

impl<E: Enum> Range<E> for EnumRange<E> {
fn steps(&self) -> Option<u32> {
impl<E: Enum> Encode for E {
fn steps() -> Option<u32> {
Some(E::values())
}

fn encode(&self, value: E) -> ParamValue {
fn encode(self) -> ParamValue {
let steps_recip = 1.0 / E::values() as f64;
(value.to_index() as f64 + 0.5) * steps_recip
(self.to_index() as f64 + 0.5) * steps_recip
}

fn decode(&self, value: ParamValue) -> E {
fn decode(value: ParamValue) -> E {
let steps = E::values() as f64;
E::from_index((value * steps) as u32)
}
}

impl<E: Enum> DefaultRange for E {
type Range = EnumRange<E>;

#[inline]
fn default_range() -> Self::Range {
EnumRange::new()
}
}

impl Enum for bool {
fn values() -> u32 {
2
Expand Down

0 comments on commit 6b55ff7

Please sign in to comment.