Skip to content

Commit

Permalink
come on
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Apr 9, 2024
1 parent 1b64d56 commit 88ac1ed
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 52 deletions.
33 changes: 18 additions & 15 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
using TaijaInteroperability
using Documenter

DocMeta.setdocmeta!(TaijaInteroperability, :DocTestSetup, :(using TaijaInteroperability); recursive=true)
DocMeta.setdocmeta!(
TaijaInteroperability,
:DocTestSetup,
:(using TaijaInteroperability);
recursive = true,
)

makedocs(;
modules=[TaijaInteroperability],
authors="Patrick Altmeyer",
repo="https://github.com/JuliaTrustworthyAI/TaijaInteroperability.jl/blob/{commit}{path}#{line}",
sitename="TaijaInteroperability.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://JuliaTrustworthyAI.github.io/TaijaInteroperability.jl",
edit_link="master",
assets=String[],
modules = [TaijaInteroperability],
authors = "Patrick Altmeyer",
repo = "https://github.com/JuliaTrustworthyAI/TaijaInteroperability.jl/blob/{commit}{path}#{line}",
sitename = "TaijaInteroperability.jl",
format = Documenter.HTML(;
prettyurls = get(ENV, "CI", "false") == "true",
canonical = "https://JuliaTrustworthyAI.github.io/TaijaInteroperability.jl",
edit_link = "master",
assets = String[],
),
pages=[
"Home" => "index.md",
],
pages = ["Home" => "index.md"],
)

deploydocs(;
repo="github.com/JuliaTrustworthyAI/TaijaInteroperability.jl",
devbranch="master",
repo = "github.com/JuliaTrustworthyAI/TaijaInteroperability.jl",
devbranch = "master",
)
2 changes: 1 addition & 1 deletion ext/PyTorchModelExt/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct PyTorchModel <: AbstractDifferentiableModel
else
throw(
ArgumentError(
"`type` should be in `[:classification_binary,:classification_multi]`"
"`type` should be in `[:classification_binary,:classification_multi]`",
),
)
end
Expand Down
5 changes: 4 additions & 1 deletion ext/PyTorchModelExt/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ model = pytorch_model_loader(
```
"""
function TaijaInteroperability.pytorch_model_loader(
model_path::String, model_file::String, class_name::String, pickle_path::String
model_path::String,
model_file::String,
class_name::String,
pickle_path::String,
)
sys = PythonCall.pyimport("sys")
torch = PythonCall.pyimport("torch")
Expand Down
4 changes: 2 additions & 2 deletions src/CounterfactualExplanations/CounterfactualExplanations.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export PyTorchModel, pytorch_model_loader, preprocess_python_data
export RTorchModel, rtorch_model_loader
export RTorchModel, rtorch_model_loader

include("PyTorchModel.jl")
include("RTorchModel.jl")
include("RTorchModel.jl")
53 changes: 36 additions & 17 deletions test/pytorch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@ if VERSION >= v"1.8"
create_new_pytorch_model(data, model_path)
train_and_save_pytorch_model(data, model_location, pickle_path)
model_loaded = pytorch_model_loader(
model_location, model_file, class_name, pickle_path
model_location,
model_file,
class_name,
pickle_path,
)

model_pytorch = PyTorchModel(model_loaded, data.likelihood)

@testset "Test for errors" begin
@test_throws ArgumentError PyTorchModel(
model_loaded, :regression
)
@test_throws ArgumentError PyTorchModel(model_loaded, :regression)
end

@testset "$name" begin
Expand Down Expand Up @@ -65,24 +66,34 @@ if VERSION >= v"1.8"
# Create and save model in the model_path directory
create_new_pytorch_model(counterfactual_data, model_path)
train_and_save_pytorch_model(
counterfactual_data, model_location, pickle_path
counterfactual_data,
model_location,
pickle_path,
)
model_loaded = pytorch_model_loader(
model_location, model_file, class_name, pickle_path
model_location,
model_file,
class_name,
pickle_path,
)
M = PyTorchModel(model_loaded, counterfactual_data.likelihood)

# Randomly selected factual:
Random.seed!(123)
x = DataPreprocessing.select_factual(
counterfactual_data, Random.rand(1:size(X, 2))
counterfactual_data,
Random.rand(1:size(X, 2)),
)
# Choose target:
y = Models.predict_label(M, counterfactual_data, x)
target = get_target(counterfactual_data, y[1])
# Single sample:
counterfactual = CounterfactualExplanations.generate_counterfactual(
x, target, counterfactual_data, M, generator
x,
target,
counterfactual_data,
M,
generator,
)

@testset "Predetermined outputs" begin
Expand All @@ -91,10 +102,10 @@ if VERSION >= v"1.8"
end
@test counterfactual.target == target
@test counterfactual.x == x &&
CounterfactualExplanations.factual(counterfactual) == x
CounterfactualExplanations.factual(counterfactual) == x
@test CounterfactualExplanations.factual_label(counterfactual) == y
@test CounterfactualExplanations.factual_probability(counterfactual) ==
probs(M, x)
probs(M, x)
end

@testset "Convergence" begin
Expand All @@ -103,21 +114,25 @@ if VERSION >= v"1.8"
# Threshold reached if converged:
γ = 0.9
max_iter = 1000
conv =
CounterfactualExplanations.Convergence.DecisionThresholdConvergence(
γ,
max_iter,
)
counterfactual = CounterfactualExplanations.generate_counterfactual(
x,
target,
counterfactual_data,
M,
generator;
max_iter=max_iter,
decision_threshold=γ,
convergence = conv,
)
using CounterfactualExplanations: counterfactual_probability
@test !CounterfactualExplanations.converged(counterfactual) ||
CounterfactualExplanations.target_probs(counterfactual)[1] >=
CounterfactualExplanations.target_probs(counterfactual)[1] >=
γ # either not converged or threshold reached
@test !CounterfactualExplanations.converged(counterfactual) ||
length(path(counterfactual)) <= max_iter
length(path(counterfactual)) <= max_iter
end

@testset "Trivial case (already in target class)" begin
Expand All @@ -126,18 +141,22 @@ if VERSION >= v"1.8"
y = Models.predict_label(M, counterfactual_data, x)
target = y[1]
γ = minimum([1 / length(counterfactual_data.y_levels), 0.5])
conv =
CounterfactualExplanations.Convergence.DecisionThresholdConvergence(
γ,
)
counterfactual = CounterfactualExplanations.generate_counterfactual(
x,
target,
counterfactual_data,
M,
generator;
decision_threshold=γ,
initialization=:identity,
convergence = conv,
initialization = :identity,
)
x′ = CounterfactualExplanations.decode_state(counterfactual)
if counterfactual.generator.latent_space == false
@test isapprox(counterfactual.x, x′; atol=1e-6)
@test isapprox(counterfactual.x, x′; atol = 1e-6)
@test CounterfactualExplanations.converged(counterfactual)
@test CounterfactualExplanations.terminated(counterfactual)
end
Expand Down
32 changes: 16 additions & 16 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ Loads synthetic data, models, and generators.
function _load_synthetic()
# Data:
data_sets = Dict(
:classification_binary => CounterfactualData(TaijaData.load_linearly_separable()...),
:classification_binary =>
CounterfactualData(TaijaData.load_linearly_separable()...),
:classification_multi => CounterfactualData(TaijaData.load_multi_class()...),
)
# Models
Expand All @@ -28,9 +29,8 @@ end
Returns a target label that is different from the factual label.
"""
function get_target(counterfactual_data::CounterfactualData, factual_label::RawTargetType)
target = rand(
counterfactual_data.y_levels[counterfactual_data.y_levels .!= factual_label]
)
target =
rand(counterfactual_data.y_levels[counterfactual_data.y_levels.!=factual_label])
return target
end

Expand All @@ -47,8 +47,8 @@ function _load_pretrained_models()
:ensemble => Models.load_mnist_ensemble(),
),
:latent => Dict(
:vae_strong => Models.load_mnist_vae(; strong=true),
:vae_weak => Models.load_mnist_vae(; strong=false),
:vae_strong => Models.load_mnist_vae(; strong = true),
:vae_weak => Models.load_mnist_vae(; strong = false),
),
),
:fashion_mnist => Dict(
Expand All @@ -57,8 +57,8 @@ function _load_pretrained_models()
:ensemble => Models.load_fashion_mnist_ensemble(),
),
:latent => Dict(
:vae_strong => Models.load_fashion_mnist_vae(; strong=true),
:vae_weak => Models.load_fashion_mnist_vae(; strong=false),
:vae_strong => Models.load_fashion_mnist_vae(; strong = true),
:vae_weak => Models.load_fashion_mnist_vae(; strong = false),
),
),
)
Expand All @@ -70,8 +70,8 @@ function _load_pretrained_models()
:ensemble => Models.load_cifar_10_ensemble(),
),
:latent => Dict(
:vae_strong => Models.load_cifar_10_vae(; strong=true),
:vae_weak => Models.load_cifar_10_vae(; strong=false),
:vae_strong => Models.load_cifar_10_vae(; strong = true),
:vae_weak => Models.load_cifar_10_vae(; strong = false),
),
)
end
Expand Down Expand Up @@ -118,7 +118,9 @@ end
Trains a PyTorch model and saves it to a pickle file.
"""
function train_and_save_pytorch_model(
data::CounterfactualData, model_location::String, pickle_path::String
data::CounterfactualData,
model_location::String,
pickle_path::String,
)
sys = PythonCall.pyimport("sys")

Expand All @@ -132,15 +134,13 @@ function train_and_save_pytorch_model(
NeuralNetwork = neural_network_class.NeuralNetwork
model = NeuralNetwork()

x_python, y_python = preprocess_python_data(
data
)
x_python, y_python = preprocess_python_data(data)

optimizer = torch.optim.Adam(model.parameters(); lr=0.1)
optimizer = torch.optim.Adam(model.parameters(); lr = 0.1)
loss_fun = torch.nn.BCEWithLogitsLoss()

# Training
for _ in 1:100
for _ = 1:100
# Compute prediction and loss:
output = model(x_python).squeeze()
sleep(1)
Expand Down

0 comments on commit 88ac1ed

Please sign in to comment.