Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix extract_slice causing compilation errors #17519

Closed
wants to merge 2 commits into from

Conversation

IanWood1
Copy link
Contributor

@IanWood1 IanWood1 commented May 29, 2024

Addresses this SHARK-TestSuite issue related to ExtractSliceOp causing compilation failures. The issue originated from a pattern of dequantize-like -> tensor.extract_slice -> quantize-like (within a dispatch). This means that the tensor.extract_slice op was operating on f32 instead of the quantized i8 types. MLIR doesn't handle bufferization on extractslice/differing bit-width casts (see mlir/*/EmptyTensorElimination.cpp), causing stack allocations in the middle of dispatches.

This is addressed by adding linalg.generic ops before and after the extract_slice op that cast back to the low bit-width type, eliminating the f32 tensors (after getting merged with the dequantize-like/quantize generic

TODOs:

  • Add iree tests (although passes SHARK-TestSuite, add iree tests)
  • Fix handling/matching of element types

Related issues:
nod-ai/SHARK-TestSuite#182
nod-ai/SHARK-ModelDev#683

@IanWood1 IanWood1 added benchmarks:cuda Run default CUDA benchmarks benchmarks:x86_64 Run default x86_64 benchmarks benchmarks:comp-stats Run default compilation statistics benchmarks benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:android-gpu Run default Android GPU benchmarks benchmarks:vulkan-nvidia Run default Vulkan benchmarks on NVIDIA GPU labels May 29, 2024
Copy link

github-actions bot commented May 30, 2024

Abbreviated Benchmark Summary

@ commit a108d2c016256b0e08484ce739e78f17399d5f68 (no previous benchmark results to compare)

Data-Tiling Comparison Table

Click to show
Name No-DT (baseline) DT-Only DT-UK
BertForMaskedLMTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 217.391 (1.0X) N/A 106.261 (2.0X)
BertLargeTF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 780.824 (1.0X) N/A 219.350 (3.6X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 6.988 (1.0X) N/A 8.522 (0.8X)
DeepLabV3_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 31.652 (1.0X) N/A 29.963 (1.1X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 35.973 (1.0X) N/A 34.361 (1.0X)
EfficientNetV2STF(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 280.420 (1.0X) N/A 229.646 (1.2X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.831 (1.0X) N/A 5.006 (1.2X)
EfficientNet_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 26.928 (1.0X) N/A 13.104 (2.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 9.152 (1.0X) N/A 8.774 (1.0X)
GPT2_117M_TF_1X1XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 70.511 (1.0X) N/A 40.981 (1.7X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.202 (1.0X) N/A 8.629 (1.3X)
GPT2_117M_TF_1X4XI32(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 89.969 (1.0X) N/A 41.996 (2.1X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 12.299 (1.0X) N/A 13.258 (0.9X)
MiniLML12H384Uncased(stablehlo) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 80.159 (1.0X) N/A 57.320 (1.4X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 33.772 (1.0X) N/A 61.929 (0.5X)
MobileBertSquad_fp16(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 178.860 (1.0X) N/A 186.315 (1.0X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.361 (1.0X) N/A 62.335 (0.6X)
MobileBertSquad_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 179.644 (1.0X) N/A 191.804 (0.9X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[15-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 66.878 (1.0X) N/A 62.848 (1.1X)
MobileBertSquad_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 489.299 (1.0X) N/A 214.455 (2.3X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.562 (1.0X) N/A 4.598 (1.0X)
MobileNetV1_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 25.110 (1.0X) N/A 18.185 (1.4X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 3.762 (1.0X) N/A 4.903 (0.8X)
MobileNetV2_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 11.895 (1.0X) N/A 11.925 (1.0X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 5.870 (1.0X) N/A 5.417 (1.1X)
MobileNetV2_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 21.708 (1.0X) N/A 11.866 (1.8X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 2.906 (1.0X) N/A 2.821 (1.0X)
MobileNetV3Small_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 2.793 (1.0X) N/A 2.657 (1.1X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 8.452 (1.0X) N/A 9.889 (0.9X)
MobileSSD_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 34.936 (1.0X) N/A 31.804 (1.1X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 0.818 (1.0X) N/A 0.604 (1.4X)
PersonDetect_int8(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 0.749 (1.0X) N/A 0.540 (1.4X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 4.173 (1.0X) N/A 5.153 (0.8X)
PoseNet_fp32(tflite) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 17.719 (1.0X) N/A 18.994 (0.9X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [x86_64-cascadelake-linux_gnu-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ c2-standard-60[cpu] 7.562 (1.0X) N/A 7.579 (1.0X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 48.315 (1.0X) N/A 44.379 (1.1X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 50.529 (1.0X) N/A 44.637 (1.1X)
DeepLabV3_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 30.099 (1.0X) N/A 27.816 (1.1X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 92.179 (1.0X) N/A 21.761 (4.2X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 92.755 (1.0X) N/A 22.318 (4.2X)
GPT2_117M_TF_1X1XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 51.770 (1.0X) N/A 22.051 (2.3X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 126.326 (1.0X) N/A 26.824 (4.7X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 141.428 (1.0X) N/A 28.751 (4.9X)
GPT2_117M_TF_1X4XI32(stablehlo) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 75.850 (1.0X) N/A 26.449 (2.9X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 698.868 (1.0X) N/A 353.006 (2.0X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 700.793 (1.0X) N/A 362.529 (1.9X)
MobileBertSquad_fp32(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 395.617 (1.0X) N/A 217.410 (1.8X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 1117.483 (1.0X) N/A 304.817 (3.7X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1117.272 (1.0X) N/A 304.699 (3.7X)
MobileBertSquad_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 573.713 (1.0X) N/A 179.319 (3.2X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 2102.731 (1.0X) N/A 307.008 (6.8X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[1-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 2104.604 (1.0X) N/A 305.181 (6.9X)
Vit_int8(tflite) [armv8.2-a-generic-linux_android29-llvm_cpu] local_task(embedded_elf)[2-thread,full-inference,system-scheduling] with default @ pixel-6-pro[big-cores] 1122.035 (1.0X) N/A 182.760 (6.1X)
matmul_256x256x2048_i8_i4_i32_tile_config_default(linalg) [armv8.2-a-generic-linux_android29-llvm_cpu] local_sync(embedded_elf)[full-inference,default-flags] with default @ pixel-6-pro[big-cores] 12.163 (1.0X) N/A 1.279 (9.5X)

Raw Latencies

Benchmark Name Average Latency (ms) Median Latency (ms) Latency Standard Deviation (ms)
BertForMaskedLMTF(stablehlo) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][experimental-flags,no-dt] local\_task(embedded\_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 217.391 211.830 14.754
BertLargeTF(stablehlo) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[30-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 219.350 219.757 1.930
DeepLabV3\_fp32(tflite) [x86\_64-cascadelake-linux\_gnu-llvm\_cpu][default-flags,dt-uk] local\_task(embedded\_elf)[8-thread,full-inference,default-flags] with default @ c2-standard-60[cpu] 8.522 8.510 0.030

[Top 3 out of 135 results showed]

No improved or regressed compilation metrics 🏖️

For more information:

Source Workflow Run

@IanWood1 IanWood1 removed the benchmarks:comp-stats Run default compilation statistics benchmarks label May 30, 2024
@IanWood1 IanWood1 removed benchmarks:cuda Run default CUDA benchmarks benchmarks:x86_64 Run default x86_64 benchmarks benchmarks:android-cpu Run default Android CPU benchmarks benchmarks:android-gpu Run default Android GPU benchmarks benchmarks:vulkan-nvidia Run default Vulkan benchmarks on NVIDIA GPU labels Jun 3, 2024
@IanWood1 IanWood1 marked this pull request as ready for review June 3, 2024 16:04
Copy link
Contributor

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this, @IanWood1.

I'm a bit worried about the generic ops produced in this pattern. For example, if the scale of the initial dequant op is very small so that the results are all very close to zero, does your generated generic op for casting f32 -> i8 just give a zero tensor?

sourceGenericOp.getIteratorTypesArray(),
[&](OpBuilder &nestedBuilder, Location loc, ValueRange args) {
// Custom region for f32 -> i8 conversion
auto castOp = nestedBuilder.create<arith::FPToSIOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does modification by the quantization scale come into play here? If you just cast the fp to si, I can't imagine you will retain any useful information.

Is it possible to copy the payload of a generic op, but modify the other information?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah your right, this is definitely a problem. I think that solution might work, I'll take a look

@IanWood1 IanWood1 closed this Jun 3, 2024
@IanWood1 IanWood1 deleted the extract_slice branch November 6, 2024 01:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants