forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FunctionsManual.cpp
7188 lines (6647 loc) · 247 KB
/
FunctionsManual.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <torch/csrc/autograd/FunctionsManual.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/variable.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/LegacyBatchedTensorImpl.h>
#include <ATen/ScalarOps.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/Utils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/core/Reduction.h>
#include <ATen/core/grad_mode.h>
#include <ATen/native/Activation.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/SparseTensorUtils.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/OptionalArrayRef.h>
#include <c10/util/SmallBuffer.h>
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
#include <algorithm>
#include <ciso646>
#include <functional>
#include <numeric>
#include <utility>
// Helper functions for autogenerated code
// These used to be inlined into the codegened Functions.cpp
namespace torch::autograd::generated::details {
using at::areAnyTensorSubclassLike;
using at::IntArrayRef;
using at::OptionalIntArrayRef;
using at::Scalar;
using at::Tensor;
using at::TensorList;
const char* kCudnnDoubleBackwardMsg =
"Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: \nwith torch.backends.cudnn.flags(enabled=False):\n output = model(inputs)";
Tensor apply_loss_reduction(const Tensor& unreduced, int64_t reduction) {
if (reduction == at::Reduction::Mean) {
return unreduced.mean();
} else if (reduction == at::Reduction::Sum) {
return unreduced.sum();
}
return unreduced;
}
static bool isDefined(const std::optional<Tensor>& t) {
return t.has_value() && t->defined();
}
Tensor toNonOptTensor(const std::optional<Tensor>& t) {
return t.has_value() ? *t : Tensor();
}
Tensor toNonOptFwGrad(const std::optional<Tensor>& t) {
return (t.has_value() && t->defined()) ? t->_fw_grad(/*level */ 0) : Tensor();
}
Tensor toNonOptPrimal(const std::optional<Tensor>& t) {
if (t.has_value() && t->defined()) {
if (t->unsafeGetTensorImpl()->is_wrapped_number()) {
return *t;
}
return t->_fw_primal(/* level */ 0);
}
return Tensor();
}
void copy_range(variable_list& out, IndexRange range, const Tensor& t) {
TORCH_CHECK(range.second <= out.size());
TORCH_CHECK(
range.second - range.first == 1, "inconsistent range for Tensor output");
out[range.first] = t;
}
void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
TORCH_CHECK(range.second <= out.size());
TORCH_CHECK(
range.second - range.first == t.size(),
"inconsistent range for TensorList output");
std::copy(
t.begin(), t.end(), out.begin() + static_cast<int64_t>(range.first));
}
Tensor copysign_tensor_self_backward(
const Tensor& grad,
const Tensor& self,
const Tensor& result) {
auto ratio = result / self;
ratio.masked_fill_(self == 0, 0);
return grad * ratio;
}
template <typename T>
T not_implemented_base(const char* name, const char* reason) {
std::string msg =
c10::str("the derivative for '", name, "' is not implemented.");
if (reason[0] != '\0') {
msg = c10::str(msg, " ", reason);
};
TORCH_CHECK_NOT_IMPLEMENTED(false, msg);
}
Tensor not_implemented(const char* name, const char* reason) {
return not_implemented_base<Tensor>(name, reason);
}
std::vector<Tensor> not_implemented_list(const char* name, const char* reason) {
return not_implemented_base<std::vector<Tensor>>(name, reason);
}
Tensor maybe_multiply(const Tensor& t, const Scalar& s) {
bool is_one = false;
if (s.isFloatingPoint()) {
is_one = s.toSymFloat() == 1;
} else if (s.isIntegral(true)) {
is_one = s.toSymInt() == 1;
}
if (is_one) {
return t;
} else {
return t * s;
}
}
int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
int64_t size = 1;
if (sizes.empty()) {
return 1;
}
for (auto d : dim) {
d = at::maybe_wrap_dim(d, static_cast<int64_t>(sizes.size()));
size *= sizes[d];
}
return size;
}
static c10::SymInt _safe_size(c10::SymIntArrayRef sizes, c10::IntArrayRef dim) {
c10::SymInt size = 1;
if (sizes.empty()) {
return 1;
}
for (auto d : dim) {
d = at::maybe_wrap_dim(d, static_cast<int64_t>(sizes.size()));
size *= sizes[d];
}
return size;
}
Tensor handle_r_to_c(ScalarType self_st, Tensor gradient_result) {
if (!at::isComplexType(self_st) && gradient_result.is_complex()) {
// R -> C
return at::real(gradient_result);
}
return gradient_result;
}
static Tensor handle_r_to_c(const Tensor& self, Tensor gradient_result) {
if (!self.is_complex() && gradient_result.is_complex()) {
// R -> C
return at::real(gradient_result);
}
return gradient_result;
}
Tensor restore_reduced_dims(
const Tensor& output,
IntArrayRef dims,
bool keepdim) {
if (keepdim) {
return output;
}
auto total_dims = output.dim() + dims.size();
std::vector<c10::SymInt> target_shape(total_dims, 0);
for (int64_t i : dims) {
if (i < 0) {
i = static_cast<int64_t>(total_dims) + i;
}
target_shape[i] = 1;
}
int64_t j = 0;
for (const c10::SymInt& i : output.sym_sizes()) {
while (target_shape[j] > 0)
j++;
target_shape[j++] = i;
}
return output.reshape_symint(target_shape);
}
Tensor scale_grad_by_count(
const Tensor& grad,
const Tensor& mask,
IntArrayRef dims) {
return (grad / mask.sum(dims, true)) * mask;
}
Tensor amaxamin_jvp(
const Tensor& x,
const Tensor& dx,
const Tensor& result,
IntArrayRef dim,
bool keepdim) {
auto mask = x == restore_reduced_dims(result, dim, keepdim);
return at::where(mask, dx, 0.).sum(dim, keepdim) / mask.sum(dim, keepdim);
}
std::tuple<Tensor, Tensor> _euclidean_dist_backward(
const Tensor& grad,
const Tensor& x1,
const Tensor& x2,
const Tensor& res) {
if (!grad.defined()) {
return std::tuple<Tensor, Tensor>(Tensor(), Tensor());
}
// handle case at 0 where we return a subgradient containing 0
Tensor ratio = grad / res;
ratio.masked_fill_(res == 0, 0);
return std::tuple<Tensor, Tensor>{
x1 * ratio.sum(-1, true) - ratio.matmul(x2),
x2 * ratio.sum(-2, false).unsqueeze(-1) - ratio.mT().matmul(x1)};
}
Tensor norm_backward(
const Tensor& grad,
const Tensor& self,
const std::optional<Scalar>& p_,
const Tensor& norm) {
return norm_backward(grad, self, p_, norm, {}, true);
}
Tensor norm_backward(
Tensor grad,
const Tensor& self,
const std::optional<Scalar>& p_,
Tensor norm,
IntArrayRef dim,
bool keepdim) {
// NB: We mask fill the NaNs in the output to be zero but still do float
// division
// by zero, which ASAN complains about. One way to appease ASAN is to fill
// the problematic values with something arbitrary before the division,
// but we decide not to due to the perf hit. Instead we just silence ASAN
// where necessary
size_t ndim = self.dim();
double p = p_.value_or(2.0).toDouble();
Tensor self_scaled;
Tensor scale_v;
if (!keepdim && self.dim() != 0) {
grad = unsqueeze_multiple(grad, dim, ndim);
norm = unsqueeze_multiple(norm, dim, ndim);
}
if (p == 0.0) {
return {};
} else if (p == 1.0) {
return self.sgn() * grad;
} else if (p == 2.0) {
return grad * (self / norm).masked_fill_(norm == 0, 0);
} else if (std::isinf(p)) {
// Derivative of amax(abs(self), dim, keepdim) but respecting nans
// We create a mask of `argmax`: it's argmax if self.abs() == norm or it's
// NaN
auto self_abs = self.abs();
auto mask = self_abs.eq(norm).logical_or(self_abs.isnan());
return self.sgn() * ((grad / mask.sum(dim, true)) * mask);
} else if (p < 1.0) {
self_scaled =
self.sgn() * self.abs().pow_(p - 1).masked_fill_(self == 0, 0);
return self_scaled * grad * norm.pow(1 - p);
} else if (p < 2.0) {
self_scaled = self.sgn() * self.abs().pow_(p - 1);
scale_v = grad / norm.pow(p - 1);
scale_v.masked_fill_(norm == 0, 0);
return self_scaled * scale_v;
} else {
self_scaled = self * self.abs().pow_(p - 2);
scale_v = grad / norm.pow(p - 1);
scale_v.masked_fill_(norm == 0, 0);
return self_scaled * scale_v;
}
}
// See norm_backward above for a note on ignoring the sanitizer
Tensor norm_jvp(
const Tensor& self_p,
const Tensor& self_t,
const std::optional<Scalar>& p_,
Tensor norm,
IntArrayRef dim,
bool keepdim) {
// NB: currently norm_jvp is also reused for dist's jvp (which haas two
// differentiable inputs)
// but self_t still cannot be a ZT because that would require both self_t
// and other_t to be ZT
TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor());
size_t ndim = self_p.dim(); // composite compliance?
double p = p_.value_or(2.0).toDouble();
if (p == 0.0) {
return at::zeros_like(norm);
} else if (p == 1.0) {
auto result = self_p.sgn();
result = areAnyTensorSubclassLike({self_t}) ? result.mul(self_t.conj())
: result.mul_(self_t.conj());
result = at::real(result);
return result.sum(dim, keepdim);
} else if (p == 2.0) {
auto result = self_p.mul(self_t.conj());
result = at::real(result);
result = result.sum(dim, keepdim);
return result.div_(norm).masked_fill_(norm == 0, 0);
} else if (std::isinf(p)) {
if (!keepdim && self_p.dim() != 0) {
norm = unsqueeze_multiple(norm, dim, ndim);
}
const auto self_isnan = self_p.isnan();
const auto norm_isnan = norm.isnan();
const auto& self_and_norm_isnan = areAnyTensorSubclassLike({norm})
? self_isnan.logical_and(norm_isnan)
: self_isnan.logical_and_(norm_isnan);
const auto is_eq_max =
(self_p.abs() == norm).logical_or_(self_and_norm_isnan).type_as(norm);
auto nb_max = is_eq_max.count_nonzero(dim);
if (self_p.dim() != 0) {
nb_max = unsqueeze_multiple(nb_max, dim, ndim);
}
return (at::real(self_p.sgn() * self_t.conj()) * is_eq_max / nb_max)
.sum(dim, keepdim);
} else if (p < 1.0) {
auto sumpow_t = (self_p.abs().pow_(p - 1).masked_fill_(self_p == 0, 0) *
at::real(self_p.sgn() * self_t.conj()))
.sum(dim, keepdim);
return sumpow_t * norm.pow(1 - p);
} else if (p < 2.0) {
auto sumpow_t =
(self_p.abs().pow_(p - 1) * at::real(self_p.sgn() * self_t.conj()))
.sum(dim, keepdim);
auto out = sumpow_t / norm.pow(p - 1);
return out.masked_fill_(norm == 0, 0);
} else {
auto sumpow_t =
(self_p.abs().pow_(p - 2) * at::real(self_p * self_t.conj()))
.sum(dim, keepdim);
auto out = sumpow_t / norm.pow(p - 1);
return out.masked_fill_(norm == 0, 0);
}
}
Tensor norm_jvp(
const Tensor& self_p,
const Tensor& self_t,
const std::optional<Scalar>& p_,
Tensor norm) {
return norm_jvp(self_p, self_t, p_, std::move(norm), {}, true);
}
Tensor _nested_from_padded_backward(
const Tensor& grad,
const Tensor& input,
bool do_transform_0213) {
if (do_transform_0213) {
auto new_sizes = {
input.size(0), input.size(2), (input.size(1) * input.size(3))};
auto out = grad.to_padded_tensor(0, new_sizes);
auto expand_last_dim_size = {
input.size(0), input.size(2), input.size(1), input.size(3)};
return out.view(expand_last_dim_size).permute({0, 2, 1, 3});
}
return grad.to_padded_tensor(0, input.sizes());
}
std::tuple<Tensor, Tensor, Tensor> linear_double_backward(
const variable_list& grads,
const Tensor& self,
const Tensor& grad_output,
const Tensor& weight) {
if (!grad_output.defined()) {
return std::make_tuple(Tensor(), Tensor(), Tensor());
}
Tensor grad_self, grad_grad_output, grad_weight;
if (grads[1].defined()) {
grad_self =
(grad_output.dim() == 1 ? grad_output.unsqueeze(0) : grad_output)
.matmul(grads[1]);
if (grad_output.dim() == 1) {
grad_self = grad_self.squeeze(0);
}
}
if (grads[0].defined()) {
grad_weight =
(grad_output.dim() == 1 ? grad_output.unsqueeze(1) : grad_output.mT())
.matmul(grads[0].dim() == 1 ? grads[0].unsqueeze(0) : grads[0]);
}
if (grads[0].defined() || grads[1].defined() || grads[2].defined()) {
grad_grad_output = at::zeros_like(grad_output);
if (grad_output.dim() == 1) {
grad_grad_output = grad_grad_output.unsqueeze(0);
}
}
if (grads[0].defined()) {
grad_grad_output = grad_grad_output +
(grads[0].dim() == 1 ? grads[0].unsqueeze(0) : grads[0])
.matmul(weight.mT());
}
if (grads[1].defined()) {
grad_grad_output = grad_grad_output +
(self.dim() == 1 ? self.unsqueeze(0) : self).matmul(grads[1].mT());
}
if (grads[2].defined()) {
grad_grad_output = grad_grad_output + grads[2];
}
if (grad_grad_output.defined() && grad_output.dim() == 1) {
grad_grad_output = grad_grad_output.squeeze(0);
}
return std::make_tuple(
std::move(grad_self),
std::move(grad_grad_output),
std::move(grad_weight));
}
Tensor linalg_vector_norm_jvp(
const Tensor& self_p,
const Tensor& self_t,
const Scalar& scalar_ord,
Tensor norm,
const at::OptionalIntArrayRef& opt_dim,
bool keepdim) {
// No need to handle the dtype arg as it's handled via broadcasting in the
// function
auto dim = opt_dim.value_or(IntArrayRef({}));
return norm_jvp(self_p, self_t, scalar_ord, std::move(norm), dim, keepdim);
}
Tensor linalg_vector_norm_backward(
Tensor grad,
const Tensor& self,
const Scalar& scalar_ord,
Tensor norm,
const at::OptionalIntArrayRef& opt_dim,
bool keepdim) {
// No need to handle the dtype arg as it's handled via broadcasting in the
// function
auto dim = opt_dim.value_or(IntArrayRef({}));
return norm_backward(
std::move(grad), self, scalar_ord, std::move(norm), dim, keepdim);
}
Tensor pow_backward(Tensor grad, const Tensor& self, const Scalar& exponent) {
if (exponent.equal(0.0)) {
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
auto grad_lambda = [&](auto exp) {
return grad * (exp * self.pow(exp - 1)).conj();
};
Tensor out = (exponent.isComplex())
? grad_lambda(exponent.toComplexDouble())
: grad_lambda(exponent.toDouble());
return handle_r_to_c(self, std::move(out));
}
}
Tensor pow_backward_self(
const Tensor& grad,
const Tensor& self,
const Tensor& exponent) {
auto out = at::where(
exponent == 0.0,
at::zeros({}, grad.options()),
grad * (exponent * self.pow(exponent - 1)).conj());
return handle_r_to_c(self, std::move(out));
}
// Caveats:
// We define d(a^b)/db at a = 0 and b < 0 to be -inf. This is due to
// d(a^b)/db -> -inf for a fixed b as a -> +0
// Currently, tensorflow defines d(a^b)/db = nan for a = 0 and b < 0.
//
// We define d(a^b)/db = 0 for a = 0 and b = 0 by continuity as
// d(a^b)/db = 0 for a > 0 and b -> +0.
// Currently, tensorflow agrees with us.
Tensor pow_backward_exponent(
const Tensor& grad,
const Tensor& self,
const Tensor& exponent,
const Tensor& result) {
Tensor cond;
if (exponent.is_complex()) {
auto is_real_exp =
at::logical_and(at::imag(exponent) == 0, at::real(exponent) >= 0);
cond = at::logical_and(self == 0, is_real_exp);
} else {
cond = at::logical_and(self == 0, exponent >= 0);
}
auto promoted_dtype = at::result_type(self, exponent);
// `.to()` is no-op if dtype is same.
auto self_ = self.to(promoted_dtype);
auto out =
grad *
at::where(
cond, at::zeros({}, grad.options()), (result * self_.log()).conj());
return handle_r_to_c(exponent, std::move(out));
}
Tensor pow_backward_exponent(
const Tensor& grad,
const Scalar& base,
const Tensor& exponent,
const Tensor& result) {
auto grad_lambda = [](const Tensor& a, const Scalar& b) {
return (a * b.log()).conj();
};
auto base_ = exponent.is_complex() && !base.isComplex()
? base.toComplexDouble()
: base;
if (base.equal(0.0)) {
auto cond = [](auto exp) {
if (exp.is_complex()) {
return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
} else {
return exp >= 0;
}
};
auto out = grad *
at::where(cond(exponent),
at::zeros({}, grad.options()),
grad_lambda(result, base_));
return handle_r_to_c(exponent, std::move(out));
} else {
auto out = grad * grad_lambda(result, base_);
return handle_r_to_c(exponent, std::move(out));
}
}
Tensor angle_backward(const Tensor& grad, const Tensor& self) {
if (self.is_complex()) {
return at::where(
self == 0.0,
at::zeros({}, self.options()),
grad * self / self.abs().pow(2) *
Scalar(c10::complex<double>{0.0, 1.0}));
} else {
return at::zeros_like(self, at::MemoryFormat::Preserve);
}
}
Tensor mvlgamma_backward(const Tensor& grad, const Tensor& self, int64_t p) {
Tensor args = at::arange(
-static_cast<double>(p) / 2. + 0.5,
0.5,
0.5,
// use strided here regardless of self's layout; useful for e.g. NJT
self.options().layout(c10::kStrided));
args = args.add(self.unsqueeze(-1));
return grad * args.digamma_().sum(-1);
}
Tensor sgn_backward(const Tensor& x, const Tensor& gx, const Tensor& sgn) {
if (x.is_complex()) {
auto abs = x.abs();
return ((gx - (sgn * sgn) * gx.conj()) / (2. * abs))
.masked_fill_(abs == 0., 0.);
} else {
return at::_efficientzerotensor(sgn.sizes(), sgn.options());
}
}
Tensor masked_fill_backward(const Tensor& grad, const Tensor& mask) {
// masked_select does not work well with functorch, as its shape is
// data-dependent
return areAnyTensorSubclassLike({grad, mask})
? at::where(mask, grad, 0).sum()
: grad.masked_select(mask).sum();
}
template <typename T>
Tensor mul_tensor_backward(const Tensor& grad, T other, ScalarType self_st) {
auto out = grad * other.conj();
return handle_r_to_c(self_st, std::move(out));
}
template Tensor mul_tensor_backward(const Tensor&, Tensor, ScalarType);
template Tensor mul_tensor_backward(const Tensor&, Scalar, ScalarType);
template <typename T>
Tensor div_tensor_self_backward(
const Tensor& grad,
T other,
ScalarType self_st,
const std::optional<c10::string_view>& rounding_mode) {
if (rounding_mode.has_value()) {
return at::zeros_like(grad, grad.options().dtype(self_st));
}
auto result = grad / other.conj();
return handle_r_to_c(self_st, std::move(result));
}
template Tensor div_tensor_self_backward(
const Tensor&,
Tensor,
ScalarType,
const std::optional<c10::string_view>&);
template Tensor div_tensor_self_backward(
const Tensor&,
Scalar,
ScalarType,
const std::optional<c10::string_view>&);
template <typename T>
Tensor div_tensor_self_backward(
const Tensor& grad,
T other,
ScalarType self_st) {
return div_tensor_self_backward(
grad, std::move(other), self_st, std::nullopt);
}
template Tensor div_tensor_self_backward(const Tensor&, Tensor, ScalarType);
template Tensor div_tensor_self_backward(const Tensor&, Scalar, ScalarType);
Tensor div_tensor_other_backward(
const Tensor& grad,
const Tensor& self,
const Tensor& other,
const std::optional<c10::string_view>& rounding_mode) {
if (rounding_mode.has_value()) {
return at::zeros_like(grad, grad.options().dtype(other.scalar_type()));
}
auto result = -grad * ((self / other) / other).conj();
return handle_r_to_c(other, std::move(result));
}
Tensor div_tensor_other_backward(
const Tensor& grad,
const Tensor& self,
const Tensor& other) {
return div_tensor_other_backward(grad, self, other, std::nullopt);
}
Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) {
// invert the permutation
auto ndims = fwd_dims.size();
std::vector<int64_t> dims(ndims);
for (const auto i : c10::irange(ndims)) {
dims[at::maybe_wrap_dim(fwd_dims[i], static_cast<int64_t>(ndims))] =
static_cast<int64_t>(i);
}
return grad.permute(dims);
}
Tensor rad2deg_backward(const Tensor& grad) {
constexpr double M_180_PI =
57.295779513082320876798154814105170332405472466564;
return at::mul(grad, Scalar(M_180_PI));
}
Tensor deg2rad_backward(const Tensor& grad) {
constexpr double M_PI_180 =
0.017453292519943295769236907684886127134428718885417;
return at::mul(grad, Scalar(M_PI_180));
}
Tensor unsqueeze_multiple(
const Tensor& t,
OptionalIntArrayRef opt_dim,
size_t n_dims) {
if (opt_dim.has_value()) {
IntArrayRef dim = opt_dim.value();
auto dim_size = dim.size();
// Optimisation for two common cases
if (dim_size == 0) {
return t;
} else if (dim_size == 1) {
return t.unsqueeze(dim[0]);
}
}
auto dims_to_unsqueeze = at::dim_list_to_bitset(opt_dim, n_dims);
Tensor res = t;
for (const auto i : c10::irange(n_dims)) {
if (dims_to_unsqueeze[i]) {
res = res.unsqueeze(static_cast<int64_t>(i));
}
}
return res;
}
Tensor sum_backward(
const Tensor& grad,
c10::SymIntArrayRef sizes,
OptionalIntArrayRef opt_dims,
bool keepdim) {
if (!keepdim && !sizes.empty()) {
if (opt_dims.has_value() && !opt_dims.value().empty()) {
return unsqueeze_multiple(grad, opt_dims, sizes.size())
.expand_symint(sizes);
}
}
return grad.expand_symint(sizes);
}
Tensor sum_backward(
const Tensor& grad,
c10::SymIntArrayRef sizes,
c10::IntArrayRef dims,
bool keepdim) {
if (!keepdim && !sizes.empty() && !dims.empty()) {
// we are only using `keepdim=true` path for SymInts for now
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"Only the keepdim=true path is implemented to support symints in autograd");
} else {
return grad.expand_symint(sizes);
}
}
Tensor nansum_backward(
const Tensor& grad,
const Tensor& self,
at::OptionalIntArrayRef dims,
bool keepdim) {
return sum_backward(grad, self.sym_sizes(), dims, keepdim) *
self.isnan().logical_not();
}
Tensor mean_backward(
const Tensor& grad,
c10::SymIntArrayRef shape,
OptionalIntArrayRef opt_dim,
c10::SymInt numel,
bool keepdim) {
bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty();
auto n =
is_all_reduce ? std::move(numel) : _safe_size(shape, opt_dim.value());
return sum_backward(grad, shape, opt_dim, keepdim) / std::move(n);
}
std::vector<c10::SymInt> reverse_list_symint(const c10::SymIntArrayRef list) {
auto result = std::vector<c10::SymInt>();
result.reserve(list.size());
for (auto iter = list.rbegin(); iter != list.rend(); iter++) {
result.push_back(*iter);
}
return result;
}
std::vector<int64_t> reverse_list(const IntArrayRef list) {
auto result = std::vector<int64_t>();
result.reserve(list.size());
for (auto iter = list.rbegin(); iter != list.rend(); iter++) {
result.push_back(*iter);
}
return result;
}
Tensor prod_safe_zeros_backward(
const Tensor& grad,
const Tensor& inp,
int64_t dim) {
if (inp.sym_numel() == 0) {
// When input has a zero sized dimension (empty tensor),
// we don't need to actually compute the grads.
// So we just reshape `grad` as `input`.
return grad.expand_as(inp);
}
if (inp.sym_size(dim) == 1) {
return grad;
}
auto ones_size = inp.sym_sizes().vec();
ones_size[dim] = 1;
Tensor ones = at::ones_symint(ones_size, grad.options());
Tensor exclusive_normal_nocp =
at::cat({ones, inp.narrow_symint(dim, 0, inp.sym_size(dim) - 1)}, dim);
Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim);
Tensor narrow_reverse =
inp.narrow_symint(dim, 1, inp.sym_size(dim) - 1).flip(dim);
Tensor exclusive_reverse_nocp =
at::cat({std::move(ones), std::move(narrow_reverse)}, dim);
Tensor exclusive_reverse = exclusive_reverse_nocp.cumprod(dim).flip(dim);
return grad * (exclusive_normal * exclusive_reverse).conj();
}
// note that the gradient for prod is equivalent to:
// cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.:
// input: [ a, b, c]
// cumprod(exclusive, normal): [1 , a, a * b]
// cumprod(exclusive, reverse): [b * c, c, 1]
// product: [b * c, a * c, a * b]
// and this is safe under input with 0s.
Tensor prod_backward(
const Tensor& grad,
const Tensor& input,
const Tensor& result) {
if (input.dim() == 0) {
return grad;
}
if (input.is_meta() || isTensorSubclassLike(input)) {
// For Composite Compliance, always take the safer (and slower) path
return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0)
.view_as(input);
}
Tensor zero_idx = (input == 0).nonzero();
if (zero_idx.sym_numel() == 0) {
return grad * (result / input).conj();
} else if (!at::GradMode::is_enabled() && zero_idx.sym_size(0) > 1) {
return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0)
.view_as(input);
}
}
Tensor prod_backward(
Tensor grad,
const Tensor& input,
Tensor result,
int64_t dim,
bool keepdim) {
if (input.dim() == 0) {
return grad;
}
dim = at::maybe_wrap_dim(dim, static_cast<int64_t>(input.sym_sizes().size()));
if (!keepdim) {
// `prod` reduces the dimension at `dim`,
// so, unsqueeze `grad` and `result` at dim.
grad = grad.unsqueeze(dim);
result = result.unsqueeze(dim);
}
if (input.is_meta() || isTensorSubclassLike(input)) {
// For Composite Compliance, always take the safer (and slower) path
return prod_safe_zeros_backward(grad, input, dim);
}
Tensor zero_mask = (input == 0);
Tensor slice_zero_count = zero_mask.sum(dim, true);
int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
if (total_zeros == 0) {
return grad * (result / input).conj();
} else {
return prod_safe_zeros_backward(grad, input, dim);
}
}
template <typename solve_f>
static Tensor generic_solve_jvp(
solve_f solve,
const Tensor& X,
const Tensor& A,
const Tensor& dA,
const Tensor& dB) {
auto is_vector_case = at::native::linalg_solve_is_vector_rhs(dA, dB);
auto dA_contrib =
is_vector_case ? dA.matmul(X.unsqueeze(-1)).squeeze(-1) : dA.matmul(X);
// In general,
// dX = solve(A, dB - dA_contrib), but this behavior is different for
// lu_solve. For refer to lu_solve_jvp for more details on this.
return solve(A, dB, dA_contrib);
}
Tensor cumsum_backward(const Tensor& grad, int64_t dim) {
// Trivial case
if (grad.sym_numel() <= 1 || grad.sym_size(dim) == 1) {
return grad;
}
return grad.flip(dim).cumsum(dim).flip(dim);
}
Tensor logsumexp_backward(
Tensor grad,
const Tensor& self,
Tensor result,
IntArrayRef dim,
bool keepdim) {
if (!keepdim && self.dim() != 0) {
grad = unsqueeze_multiple(grad, dim, self.sym_sizes().size());
result = unsqueeze_multiple(result, dim, self.sym_sizes().size());
}
return grad * (self - result).exp().conj();
}
Tensor logcumsumexp_backward(
Tensor grad,
const Tensor& self,
const Tensor& result,
int64_t dim) {
if (grad.dim() == 0 || grad.sym_numel() == 0) {
return grad;
}
// Reference: https://github.com/tensorflow/tensorflow/blob/
// 2a5910906a0e0f3dbc186ff9db6386d81a63448c/tensorflow/python/ops/math_grad.py#L1832-L1863
auto scalar_min = AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(
at::ScalarType::BFloat16,
at::typeMetaToScalarType(grad.dtype()),
"logcumsumexp_backward",
[]() { return c10::Scalar(std::numeric_limits<scalar_t>::lowest()); });
auto reverse_logcumsumexp = [dim](auto x) {
return at::flip(at::logcumsumexp(at::flip(x, {dim}), dim), {dim});
};
if (!at::is_complex(grad)) {
auto grad_min = at::scalar_tensor(scalar_min, grad.options());
auto log_abs_grad = grad.abs().log();
auto log_grad_positive = at::where(grad > 0, log_abs_grad, grad_min);
auto log_grad_negative = at::where(grad < 0, log_abs_grad, grad_min);
auto output_pos =
(reverse_logcumsumexp(log_grad_positive - result) + self).exp();
auto output_neg =
(reverse_logcumsumexp(log_grad_negative - result) + self).exp();
return output_pos - output_neg;
} else {
// no trick separating the positive and negative required
auto log_grad = grad.conj().log();
auto output = (reverse_logcumsumexp(log_grad - result) + self).exp();
return output.conj();
}
}
Tensor logcumsumexp_jvp(
const Tensor& self_p,
const Tensor& self_t,
int64_t dim) {
// Mostly taken from logsumexp_jvp
// NB: for simplicity, we recompute some values that can be reused from
// forward
auto self_p_exp = [&self_p, dim]() {
if (!at::is_complex(self_p)) {
return (self_p - std::get<0>(at::max(self_p, dim, true)))
.exp(); // Use the exp-normalize trick
} else {
// at::max doesn't support complex128
return self_p.exp();
}
}();
auto cumsumexp_p = self_p_exp.cumsum(dim);
TORCH_INTERNAL_ASSERT(!self_t._is_zerotensor())
constexpr double eps = 1e-13;
if (areAnyTensorSubclassLike({self_p, self_t})) {
auto result = (self_p_exp * self_t).cumsum(dim);
result /= cumsumexp_p.add_(eps);
return result;
} else {
self_p_exp *= self_t;
auto cumsumexp_t = self_p_exp.cumsum(dim);
return cumsumexp_t /= cumsumexp_p.add_(eps);
}
}
Tensor unbind_backward(const variable_list& grads, int64_t dim) {
c10::SymIntArrayRef sizes;
at::TensorOptions o;
for (const auto& v : grads) {
if (v.defined()) {
sizes = v.sym_sizes();
o = static_cast<Tensor>(v).options();
break;
}
}
auto grads_tensors = fmap(grads, [&](const Variable& v) {
return (
v.defined() ? static_cast<Tensor>(v)
: at::zeros({}, o).expand_symint(sizes));
});
return at::stack(grads_tensors, dim);
}
Tensor unbind_backward_nested(
const variable_list& grads,
const Tensor& nt_sizes,