diff --git a/ext/PyTorchModelExt/generators.jl b/ext/PyTorchModelExt/generators.jl index 18503be..1e74202 100644 --- a/ext/PyTorchModelExt/generators.jl +++ b/ext/PyTorchModelExt/generators.jl @@ -8,7 +8,7 @@ The gradients are calculated through PyTorch using PythonCall.jl. # Arguments - `generator::AbstractGradientBasedGenerator`: The generator object that is used to generate the counterfactual explanation. -- `M::Models.PyTorchModel`: The PyTorch model for which the counterfactual is generated. +- `M::PyTorchModel`: The PyTorch model for which the counterfactual is generated. - `ce::AbstractCounterfactualExplanation`: The counterfactual explanation object for which the gradient is calculated. # Returns diff --git a/ext/PyTorchModelExt/utils.jl b/ext/PyTorchModelExt/utils.jl index dca5950..0a1e695 100644 --- a/ext/PyTorchModelExt/utils.jl +++ b/ext/PyTorchModelExt/utils.jl @@ -1,5 +1,5 @@ """ - CounterfactualExplanations.pytorch_model_loader(model_path::String, model_file::String, class_name::String, pickle_path::String) + pytorch_model_loader(model_path::String, model_file::String, class_name::String, pickle_path::String) Loads a previously saved PyTorch model. @@ -44,7 +44,7 @@ function TaijaInteroperability.pytorch_model_loader( end """ - CounterfactualExplanations.preprocess_python_data(data::CounterfactualData) + preprocess_python_data(data::CounterfactualData) Converts a `CounterfactualData` object to an input tensor and a label tensor. diff --git a/test/pytorch.jl b/test/pytorch.jl index 5bb923a..0e24ce9 100644 --- a/test/pytorch.jl +++ b/test/pytorch.jl @@ -19,14 +19,14 @@ if VERSION >= v"1.8" # Create and save model in the model_path directory create_new_pytorch_model(data, model_path) train_and_save_pytorch_model(data, model_location, pickle_path) - model_loaded = Models.pytorch_model_loader( + model_loaded = pytorch_model_loader( model_location, model_file, class_name, pickle_path ) - model_pytorch = Models.PyTorchModel(model_loaded, data.likelihood) + model_pytorch = PyTorchModel(model_loaded, data.likelihood) @testset "Test for errors" begin - @test_throws ArgumentError Models.PyTorchModel( + @test_throws ArgumentError PyTorchModel( model_loaded, :regression ) end @@ -67,10 +67,10 @@ if VERSION >= v"1.8" train_and_save_pytorch_model( counterfactual_data, model_location, pickle_path ) - model_loaded = Models.pytorch_model_loader( + model_loaded = pytorch_model_loader( model_location, model_file, class_name, pickle_path ) - M = Models.PyTorchModel(model_loaded, counterfactual_data.likelihood) + M = PyTorchModel(model_loaded, counterfactual_data.likelihood) # Randomly selected factual: Random.seed!(123)