Skip to content

Latest commit

 

History

History
86 lines (59 loc) · 2.96 KB

README.md

File metadata and controls

86 lines (59 loc) · 2.96 KB

Simple Linear Regression Neural Network


Project Banner

Project Banner

My Very First Model Trained . This Python script demonstrates the creation and training of a simple neural network for linear regression using TensorFlow and NumPy.

Usage

Run the script using a Python interpreter. Upon execution, the script will prompt the user to train the neural network. If the user chooses to train the model (y), the script will fit the neural network to the given data. Otherwise, it will print a warning that the neural network is not trained.

python simple_linear_regression_nn.py

Neural Network Architecture

The neural network consists of one layer with one neuron, making it a simple linear regression model. The model is compiled with Stochastic Gradient Descent (sgd) as the optimizer and Mean Squared Error (mean_squared_error) as the loss function.

model = tf.keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
model.compile(optimizer='sgd', loss='mean_squared_error')

Input Data

The input data (xs) and corresponding output data (ys) are provided for training the neural network. The script uses NumPy to create arrays for input and output data.

xs = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=float)
ys = np.array([-2.0, 1.0, 4.0, 7.0, 10.0, 13.0], dtype=float)

Training the Neural Network

The user is prompted to decide whether to train the neural network. If training is requested, the model is fitted to the input and output data using the model.fit method.

user_request = input("Do you want to train the neural network? (y/n) ")
if user_request.lower() == "y":
    model.fit(xs, ys, epochs=500)
    print("Training complete.")
else:
    print("Training skipped. (Beware: The neural network is not trained.)")

Making Predictions

After training or if the user chooses to skip training, the script makes a prediction using the trained model. In this example, the script predicts the output for the input value 10.0.

print(model.predict([10.0]))

Modules Used

The script utilizes the following Python modules:

  • TensorFlow - An open-source machine learning framework.
  • NumPy - A powerful library for numerical operations in Python.

Ensure these modules are installed in your Python environment before running the script.

pip install tensorflow numpy

Project Banner