1// RUN: mlir-opt -gpu-decompose-memrefs -allow-unregistered-dialect -split-input-file %s | FileCheck %s 2 3// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)> 4// CHECK: @decompose_store 5// CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref<?x?x?xf32>) 6// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] 7// CHECK: gpu.launch 8// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in 9// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]] 10// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>> 11// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>> 12func.func @decompose_store(%arg0 : f32, %arg1 : memref<?x?x?xf32>) { 13 %c0 = arith.constant 0 : index 14 %c1 = arith.constant 1 : index 15 %c2 = arith.constant 2 : index 16 %block_dim0 = memref.dim %arg1, %c0 : memref<?x?x?xf32> 17 %block_dim1 = memref.dim %arg1, %c1 : memref<?x?x?xf32> 18 %block_dim2 = memref.dim %arg1, %c2 : memref<?x?x?xf32> 19 gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) 20 threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) { 21 memref.store %arg0, %arg1[%tx, %ty, %tz] : memref<?x?x?xf32> 22 gpu.terminator 23 } 24 return 25} 26 27// ----- 28 29// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)> 30// CHECK: @decompose_store_strided 31// CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) 32// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] 33// CHECK: gpu.launch 34// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in 35// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]], %[[STRIDES]]#2] 36// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>> 37// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref<f32, strided<[], offset: ?>> 38func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) { 39 %c0 = arith.constant 0 : index 40 %c1 = arith.constant 1 : index 41 %c2 = arith.constant 2 : index 42 %block_dim0 = memref.dim %arg1, %c0 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 43 %block_dim1 = memref.dim %arg1, %c1 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 44 %block_dim2 = memref.dim %arg1, %c2 : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 45 gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) 46 threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) { 47 memref.store %arg0, %arg1[%tx, %ty, %tz] : memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 48 gpu.terminator 49 } 50 return 51} 52 53// ----- 54 55// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)> 56// CHECK: @decompose_load 57// CHECK-SAME: (%[[MEM:.*]]: memref<?x?x?xf32>) 58// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] 59// CHECK: gpu.launch 60// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in 61// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]] 62// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref<f32> to memref<f32, strided<[], offset: ?>> 63// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref<f32, strided<[], offset: ?>> 64// CHECK: "test.test"(%[[RES]]) : (f32) -> () 65func.func @decompose_load(%arg0 : memref<?x?x?xf32>) { 66 %c0 = arith.constant 0 : index 67 %c1 = arith.constant 1 : index 68 %c2 = arith.constant 2 : index 69 %block_dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32> 70 %block_dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32> 71 %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32> 72 gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) 73 threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) { 74 %res = memref.load %arg0[%tx, %ty, %tz] : memref<?x?x?xf32> 75 "test.test"(%res) : (f32) -> () 76 gpu.terminator 77 } 78 return 79} 80 81// ----- 82 83// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)> 84// CHECK: @decompose_subview 85// CHECK-SAME: (%[[MEM:.*]]: memref<?x?x?xf32>) 86// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] 87// CHECK: gpu.launch 88// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in 89// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]] 90// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1, 1] 91// CHECK: "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> () 92func.func @decompose_subview(%arg0 : memref<?x?x?xf32>) { 93 %c0 = arith.constant 0 : index 94 %c1 = arith.constant 1 : index 95 %c2 = arith.constant 2 : index 96 %block_dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32> 97 %block_dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32> 98 %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32> 99 gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) 100 threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) { 101 %res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> 102 "test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>) -> () 103 gpu.terminator 104 } 105 return 106} 107 108// ----- 109 110// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 2)> 111// CHECK: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)> 112// CHECK: #[[MAP2:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)> 113// CHECK: @decompose_subview_strided 114// CHECK-SAME: (%[[MEM:.*]]: memref<?x?x?xf32>) 115// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] 116// CHECK: gpu.launch 117// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in 118// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[STRIDES]]#0] 119// CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#1] 120// CHECK: %[[IDX2:.*]] = affine.apply #[[MAP2]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]] 121// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX2]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[IDX]], %[[IDX1]], 4] 122// CHECK: "test.test"(%[[PTR]]) : (memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>) -> () 123func.func @decompose_subview_strided(%arg0 : memref<?x?x?xf32>) { 124 %c0 = arith.constant 0 : index 125 %c1 = arith.constant 1 : index 126 %c2 = arith.constant 2 : index 127 %block_dim0 = memref.dim %arg0, %c0 : memref<?x?x?xf32> 128 %block_dim1 = memref.dim %arg0, %c1 : memref<?x?x?xf32> 129 %block_dim2 = memref.dim %arg0, %c2 : memref<?x?x?xf32> 130 gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) 131 threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) { 132 %res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [2, 3, 4] : memref<?x?x?xf32> to memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>> 133 "test.test"(%res) : (memref<?x?x?xf32, strided<[?, ?, 4], offset: ?>>) -> () 134 gpu.terminator 135 } 136 return 137} 138