From 92da59a22b9f8efadc7b8a007d0371fe9273970d Mon Sep 17 00:00:00 2001 From: Nick Nuon <76173566+Nick-Nuon@users.noreply.github.com> Date: Thu, 20 Jun 2024 09:08:35 -0400 Subject: [PATCH] Avx512 Validation (#45) * NoErrorIncompleteThenASCIIAvx512 without validate count passing * addition * all tests pass save BadHeaderBitsAvx512 and Validaecount * All tests wo validatecount passing * all tests working + benchmarks * cleanup * Update src/UTF8.cs --------- Co-authored-by: Daniel Lemire --- benchmark/Benchmark.cs | 16 ++ src/UTF8.cs | 395 ++++++++++++++++++++++++++++++------ test/UTF8ValidationTests.cs | 70 ++++++- 3 files changed, 411 insertions(+), 70 deletions(-) diff --git a/benchmark/Benchmark.cs b/benchmark/Benchmark.cs index e3af605..c288921 100644 --- a/benchmark/Benchmark.cs +++ b/benchmark/Benchmark.cs @@ -238,6 +238,22 @@ public unsafe void SIMDUtf8ValidationRealDataArm64() } + [Benchmark] + [BenchmarkCategory("avx512")] + public unsafe void SIMDUtf8ValidationRealDataAvx512() + { + if (allLinesUtf8 != null) + { + RunUtf8ValidationBenchmark(allLinesUtf8, (byte* pInputBuffer, int inputLength) => + { + int dummyUtf16CodeUnitCountAdjustment, dummyScalarCountAdjustment; + // Call the method with additional out parameters within the lambda. + // You must handle these additional out parameters inside the lambda, as they cannot be passed back through the delegate. + return SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx512(pInputBuffer, inputLength, out dummyUtf16CodeUnitCountAdjustment, out dummyScalarCountAdjustment); + }); + } + } + [Benchmark] [BenchmarkCategory("avx")] public unsafe void SIMDUtf8ValidationRealDataAvx2() diff --git a/src/UTF8.cs b/src/UTF8.cs index edffb54..ad5b92c 100644 --- a/src/UTF8.cs +++ b/src/UTF8.cs @@ -10,7 +10,6 @@ namespace SimdUnicode public static class UTF8 { - // Returns &inputBuffer[inputLength] if the input buffer is valid. /// /// Given an input buffer of byte length , @@ -32,10 +31,10 @@ public static class UTF8 { return GetPointerToFirstInvalidByteAvx2(pInputBuffer, inputLength, out Utf16CodeUnitCountAdjustment, out ScalarCodeUnitCountAdjustment); } - /*if (Vector512.IsHardwareAccelerated && Avx512Vbmi2.IsSupported) + if (Vector512.IsHardwareAccelerated && Avx512Vbmi.IsSupported) { - return GetPointerToFirstInvalidByteAvx512(pInputBuffer, inputLength); - }*/ + return GetPointerToFirstInvalidByteAvx512(pInputBuffer, inputLength, out Utf16CodeUnitCountAdjustment, out ScalarCodeUnitCountAdjustment); + } if (Ssse3.IsSupported) { return GetPointerToFirstInvalidByteSse(pInputBuffer, inputLength,out Utf16CodeUnitCountAdjustment, out ScalarCodeUnitCountAdjustment); @@ -196,69 +195,6 @@ private static (int utfAdjust, int scalarAdjust) GetFinalScalarUtfAdjustments(by } - // We scan the input from buf to len, possibly going back howFarBack bytes, to find the end of - // a valid UTF-8 sequence. We return buf + len if the buffer is valid, otherwise we return the - // pointer to the first invalid byte. Also updated the utf16CodeUnitCountAdjustment and scalarCountAdjustment - private unsafe static byte* RewindAndValidateWithErrors(int howFarBack, byte* buf, int len, ref int utf16CodeUnitCountAdjustment, ref int scalarCountAdjustment) - { - int extraLen = 0; - bool foundLeadingBytes = false; - - // Print the byte value at the buf pointer - byte* PinputPlusProcessedlength = buf; - int TooLongErroronEdgeUtfadjust = 0; - int TooLongErroronEdgeScalaradjust = 0; - - for (int i = 0; i <= howFarBack; i++) - { - byte candidateByte = buf[0 - i]; - foundLeadingBytes = (candidateByte & 0b11000000) != 0b10000000; - - if (foundLeadingBytes) - { - - (TooLongErroronEdgeUtfadjust, TooLongErroronEdgeScalaradjust) = GetFinalScalarUtfAdjustments(candidateByte); - - buf -= i; - break; - } - } - - if (!foundLeadingBytes) - { - return buf - howFarBack; - } - - int TailUtf16CodeUnitCountAdjustment = 0; - int TailScalarCountAdjustment = 0; - - byte* invalidBytePointer = GetPointerToFirstInvalidByteScalar(buf, len + extraLen, out TailUtf16CodeUnitCountAdjustment, out TailScalarCountAdjustment); - - // We need to take care of eg - // 11011110 10101101 11110000 10101101 10101111 10011111 11010111 10101000 11001101 10111001 11010100 10000111 11101111 10010000 10000000 11110011 - // 10110100 10101100 10100111 11100100 10101011 10011111 11101111 10100010 10110010 11011100 10100000 00100010 *11110000* 10011001 10101011 10000011 - // 10000000 10100010 11101110 10010101 10101001 11010100 10100111 11110000 10101001 10011101 10011011 11100100 10101011 10010111 11100110 10011001 <= Too long error @ 32 byte edge - // 10010000 11101111 10111111 10010110 11001010 10000000 11000111 10100010 11110010 10111100 10111011 10010100 11101001 10001011 10000110 11110100 - // Without the following check, the 11110000 byte is erroneously double counted: the SIMD procedure counts it once, then it is counted again by the scalar function - // Normally , if there is an error, this does not cause an issue: most erronous utf-8 unit will not be counted - // but it is in the case of too long as if you take for example (1111---- 10----- 10----- 10-----) 10----- - // the part between parentheses will be counted as valid and thus scalaradjust/utfadjust will be incremented once too much - - bool isContinuationByte = (invalidBytePointer[0] & 0xC0) == 0x80; - bool isOnEdge = (invalidBytePointer == PinputPlusProcessedlength); - - if (isContinuationByte && isOnEdge) - { - utf16CodeUnitCountAdjustment += TooLongErroronEdgeUtfadjust; - scalarCountAdjustment += TooLongErroronEdgeScalaradjust; - } - - utf16CodeUnitCountAdjustment += TailUtf16CodeUnitCountAdjustment; - scalarCountAdjustment += TailScalarCountAdjustment; - - return invalidBytePointer; - } - public unsafe static byte* GetPointerToFirstInvalidByteScalar(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment) { @@ -965,6 +901,331 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust return GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment); } + public unsafe static byte* GetPointerToFirstInvalidByteAvx512(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment) + { + int processedLength = 0; + if (pInputBuffer == null || inputLength <= 0) + { + utf16CodeUnitCountAdjustment = 0; + scalarCountAdjustment = 0; + return pInputBuffer; + } + + if (inputLength > 256) + { + + // We skip any ASCII characters at the start of the buffer + int asciirun = 0; + for (; asciirun + 128 <= inputLength; asciirun += 128) + { + + Vector512 block1 = Avx512F.LoadVector512(pInputBuffer + asciirun); + Vector512 block2 = Avx512F.LoadVector512(pInputBuffer + asciirun + 64); + Vector512 or = Avx512F.Or(block1, block2); + if (or.ExtractMostSignificantBits() != 0) + { + break; + } + } + processedLength = asciirun; + + if (processedLength + 64 < inputLength) + { + + Vector512 prevInputBlock = Vector512.Zero; + + Vector512 maxValue = Vector512.Create( + 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 0b11110000 - 1, 0b11100000 - 1, 0b11000000 - 1); + Vector512 prevIncomplete = Avx512BW.SubtractSaturate(prevInputBlock, maxValue); + + + Vector512 shuf1 = Vector512.Create(TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG, + TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG, + TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS, + TOO_SHORT | OVERLONG_2, + TOO_SHORT, + TOO_SHORT | OVERLONG_3 | SURROGATE, + TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4, + TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG, + TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG, + TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS, + TOO_SHORT | OVERLONG_2, + TOO_SHORT, + TOO_SHORT | OVERLONG_3 | SURROGATE, + TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4, + TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG, + TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG, + TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS, + TOO_SHORT | OVERLONG_2, + TOO_SHORT, + TOO_SHORT | OVERLONG_3 | SURROGATE, + TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4, + TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG, + TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG, + TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS, + TOO_SHORT | OVERLONG_2, + TOO_SHORT, + TOO_SHORT | OVERLONG_3 | SURROGATE, + TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4); + + Vector512 shuf2 = Vector512.Create( + CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4, + CARRY | OVERLONG_2, + CARRY, + CARRY, + CARRY | TOO_LARGE, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4, + CARRY | OVERLONG_2, + CARRY, + CARRY, + CARRY | TOO_LARGE, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4, + CARRY | OVERLONG_2, + CARRY, + CARRY, + CARRY | TOO_LARGE, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4, + CARRY | OVERLONG_2, + CARRY, + CARRY, + CARRY | TOO_LARGE, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE, + CARRY | TOO_LARGE | TOO_LARGE_1000, + CARRY | TOO_LARGE | TOO_LARGE_1000); + Vector512 shuf3 = Vector512.Create(TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, + TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, + TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4, + TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE, + TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, + TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, + TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, + TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, + TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4, + TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE, + TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, + TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, + TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, + TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, + TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, + TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4, + TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE, + TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, + TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, + TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, + TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT, + TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4, + TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE, + TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, + TOO_LONG | OVERLONG_2 | TWO_CONTS | SURROGATE | TOO_LARGE, + TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT); + + Vector512 thirdByte = Vector512.Create((byte)(0b11100000u - 0x80)); + Vector512 fourthByte = Vector512.Create((byte)(0b11110000u - 0x80)); + Vector512 v0f = Vector512.Create((byte)0x0F); + Vector512 v80 = Vector512.Create((byte)0x80); + /**** + * So we want to count the number of 4-byte sequences, + * the number of 4-byte sequences, 3-byte sequences, and + * the number of 2-byte sequences. + * We can do it indirectly. We know how many bytes in total + * we have (length). Let us assume that the length covers + * only complete sequences (we need to adjust otherwise). + * We have that + * length = 4 * n4 + 3 * n3 + 2 * n2 + n1 + * where n1 is the number of 1-byte sequences (ASCII), + * n2 is the number of 2-byte sequences, n3 is the number + * of 3-byte sequences, and n4 is the number of 4-byte sequences. + * + * Let ncon be the number of continuation bytes, then we have + * length = n4 + n3 + n2 + ncon + n1 + * + * We can solve for n2 and n3 in terms of the other variables: + * n3 = n1 - 2 * n4 + 2 * ncon - length + * n2 = -2 * n1 + n4 - 4 * ncon + 2 * length + * Thus we only need to count the number of continuation bytes, + * the number of ASCII bytes and the number of 4-byte sequences. + */ + //////////// + // The *block* here is what begins at processedLength and ends + // at processedLength/16*16 or when an error occurs. + /////////// + int start_point = processedLength; + + // The block goes from processedLength to processedLength/16*16. + int asciibytes = 0; // number of ascii bytes in the block (could also be called n1) + int contbytes = 0; // number of continuation bytes in the block + int n4 = 0; // number of 4-byte sequences that start in this block + for (; processedLength + 64 <= inputLength; processedLength += 64) + { + + Vector512 currentBlock = Avx512F.LoadVector512(pInputBuffer + processedLength); + ulong mask = currentBlock.ExtractMostSignificantBits(); + if (mask == 0) + { + // We have an ASCII block, no need to process it, but + // we need to check if the previous block was incomplete. + if (Avx512BW.CompareGreaterThan(prevIncomplete,Vector512.Zero).ExtractMostSignificantBits() != 0) + { + int off = processedLength >= 3 ? processedLength - 3 : processedLength; + byte* invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(16 - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3); + // So the code is correct up to invalidBytePointer + if (invalidBytePointer < pInputBuffer + processedLength) + { + removeCounters(invalidBytePointer, pInputBuffer + processedLength, ref asciibytes, ref n4, ref contbytes); + } + else + { + addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes); + } + int totalbyteasciierror = processedLength - start_point; + (utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, totalbyteasciierror); + return invalidBytePointer; + } + prevIncomplete = Vector512.Zero; + } + else // Contains non-ASCII characters, we need to do non-trivial processing + { + // Use SubtractSaturate to effectively compare if bytes in block are greater than markers. + Vector512 movemask = Vector512.Create(28,29,30,31,0,1,2,3,4,5,6,7,8,9,10,11); + Vector512 shuffled = Avx512F.PermuteVar16x32x2(currentBlock.AsInt32(), movemask , prevInputBlock.AsInt32()).AsByte(); + prevInputBlock = currentBlock; + + Vector512 prev1 = Avx512BW.AlignRight(prevInputBlock, shuffled, (byte)(16 - 1)); + Vector512 byte_1_high = Avx512BW.Shuffle(shuf1, Avx512BW.ShiftRightLogical(prev1.AsUInt16(), 4).AsByte() & v0f);// takes the XXXX 0000 part of the previous byte + Vector512 byte_1_low = Avx512BW.Shuffle(shuf2, (prev1 & v0f)); // takes the 0000 XXXX part of the previous part + Vector512 byte_2_high = Avx512BW.Shuffle(shuf3, Avx512BW.ShiftRightLogical(currentBlock.AsUInt16(), 4).AsByte() & v0f); // takes the XXXX 0000 part of the current byte + Vector512 sc = Avx512F.And(Avx512F.And(byte_1_high, byte_1_low), byte_2_high); + Vector512 prev2 = Avx512BW.AlignRight(prevInputBlock, shuffled, (byte)(16 - 2)); + Vector512 prev3 = Avx512BW.AlignRight(prevInputBlock, shuffled, (byte)(16 - 3)); + Vector512 isThirdByte = Avx512BW.SubtractSaturate(prev2, thirdByte); + Vector512 isFourthByte = Avx512BW.SubtractSaturate(prev3, fourthByte); + Vector512 must23 = Avx512F.Or(isThirdByte, isFourthByte); + Vector512 must23As80 = Avx512F.And(must23, v80); + Vector512 error = Avx512F.Xor(must23As80, sc); + + if (Avx512BW.CompareGreaterThan(error,Vector512.Zero).ExtractMostSignificantBits() != 0) + { + byte* invalidBytePointer; + if (processedLength == 0) + { + invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(0, pInputBuffer + processedLength, inputLength - processedLength); + } + else + { + invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3); + } + if (invalidBytePointer < pInputBuffer + processedLength) + { + removeCounters(invalidBytePointer, pInputBuffer + processedLength, ref asciibytes, ref n4, ref contbytes); + } + else + { + addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes); + } + int total_bytes_processed = (int)(invalidBytePointer - (pInputBuffer + start_point)); + (utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, total_bytes_processed); + return invalidBytePointer; + } + + prevIncomplete = Avx512BW.SubtractSaturate(currentBlock, maxValue); + contbytes += (int)Popcnt.X64.PopCount(byte_2_high.ExtractMostSignificantBits()); + // We use two instructions (SubtractSaturate and ExtractMostSignificantBits) to update n4, with one arithmetic operation. + n4 += (int)Popcnt.X64.PopCount(Avx512BW.SubtractSaturate(currentBlock, fourthByte).ExtractMostSignificantBits()); + } + + // important: we just update asciibytes if there was no error. + // We count the number of ascii bytes in the block using just some simple arithmetic + // and no expensive operation: + asciibytes += (int)(64 - Popcnt.X64.PopCount(mask)); + } + // We may still have an error. + if (processedLength < inputLength || Avx512BW.CompareGreaterThan(prevIncomplete,Vector512.Zero).ExtractMostSignificantBits() != 0 ) + { + byte* invalidBytePointer; + if (processedLength == 0) + { + invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(0, pInputBuffer + processedLength, inputLength - processedLength); + } + else + { + invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3); + + } + if (invalidBytePointer != pInputBuffer + inputLength) + { + if (invalidBytePointer < pInputBuffer + processedLength) + { + removeCounters(invalidBytePointer, pInputBuffer + processedLength, ref asciibytes, ref n4, ref contbytes); + } + else + { + addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes); + } + int total_bytes_processed = (int)(invalidBytePointer - (pInputBuffer + start_point)); + (utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, total_bytes_processed); + return invalidBytePointer; + } + else + { + addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes); + } + } + int final_total_bytes_processed = inputLength - start_point; + (utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, final_total_bytes_processed); + return pInputBuffer + inputLength; + } + } + return GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment); + } + public unsafe static byte* GetPointerToFirstInvalidByteArm64(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment) { int processedLength = 0; diff --git a/test/UTF8ValidationTests.cs b/test/UTF8ValidationTests.cs index 03cc0f6..16ac531 100644 --- a/test/UTF8ValidationTests.cs +++ b/test/UTF8ValidationTests.cs @@ -139,6 +139,14 @@ public void simpleGoodSequencesSse() simpleGoodSequences(SimdUnicode.UTF8.GetPointerToFirstInvalidByteSse); } + [Trait("Category", "avx512")] + [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Avx512)] + public void simpleGoodSequencesAvx512() + { + simpleGoodSequences(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx512); + } + + private void BadSequences(Utf8ValidationFunction utf8ValidationDelegate) { string[] badSequences = { @@ -209,6 +217,13 @@ public void BadSequencesAvx2() BadSequences(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx2); } + [Trait("Category", "avx512")] + [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Avx512)] + public void BadSequencesAvx512() + { + BadSequences(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx512); + } + [Trait("Category", "arm64")] [FactOnSystemRequirementAttribute(TestSystemRequirements.Arm64)] public void BadSequencesArm64() @@ -263,6 +278,13 @@ public void NoErrorAvx2() NoError(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx2); } + [Trait("Category", "avx512")] + [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Avx512)] + public void NoErrorAvx512() + { + NoError(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx512); + } + [Trait("Category", "arm64")] [FactOnSystemRequirementAttribute(TestSystemRequirements.Arm64)] public void NoErrorArm64() @@ -323,6 +345,13 @@ public void NoErrorSpecificByteCountAvx2() NoErrorSpecificByteCount(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx2); } + [Trait("Category", "avx512")] + [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Avx512)] + public void NoErrorSpecificByteCountAvx512() + { + NoErrorSpecificByteCount(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx512); + } + [Trait("Category", "arm64")] [FactOnSystemRequirementAttribute(TestSystemRequirements.Arm64)] public void NoErrorSpecificByteCountArm64() @@ -344,14 +373,13 @@ private void NoErrorIncompleteThenASCII(Utf8ValidationFunction utf8ValidationDel allAscii.InsertRange(incompleteLocation, singleBytes); var utf8 = allAscii.ToArray(); - int cutOffLength = 128;//utf8.Length - rand.Next(1, firstCodeLength); + int cutOffLength = 128; cutOffLength = Math.Min(cutOffLength, outputLength); // Ensure it doesn't exceed the length of truncatedUtf8 byte[] truncatedUtf8 = new byte[outputLength]; // Initialized to zero Array.Copy(utf8, 0, truncatedUtf8, 0, cutOffLength); bool isValidUtf8 = ValidateUtf8(truncatedUtf8, utf8ValidationDelegate); - // string utf8HexString = BitConverter.ToString(truncatedUtf8).Replace("-", " "); try { Assert.False(isValidUtf8); @@ -389,6 +417,13 @@ public void NoErrorIncompleteThenASCIIAvx2() NoErrorIncompleteThenASCII(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx2); } + [Trait("Category", "avx512")] + [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Avx512)] + public void NoErrorIncompleteThenASCIIAvx512() + { + NoErrorIncompleteThenASCII(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx512); + } + [Trait("Category", "arm64")] [FactOnSystemRequirementAttribute(TestSystemRequirements.Arm64)] @@ -405,7 +440,7 @@ private void NoErrorIncompleteAt256Vector(Utf8ValidationFunction utf8ValidationD { var allAscii = new List(Enumerable.Repeat((byte)0, 256)); int firstcodeLength = rand.Next(2, 5); - List singlebytes = generator.Generate(1, firstcodeLength); //recall:generate a utf8 code between 2 and 4 bytes + List singlebytes = generator.Generate(1, firstcodeLength); //generate a utf8 code between 2 and 4 bytes int incompleteLocation = 128 - rand.Next(1, firstcodeLength - 1); allAscii.InsertRange(incompleteLocation, singlebytes); @@ -450,6 +485,13 @@ public void NoErrorIncompleteAt256VectorAvx2() NoErrorIncompleteAt256Vector(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx2); } + [Trait("Category", "avx51")] + [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Avx512)] + public void NoErrorIncompleteAt256VectorAvx512() + { + NoErrorIncompleteAt256Vector(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx512); + } + [Trait("Category", "arm64")] [FactOnSystemRequirementAttribute(TestSystemRequirements.Arm64)] public void NoErrorIncompleteAt256VectorArm64() @@ -513,6 +555,13 @@ public void BadHeaderBitsAvx2() BadHeaderBits(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx2); } + [Trait("Category", "avx512")] + [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Avx512)] + public void BadHeaderBitsAvx512() + { + BadHeaderBits(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx512); + } + [Trait("Category", "arm64")] [FactOnSystemRequirementAttribute(TestSystemRequirements.Arm64)] public void BadHeaderBitsArm64() @@ -575,6 +624,13 @@ public void TooShortErrorAvx2() TooShortError(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx2); } + [Trait("Category", "avx512")] + [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Avx512)] + public void TooShortErrorAvx512() + { + TooShortError(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx512); + } + [Trait("Category", "arm64")] [FactOnSystemRequirementAttribute(TestSystemRequirements.Arm64)] public void TooShortErrorArm64() @@ -637,6 +693,14 @@ public void TooLongErrorAvx2() TooLongError(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx2); } + [Trait("Category", "avx512")] + [FactOnSystemRequirementAttribute(TestSystemRequirements.X64Avx512)] + public void TooLongErrorAvx512() + { + TooLongError(SimdUnicode.UTF8.GetPointerToFirstInvalidByteAvx512); + } + + [Trait("Category", "arm64")] [FactOnSystemRequirementAttribute(TestSystemRequirements.Arm64)] public void TooLongErrorArm64()