forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
THCStorage.cpp
70 lines (55 loc) · 1.94 KB
/
THCStorage.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <THC/THCStorage.hpp>
#include <THC/THCGeneral.h>
#include <TH/THHalf.h>
#include <new>
#include <THC/generic/THCStorage.cpp>
#include <THC/THCGenerateAllTypes.h>
#include <THC/generic/THCStorage.cpp>
#include <THC/THCGenerateComplexTypes.h>
#include <THC/generic/THCStorage.cpp>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCStorage.cpp>
#include <THC/THCGenerateBFloat16Type.h>
#include <c10/util/intrusive_ptr.h>
void THCStorage_resizeBytes(
THCState* state,
THCStorage* self,
ptrdiff_t size_bytes) {
THArgCheck(size_bytes >= 0, 2, "invalid size");
THAssert(self->allocator() != nullptr);
int device;
THCudaCheck(cudaGetDevice(&device));
if (!self->resizable())
THError("Trying to resize storage that is not resizable");
if (size_bytes == 0) {
self->set_data_ptr(at::DataPtr(nullptr, at::Device(at::DeviceType::CUDA, device)));
self->set_nbytes(0);
} else {
at::DataPtr data = self->allocator()->allocate(size_bytes);
if (self->data_ptr()) {
// Enable p2p access when the memcpy is across devices
THCState_getPeerToPeerAccess(state, device, THCStorage_getDevice(state, self));
THCudaCheck(cudaMemcpyAsync(
data.get(),
self->data(),
THMin(self->nbytes(), size_bytes),
cudaMemcpyDeviceToDevice,
c10::cuda::getCurrentCUDAStream()));
}
// Destructively overwrite data_ptr
self->set_data_ptr(std::move(data));
self->set_nbytes(size_bytes);
}
}
int THCStorage_getDevice(THCState* state, const THCStorage* storage) {
return storage->device().index();
}
THCStorage* THCStorage_new(THCState* state) {
THStorage* storage = c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
0,
c10::cuda::CUDACachingAllocator::get(),
true)
.release();
return storage;
}