Skip to content

Commit

Permalink
Update tensor_prod.md
Browse files Browse the repository at this point in the history
  • Loading branch information
Spinachboul authored Jan 13, 2024
1 parent acc001b commit 7849439
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions docs/src/tensor_prod.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function generate_data(n, lb, ub, a)
x_train = sample(n, lb, ub, SobolSample())
y_train = tensor_product_function(x_train, a)
x_test = sample(1000, lb, ub, SobolSample()) # Generating test data
x_test = sample(1000, lb, ub, RandomSample()) # Generating test data
y_test = tensor_product_function(x_test, a) # Generating test labels
return x_train, y_train, x_test, y_test
Expand All @@ -28,10 +28,9 @@ end
# Visualize training data and the true function
function plot_data_and_true_function(x_train, y_train, x_test, y_test, a, lb, ub)
xs = range(lb, ub, length=1000)
scatter(x_train, y_train, label="Training points", xlims=(lb, ub), ylims=(-1, 1), legend=:top)
plot!(xs, tensor_product_function.(Ref(xs), a), label="True function", legend=:top)
scatter!(x_test, y_test, label="Test points")
plot(xs, tensor_product_function.(xs, a), label="True Function", legend=:top)
scatter!(x_train, repeat([y_train], length(x_train)), label="Training Points", xlims=(lb,ub), ylims=(-1,1))
scatter!(x_test, repeat([y_test], length(x_test)), label="Test Points")
end
# Generate data and plot
Expand Down

0 comments on commit 7849439

Please sign in to comment.