forked from phoenix104104/LapSRN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vllab_dag_loss.m
60 lines (56 loc) · 1.95 KB
/
vllab_dag_loss.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
classdef vllab_dag_loss < dagnn.Loss
% -------------------------------------------------------------------------
% Description:
% loss object for dagnn class
% if using your own MatConvNet version, copy this file to [matconvnet]/matlab/+dagnn
%
% Parameters:
% - lambda : weight of loss
% - loss_type : support 'L1' or 'L2' loss
%
% Citation:
% Deep Laplacian Pyramid Networks for Fast and Accurate Super-Resolution
% Wei-Sheng Lai, Jia-Bin Huang, Narendra Ahuja, and Ming-Hsuan Yang
% IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2017
%
% Contact:
% Wei-Sheng Lai
% University of California, Merced
% -------------------------------------------------------------------------
properties
lambda = 1;
loss_type = 'L2';
end
methods
function outputs = forward(obj, inputs, params)
if( strcmp(obj.loss_type, 'L2') )
outputs{1} = vllab_nn_L2_loss(inputs{1}, inputs{2});
elseif( strcmp(obj.loss_type, 'L1') )
outputs{1} = vllab_nn_L1_loss(inputs{1}, inputs{2});
else
error('Unknown loss %s\n', obj.loss_type);
end
outputs{1} = obj.lambda * outputs{1};
n = obj.numAveraged ;
m = n + size(inputs{1},4) ;
obj.average = (n * obj.average + gather(outputs{1})) / m ;
obj.numAveraged = m ;
end
function [derInputs, derParams] = backward(obj, inputs, params, derOutputs)
if( strcmp(obj.loss_type, 'L2') )
derInputs{1} = vllab_nn_L2_loss(inputs{1}, inputs{2}, derOutputs{1}) ;
elseif( strcmp(obj.loss_type, 'L1') )
derInputs{1} = vllab_nn_L1_loss(inputs{1}, inputs{2}, derOutputs{1}) ;
else
error('Unknown loss %s\n', obj.loss_type);
end
derInputs{1} = obj.lambda * derInputs{1};
derInputs{2} = [] ;
derParams = {} ;
end
function obj = vllab_dag_loss(varargin)
obj.load(varargin) ;
end
end
end