diff --git a/include/ctranslate2/encoder.h b/include/ctranslate2/encoder.h index b3d264ddf..ab75e293e 100644 --- a/include/ctranslate2/encoder.h +++ b/include/ctranslate2/encoder.h @@ -11,13 +11,17 @@ namespace ctranslate2 { using ReplicaPool::ReplicaPool; std::future - forward_batch_async(std::vector> tokens); + forward_batch_async(std::vector> tokens, + std::vector> token_type_ids = {}); std::future - forward_batch_async(std::vector> ids); + forward_batch_async(std::vector> ids, + std::vector> token_type_ids = {}); std::future - forward_batch_async(const StorageView& ids, const StorageView& lengths); + forward_batch_async(const StorageView& ids, + const StorageView& lengths, + std::vector> token_type_ids = {}); }; } diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h index fdab65d88..7532b9a3a 100644 --- a/include/ctranslate2/models/language_model.h +++ b/include/ctranslate2/models/language_model.h @@ -123,13 +123,19 @@ namespace ctranslate2 { return model.as_sequence_encoder(); } - EncoderForwardOutput forward(const std::vector>& tokens); - EncoderForwardOutput forward(const std::vector>& ids); - EncoderForwardOutput forward(const StorageView& ids, const StorageView& lengths); + EncoderForwardOutput forward(const std::vector>& tokens, + const std::vector>& token_type_ids = {}); + EncoderForwardOutput forward(const std::vector>& ids, + const std::vector>& token_type_ids = {}); + EncoderForwardOutput forward(const StorageView& ids, + const StorageView& lengths, + const std::vector>& token_type_ids = {}); protected: virtual EncoderForwardOutput - forward_impl(const StorageView& ids, const StorageView& lengths) = 0; + forward_impl(const StorageView& ids, + const StorageView& lengths, + const StorageView& token_type_ids) = 0; private: const std::shared_ptr _model; @@ -144,7 +150,9 @@ namespace ctranslate2 { protected: EncoderForwardOutput - forward_impl(const StorageView& ids, const StorageView& lengths) override; + forward_impl(const StorageView& ids, + const StorageView& lengths, + const StorageView& token_type_ids) override; private: const std::shared_ptr _model; diff --git a/python/cpp/encoder.cc b/python/cpp/encoder.cc index 2263f1b99..7eb8abac7 100644 --- a/python/cpp/encoder.cc +++ b/python/cpp/encoder.cc @@ -13,20 +13,28 @@ namespace ctranslate2 { EncoderForwardOutput forward_batch(const std::variant& inputs, - const std::optional& lengths) { + const std::optional& lengths, + const std::optional& token_type_ids) { std::future future; switch (inputs.index()) { case 0: - future = _pool->forward_batch_async(std::get(inputs)); + future = _pool->forward_batch_async( + std::get(inputs), + token_type_ids.value_or(std::vector>())); break; case 1: - future = _pool->forward_batch_async(std::get(inputs)); + future = _pool->forward_batch_async( + std::get(inputs), + token_type_ids.value_or(std::vector>())); break; case 2: if (!lengths) throw std::invalid_argument("lengths vector is required when passing a dense input"); - future = _pool->forward_batch_async(std::get(inputs), lengths.value()); + future = _pool->forward_batch_async( + std::get(inputs), + lengths.value(), + token_type_ids.value_or(std::vector>())); break; } @@ -109,6 +117,7 @@ namespace ctranslate2 { .def("forward_batch", &EncoderWrapper::forward_batch, py::arg("inputs"), py::arg("lengths")=py::none(), + py::arg("token_type_ids")=py::none(), py::call_guard(), R"pbdoc( Forwards a batch of sequences in the encoder. @@ -119,6 +128,8 @@ namespace ctranslate2 { ``[batch_size, max_length]`` (e.g. created from a Numpy array or PyTorch tensor). lengths: The length of each sequence as a int32 array with shape ``[batch_size]``. Required when :obj:`inputs` is a dense array. + token_type_ids: A batch of token type IDs of same shape as :obj:`inputs`. + ``[batch_size, max_length]``. Returns: The encoder model output. diff --git a/src/encoder.cc b/src/encoder.cc index efbe49d7e..3e7f728bc 100644 --- a/src/encoder.cc +++ b/src/encoder.cc @@ -3,29 +3,35 @@ namespace ctranslate2 { std::future - Encoder::forward_batch_async(std::vector> tokens) { + Encoder::forward_batch_async(std::vector> tokens, + std::vector> token_type_ids) { return post( - [tokens = std::move(tokens)] + [tokens = std::move(tokens), token_type_ids = std::move(token_type_ids)] (models::SequenceEncoderReplica& encoder) { - return encoder.forward(tokens); + return encoder.forward(tokens, token_type_ids); }); } std::future - Encoder::forward_batch_async(std::vector> ids) { + Encoder::forward_batch_async(std::vector> ids, + std::vector> token_type_ids) { return post( - [ids = std::move(ids)] + [ids = std::move(ids), token_type_ids = std::move(token_type_ids)] (models::SequenceEncoderReplica& encoder) { - return encoder.forward(ids); + return encoder.forward(ids, token_type_ids); }); } std::future - Encoder::forward_batch_async(const StorageView& ids, const StorageView& lengths) { + Encoder::forward_batch_async(const StorageView& ids, + const StorageView& lengths, + std::vector> token_type_ids) { return post( - [ids = ids.sync_copy(), lengths = lengths.sync_copy()] + [ids = ids.sync_copy(), + lengths = lengths.sync_copy(), + token_type_ids = std::move(token_type_ids)] (models::SequenceEncoderReplica& encoder) { - return encoder.forward(ids, lengths); + return encoder.forward(ids, lengths, token_type_ids); }); } diff --git a/src/models/language_model.cc b/src/models/language_model.cc index fb09d18c7..466e42594 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -297,31 +297,36 @@ namespace ctranslate2 { EncoderForwardOutput - SequenceEncoderReplica::forward(const std::vector>& tokens) { + SequenceEncoderReplica::forward(const std::vector>& tokens, + const std::vector>& token_type_ids) { const auto& vocabulary = _model->get_vocabulary(); - return forward(vocabulary.to_ids(tokens)); + return forward(vocabulary.to_ids(tokens), token_type_ids); } EncoderForwardOutput - SequenceEncoderReplica::forward(const std::vector>& ids) { + SequenceEncoderReplica::forward(const std::vector>& ids, + const std::vector>& token_type_ids) { StorageView lengths; StorageView input_ids = layers::make_sequence_inputs(ids, Device::CPU, 1, &lengths); - return forward(input_ids, lengths); + return forward(input_ids, lengths, token_type_ids); } EncoderForwardOutput - SequenceEncoderReplica::forward(const StorageView& ids, const StorageView& lengths) { + SequenceEncoderReplica::forward(const StorageView& ids, + const StorageView& lengths, + const std::vector>& token_type_ids) { PROFILE("SequenceEncoderReplica::forward"); const auto& model = *this->model(); const auto device = model.device(); const auto scoped_device_setter = model.get_scoped_device_setter(); + StorageView input_token_type_ids = layers::make_sequence_inputs(token_type_ids, device); EncoderForwardOutput output; if (ids.device() != device) - output = forward_impl(ids.to(device), lengths.to(device)); + output = forward_impl(ids.to(device), lengths.to(device), input_token_type_ids); else - output = forward_impl(ids, lengths); + output = forward_impl(ids, lengths, input_token_type_ids); // Ensure all operations are finished before returning the output. synchronize_stream(device); @@ -342,7 +347,9 @@ namespace ctranslate2 { } EncoderForwardOutput - EncoderReplica::forward_impl(const StorageView& ids, const StorageView& lengths) { + EncoderReplica::forward_impl(const StorageView& ids, + const StorageView& lengths, + const StorageView& token_type_ids) { if (ids.rank() != 2) throw std::invalid_argument("Expected input ids to have 2 dimensions, but got " + std::to_string(ids.rank()) @@ -360,9 +367,13 @@ namespace ctranslate2 { std::vector inputs{ids}; if (_encoder->num_input_features() > 1) { - StorageView token_type_ids(ids.shape(), ids.dtype(), device); - token_type_ids.zero(); - inputs.emplace_back(std::move(token_type_ids)); + if (token_type_ids.empty()) { + StorageView placeholder_type_ids(ids.shape(), ids.dtype(), device); + placeholder_type_ids.zero(); + inputs.emplace_back(std::move(placeholder_type_ids)); + } else { + inputs.emplace_back(token_type_ids); + } } StorageView last_hidden_state(dtype, device);