Skip to content

Commit

Permalink
remove_line_end_token fix
Browse files Browse the repository at this point in the history
Signed-off-by: Suraj Aralihalli <[email protected]>
  • Loading branch information
SurajAralihalli committed Apr 26, 2024
1 parent 03e9824 commit e2bd05e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 37 deletions.
82 changes: 55 additions & 27 deletions cpp/src/io/json/nested_json_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -297,17 +297,16 @@ struct TransduceTokenKeepLineEnd {

if (is_end_of_invalid_line) {
switch (relative_offset) {
case 0 : return SymbolT{token_t::StructEnd, 0};
case 1 : return SymbolT{token_t::StructBegin, 0};
case 2 : return SymbolT{token_t::LineEnd, 0};
default: return SymbolT{token_t::LineEnd, 0}; // doesn't appear
case 0: return SymbolT{token_t::StructEnd, 0};
case 1: return SymbolT{token_t::StructBegin, 0};
case 2: return SymbolT{token_t::LineEnd, 0};
default: return SymbolT{token_t::LineEnd, 0}; // doesn't appear
}
} else if (is_end_of_valid_line) {
return SymbolT{token_t::LineEnd, 0};
} else {
return read_symbol;
}

}

template <typename SymbolT>
Expand Down Expand Up @@ -1539,34 +1538,63 @@ std::pair<rmm::device_uvector<PdaTokenT>, rmm::device_uvector<SymbolOffsetT>> pr
// Instantiate FST for post-processing the token stream to remove all tokens that belong to an
// invalid JSON line
token_filter::UnwrapTokenFromSymbolOp sgid_op{};
auto filter_fst =
fst::detail::make_fst(fst::detail::make_symbol_group_lut(token_filter::symbol_groups, sgid_op),
fst::detail::make_transition_table(token_filter::transition_table),
fst::detail::make_translation_functor(token_filter::TransduceTokenKeepLineEnd{}),
stream);

auto const mr = rmm::mr::get_current_device_resource();
rmm::device_scalar<SymbolOffsetT> d_num_selected_tokens(stream, mr);
rmm::device_uvector<PdaTokenT> filtered_tokens_out{tokens.size(), stream, mr};
rmm::device_uvector<SymbolOffsetT> filtered_token_indices_out{tokens.size(), stream, mr};

// The FST is run on the reverse token stream, discarding all tokens between ErrorBegin and the
// next LineEnd (LineEnd, inv_token_0, inv_token_1, ..., inv_token_n, ErrorBegin, LineEnd, ...),
// emitting a [StructBegin, StructEnd] pair on the end of such an invalid line. In that example,
// inv_token_i for i in [0, n] together with the ErrorBegin are removed and replaced with
// StructBegin, StructEnd. Also, all LineEnd are removed as well, as these are not relevant after
// this stage anymore
filter_fst.Transduce(
thrust::make_reverse_iterator(thrust::make_zip_iterator(tokens.data(), token_indices.data()) +
tokens.size()),
static_cast<SymbolOffsetT>(tokens.size()),
thrust::make_reverse_iterator(
thrust::make_zip_iterator(filtered_tokens_out.data(), filtered_token_indices_out.data()) +
tokens.size()),
thrust::make_discard_iterator(),
d_num_selected_tokens.data(),
token_filter::start_state,
stream);
if (remove_line_end_token) {
// The FST is run on the reverse token stream, discarding all tokens between ErrorBegin and the
// next LineEnd (LineEnd, inv_token_0, inv_token_1, ..., inv_token_n, ErrorBegin, LineEnd, ...),
// emitting a [StructBegin, StructEnd] pair on the end of such an invalid line. In that example,
// inv_token_i for i in [0, n] together with the ErrorBegin are removed and replaced with
// StructBegin, StructEnd. Also, all LineEnd are removed as well, as these are not relevant
// after this stage anymore

auto filter_fst = fst::detail::make_fst(
fst::detail::make_symbol_group_lut(token_filter::symbol_groups, sgid_op),
fst::detail::make_transition_table(token_filter::transition_table),
fst::detail::make_translation_functor(token_filter::TransduceToken{}),
stream);

filter_fst.Transduce(
thrust::make_reverse_iterator(thrust::make_zip_iterator(tokens.data(), token_indices.data()) +
tokens.size()),
static_cast<SymbolOffsetT>(tokens.size()),
thrust::make_reverse_iterator(
thrust::make_zip_iterator(filtered_tokens_out.data(), filtered_token_indices_out.data()) +
tokens.size()),
thrust::make_discard_iterator(),
d_num_selected_tokens.data(),
token_filter::start_state,
stream);
} else {
// The FST is run on the reverse token stream, discarding all tokens between ErrorBegin and the
// next LineEnd (LineEnd, inv_token_0, inv_token_1, ..., inv_token_n, ErrorBegin, LineEnd, ...),
// emitting a [LineEnd, StructBegin, StructEnd] on the end of such an invalid line. In that
// example, inv_token_i for i in [0, n] together with the ErrorBegin are removed and replaced
// with LineEnd, StructBegin, StructEnd. Unlike the previous case, LineEnd tokens are retained
// however, the corresponding token index is written as 0.

auto filter_fst = fst::detail::make_fst(
fst::detail::make_symbol_group_lut(token_filter::symbol_groups, sgid_op),
fst::detail::make_transition_table(token_filter::transition_table),
fst::detail::make_translation_functor(token_filter::TransduceTokenKeepLineEnd{}),
stream);

filter_fst.Transduce(
thrust::make_reverse_iterator(thrust::make_zip_iterator(tokens.data(), token_indices.data()) +
tokens.size()),
static_cast<SymbolOffsetT>(tokens.size()),
thrust::make_reverse_iterator(
thrust::make_zip_iterator(filtered_tokens_out.data(), filtered_token_indices_out.data()) +
tokens.size()),
thrust::make_discard_iterator(),
d_num_selected_tokens.data(),
token_filter::start_state,
stream);
}

auto const num_total_tokens = d_num_selected_tokens.value(stream);
rmm::device_uvector<PdaTokenT> tokens_out{num_total_tokens, stream, mr};
Expand Down
18 changes: 8 additions & 10 deletions cpp/tests/io/nested_json_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,7 @@ TEST_F(JsonTest, TokenStreamWithLineEnd)
using cuio_json::SymbolOffsetT;
using cuio_json::SymbolT;
// Test input
std::string const input = R"({"a":1}
std::string const input = R"({"a":1}
{"b":2}
{"c"
{"d":4})";
Expand All @@ -1181,14 +1181,14 @@ TEST_F(JsonTest, TokenStreamWithLineEnd)

// Parse the JSON and get the token stream
auto [d_tokens_gpu, d_token_indices_gpu] = cuio_json::detail::get_token_stream(
d_input, default_options, stream, rmm::mr::get_current_device_resource());
d_input, default_options, stream, rmm::mr::get_current_device_resource(), false);
// Copy back the number of tokens that were written
auto const tokens_gpu = cudf::detail::make_std_vector_async(d_tokens_gpu, stream);
auto const tokens_gpu = cudf::detail::make_std_vector_async(d_tokens_gpu, stream);

// // Golden token stream sample
using token_t = cuio_json::token_t;
using token_t = cuio_json::token_t;
std::vector<cuio_json::PdaTokenT> const golden_token_stream = {
//Line 1
// Line 1
token_t::LineEnd,
token_t::StructBegin,
token_t::StructMemberBegin,
Expand All @@ -1198,7 +1198,7 @@ TEST_F(JsonTest, TokenStreamWithLineEnd)
token_t::ValueEnd,
token_t::StructMemberEnd,
token_t::StructEnd,
//Line 2
// Line 2
token_t::LineEnd,
token_t::StructBegin,
token_t::StructMemberBegin,
Expand All @@ -1208,11 +1208,11 @@ TEST_F(JsonTest, TokenStreamWithLineEnd)
token_t::ValueEnd,
token_t::StructMemberEnd,
token_t::StructEnd,
//Error
// Error
token_t::LineEnd,
token_t::StructBegin,
token_t::StructEnd,
//Line 3
// Line 3
token_t::LineEnd,
token_t::StructBegin,
token_t::StructMemberBegin,
Expand All @@ -1233,6 +1233,4 @@ TEST_F(JsonTest, TokenStreamWithLineEnd)
}
}



CUDF_TEST_PROGRAM_MAIN()

0 comments on commit e2bd05e

Please sign in to comment.