Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

token type ids can be set by optional argument up to python wrapper #1418

Merged
merged 4 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/ctranslate2/encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,20 @@ namespace ctranslate2 {
std::future<EncoderForwardOutput>
forward_batch_async(std::vector<std::vector<std::string>> tokens);

std::future<EncoderForwardOutput>
forward_batch_async(std::vector<std::vector<std::string>> tokens, std::vector<std::vector<size_t>> token_type_ids);

std::future<EncoderForwardOutput>
forward_batch_async(std::vector<std::vector<size_t>> ids);

std::future<EncoderForwardOutput>
forward_batch_async(std::vector<std::vector<size_t>> ids, std::vector<std::vector<size_t>> token_type_ids);

std::future<EncoderForwardOutput>
forward_batch_async(const StorageView& ids, const StorageView& lengths);

std::future<EncoderForwardOutput>
forward_batch_async(const StorageView& ids, const StorageView& lengths, std::vector<std::vector<size_t>> token_type_ids);
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
};

}
7 changes: 5 additions & 2 deletions include/ctranslate2/models/language_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,15 @@ namespace ctranslate2 {
}

EncoderForwardOutput forward(const std::vector<std::vector<std::string>>& tokens);
EncoderForwardOutput forward(const std::vector<std::vector<std::string>>& tokens, const std::vector<std::vector<size_t>>& token_type_ids);
EncoderForwardOutput forward(const std::vector<std::vector<size_t>>& ids);
EncoderForwardOutput forward(const std::vector<std::vector<size_t>>& ids, const std::vector<std::vector<size_t>>& token_type_ids);
EncoderForwardOutput forward(const StorageView& ids, const StorageView& lengths);
EncoderForwardOutput forward(const StorageView& ids, const StorageView& lengths, const std::vector<std::vector<size_t>>& token_type_ids);
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved

protected:
virtual EncoderForwardOutput
forward_impl(const StorageView& ids, const StorageView& lengths) = 0;
forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) = 0;

private:
const std::shared_ptr<const LanguageModel> _model;
Expand All @@ -144,7 +147,7 @@ namespace ctranslate2 {

protected:
EncoderForwardOutput
forward_impl(const StorageView& ids, const StorageView& lengths) override;
forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) override;

private:
const std::shared_ptr<const LanguageModel> _model;
Expand Down
22 changes: 14 additions & 8 deletions python/cpp/encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,25 @@ namespace ctranslate2 {

EncoderForwardOutput
forward_batch(const std::variant<BatchTokens, BatchIds, StorageView>& inputs,
const std::optional<StorageView>& lengths) {
const std::optional<StorageView>& lengths,
const std::optional<BatchIds>& token_type_ids
) {
std::future<EncoderForwardOutput> future;

switch (inputs.index()) {
case 0:
future = _pool->forward_batch_async(std::get<BatchTokens>(inputs));
future = (token_type_ids) ? _pool->forward_batch_async(std::get<BatchTokens>(inputs), token_type_ids.value())
: _pool->forward_batch_async(std::get<BatchTokens>(inputs));
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
break;
case 1:
future = _pool->forward_batch_async(std::get<BatchIds>(inputs));
future = (token_type_ids) ? _pool->forward_batch_async(std::get<BatchIds>(inputs), token_type_ids.value())
: _pool->forward_batch_async(std::get<BatchIds>(inputs));
break;
case 2:
if (!lengths)
throw std::invalid_argument("lengths vector is required when passing a dense input");
future = _pool->forward_batch_async(std::get<StorageView>(inputs), lengths.value());
future = (token_type_ids) ? _pool->forward_batch_async(std::get<StorageView>(inputs), lengths.value(), token_type_ids.value())
: _pool->forward_batch_async(std::get<StorageView>(inputs), lengths.value());
break;
}

Expand Down Expand Up @@ -81,8 +86,8 @@ namespace ctranslate2 {
device: Device to use (possible values are: cpu, cuda, auto).
device_index: Device IDs where to place this encoder on.
compute_type: Model computation type or a dictionary mapping a device name
to the computation type (possible values are: default, auto, int8, int8_float32,
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
to the computation type (possible values are: default, auto, int8, int8_float16,
int8_bfloat16, int16, float16, bfloat16, float32).
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
inter_threads: Maximum number of parallel generations.
intra_threads: Number of OpenMP threads per encoder (0 to use a default value).
max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited,
Expand All @@ -97,8 +102,6 @@ namespace ctranslate2 {
"Device this encoder is running on.")
.def_property_readonly("device_index", &EncoderWrapper::device_index,
"List of device IDs where this encoder is running on.")
.def_property_readonly("compute_type", &EncoderWrapper::compute_type,
"Computation type used by the model.")
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
.def_property_readonly("num_encoders", &EncoderWrapper::num_replicas,
"Number of encoders backing this instance.")
.def_property_readonly("num_queued_batches", &EncoderWrapper::num_queued_batches,
Expand All @@ -109,6 +112,7 @@ namespace ctranslate2 {
.def("forward_batch", &EncoderWrapper::forward_batch,
py::arg("inputs"),
py::arg("lengths")=py::none(),
py::arg("token_type_ids")=py::none(),
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
Forwards a batch of sequences in the encoder.
Expand All @@ -119,6 +123,8 @@ namespace ctranslate2 {
``[batch_size, max_length]`` (e.g. created from a Numpy array or PyTorch tensor).
lengths: The length of each sequence as a int32 array with shape
``[batch_size]``. Required when :obj:`inputs` is a dense array.
token_type_ids: A batch of token type IDs of same shape as :obj:`inputs`.
``[batch_size, max_length]``.
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved

Returns:
The encoder model output.
Expand Down
27 changes: 27 additions & 0 deletions src/encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ namespace ctranslate2 {
return encoder.forward(tokens);
});
}

std::future<EncoderForwardOutput>
Encoder::forward_batch_async(std::vector<std::vector<std::string>> tokens, std::vector<std::vector<size_t>> token_type_ids) {
return post<EncoderForwardOutput>(
[tokens = std::move(tokens), token_type_ids = std::move(token_type_ids)]
(models::SequenceEncoderReplica& encoder) {
return encoder.forward(tokens, token_type_ids);
});
}

std::future<EncoderForwardOutput>
Encoder::forward_batch_async(std::vector<std::vector<size_t>> ids) {
Expand All @@ -19,6 +28,15 @@ namespace ctranslate2 {
return encoder.forward(ids);
});
}

std::future<EncoderForwardOutput>
Encoder::forward_batch_async(std::vector<std::vector<size_t>> ids, std::vector<std::vector<size_t>> token_type_ids) {
return post<EncoderForwardOutput>(
[ids = std::move(ids), token_type_ids = std::move(token_type_ids)]
(models::SequenceEncoderReplica& encoder) {
return encoder.forward(ids, token_type_ids);
});
}

std::future<EncoderForwardOutput>
Encoder::forward_batch_async(const StorageView& ids, const StorageView& lengths) {
Expand All @@ -29,4 +47,13 @@ namespace ctranslate2 {
});
}

std::future<EncoderForwardOutput>
Encoder::forward_batch_async(const StorageView& ids, const StorageView& lengths, std::vector<std::vector<size_t>> token_type_ids) {
return post<EncoderForwardOutput>(
[ids = ids.sync_copy(), lengths = lengths.sync_copy(), token_type_ids = std::move(token_type_ids)]
(models::SequenceEncoderReplica& encoder) {
return encoder.forward(ids, lengths, token_type_ids);
});
}

}
39 changes: 31 additions & 8 deletions src/models/language_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,30 +298,49 @@ namespace ctranslate2 {

EncoderForwardOutput
SequenceEncoderReplica::forward(const std::vector<std::vector<std::string>>& tokens) {
std::vector<std::vector<size_t>> token_type_ids;
return forward(tokens, token_type_ids);
}

EncoderForwardOutput
SequenceEncoderReplica::forward(const std::vector<std::vector<std::string>>& tokens, const std::vector<std::vector<size_t>>& token_type_ids) {
const auto& vocabulary = _model->get_vocabulary();
return forward(vocabulary.to_ids(tokens));
return forward(vocabulary.to_ids(tokens), token_type_ids);
}

EncoderForwardOutput
SequenceEncoderReplica::forward(const std::vector<std::vector<size_t>>& ids) {
std::vector<std::vector<size_t>> token_type_ids;
return forward(ids, token_type_ids);
}

EncoderForwardOutput
SequenceEncoderReplica::forward(const std::vector<std::vector<size_t>>& ids, const std::vector<std::vector<size_t>>& token_type_ids) {
StorageView lengths;
StorageView input_ids = layers::make_sequence_inputs(ids, Device::CPU, 1, &lengths);
return forward(input_ids, lengths);
return forward(input_ids, lengths, token_type_ids);
}

EncoderForwardOutput
SequenceEncoderReplica::forward(const StorageView& ids, const StorageView& lengths) {
std::vector<std::vector<size_t>> token_type_ids;
return forward(ids, lengths, token_type_ids);
}

EncoderForwardOutput
SequenceEncoderReplica::forward(const StorageView& ids, const StorageView& lengths, const std::vector<std::vector<size_t>>& token_type_ids) {
PROFILE("SequenceEncoderReplica::forward");
const auto& model = *this->model();
const auto device = model.device();
const auto scoped_device_setter = model.get_scoped_device_setter();

StorageView input_token_type_ids = layers::make_sequence_inputs(token_type_ids, Device::CPU, 1, nullptr);
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
EncoderForwardOutput output;

if (ids.device() != device)
output = forward_impl(ids.to(device), lengths.to(device));
output = forward_impl(ids.to(device), lengths.to(device), input_token_type_ids.to(device));
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
else
output = forward_impl(ids, lengths);
output = forward_impl(ids, lengths, input_token_type_ids);

// Ensure all operations are finished before returning the output.
synchronize_stream(device);
Expand All @@ -342,7 +361,7 @@ namespace ctranslate2 {
}

EncoderForwardOutput
EncoderReplica::forward_impl(const StorageView& ids, const StorageView& lengths) {
EncoderReplica::forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) {
if (ids.rank() != 2)
throw std::invalid_argument("Expected input ids to have 2 dimensions, but got "
+ std::to_string(ids.rank())
Expand All @@ -360,9 +379,13 @@ namespace ctranslate2 {
std::vector<StorageView> inputs{ids};

if (_encoder->num_input_features() > 1) {
StorageView token_type_ids(ids.shape(), ids.dtype(), device);
token_type_ids.zero();
inputs.emplace_back(std::move(token_type_ids));
if (token_type_ids.empty()) {
StorageView placeholder_type_ids(ids.shape(), ids.dtype(), device);
placeholder_type_ids.zero();
inputs.emplace_back(std::move(placeholder_type_ids));
} else {
inputs.emplace_back(std::move(token_type_ids));
guillaumekln marked this conversation as resolved.
Show resolved Hide resolved
}
}

StorageView last_hidden_state(dtype, device);
Expand Down