diff --git a/CMakeLists.txt b/CMakeLists.txt index 5ed85ef..378f479 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,6 +43,8 @@ target_link_libraries(zimtohrli_base PRIVATE absl::check) target_link_libraries(zimtohrli_base PUBLIC hwy portaudio absl::statusor absl::span sndfile) add_library(zimtohrli_visqol_adapter STATIC + cpp/zimt/visqol_model.h + cpp/zimt/visqol_model.cc cpp/zimt/visqol.h cpp/zimt/visqol.cc cpp/zimt/resample.h diff --git a/cpp/zimt/goohrli.cc b/cpp/zimt/goohrli.cc index 979c855..df4cd3e 100644 --- a/cpp/zimt/goohrli.cc +++ b/cpp/zimt/goohrli.cc @@ -121,8 +121,6 @@ void SetPerceptualSampleRate(Zimtohrli zimtohrli, float f) { static_cast(zimtohrli)->perceptual_sample_rate = f; } -typedef void* ViSQOL; - ViSQOL CreateViSQOL() { return new zimtohrli::ViSQOL(); } void FreeViSQOL(ViSQOL v) { delete (zimtohrli::ViSQOL*)(v); } diff --git a/cpp/zimt/visqol.cc b/cpp/zimt/visqol.cc index 5b3e7ed..53f640e 100644 --- a/cpp/zimt/visqol.cc +++ b/cpp/zimt/visqol.cc @@ -18,9 +18,9 @@ #include "absl/log/check.h" #include "absl/types/span.h" -#include "libsvm_nu_svr_model.h" #include "visqol_api.h" #include "zimt/resample.h" +#include "zimt/visqol_model.h" constexpr size_t SAMPLE_RATE = 48000; @@ -39,8 +39,8 @@ ViSQOL::ViSQOL() { populated_path_template.data(), populated_path_template.size())); std::ofstream output_stream(model_path_); CHECK(output_stream.good()); - output_stream.write(reinterpret_cast(visqol_model_bytes), - visqol_model_bytes_len); + absl::Span model = ViSQOLModel(); + output_stream.write(model.data(), model.size()); CHECK(output_stream.good()); output_stream.close(); CHECK(output_stream.good()); @@ -82,7 +82,7 @@ float ViSQOL::MOS(absl::Span reference, visqol.Measure(absl::Span(resampled_reference.data(), resampled_reference.size()), absl::Span(resampled_degraded.data(), - resampled_reference.size())); + resampled_degraded.size())); CHECK_OK(comparison_status_or); Visqol::SimilarityResultMsg similarity_result = comparison_status_or.value(); diff --git a/cpp/zimt/visqol_model.cc b/cpp/zimt/visqol_model.cc new file mode 100644 index 0000000..ea5afc6 --- /dev/null +++ b/cpp/zimt/visqol_model.cc @@ -0,0 +1,27 @@ +// Copyright 2024 The Zimtohrli Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "visqol_model.h" + +#include "absl/types/span.h" +#include "libsvm_nu_svr_model.h" + +namespace zimtohrli { + +absl::Span ViSQOLModel() { + return absl::Span(reinterpret_cast(visqol_model_bytes), + visqol_model_bytes_len); +} + +} // namespace zimtohrli \ No newline at end of file diff --git a/cpp/zimt/visqol_model.h b/cpp/zimt/visqol_model.h new file mode 100644 index 0000000..f4479aa --- /dev/null +++ b/cpp/zimt/visqol_model.h @@ -0,0 +1,27 @@ +// Copyright 2024 The Zimtohrli Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef CPP_ZIMT_VISQOL_MODEL_H_ +#define CPP_ZIMT_VISQOL_MODEL_H_ + +#include "absl/types/span.h" + +namespace zimtohrli { + +// Returns the bytes of the default ViSQOL model. +absl::Span ViSQOLModel(); + +} // namespace zimtohrli + +#endif // CPP_ZIMT_VISQOL_MODEL_H_ diff --git a/go/bin/compare/compare.go b/go/bin/compare/compare.go index 5aa138e..638c2d5 100644 --- a/go/bin/compare/compare.go +++ b/go/bin/compare/compare.go @@ -28,9 +28,9 @@ import ( func main() { pathA := flag.String("path_a", "", "Path to ffmpeg-decodable file with signal A.") pathB := flag.String("path_b", "", "Path to ffmpeg-decodable file with signal B.") + visqol := flag.Bool("visqol", false, "Whether to use ViSQOL instead of Zimtohrli metrics.") outputZimtohrliDistance := flag.Bool("output_zimtohrli_distance", false, "Whether to output the raw Zimtohrli distance instead of a mapped mean opinion score.") perChannel := flag.Bool("per_channel", false, "Whether to output the produced metric per channel instead of a single value for all channels.") - frequencyResolution := flag.Float64("frequency_resolution", float64(goohrli.DefaultFrequencyResolution()), "Band width of smallest filter, i.e. expected frequency resolution of human hearing.") flag.Parse() if *pathA == "" || *pathB == "" { @@ -55,25 +55,41 @@ func main() { log.Panic(fmt.Errorf("%q has %v channels, and %q has %v channels", *pathA, len(signalA.Samples), *pathB, len(signalB.Samples))) } - getMetric := func(f float64) float64 { - if *outputZimtohrliDistance { - return f - } - return goohrli.MOSFromZimtohrli(f) - } - - g := goohrli.New(signalA.Rate, *frequencyResolution) - if *perChannel { - for channelIndex := range signalA.Samples { - measurement := goohrli.Measure(signalA.Samples[channelIndex]) - goohrli.NormalizeAmplitude(measurement.MaxAbsAmplitude, signalB.Samples[channelIndex]) - fmt.Println(getMetric(g.Distance(signalA.Samples[channelIndex], signalB.Samples[channelIndex]))) + if *visqol { + v := goohrli.NewViSQOL() + if *perChannel { + for channelIndex := range signalA.Samples { + mos := v.MOS(signalA.Rate, signalA.Samples[channelIndex], signalB.Samples[channelIndex]) + fmt.Println(mos) + } + } else { + mos, err := v.AudioMOS(signalA, signalB) + if err != nil { + log.Panic(err) + } + fmt.Println(mos) } } else { - dist, err := g.NormalizedAudioDistance(signalA, signalB) - if err != nil { - log.Panic(err) + getMetric := func(f float64) float64 { + if *outputZimtohrliDistance { + return f + } + return goohrli.MOSFromZimtohrli(f) + } + + g := goohrli.New(signalA.Rate, goohrli.DefaultFrequencyResolution()) + if *perChannel { + for channelIndex := range signalA.Samples { + measurement := goohrli.Measure(signalA.Samples[channelIndex]) + goohrli.NormalizeAmplitude(measurement.MaxAbsAmplitude, signalB.Samples[channelIndex]) + fmt.Println(getMetric(g.Distance(signalA.Samples[channelIndex], signalB.Samples[channelIndex]))) + } + } else { + dist, err := g.NormalizedAudioDistance(signalA, signalB) + if err != nil { + log.Panic(err) + } + fmt.Println(getMetric(dist)) } - fmt.Println(getMetric(dist)) } } diff --git a/go/bin/score/score.go b/go/bin/score/score.go index 1213025..3de267b 100644 --- a/go/bin/score/score.go +++ b/go/bin/score/score.go @@ -40,8 +40,8 @@ func main() { calculate := flag.String("calculate", "", "Path to a database directory with a study to calculate metrics for.") calculateZimtohrli := flag.Bool("calculate_zimtohrli", true, "Whether to calculate Zimtohrli scores.") calculateViSQOL := flag.Bool("calculate_visqol", false, "Whether to calculate ViSQOL scores.") - zimtohrliFrequencyResolution := flag.Float64("zimtohrli_frequency_resolution", float64(goohrli.DefaultFrequencyResolution()), "Smallest bandwidth of the Zimtohrli filterbank.") - zimtohrliPerceptualSampleRate := flag.Float64("zimtohrli_perceptual_sample_rate", float64(goohrli.DefaultPerceptualSampleRate()), "Sample rate of the Zimtohrli spectrograms.") + zimtohrliFrequencyResolution := flag.Float64("zimtohrli_frequency_resolution", goohrli.DefaultFrequencyResolution(), "Smallest bandwidth of the Zimtohrli filterbank.") + zimtohrliPerceptualSampleRate := flag.Float64("zimtohrli_perceptual_sample_rate", goohrli.DefaultPerceptualSampleRate(), "Sample rate of the Zimtohrli spectrograms.") correlate := flag.String("correlate", "", "Path to a database directory with a study to correlate scores for.") hist := flag.String("hist", "", "Path to a database directory with a study to provide JND histograms for.") histThresholds := flag.String("hist_thresholds", "0,5,10,15,20,25,30,35,40,45,50,55,60,65,70,75,80,85,90,95,100", "A comma separated list of Zimtohrli distance thresholds to compute histograms for.") diff --git a/go/goohrli/goohrli.a b/go/goohrli/goohrli.a index fdddff9..733e8fc 100644 Binary files a/go/goohrli/goohrli.a and b/go/goohrli/goohrli.a differ diff --git a/go/goohrli/goohrli.go b/go/goohrli/goohrli.go index 5dc7a89..113a3c9 100644 --- a/go/goohrli/goohrli.go +++ b/go/goohrli/goohrli.go @@ -31,13 +31,13 @@ import ( ) // DefaultFrequencyResolution returns the default frequency resolution corresponding to the minimum width (at the low frequency end) of the Zimtohrli filter bank. -func DefaultFrequencyResolution() float32 { - return float32(C.DefaultFrequencyResolution()) +func DefaultFrequencyResolution() float64 { + return float64(C.DefaultFrequencyResolution()) } // DefaultPerceptualSampleRate returns the default perceptual sample rate corresponding to the human hearing sensitivity to timing changes. -func DefaultPerceptualSampleRate() float32 { - return float32(C.DefaultPerceptualSampleRate()) +func DefaultPerceptualSampleRate() float64 { + return float64(C.DefaultPerceptualSampleRate()) } // EnergyAndMaxAbsAmplitude is holds the energy and maximum absolute amplitude of a measurement. @@ -200,12 +200,12 @@ func NewViSQOL() *ViSQOL { } // MOS returns the ViSQOL mean opinion score of the degraded samples comapred to the reference samples. -func (g *ViSQOL) MOS(sampleRate float64, reference []float32, degraded []float32) float64 { - return float64(C.MOS(g.visqol, C.float(sampleRate), (*C.float)(&reference[0]), C.int(len(reference)), (*C.float)(°raded[0]), C.int(len(degraded)))) +func (v *ViSQOL) MOS(sampleRate float64, reference []float32, degraded []float32) float64 { + return float64(C.MOS(v.visqol, C.float(sampleRate), (*C.float)(&reference[0]), C.int(len(reference)), (*C.float)(°raded[0]), C.int(len(degraded)))) } // AudioMOS returns the ViSQOL mean opinion score of the degraded audio compared to the reference audio. -func (g *ViSQOL) AudioMOS(reference, degraded *audio.Audio) (float64, error) { +func (v *ViSQOL) AudioMOS(reference, degraded *audio.Audio) (float64, error) { sumOfSquares := 0.0 if reference.Rate != degraded.Rate { return 0, fmt.Errorf("the audio files don't have the same sample rate: %v, %v", reference.Rate, degraded.Rate) @@ -214,7 +214,7 @@ func (g *ViSQOL) AudioMOS(reference, degraded *audio.Audio) (float64, error) { return 0, fmt.Errorf("the audio files don't have the same number of channels: %v, %v", len(reference.Samples), len(degraded.Samples)) } for channelIndex := range reference.Samples { - mos := g.MOS(reference.Rate, reference.Samples[channelIndex], degraded.Samples[channelIndex]) + mos := v.MOS(reference.Rate, reference.Samples[channelIndex], degraded.Samples[channelIndex]) sumOfSquares += mos * mos } return math.Sqrt(sumOfSquares / float64(len(reference.Samples))), nil