Skip to content

Commit

Permalink
#tf-data Compress iterator checkpoints.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 486227313
  • Loading branch information
yangustc07 authored and tensorflower-gardener committed Nov 4, 2022
1 parent 5e418c3 commit 881f491
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 9 deletions.
7 changes: 6 additions & 1 deletion tensorflow/core/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package(
exports_files([
"captured_function.cc",
"captured_function.h",
"compression_utils.cc",
"compression_utils.h",
"dataset_utils.cc",
"dataset_utils.h",
"finalization_utils.cc",
Expand Down Expand Up @@ -93,10 +95,11 @@ tf_cc_test(
":dataset_test_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/platform:status_matchers",
"//tensorflow/tsl/platform:status_matchers",
],
)

Expand Down Expand Up @@ -326,10 +329,12 @@ cc_library(
srcs = ["serialization_utils.cc"],
hdrs = ["serialization_utils.h"],
deps = [
":compression_utils",
":dataset_utils",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/core:status",
],
)
Expand Down
17 changes: 17 additions & 0 deletions tensorflow/core/data/compression_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/snappy.h"
#include "tensorflow/core/platform/status.h"
Expand All @@ -30,6 +31,14 @@ limitations under the License.

namespace tensorflow {
namespace data {
namespace {

// Increment this when making changes to the `CompressedElement` proto. The
// `UncompressElement` function will determine what to read according to the
// version.
constexpr int kCompressedElementVersion = 0;

} // namespace

class Iov {
public:
Expand Down Expand Up @@ -122,13 +131,18 @@ Status CompressElement(const std::vector<Tensor>& element,
out->mutable_data())) {
return errors::Internal("Failed to compress using snappy.");
}
out->set_version(kCompressedElementVersion);
VLOG(3) << "Compressed element from " << iov.NumBytes() << " bytes to "
<< out->data().size() << " bytes";
return OkStatus();
}

Status UncompressElement(const CompressedElement& compressed,
std::vector<Tensor>* out) {
if (compressed.version() != kCompressedElementVersion) {
return errors::Internal("Unsupported compressed element version: ",
compressed.version());
}
int num_components = compressed.component_metadata_size();
out->clear();
out->reserve(num_components);
Expand Down Expand Up @@ -242,5 +256,8 @@ StatusOr<std::vector<Tensor>> DeserializeAndUncompress(
return tensors;
}

REGISTER_UNARY_VARIANT_DECODE_FUNCTION(CompressedElement,
"tensorflow.data.CompressedElement");

} // namespace data
} // namespace tensorflow
23 changes: 21 additions & 2 deletions tensorflow/core/data/compression_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ limitations under the License.

#include "tensorflow/core/data/dataset_test_base.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/status_matchers.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/tsl/platform/status_matchers.h"

namespace tensorflow {
namespace data {
namespace {

using ::tensorflow::testing::StatusIs;
using ::testing::HasSubstr;
using ::tsl::testing::StatusIs;

TEST(CompressionUtilsTest, Exceeds4GB) {
std::vector<Tensor> element = {
Expand Down Expand Up @@ -85,6 +86,24 @@ TEST_P(ParameterizedCompressionUtilsTest, RoundTrip) {
ExpectEqual(element, round_trip_element, /*compare_order=*/true));
}

TEST_P(ParameterizedCompressionUtilsTest, CompressedElementVersion) {
std::vector<Tensor> element = GetParam();
CompressedElement compressed;
TF_ASSERT_OK(CompressElement(element, &compressed));
EXPECT_EQ(0, compressed.version());
}

TEST_P(ParameterizedCompressionUtilsTest, VersionMismatch) {
std::vector<Tensor> element = GetParam();
CompressedElement compressed;
TF_ASSERT_OK(CompressElement(element, &compressed));

compressed.set_version(1);
std::vector<Tensor> round_trip_element;
EXPECT_THAT(UncompressElement(compressed, &round_trip_element),
StatusIs(error::INTERNAL));
}

INSTANTIATE_TEST_SUITE_P(Instantiation, ParameterizedCompressionUtilsTest,
::testing::ValuesIn(TestCases()));

Expand Down
57 changes: 51 additions & 6 deletions tensorflow/core/data/serialization_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/data/serialization_utils.h"

#include <memory>
Expand All @@ -22,8 +21,10 @@ limitations under the License.

#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/data/compression_utils.h"
#include "tensorflow/core/data/dataset_utils.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/dataset.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
Expand Down Expand Up @@ -428,7 +429,7 @@ std::string IteratorStateVariant::TypeName() {

IteratorStateVariant::IteratorStateVariant(const IteratorStateVariant& other) {
if (other.data_) {
Decode(*other.data_);
data_ = std::make_unique<VariantTensorData>(*other.data_);
}
}

Expand All @@ -439,19 +440,63 @@ Status IteratorStateVariant::InitializeFromVariantData(
}

void IteratorStateVariant::Encode(VariantTensorData* data) const {
*data = *data_;
CompressedElement compressed_tensors;
Status s = CompressElement(data_->tensors(), &compressed_tensors);
if (!s.ok()) {
LOG(WARNING) << "Failed to compress iterator state variant: " << s;
*data = *data_;
return;
}

data->set_type_name(TypeName());
data->set_metadata(data_->metadata_string());
Tensor tensor(DT_VARIANT, TensorShape({}));
tensor.scalar<Variant>()() = std::move(compressed_tensors);
*data->add_tensors() = std::move(tensor);
}

bool IteratorStateVariant::Decode(VariantTensorData data) {
if (data.type_name() != TypeName()) {
return false;
}
auto tensor_data = std::make_unique<VariantTensorData>();
std::swap(*tensor_data, data);
data_ = std::move(tensor_data);

const CompressedElement* compressed = GetCompressedElement(data);
if (!compressed) {
data_ = std::make_unique<VariantTensorData>(std::move(data));
return true;
}

std::vector<Tensor> tensors;
Status s = UncompressElement(*compressed, &tensors);
if (!s.ok()) {
LOG(WARNING) << "Failed to uncompress iterator state variant: " << s;
data_ = std::make_unique<VariantTensorData>(std::move(data));
return true;
}

data_ = std::make_unique<VariantTensorData>();
data_->set_type_name(TypeName());
data_->set_metadata(std::move(data.metadata_string()));
for (auto& tensor : tensors) {
*data_->add_tensors() = std::move(tensor);
}
return true;
}

const CompressedElement* IteratorStateVariant::GetCompressedElement(
const VariantTensorData& data) {
bool should_uncompress =
data.tensors_size() == 1 &&
TensorShapeUtils::IsScalar(data.tensors(0).shape()) &&
data.tensors(0).dtype() == DT_VARIANT;
if (!should_uncompress) {
return nullptr;
}

const Variant& variant = data.tensors(0).scalar<Variant>()();
return variant.get<CompressedElement>();
}

std::string IteratorStateVariant::DebugString() const {
if (data_) {
return strings::StrCat("IteratorStateVariant<", data_->DebugString(), ">");
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/core/data/serialization_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <string>

#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/dataset.pb.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/core/status.h"

Expand Down Expand Up @@ -161,12 +162,24 @@ class IteratorStateVariant {
// Returns a borrowed pointer to the underlying VariantTensorData.
const VariantTensorData* GetData() const { return data_.get(); }

// Encodes this `IteratorStateVariant` into `*data`. Data will be compressed
// and stored as a scalar `CompressedElement` tensor, or left uncompressed if
// compression fails.
void Encode(VariantTensorData* data) const;

// Decodes from `data`. If `data` contains a single scalar `CompressedElement`
// tensor, it is assumed to be compressed by `Encode`, and will be
// uncompressed as part of `Decode`.
bool Decode(VariantTensorData data);

std::string DebugString() const;

private:
// Returns the compressed element in `data`. If `data` does not contain a
// compressed element, returns nullptr.
static const CompressedElement* GetCompressedElement(
const VariantTensorData& data);

std::unique_ptr<VariantTensorData> data_;
};

Expand Down
17 changes: 17 additions & 0 deletions tensorflow/core/data/serialization_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,13 @@ class ParameterizedIteratorStateVariantTest
decoder.Decode(encoded_data);
return *decoder.GetData();
}

StatusOr<VariantTensorData> DecodeUncompressed(
const VariantTensorData& data) const {
IteratorStateVariant decoder;
decoder.Decode(data);
return *decoder.GetData();
}
};

std::vector<std::vector<Tensor>> TestCases() {
Expand All @@ -265,6 +272,16 @@ TEST_P(ParameterizedIteratorStateVariantTest, EncodeAndDecode) {
}
}

TEST_P(ParameterizedIteratorStateVariantTest, DecodeUncompressed) {
VariantTensorData data = GetVariantTensorData();
TF_ASSERT_OK_AND_ASSIGN(VariantTensorData result, DecodeUncompressed(data));

EXPECT_EQ(result.type_name(), data.type_name());
for (int i = 0; i < result.tensors_size(); ++i) {
test::ExpectEqual(result.tensors(i), data.tensors(i));
}
}

INSTANTIATE_TEST_SUITE_P(Instantiation, ParameterizedIteratorStateVariantTest,
::testing::ValuesIn(TestCases()));

Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/framework/dataset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ message CompressedElement {
bytes data = 1;
// Metadata for the components of the element.
repeated CompressedComponentMetadata component_metadata = 2;
// Version of the CompressedElement. CompressedElements may be stored on disk
// and read back by later versions of code, so we store a version number to
// help readers understand which version they are reading. When you add a new
// field to this proto, you need to increment kCompressedElementVersion in
// tensorflow/core/data/compression_utils.cc.
int32 version = 3;
}

// An uncompressed dataset element.
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,7 @@ filegroup(
name = "portable_all_op_kernels_headers",
srcs = [
"//tensorflow/core/data:captured_function.h",
"//tensorflow/core/data:compression_utils.h",
"//tensorflow/core/data:dataset_utils.h",
"//tensorflow/core/data:finalization_utils.h",
"//tensorflow/core/data:metric_utils.h",
Expand All @@ -1530,6 +1531,7 @@ filegroup(
srcs = [
":portable_all_op_kernels_headers",
"//tensorflow/core/data:captured_function.cc",
"//tensorflow/core/data:compression_utils.cc",
"//tensorflow/core/data:dataset_utils.cc",
"//tensorflow/core/data:finalization_utils.cc",
"//tensorflow/core/data:metric_utils.cc",
Expand Down

0 comments on commit 881f491

Please sign in to comment.