diff --git a/atcoder/convolution.hpp b/atcoder/convolution.hpp index 7b27f81..059b5c6 100644 --- a/atcoder/convolution.hpp +++ b/atcoder/convolution.hpp @@ -199,7 +199,6 @@ template * = nullptr> std::vector convolution_fft(std::vector a, std::vector b) { int n = int(a.size()), m = int(b.size()); int z = (int)internal::bit_ceil((unsigned int)(n + m - 1)); - assert(mint::mod() % z == 1); a.resize(z); internal::butterfly(a); b.resize(z); @@ -220,6 +219,10 @@ template * = nullptr> std::vector convolution(std::vector&& a, std::vector&& b) { int n = int(a.size()), m = int(b.size()); if (!n || !m) return {}; + + int z = (int)internal::bit_ceil((unsigned int)(n + m - 1)); + assert(mint::mod() % z == 1); + if (std::min(n, m) <= 60) return convolution_naive(a, b); return internal::convolution_fft(a, b); } @@ -229,6 +232,10 @@ std::vector convolution(const std::vector& a, const std::vector& b) { int n = int(a.size()), m = int(b.size()); if (!n || !m) return {}; + + int z = (int)internal::bit_ceil((unsigned int)(n + m - 1)); + assert(mint::mod() % z == 1); + if (std::min(n, m) <= 60) return convolution_naive(a, b); return internal::convolution_fft(a, b); } @@ -241,6 +248,10 @@ std::vector convolution(const std::vector& a, const std::vector& b) { if (!n || !m) return {}; using mint = static_modint; + + int z = (int)internal::bit_ceil((unsigned int)(n + m - 1)); + assert(mint::mod() % z == 1); + std::vector a2(n), b2(m); for (int i = 0; i < n; i++) { a2[i] = mint(a[i]); @@ -280,7 +291,7 @@ std::vector convolution_ll(const std::vector& a, static_assert(MOD1 % (1ull << MAX_AB_BIT) == 1, "MOD1 isn't enough to support an array length of 2^24."); static_assert(MOD2 % (1ull << MAX_AB_BIT) == 1, "MOD2 isn't enough to support an array length of 2^24."); static_assert(MOD3 % (1ull << MAX_AB_BIT) == 1, "MOD3 isn't enough to support an array length of 2^24."); - assert(a.size() + b.size() - 1 <= (1ull << MAX_AB_BIT)); + assert(n + m - 1 <= (1 << MAX_AB_BIT)); auto c1 = convolution(a, b); auto c2 = convolution(a, b);