Skip to content

Commit

Permalink
Fix formatting and default argument values
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Aug 28, 2023
1 parent 67182ef commit cc34ad6
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 25 deletions.
12 changes: 8 additions & 4 deletions include/ctranslate2/encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ namespace ctranslate2 {
using ReplicaPool::ReplicaPool;

std::future<EncoderForwardOutput>
forward_batch_async(std::vector<std::vector<std::string>> tokens, std::vector<std::vector<size_t>> token_type_ids);
forward_batch_async(std::vector<std::vector<std::string>> tokens,
std::vector<std::vector<size_t>> token_type_ids = {});

std::future<EncoderForwardOutput>
forward_batch_async(std::vector<std::vector<size_t>> ids, std::vector<std::vector<size_t>> token_type_ids);

forward_batch_async(std::vector<std::vector<size_t>> ids,
std::vector<std::vector<size_t>> token_type_ids = {});

std::future<EncoderForwardOutput>
forward_batch_async(const StorageView& ids, const StorageView& lengths, std::vector<std::vector<size_t>> token_type_ids);
forward_batch_async(const StorageView& ids,
const StorageView& lengths,
std::vector<std::vector<size_t>> token_type_ids = {});
};

}
18 changes: 13 additions & 5 deletions include/ctranslate2/models/language_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,19 @@ namespace ctranslate2 {
return model.as_sequence_encoder();
}

EncoderForwardOutput forward(const std::vector<std::vector<std::string>>& tokens, const std::vector<std::vector<size_t>>& token_type_ids);
EncoderForwardOutput forward(const std::vector<std::vector<size_t>>& ids, const std::vector<std::vector<size_t>>& token_type_ids);
EncoderForwardOutput forward(const StorageView& ids, const StorageView& lengths, const std::vector<std::vector<size_t>>& token_type_ids);
EncoderForwardOutput forward(const std::vector<std::vector<std::string>>& tokens,
const std::vector<std::vector<size_t>>& token_type_ids = {});
EncoderForwardOutput forward(const std::vector<std::vector<size_t>>& ids,
const std::vector<std::vector<size_t>>& token_type_ids = {});
EncoderForwardOutput forward(const StorageView& ids,
const StorageView& lengths,
const std::vector<std::vector<size_t>>& token_type_ids = {});

protected:
virtual EncoderForwardOutput
forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) = 0;
forward_impl(const StorageView& ids,
const StorageView& lengths,
const StorageView& token_type_ids) = 0;

private:
const std::shared_ptr<const LanguageModel> _model;
Expand All @@ -144,7 +150,9 @@ namespace ctranslate2 {

protected:
EncoderForwardOutput
forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) override;
forward_impl(const StorageView& ids,
const StorageView& lengths,
const StorageView& token_type_ids) override;

private:
const std::shared_ptr<const LanguageModel> _model;
Expand Down
13 changes: 10 additions & 3 deletions python/cpp/encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,22 @@ namespace ctranslate2 {

switch (inputs.index()) {
case 0:
future = _pool->forward_batch_async(std::get<BatchTokens>(inputs), token_type_ids.value_or(std::vector<std::vector<size_t>>()));
future = _pool->forward_batch_async(
std::get<BatchTokens>(inputs),
token_type_ids.value_or(std::vector<std::vector<size_t>>()));
break;
case 1:
future = _pool->forward_batch_async(std::get<BatchIds>(inputs), token_type_ids.value_or(std::vector<std::vector<size_t>>()));
future = _pool->forward_batch_async(
std::get<BatchIds>(inputs),
token_type_ids.value_or(std::vector<std::vector<size_t>>()));
break;
case 2:
if (!lengths)
throw std::invalid_argument("lengths vector is required when passing a dense input");
future = _pool->forward_batch_async(std::get<StorageView>(inputs), lengths.value(), token_type_ids.value_or(std::vector<std::vector<size_t>>()));
future = _pool->forward_batch_async(
std::get<StorageView>(inputs),
lengths.value(),
token_type_ids.value_or(std::vector<std::vector<size_t>>()));
break;
}

Expand Down
18 changes: 12 additions & 6 deletions src/encoder.cc
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#include "ctranslate2/encoder.h"

namespace ctranslate2 {

std::future<EncoderForwardOutput>
Encoder::forward_batch_async(std::vector<std::vector<std::string>> tokens, std::vector<std::vector<size_t>> token_type_ids = {}) {
Encoder::forward_batch_async(std::vector<std::vector<std::string>> tokens,
std::vector<std::vector<size_t>> token_type_ids) {
return post<EncoderForwardOutput>(
[tokens = std::move(tokens), token_type_ids = std::move(token_type_ids)]
(models::SequenceEncoderReplica& encoder) {
return encoder.forward(tokens, token_type_ids);
});
}

std::future<EncoderForwardOutput>
Encoder::forward_batch_async(std::vector<std::vector<size_t>> ids, std::vector<std::vector<size_t>> token_type_ids = {}) {
Encoder::forward_batch_async(std::vector<std::vector<size_t>> ids,
std::vector<std::vector<size_t>> token_type_ids) {
return post<EncoderForwardOutput>(
[ids = std::move(ids), token_type_ids = std::move(token_type_ids)]
(models::SequenceEncoderReplica& encoder) {
Expand All @@ -21,9 +23,13 @@ namespace ctranslate2 {
}

std::future<EncoderForwardOutput>
Encoder::forward_batch_async(const StorageView& ids, const StorageView& lengths, std::vector<std::vector<size_t>> token_type_ids = {}) {
Encoder::forward_batch_async(const StorageView& ids,
const StorageView& lengths,
std::vector<std::vector<size_t>> token_type_ids) {
return post<EncoderForwardOutput>(
[ids = ids.sync_copy(), lengths = lengths.sync_copy(), token_type_ids = std::move(token_type_ids)]
[ids = ids.sync_copy(),
lengths = lengths.sync_copy(),
token_type_ids = std::move(token_type_ids)]
(models::SequenceEncoderReplica& encoder) {
return encoder.forward(ids, lengths, token_type_ids);
});
Expand Down
21 changes: 14 additions & 7 deletions src/models/language_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,22 +294,27 @@ namespace ctranslate2 {
decoder(ids, lengths, state, logits);
return logits;
}



EncoderForwardOutput
SequenceEncoderReplica::forward(const std::vector<std::vector<std::string>>& tokens, const std::vector<std::vector<size_t>>& token_type_ids = {}) {
SequenceEncoderReplica::forward(const std::vector<std::vector<std::string>>& tokens,
const std::vector<std::vector<size_t>>& token_type_ids) {
const auto& vocabulary = _model->get_vocabulary();
return forward(vocabulary.to_ids(tokens), token_type_ids);
}

EncoderForwardOutput
SequenceEncoderReplica::forward(const std::vector<std::vector<size_t>>& ids, const std::vector<std::vector<size_t>>& token_type_ids = {}) {
SequenceEncoderReplica::forward(const std::vector<std::vector<size_t>>& ids,
const std::vector<std::vector<size_t>>& token_type_ids) {
StorageView lengths;
StorageView input_ids = layers::make_sequence_inputs(ids, Device::CPU, 1, &lengths);
return forward(input_ids, lengths, token_type_ids);
}

EncoderForwardOutput
SequenceEncoderReplica::forward(const StorageView& ids, const StorageView& lengths, const std::vector<std::vector<size_t>>& token_type_ids = {}) {
SequenceEncoderReplica::forward(const StorageView& ids,
const StorageView& lengths,
const std::vector<std::vector<size_t>>& token_type_ids) {
PROFILE("SequenceEncoderReplica::forward");
const auto& model = *this->model();
const auto device = model.device();
Expand Down Expand Up @@ -342,7 +347,9 @@ namespace ctranslate2 {
}

EncoderForwardOutput
EncoderReplica::forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) {
EncoderReplica::forward_impl(const StorageView& ids,
const StorageView& lengths,
const StorageView& token_type_ids) {
if (ids.rank() != 2)
throw std::invalid_argument("Expected input ids to have 2 dimensions, but got "
+ std::to_string(ids.rank())
Expand Down

0 comments on commit cc34ad6

Please sign in to comment.