Skip to content

Commit

Permalink
Minor improvements.
Browse files Browse the repository at this point in the history
- Extracted the xxd-provided ViSQOL model bytes to a separate versioned
  C++ file to simplify providing it by other means.
- Removed the redundant ViSQOL typedef in goohrli.cc.
- Added ViSQOL score support to compare.go.
- Changed more float32's to float64's in Go land for uniformity.
- Fixed the bug in visqol.cc where we use the length of the reference
  audio for the data of the degraded audio.
  • Loading branch information
zond committed Apr 26, 2024
1 parent 642865b commit 801d813
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 34 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ target_link_libraries(zimtohrli_base PRIVATE absl::check)
target_link_libraries(zimtohrli_base PUBLIC hwy portaudio absl::statusor absl::span sndfile)

add_library(zimtohrli_visqol_adapter STATIC
cpp/zimt/visqol_model.h
cpp/zimt/visqol_model.cc
cpp/zimt/visqol.h
cpp/zimt/visqol.cc
cpp/zimt/resample.h
Expand Down
2 changes: 0 additions & 2 deletions cpp/zimt/goohrli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ void SetPerceptualSampleRate(Zimtohrli zimtohrli, float f) {
static_cast<zimtohrli::Zimtohrli*>(zimtohrli)->perceptual_sample_rate = f;
}

typedef void* ViSQOL;

ViSQOL CreateViSQOL() { return new zimtohrli::ViSQOL(); }

void FreeViSQOL(ViSQOL v) { delete (zimtohrli::ViSQOL*)(v); }
Expand Down
8 changes: 4 additions & 4 deletions cpp/zimt/visqol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

#include "absl/log/check.h"
#include "absl/types/span.h"
#include "libsvm_nu_svr_model.h"
#include "visqol_api.h"
#include "zimt/resample.h"
#include "zimt/visqol_model.h"

constexpr size_t SAMPLE_RATE = 48000;

Expand All @@ -39,8 +39,8 @@ ViSQOL::ViSQOL() {
populated_path_template.data(), populated_path_template.size()));
std::ofstream output_stream(model_path_);
CHECK(output_stream.good());
output_stream.write(reinterpret_cast<char*>(visqol_model_bytes),
visqol_model_bytes_len);
absl::Span<const char> model = ViSQOLModel();
output_stream.write(model.data(), model.size());
CHECK(output_stream.good());
output_stream.close();
CHECK(output_stream.good());
Expand Down Expand Up @@ -82,7 +82,7 @@ float ViSQOL::MOS(absl::Span<const float> reference,
visqol.Measure(absl::Span<double>(resampled_reference.data(),
resampled_reference.size()),
absl::Span<double>(resampled_degraded.data(),
resampled_reference.size()));
resampled_degraded.size()));
CHECK_OK(comparison_status_or);

Visqol::SimilarityResultMsg similarity_result = comparison_status_or.value();
Expand Down
27 changes: 27 additions & 0 deletions cpp/zimt/visqol_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2024 The Zimtohrli Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "visqol_model.h"

#include "absl/types/span.h"
#include "libsvm_nu_svr_model.h"

namespace zimtohrli {

absl::Span<char> ViSQOLModel() {
return absl::Span<char>(reinterpret_cast<char*>(visqol_model_bytes),
visqol_model_bytes_len);
}

} // namespace zimtohrli
27 changes: 27 additions & 0 deletions cpp/zimt/visqol_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2024 The Zimtohrli Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef CPP_ZIMT_VISQOL_MODEL_H_
#define CPP_ZIMT_VISQOL_MODEL_H_

#include "absl/types/span.h"

namespace zimtohrli {

// Returns the bytes of the default ViSQOL model.
absl::Span<char> ViSQOLModel();

} // namespace zimtohrli

#endif // CPP_ZIMT_VISQOL_MODEL_H_
52 changes: 34 additions & 18 deletions go/bin/compare/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import (
func main() {
pathA := flag.String("path_a", "", "Path to ffmpeg-decodable file with signal A.")
pathB := flag.String("path_b", "", "Path to ffmpeg-decodable file with signal B.")
visqol := flag.Bool("visqol", false, "Whether to use ViSQOL instead of Zimtohrli metrics.")
outputZimtohrliDistance := flag.Bool("output_zimtohrli_distance", false, "Whether to output the raw Zimtohrli distance instead of a mapped mean opinion score.")
perChannel := flag.Bool("per_channel", false, "Whether to output the produced metric per channel instead of a single value for all channels.")
frequencyResolution := flag.Float64("frequency_resolution", float64(goohrli.DefaultFrequencyResolution()), "Band width of smallest filter, i.e. expected frequency resolution of human hearing.")
flag.Parse()

if *pathA == "" || *pathB == "" {
Expand All @@ -55,25 +55,41 @@ func main() {
log.Panic(fmt.Errorf("%q has %v channels, and %q has %v channels", *pathA, len(signalA.Samples), *pathB, len(signalB.Samples)))
}

getMetric := func(f float64) float64 {
if *outputZimtohrliDistance {
return f
}
return goohrli.MOSFromZimtohrli(f)
}

g := goohrli.New(signalA.Rate, *frequencyResolution)
if *perChannel {
for channelIndex := range signalA.Samples {
measurement := goohrli.Measure(signalA.Samples[channelIndex])
goohrli.NormalizeAmplitude(measurement.MaxAbsAmplitude, signalB.Samples[channelIndex])
fmt.Println(getMetric(g.Distance(signalA.Samples[channelIndex], signalB.Samples[channelIndex])))
if *visqol {
v := goohrli.NewViSQOL()
if *perChannel {
for channelIndex := range signalA.Samples {
mos := v.MOS(signalA.Rate, signalA.Samples[channelIndex], signalB.Samples[channelIndex])
fmt.Println(mos)
}
} else {
mos, err := v.AudioMOS(signalA, signalB)
if err != nil {
log.Panic(err)
}
fmt.Println(mos)
}
} else {
dist, err := g.NormalizedAudioDistance(signalA, signalB)
if err != nil {
log.Panic(err)
getMetric := func(f float64) float64 {
if *outputZimtohrliDistance {
return f
}
return goohrli.MOSFromZimtohrli(f)
}

g := goohrli.New(signalA.Rate, goohrli.DefaultFrequencyResolution())
if *perChannel {
for channelIndex := range signalA.Samples {
measurement := goohrli.Measure(signalA.Samples[channelIndex])
goohrli.NormalizeAmplitude(measurement.MaxAbsAmplitude, signalB.Samples[channelIndex])
fmt.Println(getMetric(g.Distance(signalA.Samples[channelIndex], signalB.Samples[channelIndex])))
}
} else {
dist, err := g.NormalizedAudioDistance(signalA, signalB)
if err != nil {
log.Panic(err)
}
fmt.Println(getMetric(dist))
}
fmt.Println(getMetric(dist))
}
}
4 changes: 2 additions & 2 deletions go/bin/score/score.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ func main() {
calculate := flag.String("calculate", "", "Path to a database directory with a study to calculate metrics for.")
calculateZimtohrli := flag.Bool("calculate_zimtohrli", true, "Whether to calculate Zimtohrli scores.")
calculateViSQOL := flag.Bool("calculate_visqol", false, "Whether to calculate ViSQOL scores.")
zimtohrliFrequencyResolution := flag.Float64("zimtohrli_frequency_resolution", float64(goohrli.DefaultFrequencyResolution()), "Smallest bandwidth of the Zimtohrli filterbank.")
zimtohrliPerceptualSampleRate := flag.Float64("zimtohrli_perceptual_sample_rate", float64(goohrli.DefaultPerceptualSampleRate()), "Sample rate of the Zimtohrli spectrograms.")
zimtohrliFrequencyResolution := flag.Float64("zimtohrli_frequency_resolution", goohrli.DefaultFrequencyResolution(), "Smallest bandwidth of the Zimtohrli filterbank.")
zimtohrliPerceptualSampleRate := flag.Float64("zimtohrli_perceptual_sample_rate", goohrli.DefaultPerceptualSampleRate(), "Sample rate of the Zimtohrli spectrograms.")
correlate := flag.String("correlate", "", "Path to a database directory with a study to correlate scores for.")
hist := flag.String("hist", "", "Path to a database directory with a study to provide JND histograms for.")
histThresholds := flag.String("hist_thresholds", "0,5,10,15,20,25,30,35,40,45,50,55,60,65,70,75,80,85,90,95,100", "A comma separated list of Zimtohrli distance thresholds to compute histograms for.")
Expand Down
Binary file modified go/goohrli/goohrli.a
Binary file not shown.
16 changes: 8 additions & 8 deletions go/goohrli/goohrli.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ import (
)

// DefaultFrequencyResolution returns the default frequency resolution corresponding to the minimum width (at the low frequency end) of the Zimtohrli filter bank.
func DefaultFrequencyResolution() float32 {
return float32(C.DefaultFrequencyResolution())
func DefaultFrequencyResolution() float64 {
return float64(C.DefaultFrequencyResolution())
}

// DefaultPerceptualSampleRate returns the default perceptual sample rate corresponding to the human hearing sensitivity to timing changes.
func DefaultPerceptualSampleRate() float32 {
return float32(C.DefaultPerceptualSampleRate())
func DefaultPerceptualSampleRate() float64 {
return float64(C.DefaultPerceptualSampleRate())
}

// EnergyAndMaxAbsAmplitude is holds the energy and maximum absolute amplitude of a measurement.
Expand Down Expand Up @@ -200,12 +200,12 @@ func NewViSQOL() *ViSQOL {
}

// MOS returns the ViSQOL mean opinion score of the degraded samples comapred to the reference samples.
func (g *ViSQOL) MOS(sampleRate float64, reference []float32, degraded []float32) float64 {
return float64(C.MOS(g.visqol, C.float(sampleRate), (*C.float)(&reference[0]), C.int(len(reference)), (*C.float)(&degraded[0]), C.int(len(degraded))))
func (v *ViSQOL) MOS(sampleRate float64, reference []float32, degraded []float32) float64 {
return float64(C.MOS(v.visqol, C.float(sampleRate), (*C.float)(&reference[0]), C.int(len(reference)), (*C.float)(&degraded[0]), C.int(len(degraded))))
}

// AudioMOS returns the ViSQOL mean opinion score of the degraded audio compared to the reference audio.
func (g *ViSQOL) AudioMOS(reference, degraded *audio.Audio) (float64, error) {
func (v *ViSQOL) AudioMOS(reference, degraded *audio.Audio) (float64, error) {
sumOfSquares := 0.0
if reference.Rate != degraded.Rate {
return 0, fmt.Errorf("the audio files don't have the same sample rate: %v, %v", reference.Rate, degraded.Rate)
Expand All @@ -214,7 +214,7 @@ func (g *ViSQOL) AudioMOS(reference, degraded *audio.Audio) (float64, error) {
return 0, fmt.Errorf("the audio files don't have the same number of channels: %v, %v", len(reference.Samples), len(degraded.Samples))
}
for channelIndex := range reference.Samples {
mos := g.MOS(reference.Rate, reference.Samples[channelIndex], degraded.Samples[channelIndex])
mos := v.MOS(reference.Rate, reference.Samples[channelIndex], degraded.Samples[channelIndex])
sumOfSquares += mos * mos
}
return math.Sqrt(sumOfSquares / float64(len(reference.Samples))), nil
Expand Down

0 comments on commit 801d813

Please sign in to comment.