Skip to content

Commit

Permalink
ok now then
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Apr 9, 2024
1 parent 5c60c4a commit 43bced8
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions test/pytorch.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using TaijaInteroperability

model_file = "neural_network_class"
class_name = "NeuralNetwork"
model_location = "$(pwd())"
Expand All @@ -19,17 +21,17 @@ if VERSION >= v"1.8"
# Create and save model in the model_path directory
create_new_pytorch_model(data, model_path)
train_and_save_pytorch_model(data, model_location, pickle_path)
model_loaded = pytorch_model_loader(
model_loaded = TaijaInteroperability.pytorch_model_loader(
model_location,
model_file,
class_name,
pickle_path,
)

model_pytorch = PyTorchModel(model_loaded, data.likelihood)
model_pytorch = TaijaInteroperability.PyTorchModel(model_loaded, data.likelihood)

@testset "Test for errors" begin
@test_throws ArgumentError PyTorchModel(model_loaded, :regression)
@test_throws ArgumentError TaijaInteroperability.PyTorchModel(model_loaded, :regression)
end

@testset "$name" begin
Expand Down Expand Up @@ -70,13 +72,13 @@ if VERSION >= v"1.8"
model_location,
pickle_path,
)
model_loaded = pytorch_model_loader(
model_loaded = TaijaInteroperability.pytorch_model_loader(
model_location,
model_file,
class_name,
pickle_path,
)
M = PyTorchModel(model_loaded, counterfactual_data.likelihood)
M = TaijaInteroperability.PyTorchModel(model_loaded, counterfactual_data.likelihood)

# Randomly selected factual:
Random.seed!(123)
Expand Down

0 comments on commit 43bced8

Please sign in to comment.