forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SparseCsrTensorMath.cpp
1489 lines (1312 loc) · 51.5 KB
/
SparseCsrTensorMath.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/Parallel.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/grad_mode.h>
#include <ATen/mkl/Sparse.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/Resize.h>
#include <ATen/native/SparseTensorUtils.h>
#include <ATen/native/TensorConversions.h>
#include <ATen/native/mkl/SparseBlasImpl.h>
#include <ATen/native/sparse/SparseBlasImpl.h>
#include <ATen/native/sparse/SparseCsrTensorMath.h>
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
#include <ATen/AccumulateType.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/Operators.h>
#else
#include <ATen/ops/_conj_physical_native.h>
#include <ATen/ops/_convert_indices_from_coo_to_csr.h>
#include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
#include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_csr_prod_native.h>
#include <ATen/ops/_sparse_csr_sum_native.h>
#include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
#include <ATen/ops/_sparse_mm_reduce_impl_backward_native.h>
#include <ATen/ops/_sparse_mm_reduce_impl_native.h>
#include <ATen/ops/_unique.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/abs_native.h>
#include <ATen/ops/add.h>
#include <ATen/ops/add_native.h>
#include <ATen/ops/addmm.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/angle.h>
#include <ATen/ops/angle_native.h>
#include <ATen/ops/asin.h>
#include <ATen/ops/asin_native.h>
#include <ATen/ops/asinh.h>
#include <ATen/ops/asinh_native.h>
#include <ATen/ops/atan.h>
#include <ATen/ops/atan_native.h>
#include <ATen/ops/atanh.h>
#include <ATen/ops/atanh_native.h>
#include <ATen/ops/ceil.h>
#include <ATen/ops/ceil_native.h>
#include <ATen/ops/conj_physical.h>
#include <ATen/ops/conj_physical_native.h>
#include <ATen/ops/copy_native.h>
#include <ATen/ops/deg2rad.h>
#include <ATen/ops/deg2rad_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like.h>
#include <ATen/ops/erf.h>
#include <ATen/ops/erf_native.h>
#include <ATen/ops/erfinv.h>
#include <ATen/ops/erfinv_native.h>
#include <ATen/ops/expm1.h>
#include <ATen/ops/expm1_native.h>
#include <ATen/ops/fill_native.h>
#include <ATen/ops/floor.h>
#include <ATen/ops/floor_native.h>
#include <ATen/ops/frac.h>
#include <ATen/ops/frac_native.h>
#include <ATen/ops/isinf.h>
#include <ATen/ops/isinf_native.h>
#include <ATen/ops/isnan.h>
#include <ATen/ops/isnan_native.h>
#include <ATen/ops/isneginf.h>
#include <ATen/ops/isneginf_native.h>
#include <ATen/ops/isposinf.h>
#include <ATen/ops/isposinf_native.h>
#include <ATen/ops/log1p.h>
#include <ATen/ops/log1p_native.h>
#include <ATen/ops/mm_native.h>
#include <ATen/ops/mul.h>
#include <ATen/ops/mul_native.h>
#include <ATen/ops/neg.h>
#include <ATen/ops/neg_native.h>
#include <ATen/ops/normal_native.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/ones_like.h>
#include <ATen/ops/rad2deg.h>
#include <ATen/ops/rad2deg_native.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/relu_native.h>
#include <ATen/ops/resize_as_sparse_native.h>
#include <ATen/ops/result_type.h>
#include <ATen/ops/round.h>
#include <ATen/ops/round_native.h>
#include <ATen/ops/round_ops.h>
#include <ATen/ops/sgn.h>
#include <ATen/ops/sgn_native.h>
#include <ATen/ops/sign.h>
#include <ATen/ops/sign_native.h>
#include <ATen/ops/signbit.h>
#include <ATen/ops/signbit_native.h>
#include <ATen/ops/sin.h>
#include <ATen/ops/sin_native.h>
#include <ATen/ops/sinh.h>
#include <ATen/ops/sinh_native.h>
#include <ATen/ops/sparse_mask.h>
#include <ATen/ops/sparse_mask_native.h>
#include <ATen/ops/sqrt.h>
#include <ATen/ops/sqrt_native.h>
#include <ATen/ops/tan.h>
#include <ATen/ops/tan_native.h>
#include <ATen/ops/tanh.h>
#include <ATen/ops/tanh_native.h>
#include <ATen/ops/tensor.h>
#include <ATen/ops/threshold_backward.h>
#include <ATen/ops/threshold_backward_native.h>
#include <ATen/ops/trunc.h>
#include <ATen/ops/trunc_native.h>
#include <ATen/ops/zero_native.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/zeros_like.h>
#endif
#include <algorithm>
namespace at {
namespace meta {
TORCH_META_FUNC(_convert_indices_from_coo_to_csr)
(const Tensor& self, const int64_t size, const bool out_int32) {
TORCH_CHECK(self.dim() <= 1, "Input is supposed to be a vector, but got ",
self.dim(), " dimensional tensor.");
ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
c10::TensorOptions options =
TensorOptions().device(self.options().device()).dtype(scalar_type);
set_output_raw_strided(0, size + 1, {}, options);
}
TORCH_META_FUNC(_convert_indices_from_csr_to_coo)
(const Tensor& crow_indices,
const Tensor& col_indices,
const bool out_int32,
const bool transpose) {
TORCH_CHECK(
crow_indices.dim() == col_indices.dim(), "crow_indices and col_indices are supposed to have"
" the same dimensionality, but got ", crow_indices.dim(), " and ",
crow_indices.dim(), " dimensional tensors, respectively.");
ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
c10::TensorOptions options = crow_indices.options().dtype(scalar_type);
set_output_raw_strided(0, {col_indices.dim() + 1, col_indices.numel()}, {}, options, {});
}
} // namespace meta
namespace {
template <typename F>
Tensor& unary_op_out(F op_out, const Tensor& self, Tensor& result) {
TORCH_INTERNAL_ASSERT(self.is_sparse_csr());
TORCH_INTERNAL_ASSERT(result.is_sparse_csr());
if (!result.is_same(self)) {
// For the case of (0x0) result tensor, manually resize `result` tensor
// to the size of `self` tensor
if (result.numel() == 0) {
at::native::resize_as_sparse_compressed_(result, self);
}
// copy_sparse_compressed_ internally checks the sizes of result and self tensors
// Hence no external size check required
at::native::copy_sparse_compressed_(result, self);
}
auto self_values = self.values();
auto result_values = result.values();
op_out(self_values, result_values);
return result;
}
template <typename F, typename... Args>
Tensor& unary_op_inplace(Tensor& self, const F& op_inplace, Args&&... args) {
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "unary_op_inplace", [](){});
auto self_values = self.values();
(self_values.*op_inplace)(std::forward<Args>(args)...);
return self;
}
} // end anonymous namespace
namespace native {
using namespace at::sparse_csr;
// certain utility functions are usable from sparse COO.
using namespace at::sparse;
Tensor& mul_out_sparse_csr(const Tensor& t_, const Tensor& src_, Tensor& r) {
// // TODO: Use a specialized CSR kernel for performance if needed
if (t_.is_sparse_csr() && src_.layout() == kStrided) {
return mul_out_sparse_csr(t_, src_.sparse_mask(t_), r);
}
if (t_.layout() == kStrided && src_.is_sparse_csr()) {
return mul_out_sparse_csr(t_.sparse_mask(src_), src_, r);
}
TORCH_CHECK(r.is_sparse_csr(), "Expected result Tensor to be of format CSR");
Tensor t = t_.to_sparse();
Tensor src = src_.to_sparse();
Tensor tmp_result = t.mul(src);
auto r_sparse_csr = tmp_result.to_sparse_csr();
r.resize_as_sparse_(r_sparse_csr);
r.copy_(r_sparse_csr);
return r;
}
template <typename op_t>
Tensor intersection_binary_op_with_wrapped_scalar(const Tensor& sparse, const Tensor& scalar, const op_t& op) {
// NOTE: intersection_binary_op_with_wrapped_scalar assumes scalar.numel() == 1.
const auto result_values = op(sparse.values(), scalar.squeeze()).to(at::result_type(sparse, scalar));
const auto result_sizes = infer_size(sparse.sizes(), scalar.sizes());
auto [compressed_indices, plain_indices] = getCompressedPlainIndices(sparse);
return at::_sparse_compressed_tensor_unsafe(
compressed_indices.clone(),
plain_indices.clone(),
result_values,
result_sizes,
sparse.options().dtype(result_values.scalar_type()));
}
template <typename op_t>
Tensor& intersection_binary_op_with_wrapped_scalar_(Tensor& sparse, const Tensor& scalar, const string& op_name, const op_t& op) {
// NOTE: intersection_binary_op_with_wrapped_scalar_ assumes scalar.numel() == 1.
const auto broadcasted_shape = infer_size(sparse.sizes(), scalar.sizes());
if (sparse.sizes() != broadcasted_shape) {
TORCH_CHECK(false, op_name, "(): output with shape ", sparse.sizes(), " does not match ",
"the broadcast shape ", broadcasted_shape);
}
auto values = sparse.values();
// Safe to use squeeze here, we already know that scalar safely broadcasts.
op(values, scalar.squeeze());
return sparse;
}
Tensor mul_sparse_csr(const Tensor& self, const Tensor& other) {
// Check if either of the arguments is a wrapped Scalar
if (self.layout() == kStrided && self.dim() == 0) {
return intersection_binary_op_with_wrapped_scalar(other, self, [](const Tensor& a, const Tensor& b) -> Tensor {
return a.mul(b);
});
}
if (other.layout() == kStrided && other.dim() == 0) {
return intersection_binary_op_with_wrapped_scalar(self, other, [](const Tensor& a, const Tensor& b) -> Tensor {
return a.mul(b);
});
}
if (self.is_sparse_csr() && other.layout() == kStrided) {
return mul_sparse_csr(self, other.sparse_mask(self));
}
if (self.layout() == kStrided && other.is_sparse_csr()) {
return mul_sparse_csr(self.sparse_mask(other), other);
}
auto commonDtype = at::result_type(self, other);
auto result_options = self.options().dtype(commonDtype);
// CSR is 2d!
Tensor result = at::empty({0, 0}, result_options);
return at::mul_out(result, self, other); // redispatch!
}
Tensor& mul_sparse_csr_(Tensor& self, const Tensor& other) {
if (other.layout() == kStrided && other.dim() == 0) {
return intersection_binary_op_with_wrapped_scalar_(self, other, "mul_", [](Tensor& a, const Tensor& b) -> Tensor& {
return a.mul_(b);
});
}
return at::mul_out(self, self, other); // redispatch!
}
namespace {
template <typename F>
inline Tensor get_result_tensor_for_unary_op(F op, const Tensor& input) {
auto values = input.values();
// To handle type promotion for inputs to unary ops,
// we first get the result from the underlined op, and use the result
// to create a sparse compressed tensor, which is used as the input to the out=
// variant
auto result_values = op(values);
auto compressed_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(input.layout(),
"get_result_tensor_for_unary_op",
[&]{ return input.crow_indices(); },
[&]{ return input.ccol_indices(); });
auto plain_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(input.layout(),
"get_result_tensor_for_unary_op",
[&]{ return input.col_indices(); },
[&]{ return input.row_indices(); });
auto result = at::_sparse_compressed_tensor_unsafe(
compressed_indices.clone(),
plain_indices.clone(),
result_values,
input.sizes(),
input.options().dtype(result_values.scalar_type()));
return result;
}
} // namespace
Tensor& normal_sparse_csr_(
Tensor& self,
double mean,
double std,
std::optional<Generator> gen) {
return unary_op_inplace(self, &Tensor::normal_, mean, std, gen);
}
Tensor& fill_sparse_csr_(Tensor& self, const Scalar& value) {
return unary_op_inplace(self, &TensorBase::fill_, value);
}
Tensor sparse_mask_sparse_compressed(
const Tensor& self,
const Tensor& mask) {
TORCH_CHECK(at::sparse_csr::is_sparse_compressed(mask),
"sparse_mask_sparse_compressed expects mask to have sparse compressed layout, got ", mask.layout());
TORCH_CHECK(
mask.sizes().equals(self.sizes()),
"sparse_mask(): operands have incompatible sizes; self has size ",
self.sizes(),
" but mask has size ",
mask.sizes());
if (self.is_same(mask)) {
return self;
}
if (!mask.numel() || !mask._nnz()) {
return mask.clone().to(self.device(), self.scalar_type());
}
if (self.layout() == kStrided) {
auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(mask);
auto mask_values = mask.values();
auto dense_mask = at::_sparse_compressed_tensor_unsafe(
compressed_indices,
plain_indices,
at::ones({1}, self.options().dtype(kBool)).expand_as(mask_values),
self.sizes(),
self.options().dtype(kBool).layout(mask.layout())).to_dense();
return AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
mask.layout(), "sparse_mask_sparse_compressed",
[&] {
return at::native::dense_to_sparse_with_mask(self, dense_mask, mask.layout(), {}, mask.dense_dim());
},
[&] {
auto blocksize = at::sparse_csr::getBlockSize(mask);
return at::native::dense_to_sparse_with_mask(self, dense_mask, mask.layout(), blocksize, mask.dense_dim());
});
} else if (self.layout() == mask.layout()) {
// TODO: keeping this for BC but the method used here may lead to
// incorrect indices.
return self.mul(at::ones_like(mask)).to(self.scalar_type());
} else {
// TODO: keeping this for BC but the method used here cannot
// support batch dimensions because sparse COO tensors are batch
// dimension ignorant.
return AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
mask.layout(), "sparse_mask_sparse_compressed",
[&] {
return self.sparse_mask(mask.to_sparse()).to_sparse(mask.layout());
},
[&] {
auto blocksize = at::sparse_csr::getBlockSize(mask);
return self.sparse_mask(mask.to_sparse()).to_sparse(mask.layout(), blocksize);
});
}
}
Tensor mul_scalar_sparse_csr(const Tensor& self, const Scalar& other) {
auto result_values = self.values().mul(other);
return at::native::_sparse_csr_tensor_unsafe(
self.crow_indices().clone(),
self.col_indices().clone(),
result_values,
self.sizes(),
result_values.scalar_type(),
self.layout(),
result_values.device());
}
Tensor& zero_sparse_csr_(Tensor& self) {
/*
csr.zero_() resets nnz to 0.
If the original sparsity pattern needs to be preserved, use
`csr.values().zero_()` instead.
The above behavior also implies that torch.zeros_like(csr) returns
a new tensor with nnz == 0. If one needs a zeros_like semantics
where the result has the same sparsity pattern as input, then use
`result = csr.clone(); result.values.zero_();`
*/
AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "zero_sparse_csr_", [](){});
get_sparse_csr_impl(self)->resize_and_clear_(self.sparse_dim(), self.dense_dim(), self.sizes());
return self;
}
/* Implementation of Unary Ufuncs, those supported for Sparse CSR Layout
* Only simple funcs, with 0->0 correspondence are currently supported. */
#define CREATE_UNARY_UFUNC_OUT(op_name) \
Tensor& op_name##_sparse_csr_out(const Tensor& self, Tensor& result) { \
return unary_op_out(&at::op_name##_outf, self, result); \
}
#define CREATE_UNARY_UFUNC_FUNCTIONAL(op_name) \
Tensor op_name##_sparse_csr(const Tensor& self) { \
return get_result_tensor_for_unary_op(&at::op_name, self); \
}
#define CREATE_UNARY_UFUNC_INPLACE(op_name) \
Tensor& op_name##_sparse_csr_(Tensor& self) { \
return unary_op_inplace(self, &Tensor::op_name##_); \
}
#define CREATE_UNARY_UFUNC(op_name) \
CREATE_UNARY_UFUNC_OUT(op_name) \
CREATE_UNARY_UFUNC_FUNCTIONAL(op_name) \
CREATE_UNARY_UFUNC_INPLACE(op_name)
#define CREATE_UNARY_UFUNC_NO_INPLACE(op_name) \
CREATE_UNARY_UFUNC_OUT(op_name) \
CREATE_UNARY_UFUNC_FUNCTIONAL(op_name)
// Exhaustive list of the unary ufuncs supported by sparse compressed
CREATE_UNARY_UFUNC(abs)
CREATE_UNARY_UFUNC(asin)
CREATE_UNARY_UFUNC(asinh)
CREATE_UNARY_UFUNC(atan)
CREATE_UNARY_UFUNC(atanh)
CREATE_UNARY_UFUNC(ceil)
CREATE_UNARY_UFUNC(deg2rad)
CREATE_UNARY_UFUNC(erf)
CREATE_UNARY_UFUNC(erfinv)
CREATE_UNARY_UFUNC(expm1)
CREATE_UNARY_UFUNC(floor)
CREATE_UNARY_UFUNC(frac)
CREATE_UNARY_UFUNC(log1p)
CREATE_UNARY_UFUNC(neg)
CREATE_UNARY_UFUNC(rad2deg)
CREATE_UNARY_UFUNC(sign)
CREATE_UNARY_UFUNC(sin)
CREATE_UNARY_UFUNC(sinh)
CREATE_UNARY_UFUNC(sgn)
CREATE_UNARY_UFUNC(sqrt)
CREATE_UNARY_UFUNC(tan)
CREATE_UNARY_UFUNC(tanh)
CREATE_UNARY_UFUNC(trunc)
CREATE_UNARY_UFUNC(conj_physical)
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")
static CREATE_UNARY_UFUNC(relu)
C10_DIAGNOSTIC_POP()
// With addition of `round.decimals` overload, using CREATE_UNARY_UFUNC leads
// to unresolved overload.
Tensor& round_sparse_csr_out(const Tensor& self, Tensor& result) {
return unary_op_out(&at::_ops::round_out::call, self, result);
}
Tensor round_sparse_csr(const Tensor& self) {
return get_result_tensor_for_unary_op(&at::_ops::round::call, self);
}
Tensor& round_sparse_csr_(Tensor& self) {
TORCH_INTERNAL_ASSERT(self.is_sparse_csr());
self.values().round_();
return self;
}
Tensor threshold_backward_sparse_compressed(
const Tensor& grad_output,
const Tensor& self,
const Scalar& threshold) {
return get_result_tensor_for_unary_op(
[&](const Tensor& t) {
return at::threshold_backward(t, self.values(), threshold);
},
grad_output);
}
Tensor& threshold_backward_sparse_compressed_out(
const Tensor& grad_output,
const Tensor& self,
const Scalar& threshold,
Tensor& grad_input) {
return unary_op_out(
[&](const Tensor& t, Tensor& out) {
return at::threshold_backward_outf(t, self.values(), threshold, out);
},
grad_output,
grad_input);
}
// angle, isneginf, isposinf and signbit currently don't have an inplace variant
CREATE_UNARY_UFUNC_NO_INPLACE(angle)
CREATE_UNARY_UFUNC_NO_INPLACE(isneginf)
CREATE_UNARY_UFUNC_NO_INPLACE(isposinf)
CREATE_UNARY_UFUNC_NO_INPLACE(signbit)
// isnan and isinf don't have an out variant
CREATE_UNARY_UFUNC_FUNCTIONAL(isnan)
CREATE_UNARY_UFUNC_FUNCTIONAL(isinf)
template <typename scalar_t>
void addmm_out_sparse_csr_native_cpu(
const Tensor& sparse,
const Tensor& dense,
const Tensor& r,
Scalar alpha,
Scalar beta) {
auto dim_i = sparse.size(0);
auto dim_k = dense.size(1);
auto csr = sparse.crow_indices();
auto col_indices = sparse.col_indices();
auto values = sparse.values();
scalar_t cast_alpha = alpha.to<scalar_t>();
r.mul_(beta);
AT_DISPATCH_INDEX_TYPES(
col_indices.scalar_type(), "csr_mm_crow_indices", [&]() {
auto csr_accessor = csr.accessor<index_t, 1>();
auto col_indices_accessor = col_indices.accessor<index_t, 1>();
auto values_accessor = values.accessor<scalar_t, 1>();
scalar_t* dense_ptr = dense.data_ptr<scalar_t>();
scalar_t* r_ptr = r.data_ptr<scalar_t>();
int64_t dense_stride0 = dense.stride(0);
int64_t dense_stride1 = dense.stride(1);
int64_t r_stride0 = r.stride(0);
int64_t r_stride1 = r.stride(1);
at::parallel_for(
0,
dim_i,
internal::GRAIN_SIZE,
[&](int64_t irow_start, int64_t irow_end) {
for (index_t h = irow_start; h < irow_end; ++h) {
index_t i_start = csr_accessor[h];
index_t i_end = csr_accessor[h + 1];
for (index_t i = i_start; i < i_end; i++) {
scalar_t val = values_accessor[i];
index_t col = col_indices_accessor[i];
at::native::cpublas::axpy<scalar_t>(
dim_k,
cast_alpha * val,
dense_ptr + col * dense_stride0,
dense_stride1,
r_ptr + h * r_stride0,
r_stride1);
}
}
});
});
}
// Functions for matrix multiplication.
// result = beta * self + alpha (mat1 @ mat2)
Tensor& addmm_out_sparse_compressed_cpu(
const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
Tensor& result) {
// All the checks are from addmm_out_cuda_impl (ATen/native/cuda/Blas.cpp) and
// TORCH_META_FUNC(addmm) (ATen/native/LinearAlgebra.cpp)
// TODO: remove code duplication and unify code
sparse::impl::_check_dim(mat1, 2, "mat1");
sparse::impl::_check_dim(mat2, 2, "mat2");
TORCH_CHECK(
mat1.size(1) == mat2.size(0), "mat1 and mat2 shapes cannot be multiplied (",
mat1.size(0), "x", mat1.size(1), " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
c10::MaybeOwned<at::Tensor> self_;
// Don't expand self if this is an in-place operation
if (&result == &self) {
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
} else {
self_ = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm");
}
TORCH_CHECK(((self_->dim() == 2) &&
(self_->size(0) == mat1.size(0)) &&
(self_->size(1) == mat2.size(1))),
"The input tensor must be a matrix with size ",
mat1.size(0),
"x",
mat2.size(1),
", but got a ",
self_->dim(),
"-D tensor with size ",
self_->size(0),
"x",
self_->size(1));
if (&result != &self) {
if (result.layout() == kStrided) {
at::native::resize_output(result, self_->sizes());
} else {
result.resize_as_sparse_(*self_);
}
result.copy_(*self_);
}
if (result.numel() == 0) {
// If result gets resized and is sparse compressed,
// it's compressed_indices tensor will contain junk values
// so the whole tensor is not a valid compressed tensor.
// To combat that, result needs to get zeroed out.
if (at::sparse_csr::is_sparse_compressed(result)) {
result.zero_();
}
return result;
}
if (sparse::impl::_is_sparse_and_zero(mat1) || sparse::impl::_is_sparse_and_zero(mat2)) {
// According to docs, when beta==0 values in self should be ignored.
// nans and infs should not propagate
if (beta.toComplexDouble() == 0.) {
result.zero_();
} else {
result.mul_(beta);
}
return result;
}
#if !AT_USE_MKL_SPARSE()
// The custom impl addmm_out_sparse_csr_native_cpu only supports CSR @
// strided -> strided
if (mat1.layout() == kStrided) {
if (mat2.layout() == kSparseCsr) {
if (result.layout() == kStrided) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
result.scalar_type(), "addmm_sparse_dense", [&] {
addmm_out_sparse_csr_native_cpu<scalar_t>(
mat2.transpose(-2, -1).to_sparse_csr(),
mat1.transpose(-2, -1),
result.transpose(-2, -1),
alpha,
beta);
});
return result;
}
}
if (mat2.layout() == kSparseCsc) {
if (result.layout() == kStrided) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
result.scalar_type(), "addmm_sparse_dense", [&] {
addmm_out_sparse_csr_native_cpu<scalar_t>(
mat2.transpose(-2, -1),
mat1.transpose(-2, -1),
result.transpose(-2, -1),
alpha,
beta);
});
return result;
}
}
} else if (mat1.layout() == kSparseCsr) {
if (mat2.layout() == kStrided) {
if (result.layout() == kStrided) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
result.scalar_type(), "addmm_sparse_dense", [&] {
addmm_out_sparse_csr_native_cpu<scalar_t>(
mat1, mat2, result, alpha, beta);
});
return result;
}
}
} else if (mat1.layout() == kSparseCsc) {
if (mat2.layout() == kStrided) {
if (result.layout() == kStrided) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
result.scalar_type(), "addmm_sparse_dense", [&] {
addmm_out_sparse_csr_native_cpu<scalar_t>(
mat1.to_sparse_csr(), mat2, result, alpha, beta);
});
return result;
}
}
}
TORCH_CHECK(
false,
"addmm: computation on CPU is not implemented for ",
result.layout(),
" + ",
mat1.layout(),
" @ ",
mat2.layout(),
" without MKL. PyTorch built with MKL has better support for addmm with sparse CPU tensors.");
#else
sparse::impl::mkl::addmm_out_sparse_csr(mat1, mat2, beta, alpha, result);
#endif
return result;
}
Tensor addmm_sparse_compressed_dense(
const Tensor& self,
const SparseCsrTensor& sparse,
const Tensor& dense,
const Scalar& beta,
const Scalar& alpha) {
Tensor r = at::empty({0, 0}, self.options());
at::addmm_out(r, self, sparse, dense, beta, alpha);
return r;
}
Tensor& _sparse_csr_mm_out(
const Tensor& mat1,
const Tensor& mat2,
Tensor& result) {
auto zero = at::zeros_like(result);
return at::addmm_out(result, zero, mat1, mat2, 0.0, 1.0);
}
Tensor _sparse_csr_mm(const Tensor& mat1, const Tensor& mat2) {
if (mat1.is_sparse_csr() && mat2.is_sparse_csr()) {
// Return sparse
return at::addmm(
at::zeros({mat1.size(0), mat2.size(1)}, mat2.options()),
mat1,
mat2,
0.0,
1.0);
}
if ((mat1.layout() == kSparseCsc || mat1.layout() == kSparseCsr) &&
(mat2.layout() == kSparseCsc || mat2.layout() == kSparseCsr)) {
// TODO: Expensive conversion to CSR. Should add native support for CSC.
// Covers CSC @ CSR
// Covers CSR @ CSC
// Covers CSC @ CSC
return _sparse_csr_mm(mat1.to_sparse_csr(), mat2.to_sparse_csr());
}
if (mat1.layout() == kSparseCsc && mat2.layout() == c10::kStrided) {
// TODO: This is a costly conversion. We should have
// native support for CSC.
return _sparse_csr_mm(mat1.to_sparse_csr(), mat2);
}
// Default to taking options from mat1
auto result_options = mat1.options();
if (mat2.layout() == kStrided) {
// if either arg is strided we return strided, so update the options if
// mat2 is strided.
result_options = result_options.layout(kStrided);
}
return at::addmm(
at::zeros({mat1.size(0), mat2.size(1)}, result_options),
mat1,
mat2,
0.0,
1.0);
}
// Functions for element-wise addition.
Tensor add_sparse_csr(
const Tensor& self,
const Tensor& other,
const Scalar& alpha) {
auto commonDtype = at::result_type(self, other);
alpha_check(commonDtype, alpha);
Tensor result;
if (self.layout() != kStrided && other.layout() == kStrided) {
// add(sparse, dense) -> dense
result = at::empty_like(
other,
other.options()
.dtype(commonDtype)
.memory_format(at::MemoryFormat::Contiguous));
} else {
// add(dense, sparse) -> dense AND add(sparse, sparse) -> sparse
result = at::empty_like(
self,
self.options()
.dtype(commonDtype)
.memory_format(at::MemoryFormat::Contiguous));
}
return at::add_out(result, self, other, alpha); // redispatch!
}
Tensor& add_sparse_csr_(
Tensor& self,
const Tensor& other,
const Scalar& alpha) {
return at::add_out(self, self, other, alpha); // redispatch!
}
static void add_out_dense_sparse_compressed_cpu(
const Tensor& out,
const Tensor& dense,
const SparseCsrTensor& src,
const Scalar& alpha) {
TORCH_INTERNAL_ASSERT(dense.layout() == kStrided);
TORCH_INTERNAL_ASSERT(
src.layout() == kSparseCsr || src.layout() == kSparseCsc);
TORCH_INTERNAL_ASSERT(dense.device() == kCPU || dense.device() == kMeta);
TORCH_CHECK(
out.is_contiguous(),
"out argument must be contiguous, but got: ",
out.suggest_memory_format());
TORCH_CHECK(
out.device() == dense.device(),
"add: expected 'out' to match dense tensor, but got tensor on device: ",
out.device());
TORCH_CHECK(
src.device() == dense.device(),
"add: expected 'src' to match dense tensor, but got tensor on device: ",
src.device());
TORCH_CHECK(
dense.sizes().equals(src.sizes()),
"add: expected 'self' and 'other' to have same size, but self has size ",
dense.sizes(),
" while other has size ",
src.sizes(),
" (FYI: op2-sparse addition does not currently support broadcasting)");
auto commonDtype = promoteTypes(dense.scalar_type(), src.scalar_type());
TORCH_CHECK(
canCast(commonDtype, out.scalar_type()),
"Can't convert result type ",
commonDtype,
" to output ",
out.scalar_type(),
" in add operation");
auto src_values = src.values();
resize_output(out, dense.sizes());
Tensor resultBuffer = out;
if (out.scalar_type() != commonDtype) {
resultBuffer = dense.to(commonDtype);
} else if (!is_same_tensor(out, dense)) {
resultBuffer.copy_(dense);
}
if (src._nnz() == 0) {
return;
}
TORCH_INTERNAL_ASSERT(dense.device() == kCPU);
auto valuesBuffer = src_values.to(commonDtype).reshape({-1, src_values.size(-1)});
resultBuffer = resultBuffer.view({-1, out.size(-2), out.size(-1)});
Tensor src_compressed_indices;
Tensor src_plain_indices;
std::tie(src_compressed_indices, src_plain_indices) =
at::sparse_csr::getCompressedPlainIndices(src);
src_compressed_indices =
src_compressed_indices.reshape({-1, src_compressed_indices.size(-1)});
src_plain_indices =
src_plain_indices.reshape({-1, src_plain_indices.size(-1)});
auto src_layout = src.layout();
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf,
kHalf,
kBool,
kBFloat16,
commonDtype,
"add_out_op2_sparse_csr",
[&valuesBuffer,
&resultBuffer,
&alpha,
&src_compressed_indices,
&src_plain_indices,
&src_layout]() {
AT_DISPATCH_INDEX_TYPES(
src_compressed_indices.scalar_type(),
"csr_add_out_crow_indices",
[&valuesBuffer,
&resultBuffer,
&alpha,
&src_compressed_indices,
&src_plain_indices,
&src_layout]() {
auto batch_count =
resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
auto values_accessor = valuesBuffer.accessor<scalar_t, 2>();
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
scalar_t cast_value = alpha.to<scalar_t>();
auto compressed_indices_accessor =
src_compressed_indices.accessor<index_t, 2>();
auto plain_indices_accessor =
src_plain_indices.accessor<index_t, 2>();
auto out_strides = resultBuffer.strides();
auto const out_stride_batch = out_strides[0];
auto const out_stride_compressed =
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
src_layout,
"add_out_dense_sparse_compressed_cpu",
[&out_strides] { return out_strides[1]; },
[&out_strides] { return out_strides[2]; });
auto const out_stride_plain =
AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
src_layout,
"add_out_dense_sparse_compressed_cpu",
[&out_strides] { return out_strides[2]; },
[&out_strides] { return out_strides[1]; });
for (const auto batch_idx : c10::irange(batch_count)) {
for (const auto i_compressed :
c10::irange(src_compressed_indices.size(-1) - 1)) {
index_t start_index =
compressed_indices_accessor[batch_idx][i_compressed];
index_t end_index =
compressed_indices_accessor[batch_idx][i_compressed + 1];
for (const auto i : c10::irange(start_index, end_index)) {
auto i_plain = plain_indices_accessor[batch_idx][i];
auto index = batch_idx * out_stride_batch +
i_compressed * out_stride_compressed +
i_plain * out_stride_plain;
out_ptr[index] +=
cast_value * values_accessor[batch_idx][i];
}
}
}
});
});
if (out.scalar_type() != commonDtype) {
out.copy_(resultBuffer);
}
}
Tensor& add_out_sparse_compressed_cpu(
const Tensor& self,
const SparseCsrTensor& other,
const Scalar& alpha,
SparseCsrTensor& out) {
if (self.layout() == kStrided) {
add_out_dense_sparse_compressed_cpu(out, self, other, alpha);
} else if (other.layout() == kStrided) {
add_out_dense_sparse_compressed_cpu(out, other, self, alpha);
} else {
TORCH_CHECK(
self.sizes().equals(other.sizes()),
"torch.add: Expected input tensors to have the same shape, but got tensor `self` with shape ",
self.sizes(),
" and tensor `other` with shape ",
other.sizes());
if (only_sparse_compressed_add_trivial_cases(self, other, alpha, out)) {
return out;
}
at::native::resize_as_sparse_compressed_(out, self);
sparse::impl::cpu::add_out_sparse_csr(self, other, alpha, out);
}
return out;
}
/*
Reductions on sparse CSR tensors using masked semantics.
- A CSR tensor is a 2D tensor that is specified by a 3-tuple
(crow_indices, col_indices, values).
- To support a reduction operator on a CSR tensor, define:
template <typename scalar_t>
struct Reduction...Op {
inline scalar_t operator()(const scalar_t& a, const scalar_t& b) const {
return a ... b;
}
inline scalar_t identity() const { return ...; }
};
Tensor _sparse_csr_..._cpu(const Tensor& input, IntArrayRef dims_to_sum, bool keepdim, std::optional<ScalarType> dtype) {
...
result = reduce_sparse_csr_cpu_template<scalar_t>(input_, dims_to_sum, keepdim, Reduction...Op<scalar_t>());
...
return result;
}