diff --git a/docs/src/tensor_prod.md b/docs/src/tensor_prod.md index 472f5bc8..015aef85 100644 --- a/docs/src/tensor_prod.md +++ b/docs/src/tensor_prod.md @@ -48,34 +48,35 @@ Now let's train various surrogate models and evaluate their performance on the t ``` # Train different surrogate models -function train_surrogates(x_train, y_train) - loba = LobachevskySurrogate(x_train, y_train) - krig = Kriging(x_train, y_train) +function train_surrogates(x_train, y_train, lb, ub, alpha=2.0, n=6) + loba = LobachevskySurrogate(x_train, y_train, lb, ub, alpha=alpha, n=n) + krig = Kriging(x_train, y_train, lb, ub) return loba, krig end # Evaluate and compare surrogate model performances function evaluate_surrogates(loba, krig, x_test) - loba_pred = loba(x_test) - krig_pred = krig(x_test) + loba_pred = loba.(x_test) + krig_pred = krig.(x_test) return loba_pred, krig_pred end # Plot surrogate predictions against the true function -function plot_surrogate_predictions(loba_pred, krig_pred, y_test, a, lb, ub) - xs = range(lb, ub, length=1000) - - plot(xs, tensor_product_function.(Ref(xs), a), label="True function", legend=:top) - plot!(xs, loba_pred, label="Lobachevsky") - plot!(xs, krig_pred, label="Kriging") +function plot_surrogate_predictions(loba_pred, krig_pred, x_test, y_test, a, lb, ub) + xs = collect(x_test) # Convert x_test to an array + plot(xs, tensor_product_function.(xs, a), label="True Function", legend=:top) + plot!(collect(x_test), loba_pred, seriestype=:scatter, label="Lobachevsky") + plot!(collect(x_test), krig_pred, seriestype=:scatter, label="Kriging") + plot!(collect(x_test), fill(y_test, length(x_test)), seriestype=:scatter, label="Sampled points") # Use fill to create an array of the same length as x_test end # Train surrogates and evaluate their performance -loba, krig = train_surrogates(x_train, y_train) +lb, ub = minimum(x_train), maximum(x_train) +loba, krig = train_surrogates(x_train, y_train, lb, ub) loba_pred, krig_pred = evaluate_surrogates(loba, krig, x_test) -# Plot surrogate predictions against the true function -plot_surrogate_predictions(loba_pred, krig_pred, y_test, a, lb, ub) +# Plotting Results +plot_surrogate_predictions(loba_pred, krig_pred, x_test, y_test, 2.0, lb, ub) ``` # Reporting the best Surrogate Model