-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
0.4.3 NN - Autograd - Random - Plot - Faster iteration (#42)
* testing yielding over iterators * tri iterator * iter complete * sorting for tensors * matrix exponential using pade * neural network basics * remove req * move images * full implementation of autograd and better networks * Added plplot bindings * merge PRs * docs * remove dataframes * more activation * default bias * random creation methods using alea
- Loading branch information
1 parent
694080e
commit 4c2f50f
Showing
66 changed files
with
4,932 additions
and
695 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
## Basic XOR Classifier | ||
|
||
The following implements a simple XOR classifier to show how to use | ||
`num.cr`'s `Network` class. Plotting is done via `ishi`. | ||
|
||
```crystal | ||
ctx = Num::Grad::Context(Tensor(Float64)).new | ||
bsz = 32 | ||
x_train_bool = Tensor.random(0_u8...2_u8, [bsz * 100, 2]) | ||
y_bool = x_train_bool[..., ...1] ^ x_train_bool[..., 1...] | ||
x_train = ctx.variable(x_train_bool.as_type(Float64)) | ||
y = y_bool.as_type(Float64) | ||
net = Num::NN::Network.new(ctx) do | ||
linear 2, 3 | ||
relu | ||
linear 3, 1 | ||
sgd 0.7 | ||
sigmoid_cross_entropy_loss | ||
end | ||
losses = [] of Float64 | ||
50.times do |epoch| | ||
100.times do |batch_id| | ||
offset = batch_id * 32 | ||
x = x_train[offset...offset + 32] | ||
target = y[offset...offset + 32] | ||
y_pred = net.forward(x) | ||
loss = net.loss(y_pred, target) | ||
puts "Epoch is: #{epoch}" | ||
puts "Batch id: #{batch_id}" | ||
puts "Loss is: #{loss.value.value}" | ||
losses << loss.value.value | ||
loss.backprop | ||
net.optimizer.update | ||
end | ||
end | ||
``` | ||
|
||
``` | ||
... | ||
Epoch is: 49 | ||
Batch id: 95 | ||
Loss is: 0.00065050072686102 | ||
Epoch is: 49 | ||
Batch id: 96 | ||
Loss is: 0.0006024037564266797 | ||
Epoch is: 49 | ||
Batch id: 97 | ||
Loss is: 0.0005297538443899917 | ||
Epoch is: 49 | ||
Batch id: 98 | ||
Loss is: 0.0005765025171222869 | ||
Epoch is: 49 | ||
Batch id: 99 | ||
Loss is: 0.0005290653040218895 | ||
``` | ||
|
||
The Network learns this function very quickly, as XOR is one of the simplest | ||
distributions to hit. Since the training data is so limited, accuracy | ||
can be a bit jagged, but eventually the network smooths out. | ||
|
||
### Loss over time | ||
|
||
![xorloss](xor_classifier_loss.png) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright (c) 2020 Crystal Data Contributors | ||
# | ||
# MIT License | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining | ||
# a copy of this software and associated documentation files (the | ||
# "Software"), to deal in the Software without restriction, including | ||
# without limitation the rights to use, copy, modify, merge, publish, | ||
# distribute, sublicense, and/or sell copies of the Software, and to | ||
# permit persons to whom the Software is furnished to do so, subject to | ||
# the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be | ||
# included in all copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | ||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | ||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND | ||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE | ||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION | ||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION | ||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | ||
|
||
require "../../src/num" | ||
|
||
ctx = Num::Grad::Context(Tensor(Float64)).new | ||
|
||
bsz = 32 | ||
|
||
x_train_bool = Tensor.random(0_u8...2_u8, [bsz * 100, 2]) | ||
|
||
y_bool = x_train_bool[..., ...1] ^ x_train_bool[..., 1...] | ||
|
||
x_train = ctx.variable(x_train_bool.as_type(Float64)) | ||
y = y_bool.as_type(Float64) | ||
|
||
net = Num::NN::Network.new(ctx) do | ||
linear 2, 3 | ||
relu | ||
linear 3, 1 | ||
sgd 0.7 | ||
sigmoid_cross_entropy_loss | ||
end | ||
|
||
losses = [] of Float64 | ||
|
||
50.times do |epoch| | ||
100.times do |batch_id| | ||
offset = batch_id * 32 | ||
x = x_train[offset...offset + 32] | ||
target = y[offset...offset + 32] | ||
|
||
y_pred = net.forward(x) | ||
|
||
loss = net.loss(y_pred, target) | ||
|
||
puts "Epoch is: #{epoch}" | ||
puts "Batch id: #{batch_id}" | ||
puts "Loss is: #{loss.value.value}" | ||
losses << loss.value.value | ||
|
||
loss.backprop | ||
net.optimizer.update | ||
end | ||
end | ||
|
||
Num::Plot::Plot.plot do | ||
scatter (0...losses.size), losses | ||
x_label "Epochs" | ||
y_label "Loss" | ||
label "XOR Classifier Loss" | ||
end |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
## Simple Scatter Plot | ||
|
||
Using [PLplot](http://plplot.sourceforge.net/index.php), a fantastic library for scientific plotting, charts are extremely easy to create with Num.cr. | ||
|
||
```crystal | ||
x = (0...100) | ||
y = Tensor(Float64).rand([100]) | ||
Num::Plot::Plot.plot do | ||
term nil | ||
scatter x, y, code: 14, color: 2 | ||
end | ||
``` | ||
|
||
![scatter](simple_scatter.png) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright (c) 2020 Crystal Data Contributors | ||
# | ||
# MIT License | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining | ||
# a copy of this software and associated documentation files (the | ||
# "Software"), to deal in the Software without restriction, including | ||
# without limitation the rights to use, copy, modify, merge, publish, | ||
# distribute, sublicense, and/or sell copies of the Software, and to | ||
# permit persons to whom the Software is furnished to do so, subject to | ||
# the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be | ||
# included in all copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | ||
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | ||
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND | ||
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE | ||
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION | ||
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION | ||
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | ||
|
||
require "../../src/num" | ||
|
||
x = (0...100) | ||
y = Tensor(Float64).rand([100]) | ||
|
||
Num::Plot::Plot.plot do | ||
term nil | ||
scatter x, y, code: 14, color: 2 | ||
end |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
require "../spec_helper" | ||
|
||
describe Tensor do | ||
it "sorts a one dimensional Tensor" do | ||
a = [4, 3, 2, 1].to_tensor | ||
result = Num.sort(a) | ||
expected = [1, 2, 3, 4].to_tensor | ||
assert_array_equal(result, expected) | ||
end | ||
|
||
it "sorts a strided Tensor" do | ||
a = [4, 3, 2, 1].to_tensor[{..., 2}] | ||
result = Num.sort(a) | ||
expected = [2, 4] | ||
assert_array_equal(result, expected) | ||
end | ||
|
||
it "sorts a Tensor along an axis" do | ||
a = [[3, 5, 6], [1, 1, 2], [9, 2, 3]].to_tensor | ||
result = Num.sort(a, 0) | ||
expected = [[1, 1, 2], [3, 2, 3], [9, 5, 6]].to_tensor | ||
assert_array_equal(result, expected) | ||
end | ||
|
||
it "sorts a strided Tensor along an axis" do | ||
a = [[3, 4, 5, 1], [2, 1, 3, 2], [4, 7, 6, 2]].to_tensor[..., {..., 2}] | ||
result = Num.sort(a, 0) | ||
expected = [[2, 3], [3, 5], [4, 6]].to_tensor | ||
assert_array_equal(result, expected) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.