1//RUN: mlir-opt -test-linalg-transform-patterns=test-bubble-up-extract-slice-op-pattern -split-input-file %s | FileCheck %s 2 3func.func @dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>, %arg2: index, %arg3: index, %arg4: index, %arg5:index) -> tensor<?x?xf32> { 4 %0 = linalg.generic { 5 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 6 affine_map<(d0, d1) -> (d1)>, 7 affine_map<(d0, d1) -> (d0, d1)>], 8 iterator_types = ["parallel", "parallel"] 9 } ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>) 10 outs(%arg0 : tensor<?x?xf32>) { 11 ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): 12 %add = arith.addf %b0, %b1 : f32 13 linalg.yield %add : f32 14 } -> tensor<?x?xf32> 15 %1 = tensor.extract_slice %0 [%arg2, %arg3] [%arg4, %arg5] [1, 1] 16 : tensor<?x?xf32> to tensor<?x?xf32> 17 return %1 : tensor<?x?xf32> 18} 19 20// CHECK: func @dynamic 21// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> 22// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg3] [%arg5] [1] : tensor<?xf32> to tensor<?xf32> 23// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[%arg2, %arg3] [%arg4, %arg5] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32> 24// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} 25// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[SLICE2]] : tensor<?x?xf32>) 26// CHECK: return %[[GENERIC]] : tensor<?x?xf32> 27 28//----- 29 30func.func @static(%arg0: tensor<16x8xf32>, %arg1: tensor<8xf32>) -> tensor<4x2xf32> { 31 %0 = linalg.generic { 32 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 33 affine_map<(d0, d1) -> (d1)>, 34 affine_map<(d0, d1) -> (d0, d1)>], 35 iterator_types = ["parallel", "parallel"] 36 } ins(%arg0, %arg1 : tensor<16x8xf32>, tensor<8xf32>) 37 outs(%arg0 : tensor<16x8xf32>) { 38 ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): 39 %add = arith.addf %b0, %b1 : f32 40 linalg.yield %add : f32 41 } -> tensor<16x8xf32> 42 %1 = tensor.extract_slice %0 [8, 4] [4, 2] [1, 1] 43 : tensor<16x8xf32> to tensor<4x2xf32> 44 return %1 : tensor<4x2xf32> 45} 46 47// CHECK: func @static 48// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<16x8xf32> to tensor<4x2xf32> 49// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[4] [2] [1] : tensor<8xf32> to tensor<2xf32> 50// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<16x8xf32> to tensor<4x2xf32> 51// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} 52// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<4x2xf32>) 53// CHECK: return %[[GENERIC]] : tensor<4x2xf32> 54 55//----- 56 57func.func @mixed(%arg0: tensor<?x8xf32>, %arg1: tensor<8xf32>, %arg2: index, %arg3: index) -> tensor<?x2xf32> { 58 %0 = linalg.generic { 59 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 60 affine_map<(d0, d1) -> (d1)>, 61 affine_map<(d0, d1) -> (d0, d1)>], 62 iterator_types = ["parallel", "parallel"] 63 } ins(%arg0, %arg1 : tensor<?x8xf32>, tensor<8xf32>) 64 outs(%arg0 : tensor<?x8xf32>) { 65 ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): 66 %add = arith.addf %b0, %b1 : f32 67 linalg.yield %add : f32 68 } -> tensor<?x8xf32> 69 %1 = tensor.extract_slice %0 [8, %arg2] [%arg3, 2] [1, 1] 70 : tensor<?x8xf32> to tensor<?x2xf32> 71 return %1 : tensor<?x2xf32> 72} 73 74// CHECK: func @mixed 75// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, %arg2] [%arg3, 2] [1, 1] : tensor<?x8xf32> to tensor<?x2xf32> 76// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[%arg2] [2] [1] : tensor<8xf32> to tensor<2xf32> 77// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, %arg2] [%arg3, 2] [1, 1] : tensor<?x8xf32> to tensor<?x2xf32> 78// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} 79// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<?x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<?x2xf32>) 80// CHECK: return %[[GENERIC]] : tensor<?x2xf32> 81 82//----- 83 84func.func @dynamic_to_static(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<4x2xf32> { 85 %0 = linalg.generic { 86 indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 87 affine_map<(d0, d1) -> (d1)>, 88 affine_map<(d0, d1) -> (d0, d1)>], 89 iterator_types = ["parallel", "parallel"] 90 } ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>) 91 outs(%arg0 : tensor<?x?xf32>) { 92 ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): 93 %add = arith.addf %b0, %b1 : f32 94 linalg.yield %add : f32 95 } -> tensor<?x?xf32> 96 %1 = tensor.extract_slice %0 [8, 4] [4, 2] [1, 1] 97 : tensor<?x?xf32> to tensor<4x2xf32> 98 return %1 : tensor<4x2xf32> 99} 100 101// CHECK: func @dynamic_to_static 102// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<?x?xf32> to tensor<4x2xf32> 103// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[4] [2] [1] : tensor<?xf32> to tensor<2xf32> 104// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %arg0[8, 4] [4, 2] [1, 1] : tensor<?x?xf32> to tensor<4x2xf32> 105// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} 106// CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] : tensor<4x2xf32>, tensor<2xf32>) outs(%[[SLICE2]] : tensor<4x2xf32>) 107// CHECK: return %[[GENERIC]] : tensor<4x2xf32> 108 109//----- 110 111func.func @matmul_slice() -> tensor<2x2xf32> { 112 %lhs = arith.constant dense<1.0> : tensor<4x4xf32> 113 %rhs = arith.constant dense<1.0> : tensor<4x4xf32> 114 %dst = arith.constant dense<[[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0], [8.0, 9.0, 10.0, 11.0], [12.0, 13.0, 14.0, 15.0]]> : tensor<4x4xf32> 115 %0 = linalg.matmul ins(%lhs, %rhs : tensor<4x4xf32>, tensor<4x4xf32>) outs(%dst : tensor<4x4xf32>) -> tensor<4x4xf32> 116 %1 = tensor.extract_slice %0[1,1][2,2][1,1] : tensor<4x4xf32> to tensor<2x2xf32> 117 return %1 : tensor<2x2xf32> 118} 119 120// CHECK: func @matmul_slice 121// CHECK: %[[SLICE0:.+]] = arith.constant dense<1.000000e+00> : tensor<2x4xf32> 122// CHECK: %[[SLICE1:.+]] = arith.constant dense<1.000000e+00> : tensor<4x2xf32> 123// CHECK: %[[SLICE3:.+]] = tensor.extract_slice %[[CST:.+]][1, 1] [2, 2] [1, 1] : tensor<4x4xf32> to tensor<2x2xf32> 124// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[SLICE0]], %[[SLICE1]] : tensor<2x4xf32>, tensor<4x2xf32>) outs(%[[SLICE3]] : tensor<2x2xf32>) -> tensor<2x2xf32> 125// CHECK: return %[[MATMUL]] : tensor<2x2xf32> 126 127//----- 128 129func.func @conv_slice(%input: tensor<1x225x225x3xf32>, %filter: tensor<3x3x3x32xf32>) -> tensor<1x32x32x16xf32> { 130 %c112 = arith.constant 112 : index 131 %c32 = arith.constant 32 : index 132 %c16 = arith.constant 16 : index 133 %c8 = arith.constant 8 : index 134 %c4 = arith.constant 4 : index 135 %c0 = arith.constant 0 : index 136 %cst = arith.constant 0.0 : f32 137 138 %init = tensor.empty() : tensor<1x112x112x32xf32> 139 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> 140 141 %conv = linalg.conv_2d_nhwc_hwcf 142 {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} 143 ins(%input, %filter : tensor<1x225x225x3xf32>, tensor<3x3x3x32xf32>) 144 outs(%fill : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> 145 146 %slice = tensor.extract_slice %conv [0, 64, 64, 16] [1, 32, 32, 16] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x32x32x16xf32> 147 148 return %slice : tensor<1x32x32x16xf32> 149} 150 151// CHECK: func @conv_slice 152// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x112x112x32xf32> 153// CHECK: %[[SLICE0:.+]] = tensor.extract_slice %arg0[0, 128, 128, 0] [1, 65, 65, 3] [1, 1, 1, 1] : tensor<1x225x225x3xf32> to tensor<1x65x65x3xf32> 154// CHECK: %[[SLICE1:.+]] = tensor.extract_slice %arg1[0, 0, 0, 16] [3, 3, 3, 16] [1, 1, 1, 1] : tensor<3x3x3x32xf32> to tensor<3x3x3x16xf32> 155// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[INIT]][0, 64, 64, 16] [1, 32, 32, 16] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x32x32x16xf32> 156// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[CST:.+]] : f32) outs(%[[SLICE2]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32> 157// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[SLICE0]], %[[SLICE1]] : tensor<1x65x65x3xf32>, tensor<3x3x3x16xf32>) outs(%[[FILL]] : tensor<1x32x32x16xf32>) -> tensor<1x32x32x16xf32> 158// CHECK: return %[[CONV]] : tensor<1x32x32x16xf32> 159 160//----- 161 162// The slice is not supposed to be bubbled up when it is rank-reducing. 163func.func @rank_reducing_slice(%width : index) -> tensor<1x1x1x?xf32> { 164 %cst = arith.constant 1.000000e+00 : f32 165 %init = tensor.empty(%width) : tensor<1x?xf32> 166 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<1x?xf32>) -> tensor<1x?xf32> 167 %slice = tensor.extract_slice %fill[0, 0] [1, %width] [1, 1] : tensor<1x?xf32> to tensor<?xf32> 168 %c0 = arith.constant 0 : index 169 %sz0 = tensor.dim %slice, %c0 : tensor<?xf32> 170 %expand = tensor.expand_shape %slice [[0, 1, 2, 3]] output_shape [1, 1, 1, %sz0] : tensor<?xf32> into tensor<1x1x1x?xf32> 171 return %expand : tensor<1x1x1x?xf32> 172} 173 174// CHECK: func @rank_reducing_slice 175// CHECK: %[[INIT:.+]] = tensor.empty 176// CHECK: %[[FILL:.+]] = linalg.fill ins 177// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[FILL]] 178// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] 179// CHECK: return %[[EXPAND]] 180