xref: /llvm-project/mlir/test/Conversion/VectorToGPU/fold-arith-vector-to-mma-ops-mma-sync.mlir (revision 9a795f0c59b1707d1f4bdb352e8805133d72d9e2)
1// RUN: mlir-opt %s -split-input-file -pass-pipeline="builtin.module(func.func(test-fold-arith-extf-into-vector-contract-patterns,convert-vector-to-gpu{use-nvgpu=true},cse))" | FileCheck %s
2
3//###############################################################################################
4// FP16 input, F32 accumulation row-row-row (ldmatrix x4 for matrixA and ldmatrix x4 for matrixB)
5//###############################################################################################
6
7#map0 = affine_map<(d0, d1) -> (d1, d0)>
8#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
9#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
10#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
11
12// CHECK-LABEL: func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row
13func.func @m16n8k16_mmasync16816_f16_f16_f32_row_row_row(%arg0: memref<42x32xf16, #gpu.address_space<workgroup>>, %arg1: memref<32x64xf16, #gpu.address_space<workgroup>>, %arg2: memref<42x64xf32, #gpu.address_space<workgroup>>) {
14  %c0 = arith.constant 0 : index
15  %c8 = arith.constant 8 : index
16  %cst_f16 = arith.constant 0.000000e+00 : f16
17  %cst_f32 = arith.constant 0.000000e+00 : f32
18
19  // CHECK-DAG: nvgpu.ldmatrix %arg0[%{{.*}}, %{{.*}}] {numTiles = 4 : i32, transpose = false}
20  %A = vector.transfer_read %arg0[%c0, %c0], %cst_f16 {in_bounds = [true, true]} : memref<42x32xf16, #gpu.address_space<workgroup>>, vector<16x16xf16>
21  %A_f32 = arith.extf %A : vector<16x16xf16> to vector<16x16xf32>
22
23
24  // CHECK-DAG: nvgpu.ldmatrix %arg1[%{{.*}}, %{{.*}}] {numTiles = 4 : i32, transpose = true}
25  %B = vector.transfer_read %arg1[%c0, %c0], %cst_f16 {permutation_map = #map0, in_bounds = [true, true]} : memref<32x64xf16, #gpu.address_space<workgroup>>, vector<16x16xf16>
26  %C = vector.transfer_read %arg2[%c0, %c0], %cst_f32 {in_bounds = [true, true]} : memref<42x64xf32, #gpu.address_space<workgroup>>, vector<16x16xf32>
27
28  %B0 = vector.extract_strided_slice %B {offsets = [0, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
29  %B0_f32 = arith.extf %B0 : vector<8x16xf16> to vector<8x16xf32>
30  %C0 = vector.extract_strided_slice %C {offsets = [0, 0], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
31
32  // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
33  %D0 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B0_f32, %C0 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
34  vector.transfer_write %D0, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
35
36
37  %B1 = vector.extract_strided_slice %B {offsets = [8, 0], sizes = [8, 16], strides = [1, 1]} : vector<16x16xf16> to vector<8x16xf16>
38  %B1_f32 = arith.extf %B1 : vector<8x16xf16> to vector<8x16xf32>
39  %C1 = vector.extract_strided_slice %C {offsets = [0, 8], sizes = [16, 8], strides = [1, 1]} : vector<16x16xf32> to vector<16x8xf32>
40
41  // CHECK-DAG: nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
42  %D1 = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A_f32, %B1_f32, %C1 : vector<16x16xf32>, vector<8x16xf32> into vector<16x8xf32>
43  vector.transfer_write %D1, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf32>, memref<42x64xf32, #gpu.address_space<workgroup>>
44
45  return
46}
47