Skip to content

Commit

Permalink
fix convolution constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
yosupo06 committed Mar 26, 2023
1 parent 47b2ec4 commit b76ba51
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions atcoder/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution_fft(std::vector<mint> a, std::vector<mint> b) {
int n = int(a.size()), m = int(b.size());
int z = (int)internal::bit_ceil((unsigned int)(n + m - 1));
assert(mint::mod() % z == 1);
a.resize(z);
internal::butterfly(a);
b.resize(z);
Expand All @@ -220,6 +219,10 @@ template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution(std::vector<mint>&& a, std::vector<mint>&& b) {
int n = int(a.size()), m = int(b.size());
if (!n || !m) return {};

int z = (int)internal::bit_ceil((unsigned int)(n + m - 1));
assert(mint::mod() % z == 1);

if (std::min(n, m) <= 60) return convolution_naive(a, b);
return internal::convolution_fft(a, b);
}
Expand All @@ -229,6 +232,10 @@ std::vector<mint> convolution(const std::vector<mint>& a,
const std::vector<mint>& b) {
int n = int(a.size()), m = int(b.size());
if (!n || !m) return {};

int z = (int)internal::bit_ceil((unsigned int)(n + m - 1));
assert(mint::mod() % z == 1);

if (std::min(n, m) <= 60) return convolution_naive(a, b);
return internal::convolution_fft(a, b);
}
Expand All @@ -241,6 +248,10 @@ std::vector<T> convolution(const std::vector<T>& a, const std::vector<T>& b) {
if (!n || !m) return {};

using mint = static_modint<mod>;

int z = (int)internal::bit_ceil((unsigned int)(n + m - 1));
assert(mint::mod() % z == 1);

std::vector<mint> a2(n), b2(m);
for (int i = 0; i < n; i++) {
a2[i] = mint(a[i]);
Expand Down Expand Up @@ -280,7 +291,7 @@ std::vector<long long> convolution_ll(const std::vector<long long>& a,
static_assert(MOD1 % (1ull << MAX_AB_BIT) == 1, "MOD1 isn't enough to support an array length of 2^24.");
static_assert(MOD2 % (1ull << MAX_AB_BIT) == 1, "MOD2 isn't enough to support an array length of 2^24.");
static_assert(MOD3 % (1ull << MAX_AB_BIT) == 1, "MOD3 isn't enough to support an array length of 2^24.");
assert(a.size() + b.size() - 1 <= (1ull << MAX_AB_BIT));
assert(n + m - 1 <= (1 << MAX_AB_BIT));

auto c1 = convolution<MOD1>(a, b);
auto c2 = convolution<MOD2>(a, b);
Expand Down

0 comments on commit b76ba51

Please sign in to comment.