diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index a08379ff8dbc..9fc85ac9ecc5 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -31,39 +31,51 @@ jobs: include: - torch-version: 1.8.1 torchvision-version: 0.9.1 + torchaudio-version: 0.8.1 - torch-version: 1.9.1 torchvision-version: 0.10.1 + torchaudio-version: 0.9.1 - torch-version: 1.10.0 torchvision-version: 0.11.1 + torchaudio-version: '0.10.0+cpu' - torch-version: 1.11.0 torchvision-version: 0.12.0 + torchaudio-version: '0.11.0+cpu' - torch-version: 1.12.0 torchvision-version: 0.13.0 + torchaudio-version: '0.12.0+cpu' - torch-version: 1.13.0 torchvision-version: 0.14.0 + torchaudio-version: '0.13.0+cpu' - torch-version: 2.0.0 torchvision-version: 0.15.1 + torchaudio-version: '2.0.0+cpu' - torch-version: 2.1.0 torchvision-version: 0.16.0 + torchaudio-version: '2.1.0+cpu' - torch-version: 2.2.1 torchvision-version: 0.17.1 + torchaudio-version: '2.2.1+cpu' - torch-version: 2.3.0 torchvision-version: 0.18.0 + torchaudio-version: '2.3.0+cpu' - torch-version: 2.4.0 torchvision-version: 0.19.0 + torchaudio-version: '2.4.0+cpu' - torch-version: 2.5.0 torchvision-version: 0.20.0 + torchaudio-version: '2.5.0+cpu' runs-on: pool-name: docker @@ -169,7 +181,7 @@ jobs: - name: setup-pytorch run: | export PYTHONUSERBASE=${{ci.workspace}}/torch-${{matrix.torch-version}} - pip3 install --user torch==${{matrix.torch-version}}+cpu torchvision==${{matrix.torchvision-version}}+cpu --index-url https://download.pytorch.org/whl/cpu + pip3 install --user torch==${{matrix.torch-version}}+cpu torchvision==${{matrix.torchvision-version}}+cpu torchaudio==${{matrix.torchaudio-version}} --index-url https://download.pytorch.org/whl/cpu pip3 install --user onnx pip3 install --user onnxscript diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index 4c82fd472c10..10fe1f03f0f6 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -46,6 +46,7 @@ * [Input](#input) * [InstanceNorm](#instancenorm) * [Interp](#interp) +* [InverseSpectrogram](#inversespectrogram) * [LayerNorm](#layernorm) * [Log](#log) * [LRN](#lrn) @@ -81,6 +82,7 @@ * [Slice](#slice) * [Softmax](#softmax) * [Softplus](#softplus) +* [Spectrogram](#spectrogram) * [Split](#split) * [Swish](#swish) * [TanH](#tanh) @@ -1141,6 +1143,30 @@ Resize type: - 2 = Bilinear - 3 = Bicubic +# InverseSpectrogram +``` +x1 = x as complex +x1 = x1 * sqrt(norm) if normalized +y = istft(x1) +y1 = unpad(y) if center + +if returns == 0 return y1 as complex +if returns == 1 return y1 real +if returns == 2 return y1 imag +``` + +* one_blob_only + +| param id | name | type | default | description | +| --------- | ------------- | ----- | --------- | ----------------- | +| 0 | n_fft | int | 0 | | +| 1 | returns | int | 1 | | +| 2 | hoplen | int | n_fft / 4 | | +| 3 | winlen | int | n_fft | | +| 4 | window_type | int | 0 | 0=ones 1=hann 2=hamming | +| 5 | center | int | 1 | | +| 7 | normalized | int | 0 | 0=no 1=n_fft 2=window-l2-energy | + # LayerNorm ``` split x along outmost axis into part x0, x1 ... @@ -1829,6 +1855,31 @@ y = log(exp(x) + 1) * one_blob_only * support_inplace +# Spectrogram +``` +x1 = pad(x) if center +y = stft(x1) +y = y / sqrt(norm) if normalized + +if power == 0 return y as real +if power == 1 return magnitude +if power == 2 return square of magnitude +``` + +* one_blob_only + +| param id | name | type | default | description | +| --------- | ------------- | ----- | --------- | ----------------- | +| 0 | n_fft | int | 0 | | +| 1 | power | int | 0 | | +| 2 | hoplen | int | n_fft / 4 | | +| 3 | winlen | int | n_fft | | +| 4 | window_type | int | 0 | 0=ones 1=hann 2=hamming | +| 5 | center | int | 1 | | +| 6 | pad_type | int | 2 | 0=CONSTANT 1=REPLICATE 2=REFLECT | +| 7 | normalized | int | 0 | 0=no 1=n_fft 2=window-l2-energy | +| 8 | onesided | int | 1 | | + # Split ``` y0, y1 ... = x diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4aa952e3d0f0..21d5ae5eae31 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -167,6 +167,8 @@ ncnn_add_layer(Diag) ncnn_add_layer(CELU) ncnn_add_layer(Shrink) ncnn_add_layer(RMSNorm) +ncnn_add_layer(Spectrogram) +ncnn_add_layer(InverseSpectrogram) if(NCNN_VULKAN) ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp) diff --git a/src/layer/inversespectrogram.cpp b/src/layer/inversespectrogram.cpp new file mode 100644 index 000000000000..08aa0f86d10e --- /dev/null +++ b/src/layer/inversespectrogram.cpp @@ -0,0 +1,238 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "inversespectrogram.h" + +namespace ncnn { + +InverseSpectrogram::InverseSpectrogram() +{ + one_blob_only = true; + support_inplace = false; +} + +int InverseSpectrogram::load_param(const ParamDict& pd) +{ + n_fft = pd.get(0, 0); + returns = pd.get(1, 0); + hoplen = pd.get(2, n_fft / 4); + winlen = pd.get(3, n_fft); + window_type = pd.get(4, 0); + center = pd.get(5, 1); + normalized = pd.get(7, 0); + + // assert winlen <= n_fft + // generate window + window_data.create(normalized == 2 ? n_fft + 1 : n_fft); + { + float* p = window_data; + for (int i = 0; i < (n_fft - winlen) / 2; i++) + { + *p++ = 0.f; + } + if (window_type == 0) + { + // all ones + for (int i = 0; i < winlen; i++) + { + *p++ = 1.f; + } + } + if (window_type == 1) + { + // hann window + for (int i = 0; i < winlen; i++) + { + *p++ = 0.5f * (1 - cosf(2 * 3.14159265358979323846 * i / winlen)); + } + } + if (window_type == 2) + { + // hamming window + for (int i = 0; i < winlen; i++) + { + *p++ = 0.54f - 0.46f * cosf(2 * 3.14159265358979323846 * i / winlen); + } + } + for (int i = 0; i < n_fft - winlen - (n_fft - winlen) / 2; i++) + { + *p++ = 0.f; + } + + // pre-calculated window norm factor + if (normalized == 2) + { + float sqsum = 0.f; + for (int i = 0; i < n_fft; i++) + { + sqsum += window_data[i] * window_data[i]; + } + window_data[n_fft] = sqrt(sqsum); + } + } + + return 0; +} + +int InverseSpectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + // https://github.com/librosa/librosa/blob/main/librosa/core/spectrum.py#L630 + + // TODO custom window + // TODO output length + + const int frames = bottom_blob.h; + const int freqs = bottom_blob.c; + // assert freqs == n_fft or freqs == n_fft / 2 + 1 + + const int onesided = freqs == n_fft / 2 + 1 ? 1 : 0; + + const int outsize = center ? (frames - 1) * hoplen + (n_fft - n_fft / 2 * 2) : (frames - 1) * hoplen + n_fft; + + const size_t elemsize = bottom_blob.elemsize; + + if (returns == 0) + { + top_blob.create(2, outsize, elemsize, opt.blob_allocator); + } + else + { + top_blob.create(outsize, elemsize, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + Mat window_sumsquare(outsize + n_fft, elemsize, opt.workspace_allocator); + if (window_sumsquare.empty()) + return -100; + + top_blob.fill(0.f); + window_sumsquare.fill(0.f); + + for (int j = 0; j < frames; j++) + { + // collect complex + Mat sp(2, n_fft); + if (onesided == 1) + { + for (int k = 0; k < n_fft / 2 + 1; k++) + { + sp.row(k)[0] = bottom_blob.channel(k).row(j)[0]; + sp.row(k)[1] = bottom_blob.channel(k).row(j)[1]; + } + for (int k = n_fft / 2 + 1; k < n_fft; k++) + { + sp.row(k)[0] = bottom_blob.channel(n_fft - k).row(j)[0]; + sp.row(k)[1] = -bottom_blob.channel(n_fft - k).row(j)[1]; + } + } + else + { + for (int k = 0; k < n_fft; k++) + { + sp.row(k)[0] = bottom_blob.channel(k).row(j)[0]; + sp.row(k)[1] = bottom_blob.channel(k).row(j)[1]; + } + } + + if (normalized == 1) + { + float norm = sqrt(n_fft); + for (int i = 0; i < 2 * n_fft; i++) + { + sp[i] *= norm; + } + } + if (normalized == 2) + { + float norm = window_data[n_fft]; + for (int i = 0; i < 2 * n_fft; i++) + { + sp[i] *= norm; + } + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < n_fft; i++) + { + // inverse dft + float re = 0.f; + float im = 0.f; + for (int k = 0; k < n_fft; k++) + { + double angle = 2 * 3.14159265358979323846 * i * k / n_fft; + + re += sp.row(k)[0] * cosf(angle) - sp.row(k)[1] * sinf(angle); + im += sp.row(k)[0] * sinf(angle) + sp.row(k)[1] * cosf(angle); + } + + re /= n_fft; + im /= n_fft; + + // apply window + re *= window_data[i]; + im *= window_data[i]; + + int output_index = j * hoplen + i; + if (center == 1) + { + output_index -= n_fft / 2; + } + if (output_index >= 0 && output_index < outsize) + { + // square window + window_sumsquare[output_index] += window_data[i] * window_data[i]; + + if (returns == 0) + { + top_blob.row(output_index)[0] += re; + top_blob.row(output_index)[1] += im; + } + if (returns == 1) + { + top_blob[output_index] += re; + } + if (returns == 2) + { + top_blob[output_index] += im; + } + } + } + } + + // square window norm + if (returns == 0) + { + for (int i = 0; i < outsize; i++) + { + if (window_sumsquare[i] != 0.f) + { + top_blob.row(i)[0] /= window_sumsquare[i]; + top_blob.row(i)[1] /= window_sumsquare[i]; + } + } + } + else + { + for (int i = 0; i < outsize; i++) + { + if (window_sumsquare[i] != 0.f) + top_blob[i] /= window_sumsquare[i]; + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/inversespectrogram.h b/src/layer/inversespectrogram.h new file mode 100644 index 000000000000..969868d1540a --- /dev/null +++ b/src/layer/inversespectrogram.h @@ -0,0 +1,45 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_INVERSESPECTROGRAM_H +#define LAYER_INVERSESPECTROGRAM_H + +#include "layer.h" + +namespace ncnn { + +class InverseSpectrogram : public Layer +{ +public: + InverseSpectrogram(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + +public: + int n_fft; + int returns; // 0=complex 1=real 2=imag + int hoplen; + int winlen; + int window_type; // 0=ones 1=hann 2=hamming + int center; + int normalized; // 0=disabled 1=sqrt(n_fft) 2=window-l2-energy + + Mat window_data; +}; + +} // namespace ncnn + +#endif // LAYER_INVERSESPECTROGRAM_H diff --git a/src/layer/spectrogram.cpp b/src/layer/spectrogram.cpp new file mode 100644 index 000000000000..f616131579b6 --- /dev/null +++ b/src/layer/spectrogram.cpp @@ -0,0 +1,221 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "spectrogram.h" + +namespace ncnn { + +Spectrogram::Spectrogram() +{ + one_blob_only = true; + support_inplace = false; +} + +int Spectrogram::load_param(const ParamDict& pd) +{ + n_fft = pd.get(0, 0); + power = pd.get(1, 0); + hoplen = pd.get(2, n_fft / 4); + winlen = pd.get(3, n_fft); + window_type = pd.get(4, 0); + center = pd.get(5, 1); + pad_type = pd.get(6, 2); + normalized = pd.get(7, 0); + onesided = pd.get(8, 1); + + // assert winlen <= n_fft + // generate window + window_data.create(normalized == 2 ? n_fft + 1 : n_fft); + { + float* p = window_data; + for (int i = 0; i < (n_fft - winlen) / 2; i++) + { + *p++ = 0.f; + } + if (window_type == 0) + { + // all ones + for (int i = 0; i < winlen; i++) + { + *p++ = 1.f; + } + } + if (window_type == 1) + { + // hann window + for (int i = 0; i < winlen; i++) + { + *p++ = 0.5f * (1 - cosf(2 * 3.14159265358979323846 * i / winlen)); + } + } + if (window_type == 2) + { + // hamming window + for (int i = 0; i < winlen; i++) + { + *p++ = 0.54f - 0.46f * cosf(2 * 3.14159265358979323846 * i / winlen); + } + } + for (int i = 0; i < n_fft - winlen - (n_fft - winlen) / 2; i++) + { + *p++ = 0.f; + } + + // pre-calculated window norm factor + if (normalized == 2) + { + float sqsum = 0.f; + for (int i = 0; i < n_fft; i++) + { + sqsum += window_data[i] * window_data[i]; + } + window_data[n_fft] = 1.f / sqrt(sqsum); + } + } + + return 0; +} + +int Spectrogram::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + // https://pytorch.org/audio/stable/generated/torchaudio.functional.spectrogram.html + + // TODO custom window + + Mat bottom_blob_bordered = bottom_blob; + if (center == 1) + { + Option opt_b = opt; + opt_b.blob_allocator = opt.workspace_allocator; + if (pad_type == 0) + copy_make_border(bottom_blob, bottom_blob_bordered, 0, 0, n_fft / 2, n_fft / 2, BORDER_CONSTANT, 0.f, opt_b); + if (pad_type == 1) + copy_make_border(bottom_blob, bottom_blob_bordered, 0, 0, n_fft / 2, n_fft / 2, BORDER_REPLICATE, 0.f, opt_b); + if (pad_type == 2) + copy_make_border(bottom_blob, bottom_blob_bordered, 0, 0, n_fft / 2, n_fft / 2, BORDER_REFLECT, 0.f, opt_b); + } + + const int size = bottom_blob_bordered.w; + + // const int frames = size / hoplen + 1; + const int frames = (size - n_fft) / hoplen + 1; + const int freqs_onesided = n_fft / 2 + 1; + const int freqs = onesided ? freqs_onesided : n_fft; + + const size_t elemsize = bottom_blob_bordered.elemsize; + + if (power == 0) + { + top_blob.create(2, frames, freqs, elemsize, opt.blob_allocator); + } + else + { + top_blob.create(frames, freqs, elemsize, opt.blob_allocator); + } + if (top_blob.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < freqs_onesided; i++) + { + const float* ptr = bottom_blob_bordered; + float* outptr = power == 0 ? top_blob.channel(i) : top_blob.row(i); + + for (int j = 0; j < frames; j++) + { + float re = 0.f; + float im = 0.f; + for (int k = 0; k < n_fft; k++) + { + float v = ptr[k]; + + // apply window + v *= window_data[k]; + + // dft + double angle = 2 * 3.14159265358979323846 * i * k / n_fft; + + re += v * cosf(angle); // + imag * sinf(angle); + im -= v * sinf(angle); // + imag * cosf(angle); + } + + if (normalized == 1) + { + float norm = 1.f / sqrt(n_fft); + re *= norm; + im *= norm; + } + if (normalized == 2) + { + float norm = window_data[n_fft]; + re *= norm; + im *= norm; + } + + if (power == 0) + { + // complex as real + outptr[0] = re; + outptr[1] = im; + outptr += 2; + } + if (power == 1) + { + // magnitude + outptr[0] = sqrt(re * re + im * im); + outptr += 1; + } + if (power == 2) + { + outptr[0] = re * re + im * im; + outptr += 1; + } + + ptr += hoplen; + } + } + + if (!onesided) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = freqs_onesided; i < n_fft; i++) + { + if (power == 0) + { + const float* ptr = top_blob.channel(n_fft - i); + float* outptr = top_blob.channel(i); + + for (int j = 0; j < frames; j++) + { + // complex as real + outptr[0] = ptr[0]; + outptr[1] = -ptr[1]; + ptr += 2; + outptr += 2; + } + } + else // if (power == 1 || power == 2) + { + const float* ptr = top_blob.row(n_fft - i); + float* outptr = top_blob.row(i); + + memcpy(outptr, ptr, frames * sizeof(float)); + } + } + } + + return 0; +} + +} // namespace ncnn diff --git a/src/layer/spectrogram.h b/src/layer/spectrogram.h new file mode 100644 index 000000000000..712dadafd18b --- /dev/null +++ b/src/layer/spectrogram.h @@ -0,0 +1,47 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef LAYER_SPECTROGRAM_H +#define LAYER_SPECTROGRAM_H + +#include "layer.h" + +namespace ncnn { + +class Spectrogram : public Layer +{ +public: + Spectrogram(); + + virtual int load_param(const ParamDict& pd); + + virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + +public: + int n_fft; + int power; + int hoplen; + int winlen; + int window_type; // 0=ones 1=hann 2=hamming + int center; + int pad_type; // 0=CONSTANT 1=REPLICATE 2=REFLECT + int normalized; // 0=disabled 1=sqrt(n_fft) 2=window-l2-energy + int onesided; + + Mat window_data; +}; + +} // namespace ncnn + +#endif // LAYER_SPECTROGRAM_H diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e2ddc32a00dc..f55859e736ea 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -117,6 +117,7 @@ ncnn_add_layer_test(HardSwish) ncnn_add_layer_test(InnerProduct) ncnn_add_layer_test(InstanceNorm) ncnn_add_layer_test(Interp) +ncnn_add_layer_test(InverseSpectrogram) ncnn_add_layer_test(LayerNorm) ncnn_add_layer_test(LRN) ncnn_add_layer_test(LSTM) @@ -154,6 +155,7 @@ ncnn_add_layer_test(Sigmoid) ncnn_add_layer_test(Slice) ncnn_add_layer_test(Softmax) ncnn_add_layer_test(Softplus) +ncnn_add_layer_test(Spectrogram) ncnn_add_layer_test(Squeeze) ncnn_add_layer_test(Swish) ncnn_add_layer_test(TanH) diff --git a/tests/test_inversespectrogram.cpp b/tests/test_inversespectrogram.cpp new file mode 100644 index 000000000000..59796efdf93a --- /dev/null +++ b/tests/test_inversespectrogram.cpp @@ -0,0 +1,56 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "testutil.h" + +static int test_inversespectrogram(int frames, int freqs, int n_fft, int returns, int hoplen, int winlen, int window_type, int center, int normalized) +{ + ncnn::Mat a = RandomMat(2, frames, freqs); + + ncnn::ParamDict pd; + pd.set(0, n_fft); + pd.set(1, returns); + pd.set(2, hoplen); + pd.set(3, winlen); + pd.set(4, window_type); + pd.set(5, center); + pd.set(7, normalized); + + std::vector weights(0); + + int ret = test_layer("InverseSpectrogram", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_inversespectrogram failed frames=%d freqs=%d n_fft=%d returns=%d hoplen=%d winlen=%d window_type=%d center=%d normalized=%d\n", frames, freqs, n_fft, returns, hoplen, winlen, window_type, center, normalized); + } + + return ret; +} + +static int test_inversespectrogram_0() +{ + return 0 + || test_inversespectrogram(17, 1, 1, 0, 1, 1, 0, 1, 0) + || test_inversespectrogram(39, 9, 17, 0, 7, 15, 0, 0, 1) + || test_inversespectrogram(128, 6, 10, 0, 2, 7, 1, 1, 1) + || test_inversespectrogram(255, 17, 17, 1, 14, 17, 2, 0, 0) + || test_inversespectrogram(124, 28, 55, 2, 12, 55, 1, 1, 2); +} + +int main() +{ + SRAND(7767517); + + return test_inversespectrogram_0(); +} diff --git a/tests/test_spectrogram.cpp b/tests/test_spectrogram.cpp new file mode 100644 index 000000000000..b58ddd3cfba7 --- /dev/null +++ b/tests/test_spectrogram.cpp @@ -0,0 +1,58 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "testutil.h" + +static int test_spectrogram(int size, int n_fft, int power, int hoplen, int winlen, int window_type, int center, int pad_type, int normalized, int onesided) +{ + ncnn::Mat a = RandomMat(size); + + ncnn::ParamDict pd; + pd.set(0, n_fft); + pd.set(1, power); + pd.set(2, hoplen); + pd.set(3, winlen); + pd.set(4, window_type); + pd.set(5, center); + pd.set(6, pad_type); + pd.set(7, normalized); + pd.set(8, onesided); + + std::vector weights(0); + + int ret = test_layer("Spectrogram", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_spectrogram failed size=%d n_fft=%d power=%d hoplen=%d winlen=%d window_type=%d center=%d pad_type=%d normalized=%d onesided=%d\n", size, n_fft, power, hoplen, winlen, window_type, center, pad_type, normalized, onesided); + } + + return ret; +} + +static int test_spectrogram_0() +{ + return 0 + || test_spectrogram(17, 1, 0, 1, 1, 0, 1, 0, 0, 0) + || test_spectrogram(39, 17, 0, 7, 15, 0, 0, 0, 1, 0) + || test_spectrogram(128, 10, 0, 2, 7, 1, 1, 1, 1, 1) + || test_spectrogram(255, 17, 1, 14, 17, 2, 0, 0, 0, 1) + || test_spectrogram(124, 55, 2, 12, 55, 1, 1, 2, 2, 0); +} + +int main() +{ + SRAND(7767517); + + return test_spectrogram_0(); +} diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 2281875dbd43..89a4a52d02c0 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -306,6 +306,9 @@ set(pnnx_pass_level2_SRCS pass_level2/nn_quantized_FloatFunctional.cpp + pass_level2/torchaudio_F_inverse_spectrogram.cpp + pass_level2/torchaudio_F_spectrogram.cpp + pass_level2/nn_GRU.cpp pass_level2/nn_LSTM.cpp pass_level2/nn_RNN.cpp @@ -570,6 +573,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_cumsum.cpp pass_ncnn/torch_diag.cpp pass_ncnn/torch_flatten.cpp + pass_ncnn/torch_istft.cpp pass_ncnn/torch_logsumexp.cpp pass_ncnn/torch_matmul.cpp pass_ncnn/torch_max.cpp @@ -582,9 +586,12 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torch_slice_scatter.cpp pass_ncnn/torch_squeeze.cpp pass_ncnn/torch_sum.cpp + pass_ncnn/torch_stft.cpp pass_ncnn/torch_t.cpp pass_ncnn/torch_transpose.cpp pass_ncnn/torch_unsqueeze.cpp + pass_ncnn/torchaudio_F_inverse_spectrogram.cpp + pass_ncnn/torchaudio_F_spectrogram.cpp pass_ncnn/torchvision_DeformConv2d.cpp ) diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 9e616699a9e5..394754273b72 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1458,6 +1458,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) fprintf(pyfp, "import torch.nn.functional as F\n"); fprintf(pyfp, "try:\n"); fprintf(pyfp, " import torchvision\n"); + fprintf(pyfp, " import torchaudio\n"); fprintf(pyfp, "except:\n"); fprintf(pyfp, " pass\n"); diff --git a/tools/pnnx/src/pass_level2/torch_stft.cpp b/tools/pnnx/src/pass_level2/torch_stft.cpp index 8a5290bcc747..544defffefbf 100644 --- a/tools/pnnx/src/pass_level2/torch_stft.cpp +++ b/tools/pnnx/src/pass_level2/torch_stft.cpp @@ -43,6 +43,7 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& /*captured_params*/) const { + op->params["pad_mode"] = "reflect"; op->params["center"] = false; } }; @@ -55,7 +56,7 @@ class torch_stft_1 : public GraphRewriterPass const char* match_pattern_graph() const { return R"PNNXIR(7767517 -37 36 +36 35 pnnx.Input input_0 0 1 input pnnx.Input input_1 0 1 n_fft pnnx.Input input_2 0 1 hop_length @@ -79,19 +80,18 @@ prim::Constant op_11 0 1 21 value=%pad_left prim::Constant op_12 0 1 63 value=%pad_right prim::ListConstruct op_13 2 1 21 63 22 prim::Constant op_14 0 1 23 value=%pad_mode -prim::Constant op_15 0 1 24 value=None -aten::pad op_16 4 1 a 22 23 24 b -prim::Constant op_17 0 1 64 value=1 -aten::size op_18 2 1 b 64 27 -prim::NumToTensor op_19 1 1 27 28 -aten::Int op_20 1 1 28 31 -prim::Constant op_21 0 1 33 value=2 -aten::size op_22 2 1 b 33 34 -prim::NumToTensor op_23 1 1 34 35 -aten::Int op_24 1 1 35 40 -prim::ListConstruct op_25 2 1 31 40 41 -aten::view op_26 2 1 b 41 c -aten::stft op_27 8 1 c n_fft hop_length win_length window normalized onesided return_complex out +F.pad op_15 3 1 a 22 23 b +prim::Constant op_16 0 1 64 value=1 +aten::size op_17 2 1 b 64 27 +prim::NumToTensor op_18 1 1 27 28 +aten::Int op_29 1 1 28 31 +prim::Constant op_20 0 1 33 value=2 +aten::size op_21 2 1 b 33 34 +prim::NumToTensor op_22 1 1 34 35 +aten::Int op_23 1 1 35 40 +prim::ListConstruct op_24 2 1 31 40 41 +aten::view op_25 2 1 b 41 c +aten::stft op_26 8 1 c n_fft hop_length win_length window normalized onesided return_complex out pnnx.Output output 1 0 out )PNNXIR"; } @@ -110,4 +110,88 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft_1, 19) +class torch_stft_2 : public torch_stft_1 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +29 28 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 normalized +pnnx.Input input_6 0 1 onesided +pnnx.Input input_7 0 1 return_complex +prim::Constant op_0 0 1 11 value=0 +aten::size op_1 2 1 input 11 12 +prim::NumToTensor op_2 1 1 12 13 +aten::Int op_3 1 1 13 18 +prim::Constant op_4 0 1 15 value=1 +prim::Constant op_5 0 1 121 value=1 +prim::ListConstruct op_6 3 1 15 121 18 19 +aten::view op_7 2 1 input 19 a +prim::Constant op_8 0 1 22 value=%pad_left +prim::Constant op_9 0 1 122 value=%pad_right +prim::ListConstruct op_10 2 1 22 122 23 +prim::Constant op_11 0 1 24 value=%pad_mode +F.pad op_12 3 1 a 23 24 b +prim::Constant op_13 0 1 28 value=2 +aten::size op_14 2 1 b 28 29 +prim::NumToTensor op_15 1 1 29 30 +aten::Int op_16 1 1 30 34 +prim::ListConstruct op_17 1 1 34 35 +aten::view op_18 2 1 b 35 c +aten::stft op_19 8 1 c n_fft hop_length win_length window normalized onesided return_complex out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft_2, 19) + +class torch_stft_3 : public torch_stft_1 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +29 28 +pnnx.Input input_0 0 1 input +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 normalized +pnnx.Input input_6 0 1 onesided +pnnx.Input input_7 0 1 return_complex +prim::Constant op_0 0 1 11 value=0 +aten::size op_1 2 1 input 11 12 +prim::NumToTensor op_2 1 1 12 13 +aten::Int op_3 1 1 13 18 +prim::Constant op_4 0 1 15 value=1 +prim::Constant op_5 0 1 121 value=1 +prim::ListConstruct op_6 3 1 15 121 18 19 +aten::view op_7 2 1 input 19 a +prim::Constant op_8 0 1 22 value=%pad_left +prim::Constant op_9 0 1 122 value=%pad_right +prim::ListConstruct op_10 2 1 22 122 23 +prim::Constant op_11 0 1 24 value=None +F.pad op_12 3 1 a 23 24 b mode=%pad_mode +prim::Constant op_13 0 1 28 value=2 +aten::size op_14 2 1 b 28 29 +prim::NumToTensor op_15 1 1 29 30 +aten::Int op_16 1 1 30 34 +prim::ListConstruct op_17 1 1 34 35 +aten::view op_18 2 1 b 35 c +aten::stft op_19 8 1 c n_fft hop_length win_length window normalized onesided return_complex out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_stft_3, 19) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torchaudio_F_inverse_spectrogram.cpp b/tools/pnnx/src/pass_level2/torchaudio_F_inverse_spectrogram.cpp new file mode 100644 index 000000000000..c2dd7db90644 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torchaudio_F_inverse_spectrogram.cpp @@ -0,0 +1,165 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torchaudio_F_inverse_spectrogram : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +29 28 +pnnx.Input input_0 0 1 spectrogram +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 center +pnnx.Input input_6 0 1 onesided +prim::Constant op_0 0 1 13 value=0 +aten::size op_1 2 1 spectrogram 13 14 +prim::NumToTensor op_2 1 1 14 15 +aten::Int op_3 1 1 15 18 +prim::Constant op_4 0 1 20 value=1 +aten::size op_5 2 1 spectrogram 20 21 +prim::NumToTensor op_6 1 1 21 22 +aten::Int op_7 1 1 22 28 +prim::Constant op_8 0 1 24 value=-1 +prim::ListConstruct op_9 3 1 24 18 28 29 +aten::reshape op_10 2 1 spectrogram 29 spectrogram.1 +prim::Constant op_11 0 1 normalized value=%normalized +prim::Constant op_12 0 1 length value=None +prim::Constant op_13 0 1 return_complex value=False +aten::istft op_14 10 1 spectrogram.1 n_fft hop_length win_length window center normalized onesided length return_complex waveform.1 +prim::Constant op_15 0 1 75 value=1 +aten::size op_16 2 1 waveform.1 75 42 +prim::NumToTensor op_17 1 1 42 43 +aten::Int op_18 1 1 43 47 +prim::ListConstruct op_19 1 1 47 48 +aten::reshape op_20 2 1 waveform.1 48 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.inverse_spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["length"] = Parameter(); + op->params["pad"] = 0; + if (captured_params.at("normalized").b) + { + op->params["normalized"] = "frame_length"; + } + else + { + op->params["normalized"] = false; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_inverse_spectrogram, 6) + +class torchaudio_F_inverse_spectrogram_0 : public torchaudio_F_inverse_spectrogram +{ + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +33 32 +pnnx.Input input_0 0 1 spectrogram +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 center +pnnx.Input input_6 0 1 onesided +prim::Constant op_0 0 1 13 value=0 +aten::size op_1 2 1 spectrogram 13 14 +prim::NumToTensor op_2 1 1 14 15 +aten::Int op_3 1 1 15 18 +prim::Constant op_4 0 1 20 value=1 +aten::size op_5 2 1 spectrogram 20 21 +prim::NumToTensor op_6 1 1 21 22 +aten::Int op_7 1 1 22 25 +prim::Constant op_8 0 1 27 value=2 +aten::size op_9 2 1 spectrogram 27 28 +prim::NumToTensor op_10 1 1 28 29 +aten::Int op_11 1 1 29 35 +prim::Constant op_12 0 1 31 value=-1 +prim::ListConstruct op_13 3 1 31 25 35 36 +aten::reshape op_14 2 1 spectrogram 36 spectrogram.1 +prim::Constant op_15 0 1 normalized value=%normalized +prim::Constant op_16 0 1 length value=None +prim::Constant op_17 0 1 return_complex value=False +aten::istft op_18 10 1 spectrogram.1 n_fft hop_length win_length window center normalized onesided length return_complex waveform.1 +prim::Constant op_19 0 1 83 value=1 +aten::size op_20 2 1 waveform.1 83 49 +prim::NumToTensor op_21 1 1 49 50 +aten::Int op_22 1 1 50 55 +prim::ListConstruct op_23 2 1 18 55 56 +aten::reshape op_24 2 1 waveform.1 56 out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_inverse_spectrogram_0, 6) + +class torchaudio_F_inverse_spectrogram_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +15 14 +pnnx.Input input_0 0 1 spectrogram +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 center +pnnx.Input input_6 0 1 onesided +prim::Constant op_0 0 1 13 value=2.000000e+00 +aten::pow op_1 2 1 window 13 14 +prim::Constant op_2 0 1 87 value=None +aten::sum op_3 2 1 14 87 16 +aten::sqrt op_4 1 1 16 17 +aten::mul op_5 2 1 spectrogram 17 spectrogram.1 +torchaudio.functional.inverse_spectrogram op_6 7 1 spectrogram.1 n_fft hop_length win_length window center onesided out normalized=False length=%length pad=%pad +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.inverse_spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["length"] = captured_params.at("length"); + op->params["pad"] = captured_params.at("pad"); + op->params["normalized"] = "window"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_inverse_spectrogram_1, 7) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level2/torchaudio_F_spectrogram.cpp b/tools/pnnx/src/pass_level2/torchaudio_F_spectrogram.cpp new file mode 100644 index 000000000000..cf123c78d633 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torchaudio_F_spectrogram.cpp @@ -0,0 +1,709 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torchaudio_F_spectrogram : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +27 26 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +prim::Constant op_0 0 1 11 value=0 +aten::size op_1 2 1 waveform 11 12 +prim::NumToTensor op_2 1 1 12 13 +aten::Int op_3 1 1 13 18 +prim::Constant op_4 0 1 15 value=-1 +prim::ListConstruct op_5 2 1 15 18 19 +aten::reshape op_6 2 1 waveform 19 waveform.1 +prim::Constant op_7 0 1 normalized value=%normalized +prim::Constant op_8 0 1 return_complex value=True +aten::stft op_9 8 1 waveform.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 +prim::Constant op_10 0 1 29 value=1 +aten::size op_11 2 1 spec_f.1 29 30 +prim::NumToTensor op_12 1 1 30 31 +aten::Int op_13 1 1 31 34 +prim::Constant op_14 0 1 36 value=2 +aten::size op_15 2 1 spec_f.1 36 37 +prim::NumToTensor op_16 1 1 37 38 +aten::Int op_17 1 1 38 43 +prim::ListConstruct op_18 2 1 34 43 44 +aten::reshape op_19 2 1 spec_f.1 44 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["pad"] = 0; + op->params["pad_mode"] = "reflect"; + op->params["center"] = false; + op->params["power"] = Parameter(); + if (captured_params.at("normalized").b) + { + op->params["normalized"] = "frame_length"; + } + else + { + op->params["normalized"] = false; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram, 6) + +class torchaudio_F_spectrogram_0 : public torchaudio_F_spectrogram +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +31 30 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +prim::Constant op_0 0 1 11 value=0 +aten::size op_1 2 1 waveform 11 12 +prim::NumToTensor op_2 1 1 12 13 +aten::Int op_3 1 1 13 16 +prim::Constant op_4 0 1 18 value=1 +aten::size op_5 2 1 waveform 18 19 +prim::NumToTensor op_6 1 1 19 20 +aten::Int op_7 1 1 20 25 +prim::Constant op_8 0 1 22 value=-1 +prim::ListConstruct op_9 2 1 22 25 26 +aten::reshape op_10 2 1 waveform 26 waveform.1 +prim::Constant op_11 0 1 normalized value=%normalized +prim::Constant op_12 0 1 return_complex value=True +aten::stft op_13 8 1 waveform.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 +prim::Constant op_14 0 1 72 value=1 +aten::size op_15 2 1 spec_f.1 72 36 +prim::NumToTensor op_16 1 1 36 37 +aten::Int op_17 1 1 37 40 +prim::Constant op_18 0 1 42 value=2 +aten::size op_19 2 1 spec_f.1 42 43 +prim::NumToTensor op_20 1 1 43 44 +aten::Int op_21 1 1 44 50 +prim::ListConstruct op_22 3 1 16 40 50 51 +aten::reshape op_23 2 1 spec_f.1 51 out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_0, 6) + +class torchaudio_F_spectrogram_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +58 57 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +prim::Constant op_0 0 1 18 value=1 +aten::size op_1 2 1 waveform 18 19 +prim::NumToTensor op_2 1 1 19 20 +aten::Int op_3 1 1 20 25 +prim::Constant op_4 0 1 22 value=-1 +prim::ListConstruct op_5 2 1 22 25 26 +aten::reshape op_6 2 1 waveform 26 waveform.1 +prim::Constant op_7 0 1 106 value=0 +aten::size op_8 2 1 waveform.1 106 29 +prim::NumToTensor op_9 1 1 29 30 +aten::Int op_10 1 1 30 33 +prim::Constant op_11 0 1 107 value=1 +aten::size op_12 2 1 waveform.1 107 35 +prim::NumToTensor op_13 1 1 35 36 +aten::Int op_14 1 1 36 41 +prim::Constant op_15 0 1 108 value=1 +prim::ListConstruct op_16 3 1 108 33 41 42 +aten::view op_17 2 1 waveform.1 42 input0.1 +prim::Constant op_18 0 1 45 value=%pad_left +prim::Constant op_19 0 1 109 value=%pad_right +prim::ListConstruct op_20 2 1 45 109 46 +prim::Constant op_21 0 1 47 value=%pad_mode +prim::Constant op_22 0 1 110 value=None +aten::pad op_23 4 1 input0.1 46 47 110 input1.1 +prim::Constant op_24 0 1 111 value=1 +aten::size op_25 2 1 input1.1 111 51 +prim::NumToTensor op_26 1 1 51 52 +aten::Int op_27 1 1 52 55 +prim::Constant op_28 0 1 57 value=2 +aten::size op_29 2 1 input1.1 57 58 +prim::NumToTensor op_30 1 1 58 59 +aten::Int op_31 1 1 59 64 +prim::ListConstruct op_32 2 1 55 64 65 +aten::view op_33 2 1 input1.1 65 input2.1 +prim::Constant op_34 0 1 normalized value=%normalized +prim::Constant op_35 0 1 return_complex value=True +aten::stft op_36 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 +prim::Constant op_37 0 1 11 value=0 +aten::size op_38 2 1 waveform 11 12 +prim::NumToTensor op_39 1 1 12 13 +aten::Int op_40 1 1 13 16 +prim::Constant op_41 0 1 116 value=1 +aten::size op_42 2 1 spec_f.1 116 75 +prim::NumToTensor op_43 1 1 75 76 +aten::Int op_44 1 1 76 79 +prim::Constant op_45 0 1 117 value=2 +aten::size op_46 2 1 spec_f.1 117 81 +prim::NumToTensor op_47 1 1 81 82 +aten::Int op_48 1 1 82 88 +prim::ListConstruct op_49 3 1 16 79 88 89 +aten::reshape op_50 2 1 spec_f.1 89 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["pad"] = 0; + op->params["pad_mode"] = captured_params.at("pad_mode"); + op->params["center"] = true; + op->params["power"] = Parameter(); + if (captured_params.at("normalized").b) + { + op->params["normalized"] = "frame_length"; + } + else + { + op->params["normalized"] = false; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1, 6) + +class torchaudio_F_spectrogram_1_1 : public torchaudio_F_spectrogram_1 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +63 62 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +prim::Constant op_0 0 1 11 value=0 +aten::size op_1 2 1 waveform 11 12 +prim::NumToTensor op_2 1 1 12 13 +aten::Int op_3 1 1 13 18 +prim::Constant op_4 0 1 15 value=-1 +prim::ListConstruct op_5 2 1 15 18 19 +aten::reshape op_6 2 1 waveform 19 waveform.1 +prim::Constant op_7 0 1 108 value=0 +aten::size op_8 2 1 waveform.1 108 22 +prim::NumToTensor op_9 1 1 22 23 +aten::Int op_10 1 1 23 26 +prim::Constant op_11 0 1 28 value=1 +aten::size op_12 2 1 waveform.1 28 29 +prim::NumToTensor op_13 1 1 29 30 +aten::Int op_14 1 1 30 35 +prim::Constant op_15 0 1 109 value=1 +prim::ListConstruct op_16 3 1 109 26 35 36 +aten::view op_17 2 1 waveform.1 36 input0.1 +prim::Constant op_18 0 1 39 value=%pad_left +prim::Constant op_19 0 1 110 value=%pad_right +prim::ListConstruct op_20 2 1 39 110 40 +prim::Constant op_21 0 1 41 value=%pad_mode +prim::Constant op_22 0 1 111 value=None +aten::pad op_23 4 1 input0.1 40 41 111 input1.1 +prim::Constant op_24 0 1 112 value=1 +aten::size op_25 2 1 input1.1 112 45 +prim::NumToTensor op_26 1 1 45 46 +aten::Int op_27 1 1 46 49 +prim::Constant op_28 0 1 51 value=2 +aten::size op_29 2 1 input1.1 51 52 +prim::NumToTensor op_30 1 1 52 53 +aten::Int op_31 1 1 53 58 +prim::ListConstruct op_32 2 1 49 58 59 +aten::view op_33 2 1 input1.1 59 input2.1 +prim::Constant op_34 0 1 normalized value=%normalized +prim::Constant op_35 0 1 return_complex value=True +aten::stft op_36 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 +prim::Constant op_37 0 1 117 value=1 +aten::size op_38 2 1 spec_f.1 117 69 +prim::NumToTensor op_39 1 1 69 70 +aten::Int op_40 1 1 70 73 +prim::Constant op_50 0 1 118 value=2 +aten::size op_51 2 1 spec_f.1 118 75 +prim::NumToTensor op_52 1 1 75 76 +aten::Int op_53 1 1 76 81 +prim::ListConstruct op_54 2 1 73 81 82 +aten::reshape op_55 2 1 spec_f.1 82 out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_1, 6) + +class torchaudio_F_spectrogram_1_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +52 51 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +prim::Constant op_0 0 1 211 value=0 +aten::size op_1 2 1 waveform 211 107 +prim::NumToTensor op_2 1 1 107 108 +aten::Int op_3 1 1 108 112 +prim::Constant op_4 0 1 212 value=-1 +prim::ListConstruct op_5 2 1 212 112 113 +aten::reshape op_6 2 1 waveform 113 input3.1 +prim::Constant op_7 0 1 213 value=0 +aten::size op_8 2 1 input3.1 213 116 +prim::NumToTensor op_9 1 1 116 117 +aten::Int op_10 1 1 117 120 +prim::Constant op_11 0 1 214 value=1 +aten::size op_12 2 1 input3.1 214 122 +prim::NumToTensor op_13 1 1 122 123 +aten::Int op_14 1 1 123 128 +prim::Constant op_15 0 1 215 value=1 +prim::ListConstruct op_16 3 1 215 120 128 129 +aten::view op_17 2 1 input3.1 129 input4.1 +prim::Constant op_18 0 1 216 value=%pad_left +prim::Constant op_19 0 1 217 value=%pad_right +prim::ListConstruct op_20 2 1 216 217 132 +aten::reflection_pad1d op_21 2 1 input4.1 132 input5.1 +prim::Constant op_22 0 1 218 value=1 +aten::size op_23 2 1 input5.1 218 135 +prim::NumToTensor op_24 1 1 135 136 +aten::Int op_25 1 1 136 139 +prim::Constant op_26 0 1 219 value=2 +aten::size op_27 2 1 input5.1 219 141 +prim::NumToTensor op_28 1 1 141 142 +aten::Int op_29 1 1 142 147 +prim::ListConstruct op_30 2 1 139 147 148 +aten::view op_31 2 1 input5.1 148 input6.1 +prim::Constant op_32 0 1 normalized value=%normalized +prim::Constant op_33 0 1 return_complex value=True +aten::stft op_34 8 1 input6.1 n_fft hop_length win_length window normalized onesided return_complex spec_f2.1 +prim::Constant op_35 0 1 226 value=1 +aten::size op_36 2 1 spec_f2.1 226 157 +prim::NumToTensor op_37 1 1 157 158 +aten::Int op_38 1 1 158 161 +prim::Constant op_39 0 1 227 value=2 +aten::size op_40 2 1 spec_f2.1 227 163 +prim::NumToTensor op_41 1 1 163 164 +aten::Int op_42 1 1 164 169 +prim::ListConstruct op_43 2 1 161 169 170 +aten::reshape op_44 2 1 spec_f2.1 170 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["pad"] = 0; + op->params["pad_mode"] = "reflect"; + op->params["center"] = true; + op->params["power"] = Parameter(); + if (captured_params.at("normalized").b) + { + op->params["normalized"] = "frame_length"; + } + else + { + op->params["normalized"] = false; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_2, 6) + +class torchaudio_F_spectrogram_1_3 : public torchaudio_F_spectrogram_1_2 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +56 55 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +prim::Constant op_0 0 1 11 value=0 +aten::size op_1 2 1 waveform 11 12 +prim::NumToTensor op_2 1 1 12 13 +aten::Int op_3 1 1 13 16 +prim::Constant op_4 0 1 18 value=1 +aten::size op_5 2 1 waveform 18 19 +prim::NumToTensor op_6 1 1 19 20 +aten::Int op_7 1 1 20 25 +prim::Constant op_8 0 1 22 value=-1 +prim::ListConstruct op_9 2 1 22 25 26 +aten::reshape op_10 2 1 waveform 26 input.1 +prim::Constant op_11 0 1 326 value=0 +aten::size op_12 2 1 input.1 326 29 +prim::NumToTensor op_13 1 1 29 30 +aten::Int op_14 1 1 30 33 +prim::Constant op_15 0 1 327 value=1 +aten::size op_16 2 1 input.1 327 35 +prim::NumToTensor op_17 1 1 35 36 +aten::Int op_18 1 1 36 41 +prim::Constant op_19 0 1 328 value=1 +prim::ListConstruct op_20 3 1 328 33 41 42 +aten::view op_21 2 1 input.1 42 input0.1 +prim::Constant op_22 0 1 45 value=%pad_left +prim::Constant op_23 0 1 329 value=%pad_right +prim::ListConstruct op_24 2 1 45 329 46 +aten::reflection_pad1d op_25 2 1 input0.1 46 input1.1 +prim::Constant op_26 0 1 330 value=1 +aten::size op_27 2 1 input1.1 330 49 +prim::NumToTensor op_28 1 1 49 50 +aten::Int op_29 1 1 50 53 +prim::Constant op_30 0 1 55 value=2 +aten::size op_31 2 1 input1.1 55 56 +prim::NumToTensor op_32 1 1 56 57 +aten::Int op_33 1 1 57 62 +prim::ListConstruct op_34 2 1 53 62 63 +aten::view op_35 2 1 input1.1 63 input2.1 +prim::Constant op_36 0 1 normalized value=%normalized +prim::Constant op_37 0 1 return_complex value=True +aten::stft op_38 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 +prim::Constant op_39 0 1 334 value=1 +aten::size op_40 2 1 spec_f.1 334 74 +prim::NumToTensor op_41 1 1 74 75 +aten::Int op_42 1 1 75 78 +prim::Constant op_43 0 1 335 value=2 +aten::size op_44 2 1 spec_f.1 335 80 +prim::NumToTensor op_45 1 1 80 81 +aten::Int op_46 1 1 81 87 +prim::ListConstruct op_47 3 1 16 78 87 88 +aten::reshape op_48 2 1 spec_f.1 88 out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_3, 6) + +class torchaudio_F_spectrogram_1_4 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +53 52 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +prim::Constant op_0 0 1 211 value=0 +aten::size op_1 2 1 waveform 211 107 +prim::NumToTensor op_2 1 1 107 108 +aten::Int op_3 1 1 108 112 +prim::Constant op_4 0 1 212 value=-1 +prim::ListConstruct op_5 2 1 212 112 113 +aten::reshape op_6 2 1 waveform 113 input3.1 +prim::Constant op_7 0 1 213 value=0 +aten::size op_8 2 1 input3.1 213 116 +prim::NumToTensor op_9 1 1 116 117 +aten::Int op_10 1 1 117 120 +prim::Constant op_11 0 1 214 value=1 +aten::size op_12 2 1 input3.1 214 122 +prim::NumToTensor op_13 1 1 122 123 +aten::Int op_14 1 1 123 128 +prim::Constant op_15 0 1 215 value=1 +prim::ListConstruct op_16 3 1 215 120 128 129 +aten::view op_17 2 1 input3.1 129 input4.1 +prim::Constant op_18 0 1 216 value=%pad_left +prim::Constant op_19 0 1 217 value=%pad_right +prim::ListConstruct op_20 2 1 216 217 132 +prim::Constant op_21 0 1 46 value=0.000000e+00 +aten::constant_pad_nd op_22 3 1 input4.1 132 46 input5.1 +prim::Constant op_23 0 1 218 value=1 +aten::size op_24 2 1 input5.1 218 135 +prim::NumToTensor op_25 1 1 135 136 +aten::Int op_26 1 1 136 139 +prim::Constant op_27 0 1 219 value=2 +aten::size op_28 2 1 input5.1 219 141 +prim::NumToTensor op_29 1 1 141 142 +aten::Int op_30 1 1 142 147 +prim::ListConstruct op_31 2 1 139 147 148 +aten::view op_32 2 1 input5.1 148 input6.1 +prim::Constant op_33 0 1 normalized value=%normalized +prim::Constant op_34 0 1 return_complex value=True +aten::stft op_35 8 1 input6.1 n_fft hop_length win_length window normalized onesided return_complex spec_f2.1 +prim::Constant op_36 0 1 226 value=1 +aten::size op_37 2 1 spec_f2.1 226 157 +prim::NumToTensor op_38 1 1 157 158 +aten::Int op_39 1 1 158 161 +prim::Constant op_40 0 1 227 value=2 +aten::size op_41 2 1 spec_f2.1 227 163 +prim::NumToTensor op_42 1 1 163 164 +aten::Int op_43 1 1 164 169 +prim::ListConstruct op_44 2 1 161 169 170 +aten::reshape op_45 2 1 spec_f2.1 170 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["pad"] = 0; + op->params["pad_mode"] = "constant"; + op->params["center"] = true; + op->params["power"] = Parameter(); + if (captured_params.at("normalized").b) + { + op->params["normalized"] = "frame_length"; + } + else + { + op->params["normalized"] = false; + } + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_4, 6) + +class torchaudio_F_spectrogram_1_5 : public torchaudio_F_spectrogram_1_4 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +57 56 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +prim::Constant op_0 0 1 11 value=0 +aten::size op_1 2 1 waveform 11 12 +prim::NumToTensor op_2 1 1 12 13 +aten::Int op_3 1 1 13 16 +prim::Constant op_4 0 1 18 value=1 +aten::size op_5 2 1 waveform 18 19 +prim::NumToTensor op_6 1 1 19 20 +aten::Int op_7 1 1 20 25 +prim::Constant op_8 0 1 22 value=-1 +prim::ListConstruct op_9 2 1 22 25 26 +aten::reshape op_10 2 1 waveform 26 input.1 +prim::Constant op_11 0 1 326 value=0 +aten::size op_12 2 1 input.1 326 29 +prim::NumToTensor op_13 1 1 29 30 +aten::Int op_14 1 1 30 33 +prim::Constant op_15 0 1 327 value=1 +aten::size op_16 2 1 input.1 327 35 +prim::NumToTensor op_17 1 1 35 36 +aten::Int op_18 1 1 36 41 +prim::Constant op_19 0 1 328 value=1 +prim::ListConstruct op_20 3 1 328 33 41 42 +aten::view op_21 2 1 input.1 42 input0.1 +prim::Constant op_22 0 1 45 value=%pad_left +prim::Constant op_23 0 1 329 value=%pad_right +prim::ListConstruct op_24 2 1 45 329 46 +prim::Constant op_25 0 1 47 value=0.000000e+00 +aten::constant_pad_nd op_26 3 1 input0.1 46 47 input1.1 +prim::Constant op_27 0 1 330 value=1 +aten::size op_28 2 1 input1.1 330 49 +prim::NumToTensor op_29 1 1 49 50 +aten::Int op_30 1 1 50 53 +prim::Constant op_31 0 1 55 value=2 +aten::size op_32 2 1 input1.1 55 56 +prim::NumToTensor op_33 1 1 56 57 +aten::Int op_34 1 1 57 62 +prim::ListConstruct op_35 2 1 53 62 63 +aten::view op_36 2 1 input1.1 63 input2.1 +prim::Constant op_37 0 1 normalized value=%normalized +prim::Constant op_38 0 1 return_complex value=True +aten::stft op_39 8 1 input2.1 n_fft hop_length win_length window normalized onesided return_complex spec_f.1 +prim::Constant op_40 0 1 334 value=1 +aten::size op_41 2 1 spec_f.1 334 74 +prim::NumToTensor op_42 1 1 74 75 +aten::Int op_43 1 1 75 78 +prim::Constant op_44 0 1 335 value=2 +aten::size op_45 2 1 spec_f.1 335 80 +prim::NumToTensor op_46 1 1 80 81 +aten::Int op_47 1 1 81 87 +prim::ListConstruct op_48 3 1 16 78 87 88 +aten::reshape op_49 2 1 spec_f.1 88 out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1_5, 6) + +class torchaudio_F_spectrogram_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +14 13 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +torchaudio.functional.spectrogram op_0 6 1 waveform n_fft hop_length win_length window onesided spec power=None normalized=False center=%center pad=%pad pad_mode=%pad_mode +prim::Constant op_1 0 1 92 value=2.000000e+00 +aten::pow op_2 2 1 window 92 93 +prim::Constant op_3 0 1 127 value=None +aten::sum op_4 2 1 93 127 95 +aten::sqrt op_5 1 1 95 96 +aten::div op_6 2 1 spec 96 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["pad"] = captured_params.at("pad"); + op->params["pad_mode"] = captured_params.at("pad_mode"); + op->params["center"] = captured_params.at("center"); + op->params["power"] = Parameter(); + op->params["normalized"] = "window"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_2, 7) + +class torchaudio_F_spectrogram_3 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +torchaudio.functional.spectrogram op_0 6 1 waveform n_fft hop_length win_length window onesided spec power=None normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode +aten::abs op_1 1 1 spec out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["pad"] = captured_params.at("pad"); + op->params["pad_mode"] = captured_params.at("pad_mode"); + op->params["center"] = captured_params.at("center"); + op->params["normalized"] = captured_params.at("normalized"); + op->params["power"] = 1; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_3, 8) + +class torchaudio_F_spectrogram_4 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +10 9 +pnnx.Input input_0 0 1 waveform +pnnx.Input input_1 0 1 n_fft +pnnx.Input input_2 0 1 hop_length +pnnx.Input input_3 0 1 win_length +pnnx.Input input_4 0 1 window +pnnx.Input input_5 0 1 onesided +torchaudio.functional.spectrogram op_0 6 1 waveform n_fft hop_length win_length window onesided spec power=1 normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode +prim::Constant op_1 0 1 391 value=2 +aten::pow op_2 2 1 spec 391 out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torchaudio.functional.spectrogram"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["pad"] = captured_params.at("pad"); + op->params["pad_mode"] = captured_params.at("pad_mode"); + op->params["center"] = captured_params.at("center"); + op->params["normalized"] = captured_params.at("normalized"); + op->params["power"] = 2; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_4, 9) + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_istft.cpp b/tools/pnnx/src/pass_ncnn/torch_istft.cpp new file mode 100644 index 000000000000..3acbe6540095 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_istft.cpp @@ -0,0 +1,203 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_istft : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +torch.view_as_complex op_0 1 1 input a +torch.istft op_1 1 1 a out center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=False win_length=%win_length window=None +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InverseSpectrogram"; + } + + const char* name_str() const + { + return "istft"; + } + + void write(Operator* op, const std::map& captured_params) const + { + op->params["0"] = captured_params.at("n_fft"); + op->params["1"] = 1; // returns + op->params["2"] = captured_params.at("hop_length"); + op->params["3"] = captured_params.at("win_length"); + op->params["4"] = 0; // all ones + op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; + op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft, 20) + +class torch_istft_1 : public torch_istft +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +torch.view_as_complex op_0 1 1 input a +torch.istft op_1 1 1 a b center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length window=None +torch.view_as_real op_2 1 1 b out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(Operator* op, const std::map& captured_params) const + { + torch_istft::write(op, captured_params); + + op->params["1"] = 0; // returns + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_1, 20) + +static bool NearlyEqual(float a, float b, float epsilon) +{ + if (a == b) + return true; + + float diff = (float)fabs(a - b); + if (diff <= epsilon) + return true; + + // relative error + return diff < epsilon * std::max(fabs(a), fabs(b)); +} + +static int detect_window_type(const std::vector& window_data) +{ + const int winlen = (int)window_data.size(); + + bool is_one = true; + bool is_hann = true; + bool is_hamming = true; + for (int i = 0; i < winlen; i++) + { + if (!NearlyEqual(window_data[i], 1.f, 0.001)) + is_one = false; + + if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001)) + is_hann = false; + + if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001)) + is_hamming = false; + } + + if (is_one) + return 0; + if (is_hann) + return 1; + if (is_hamming) + return 2; + + return -1; +} + +class torch_istft_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +torch.view_as_complex op_0 1 1 input a +pnnx.Attribute op_1 0 1 window @data +torch.istft op_2 2 1 a window out center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=False win_length=%win_length +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InverseSpectrogram"; + } + + const char* name_str() const + { + return "istft"; + } + + bool match(const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + const std::vector window_data = captured_attrs.at("op_1.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + return window_type != -1; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector window_data = captured_attrs.at("op_1.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + + op->params["0"] = captured_params.at("n_fft"); + op->params["1"] = 1; // returns + op->params["2"] = captured_params.at("hop_length"); + op->params["3"] = captured_params.at("win_length"); + op->params["4"] = window_type; + op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; + op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_2, 20) + +class torch_istft_3 : public torch_istft_2 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +6 5 +pnnx.Input input 0 1 input +torch.view_as_complex op_0 1 1 input a +pnnx.Attribute op_1 0 1 window @data +torch.istft op_2 2 1 a window b center=%center hop_length=%hop_length length=%length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length +torch.view_as_real op_3 1 1 b out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + torch_istft_2::write(op, captured_params, captured_attrs); + + op->params["1"] = 0; // returns + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_istft_3, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torch_stft.cpp b/tools/pnnx/src/pass_ncnn/torch_stft.cpp new file mode 100644 index 000000000000..2b2296ccbc2c --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torch_stft.cpp @@ -0,0 +1,176 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +class torch_stft : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +torch.stft op_0 1 1 input a center=%center pad_mode=%pad_mode hop_length=%hop_length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length window=None +torch.view_as_real op_1 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Spectrogram"; + } + + const char* name_str() const + { + return "stft"; + } + + void write(Operator* op, const std::map& captured_params) const + { + const std::string& pad_mode = captured_params.at("pad_mode").s; + int pad_type = 2; + if (pad_mode == "constant") + pad_type = 0; + if (pad_mode == "replicate") + pad_type = 1; + if (pad_mode == "reflect") + pad_type = 2; + const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1; + + op->params["0"] = captured_params.at("n_fft"); + op->params["1"] = 0; // power + op->params["2"] = captured_params.at("hop_length"); + op->params["3"] = captured_params.at("win_length"); + op->params["4"] = 0; // all ones + op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; + op->params["6"] = pad_type; + op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0; + op->params["8"] = onesided; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_stft, 20) + +static bool NearlyEqual(float a, float b, float epsilon) +{ + if (a == b) + return true; + + float diff = (float)fabs(a - b); + if (diff <= epsilon) + return true; + + // relative error + return diff < epsilon * std::max(fabs(a), fabs(b)); +} + +static int detect_window_type(const std::vector& window_data) +{ + const int winlen = (int)window_data.size(); + + bool is_one = true; + bool is_hann = true; + bool is_hamming = true; + for (int i = 0; i < winlen; i++) + { + if (!NearlyEqual(window_data[i], 1.f, 0.001)) + is_one = false; + + if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001)) + is_hann = false; + + if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001)) + is_hamming = false; + } + + if (is_one) + return 0; + if (is_hann) + return 1; + if (is_hamming) + return 2; + + return -1; +} + +class torch_stft_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 window @data +torch.stft op_1 2 1 input window a center=%center pad_mode=%pad_mode hop_length=%hop_length n_fft=%n_fft normalized=%normalized onesided=%onesided return_complex=True win_length=%win_length +torch.view_as_real op_2 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Spectrogram"; + } + + const char* name_str() const + { + return "stft"; + } + + bool match(const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + const std::vector window_data = captured_attrs.at("op_0.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + return window_type != -1; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector window_data = captured_attrs.at("op_0.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + + const std::string& pad_mode = captured_params.at("pad_mode").s; + int pad_type = 2; + if (pad_mode == "constant") + pad_type = 0; + if (pad_mode == "replicate") + pad_type = 1; + if (pad_mode == "reflect") + pad_type = 2; + const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1; + + op->params["0"] = captured_params.at("n_fft"); + op->params["1"] = 0; // power + op->params["2"] = captured_params.at("hop_length"); + op->params["3"] = captured_params.at("win_length"); + op->params["4"] = window_type; + op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; + op->params["6"] = pad_type; + op->params["7"] = captured_params.at("normalized").type == 1 && captured_params.at("normalized").b ? 1 : 0; + op->params["8"] = onesided; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_stft_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torchaudio_F_inverse_spectrogram.cpp b/tools/pnnx/src/pass_ncnn/torchaudio_F_inverse_spectrogram.cpp new file mode 100644 index 000000000000..d712fcc2990f --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torchaudio_F_inverse_spectrogram.cpp @@ -0,0 +1,127 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +static bool NearlyEqual(float a, float b, float epsilon) +{ + if (a == b) + return true; + + float diff = (float)fabs(a - b); + if (diff <= epsilon) + return true; + + // relative error + return diff < epsilon * std::max(fabs(a), fabs(b)); +} + +static int detect_window_type(const std::vector& window_data) +{ + const int winlen = (int)window_data.size(); + + bool is_one = true; + bool is_hann = true; + bool is_hamming = true; + for (int i = 0; i < winlen; i++) + { + if (!NearlyEqual(window_data[i], 1.f, 0.001)) + is_one = false; + + if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001)) + is_hann = false; + + if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001)) + is_hamming = false; + } + + if (is_one) + return 0; + if (is_hann) + return 1; + if (is_hamming) + return 2; + + return -1; +} + +class torchaudio_F_inverse_spectrogram : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 window @data +torch.view_as_complex op_1 1 1 input a +torchaudio.functional.inverse_spectrogram op_2 2 1 a window out center=%center hop_length=%hop_length length=None n_fft=%n_fft normalized=%normalized onesided=%onesided pad=0 win_length=%win_length +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "InverseSpectrogram"; + } + + const char* name_str() const + { + return "inverse_spectrogram"; + } + + bool match(const std::map& /*captured_params*/, const std::map& captured_attrs) const + { + const std::vector window_data = captured_attrs.at("op_0.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + return window_type != -1; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector window_data = captured_attrs.at("op_0.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + + int normalized = 0; + if (captured_params.at("normalized").type == 1) + { + normalized = captured_params.at("normalized").b ? 2 : 0; + } + if (captured_params.at("normalized").type == 4) + { + if (captured_params.at("normalized").s == "frame_length") + normalized = 1; + if (captured_params.at("normalized").s == "window") + normalized = 2; + } + + op->params["0"] = captured_params.at("n_fft"); + op->params["1"] = 1; // returns + op->params["2"] = captured_params.at("hop_length"); + op->params["3"] = captured_params.at("win_length"); + op->params["4"] = window_type; + op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; + op->params["7"] = normalized; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torchaudio_F_inverse_spectrogram, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/torchaudio_F_spectrogram.cpp b/tools/pnnx/src/pass_ncnn/torchaudio_F_spectrogram.cpp new file mode 100644 index 000000000000..5c42dc191704 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/torchaudio_F_spectrogram.cpp @@ -0,0 +1,233 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +static bool NearlyEqual(float a, float b, float epsilon) +{ + if (a == b) + return true; + + float diff = (float)fabs(a - b); + if (diff <= epsilon) + return true; + + // relative error + return diff < epsilon * std::max(fabs(a), fabs(b)); +} + +static int detect_window_type(const std::vector& window_data) +{ + const int winlen = (int)window_data.size(); + + bool is_one = true; + bool is_hann = true; + bool is_hamming = true; + for (int i = 0; i < winlen; i++) + { + if (!NearlyEqual(window_data[i], 1.f, 0.001)) + is_one = false; + + if (!NearlyEqual(window_data[i], 0.5f * (1 - cos(2 * M_PI * i / winlen)), 0.001)) + is_hann = false; + + if (!NearlyEqual(window_data[i], 0.54f - 0.46f * cos(2 * M_PI * i / winlen), 0.001)) + is_hamming = false; + } + + if (is_one) + return 0; + if (is_hann) + return 1; + if (is_hamming) + return 2; + + return -1; +} + +class torchaudio_F_spectrogram : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +5 4 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 window @data +torchaudio.functional.spectrogram op_1 2 1 input window a n_fft=%n_fft hop_length=%hop_length win_length=%win_length onesided=%onesided power=%power normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode +torch.view_as_real op_2 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Spectrogram"; + } + + const char* name_str() const + { + return "spectrogram"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (captured_params.at("power").type != 0) + return false; + + const std::vector window_data = captured_attrs.at("op_0.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + return window_type != -1; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector window_data = captured_attrs.at("op_0.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + + const std::string& pad_mode = captured_params.at("pad_mode").s; + int pad_type = 2; + if (pad_mode == "constant") + pad_type = 0; + if (pad_mode == "replicate") + pad_type = 1; + if (pad_mode == "reflect") + pad_type = 2; + const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1; + int normalized = 0; + if (captured_params.at("normalized").type == 1) + { + normalized = captured_params.at("normalized").b ? 2 : 0; + } + if (captured_params.at("normalized").type == 4) + { + if (captured_params.at("normalized").s == "frame_length") + normalized = 1; + if (captured_params.at("normalized").s == "window") + normalized = 2; + } + + op->params["0"] = captured_params.at("n_fft"); + op->params["1"] = 0; // power + op->params["2"] = captured_params.at("hop_length"); + op->params["3"] = captured_params.at("win_length"); + op->params["4"] = window_type; + op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; + op->params["6"] = pad_type; + op->params["7"] = normalized; + op->params["8"] = onesided; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram, 20) + +class torchaudio_F_spectrogram_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +pnnx.Attribute op_0 0 1 window @data +torchaudio.functional.spectrogram op_1 2 1 input window out n_fft=%n_fft hop_length=%hop_length win_length=%win_length onesided=%onesided power=%power normalized=%normalized center=%center pad=%pad pad_mode=%pad_mode +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Spectrogram"; + } + + const char* name_str() const + { + return "spectrogram"; + } + + bool match(const std::map& captured_params, const std::map& captured_attrs) const + { + if (captured_params.at("power").type == 0) + return false; + + const std::vector window_data = captured_attrs.at("op_0.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + return window_type != -1; + } + + void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const + { + const std::vector window_data = captured_attrs.at("op_0.data").get_float32_data(); + const int window_type = detect_window_type(window_data); + + const std::string& pad_mode = captured_params.at("pad_mode").s; + int pad_type = 2; + if (pad_mode == "constant") + pad_type = 0; + if (pad_mode == "replicate") + pad_type = 1; + if (pad_mode == "reflect") + pad_type = 2; + const int onesided = captured_params.at("onesided").type == 1 && captured_params.at("onesided").b == false ? 0 : 1; + int normalized = 0; + if (captured_params.at("normalized").type == 1) + { + normalized = captured_params.at("normalized").b ? 2 : 0; + } + if (captured_params.at("normalized").type == 4) + { + if (captured_params.at("normalized").s == "frame_length") + normalized = 1; + if (captured_params.at("normalized").s == "window") + normalized = 2; + } + + int power = 0; + if (captured_params.at("power").type == 2) + { + power = captured_params.at("power").i; + if (power != 1 && power != 2) + fprintf(stderr, "unsupported spectrogram power %d\n", power); + } + if (captured_params.at("power").type == 3) + { + if (NearlyEqual(captured_params.at("power").f, 1.0, 0.0001)) + power = 1; + else if (NearlyEqual(captured_params.at("power").f, 2.0, 0.0001)) + power = 2; + else + fprintf(stderr, "unsupported spectrogram power %f\n", captured_params.at("power").f); + } + + op->params["0"] = captured_params.at("n_fft"); + op->params["1"] = power; + op->params["2"] = captured_params.at("hop_length"); + op->params["3"] = captured_params.at("win_length"); + op->params["4"] = window_type; + op->params["5"] = captured_params.at("center").type == 1 && captured_params.at("center").b ? 1 : 0; + op->params["6"] = pad_type; + op->params["7"] = normalized; + op->params["8"] = onesided; + } +}; + +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torchaudio_F_spectrogram_1, 20) + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 0dd566c37b58..88266525a46c 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -360,6 +360,11 @@ if(TorchVision_FOUND) pnnx_add_test(torchvision_RoIAlign) endif() +pnnx_add_test(torchaudio_F_inverse_spectrogram) +pnnx_add_test(torchaudio_F_spectrogram) +pnnx_add_test(torchaudio_InverseSpectrogram) +pnnx_add_test(torchaudio_Spectrogram) + add_subdirectory(ncnn) if(onnxruntime_FOUND) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index 49cb063f335e..42c3bed32e05 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -175,6 +175,9 @@ pnnx_ncnn_add_test(torch_transpose) pnnx_ncnn_add_test(torch_unbind) pnnx_ncnn_add_test(torch_unsqueeze) +pnnx_ncnn_add_test(torch_istft) +pnnx_ncnn_add_test(torch_stft) + pnnx_ncnn_add_test(torch_abs) pnnx_ncnn_add_test(torch_acos) pnnx_ncnn_add_test(torch_asin) @@ -217,3 +220,8 @@ pnnx_ncnn_add_test(ncnn_numpy_binaryop_broadcast) if(TorchVision_FOUND) pnnx_ncnn_add_test(torchvision_DeformConv2d) endif() + +pnnx_ncnn_add_test(torchaudio_F_inverse_spectrogram) +pnnx_ncnn_add_test(torchaudio_F_spectrogram) +pnnx_ncnn_add_test(torchaudio_InverseSpectrogram) +pnnx_ncnn_add_test(torchaudio_Spectrogram) diff --git a/tools/pnnx/tests/ncnn/test_torch_istft.py b/tools/pnnx/tests/ncnn/test_torch_istft.py new file mode 100644 index 000000000000..bda4ab72f6aa --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_istft.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = torch.view_as_complex(x) + y = torch.view_as_complex(y) + z = torch.view_as_complex(z) + w = torch.view_as_complex(w) + out0 = torch.istft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=False) + out1 = torch.istft(y, n_fft=128, center=False, onesided=True, return_complex=False) + out2 = torch.istft(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, onesided=True, return_complex=False) + out3 = torch.istft(w, n_fft=512, center=False, onesided=False, return_complex=True) + out3 = torch.view_as_real(out3) + return out0, out1, out2, out3 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(33, 161, 2) + y = torch.rand(65, 77, 2) + z = torch.rand(257, 8, 2) + w = torch.rand(512, 4, 2) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torch_istft.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_istft.pt inputshape=[33,161,2],[65,77,2],[257,8,2],[512,4,2]") + + # ncnn inference + import test_torch_istft_ncnn + b = test_torch_istft_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-3, 1e-3): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torch_stft.py b/tools/pnnx/tests/ncnn/test_torch_stft.py new file mode 100644 index 000000000000..ac403ee7e769 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_stft.py @@ -0,0 +1,65 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + out0 = torch.stft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=True) + out1 = torch.stft(x, n_fft=128, center=False, onesided=True, return_complex=True) + out2 = torch.stft(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, pad_mode='constant', onesided=True, return_complex=True) + out3 = torch.stft(y, n_fft=512, center=True, onesided=False, return_complex=True) + out0 = torch.view_as_real(out0) + out1 = torch.view_as_real(out1) + out2 = torch.view_as_real(out2) + out3 = torch.view_as_real(out3) + return out0, out1, out2, out3 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(2560) + y = torch.rand(1000) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torch_stft.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_stft.pt inputshape=[2560],[1000]") + + # ncnn inference + import test_torch_stft_ncnn + b = test_torch_stft_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-3, 1e-3): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torchaudio_F_inverse_spectrogram.py b/tools/pnnx/tests/ncnn/test_torchaudio_F_inverse_spectrogram.py new file mode 100644 index 000000000000..5c19cca1d1a7 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torchaudio_F_inverse_spectrogram.py @@ -0,0 +1,72 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + x = torch.view_as_complex(x) + y = torch.view_as_complex(y) + z = torch.view_as_complex(z) + w = torch.view_as_complex(w) + out0 = torchaudio.functional.inverse_spectrogram(x, n_fft=64, window=torch.hann_window(44), win_length=44, hop_length=16, pad=0, center=True, normalized='window', length=None) + out1 = torchaudio.functional.inverse_spectrogram(y, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=True, onesided=True, normalized=False, length=None) + out2 = torchaudio.functional.inverse_spectrogram(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, onesided=True, normalized='frame_length', length=None) + out3 = torchaudio.functional.inverse_spectrogram(w, n_fft=1024, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=0, center=True, onesided=True, normalized=False, length=None) + return out0, out1, out2, out3 + +def test(): + if version.parse(torchaudio.__version__) < version.parse('0.10.0'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(33, 161, 2) + y = torch.rand(65, 77, 2) + z = torch.rand(257, 8, 2) + w = torch.rand(513, 4, 2) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torchaudio_F_inverse_spectrogram.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torchaudio_F_inverse_spectrogram.pt inputshape=[33,161,2],[65,77,2],[257,8,2],[513,4,2]") + + # ncnn inference + import test_torchaudio_F_inverse_spectrogram_ncnn + b = test_torchaudio_F_inverse_spectrogram_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-3, 1e-3): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torchaudio_F_spectrogram.py b/tools/pnnx/tests/ncnn/test_torchaudio_F_spectrogram.py new file mode 100644 index 000000000000..379fa4723a67 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torchaudio_F_spectrogram.py @@ -0,0 +1,63 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + out0 = torchaudio.functional.spectrogram(x, n_fft=64, window=torch.hann_window(44), win_length=44, hop_length=16, pad=0, center=True, normalized='window', power=1) + out1 = torchaudio.functional.spectrogram(x, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None) + out2 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2) + out3 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2) + out1 = torch.view_as_real(out1) + return out0, out1, out2, out3 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(2560) + y = torch.rand(1000) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torchaudio_F_spectrogram.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torchaudio_F_spectrogram.pt inputshape=[2560],[1000]") + + # ncnn inference + import test_torchaudio_F_spectrogram_ncnn + b = test_torchaudio_F_spectrogram_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-3, 1e-3): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torchaudio_InverseSpectrogram.py b/tools/pnnx/tests/ncnn/test_torchaudio_InverseSpectrogram.py new file mode 100644 index 000000000000..e48c4d6411cc --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torchaudio_InverseSpectrogram.py @@ -0,0 +1,77 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.s0 = torchaudio.transforms.InverseSpectrogram(n_fft=64, window_fn=torch.hann_window, win_length=44, hop_length=16, pad=0, center=True, normalized='window') + self.s1 = torchaudio.transforms.InverseSpectrogram(n_fft=128, window_fn=torch.hann_window, win_length=128, hop_length=3, pad=0, center=True, onesided=True, normalized=False) + self.s2 = torchaudio.transforms.InverseSpectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=256, hop_length=128, pad=0, center=True, onesided=True, normalized='frame_length') + self.s3 = torchaudio.transforms.InverseSpectrogram(n_fft=1024, window_fn=torch.hamming_window, win_length=512, hop_length=128, pad=0, center=True, onesided=True, normalized=False) + + def forward(self, x, y, z, w): + x = torch.view_as_complex(x) + y = torch.view_as_complex(y) + z = torch.view_as_complex(z) + w = torch.view_as_complex(w) + out0 = self.s0(x) + out1 = self.s1(y) + out2 = self.s2(z) + out3 = self.s3(w) + return out0, out1, out2, out3 + +def test(): + if version.parse(torchaudio.__version__) < version.parse('0.10.0'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(33, 161, 2) + y = torch.rand(65, 77, 2) + z = torch.rand(257, 8, 2) + w = torch.rand(513, 4, 2) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torchaudio_InverseSpectrogram.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torchaudio_InverseSpectrogram.pt inputshape=[33,161,2],[65,77,2],[257,8,2],[513,4,2]") + + # ncnn inference + import test_torchaudio_InverseSpectrogram_ncnn + b = test_torchaudio_InverseSpectrogram_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-3, 1e-3): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/ncnn/test_torchaudio_Spectrogram.py b/tools/pnnx/tests/ncnn/test_torchaudio_Spectrogram.py new file mode 100644 index 000000000000..d054b84de618 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torchaudio_Spectrogram.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.s0 = torchaudio.transforms.Spectrogram(n_fft=64, window_fn=torch.hann_window, win_length=44, hop_length=16, pad=0, center=True, normalized='window', power=1) + self.s1 = torchaudio.transforms.Spectrogram(n_fft=128, window_fn=torch.hann_window, win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None) + self.s2 = torchaudio.transforms.Spectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2) + self.s3 = torchaudio.transforms.Spectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2) + + def forward(self, x, y): + out0 = self.s0(x) + out1 = self.s1(x) + out2 = self.s2(y) + out3 = self.s3(y) + out1 = torch.view_as_real(out1) + return out0, out1, out2, out3 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(2560) + y = torch.rand(1000) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torchaudio_Spectrogram.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torchaudio_Spectrogram.pt inputshape=[2560],[1000]") + + # ncnn inference + import test_torchaudio_Spectrogram_ncnn + b = test_torchaudio_Spectrogram_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-3, 1e-3): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_istft.py b/tools/pnnx/tests/test_torch_istft.py index bde15e8f3353..bdcabd966bae 100644 --- a/tools/pnnx/tests/test_torch_istft.py +++ b/tools/pnnx/tests/test_torch_istft.py @@ -21,9 +21,9 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y, z, w): - out0 = torch.istft(x, n_fft=64, center=True, normalized=True, return_complex=False) + out0 = torch.istft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=False) out1 = torch.istft(y, n_fft=128, center=False, onesided=True, return_complex=False) - out2 = torch.istft(z, n_fft=512, center=True, onesided=True, return_complex=False) + out2 = torch.istft(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, onesided=True, return_complex=False) out3 = torch.istft(w, n_fft=512, center=False, onesided=False, return_complex=True) return out0, out1, out2, out3 @@ -52,7 +52,7 @@ def test(): b = test_torch_istft_pnnx.test_inference() for a0, b0 in zip(a, b): - if not torch.equal(a0, b0): + if not torch.allclose(a0, b0, 1e-4, 1e-4): return False return True diff --git a/tools/pnnx/tests/test_torch_stft.py b/tools/pnnx/tests/test_torch_stft.py index 4d347c579004..cc280f3e84d0 100644 --- a/tools/pnnx/tests/test_torch_stft.py +++ b/tools/pnnx/tests/test_torch_stft.py @@ -21,10 +21,10 @@ def __init__(self): super(Model, self).__init__() def forward(self, x, y): - out0 = torch.stft(x, n_fft=64, center=True, pad_mode='reflect', normalized=True, return_complex=True) + out0 = torch.stft(x, n_fft=64, window=torch.hann_window(44), win_length=44, center=True, normalized=True, return_complex=True) out1 = torch.stft(x, n_fft=128, center=False, onesided=True, return_complex=True) - out2 = torch.stft(y, n_fft=512, center=True, pad_mode='constant', onesided=True, return_complex=True) - out3 = torch.stft(y, n_fft=512, center=False, onesided=False, return_complex=True) + out2 = torch.stft(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, center=True, pad_mode='constant', onesided=True, return_complex=True) + out3 = torch.stft(y, n_fft=512, center=True, onesided=False, return_complex=True) return out0, out1, out2, out3 def test(): @@ -50,7 +50,7 @@ def test(): b = test_torch_stft_pnnx.test_inference() for a0, b0 in zip(a, b): - if not torch.equal(a0, b0): + if not torch.allclose(a0, b0, 1e-4, 1e-4): return False return True diff --git a/tools/pnnx/tests/test_torchaudio_F_inverse_spectrogram.py b/tools/pnnx/tests/test_torchaudio_F_inverse_spectrogram.py new file mode 100644 index 000000000000..92623934a38d --- /dev/null +++ b/tools/pnnx/tests/test_torchaudio_F_inverse_spectrogram.py @@ -0,0 +1,68 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, w): + out0 = torchaudio.functional.inverse_spectrogram(x, n_fft=64, window=torch.hann_window(44), win_length=44, hop_length=16, pad=0, center=True, normalized='window', length=None) + out1 = torchaudio.functional.inverse_spectrogram(y, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=True, onesided=True, normalized=False, length=None) + out2 = torchaudio.functional.inverse_spectrogram(z, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, onesided=True, normalized='frame_length', length=None) + out3 = torchaudio.functional.inverse_spectrogram(w, n_fft=512, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=0, center=True, onesided=False, normalized=False, length=None) + return out0, out1, out2, out3 + +def test(): + if version.parse(torchaudio.__version__) < version.parse('0.10.0'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 33, 161, dtype=torch.complex64) + y = torch.rand(1, 65, 77, dtype=torch.complex64) + z = torch.rand(257, 8, dtype=torch.complex64) + w = torch.rand(512, 4, dtype=torch.complex64) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torchaudio_F_inverse_spectrogram.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torchaudio_F_inverse_spectrogram.pt inputshape=[3,33,161]c64,[1,65,77]c64,[257,8]c64,[512,4]c64") + + # pnnx inference + import test_torchaudio_F_inverse_spectrogram_pnnx + b = test_torchaudio_F_inverse_spectrogram_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torchaudio_F_spectrogram.py b/tools/pnnx/tests/test_torchaudio_F_spectrogram.py new file mode 100644 index 000000000000..ec5a5486c5d9 --- /dev/null +++ b/tools/pnnx/tests/test_torchaudio_F_spectrogram.py @@ -0,0 +1,62 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + out0 = torchaudio.functional.spectrogram(x, n_fft=64, window=torch.hann_window(44), win_length=44, hop_length=16, pad=0, center=True, normalized='window', power=1) + out1 = torchaudio.functional.spectrogram(x, n_fft=128, window=torch.hann_window(128), win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None) + out2 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(256), win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2) + out3 = torchaudio.functional.spectrogram(y, n_fft=512, window=torch.hamming_window(512), win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2) + return out0, out1, out2, out3 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 2560) + y = torch.rand(1000) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torchaudio_F_spectrogram.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torchaudio_F_spectrogram.pt inputshape=[3,2560],[1000]") + + # pnnx inference + import test_torchaudio_F_spectrogram_pnnx + b = test_torchaudio_F_spectrogram_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torchaudio_InverseSpectrogram.py b/tools/pnnx/tests/test_torchaudio_InverseSpectrogram.py new file mode 100644 index 000000000000..7080ddd1267b --- /dev/null +++ b/tools/pnnx/tests/test_torchaudio_InverseSpectrogram.py @@ -0,0 +1,73 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.s0 = torchaudio.transforms.InverseSpectrogram(n_fft=64, window_fn=torch.hann_window, win_length=44, hop_length=16, pad=0, center=True, normalized='window') + self.s1 = torchaudio.transforms.InverseSpectrogram(n_fft=128, window_fn=torch.hann_window, win_length=128, hop_length=3, pad=0, center=True, onesided=True, normalized=False) + self.s2 = torchaudio.transforms.InverseSpectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=256, hop_length=128, pad=0, center=True, onesided=True, normalized='frame_length') + self.s3 = torchaudio.transforms.InverseSpectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=512, hop_length=128, pad=0, center=True, onesided=False, normalized=False) + + def forward(self, x, y, z, w): + out0 = self.s0(x) + out1 = self.s1(y) + out2 = self.s2(z) + out3 = self.s3(w) + return out0, out1, out2, out3 + +def test(): + if version.parse(torchaudio.__version__) < version.parse('0.10.0'): + return True + + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 33, 161, dtype=torch.complex64) + y = torch.rand(1, 65, 77, dtype=torch.complex64) + z = torch.rand(257, 8, dtype=torch.complex64) + w = torch.rand(512, 4, dtype=torch.complex64) + + a = net(x, y, z, w) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, w)) + mod.save("test_torchaudio_InverseSpectrogram.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torchaudio_InverseSpectrogram.pt inputshape=[3,33,161]c64,[1,65,77]c64,[257,8]c64,[512,4]c64") + + # pnnx inference + import test_torchaudio_InverseSpectrogram_pnnx + b = test_torchaudio_InverseSpectrogram_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torchaudio_Spectrogram.py b/tools/pnnx/tests/test_torchaudio_Spectrogram.py new file mode 100644 index 000000000000..e4887050c95b --- /dev/null +++ b/tools/pnnx/tests/test_torchaudio_Spectrogram.py @@ -0,0 +1,67 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.s0 = torchaudio.transforms.Spectrogram(n_fft=64, window_fn=torch.hann_window, win_length=44, hop_length=16, pad=0, center=True, normalized='window', power=1) + self.s1 = torchaudio.transforms.Spectrogram(n_fft=128, window_fn=torch.hann_window, win_length=128, hop_length=3, pad=0, center=False, onesided=True, normalized=False, power=None) + self.s2 = torchaudio.transforms.Spectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=256, hop_length=128, pad=0, center=True, pad_mode='constant', onesided=True, normalized='frame_length', power=2) + self.s3 = torchaudio.transforms.Spectrogram(n_fft=512, window_fn=torch.hamming_window, win_length=512, hop_length=128, pad=32, center=True, onesided=False, normalized=False, power=2) + + def forward(self, x, y): + out0 = self.s0(x) + out1 = self.s1(x) + out2 = self.s2(y) + out3 = self.s3(y) + return out0, out1, out2, out3 + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 2560) + y = torch.rand(1000) + + a = net(x, y) + + # export torchscript + mod = torch.jit.trace(net, (x, y)) + mod.save("test_torchaudio_Spectrogram.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torchaudio_Spectrogram.pt inputshape=[3,2560],[1000]") + + # pnnx inference + import test_torchaudio_Spectrogram_pnnx + b = test_torchaudio_Spectrogram_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)