diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 30d3ae5..4d2f3fd 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,7 +21,6 @@ jobs: - '1.7' - '1.8' - '1.9' - - 'nightly' os: - ubuntu-latest arch: diff --git a/.gitignore b/.gitignore index 7f14b3c..c7c2e32 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,6 @@ # Folder and files generated by the package manager based on CondaPkg.toml (Python dependencies) **/.CondaPkg/ -*.pyc \ No newline at end of file +*.pyc + +*.pt \ No newline at end of file diff --git a/CondaPkg.toml b/CondaPkg.toml index e6aff2f..c6cb615 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -2,4 +2,4 @@ channels = ["pytorch", "conda-forge"] [deps] numpy = "" -pytorch = "" +pytorch = "" \ No newline at end of file diff --git a/Project.toml b/Project.toml index dbe1c41..0f15ece 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "1.0.0-DEV" CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" [weakdeps] PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" @@ -18,9 +19,3 @@ RTorchModelExt = "RCall" [compat] julia = "1.9" - -[extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[targets] -test = ["Test"] diff --git a/src/TaijaInteroperability.jl b/src/TaijaInteroperability.jl index cf0ed79..5c8578e 100644 --- a/src/TaijaInteroperability.jl +++ b/src/TaijaInteroperability.jl @@ -1,5 +1,10 @@ module TaijaInteroperability +using PackageExtensionCompat +function __init__() + @require_extensions +end + include("CounterfactualExplanations/CounterfactualExplanations.jl") end diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..93ab25a --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,9 @@ +[deps] +CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" +CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" +PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/pytorch.jl b/test/pytorch.jl new file mode 100644 index 0000000..5bb923a --- /dev/null +++ b/test/pytorch.jl @@ -0,0 +1,154 @@ +model_file = "neural_network_class" +class_name = "NeuralNetwork" +model_location = "$(pwd())" +model_path = "$(pwd())/neural_network_class.py" +pickle_path = "$(pwd())/pretrained_model.pt" + +# Using PyTorch models is supported only for Julia versions >= 1.8 +if VERSION >= v"1.8" + ENV["KMP_DUPLICATE_LIB_OK"] = "TRUE" + + torch = PythonCall.pyimport("torch") + @testset "PyTorch model test" begin + for (key, value) in synthetic + name = string(key) + @testset "$name" begin + data = value[:data] + X = data.X + + # Create and save model in the model_path directory + create_new_pytorch_model(data, model_path) + train_and_save_pytorch_model(data, model_location, pickle_path) + model_loaded = Models.pytorch_model_loader( + model_location, model_file, class_name, pickle_path + ) + + model_pytorch = Models.PyTorchModel(model_loaded, data.likelihood) + + @testset "Test for errors" begin + @test_throws ArgumentError Models.PyTorchModel( + model_loaded, :regression + ) + end + + @testset "$name" begin + @testset "Verify the correctness of the likelihood field" begin + @test model_pytorch.likelihood == data.likelihood + end + @testset "Matrix of inputs" begin + @test size(Models.logits(model_pytorch, X))[2] == size(X, 2) + @test size(Models.probs(model_pytorch, X))[2] == size(X, 2) + end + @testset "Vector of inputs" begin + @test size(Models.logits(model_pytorch, X[:, 1]), 2) == 1 + @test size(Models.probs(model_pytorch, X[:, 1]), 2) == 1 + end + end + + # Clean up the temporary files + remove_file(model_path) + remove_file(pickle_path) + end + end + end + + @testset "Counterfactuals for PyTorch models" begin + # Test the Python models on only one generator to avoid the pipeline getting too slow + # All generators are tested in the generators/counterfactuals.jl file + generator = GravitationalGenerator() + for (key, value) in synthetic + name = string(key) + @testset "$name" begin + counterfactual_data = value[:data] + X = counterfactual_data.X + + # Create and save model in the model_path directory + create_new_pytorch_model(counterfactual_data, model_path) + train_and_save_pytorch_model( + counterfactual_data, model_location, pickle_path + ) + model_loaded = Models.pytorch_model_loader( + model_location, model_file, class_name, pickle_path + ) + M = Models.PyTorchModel(model_loaded, counterfactual_data.likelihood) + + # Randomly selected factual: + Random.seed!(123) + x = DataPreprocessing.select_factual( + counterfactual_data, Random.rand(1:size(X, 2)) + ) + # Choose target: + y = Models.predict_label(M, counterfactual_data, x) + target = get_target(counterfactual_data, y[1]) + # Single sample: + counterfactual = CounterfactualExplanations.generate_counterfactual( + x, target, counterfactual_data, M, generator + ) + + @testset "Predetermined outputs" begin + if generator.latent_space + @test counterfactual.params[:latent_space] + end + @test counterfactual.target == target + @test counterfactual.x == x && + CounterfactualExplanations.factual(counterfactual) == x + @test CounterfactualExplanations.factual_label(counterfactual) == y + @test CounterfactualExplanations.factual_probability(counterfactual) == + probs(M, x) + end + + @testset "Convergence" begin + @testset "Non-trivial case" begin + counterfactual_data.generative_model = nothing + # Threshold reached if converged: + γ = 0.9 + max_iter = 1000 + counterfactual = CounterfactualExplanations.generate_counterfactual( + x, + target, + counterfactual_data, + M, + generator; + max_iter=max_iter, + decision_threshold=γ, + ) + using CounterfactualExplanations: counterfactual_probability + @test !CounterfactualExplanations.converged(counterfactual) || + CounterfactualExplanations.target_probs(counterfactual)[1] >= + γ # either not converged or threshold reached + @test !CounterfactualExplanations.converged(counterfactual) || + length(path(counterfactual)) <= max_iter + end + + @testset "Trivial case (already in target class)" begin + counterfactual_data.generative_model = nothing + # Already in target and exceeding threshold probability: + y = Models.predict_label(M, counterfactual_data, x) + target = y[1] + γ = minimum([1 / length(counterfactual_data.y_levels), 0.5]) + counterfactual = CounterfactualExplanations.generate_counterfactual( + x, + target, + counterfactual_data, + M, + generator; + decision_threshold=γ, + initialization=:identity, + ) + x′ = CounterfactualExplanations.decode_state(counterfactual) + if counterfactual.generator.latent_space == false + @test isapprox(counterfactual.x, x′; atol=1e-6) + @test CounterfactualExplanations.converged(counterfactual) + @test CounterfactualExplanations.terminated(counterfactual) + end + @test CounterfactualExplanations.total_steps(counterfactual) == 0 + end + end + + # Clean up the temporary files + remove_file(model_path) + remove_file(pickle_path) + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 204cea8..70ca7e6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,17 @@ +using CounterfactualExplanations +using CounterfactualExplanations.Data +using CounterfactualExplanations.DataPreprocessing +using CounterfactualExplanations.Models +using Printf +using PythonCall +using Random using TaijaInteroperability using Test +include("utils.jl") + +synthetic = _load_synthetic() + @testset "TaijaInteroperability.jl" begin - # Write your tests here. + include("pytorch.jl") end diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..b0c6a41 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,171 @@ +""" + _load_synthetic() + +Loads synthetic data, models, and generators. +""" +function _load_synthetic() + # Data: + data_sets = Dict( + :classification_binary => load_linearly_separable(), + :classification_multi => load_multi_class(), + ) + # Models + synthetic = Dict() + for (likelihood, data) in data_sets + models = Dict() + for (model_name, model) in Models.standard_models_catalogue + M = fit_model(data, model_name) + models[model_name] = Dict(:raw_model => M.model, :model => M) + end + synthetic[likelihood] = Dict(:models => models, :data => data) + end + return synthetic +end + +""" + get_target(counterfactual_data::CounterfactualData, factual_label::RawTargetType) + +Returns a target label that is different from the factual label. +""" +function get_target(counterfactual_data::CounterfactualData, factual_label::RawTargetType) + target = rand( + counterfactual_data.y_levels[counterfactual_data.y_levels .!= factual_label] + ) + return target +end + +""" + _load_pretrained_models() + +Loads pretrained Flux models. +""" +function _load_pretrained_models() + pretrained = Dict( + :mnist => Dict( + :models => Dict( + :mlp => Models.load_mnist_mlp(), + :ensemble => Models.load_mnist_ensemble(), + ), + :latent => Dict( + :vae_strong => Models.load_mnist_vae(; strong=true), + :vae_weak => Models.load_mnist_vae(; strong=false), + ), + ), + :fashion_mnist => Dict( + :models => Dict( + :mlp => Models.load_fashion_mnist_mlp(), + :ensemble => Models.load_fashion_mnist_ensemble(), + ), + :latent => Dict( + :vae_strong => Models.load_fashion_mnist_vae(; strong=true), + :vae_weak => Models.load_fashion_mnist_vae(; strong=false), + ), + ), + ) + + if !Sys.iswindows() + pretrained[:cifar_10] = Dict( + :models => Dict( + :mlp => Models.load_cifar_10_mlp(), + :ensemble => Models.load_cifar_10_ensemble(), + ), + :latent => Dict( + :vae_strong => Models.load_cifar_10_vae(; strong=true), + :vae_weak => Models.load_cifar_10_vae(; strong=false), + ), + ) + end + + return pretrained +end + +""" + create_new_pytorch_model(data::CounterfactualData, model_path::String) + +Creates a new PyTorch model and saves it to a Python file. +""" +function create_new_pytorch_model(data::CounterfactualData, model_path::String) + in_size = size(data.X)[1] + out_size = size(data.y)[1] + + class_str = """ + from torch import nn + + class NeuralNetwork(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Sequential( + nn.Flatten(), + nn.Linear($(in_size), 32), + nn.Sigmoid(), + nn.Linear(32, $(out_size)) + ) + + def forward(self, x): + return self.model(x) + """ + + open(model_path, "w") do f + @printf(f, "%s", class_str) + end + + return nothing +end + +""" + train_and_save_pytorch_model(data::CounterfactualData, model_location::String, pickle_path::String) + +Trains a PyTorch model and saves it to a pickle file. +""" +function train_and_save_pytorch_model( + data::CounterfactualData, model_location::String, pickle_path::String +) + sys = PythonCall.pyimport("sys") + + if !in(model_location, sys.path) + sys.path.append(model_location) + end + + importlib = PythonCall.pyimport("importlib") + neural_network_class = importlib.import_module("neural_network_class") + importlib.reload(neural_network_class) + NeuralNetwork = neural_network_class.NeuralNetwork + model = NeuralNetwork() + + x_python, y_python = CounterfactualExplanations.DataPreprocessing.preprocess_python_data( + data + ) + + optimizer = torch.optim.Adam(model.parameters(); lr=0.1) + loss_fun = torch.nn.BCEWithLogitsLoss() + + # Training + for _ in 1:100 + # Compute prediction and loss: + output = model(x_python).squeeze() + sleep(1) + loss = loss_fun(output, y_python.t()) + # Backpropagation: + optimizer.zero_grad() + loss.backward() + optimizer.step() + end + + torch.save(model, pickle_path) + return nothing +end + +""" + remove_file(file_path::String) + +Removes a file from the specified path. +""" +function remove_file(file_path::String) + try + rm(file_path) # removes the file + println("File $file_path removed successfully.") + return nothing + catch e + throw(ArgumentError("Error occurred while removing file $file_path: $e")) + end +end