Skip to content

Commit

Permalink
Update tensor_prod.md
Browse files Browse the repository at this point in the history
  • Loading branch information
Spinachboul authored Jan 13, 2024
1 parent 7849439 commit 3013518
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions docs/src/tensor_prod.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,34 +48,35 @@ Now let's train various surrogate models and evaluate their performance on the t

```
# Train different surrogate models
function train_surrogates(x_train, y_train)
loba = LobachevskySurrogate(x_train, y_train)
krig = Kriging(x_train, y_train)
function train_surrogates(x_train, y_train, lb, ub, alpha=2.0, n=6)
loba = LobachevskySurrogate(x_train, y_train, lb, ub, alpha=alpha, n=n)
krig = Kriging(x_train, y_train, lb, ub)
return loba, krig
end
# Evaluate and compare surrogate model performances
function evaluate_surrogates(loba, krig, x_test)
loba_pred = loba(x_test)
krig_pred = krig(x_test)
loba_pred = loba.(x_test)
krig_pred = krig.(x_test)
return loba_pred, krig_pred
end
# Plot surrogate predictions against the true function
function plot_surrogate_predictions(loba_pred, krig_pred, y_test, a, lb, ub)
xs = range(lb, ub, length=1000)
plot(xs, tensor_product_function.(Ref(xs), a), label="True function", legend=:top)
plot!(xs, loba_pred, label="Lobachevsky")
plot!(xs, krig_pred, label="Kriging")
function plot_surrogate_predictions(loba_pred, krig_pred, x_test, y_test, a, lb, ub)
xs = collect(x_test) # Convert x_test to an array
plot(xs, tensor_product_function.(xs, a), label="True Function", legend=:top)
plot!(collect(x_test), loba_pred, seriestype=:scatter, label="Lobachevsky")
plot!(collect(x_test), krig_pred, seriestype=:scatter, label="Kriging")
plot!(collect(x_test), fill(y_test, length(x_test)), seriestype=:scatter, label="Sampled points") # Use fill to create an array of the same length as x_test
end
# Train surrogates and evaluate their performance
loba, krig = train_surrogates(x_train, y_train)
lb, ub = minimum(x_train), maximum(x_train)
loba, krig = train_surrogates(x_train, y_train, lb, ub)
loba_pred, krig_pred = evaluate_surrogates(loba, krig, x_test)
# Plot surrogate predictions against the true function
plot_surrogate_predictions(loba_pred, krig_pred, y_test, a, lb, ub)
# Plotting Results
plot_surrogate_predictions(loba_pred, krig_pred, x_test, y_test, 2.0, lb, ub)
```

# Reporting the best Surrogate Model
Expand Down

0 comments on commit 3013518

Please sign in to comment.