Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept variable-length batch prompts for Whisper #1784

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ set(SOURCES
src/ops/mul.cc
src/ops/multinomial.cc
src/ops/multinomial_cpu.cc
src/ops/position_encodings_add.cc
src/ops/position_encodings_add_cpu.cc
src/ops/quantize.cc
src/ops/quantize_cpu.cc
src/ops/relu.cc
Expand Down Expand Up @@ -569,6 +571,7 @@ if (WITH_CUDA)
src/ops/layer_norm_gpu.cu
src/ops/mean_gpu.cu
src/ops/multinomial_gpu.cu
src/ops/position_encodings_add_gpu.cu
src/ops/rms_norm_gpu.cu
src/ops/rotary_gpu.cu
src/ops/softmax_gpu.cu
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace ctranslate2 {
virtual void operator()(const StorageView& queries,
const StorageView& values,
const StorageView* values_lengths,
const StorageView* values_offsets,
StorageView& output,
StorageView* cached_keys = nullptr,
StorageView* cached_values = nullptr,
Expand Down
6 changes: 5 additions & 1 deletion include/ctranslate2/layers/attention_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace ctranslate2 {
virtual void operator()(const StorageView& queries,
const StorageView& values,
const StorageView* values_lengths,
const StorageView* values_offsets,
StorageView& output,
StorageView* cached_keys = nullptr,
StorageView* cached_values = nullptr,
Expand All @@ -49,7 +50,10 @@ namespace ctranslate2 {
const dim_t num_heads,
const dim_t num_queries,
const bool mask_future = false,
const bool multi_query = false);
const bool multi_query = false,
const dim_t step = 0,
const StorageView* offsets = nullptr,
StorageView* values_offsets = nullptr);

protected:
const bool _tensor_parallel;
Expand Down
8 changes: 6 additions & 2 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,14 @@ namespace ctranslate2 {
// Base class for position encoders.
class PositionEncoder : public Layer {
public:
void operator()(StorageView& input, dim_t index = 0);
void operator()(const StorageView& input, StorageView& output, dim_t index = 0);
void operator()(const StorageView& input,
StorageView& output,
dim_t step = 0,
const StorageView* offsets = nullptr);
protected:
virtual const StorageView& get_position_encoding(dim_t max_time) = 0;
private:
ops::PositionEncodingsAdd _add_op;
};

// Concrete position encoder loading encoding vectors from the model.
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/layers/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace ctranslate2 {
void operator()(const StorageView& queries,
const StorageView& values,
const StorageView* values_lengths,
const StorageView* values_offsets,
StorageView& output,
StorageView* cached_keys = nullptr,
StorageView* cached_values = nullptr,
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ namespace ctranslate2 {

void operator()(const StorageView& input,
const StorageView* input_lengths,
const StorageView* input_offsets,
const StorageView* memory,
const StorageView* memory_lengths,
StorageView* cached_self_attn_keys,
Expand Down
2 changes: 1 addition & 1 deletion include/ctranslate2/models/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ namespace ctranslate2 {

std::vector<WhisperGenerationResult>
generate(StorageView features,
const std::vector<std::vector<size_t>>& prompts,
std::vector<std::vector<size_t>> prompts,
const WhisperOptions& options);

std::vector<std::vector<std::pair<std::string, float>>>
Expand Down
1 change: 1 addition & 0 deletions include/ctranslate2/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@
#include "awq/gemv.h"
#include "awq/dequantize_awq.h"
#include "sum.h"
#include "position_encodings_add.h"
26 changes: 26 additions & 0 deletions include/ctranslate2/ops/position_encodings_add.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "op.h"

namespace ctranslate2 {
namespace ops {

class PositionEncodingsAdd : public Op {
public:
void operator()(const StorageView& input,
const StorageView& encodings,
StorageView& output,
const StorageView* offsets = nullptr,
const dim_t step = 0) const;

private:
template <Device D, typename T>
void compute(const dim_t step,
const StorageView* offsets,
const StorageView& input,
const StorageView& encodings,
StorageView& output) const;
};

}
}
14 changes: 12 additions & 2 deletions include/ctranslate2/ops/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@ namespace ctranslate2 {
void operator()(StorageView& x) const;
void operator()(const StorageView& x, StorageView& y) const override;
void operator()(const StorageView& x, const StorageView& lengths, StorageView& y) const;
void operator()(const StorageView& x, const StorageView* lengths, StorageView& y) const;
void operator()(const StorageView& x,
const StorageView& lengths,
const StorageView& offsets,
StorageView& y) const;
void operator()(const StorageView& x,
const StorageView* lengths,
const StorageView* offsets,
StorageView& y) const;

private:
template <Device D, typename T>
void compute(const StorageView& input, const StorageView* lengths, StorageView& output) const;
void compute(const StorageView& input,
const StorageView* lengths,
const StorageView* offsets,
StorageView& output) const;

bool _log;
};
Expand Down
10 changes: 10 additions & 0 deletions include/ctranslate2/padder.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,23 @@ namespace ctranslate2 {
const dim_t max_time = -1,
const dim_t pad_batch_to_multiple = 1);

Padder(const StorageView& lengths,
const StorageView* offsets,
const dim_t max_time = -1,
const dim_t pad_batch_to_multiple = 1);

// Merge batch and time dimensions and remove padding.
void remove_padding(StorageView& x) const;

// Split first dimension into batch and time dimensions and add padding.
void add_padding(StorageView& x) const;

private:
void initialize(const StorageView& lengths,
const StorageView* offsets,
const dim_t max_time,
const dim_t pad_batch_to_multiple);

dim_t _batch_size;
dim_t _max_time;
StorageView _padded_to_flat;
Expand Down
17 changes: 10 additions & 7 deletions include/ctranslate2/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,16 @@ namespace ctranslate2 {
dim_t length,
dim_t vocabulary_size);

static void prepare_length_mask(const int32_t* lengths,
dim_t batch_size,
dim_t num_heads,
dim_t num_queries,
bool mask_future,
bool multi_query,
int32_t* mask);
static void prepare_mha_values_mask(const int32_t* lengths,
const int32_t* offsets,
dim_t batch_size,
dim_t num_heads,
dim_t num_queries,
bool mask_future,
bool multi_query,
dim_t step,
int32_t* values_lengths,
int32_t* values_offsets);

template <typename T>
static void transpose_2d(const T* a, const dim_t* dims, T* b);
Expand Down
26 changes: 26 additions & 0 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,32 @@ def teardown_class(cls):
pytest.approx(0.062380101531744, abs=1e-3),
],
),
(
"openai/whisper-tiny.en",
[
["<|startoftranscript|>"],
[
"<|startofprev|>",
"ĠAnd",
"Ġthen",
"Ġthe",
"ĠPresident",
"Ġshouted",
":",
"<|startoftranscript|>",
],
],
[
" Mr. Quilter is the apostle of the middle classes, and we are glad"
" to welcome his gospel.",
" And so my fellow Americans ask not what your country can do for you,"
" ask what you can do for your country.",
],
[
pytest.approx(0.02644546702504158, abs=1e-4),
pytest.approx(0.008309835568070412, abs=1e-3),
],
),
],
)
def test_transformers_whisper(
Expand Down
26 changes: 16 additions & 10 deletions src/cpu/kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ namespace ctranslate2 {
template<>
void softmax<TARGET_ISA>(const float* input,
const int32_t* lengths,
const int32_t* offsets,
float* output,
dim_t batch_size,
dim_t depth,
Expand All @@ -410,24 +411,29 @@ namespace ctranslate2 {

parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) {
for (dim_t i = begin; i < end; ++i) {
const dim_t start = offsets ? offsets[i] : 0;
const dim_t size = lengths ? lengths[i] : depth - start;

const dim_t offset = i * depth;
const float* x = input + offset;
float* y = output + offset;

dim_t size = depth;
if (lengths) {
size = lengths[i];
// Directly set 0 in output for out of range positions.

// Directly set 0 in output for out of range positions.
for (dim_t j = size; j < depth; ++j) {
if (size <= 0) {
for (dim_t j = 0; j < depth; ++j)
y[j] = 0;
}

if (size == 0) {
continue;
}
continue;
}

for (dim_t j = 0; j < start; ++j)
y[j] = 0;
for (dim_t j = start + size; j < depth; ++j)
y[j] = 0;

x += start;
y += start;

const auto x_max = reduce_max<TARGET_ISA>(x, size);
const auto vec_x_max = VecType::load(x_max);

Expand Down
1 change: 1 addition & 0 deletions src/cpu/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ namespace ctranslate2 {
template <CpuIsa ISA>
void softmax(const float* input,
const int32_t* lengths,
const int32_t* offsets,
float* output,
dim_t batch_size,
dim_t depth,
Expand Down
38 changes: 24 additions & 14 deletions src/cpu/primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,21 +414,31 @@ namespace ctranslate2 {
}

template<>
void primitives<Device::CPU>::prepare_length_mask(const int32_t* lengths,
dim_t batch_size,
dim_t num_heads,
dim_t num_queries,
bool mask_future,
bool multi_query,
int32_t* mask) {
void primitives<Device::CPU>::prepare_mha_values_mask(const int32_t* lengths,
const int32_t* offsets,
dim_t batch_size,
dim_t num_heads,
dim_t num_queries,
bool mask_future,
bool multi_query,
dim_t step,
int32_t* values_lengths,
int32_t* values_offsets) {
for (dim_t b = 0; b < batch_size; ++b) {
const auto length = lengths[b];
auto* batch_mask = mask + b * num_heads * num_queries;
for (dim_t i = 0; i < num_heads * num_queries; ++i) {
batch_mask[i] = (mask_future
? std::min(length,
int32_t((multi_query ? i / num_heads : i % num_queries) + 1))
: length);
const auto offset = offsets ? offsets[b] : 0;
const auto length = lengths[b] + int32_t(step) - offset;
const auto batch_offset = b * num_heads * num_queries;

for (dim_t i = batch_offset; i < batch_offset + num_heads * num_queries; ++i) {
if (mask_future) {
const int32_t time = step + (multi_query ? i / num_heads : i % num_queries);
values_lengths[i] = time < offset ? 0 : std::min(time - offset + 1, length);
} else {
values_lengths[i] = length;
}

if (values_offsets)
values_offsets[i] = offset;
}
}
}
Expand Down
Loading
Loading