Skip to content

Commit

Permalink
Support string tensors (#1079)
Browse files Browse the repository at this point in the history
Signed-off-by: Tao He <[email protected]>
  • Loading branch information
sighingnow authored Dec 6, 2022
1 parent f8a6c0a commit 019d2dd
Show file tree
Hide file tree
Showing 8 changed files with 397 additions and 37 deletions.
8 changes: 8 additions & 0 deletions modules/basic/ds/arrow_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ std::shared_ptr<arrow::Table> ConcatenateTables(

std::shared_ptr<arrow::DataType> FromAnyType(AnyType type) {
switch (type) {
case AnyType::Undefined:
return arrow::null();
case AnyType::Int32:
return arrow::int32();
case AnyType::UInt32:
Expand All @@ -59,6 +61,12 @@ std::shared_ptr<arrow::DataType> FromAnyType(AnyType type) {
return arrow::float32();
case AnyType::Double:
return arrow::float64();
case AnyType::String:
return arrow::large_utf8();
case AnyType::Date32:
return arrow::int32();
case AnyType::Date64:
return arrow::int64();
default:
return arrow::null();
}
Expand Down
39 changes: 31 additions & 8 deletions modules/basic/ds/dataframe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,19 +78,42 @@ const std::shared_ptr<arrow::RecordBatch> DataFrame::AsBatch(bool copy) const {
} else if (auto tensor =
std::dynamic_pointer_cast<Tensor<double>>(df_col)) {
num_rows = tensor->shape()[0];
} else if (auto tensor =
std::dynamic_pointer_cast<Tensor<std::string>>(df_col)) {
num_rows = tensor->shape()[0];
}

std::shared_ptr<arrow::Buffer> copied_buffer;
if (copy) {
CHECK_ARROW_ERROR_AND_ASSIGN(
copied_buffer,
df_col->buffer()->CopySlice(0, df_col->buffer()->size()));
} else {
copied_buffer = df_col->buffer();
std::vector<std::shared_ptr<arrow::Buffer>> buffer{
nullptr /* null bitmap */};

// process the second buffer for std::string type
if (auto tensor = std::dynamic_pointer_cast<Tensor<std::string>>(df_col)) {
std::shared_ptr<arrow::Buffer> copied_buffer;
if (copy) {
CHECK_ARROW_ERROR_AND_ASSIGN(
copied_buffer,
df_col->buffer()->CopySlice(0, df_col->auxiliary_buffer()->size()));
} else {
copied_buffer = df_col->buffer();
}
buffer.push_back(copied_buffer);
}

// process buffer
{
std::shared_ptr<arrow::Buffer> copied_buffer;
if (copy) {
CHECK_ARROW_ERROR_AND_ASSIGN(
copied_buffer,
df_col->buffer()->CopySlice(0, df_col->buffer()->size()));
} else {
copied_buffer = df_col->buffer();
}
buffer.push_back(copied_buffer);
}

columns[i] = arrow::MakeArray(arrow::ArrayData::Make(
FromAnyType(df_col->value_type()), num_rows, {nullptr, copied_buffer}));
FromAnyType(df_col->value_type()), num_rows, buffer));

std::shared_ptr<arrow::Scalar> sca;
CHECK_ARROW_ERROR_AND_ASSIGN(sca, columns[i]->GetScalar(0));
Expand Down
148 changes: 147 additions & 1 deletion modules/basic/ds/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class ITensorBuilder {
*/
template <typename T>
class TensorBuilder : public ITensorBuilder, public TensorBaseBuilder<T> {
public:
using value_t = T;
using value_pointer_t = T*;
using value_const_pointer_t = const T*;

public:
/**
* @brief Initialize the TensorBuilder with the tensor shape.
Expand Down Expand Up @@ -136,7 +141,7 @@ class TensorBuilder : public ITensorBuilder, public TensorBaseBuilder<T> {
* @brief Get the data pointer of the tensor.
*
*/
inline T* data() const { return this->data_; }
inline value_pointer_t data() const { return this->data_; }

/**
* @brief Build the tensor.
Expand All @@ -153,6 +158,147 @@ class TensorBuilder : public ITensorBuilder, public TensorBaseBuilder<T> {
T* data_;
};

/**
* @brief TensorBuilder is used for building tensors that supported by vineyard
*
* @tparam T
*/
template <>
class TensorBuilder<std::string> : public ITensorBuilder,
public TensorBaseBuilder<std::string> {
public:
using value_t = detail::arrow_string_view;
using value_pointer_t = uint8_t*;
using value_const_pointer_t = const uint8_t*;

public:
/**
* @brief Initialize the TensorBuilder with the tensor shape.
*
* @param client The client connected to the vineyard server.
* @param shape The shape of the tensor.
*/
TensorBuilder(Client& client, std::vector<int64_t> const& shape)
: TensorBaseBuilder<std::string>(client) {
this->set_value_type_(AnyType(AnyTypeEnum<std::string>::value));
this->set_shape_(shape);
this->buffer_writer_ = std::make_shared<arrow::LargeStringBuilder>();
}

/**
* @brief Initialize the TensorBuilder for a partition of a GlobalTensor.
*
* @param client The client connected to the vineyard server.
* @param shape The shape of the partition.
* @param partition_index The partition index in the global tensor.
*/
TensorBuilder(Client& client, std::vector<int64_t> const& shape,
std::vector<int64_t> const& partition_index)
: TensorBuilder(client, shape) {
this->set_partition_index_(partition_index);
}

/**
* @brief Get the shape of the tensor.
*
* @return The shape vector where the ith element represents
* the size of the ith axis.
*/
std::vector<int64_t> const& shape() const { return this->shape_; }

/**
* @brief Get the index of this partition in the global tensor.
*
* @return The index vector where the ith element represents the index
* in the ith axis.
*/
std::vector<int64_t> const& partition_index() const {
return this->partition_index_;
}

/**
* @brief Set the shape of the tensor.
*
* @param shape The vector for the shape, where the ith element
* represents the size of the shape in the ith axis.
*/
void set_shape(std::vector<int64_t> const& shape) { this->set_shape_(shape); }

/**
* @brief Set the index in the global tensor.
*
* @param partition_index The vector of indices, where the ith element
* represents the index in the ith axis.
*/
void set_partition_index(std::vector<int64_t> const& partition_index) {
this->set_partition_index_(partition_index);
}

/**
* @brief Get the strides of the tensor.
*
* @return The strides of the tensor. The definition of the tensor's strides
* can be found in https://pytorch.org/docs/stable/tensor_attributes.html
*/
std::vector<int64_t> strides() const {
std::vector<int64_t> vec(this->shape_.size());
vec[this->shape_.size() - 1] = 1 /* special case for std::string */;
for (size_t i = this->shape_.size() - 1; i > 0; --i) {
vec[i - 1] = vec[i] * this->shape_[i];
}
return vec;
}

/**
* @brief Get the data pointer of the tensor.
*
*/
inline value_pointer_t data() const {
return const_cast<value_pointer_t>(this->buffer_writer_->value_data());
}

/**
* @brief Append value to the builder.
*/
inline Status Append(value_t const& value) {
RETURN_ON_ARROW_ERROR(
this->buffer_writer_->Append(value.data(), value.size()));
return Status::OK();
}

/**
* @brief Append value to the builder.
*/
inline Status Append(value_const_pointer_t value, const size_t length) {
RETURN_ON_ARROW_ERROR(this->buffer_writer_->Append(value, length));
return Status::OK();
}

/**
* @brief Append value to the builder.
*/
inline Status Append(std::string const& value) {
RETURN_ON_ARROW_ERROR(this->buffer_writer_->Append(value));
return Status::OK();
}

/**
* @brief Build the tensor.
*
* @param client The client connceted to the vineyard server.
*/
Status Build(Client& client) override {
std::shared_ptr<arrow::Array> array;
RETURN_ON_ARROW_ERROR_AND_ASSIGN(array, buffer_writer_->Finish());
this->set_buffer_(std::make_shared<LargeStringArrayBuilder>(
client, std::dynamic_pointer_cast<arrow::LargeStringArray>(array)));
return Status::OK();
}

private:
std::shared_ptr<arrow::LargeStringBuilder> buffer_writer_;
};

class GlobalTensorBaseBuilder;

/**
Expand Down
Loading

0 comments on commit 019d2dd

Please sign in to comment.