diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index c011d6cd..f7f316d3 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -383,6 +383,70 @@ def test( # CHECK-SAME: strided<[16, 1], offset: ?>>, vector<16xindex>, vector<16xi1>, vector<16xf16> +@run_test +def test_read_write_dynamic_mapping(): + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16} + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M)] + constraints += [tkw.WaveConstraint(N, BLOCK_N)] + + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + k = tkw.IndexMapping.dynamic_val(0) + mapping = tkw.IndexMapping( + num_iterators=2, + inputs={M: k, N: j}, + outputs={M: i, N: j}, + dynamic_val_mappings={M: i, N: j}, + ) + + @tkw.wave(constraints) + def test( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + off: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16], + ): + offset = tkw.read(off, elements_per_thread=16) + res = tkw.read( + a, + mapping=mapping, + mapping_dynamic_vals=(offset,), + elements_per_thread=16, + ) + tkw.write(res, b, elements_per_thread=16) + + with codegen_test_context(canonicalize=True): + a = torch.randn(16, 16, dtype=torch.float16) + off = torch.randint(16, (16, 16), dtype=torch.int32) + b = torch.zeros(16, 16, dtype=torch.float16) + print(test(a, off, b).module_op) + + # CHECK: func.func @test(%[[ARG0:.*]]: !stream.binding, %[[ARG1:.*]]: !stream.binding, %[[ARG2:.*]]: !stream.binding) + # CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16> + # CHECK-DAG: %[[CST0:.*]] = arith.constant dense<16> : vector<16xindex> + # CHECK-DAG: %[[CST1:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex> + # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + # CHECK: %[[D0:.*]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<16x16xi32, strided<[16, 1], offset: ?>> + # CHECK: %[[D9:.*]] = vector.load %[[D0]][%[[D5:.*]], %[[D8:.*]]] : memref<16x16xi32, strided<[16, 1], offset: ?>>, vector<16xi32> + # CHECK: %[[D10:.*]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<16x16xf16, strided<[16, 1], offset: ?>> + # CHECK: %[[D11:.*]] = arith.index_cast %[[D9]] : vector<16xi32> to vector<16xindex> + # CHECK: %[[D12:.*]] = arith.muli %[[D11]], %[[CST0]] overflow : vector<16xindex> + # CHECK: %[[D13:.*]] = vector.splat %{{.*}} : vector<16xindex> + # CHECK: %[[D14:.*]] = arith.addi %[[D13]], %[[D12]] overflow : vector<16xindex> + # CHECK: %[[D15:.*]] = vector.splat %{{.*}} : vector<16xindex> + # CHECK: %[[D16:.*]] = arith.addi %[[D14]], %[[D15]] overflow : vector<16xindex> + # CHECK: %[[D17:.*]] = arith.addi %[[D16]], %[[CST1]] overflow : vector<16xindex> + # CHECK: %[[D18:.*]] = vector.constant_mask [16] : vector<16xi1> + # CHECK: %[[D19:.*]] = vector.gather %[[D10]][%[[C0]], %[[C0]]] [%[[D17]]], %[[D18]], %[[CST]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xindex>, vector<16xi1>, vector<16xf16> into vector<16xf16> + # CHECK: %[[D20:.*]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<16x16xf16, strided<[16, 1], offset: ?>> + # CHECK: vector.store %[[D19]], %[[D20]][%[[D5]], %[[D8]]] : memref<16x16xf16, strided<[16, 1], offset: ?>>, vector<16xf16> + + @run_test def test_dynamic_copy(): constraints: list[tkw.Constraint] = [