Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor improvements. #44

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ target_link_libraries(zimtohrli_base PRIVATE absl::check)
target_link_libraries(zimtohrli_base PUBLIC hwy portaudio absl::statusor absl::span sndfile)

add_library(zimtohrli_visqol_adapter STATIC
cpp/zimt/visqol_model.h
cpp/zimt/visqol_model.cc
cpp/zimt/visqol.h
cpp/zimt/visqol.cc
cpp/zimt/resample.h
Expand Down
2 changes: 0 additions & 2 deletions cpp/zimt/goohrli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ void SetPerceptualSampleRate(Zimtohrli zimtohrli, float f) {
static_cast<zimtohrli::Zimtohrli*>(zimtohrli)->perceptual_sample_rate = f;
}

typedef void* ViSQOL;

ViSQOL CreateViSQOL() { return new zimtohrli::ViSQOL(); }

void FreeViSQOL(ViSQOL v) { delete (zimtohrli::ViSQOL*)(v); }
Expand Down
8 changes: 4 additions & 4 deletions cpp/zimt/visqol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

#include "absl/log/check.h"
#include "absl/types/span.h"
#include "libsvm_nu_svr_model.h"
#include "visqol_api.h"
#include "zimt/resample.h"
#include "zimt/visqol_model.h"

constexpr size_t SAMPLE_RATE = 48000;

Expand All @@ -39,8 +39,8 @@ ViSQOL::ViSQOL() {
populated_path_template.data(), populated_path_template.size()));
std::ofstream output_stream(model_path_);
CHECK(output_stream.good());
output_stream.write(reinterpret_cast<char*>(visqol_model_bytes),
visqol_model_bytes_len);
absl::Span<const char> model = ViSQOLModel();
output_stream.write(model.data(), model.size());
CHECK(output_stream.good());
output_stream.close();
CHECK(output_stream.good());
Expand Down Expand Up @@ -82,7 +82,7 @@ float ViSQOL::MOS(absl::Span<const float> reference,
visqol.Measure(absl::Span<double>(resampled_reference.data(),
resampled_reference.size()),
absl::Span<double>(resampled_degraded.data(),
resampled_reference.size()));
resampled_degraded.size()));
CHECK_OK(comparison_status_or);

Visqol::SimilarityResultMsg similarity_result = comparison_status_or.value();
Expand Down
27 changes: 27 additions & 0 deletions cpp/zimt/visqol_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2024 The Zimtohrli Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "visqol_model.h"

#include "absl/types/span.h"
#include "libsvm_nu_svr_model.h"

namespace zimtohrli {

absl::Span<char> ViSQOLModel() {
return absl::Span<char>(reinterpret_cast<char*>(visqol_model_bytes),
visqol_model_bytes_len);
}

} // namespace zimtohrli
27 changes: 27 additions & 0 deletions cpp/zimt/visqol_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2024 The Zimtohrli Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef CPP_ZIMT_VISQOL_MODEL_H_
#define CPP_ZIMT_VISQOL_MODEL_H_

#include "absl/types/span.h"

namespace zimtohrli {

// Returns the bytes of the default ViSQOL model.
absl::Span<char> ViSQOLModel();

} // namespace zimtohrli

#endif // CPP_ZIMT_VISQOL_MODEL_H_
52 changes: 34 additions & 18 deletions go/bin/compare/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import (
func main() {
pathA := flag.String("path_a", "", "Path to ffmpeg-decodable file with signal A.")
pathB := flag.String("path_b", "", "Path to ffmpeg-decodable file with signal B.")
visqol := flag.Bool("visqol", false, "Whether to use ViSQOL instead of Zimtohrli metrics.")
outputZimtohrliDistance := flag.Bool("output_zimtohrli_distance", false, "Whether to output the raw Zimtohrli distance instead of a mapped mean opinion score.")
perChannel := flag.Bool("per_channel", false, "Whether to output the produced metric per channel instead of a single value for all channels.")
frequencyResolution := flag.Float64("frequency_resolution", float64(goohrli.DefaultFrequencyResolution()), "Band width of smallest filter, i.e. expected frequency resolution of human hearing.")
flag.Parse()

if *pathA == "" || *pathB == "" {
Expand All @@ -55,25 +55,41 @@ func main() {
log.Panic(fmt.Errorf("%q has %v channels, and %q has %v channels", *pathA, len(signalA.Samples), *pathB, len(signalB.Samples)))
}

getMetric := func(f float64) float64 {
if *outputZimtohrliDistance {
return f
}
return goohrli.MOSFromZimtohrli(f)
}

g := goohrli.New(signalA.Rate, *frequencyResolution)
if *perChannel {
for channelIndex := range signalA.Samples {
measurement := goohrli.Measure(signalA.Samples[channelIndex])
goohrli.NormalizeAmplitude(measurement.MaxAbsAmplitude, signalB.Samples[channelIndex])
fmt.Println(getMetric(g.Distance(signalA.Samples[channelIndex], signalB.Samples[channelIndex])))
if *visqol {
v := goohrli.NewViSQOL()
if *perChannel {
for channelIndex := range signalA.Samples {
mos := v.MOS(signalA.Rate, signalA.Samples[channelIndex], signalB.Samples[channelIndex])
fmt.Println(mos)
}
} else {
mos, err := v.AudioMOS(signalA, signalB)
if err != nil {
log.Panic(err)
}
fmt.Println(mos)
}
} else {
dist, err := g.NormalizedAudioDistance(signalA, signalB)
if err != nil {
log.Panic(err)
getMetric := func(f float64) float64 {
if *outputZimtohrliDistance {
return f
}
return goohrli.MOSFromZimtohrli(f)
}

g := goohrli.New(signalA.Rate, goohrli.DefaultFrequencyResolution())
if *perChannel {
for channelIndex := range signalA.Samples {
measurement := goohrli.Measure(signalA.Samples[channelIndex])
goohrli.NormalizeAmplitude(measurement.MaxAbsAmplitude, signalB.Samples[channelIndex])
fmt.Println(getMetric(g.Distance(signalA.Samples[channelIndex], signalB.Samples[channelIndex])))
}
} else {
dist, err := g.NormalizedAudioDistance(signalA, signalB)
if err != nil {
log.Panic(err)
}
fmt.Println(getMetric(dist))
}
fmt.Println(getMetric(dist))
}
}
4 changes: 2 additions & 2 deletions go/bin/score/score.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ func main() {
calculate := flag.String("calculate", "", "Path to a database directory with a study to calculate metrics for.")
calculateZimtohrli := flag.Bool("calculate_zimtohrli", true, "Whether to calculate Zimtohrli scores.")
calculateViSQOL := flag.Bool("calculate_visqol", false, "Whether to calculate ViSQOL scores.")
zimtohrliFrequencyResolution := flag.Float64("zimtohrli_frequency_resolution", float64(goohrli.DefaultFrequencyResolution()), "Smallest bandwidth of the Zimtohrli filterbank.")
zimtohrliPerceptualSampleRate := flag.Float64("zimtohrli_perceptual_sample_rate", float64(goohrli.DefaultPerceptualSampleRate()), "Sample rate of the Zimtohrli spectrograms.")
zimtohrliFrequencyResolution := flag.Float64("zimtohrli_frequency_resolution", goohrli.DefaultFrequencyResolution(), "Smallest bandwidth of the Zimtohrli filterbank.")
zimtohrliPerceptualSampleRate := flag.Float64("zimtohrli_perceptual_sample_rate", goohrli.DefaultPerceptualSampleRate(), "Sample rate of the Zimtohrli spectrograms.")
correlate := flag.String("correlate", "", "Path to a database directory with a study to correlate scores for.")
hist := flag.String("hist", "", "Path to a database directory with a study to provide JND histograms for.")
histThresholds := flag.String("hist_thresholds", "0,5,10,15,20,25,30,35,40,45,50,55,60,65,70,75,80,85,90,95,100", "A comma separated list of Zimtohrli distance thresholds to compute histograms for.")
Expand Down
Binary file modified go/goohrli/goohrli.a
Binary file not shown.
16 changes: 8 additions & 8 deletions go/goohrli/goohrli.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ import (
)

// DefaultFrequencyResolution returns the default frequency resolution corresponding to the minimum width (at the low frequency end) of the Zimtohrli filter bank.
func DefaultFrequencyResolution() float32 {
return float32(C.DefaultFrequencyResolution())
func DefaultFrequencyResolution() float64 {
return float64(C.DefaultFrequencyResolution())
}

// DefaultPerceptualSampleRate returns the default perceptual sample rate corresponding to the human hearing sensitivity to timing changes.
func DefaultPerceptualSampleRate() float32 {
return float32(C.DefaultPerceptualSampleRate())
func DefaultPerceptualSampleRate() float64 {
return float64(C.DefaultPerceptualSampleRate())
}

// EnergyAndMaxAbsAmplitude is holds the energy and maximum absolute amplitude of a measurement.
Expand Down Expand Up @@ -200,12 +200,12 @@ func NewViSQOL() *ViSQOL {
}

// MOS returns the ViSQOL mean opinion score of the degraded samples comapred to the reference samples.
func (g *ViSQOL) MOS(sampleRate float64, reference []float32, degraded []float32) float64 {
return float64(C.MOS(g.visqol, C.float(sampleRate), (*C.float)(&reference[0]), C.int(len(reference)), (*C.float)(&degraded[0]), C.int(len(degraded))))
func (v *ViSQOL) MOS(sampleRate float64, reference []float32, degraded []float32) float64 {
return float64(C.MOS(v.visqol, C.float(sampleRate), (*C.float)(&reference[0]), C.int(len(reference)), (*C.float)(&degraded[0]), C.int(len(degraded))))
}

// AudioMOS returns the ViSQOL mean opinion score of the degraded audio compared to the reference audio.
func (g *ViSQOL) AudioMOS(reference, degraded *audio.Audio) (float64, error) {
func (v *ViSQOL) AudioMOS(reference, degraded *audio.Audio) (float64, error) {
sumOfSquares := 0.0
if reference.Rate != degraded.Rate {
return 0, fmt.Errorf("the audio files don't have the same sample rate: %v, %v", reference.Rate, degraded.Rate)
Expand All @@ -214,7 +214,7 @@ func (g *ViSQOL) AudioMOS(reference, degraded *audio.Audio) (float64, error) {
return 0, fmt.Errorf("the audio files don't have the same number of channels: %v, %v", len(reference.Samples), len(degraded.Samples))
}
for channelIndex := range reference.Samples {
mos := g.MOS(reference.Rate, reference.Samples[channelIndex], degraded.Samples[channelIndex])
mos := v.MOS(reference.Rate, reference.Samples[channelIndex], degraded.Samples[channelIndex])
sumOfSquares += mos * mos
}
return math.Sqrt(sumOfSquares / float64(len(reference.Samples))), nil
Expand Down
Loading