diff --git a/Network.lua b/Network.lua index 57097f1..ac4c8c6 100644 --- a/Network.lua +++ b/Network.lua @@ -130,7 +130,7 @@ function Network:trainNetwork(epochs, optimizerParams) sizes = self.calSize(sizes) local predictions = self.model:forward(inputs) local loss = criterion:forward(predictions, targets, sizes) - if loss == math.huge then loss = 0 print("Recieved an inf cost!") end + if loss == math.huge or loss == -math.huge then loss = 0 print("Recieved an inf cost!") end self.model:zeroGradParameters() local gradOutput = criterion:backward(predictions, targets) self.model:backward(inputs, gradOutput)