Skip to content

Commit

Permalink
Refactored CutFullyMasked to avoid an extra loop.
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Bruse authored and zond committed May 27, 2024
1 parent 4319e10 commit 3e6ad14
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 25 deletions.
29 changes: 10 additions & 19 deletions cpp/zimt/masking.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,35 +207,26 @@ void HwyCutFullyMasked(const Masking& m,
FullMaskingCalculator full_masking_calculator(m, cam_delta);
const size_t num_samples = energy_channels_db.shape()[0];
const size_t num_channels = energy_channels_db.shape()[1];
hwy::AlignedNDArray<float, 2> max_masked(energy_channels_db.shape());
for (size_t sample_index = 0; sample_index < num_samples; ++sample_index) {
const float* energy_channels_db_data =
energy_channels_db[{sample_index}].data();
for (size_t probe_channel_index = 0; probe_channel_index < num_channels;
++probe_channel_index) {
float max_masked = std::numeric_limits<float>::min();
for (size_t masker_channel_index = 0; masker_channel_index < num_channels;
masker_channel_index += Lanes(d)) {
const Vec masker_level_db =
Load(d, energy_channels_db[{sample_index}].data() +
masker_channel_index);
Load(d, energy_channels_db_data + masker_channel_index);
const Vec full_masking_db = full_masking_calculator.Calculate(
masker_level_db, static_cast<float>(probe_channel_index) -
static_cast<float>(masker_channel_index));
max_masked[{sample_index}][probe_channel_index] =
std::max(max_masked[{sample_index}][probe_channel_index],
ReduceMax(d, full_masking_db));
max_masked = std::max(max_masked, ReduceMax(d, full_masking_db));
}
}
}
for (size_t sample_index = 0; sample_index < num_samples; ++sample_index) {
const float* energy_channels_db_data =
energy_channels_db[{sample_index}].data();
const float* max_masked_data = max_masked[{sample_index}].data();
float* non_masked_db_data = non_masked_db[{sample_index}].data();
for (size_t channel_index = 0; channel_index < num_channels;
channel_index += Lanes(d)) {
const Vec max_masking = Load(d, max_masked_data + channel_index);
const Vec probe = Load(d, energy_channels_db_data + channel_index);
Store(IfThenElse(Ge(max_masking, probe), Sub(probe, max_masking), probe),
d, non_masked_db_data + channel_index);
const float probe_energy_db =
energy_channels_db_data[probe_channel_index];
non_masked_db[{sample_index}][probe_channel_index] =
max_masked > probe_energy_db ? probe_energy_db - max_masked
: probe_energy_db;
}
}
}
Expand Down
15 changes: 9 additions & 6 deletions cpp/zimt/masking_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,16 @@ TEST(Masking, FullMasking) {
TEST(Masking, CutFullyMasked) {
hwy::AlignedNDArray<float, 2> energy_channels({1, 2});
hwy::AlignedNDArray<float, 2> non_masked({1, 2});
Masking m;
Masking m{.lower_zero_at_20 = -2,
.lower_zero_at_80 = -6,
.upper_zero_at_20 = 2,
.upper_zero_at_80 = 10,
.max_mask = 20};

energy_channels[{0}] = {90, 20};
m.CutFullyMasked(energy_channels, 1, non_masked);
EXPECT_NEAR((non_masked[{0}][0]), 90, 1e-2) << "No self masking";
EXPECT_NEAR((non_masked[{0}][1]), -45.686698913574219, 1e-2)
<< "20dB fully masked by 90dB";
energy_channels[{0}] = {80, 20};
m.CutFullyMasked(energy_channels, 2, non_masked);
EXPECT_NEAR((non_masked[{0}][0]), 80, 1e-2) << "No self masking";
EXPECT_NEAR((non_masked[{0}][1]), -28, 1) << "20dB fully masked by 80dB";
}

void BM_FullMasking(benchmark::State& state) {
Expand Down

0 comments on commit 3e6ad14

Please sign in to comment.