Skip to content

Commit

Permalink
Optimize indexing in function negative_dtw (#1464)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Sep 11, 2023
1 parent eec0bf1 commit 5262f81
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions src/dtw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,20 @@

namespace ctranslate2 {

static std::vector<std::pair<dim_t, dim_t>> backtrace(StorageView trace) {
dim_t i = trace.dim(0) - 1;
dim_t j = trace.dim(1) - 1;

for (dim_t k = 0; k < trace.dim(1); ++k)
trace.at<int32_t>({0, k}) = 2;
for (dim_t k = 0; k < trace.dim(0); ++k)
trace.at<int32_t>({k, 0}) = 1;
static std::vector<std::pair<dim_t, dim_t>> backtrace(std::vector<std::vector<int>> trace,
dim_t i,
dim_t j) {
for (dim_t k = 0; k <= j; ++k)
trace[0][k] = 2;
for (dim_t k = 0; k <= i; ++k)
trace[k][0] = 1;

std::vector<std::pair<dim_t, dim_t>> result;

while (i > 0 || j > 0) {
result.emplace_back(i - 1, j - 1);

const auto t = trace.at<int32_t>({i, j});
const int t = trace[i][j];

if (t == 0) {
--i;
Expand All @@ -39,19 +38,22 @@ namespace ctranslate2 {
}

std::vector<std::pair<dim_t, dim_t>> negative_dtw(const StorageView& x) {
constexpr float inf = std::numeric_limits<float>::infinity();
const dim_t n = x.dim(0);
const dim_t m = x.dim(1);

StorageView cost({n + 1, m + 1}, std::numeric_limits<float>::infinity());
StorageView trace({n + 1, m + 1}, int32_t(-1));
std::vector<std::vector<float>> cost(n + 1, std::vector<float>(m + 1, inf));
std::vector<std::vector<int>> trace(n + 1, std::vector<int>(m + 1, -1));

cost[0][0] = 0;

cost.at<float>({0, 0}) = 0;
const auto* x_data = x.data<float>();

for (dim_t j = 1; j < m + 1; ++j) {
for (dim_t i = 1; i < n + 1; ++i) {
const float c0 = cost.at<float>({i - 1, j - 1});
const float c1 = cost.at<float>({i - 1, j});
const float c2 = cost.at<float>({i, j - 1});
const float c0 = cost[i - 1][j - 1];
const float c1 = cost[i - 1][j];
const float c2 = cost[i][j - 1];

float c = 0;
int t = 0;
Expand All @@ -67,12 +69,14 @@ namespace ctranslate2 {
t = 2;
}

cost.at<float>({i, j}) = -x.at<float>({i - 1, j - 1}) + c;
trace.at<int32_t>({i, j}) = t;
const float v = x_data[(i - 1) * x.dim(1) + (j - 1)];

cost[i][j] = -v + c;
trace[i][j] = t;
}
}

return backtrace(std::move(trace));
return backtrace(std::move(trace), n, m);
}

}

0 comments on commit 5262f81

Please sign in to comment.