Skip to content

Commit

Permalink
Relax shape checks for Whisper input features (#1446)
Browse files Browse the repository at this point in the history
* Relax shape checks for Whisper input features

* Add const

* Clarify method name
  • Loading branch information
guillaumekln authored Aug 29, 2023
1 parent 6fc333f commit f4ef902
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 29 deletions.
21 changes: 18 additions & 3 deletions include/ctranslate2/layers/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,27 @@ namespace ctranslate2 {
return _output_norm.output_size();
}

dim_t output_time() const {
dim_t max_output_time() const {
return _position_embedding.num_positions();
}

dim_t input_time() const {
return output_time() * 2;
dim_t input_size() const {
return _conv1.input_size();
}

dim_t max_input_time() const {
return max_output_time() * 2;
}

bool is_encoded(const StorageView& features) const {
// Input features shape: [batch_size, input_size, input_time]
// Encoder output shape: [batch_size, input_time // 2, output_size]
//
// input_time is variable so we check that dimension 1 is different than its original value.

return (features.rank() == 3
&& features.dim(2) == output_size()
&& features.dim(1) != input_size());
}

private:
Expand Down
20 changes: 10 additions & 10 deletions python/cpp/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ namespace ctranslate2 {
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, 80, 3000]``.
``[batch_size, 80, chunk_length]``.
to_cpu: Copy the encoder output to the CPU before returning the value.
Returns:
Expand Down Expand Up @@ -233,9 +233,9 @@ namespace ctranslate2 {
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, 80, 3000]``. This method also accepts the encoded features
returned by the method :meth:`ctranslate2.models.Whisper.encode`, which
have shape ``[batch_size, 1500, d_model]``.
``[batch_size, 80, chunk_length]``. This method also accepts the encoded
features returned by the method :meth:`ctranslate2.models.Whisper.encode`,
which have shape ``[batch_size, chunk_length // 2, d_model]``.
prompts: Batch of initial string tokens or token IDs.
asynchronous: Run the model asynchronously.
beam_size: Beam size (1 for greedy search).
Expand Down Expand Up @@ -271,9 +271,9 @@ namespace ctranslate2 {
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, 80, 3000]``. This method also accepts the encoded features
returned by the method :meth:`ctranslate2.models.Whisper.encode`, which
have shape ``[batch_size, 1500, d_model]``.
``[batch_size, 80, chunk_length]``. This method also accepts the encoded
features returned by the method :meth:`ctranslate2.models.Whisper.encode`,
which have shape ``[batch_size, chunk_length // 2, d_model]``.
Returns:
For each batch, a list of pairs (language, probability) ordered from
Expand All @@ -296,9 +296,9 @@ namespace ctranslate2 {
Arguments:
features: Mel spectogram of the audio, as a float array with shape
``[batch_size, 80, 3000]``. This method also accepts the encoded features
returned by the method :meth:`ctranslate2.models.Whisper.encode`, which
have shape ``[batch_size, 1500, d_model]``.
``[batch_size, 80, chunk_length]``. This method also accepts the encoded
features returned by the method :meth:`ctranslate2.models.Whisper.encode`,
which have shape ``[batch_size, chunk_length // 2, d_model]``.
start_sequence: The start sequence tokens.
text_tokens: Batch of text tokens to align.
num_frames: Number of non padding frames in the features.
Expand Down
10 changes: 3 additions & 7 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ def test_transformers_whisper_encode(self, tmp_dir, device):
)

@test_utils.only_on_linux
def test_transformers_whisper_invalid_shape(self, tmp_dir):
def test_transformers_whisper_partial_audio_context(self, tmp_dir):
import transformers

model_name = "openai/whisper-tiny"
Expand All @@ -963,13 +963,9 @@ def test_transformers_whisper_invalid_shape(self, tmp_dir):
features = ctranslate2.StorageView.from_array(inputs.input_features)

model = ctranslate2.models.Whisper(output_dir)
encoder_output = model.encode(features)

with pytest.raises(ValueError) as exception_info:
model.detect_language(features)

error_message = str(exception_info.value)
assert "(1, 80, 3000)" in error_message
assert "(1, 80, 1100)" in error_message
assert encoder_output.shape == [1, features.shape[2] // 2, 384]

@test_utils.only_on_linux
def test_transformers_whisper_include_tokenizer_json(self, tmp_dir):
Expand Down
10 changes: 4 additions & 6 deletions src/layers/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,18 @@ namespace ctranslate2 {
void WhisperEncoder::operator()(const StorageView& features, StorageView& output) {
PROFILE("WhisperEncoder");

const dim_t expected_depth = _conv1.input_size();
const dim_t expected_time = input_time();

if (features.rank() != 3)
throw std::invalid_argument("Expected input features to have 3 dimensions, but got "
+ std::to_string(features.rank())
+ " dimension(s) instead");
if (features.dim(1) != expected_depth || features.dim(2) != expected_time)

if (features.dim(1) != input_size() || features.dim(2) > max_input_time())
throw std::invalid_argument("Invalid input features shape: expected an input with shape ("
+ std::to_string(features.dim(0))
+ ", "
+ std::to_string(expected_depth)
+ std::to_string(input_size())
+ ", "
+ std::to_string(expected_time)
+ std::to_string(std::min(features.dim(2), max_input_time()))
+ "), but got an input with shape ("
+ std::to_string(features.dim(0))
+ ", "
Expand Down
4 changes: 1 addition & 3 deletions src/models/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ namespace ctranslate2 {

features.move_to(device, dtype);

// Already encoded.
if (features.dim(-1) == _encoder->output_size()
&& features.dim(-2) == _encoder->output_time())
if (_encoder->is_encoded(features))
return features;

StorageView encoder_output(dtype, device);
Expand Down

0 comments on commit f4ef902

Please sign in to comment.