Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made the MOS mapping parameters mutable. #110

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions cpp/zimt/compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ std::ostream& operator<<(std::ostream& outs, const DistanceData& data) {
return outs;
}

float GetMetric(float zimtohrli_score) {
float GetMetric(const zimtohrli::Zimtohrli& z, float zimtohrli_score) {
if (absl::GetFlag(FLAGS_output_zimtohrli_distance)) {
return zimtohrli_score;
}
return MOSFromZimtohrli(zimtohrli_score);
return z.mos_mapper.Map(zimtohrli_score);
}

int Main(int argc, char* argv[]) {
Expand Down Expand Up @@ -340,16 +340,16 @@ int Main(int argc, char* argv[]) {
z.Distance(false, file_a_spectrograms[channel_index], spectrogram_b)
.value;
if (per_channel) {
std::cout << GetMetric(distance) << std::endl;
std::cout << GetMetric(z, distance) << std::endl;
} else {
sum_of_squares += distance * distance;
}
}
if (!per_channel) {
for (int file_b_index = 0; file_b_index < file_b_vector.size();
++file_b_index) {
std::cout << GetMetric(std::sqrt(sum_of_squares /
float(file_a->Info().channels)))
std::cout << GetMetric(z, std::sqrt(sum_of_squares /
float(file_a->Info().channels)))
<< std::endl;
}
}
Expand Down Expand Up @@ -413,13 +413,13 @@ int Main(int argc, char* argv[]) {
const float distance = phons_channel_distance.distance.value;
sum_of_squares += distance * distance;

std::cout << " Channel MOS: " << MOSFromZimtohrli(distance)
std::cout << " Channel MOS: " << z.mos_mapper.Map(distance)
<< std::endl;
}
const float zimtohrli_file_distance =
std::sqrt(sum_of_squares / float(comparison.analysis_a.size()));
std::cout << " File distance: " << zimtohrli_file_distance << std::endl;
std::cout << " File MOS: " << MOSFromZimtohrli(zimtohrli_file_distance)
std::cout << " File MOS: " << z.mos_mapper.Map(zimtohrli_file_distance)
<< std::endl;
}
return 0;
Expand Down
24 changes: 17 additions & 7 deletions cpp/zimt/goohrli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ int NumLoudnessTFParams() {
return NUM_LOUDNESS_T_F_PARAMS;
}

int NumMOSMapperParams() {
CHECK_EQ(NUM_MOS_MAPPER_PARAMS, zimtohrli::MOSMapper{}.params.size());
return NUM_MOS_MAPPER_PARAMS;
}

EnergyAndMaxAbsAmplitude Measure(const float* signal, int size) {
hwy::AlignedNDArray<float, 1> signal_array({static_cast<size_t>(size)});
hwy::CopyBytes(signal, signal_array.data(), size * sizeof(float));
Expand All @@ -67,11 +72,12 @@ EnergyAndMaxAbsAmplitude NormalizeAmplitude(float max_abs_amplitude,
.MaxAbsAmplitude = measurements.max_abs_amplitude};
}

float MOSFromZimtohrli(float zimtohrli_distance) {
return zimtohrli::MOSFromZimtohrli(zimtohrli_distance);
float MOSFromZimtohrli(const Zimtohrli zimtohrli, float zimtohrli_distance) {
const zimtohrli::Zimtohrli* z = static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
return z->mos_mapper.Map(zimtohrli_distance);
}

Zimtohrli CreateZimtohrli(ZimtohrliParameters params) {
Zimtohrli CreateZimtohrli(const ZimtohrliParameters params) {
zimtohrli::Cam cam{.minimum_bandwidth_hz = params.FrequencyResolution,
.filter_order = params.FilterOrder,
.filter_pass_band_ripple = params.FilterPassBandRipple,
Expand All @@ -88,9 +94,9 @@ void FreeZimtohrli(Zimtohrli zimtohrli) {
delete static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
}

float Distance(Zimtohrli zimtohrli, float* data_a, int size_a, float* data_b,
int size_b) {
zimtohrli::Zimtohrli* z = static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
float Distance(const Zimtohrli zimtohrli, float* data_a, int size_a,
float* data_b, int size_b) {
const zimtohrli::Zimtohrli* z = static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
hwy::AlignedNDArray<float, 1> signal_a({static_cast<size_t>(size_a)});
hwy::CopyBytes(data_a, signal_a.data(), size_a * sizeof(float));
hwy::AlignedNDArray<float, 1> signal_b({static_cast<size_t>(size_b)});
Expand All @@ -105,7 +111,7 @@ float Distance(Zimtohrli zimtohrli, float* data_a, int size_a, float* data_b,
}

ZimtohrliParameters GetZimtohrliParameters(const Zimtohrli zimtohrli) {
zimtohrli::Zimtohrli* z = static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
const zimtohrli::Zimtohrli* z = static_cast<zimtohrli::Zimtohrli*>(zimtohrli);
ZimtohrliParameters result;
result.SampleRate = z->cam_filterbank->sample_rate;
const hwy::AlignedNDArray<float, 2>& thresholds =
Expand Down Expand Up @@ -133,6 +139,8 @@ ZimtohrliParameters GetZimtohrliParameters(const Zimtohrli zimtohrli) {
sizeof(result.LoudnessLUParams));
std::memcpy(result.LoudnessTFParams, z->loudness.t_f_params.data(),
sizeof(result.LoudnessTFParams));
std::memcpy(result.MOSMapperParams, z->mos_mapper.params.data(),
sizeof(result.MOSMapperParams));
return result;
}

Expand All @@ -157,6 +165,8 @@ void SetZimtohrliParameters(Zimtohrli zimtohrli,
sizeof(parameters.LoudnessLUParams));
std::memcpy(z->loudness.t_f_params.data(), parameters.LoudnessTFParams,
sizeof(parameters.LoudnessTFParams));
std::memcpy(z->mos_mapper.params.data(), parameters.MOSMapperParams,
sizeof(parameters.MOSMapperParams));
}

ZimtohrliParameters DefaultZimtohrliParameters(float sample_rate) {
Expand Down
11 changes: 3 additions & 8 deletions cpp/zimt/mos.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,14 @@ namespace zimtohrli {

namespace {

const std::array<float, 3> params = {1.000e+00, -7.449e-09, 3.344e+00};

float sigmoid(float x) {
float sigmoid(const std::array<float, 3>& params, float x) {
return params[0] / (params[1] + std::exp(params[2] * x));
}

const float zero_crossing_reciprocal = 1.0 / sigmoid(0);

} // namespace

// Optimized using `mos_mapping.ipynb`.
float MOSFromZimtohrli(float zimtohrli_distance) {
return 1.0 + 4.0 * sigmoid(zimtohrli_distance) * zero_crossing_reciprocal;
float MOSMapper::Map(float zimtohrli_distance) const {
return 1.0 + 4.0 * sigmoid(params, zimtohrli_distance) / sigmoid(params, 0);
}

} // namespace zimtohrli
27 changes: 20 additions & 7 deletions cpp/zimt/mos.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,28 @@
#ifndef CPP_ZIMT_MOS_H_
#define CPP_ZIMT_MOS_H_

#include <array>

namespace zimtohrli {

// Returns a _very_approximate_ mean opinion score based on the
// provided Zimtohrli distance.
// This is calibrated using default settings of v0.1.5, with a
// minimum channel bandwidth (zimtohrli::Cam.minimum_bandwidth_hz)
// of 5Hz and perceptual sample rate
// (zimtohrli::Distance(..., perceptual_sample_rate, ...) of 100Hz.
float MOSFromZimtohrli(float zimtohrli_distance);
// Maps from Zimtohrli distance to MOS.
struct MOSMapper {
// Returns a _very_approximate_ mean opinion score based on the
// provided Zimtohrli distance.
//
// Computed by:
// s(x) = params[0] / (params[1] + e^(params[2] * x))
// MOS = 1 + 4 * s(distance)) / s(0)
//
// This is calibrated using default settings of v0.1.5, with a
// minimum channel bandwidth (zimtohrli::Cam.minimum_bandwidth_hz)
// of 5Hz and perceptual sample rate
// (zimtohrli::Distance(..., perceptual_sample_rate, ...) of 100Hz.
float Map(float zimtohrli_distance) const;

// Params used when mapping Zimtohrli distance to MOS.
std::array<float, 3> params = {1.000e+00, -7.449e-09, 3.344e+00};
};

} // namespace zimtohrli

Expand Down
3 changes: 2 additions & 1 deletion cpp/zimt/mos_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ TEST(MOS, MOSFromZimtohrli) {
const std::vector<float> zimt_scores = {0, 0.1, 0.5, 0.7, 1.0};
const std::vector<float> mos = {5.0, 3.8630697727203369, 1.751483678817749,
1.3850023746490479, 1.1411819458007812};
const MOSMapper m;
for (size_t index = 0; index < zimt_scores.size(); ++index) {
ASSERT_NEAR(MOSFromZimtohrli(zimt_scores[index]), mos[index], 1e-2);
ASSERT_NEAR(m.Map(zimt_scores[index]), mos[index], 1e-2);
}
}

Expand Down
30 changes: 13 additions & 17 deletions cpp/zimt/pyohrli.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,22 @@ PyObject* Pyohrli_distance(PyohrliObject* self, PyObject* const* args,
return PyFloat_FromDouble(distance.value);
}

PyObject* Pyohrli_mos_from_zimtohrli(PyohrliObject* self, PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 1) {
return BadArgument("not exactly 1 argument provided");
}
return PyFloat_FromDouble(
self->zimtohrli->mos_mapper.Map(PyFloat_AsDouble(args[0])));
}

PyMethodDef Pyohrli_methods[] = {
{"distance", (PyCFunction)Pyohrli_distance, METH_FASTCALL,
"Returns the distance between the two provided signals."},
{"mos_from_zimtohrli", (PyCFunction)Pyohrli_mos_from_zimtohrli,
METH_FASTCALL,
"Returns an approximate mean opinion score based on the provided "
"Zimtohrli distance."},
{nullptr} /* Sentinel */
};

Expand All @@ -150,28 +163,11 @@ PyTypeObject PyohrliType = {
.tp_new = PyType_GenericNew,
};

PyObject* MOSFromZimtohrli(PyohrliObject* self, PyObject* const* args,
Py_ssize_t nargs) {
if (nargs != 1) {
return BadArgument("not exactly 1 argument provided");
}
return PyFloat_FromDouble(
zimtohrli::MOSFromZimtohrli(PyFloat_AsDouble(args[0])));
}

static PyMethodDef PyohrliModuleMethods[] = {
{"MOSFromZimtohrli", (PyCFunction)MOSFromZimtohrli, METH_FASTCALL,
"Returns an approximate mean opinion score based on the provided "
"Zimtohrli distance."},
{NULL, NULL, 0, NULL},
};

PyModuleDef PyohrliModule = {
.m_base = PyModuleDef_HEAD_INIT,
.m_name = "pyohrli",
.m_doc = "Python wrapper around the C++ zimtohrli library.",
.m_size = -1,
.m_methods = PyohrliModuleMethods,
};

PyMODINIT_FUNC PyInit__pyohrli(void) {
Expand Down
9 changes: 4 additions & 5 deletions cpp/zimt/pyohrli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
import _pyohrli


def mos_from_zimtohrli(zimtohrli_distance: float) -> float:
"""Returns an approximate mean opinion score based on the provided Zimtohrli distance."""
return _pyohrli.MOSFromZimtohrli(zimtohrli_distance)


class Pyohrli:
"""Wrapper around C++ zimtohrli::Zimtohrli."""

Expand Down Expand Up @@ -56,6 +51,10 @@ def distance(self, signal_a: npt.ArrayLike, signal_b: npt.ArrayLike) -> float:
np.asarray(signal_b).astype(np.float32).ravel().data,
)

def mos_from_zimtohrli(self, zimtohrli_distance: float) -> float:
"""Returns an approximate mean opinion score based on the provided Zimtohrli distance."""
return self._cc_pyohrli.mos_from_zimtohrli(zimtohrli_distance)

@property
def full_scale_sine_db(self) -> float:
"""Reference intensity for an amplitude 1.0 sine wave at 1kHz.
Expand Down
3 changes: 2 additions & 1 deletion cpp/zimt/pyohrli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def test_nyquist_threshold(self):
dict(zimtohrli_distance=1.0, mos=1.1411819458007812),
)
def test_mos_from_zimtohrli(self, zimtohrli_distance: float, mos: float):
metric = pyohrli.Pyohrli(48000.0)
self.assertAlmostEqual(
mos, pyohrli.mos_from_zimtohrli(zimtohrli_distance), places=3
mos, metric.mos_from_zimtohrli(zimtohrli_distance), places=3
)


Expand Down
4 changes: 4 additions & 0 deletions cpp/zimt/zimtohrli.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "zimt/cam.h"
#include "zimt/loudness.h"
#include "zimt/masking.h"
#include "zimt/mos.h"

namespace zimtohrli {

Expand Down Expand Up @@ -326,6 +327,9 @@ struct Zimtohrli {
// Perceptual intensity model.
Loudness loudness;

// MOS mapping model.
MOSMapper mos_mapper;

// Whether the masking model is applied when creating spectrograms.
bool apply_masking = true;

Expand Down
4 changes: 2 additions & 2 deletions go/bin/compare/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,12 @@ func main() {
}

if *zimtohrli {
g := goohrli.New(zimtohrliParameters)
getMetric := func(f float64) float64 {
if *outputZimtohrliDistance {
return f
}
return goohrli.MOSFromZimtohrli(f)
return g.MOSFromZimtohrli(f)
}

if err := zimtohrliParameters.Update([]byte(*zimtohrliParametersJSON)); err != nil {
Expand All @@ -117,7 +118,6 @@ func main() {
log.Printf("Using %+v", zimtohrliParameters)
}
zimtohrliParameters.SampleRate = signalA.Rate
g := goohrli.New(zimtohrliParameters)
if *perChannel {
for channelIndex := range signalA.Samples {
measurement := goohrli.Measure(signalA.Samples[channelIndex])
Expand Down
16 changes: 14 additions & 2 deletions go/bin/score/score.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ func main() {
optimizeNumSteps := flag.Float64("optimize_num_steps", 1000, "Number of steps for the simulated annealing.")
workers := flag.Int("workers", runtime.NumCPU(), "Number of concurrent workers for tasks.")
failFast := flag.Bool("fail_fast", false, "Whether to panic immediately on any error.")
optimizeMapping := flag.String("optimize_mapping", "", "Glob to directories with databases to optimize the MOS mapping for.")
flag.Parse()

if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" {
if *details == "" && *calculate == "" && *correlate == "" && *accuracy == "" && *leaderboard == "" && *report == "" && *optimize == "" && *optimizeMapping == "" {
flag.Usage()
os.Exit(1)
}
Expand All @@ -88,10 +89,21 @@ func main() {
f.Sync()
}
}
err = bundles.Optimize(*optimizeStartStep, *optimizeNumSteps, optimizeLog)
if err = bundles.Optimize(*optimizeStartStep, *optimizeNumSteps, optimizeLog); err != nil {
log.Fatal(err)
}
}

if *optimizeMapping != "" {
bundles, err := data.OpenBundles(*optimizeMapping)
if err != nil {
log.Fatal(err)
}
params, err := bundles.OptimizeMapping()
if err != nil {
log.Fatal(err)
}
fmt.Println(params)
}

if *calculate != "" {
Expand Down
4 changes: 4 additions & 0 deletions go/data/study.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,10 @@ func (r ReferenceBundles) Split(rng *rand.Rand, split float64) (ReferenceBundles
return left, right
}

func (r ReferenceBundles) OptimizeMapping() ([]float32, error) {
return nil, nil
}

// OptimizationEvent is a step in the optimization process.
type OptimizationEvent struct {
Parameters goohrli.Parameters
Expand Down
Binary file modified go/goohrli/goohrli.a
Binary file not shown.
Loading
Loading