From 7849439072f8803fc02bfeb04f678b644117bc16 Mon Sep 17 00:00:00 2001 From: MRIDUL JAIN <105979087+Spinachboul@users.noreply.github.com> Date: Sat, 13 Jan 2024 23:22:51 +0530 Subject: [PATCH] Update tensor_prod.md --- docs/src/tensor_prod.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/src/tensor_prod.md b/docs/src/tensor_prod.md index 57374a42..472f5bc8 100644 --- a/docs/src/tensor_prod.md +++ b/docs/src/tensor_prod.md @@ -19,7 +19,7 @@ function generate_data(n, lb, ub, a) x_train = sample(n, lb, ub, SobolSample()) y_train = tensor_product_function(x_train, a) - x_test = sample(1000, lb, ub, SobolSample()) # Generating test data + x_test = sample(1000, lb, ub, RandomSample()) # Generating test data y_test = tensor_product_function(x_test, a) # Generating test labels return x_train, y_train, x_test, y_test @@ -28,10 +28,9 @@ end # Visualize training data and the true function function plot_data_and_true_function(x_train, y_train, x_test, y_test, a, lb, ub) xs = range(lb, ub, length=1000) - - scatter(x_train, y_train, label="Training points", xlims=(lb, ub), ylims=(-1, 1), legend=:top) - plot!(xs, tensor_product_function.(Ref(xs), a), label="True function", legend=:top) - scatter!(x_test, y_test, label="Test points") + plot(xs, tensor_product_function.(xs, a), label="True Function", legend=:top) + scatter!(x_train, repeat([y_train], length(x_train)), label="Training Points", xlims=(lb,ub), ylims=(-1,1)) + scatter!(x_test, repeat([y_test], length(x_test)), label="Test Points") end # Generate data and plot