From dc9357f2e479f382d0f2ac9fb7a5cee53fc99e47 Mon Sep 17 00:00:00 2001 From: hamish Date: Mon, 14 Aug 2023 12:33:17 +0100 Subject: [PATCH 1/4] token type ids can be set by optional argument at python wrapper --- include/ctranslate2/encoder.h | 9 +++++ include/ctranslate2/models/language_model.h | 7 ++-- python/cpp/encoder.cc | 22 +++++++----- src/encoder.cc | 27 ++++++++++++++ src/models/language_model.cc | 39 ++++++++++++++++----- 5 files changed, 86 insertions(+), 18 deletions(-) diff --git a/include/ctranslate2/encoder.h b/include/ctranslate2/encoder.h index b3d264ddf..e0360300d 100644 --- a/include/ctranslate2/encoder.h +++ b/include/ctranslate2/encoder.h @@ -13,11 +13,20 @@ namespace ctranslate2 { std::future forward_batch_async(std::vector> tokens); + std::future + forward_batch_async(std::vector> tokens, std::vector> token_type_ids); + std::future forward_batch_async(std::vector> ids); + std::future + forward_batch_async(std::vector> ids, std::vector> token_type_ids); + std::future forward_batch_async(const StorageView& ids, const StorageView& lengths); + + std::future + forward_batch_async(const StorageView& ids, const StorageView& lengths, std::vector> token_type_ids); }; } diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h index fdab65d88..3689bf66d 100644 --- a/include/ctranslate2/models/language_model.h +++ b/include/ctranslate2/models/language_model.h @@ -124,12 +124,15 @@ namespace ctranslate2 { } EncoderForwardOutput forward(const std::vector>& tokens); + EncoderForwardOutput forward(const std::vector>& tokens, const std::vector>& token_type_ids); EncoderForwardOutput forward(const std::vector>& ids); + EncoderForwardOutput forward(const std::vector>& ids, const std::vector>& token_type_ids); EncoderForwardOutput forward(const StorageView& ids, const StorageView& lengths); + EncoderForwardOutput forward(const StorageView& ids, const StorageView& lengths, const std::vector>& token_type_ids); protected: virtual EncoderForwardOutput - forward_impl(const StorageView& ids, const StorageView& lengths) = 0; + forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) = 0; private: const std::shared_ptr _model; @@ -144,7 +147,7 @@ namespace ctranslate2 { protected: EncoderForwardOutput - forward_impl(const StorageView& ids, const StorageView& lengths) override; + forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) override; private: const std::shared_ptr _model; diff --git a/python/cpp/encoder.cc b/python/cpp/encoder.cc index 2263f1b99..cbad9c0d6 100644 --- a/python/cpp/encoder.cc +++ b/python/cpp/encoder.cc @@ -13,20 +13,25 @@ namespace ctranslate2 { EncoderForwardOutput forward_batch(const std::variant& inputs, - const std::optional& lengths) { + const std::optional& lengths, + const std::optional& token_type_ids + ) { std::future future; switch (inputs.index()) { case 0: - future = _pool->forward_batch_async(std::get(inputs)); + future = (token_type_ids) ? _pool->forward_batch_async(std::get(inputs), token_type_ids.value()) + : _pool->forward_batch_async(std::get(inputs)); break; case 1: - future = _pool->forward_batch_async(std::get(inputs)); + future = (token_type_ids) ? _pool->forward_batch_async(std::get(inputs), token_type_ids.value()) + : _pool->forward_batch_async(std::get(inputs)); break; case 2: if (!lengths) throw std::invalid_argument("lengths vector is required when passing a dense input"); - future = _pool->forward_batch_async(std::get(inputs), lengths.value()); + future = (token_type_ids) ? _pool->forward_batch_async(std::get(inputs), lengths.value(), token_type_ids.value()) + : _pool->forward_batch_async(std::get(inputs), lengths.value()); break; } @@ -81,8 +86,8 @@ namespace ctranslate2 { device: Device to use (possible values are: cpu, cuda, auto). device_index: Device IDs where to place this encoder on. compute_type: Model computation type or a dictionary mapping a device name - to the computation type (possible values are: default, auto, int8, int8_float32, - int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + to the computation type (possible values are: default, auto, int8, int8_float16, + int8_bfloat16, int16, float16, bfloat16, float32). inter_threads: Maximum number of parallel generations. intra_threads: Number of OpenMP threads per encoder (0 to use a default value). max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, @@ -97,8 +102,6 @@ namespace ctranslate2 { "Device this encoder is running on.") .def_property_readonly("device_index", &EncoderWrapper::device_index, "List of device IDs where this encoder is running on.") - .def_property_readonly("compute_type", &EncoderWrapper::compute_type, - "Computation type used by the model.") .def_property_readonly("num_encoders", &EncoderWrapper::num_replicas, "Number of encoders backing this instance.") .def_property_readonly("num_queued_batches", &EncoderWrapper::num_queued_batches, @@ -109,6 +112,7 @@ namespace ctranslate2 { .def("forward_batch", &EncoderWrapper::forward_batch, py::arg("inputs"), py::arg("lengths")=py::none(), + py::arg("token_type_ids")=py::none(), py::call_guard(), R"pbdoc( Forwards a batch of sequences in the encoder. @@ -119,6 +123,8 @@ namespace ctranslate2 { ``[batch_size, max_length]`` (e.g. created from a Numpy array or PyTorch tensor). lengths: The length of each sequence as a int32 array with shape ``[batch_size]``. Required when :obj:`inputs` is a dense array. + token_type_ids: A batch of token type IDs of same shape as :obj:`inputs`. + ``[batch_size, max_length]``. Returns: The encoder model output. diff --git a/src/encoder.cc b/src/encoder.cc index efbe49d7e..177062275 100644 --- a/src/encoder.cc +++ b/src/encoder.cc @@ -10,6 +10,15 @@ namespace ctranslate2 { return encoder.forward(tokens); }); } + + std::future + Encoder::forward_batch_async(std::vector> tokens, std::vector> token_type_ids) { + return post( + [tokens = std::move(tokens), token_type_ids = std::move(token_type_ids)] + (models::SequenceEncoderReplica& encoder) { + return encoder.forward(tokens, token_type_ids); + }); + } std::future Encoder::forward_batch_async(std::vector> ids) { @@ -19,6 +28,15 @@ namespace ctranslate2 { return encoder.forward(ids); }); } + + std::future + Encoder::forward_batch_async(std::vector> ids, std::vector> token_type_ids) { + return post( + [ids = std::move(ids), token_type_ids = std::move(token_type_ids)] + (models::SequenceEncoderReplica& encoder) { + return encoder.forward(ids, token_type_ids); + }); + } std::future Encoder::forward_batch_async(const StorageView& ids, const StorageView& lengths) { @@ -29,4 +47,13 @@ namespace ctranslate2 { }); } + std::future + Encoder::forward_batch_async(const StorageView& ids, const StorageView& lengths, std::vector> token_type_ids) { + return post( + [ids = ids.sync_copy(), lengths = lengths.sync_copy(), token_type_ids = std::move(token_type_ids)] + (models::SequenceEncoderReplica& encoder) { + return encoder.forward(ids, lengths, token_type_ids); + }); + } + } diff --git a/src/models/language_model.cc b/src/models/language_model.cc index fb09d18c7..58fa17d06 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -298,30 +298,49 @@ namespace ctranslate2 { EncoderForwardOutput SequenceEncoderReplica::forward(const std::vector>& tokens) { + std::vector> token_type_ids; + return forward(tokens, token_type_ids); + } + + EncoderForwardOutput + SequenceEncoderReplica::forward(const std::vector>& tokens, const std::vector>& token_type_ids) { const auto& vocabulary = _model->get_vocabulary(); - return forward(vocabulary.to_ids(tokens)); + return forward(vocabulary.to_ids(tokens), token_type_ids); } EncoderForwardOutput SequenceEncoderReplica::forward(const std::vector>& ids) { + std::vector> token_type_ids; + return forward(ids, token_type_ids); + } + + EncoderForwardOutput + SequenceEncoderReplica::forward(const std::vector>& ids, const std::vector>& token_type_ids) { StorageView lengths; StorageView input_ids = layers::make_sequence_inputs(ids, Device::CPU, 1, &lengths); - return forward(input_ids, lengths); + return forward(input_ids, lengths, token_type_ids); } EncoderForwardOutput SequenceEncoderReplica::forward(const StorageView& ids, const StorageView& lengths) { + std::vector> token_type_ids; + return forward(ids, lengths, token_type_ids); + } + + EncoderForwardOutput + SequenceEncoderReplica::forward(const StorageView& ids, const StorageView& lengths, const std::vector>& token_type_ids) { PROFILE("SequenceEncoderReplica::forward"); const auto& model = *this->model(); const auto device = model.device(); const auto scoped_device_setter = model.get_scoped_device_setter(); + StorageView input_token_type_ids = layers::make_sequence_inputs(token_type_ids, Device::CPU, 1, nullptr); EncoderForwardOutput output; if (ids.device() != device) - output = forward_impl(ids.to(device), lengths.to(device)); + output = forward_impl(ids.to(device), lengths.to(device), input_token_type_ids.to(device)); else - output = forward_impl(ids, lengths); + output = forward_impl(ids, lengths, input_token_type_ids); // Ensure all operations are finished before returning the output. synchronize_stream(device); @@ -342,7 +361,7 @@ namespace ctranslate2 { } EncoderForwardOutput - EncoderReplica::forward_impl(const StorageView& ids, const StorageView& lengths) { + EncoderReplica::forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) { if (ids.rank() != 2) throw std::invalid_argument("Expected input ids to have 2 dimensions, but got " + std::to_string(ids.rank()) @@ -360,9 +379,13 @@ namespace ctranslate2 { std::vector inputs{ids}; if (_encoder->num_input_features() > 1) { - StorageView token_type_ids(ids.shape(), ids.dtype(), device); - token_type_ids.zero(); - inputs.emplace_back(std::move(token_type_ids)); + if (token_type_ids.empty()) { + StorageView placeholder_type_ids(ids.shape(), ids.dtype(), device); + placeholder_type_ids.zero(); + inputs.emplace_back(std::move(placeholder_type_ids)); + } else { + inputs.emplace_back(std::move(token_type_ids)); + } } StorageView last_hidden_state(dtype, device); From 8baff132247f891504e30a6590b2f4ff7bd829cd Mon Sep 17 00:00:00 2001 From: hamish Date: Fri, 18 Aug 2023 14:36:37 +0100 Subject: [PATCH 2/4] changed function overloads to function defaults --- include/ctranslate2/encoder.h | 9 ------ include/ctranslate2/models/language_model.h | 3 -- python/cpp/encoder.cc | 17 +++++------ src/encoder.cc | 33 ++------------------- src/models/language_model.cc | 31 ++++--------------- 5 files changed, 17 insertions(+), 76 deletions(-) diff --git a/include/ctranslate2/encoder.h b/include/ctranslate2/encoder.h index e0360300d..521f2611e 100644 --- a/include/ctranslate2/encoder.h +++ b/include/ctranslate2/encoder.h @@ -10,20 +10,11 @@ namespace ctranslate2 { public: using ReplicaPool::ReplicaPool; - std::future - forward_batch_async(std::vector> tokens); - std::future forward_batch_async(std::vector> tokens, std::vector> token_type_ids); - std::future - forward_batch_async(std::vector> ids); - std::future forward_batch_async(std::vector> ids, std::vector> token_type_ids); - - std::future - forward_batch_async(const StorageView& ids, const StorageView& lengths); std::future forward_batch_async(const StorageView& ids, const StorageView& lengths, std::vector> token_type_ids); diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h index 3689bf66d..7c583afab 100644 --- a/include/ctranslate2/models/language_model.h +++ b/include/ctranslate2/models/language_model.h @@ -123,11 +123,8 @@ namespace ctranslate2 { return model.as_sequence_encoder(); } - EncoderForwardOutput forward(const std::vector>& tokens); EncoderForwardOutput forward(const std::vector>& tokens, const std::vector>& token_type_ids); - EncoderForwardOutput forward(const std::vector>& ids); EncoderForwardOutput forward(const std::vector>& ids, const std::vector>& token_type_ids); - EncoderForwardOutput forward(const StorageView& ids, const StorageView& lengths); EncoderForwardOutput forward(const StorageView& ids, const StorageView& lengths, const std::vector>& token_type_ids); protected: diff --git a/python/cpp/encoder.cc b/python/cpp/encoder.cc index cbad9c0d6..2b1cab353 100644 --- a/python/cpp/encoder.cc +++ b/python/cpp/encoder.cc @@ -20,18 +20,15 @@ namespace ctranslate2 { switch (inputs.index()) { case 0: - future = (token_type_ids) ? _pool->forward_batch_async(std::get(inputs), token_type_ids.value()) - : _pool->forward_batch_async(std::get(inputs)); + future = _pool->forward_batch_async(std::get(inputs), token_type_ids.value_or(std::vector>())); break; case 1: - future = (token_type_ids) ? _pool->forward_batch_async(std::get(inputs), token_type_ids.value()) - : _pool->forward_batch_async(std::get(inputs)); + future = _pool->forward_batch_async(std::get(inputs), token_type_ids.value_or(std::vector>())); break; case 2: if (!lengths) throw std::invalid_argument("lengths vector is required when passing a dense input"); - future = (token_type_ids) ? _pool->forward_batch_async(std::get(inputs), lengths.value(), token_type_ids.value()) - : _pool->forward_batch_async(std::get(inputs), lengths.value()); + future = _pool->forward_batch_async(std::get(inputs), lengths.value(), token_type_ids.value_or(std::vector>())); break; } @@ -86,8 +83,8 @@ namespace ctranslate2 { device: Device to use (possible values are: cpu, cuda, auto). device_index: Device IDs where to place this encoder on. compute_type: Model computation type or a dictionary mapping a device name - to the computation type (possible values are: default, auto, int8, int8_float16, - int8_bfloat16, int16, float16, bfloat16, float32). + to the computation type (possible values are: default, auto, int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). inter_threads: Maximum number of parallel generations. intra_threads: Number of OpenMP threads per encoder (0 to use a default value). max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, @@ -102,6 +99,8 @@ namespace ctranslate2 { "Device this encoder is running on.") .def_property_readonly("device_index", &EncoderWrapper::device_index, "List of device IDs where this encoder is running on.") + .def_property_readonly("compute_type", &EncoderWrapper::compute_type, + "Computation type used by the model.") .def_property_readonly("num_encoders", &EncoderWrapper::num_replicas, "Number of encoders backing this instance.") .def_property_readonly("num_queued_batches", &EncoderWrapper::num_queued_batches, @@ -124,7 +123,7 @@ namespace ctranslate2 { lengths: The length of each sequence as a int32 array with shape ``[batch_size]``. Required when :obj:`inputs` is a dense array. token_type_ids: A batch of token type IDs of same shape as :obj:`inputs`. - ``[batch_size, max_length]``. + ``[batch_size, max_length]``. Returns: The encoder model output. diff --git a/src/encoder.cc b/src/encoder.cc index 177062275..1e3a35ca5 100644 --- a/src/encoder.cc +++ b/src/encoder.cc @@ -1,36 +1,18 @@ #include "ctranslate2/encoder.h" namespace ctranslate2 { - - std::future - Encoder::forward_batch_async(std::vector> tokens) { - return post( - [tokens = std::move(tokens)] - (models::SequenceEncoderReplica& encoder) { - return encoder.forward(tokens); - }); - } std::future - Encoder::forward_batch_async(std::vector> tokens, std::vector> token_type_ids) { + Encoder::forward_batch_async(std::vector> tokens, std::vector> token_type_ids = {}) { return post( [tokens = std::move(tokens), token_type_ids = std::move(token_type_ids)] (models::SequenceEncoderReplica& encoder) { return encoder.forward(tokens, token_type_ids); }); } - - std::future - Encoder::forward_batch_async(std::vector> ids) { - return post( - [ids = std::move(ids)] - (models::SequenceEncoderReplica& encoder) { - return encoder.forward(ids); - }); - } std::future - Encoder::forward_batch_async(std::vector> ids, std::vector> token_type_ids) { + Encoder::forward_batch_async(std::vector> ids, std::vector> token_type_ids = {}) { return post( [ids = std::move(ids), token_type_ids = std::move(token_type_ids)] (models::SequenceEncoderReplica& encoder) { @@ -39,16 +21,7 @@ namespace ctranslate2 { } std::future - Encoder::forward_batch_async(const StorageView& ids, const StorageView& lengths) { - return post( - [ids = ids.sync_copy(), lengths = lengths.sync_copy()] - (models::SequenceEncoderReplica& encoder) { - return encoder.forward(ids, lengths); - }); - } - - std::future - Encoder::forward_batch_async(const StorageView& ids, const StorageView& lengths, std::vector> token_type_ids) { + Encoder::forward_batch_async(const StorageView& ids, const StorageView& lengths, std::vector> token_type_ids = {}) { return post( [ids = ids.sync_copy(), lengths = lengths.sync_copy(), token_type_ids = std::move(token_type_ids)] (models::SequenceEncoderReplica& encoder) { diff --git a/src/models/language_model.cc b/src/models/language_model.cc index 58fa17d06..e7a60f5c4 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -294,51 +294,32 @@ namespace ctranslate2 { decoder(ids, lengths, state, logits); return logits; } - - - EncoderForwardOutput - SequenceEncoderReplica::forward(const std::vector>& tokens) { - std::vector> token_type_ids; - return forward(tokens, token_type_ids); - } EncoderForwardOutput - SequenceEncoderReplica::forward(const std::vector>& tokens, const std::vector>& token_type_ids) { + SequenceEncoderReplica::forward(const std::vector>& tokens, const std::vector>& token_type_ids = {}) { const auto& vocabulary = _model->get_vocabulary(); return forward(vocabulary.to_ids(tokens), token_type_ids); } - - EncoderForwardOutput - SequenceEncoderReplica::forward(const std::vector>& ids) { - std::vector> token_type_ids; - return forward(ids, token_type_ids); - } EncoderForwardOutput - SequenceEncoderReplica::forward(const std::vector>& ids, const std::vector>& token_type_ids) { + SequenceEncoderReplica::forward(const std::vector>& ids, const std::vector>& token_type_ids = {}) { StorageView lengths; StorageView input_ids = layers::make_sequence_inputs(ids, Device::CPU, 1, &lengths); return forward(input_ids, lengths, token_type_ids); } - - EncoderForwardOutput - SequenceEncoderReplica::forward(const StorageView& ids, const StorageView& lengths) { - std::vector> token_type_ids; - return forward(ids, lengths, token_type_ids); - } EncoderForwardOutput - SequenceEncoderReplica::forward(const StorageView& ids, const StorageView& lengths, const std::vector>& token_type_ids) { + SequenceEncoderReplica::forward(const StorageView& ids, const StorageView& lengths, const std::vector>& token_type_ids = {}) { PROFILE("SequenceEncoderReplica::forward"); const auto& model = *this->model(); const auto device = model.device(); const auto scoped_device_setter = model.get_scoped_device_setter(); - StorageView input_token_type_ids = layers::make_sequence_inputs(token_type_ids, Device::CPU, 1, nullptr); + StorageView input_token_type_ids = layers::make_sequence_inputs(token_type_ids, device); EncoderForwardOutput output; if (ids.device() != device) - output = forward_impl(ids.to(device), lengths.to(device), input_token_type_ids.to(device)); + output = forward_impl(ids.to(device), lengths.to(device), input_token_type_ids); else output = forward_impl(ids, lengths, input_token_type_ids); @@ -384,7 +365,7 @@ namespace ctranslate2 { placeholder_type_ids.zero(); inputs.emplace_back(std::move(placeholder_type_ids)); } else { - inputs.emplace_back(std::move(token_type_ids)); + inputs.emplace_back(token_type_ids); } } From 67182ef35e44173d44268f005ca7ad6782d078f5 Mon Sep 17 00:00:00 2001 From: Hamish Hall Date: Fri, 18 Aug 2023 17:17:49 +0100 Subject: [PATCH 3/4] space styling --- python/cpp/encoder.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cpp/encoder.cc b/python/cpp/encoder.cc index 2b1cab353..a03e4e194 100644 --- a/python/cpp/encoder.cc +++ b/python/cpp/encoder.cc @@ -84,7 +84,7 @@ namespace ctranslate2 { device_index: Device IDs where to place this encoder on. compute_type: Model computation type or a dictionary mapping a device name to the computation type (possible values are: default, auto, int8, int8_float32, - int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). inter_threads: Maximum number of parallel generations. intra_threads: Number of OpenMP threads per encoder (0 to use a default value). max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, @@ -100,7 +100,7 @@ namespace ctranslate2 { .def_property_readonly("device_index", &EncoderWrapper::device_index, "List of device IDs where this encoder is running on.") .def_property_readonly("compute_type", &EncoderWrapper::compute_type, - "Computation type used by the model.") + "Computation type used by the model.") .def_property_readonly("num_encoders", &EncoderWrapper::num_replicas, "Number of encoders backing this instance.") .def_property_readonly("num_queued_batches", &EncoderWrapper::num_queued_batches, From 0d8e1ffed6b448cb60bc267c292d304565db9fb7 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Mon, 28 Aug 2023 10:29:16 +0200 Subject: [PATCH 4/4] Fix formatting and default argument values --- include/ctranslate2/encoder.h | 12 ++++++++---- include/ctranslate2/models/language_model.h | 18 +++++++++++++----- python/cpp/encoder.cc | 16 +++++++++++----- src/encoder.cc | 18 ++++++++++++------ src/models/language_model.cc | 21 ++++++++++++++------- 5 files changed, 58 insertions(+), 27 deletions(-) diff --git a/include/ctranslate2/encoder.h b/include/ctranslate2/encoder.h index 521f2611e..ab75e293e 100644 --- a/include/ctranslate2/encoder.h +++ b/include/ctranslate2/encoder.h @@ -11,13 +11,17 @@ namespace ctranslate2 { using ReplicaPool::ReplicaPool; std::future - forward_batch_async(std::vector> tokens, std::vector> token_type_ids); + forward_batch_async(std::vector> tokens, + std::vector> token_type_ids = {}); std::future - forward_batch_async(std::vector> ids, std::vector> token_type_ids); - + forward_batch_async(std::vector> ids, + std::vector> token_type_ids = {}); + std::future - forward_batch_async(const StorageView& ids, const StorageView& lengths, std::vector> token_type_ids); + forward_batch_async(const StorageView& ids, + const StorageView& lengths, + std::vector> token_type_ids = {}); }; } diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h index 7c583afab..7532b9a3a 100644 --- a/include/ctranslate2/models/language_model.h +++ b/include/ctranslate2/models/language_model.h @@ -123,13 +123,19 @@ namespace ctranslate2 { return model.as_sequence_encoder(); } - EncoderForwardOutput forward(const std::vector>& tokens, const std::vector>& token_type_ids); - EncoderForwardOutput forward(const std::vector>& ids, const std::vector>& token_type_ids); - EncoderForwardOutput forward(const StorageView& ids, const StorageView& lengths, const std::vector>& token_type_ids); + EncoderForwardOutput forward(const std::vector>& tokens, + const std::vector>& token_type_ids = {}); + EncoderForwardOutput forward(const std::vector>& ids, + const std::vector>& token_type_ids = {}); + EncoderForwardOutput forward(const StorageView& ids, + const StorageView& lengths, + const std::vector>& token_type_ids = {}); protected: virtual EncoderForwardOutput - forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) = 0; + forward_impl(const StorageView& ids, + const StorageView& lengths, + const StorageView& token_type_ids) = 0; private: const std::shared_ptr _model; @@ -144,7 +150,9 @@ namespace ctranslate2 { protected: EncoderForwardOutput - forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) override; + forward_impl(const StorageView& ids, + const StorageView& lengths, + const StorageView& token_type_ids) override; private: const std::shared_ptr _model; diff --git a/python/cpp/encoder.cc b/python/cpp/encoder.cc index a03e4e194..7eb8abac7 100644 --- a/python/cpp/encoder.cc +++ b/python/cpp/encoder.cc @@ -14,21 +14,27 @@ namespace ctranslate2 { EncoderForwardOutput forward_batch(const std::variant& inputs, const std::optional& lengths, - const std::optional& token_type_ids - ) { + const std::optional& token_type_ids) { std::future future; switch (inputs.index()) { case 0: - future = _pool->forward_batch_async(std::get(inputs), token_type_ids.value_or(std::vector>())); + future = _pool->forward_batch_async( + std::get(inputs), + token_type_ids.value_or(std::vector>())); break; case 1: - future = _pool->forward_batch_async(std::get(inputs), token_type_ids.value_or(std::vector>())); + future = _pool->forward_batch_async( + std::get(inputs), + token_type_ids.value_or(std::vector>())); break; case 2: if (!lengths) throw std::invalid_argument("lengths vector is required when passing a dense input"); - future = _pool->forward_batch_async(std::get(inputs), lengths.value(), token_type_ids.value_or(std::vector>())); + future = _pool->forward_batch_async( + std::get(inputs), + lengths.value(), + token_type_ids.value_or(std::vector>())); break; } diff --git a/src/encoder.cc b/src/encoder.cc index 1e3a35ca5..3e7f728bc 100644 --- a/src/encoder.cc +++ b/src/encoder.cc @@ -1,18 +1,20 @@ #include "ctranslate2/encoder.h" namespace ctranslate2 { - + std::future - Encoder::forward_batch_async(std::vector> tokens, std::vector> token_type_ids = {}) { + Encoder::forward_batch_async(std::vector> tokens, + std::vector> token_type_ids) { return post( [tokens = std::move(tokens), token_type_ids = std::move(token_type_ids)] (models::SequenceEncoderReplica& encoder) { return encoder.forward(tokens, token_type_ids); }); } - + std::future - Encoder::forward_batch_async(std::vector> ids, std::vector> token_type_ids = {}) { + Encoder::forward_batch_async(std::vector> ids, + std::vector> token_type_ids) { return post( [ids = std::move(ids), token_type_ids = std::move(token_type_ids)] (models::SequenceEncoderReplica& encoder) { @@ -21,9 +23,13 @@ namespace ctranslate2 { } std::future - Encoder::forward_batch_async(const StorageView& ids, const StorageView& lengths, std::vector> token_type_ids = {}) { + Encoder::forward_batch_async(const StorageView& ids, + const StorageView& lengths, + std::vector> token_type_ids) { return post( - [ids = ids.sync_copy(), lengths = lengths.sync_copy(), token_type_ids = std::move(token_type_ids)] + [ids = ids.sync_copy(), + lengths = lengths.sync_copy(), + token_type_ids = std::move(token_type_ids)] (models::SequenceEncoderReplica& encoder) { return encoder.forward(ids, lengths, token_type_ids); }); diff --git a/src/models/language_model.cc b/src/models/language_model.cc index e7a60f5c4..466e42594 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -294,22 +294,27 @@ namespace ctranslate2 { decoder(ids, lengths, state, logits); return logits; } - + + EncoderForwardOutput - SequenceEncoderReplica::forward(const std::vector>& tokens, const std::vector>& token_type_ids = {}) { + SequenceEncoderReplica::forward(const std::vector>& tokens, + const std::vector>& token_type_ids) { const auto& vocabulary = _model->get_vocabulary(); return forward(vocabulary.to_ids(tokens), token_type_ids); } - + EncoderForwardOutput - SequenceEncoderReplica::forward(const std::vector>& ids, const std::vector>& token_type_ids = {}) { + SequenceEncoderReplica::forward(const std::vector>& ids, + const std::vector>& token_type_ids) { StorageView lengths; StorageView input_ids = layers::make_sequence_inputs(ids, Device::CPU, 1, &lengths); return forward(input_ids, lengths, token_type_ids); } - + EncoderForwardOutput - SequenceEncoderReplica::forward(const StorageView& ids, const StorageView& lengths, const std::vector>& token_type_ids = {}) { + SequenceEncoderReplica::forward(const StorageView& ids, + const StorageView& lengths, + const std::vector>& token_type_ids) { PROFILE("SequenceEncoderReplica::forward"); const auto& model = *this->model(); const auto device = model.device(); @@ -342,7 +347,9 @@ namespace ctranslate2 { } EncoderForwardOutput - EncoderReplica::forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) { + EncoderReplica::forward_impl(const StorageView& ids, + const StorageView& lengths, + const StorageView& token_type_ids) { if (ids.rank() != 2) throw std::invalid_argument("Expected input ids to have 2 dimensions, but got " + std::to_string(ids.rank())