diff --git a/EpiAware/src/EpiLatentModels/modifiers/TransformLatentModel.jl b/EpiAware/src/EpiLatentModels/modifiers/TransformLatentModel.jl index 7a9c4476a..a8367840d 100644 --- a/EpiAware/src/EpiLatentModels/modifiers/TransformLatentModel.jl +++ b/EpiAware/src/EpiLatentModels/modifiers/TransformLatentModel.jl @@ -3,8 +3,8 @@ The `TransformLatentModel` struct represents a latent model that applies a trans ## Constructors -- `TransformLatentModel(model, trans_function)`: Constructs a `TransformLatentModel` instance with the specified latent model and transformation function. -- `TransformLatentModel(; model, trans_function)`: Constructs a `TransformLatentModel` instance with the specified latent model and transformation function using named arguments. +- `TransformLatentModel(model, transform)`: Constructs a `TransformLatentModel` instance with the specified latent model and transformation function. +- `TransformLatentModel(; model, transform)`: Constructs a `TransformLatentModel` instance with the specified latent model and transformation function using named arguments. ## Example @@ -20,7 +20,7 @@ trans_model() "The latent model to transform." model::M "The transformation function." - trans_function::F + transform::F end """ @@ -38,6 +38,6 @@ Generate latent variables using the specified `TransformLatentModel`. """ @model function EpiAwareBase.generate_latent(model::TransformLatentModel, n) @submodel untransformed = generate_latent(model.model, n) - latent = model.trans_function(untransformed) + latent = model.transform(untransformed) return latent end diff --git a/EpiAware/test/EpiLatentModels/modifiers/TransformLatentModel.jl b/EpiAware/test/EpiLatentModels/modifiers/TransformLatentModel.jl index ed17d1f56..b71075a8e 100644 --- a/EpiAware/test/EpiLatentModels/modifiers/TransformLatentModel.jl +++ b/EpiAware/test/EpiLatentModels/modifiers/TransformLatentModel.jl @@ -4,7 +4,7 @@ trans = TransformLatentModel(Intercept(Normal(2, 0.2)), x -> x .|> exp) @test typeof(trans) <: AbstractTuringLatentModel @test trans.model == Intercept(Normal(2, 0.2)) - @test trans.trans_function([1, 2, 3]) == [exp(1), exp(2), exp(3)] + @test trans.transform([1, 2, 3]) == [exp(1), exp(2), exp(3)] end @testitem "TransformLatentModel generate_latent method" begin