Skip to content

Commit

Permalink
ntt/ntt.cuh: make sense of |aux_data| in LDE_aux and de-duplicate LDE().
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Dec 6, 2023
1 parent f5067ce commit adfb183
Showing 1 changed file with 20 additions and 45 deletions.
65 changes: 20 additions & 45 deletions ntt/ntt.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,47 +164,6 @@ public:
return RustError{cudaSuccess};
}

static RustError LDE(const gpu_t& gpu, fr_t* inout,
uint32_t lg_domain_size, uint32_t lg_blowup)
{
try {
gpu.select();

size_t domain_size = (size_t)1 << lg_domain_size;
size_t ext_domain_size = domain_size << lg_blowup;
dev_ptr_t<fr_t> d_ext_domain{ext_domain_size, gpu};
fr_t* d_domain = &d_ext_domain[ext_domain_size - domain_size];

gpu.HtoD(&d_domain[0], inout, domain_size);

NTT_internal(&d_domain[0], lg_domain_size,
InputOutputOrder::NR, Direction::inverse,
Type::standard, gpu);

const auto gen_powers =
NTTParameters::all()[gpu.id()]->partial_group_gen_powers;

LDE_launch(gpu, &d_ext_domain[0], &d_domain[0],
gen_powers, lg_domain_size, lg_blowup);

NTT_internal(&d_ext_domain[0], lg_domain_size + lg_blowup,
InputOutputOrder::RN, Direction::forward,
Type::standard, gpu);

gpu.DtoH(inout, &d_ext_domain[0], ext_domain_size);
gpu.sync();
} catch (const cuda_error& e) {
gpu.sync();
#ifdef TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE
return RustError{e.code(), e.what()};
#else
return RustError{e.code()};
#endif
}

return RustError{cudaSuccess};
}

protected:
static void LDE_launch(stream_t& stream,
fr_t* ext_domain_data, fr_t* domain_data,
Expand Down Expand Up @@ -246,17 +205,20 @@ protected:

public:
static RustError LDE_aux(const gpu_t& gpu, fr_t* inout,
uint32_t lg_domain_size, uint32_t lg_blowup)
uint32_t lg_domain_size, uint32_t lg_blowup,
fr_t *aux_out = nullptr)
{
try {
size_t domain_size = (size_t)1 << lg_domain_size;
size_t ext_domain_size = domain_size << lg_blowup;
size_t aux_size = aux_out != nullptr ? domain_size : 0;
// The 2nd to last 'domain_size' chunk will hold the original data
// The last chunk will get the bit reversed iNTT data
dev_ptr_t<fr_t> d_inout{ext_domain_size + domain_size, gpu}; // + domain_size for aux buffer
dev_ptr_t<fr_t> d_inout{ext_domain_size + aux_size, gpu}; // + domain_size for aux buffer
fr_t* aux_data = &d_inout[ext_domain_size];
fr_t* domain_data = &d_inout[ext_domain_size - domain_size]; // aligned to the end
fr_t* ext_domain_data = &d_inout[0];

gpu.HtoD(domain_data, inout, domain_size);

NTT_internal(domain_data, lg_domain_size,
Expand All @@ -266,7 +228,12 @@ public:
const auto gen_powers =
NTTParameters::all()[gpu.id()]->partial_group_gen_powers;

bit_rev(aux_data, domain_data, lg_domain_size, gpu);
event_t sync_event;

if (aux_out != nullptr) {
bit_rev(aux_data, domain_data, lg_domain_size, gpu);
sync_event.record(gpu);
}

LDE_launch(gpu, ext_domain_data, domain_data, gen_powers,
lg_domain_size, lg_blowup);
Expand All @@ -276,7 +243,11 @@ public:
InputOutputOrder::RN, Direction::forward,
Type::standard, gpu);

gpu.DtoH(inout, ext_domain_data, domain_size << lg_blowup);
if (aux_out != nullptr) {
sync_event.wait(gpu[0]);
gpu[0].DtoH(aux_out, aux_data, aux_size);
}
gpu.DtoH(inout, ext_domain_data, ext_domain_size);
gpu.sync();
} catch (const cuda_error& e) {
gpu.sync();
Expand All @@ -290,6 +261,10 @@ public:
return RustError{cudaSuccess};
}

static RustError LDE(const gpu_t& gpu, fr_t* inout,
uint32_t lg_domain_size, uint32_t lg_blowup)
{ return LDE_aux(gpu, inout, lg_domain_size, lg_blowup); }

static void Base_dev_ptr(stream_t& stream, fr_t* d_inout,
uint32_t lg_domain_size, InputOutputOrder order,
Direction direction, Type type)
Expand Down

0 comments on commit adfb183

Please sign in to comment.