Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve basic_string::find_first_of and basic_string::find_last_of vectorization for large needles or very large haystacks #5029

Open
wants to merge 46 commits into
base: main
Choose a base branch
from

Conversation

AlexGuteniev
Copy link
Contributor

@AlexGuteniev AlexGuteniev commented Oct 20, 2024

Follow up on #4934 (comment):

The case bm<AlgType::str_member_last, char>/400/50 is changing rom 113 ns to 195 ns, a speedup of 0.58.

Looked closer into that case, and made it even faster than it was.

🗺️ Summary of changes

This PR consists of the following changes:

  • Introduced __std_find_first_of_trivial_pos_N family that is used by strings and string view. The existing __std_find_first_of_trivial_N is still used by the standalone algorithm
  • Moved most of the vectorization decision making into the separately compiled code (further simplifying control flow in the header code as a side effect)
  • Added vectorized bitmap algorithm, in addition to the existing vectorized nested loop (two of them for different element sizes), scalar bitmap, and scalar nested loop algorithms
  • Reimplemented a copy of scalar bitmap algorithm in the separately compiled code
  • Implemented threshold system that better corresponds to the expected run time
  • Restored using scalar bitmap algorithm in header in constexpr context, because why not

⚙️ Vector bitmap algorithm

It is an AVX2-only algorithm. It processes 8 values at once.

In a similar way to the existing scalar bitmap algorithm, can be used when all needle character values do not exceed 255. Instead of having an array of 256 bool values, it uses an actual bitmap. The whole bitmap can fit into __m256i variable, that is, an AVX2 register.

If another AVX2 register contains 8 32-bit values, which are indices to 32-bit bitmap parts, _mm256_permutevar8x32_epi32 (vpermd) can look up 8 parts at once. The indices to the parts are high 3 bits of 8 bit values. The low 5 bits can be then used to obtain the exact bit in 32-bit sequence by a shift. In AVX2 there's are variable 32-bit shift that use a vector of shift values instead of just one for all: _mm256_srlv_epi32, _mm256_sllv_epi32. The resulting mask can be obtained by _mm256_movemask_ps.

Bitmap building

Small needles

Unfortunately, there's no instruction in AVX2 that can combine bits from different values of the same vector in a single element. This means that the bitmap building has to be fully scalar, or at least partially (when doing some processing in parallel, but doing final steps in scalar)

The scalar bitmap building loop performs rather poorly, worse than a loop that builds bool array. So I implemented a loop that uses vector instructions for that, so it uses vector registers and no stack, it seems faster than creating a stack array and loading it after. The key things in this approach is that a value from one of the shifts is expanded via _mm256_cvtepu8_epi64, so a 32-bit shift becomes a 256-bit shift of a lower granularity, the granularity is added back by another shift.

I've managed to have only a slight improvement when trying to partially parallel it, and the complexity of bitmap building grew significantly, so let's probably don't to it.

A different variations of non-parellel bitmap building have about the same performance, so I kept almost the same that I tried at first, except that I adjusted it to work fine in 32-bit x86 too.

Large needles

The vector instructions loop that builds performs poorly relative to the bool array building loop. At some point it makes sense to build bool array and compress it to a bitmap. As the size of array/bitmap is constant, it is constant instructions sequence, without loop, and it takes constant time.

Test for suitable element values

This is done separately, before creating the bitmap. This separate check is vectorized, and allows to bail out quickly, if values aren't right, without building the bitmap. There isn't specific benchmark for that currently, but I think this would work.

Advantage over existing

The cases where the needle and haystack product is big enough to make the existing vector algorithms bad, but the haystack is still way bigger that the needle, so the scalar bitmap lookup is also bad. Added some of them to the benchmark.

Surprisingly, this extends to the case with very small needles. With over like 1000 element, vector bitmap wins over SSE4.2 even for just a few needle elements.

Can we have this in SSE?

No. There's _mm256_shuffle_epi8 to do the bitmap parts extraction. But there's no variable vector shift. There isn't even variable vector shift in AVX2 with vector element width smaller than 32. So probably nothing better than using 8-element AVX2 vector.

⚖️ Selecting algorithm

⚠️ Actual vs run time vs full haystack length

The problem with estimating run time in advance is that we don't know how long will it run. The algorithm doesn't run full haystack, if the position is found earlier.

But when selecting algorithm we know only full length. Knowing the full length we can at least estimate the worst case.

Let's still start with worst case, will get back to early return possibility later on.

Run time evaluation

The nested loop algorithms, both scalar and both vectors, are O(m*n), and definitely the vector algorithms is preferred for any noticeably high values of m and n.

Also any bitmap algorithm is faster than nested scalar, unless the element is found in the very first position. So we can safely exclude the nested loop scalar from consideration.

Both scalar and vector bitmap algorithms are some sort of O(n + m), and they have quite different weights of m and n. Specifically, vector bitmap algorithm treat needle length way worse than haystack length, because this part is not parallel, and scalar bitmap algorithm treats them almost equally (surprisingly, needle has slightly less weight). Due to large needle mode, the difference of needle impact on run time between vector and scalar bitmap is constant, in favor of scalar bitmap. This justifies a constant threshold, eventuated during benchmarking at about 48.

Vector nested loop algorithm clearly outperforms when both n and m are small, so their product is also small. In specific cases, vector algorithm is linear, if either n or m is within a single vectorization unit. In this case it doesn't even have a nested loop (for short needle it is a deliberate optimization, for small haystack it is the result of the separate haystack tail processing).

After benchmarking these edge cases, it can be seen that vector nested loop outperforms everything for long needle small haystack, but it doesn't always outperform vector bitmap for short needle / large haystack. The former allows to exclude scalar bitmap algorithm from the consideration: with any not very small haystack, vector bitmap algorithm advantage is noticeable. Very small set of cases where scalar bitmap can win (small but not very small haystack and long needle) still don't give it a solid win, these cases are ultimately bound by the same scalar bitmap building loop for both algorithms. The benchmark here still may show noticeable difference, but only because these are different instances of that loop, and some codegen factors or other random factors might affect it.

So we need to pick:

  • Between AVX bitmap and scalar bitmap for AVX2, which we'll do using a threshold
  • Between AVX bitmap and vector nested loop for AVX2 and enough haystack length fir AVX bitmap
  • Between scalar bitmap and vector nested loop for SSE4.2 or enough haystack length fir AVX bitmap

It is hard to reason about the threshold functions, so the thresholds were obtained by aggressive benchmarking.

Considering early return

There is early return possibility.
If we don't consider it, we may pick a bitmap algorithm where vector nested loop is better.
If we will expect it, but it will not happen, we may pick vector nested loop when a bitmap algorithm is better.

Looks like that the latter gives worse error.

Generally the price of error is small for short needles. Long needles are gambling cases. But even for long needles the price for not picking vector nested loop when it is better is no more than 2x.

Why this dispatch is not in headers?

No big reason.

There's overflow multiply instrisic used from <intrin.h>, but that one is not essential.

Maybe also this will make maintenance easier, by having fewer functions exposed from vector_algorithm.cpp

Otherwise I guess I'm just like hiding the complexity under a carpet.

🛑 Risks

This time I don't see anything that seems incorrect, it is a complex change with some risks to consider:

  • Regressing some performance for some cases due to spending some time deciding/dispatching. I know, but it is a small one.
  • Regressing some performance due to potentially sometimes worse choice of algorithms. The current thresholds give better big picture, still in some border cases it might give slightly worse answer
  • In particular, might give worse choice for the best case, where the element is found immediately (discussed above)
  • Different performance behavior on different CPUs might break fine tuning. Older AMDs that do AVX2 in two takes is most of the concern.
  • Complexity of the vector tricks as usual
  • Changed __std_find_last_of_trivial_pos_N usage, see below

Changed __std_find_last_of_trivial_pos_N usage

__std_find_last_of_trivial_pos_N has been shipped in #4934. Now it does the bitmap, which is not what old code expects. Although all bad would happen is when the header implementation would fail the scalar bitmap due to bad values, this would unnecessary try the bitmap again. This time the attempt would be even faster due to the vectorization of checking, unless the user does not have SSE4.2

I just don't want to add more functions with more names just for this reason

Not wanting to have this situation for another function is the reason I made this PR before the _not_ vectorization (remaining for find 🐱 family)

⏱️ Benchmark results

Click to expand:
Benchmark main this
bm<AlgType::str_member_first, char>/2/3 5.39 ns 5.43 ns
bm<AlgType::str_member_first, char>/6/81 35.0 ns 23.2 ns
bm<AlgType::str_member_first, char>/7/4 12.8 ns 15.7 ns
bm<AlgType::str_member_first, char>/9/3 11.1 ns 13.8 ns
bm<AlgType::str_member_first, char>/22/5 11.2 ns 14.6 ns
bm<AlgType::str_member_first, char>/58/2 12.7 ns 14.7 ns
bm<AlgType::str_member_first, char>/75/85 55.8 ns 46.1 ns
bm<AlgType::str_member_first, char>/102/4 16.2 ns 17.5 ns
bm<AlgType::str_member_first, char>/200/46 73.7 ns 38.4 ns
bm<AlgType::str_member_first, char>/325/1 34.0 ns 36.8 ns
bm<AlgType::str_member_first, char>/400/50 129 ns 53.4 ns
bm<AlgType::str_member_first, char>/1011/11 91.3 ns 106 ns
bm<AlgType::str_member_first, char>/1280/46 436 ns 126 ns
bm<AlgType::str_member_first, char>/1502/23 356 ns 138 ns
bm<AlgType::str_member_first, char>/2203/54 554 ns 206 ns
bm<AlgType::str_member_first, char>/3056/7 264 ns 232 ns
bm<AlgType::str_member_first, wchar_t>/2/3 14.3 ns 13.3 ns
bm<AlgType::str_member_first, wchar_t>/6/81 41.1 ns 44.9 ns
bm<AlgType::str_member_first, wchar_t>/7/4 17.3 ns 18.3 ns
bm<AlgType::str_member_first, wchar_t>/9/3 13.7 ns 18.4 ns
bm<AlgType::str_member_first, wchar_t>/22/5 14.4 ns 19.2 ns
bm<AlgType::str_member_first, wchar_t>/58/2 18.5 ns 23.2 ns
bm<AlgType::str_member_first, wchar_t>/75/85 76.0 ns 60.6 ns
bm<AlgType::str_member_first, wchar_t>/102/4 25.6 ns 29.7 ns
bm<AlgType::str_member_first, wchar_t>/200/46 110 ns 54.5 ns
bm<AlgType::str_member_first, wchar_t>/325/1 64.5 ns 46.8 ns
bm<AlgType::str_member_first, wchar_t>/400/50 184 ns 65.1 ns
bm<AlgType::str_member_first, wchar_t>/1011/11 479 ns 117 ns
bm<AlgType::str_member_first, wchar_t>/1280/46 487 ns 154 ns
bm<AlgType::str_member_first, wchar_t>/1502/23 692 ns 163 ns
bm<AlgType::str_member_first, wchar_t>/2203/54 809 ns 269 ns
bm<AlgType::str_member_first, wchar_t>/3056/7 557 ns 327 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/2/3 16.1 ns 17.2 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/6/81 195 ns 29.3 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/7/4 26.0 ns 18.1 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/9/3 13.4 ns 18.5 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/22/5 14.1 ns 19.4 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/58/2 18.5 ns 23.2 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/75/85 189 ns 170 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/102/4 25.9 ns 29.9 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/200/46 277 ns 247 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/325/1 64.3 ns 69.0 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/400/50 613 ns 532 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/1011/11 513 ns 394 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/1280/46 1631 ns 1414 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/1502/23 995 ns 838 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/2203/54 3135 ns 2828 ns
bm<AlgType::str_member_first, wchar_t, L'\x03B1'>/3056/7 559 ns 564 ns
bm<AlgType::str_member_first, char32_t>/2/3 13.0 ns 11.9 ns
bm<AlgType::str_member_first, char32_t>/6/81 40.0 ns 25.3 ns
bm<AlgType::str_member_first, char32_t>/7/4 15.6 ns 16.7 ns
bm<AlgType::str_member_first, char32_t>/9/3 13.9 ns 17.6 ns
bm<AlgType::str_member_first, char32_t>/22/5 14.3 ns 20.9 ns
bm<AlgType::str_member_first, char32_t>/58/2 14.3 ns 22.2 ns
bm<AlgType::str_member_first, char32_t>/75/85 61.3 ns 55.2 ns
bm<AlgType::str_member_first, char32_t>/102/4 16.4 ns 27.2 ns
bm<AlgType::str_member_first, char32_t>/200/46 110 ns 46.5 ns
bm<AlgType::str_member_first, char32_t>/325/1 27.3 ns 39.1 ns
bm<AlgType::str_member_first, char32_t>/400/50 183 ns 60.6 ns
bm<AlgType::str_member_first, char32_t>/1011/11 333 ns 127 ns
bm<AlgType::str_member_first, char32_t>/1280/46 489 ns 142 ns
bm<AlgType::str_member_first, char32_t>/1502/23 555 ns 164 ns
bm<AlgType::str_member_first, char32_t>/2203/54 818 ns 250 ns
bm<AlgType::str_member_first, char32_t>/3056/7 539 ns 281 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/2/3 17.0 ns 13.9 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/6/81 189 ns 25.7 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/7/4 27.9 ns 16.7 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/9/3 14.2 ns 16.9 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/22/5 14.9 ns 20.1 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/58/2 15.2 ns 18.8 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/75/85 202 ns 203 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/102/4 16.8 ns 22.4 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/200/46 284 ns 283 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/325/1 25.1 ns 29.9 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/400/50 597 ns 601 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/1011/11 333 ns 330 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/1280/46 1731 ns 1739 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/1502/23 1011 ns 1002 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/2203/54 3445 ns 3492 ns
bm<AlgType::str_member_first, char32_t, U'\x03B1'>/3056/7 541 ns 541 ns
bm<AlgType::str_member_last, char>/2/3 5.15 ns 5.19 ns
bm<AlgType::str_member_last, char>/6/81 31.2 ns 21.0 ns
bm<AlgType::str_member_last, char>/7/4 11.8 ns 16.2 ns
bm<AlgType::str_member_last, char>/9/3 10.6 ns 13.2 ns
bm<AlgType::str_member_last, char>/22/5 11.2 ns 13.7 ns
bm<AlgType::str_member_last, char>/58/2 12.3 ns 14.9 ns
bm<AlgType::str_member_last, char>/75/85 58.2 ns 43.1 ns
bm<AlgType::str_member_last, char>/102/4 15.2 ns 17.7 ns
bm<AlgType::str_member_last, char>/200/46 60.6 ns 34.9 ns
bm<AlgType::str_member_last, char>/325/1 34.7 ns 36.7 ns
bm<AlgType::str_member_last, char>/400/50 138 ns 50.3 ns
bm<AlgType::str_member_last, char>/1011/11 94.9 ns 91.4 ns
bm<AlgType::str_member_last, char>/1280/46 363 ns 113 ns
bm<AlgType::str_member_last, char>/1502/23 290 ns 128 ns
bm<AlgType::str_member_last, char>/2203/54 606 ns 204 ns
bm<AlgType::str_member_last, char>/3056/7 270 ns 251 ns
bm<AlgType::str_member_last, wchar_t>/2/3 13.3 ns 10.8 ns
bm<AlgType::str_member_last, wchar_t>/6/81 42.0 ns 49.9 ns
bm<AlgType::str_member_last, wchar_t>/7/4 15.7 ns 16.2 ns
bm<AlgType::str_member_last, wchar_t>/9/3 13.6 ns 17.0 ns
bm<AlgType::str_member_last, wchar_t>/22/5 14.6 ns 18.2 ns
bm<AlgType::str_member_last, wchar_t>/58/2 18.0 ns 20.8 ns
bm<AlgType::str_member_last, wchar_t>/75/85 82.8 ns 58.4 ns
bm<AlgType::str_member_last, wchar_t>/102/4 24.7 ns 29.9 ns
bm<AlgType::str_member_last, wchar_t>/200/46 118 ns 49.7 ns
bm<AlgType::str_member_last, wchar_t>/325/1 61.5 ns 43.5 ns
bm<AlgType::str_member_last, wchar_t>/400/50 191 ns 62.6 ns
bm<AlgType::str_member_last, wchar_t>/1011/11 404 ns 115 ns
bm<AlgType::str_member_last, wchar_t>/1280/46 493 ns 153 ns
bm<AlgType::str_member_last, wchar_t>/1502/23 587 ns 162 ns
bm<AlgType::str_member_last, wchar_t>/2203/54 830 ns 259 ns
bm<AlgType::str_member_last, wchar_t>/3056/7 529 ns 326 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/2/3 15.7 ns 13.5 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/6/81 159 ns 28.9 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/7/4 25.4 ns 17.3 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/9/3 14.3 ns 18.1 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/22/5 15.3 ns 18.5 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/58/2 18.2 ns 21.6 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/75/85 189 ns 166 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/102/4 24.7 ns 29.1 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/200/46 265 ns 255 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/325/1 62.0 ns 67.4 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/400/50 568 ns 525 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/1011/11 507 ns 400 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/1280/46 1617 ns 1391 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/1502/23 1030 ns 854 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/2203/54 3165 ns 2720 ns
bm<AlgType::str_member_last, wchar_t, L'\x03B1'>/3056/7 525 ns 563 ns

@AlexGuteniev AlexGuteniev requested a review from a team as a code owner October 20, 2024 14:28
@StephanTLavavej StephanTLavavej added the performance Must go faster label Oct 20, 2024
@StephanTLavavej StephanTLavavej self-assigned this Oct 20, 2024
@StephanTLavavej

This comment was marked as resolved.

@AlexGuteniev

This comment was marked as resolved.

@AlexGuteniev
Copy link
Contributor Author

https://github.com/AlexGuteniev/STL/tree/ascii-table-experiment is a branch with altered benchmark program that I used to confirm the thresholds characteristics and find out their values. It is experimental science, not just plain theory 🔬 !

@AlexGuteniev

This comment was marked as resolved.

@AlexGuteniev AlexGuteniev changed the title Improve basic_string::find_first_of and basic_string::find_last_of vectorization for large needles Improve basic_string::find_first_of and basic_string::find_last_of vectorization for large needles or very large haystacks Oct 21, 2024
@AlexGuteniev AlexGuteniev marked this pull request as draft October 22, 2024 10:04
@AlexGuteniev AlexGuteniev marked this pull request as ready for review October 27, 2024 16:58
@AlexGuteniev AlexGuteniev removed their assignment Oct 27, 2024
@StephanTLavavej StephanTLavavej self-assigned this Oct 27, 2024
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
@StephanTLavavej

This comment was marked as resolved.

Even though sll is potentially more expensive than sllv,
the broadcast is still more expensive
@AlexGuteniev

This comment was marked as resolved.

stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
stl/src/vector_algorithms.cpp Outdated Show resolved Hide resolved
@StephanTLavavej

This comment was marked as resolved.

@AlexGuteniev

This comment was marked as resolved.

@AlexGuteniev
Copy link
Contributor Author

@StephanTLavavej mentioned on Discord that it is ok to use three-way strategy function.

As the implementation of it still needs to be established (there was one but it matched previous AVX bitmap implementation), moving to work in progress again.

@AlexGuteniev AlexGuteniev marked this pull request as draft November 9, 2024 18:48
@AlexGuteniev AlexGuteniev marked this pull request as ready for review November 23, 2024 13:51
@AlexGuteniev AlexGuteniev removed their assignment Nov 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Must go faster
Projects
Status: Work In Progress
Development

Successfully merging this pull request may close these issues.

3 participants