Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

~5% worse, but faster and simpler #122

Merged
merged 1 commit into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 97 additions & 94 deletions cpp/zimt/fourier_bank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,121 +71,124 @@ float SimpleDb(float energy) {
return kMul * log(energy + epsilon);
}

void FinalizeDb(hwy::AlignedNDArray<float, 2>& channels, float mul,
size_t out_ix) {
float masker_down[kNumRotators];
for (int k = 0; k < kNumRotators; ++k) {
float v = SimpleDb(mul * channels[{out_ix}][k]);
channels[{out_ix}][k] = Loudness(k, v);
void PrepareMasker(hwy::AlignedNDArray<float, 2>& channels,
float *masker,
size_t out_ix) {
if (out_ix < 3) {
for (int k = 0; k < kNumRotators; ++k) {
masker[k] = channels[{out_ix}][k];
}
} else {
// convolve in time and freq, 5 freq bins, 3 time bins
static const double c[12] = {
0.011551012731481482,
0.02009898726851852,
0.27419898726851855,

-0.04009898726851849,
0.3270268229166667,
0.6400989872685185,

0.36397005208333333,
0.6505010127314814,
0.8000989872685186,

-0.15930101273148148,
1.5483130497685185,
8.31009898726852,
};
static const float div = 1.0 / (2*(c[0]+c[1]+c[2]+c[3]+c[4]+c[5]+c[6]+c[7]+c[8])+c[9]+c[10]+c[11]);
for (int k = 0; k < kNumRotators; ++k) {
int prev3 = std::max(0, k - 3);
int prev2 = std::max(0, k - 2);
int prev1 = std::max(0, k - 1);
int currk = k;
int next1 = std::min<int>(kNumRotators - 1, k + 1);
int next2 = std::min<int>(kNumRotators - 1, k + 2);
int next3 = std::min<int>(kNumRotators - 1, k + 3);
size_t oi2 = out_ix - 2;
size_t oi1 = out_ix - 1;
size_t oi0 = out_ix - 0;

float v =
channels[{oi2}][prev3] * c[0] + channels[{oi1}][prev3] * c[1] + channels[{oi0}][prev3] * c[2] +
channels[{oi2}][prev2] * c[3] + channels[{oi1}][prev2] * c[4] + channels[{oi0}][prev2] * c[5] +
channels[{oi2}][prev1] * c[6] + channels[{oi1}][prev1] * c[7] + channels[{oi0}][prev1] * c[8] +
channels[{oi2}][currk] * c[9] + channels[{oi1}][currk] * c[10] + channels[{oi0}][currk] * c[11] +
channels[{oi2}][next1] * c[6] + channels[{oi1}][next1] * c[7] + channels[{oi0}][next1] * c[8] +
channels[{oi2}][next2] * c[3] + channels[{oi1}][next2] * c[4] + channels[{oi0}][next2] * c[5] +
channels[{oi2}][next3] * c[0] + channels[{oi1}][next3] * c[1] + channels[{oi0}][next3] * c[2];

masker[k] = v * div;
}
}
double masker = 0.0;
static const double octaves_in_20_to_20000 = log(20000/20.)/log(2);
static const double octaves_per_rot =
octaves_in_20_to_20000 / float(kNumRotators - 1);
static const double masker_step_per_octave_up_0 = 19.53945781131615;
static const double masker_step_per_octave_up_1 = 24.714118008386887;
static const double masker_step_per_octave_up_2 = 6.449301354309956;
static const double masker_step_per_octave_up_0 = 20.54547806594578;
static const double masker_step_per_octave_up_1 = 24.608097753757256;
static const double masker_step_per_octave_up_2 = 6.0;
static const double masker_step_per_rot_up_0 = octaves_per_rot * masker_step_per_octave_up_0;
static const double masker_step_per_rot_up_1 = octaves_per_rot * masker_step_per_octave_up_1;
static const double masker_step_per_rot_up_2 = octaves_per_rot * masker_step_per_octave_up_2;
static const double masker_gap_up = 21.309406898722074;
static const float maskingStrengthUp = 0.2056434702527141;
static const float up_blur = 0.9442717063037425;
static const float fraction_up = 1.1657467617827404;

static const double masker_step_per_octave_down = 53.40273959309446;
static const double masker_step_per_octave_down = 53.40075984772409;
static const double masker_step_per_rot_down = octaves_per_rot * masker_step_per_octave_down;
static const double masker_gap_down = 19.08401096304284;
static const float maskingStrengthDown = 0.18030917038808858;
static const float down_blur = 0.7148792180987857;
// propagate masker up
float mask = 0;
for (int k = 0; k < kNumRotators; ++k) {
float v = masker[k];
if (mask < v) {
mask = v;
}
masker[k] = std::max<float>(masker[k], mask);
if (3 * k < kNumRotators) {
mask -= masker_step_per_rot_up_0;
} else if (3 * k < 2 * kNumRotators) {
mask -= masker_step_per_rot_up_1;
} else {
mask -= masker_step_per_rot_up_2;
}
}
// propagate masker down
mask = 0;
for (int k = kNumRotators - 1; k >= 0; --k) {
float v = masker[k];
if (mask < v) {
mask = v;
}
masker[k] = std::max<float>(masker[k], mask);
mask -= masker_step_per_rot_down;
}
}

static const float min_limit = -11.3968870989223;
static const float fraction_down = 1.0197608300379997;
void FinalizeDb(hwy::AlignedNDArray<float, 2>& channels, float mul,
size_t out_ix) {
float masker[kNumRotators];
for (int k = 0; k < kNumRotators; ++k) {
float v = SimpleDb(mul * channels[{out_ix}][k]);
channels[{out_ix}][k] = Loudness(k, v);
}
PrepareMasker(channels, &masker[0], out_ix);

static const float temporal0 = 0.09979167061501665;
static const float temporal1 = 0.14429505133534495;
static const float temporal2 = 0.009228598592129168;
static const float weightp = 0.1792443302507868;
static const float weightm = 0.7954490998745948;

static const float mask_k = 0.08709005149742773;
static const double masker_gap = 20.716199363425925;
static const float maskingStrength = 0.22591336897956596;

static const float min_limit = -11.3968870989223;

// Scan frequencies from bottom to top, let lower frequencies to mask higher frequencies.
// 'masker' maintains the masking envelope from one bin to next.
for (int k = 0; k < kNumRotators; ++k) {
float v = channels[{out_ix}][k];
if (out_ix != 0) {
v = (1.0 - mask_k) * v + mask_k * channels[{out_ix - 1}][k];
}
double mask = masker[k] - masker_gap;
if (v < min_limit) {
v = min_limit;
}
float v2 = (1 - up_blur) * v2 + up_blur * v;
if (k == 0) {
v2 = v;
}
if (masker < v2) {
masker = v2;
}
float mask = fraction_up * masker - masker_gap_up;
if (v < mask) {
v = maskingStrengthUp * mask + (1.0 - maskingStrengthUp) * v;
}
channels[{out_ix}][k] = v;
if (3 * k < kNumRotators) {
masker -= masker_step_per_rot_up_0;
} else if (3 * k < 2 * kNumRotators) {
masker -= masker_step_per_rot_up_1;
} else {
masker -= masker_step_per_rot_up_2;
}
}
// Scan frequencies from top to bottom, let higher frequencies to mask lower frequencies.
// 'masker' maintains the masking envelope from one bin to next.
masker = 0.0;
for (int k = kNumRotators - 1; k >= 0; --k) {
float v = channels[{out_ix}][k];
if (out_ix != 0) {
v = (1.0 - mask_k) * v + mask_k * channels[{out_ix - 1}][k];
}
float v2 = (1 - down_blur) * v2 + down_blur * v;
if (k == kNumRotators - 1) {
v2 = v;
}
if (masker < v) {
masker = v;
}
float mask = fraction_down * masker - masker_gap_down;
if (v < mask) {
v = maskingStrengthDown * mask + (1.0 - maskingStrengthDown) * v;
v = maskingStrength * mask + (1.0 - maskingStrength) * v;
}
channels[{out_ix}][k] = v;
masker -= masker_step_per_rot_down;
}
// temporal masker
if (out_ix >= 3) {
for (int k = 0; k < kNumRotators; ++k) {
float m = (temporal0 * channels[{out_ix - 1}][k] +
temporal1 * channels[{out_ix - 2}][k] +
temporal2 * channels[{out_ix - 3}][k]) / (temporal0 + temporal1 + temporal2);
if (m > channels[{out_ix}][k]) {
channels[{out_ix}][k] -= weightp * (m - channels[{out_ix}][k]);
} else {
channels[{out_ix}][k] -= weightm * (m - channels[{out_ix}][k]);
}
/*
// todo(jyrki): explore with this
static const float temporal_masker0 = 0.1387454636244773;
channels[{out_ix}][k] -=
temporal_masker0 * (channels[{out_ix - 1}][k] - channels[{out_ix}][k]);
static const float temporal_masker1 = 0.08715440670406614;
channels[{out_ix}][k] -=
temporal_masker1 * (channels[{out_ix - 2}][k] - channels[{out_ix}][k]);
static const float temporal_masker2 = -0.03785233735225447;
channels[{out_ix}][k] -=
temporal_masker2 * (channels[{out_ix - 3}][k] - channels[{out_ix}][k]);
*/
}
}
}

Expand Down Expand Up @@ -254,8 +257,8 @@ Rotators::Rotators(int num_channels, std::vector<float> frequency,
window[i] = std::pow(kWindow, bw * kBandwidthMagic);
float windowM1 = 1.0f - window[i];
float f = frequency[i] * 2.0f * M_PI / sample_rate;
static const float full_scale_sine_db = exp(75.27858635739499);
const float gainer = 2.0f * sqrt(full_scale_sine_db);
static const float full_scale_sine_db = exp(76.66488071851488);
const float gainer = sqrt(full_scale_sine_db);
gain[i] = gainer * filter_gains[i] * pow(windowM1, 3.0);
rot[0][i] = float(std::cos(f));
rot[1][i] = float(-std::sin(f));
Expand Down
2 changes: 1 addition & 1 deletion cpp/zimt/fourier_bank.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

namespace tabuli {

constexpr int64_t kNumRotators = 150;
constexpr int64_t kNumRotators = 128;

struct PerChannel {
// [0..1] is for real and imag of 1st leaking accumulation
Expand Down
4 changes: 2 additions & 2 deletions cpp/zimt/zimtohrli.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ struct Zimtohrli {
std::optional<CamFilterbank> cam_filterbank;

// The window in perceptual_sample_rate time steps when compting the NSIM.
size_t nsim_step_window = 16;
size_t nsim_step_window = 8;

// The window in channels when computing the NSIM.
size_t nsim_channel_window = 30;
size_t nsim_channel_window = 16;

// The window of the dynamic time warp that matches audio signals.
//
Expand Down
Binary file modified go/goohrli/goohrli.a
Binary file not shown.
Loading