Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize basic_string::rfind (the single character overload) #5087

Merged
merged 6 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 40 additions & 12 deletions benchmarks/src/find_and_count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,38 @@
#include <cstdint>
#include <cstdlib>
#include <ranges>
#include <string>
#include <type_traits>
#include <vector>

#include "skewed_allocator.hpp"

enum class Op {
FindSized,
FindUnsized,
Count,
StringFind,
StringRFind,
};

using namespace std;

template <class T, Op Operation>
template <class T, template <class> class Alloc, Op Operation>
void bm(benchmark::State& state) {
const auto size = static_cast<size_t>(state.range(0));
const auto pos = static_cast<size_t>(state.range(1));

vector<T> a(size, T{'0'});
using Container = conditional_t<Operation == Op::StringFind || Operation == Op::StringRFind,
basic_string<T, char_traits<T>, Alloc<T>>, vector<T, Alloc<T>>>;
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved

Container a(size, T{'0'});

if (pos < size) {
a[pos] = T{'1'};
if constexpr (Operation == Op::StringRFind) {
a[size - pos - 1] = T{'1'};
} else {
a[pos] = T{'1'};
}
} else {
if constexpr (Operation == Op::FindUnsized) {
abort();
Expand All @@ -39,6 +52,10 @@ void bm(benchmark::State& state) {
benchmark::DoNotOptimize(ranges::find(a.begin(), unreachable_sentinel, T{'1'}));
} else if constexpr (Operation == Op::Count) {
benchmark::DoNotOptimize(ranges::count(a.begin(), a.end(), T{'1'}));
} else if constexpr (Operation == Op::StringFind) {
benchmark::DoNotOptimize(a.find(T{'1'}));
} else if constexpr (Operation == Op::StringRFind) {
benchmark::DoNotOptimize(a.rfind(T{'1'}));
}
}
}
Expand All @@ -50,17 +67,28 @@ void common_args(auto bm) {
}


BENCHMARK(bm<uint8_t, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint8_t, Op::FindUnsized>)->Apply(common_args);
BENCHMARK(bm<uint8_t, Op::Count>)->Apply(common_args);
BENCHMARK(bm<uint8_t, not_highly_aligned_allocator, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint8_t, highly_aligned_allocator, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint8_t, not_highly_aligned_allocator, Op::FindUnsized>)->Apply(common_args);
BENCHMARK(bm<uint8_t, highly_aligned_allocator, Op::FindUnsized>)->Apply(common_args);
BENCHMARK(bm<uint8_t, not_highly_aligned_allocator, Op::Count>)->Apply(common_args);
BENCHMARK(bm<uint8_t, highly_aligned_allocator, Op::Count>)->Apply(common_args);
BENCHMARK(bm<char, not_highly_aligned_allocator, Op::StringFind>)->Apply(common_args);
BENCHMARK(bm<char, highly_aligned_allocator, Op::StringFind>)->Apply(common_args);
BENCHMARK(bm<char, not_highly_aligned_allocator, Op::StringRFind>)->Apply(common_args);
BENCHMARK(bm<char, highly_aligned_allocator, Op::StringRFind>)->Apply(common_args);

BENCHMARK(bm<uint16_t, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint16_t, Op::Count>)->Apply(common_args);
BENCHMARK(bm<uint16_t, not_highly_aligned_allocator, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint16_t, not_highly_aligned_allocator, Op::Count>)->Apply(common_args);
BENCHMARK(bm<wchar_t, not_highly_aligned_allocator, Op::StringFind>)->Apply(common_args);
BENCHMARK(bm<wchar_t, not_highly_aligned_allocator, Op::StringRFind>)->Apply(common_args);

BENCHMARK(bm<uint32_t, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint32_t, Op::Count>)->Apply(common_args);
BENCHMARK(bm<uint32_t, not_highly_aligned_allocator, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint32_t, not_highly_aligned_allocator, Op::Count>)->Apply(common_args);
BENCHMARK(bm<char32_t, not_highly_aligned_allocator, Op::StringFind>)->Apply(common_args);
BENCHMARK(bm<char32_t, not_highly_aligned_allocator, Op::StringRFind>)->Apply(common_args);

BENCHMARK(bm<uint64_t, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint64_t, Op::Count>)->Apply(common_args);
BENCHMARK(bm<uint64_t, not_highly_aligned_allocator, Op::FindSized>)->Apply(common_args);
BENCHMARK(bm<uint64_t, not_highly_aligned_allocator, Op::Count>)->Apply(common_args);

BENCHMARK_MAIN();
19 changes: 18 additions & 1 deletion stl/inc/__msvc_string_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,24 @@ constexpr size_t _Traits_rfind_ch(_In_reads_(_Hay_size) const _Traits_ptr_t<_Tra
return static_cast<size_t>(-1);
}

for (auto _Match_try = _Haystack + (_STD min)(_Start_at, _Hay_size - 1);; --_Match_try) {
const size_t _Actual_start_at = (_STD min)(_Start_at, _Hay_size - 1);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Is_implementation_handled_char_traits<_Traits>) {
if (!_STD _Is_constant_evaluated()) {
const auto _End = _Haystack + _Actual_start_at + 1;
const auto _Ptr = _STD _Find_last_vectorized(_Haystack, _End, _Ch);

if (_Ptr != _End) {
return static_cast<size_t>(_Ptr - _Haystack);
} else {
return static_cast<size_t>(-1);
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (auto _Match_try = _Haystack + _Actual_start_at;; --_Match_try) {
if (_Traits::eq(*_Match_try, _Ch)) {
return static_cast<size_t>(_Match_try - _Haystack); // found a match
}
Expand Down
32 changes: 0 additions & 32 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,6 @@ _Min_max_element_t __stdcall __std_minmax_element_8(const void* _First, const vo
_Min_max_element_t __stdcall __std_minmax_element_f(const void* _First, const void* _Last, bool _Unused) noexcept;
_Min_max_element_t __stdcall __std_minmax_element_d(const void* _First, const void* _Last, bool _Unused) noexcept;

const void* __stdcall __std_find_last_trivial_1(const void* _First, const void* _Last, uint8_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_2(const void* _First, const void* _Last, uint16_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;

__declspec(noalias) _Min_max_1i __stdcall __std_minmax_1i(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_1u __stdcall __std_minmax_1u(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_2i __stdcall __std_minmax_2i(const void* _First, const void* _Last) noexcept;
Expand Down Expand Up @@ -162,33 +157,6 @@ auto _Minmax_vectorized(_Ty* const _First, _Ty* const _Last) noexcept {
}
}

template <class _Ty, class _TVal>
_Ty* _Find_last_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val) noexcept {
if constexpr (is_pointer_v<_TVal> || is_null_pointer_v<_TVal>) {
#ifdef _WIN64
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_8(_First, _Last, reinterpret_cast<uint64_t>(_Val))));
#else
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_4(_First, _Last, reinterpret_cast<uint32_t>(_Val))));
#endif
} else if constexpr (sizeof(_Ty) == 1) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_1(_First, _Last, static_cast<uint8_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 2) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_2(_First, _Last, static_cast<uint16_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 4) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_4(_First, _Last, static_cast<uint32_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 8) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_8(_First, _Last, static_cast<uint64_t>(_Val))));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}

template <class _Ty, class _TVal1, class _TVal2>
__declspec(noalias) void _Replace_vectorized(
_Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept {
Expand Down
32 changes: 32 additions & 0 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ const void* __stdcall __std_find_trivial_2(const void* _First, const void* _Last
const void* __stdcall __std_find_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;

const void* __stdcall __std_find_last_trivial_1(const void* _First, const void* _Last, uint8_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_2(const void* _First, const void* _Last, uint16_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;

const void* __stdcall __std_find_first_of_trivial_1(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_find_first_of_trivial_2(
Expand Down Expand Up @@ -217,6 +222,33 @@ _Ty* _Find_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val) noe
}
}

template <class _Ty, class _TVal>
_Ty* _Find_last_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val) noexcept {
if constexpr (is_pointer_v<_TVal> || is_null_pointer_v<_TVal>) {
#ifdef _WIN64
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_8(_First, _Last, reinterpret_cast<uint64_t>(_Val))));
#else
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_4(_First, _Last, reinterpret_cast<uint32_t>(_Val))));
#endif
} else if constexpr (sizeof(_Ty) == 1) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_1(_First, _Last, static_cast<uint8_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 2) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_2(_First, _Last, static_cast<uint16_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 4) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_4(_First, _Last, static_cast<uint32_t>(_Val))));
} else if constexpr (sizeof(_Ty) == 8) {
return const_cast<_Ty*>(
static_cast<const _Ty*>(::__std_find_last_trivial_8(_First, _Last, static_cast<uint64_t>(_Val))));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}

// find_first_of vectorization is likely to be a win after this size (in elements)
_INLINE_VAR constexpr ptrdiff_t _Threshold_find_first_of = 16;

Expand Down
17 changes: 17 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,22 @@ void test_case_string_rfind_str(const basic_string<T>& input_haystack, const bas
assert(expected == actual);
}

template <class T>
void test_case_string_rfind_ch(const basic_string<T>& input_haystack, const T value) {
ptrdiff_t expected;

const auto expected_iter = last_known_good_find_last(input_haystack.begin(), input_haystack.end(), value);

if (expected_iter != input_haystack.end()) {
expected = expected_iter - input_haystack.begin();
} else {
expected = -1;
}

const auto actual = static_cast<ptrdiff_t>(input_haystack.rfind(value));
assert(expected == actual);
}

template <class T, class D>
void test_basic_string_dis(mt19937_64& gen, D& dis) {
basic_string<T> input_haystack;
Expand All @@ -1144,6 +1160,7 @@ void test_basic_string_dis(mt19937_64& gen, D& dis) {
test_case_string_find_last_of(input_haystack, input_needle);
test_case_string_find_str(input_haystack, input_needle);
test_case_string_rfind_str(input_haystack, input_needle);
test_case_string_rfind_ch(input_haystack, static_cast<T>(dis(gen)));

for (size_t attempts = 0; attempts < needleDataCount; ++attempts) {
input_needle.push_back(static_cast<T>(dis(gen)));
Expand Down