1// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-reassociative-reshape-folding %s | FileCheck %s 2 3// CHECK-LABEL: func @expand_shape_of_rank_reducing_extract( 4// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32> 5// CHECK-DAG: %[[extract1:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] 6// CHECK-SAME: [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32> 7// CHECK-DAG: %[[extract2:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] 8// CHECK-SAME: [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32> 9// CHECK: return %[[extract1]], %[[extract2]] 10func.func @expand_shape_of_rank_reducing_extract( 11 %t: tensor<?x?x?x?xf32>, %idx: index) 12 -> (tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>) 13{ 14 %0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1] 15 : tensor<?x?x?x?xf32> to tensor<?x1x5xf32> 16 %c0 = arith.constant 0 : index 17 %sz0 = tensor.dim %0, %c0 : tensor<?x1x5xf32> 18 %1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%sz0, 1, 1, 5] 19 : tensor<?x1x5xf32> into tensor<?x1x1x5xf32> 20 %2 = tensor.expand_shape %0 [[0, 1], [2], [3]] output_shape [%sz0, 1, 1, 5] 21 : tensor<?x1x5xf32> into tensor<?x1x1x5xf32> 22 return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32> 23} 24 25// ----- 26 27// CHECK-LABEL: func @unpadding_collapse_of_extract_slice( 28// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32> 29// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index 30// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index 31// CHECK: %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0] 32// CHECK-SAME: [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?xf32> 33// CHECK: return %[[extract]] 34func.func @unpadding_collapse_of_extract_slice( 35 %t: tensor<?x?x?x?xf32>, %x: index, %y: index) 36 -> tensor<?x?xf32> { 37 %c1 = arith.constant 1 : index 38 %c3 = arith.constant 3 : index 39 %sz0 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32> 40 %sz1 = tensor.dim %t, %c3 : tensor<?x?x?x?xf32> 41 %0 = tensor.extract_slice %t[%x, %y, 0, 0] [1, %sz0, 1, %sz1] [1, 1, 1, 1] 42 : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32> 43 %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] 44 : tensor<1x?x1x?xf32> into tensor<?x?xf32> 45 return %1 : tensor<?x?xf32> 46} 47 48// ----- 49 50// CHECK-LABEL: func @non_unpadding_collapse_of_extract_slice( 51// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32> 52// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index 53// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index 54// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index 55// CHECK: %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0] 56// CHECK-SAME: [%{{.*}}, %{{.*}}, %[[sz]], 1] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?xf32> 57// CHECK: %[[collapse:.*]] = tensor.collapse_shape %[[extract]] {{\[}}[0], [1, 2]] : tensor<?x?x?xf32> into tensor<?x?xf32> 58// CHECK: return %[[collapse]] 59func.func @non_unpadding_collapse_of_extract_slice( 60 %t: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index) 61 -> tensor<?x?xf32> { 62 %c0 = arith.constant 0 : index 63 %c1 = arith.constant 1 : index 64 %sz0 = tensor.dim %t, %c0 : tensor<?x?x?x?xf32> 65 %sz1 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32> 66 %0 = tensor.extract_slice %t[%x, %y, 0, 0] [%sz0, %sz1, %sz, 1] [1, 1, 1, 1] 67 : tensor<?x?x?x?xf32> to tensor<?x?x?xf32> 68 %1 = tensor.collapse_shape %0 [[0], [1, 2]] 69 : tensor<?x?x?xf32> into tensor<?x?xf32> 70 return %1 : tensor<?x?xf32> 71} 72 73// ----- 74 75// CHECK-LABEL: func @unpadding_collapse_of_extract_slice_with_multiple_users( 76// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32> 77// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index 78// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index 79// CHECK: %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0] 80// CHECK-SAME: [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32> 81// CHECK: %[[collapse:.*]] = tensor.collapse_shape %[[extract]] {{\[}}[0, 1], [2, 3]] : tensor<1x?x1x?xf32> into tensor<?x?xf32> 82// CHECK: return %[[extract]], %[[collapse]] 83func.func @unpadding_collapse_of_extract_slice_with_multiple_users( 84 %t: tensor<?x?x?x?xf32>, %x: index, %y: index) 85 -> (tensor<1x?x1x?xf32>, tensor<?x?xf32>) { 86 %c1 = arith.constant 1 : index 87 %c3 = arith.constant 3 : index 88 %sz0 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32> 89 %sz1 = tensor.dim %t, %c3 : tensor<?x?x?x?xf32> 90 %0 = tensor.extract_slice %t[%x, %y, 0, 0] [1, %sz0, 1, %sz1] [1, 1, 1, 1] 91 : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32> 92 %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]] 93 : tensor<1x?x1x?xf32> into tensor<?x?xf32> 94 return %0, %1 : tensor<1x?x1x?xf32>, tensor<?x?xf32> 95} 96 97// ----- 98 99// CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape( 100// CHECK-SAME: %[[t:.*]]: tensor<?x1x1x5xf32> 101// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] 102// CHECK-SAME: [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32> 103// CHECK: return %[[insert]] 104func.func @rank_reducing_insert_of_collapse_shape( 105 %t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index) 106 -> tensor<?x?x?x?xf32> { 107 %0 = tensor.collapse_shape %t [[0, 1], [2], [3]] 108 : tensor<?x1x1x5xf32> into tensor<?x1x5xf32> 109 %1 = tensor.insert_slice %0 into %d[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1] 110 : tensor<?x1x5xf32> into tensor<?x?x?x?xf32> 111 return %1 : tensor<?x?x?x?xf32> 112} 113 114// ----- 115 116// CHECK-LABEL: func @rank_reducing_parallel_insert_of_collapse_shape( 117// CHECK-SAME: %[[t:.*]]: tensor<?x1x1x5xf32> 118// CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] 119// CHECK-SAME: [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32> 120func.func @rank_reducing_parallel_insert_of_collapse_shape( 121 %t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index, %thr: index) 122 -> tensor<?x?x?x?xf32> { 123 %0 = tensor.collapse_shape %t [[0, 1], [2], [3]] 124 : tensor<?x1x1x5xf32> into tensor<?x1x5xf32> 125 %1 = scf.forall (%iv) in (%thr) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) { 126 scf.forall.in_parallel { 127 tensor.parallel_insert_slice %0 into %o[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1] 128 : tensor<?x1x5xf32> into tensor<?x?x?x?xf32> 129 } 130 } 131 return %1 : tensor<?x?x?x?xf32> 132} 133 134// ----- 135 136// CHECK-LABEL: func @insert_of_padding_expand_shape( 137// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32> 138// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32> 139// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index 140// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index 141// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%[[x]], %[[y]], 0, 0] 142// CHECK-SAME: [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32> 143// CHECK: return %[[insert]] 144func.func @insert_of_padding_expand_shape( 145 %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index) 146 -> tensor<?x?x?x?xf32> { 147 %c0 = arith.constant 0 : index 148 %c1 = arith.constant 1 : index 149 %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32> 150 %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32> 151 %0 = tensor.expand_shape %t [[0, 1], [2, 3]] output_shape [1, %sz0, 1, %sz1] 152 : tensor<?x?xf32> into tensor<1x?x1x?xf32> 153 %1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][1, %sz0, 1, %sz1][1, 1, 1, 1] 154 : tensor<1x?x1x?xf32> into tensor<?x?x?x?xf32> 155 return %1 : tensor<?x?x?x?xf32> 156} 157 158// ----- 159 160// CHECK-LABEL: func @insert_of_non_padding_expand_shape( 161// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32> 162// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32> 163// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index 164// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index 165// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index 166// CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] 167// CHECK-SAME: output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32> 168// CHECK: %[[insert:.*]] = tensor.insert_slice %[[expand]] into %[[d]][%[[x]], %[[y]], 0, 0] 169// CHECK-SAME: [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 170// CHECK: return %[[insert]] 171func.func @insert_of_non_padding_expand_shape( 172 %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index) 173 -> tensor<?x?x?x?xf32> { 174 %c0 = arith.constant 0 : index 175 %c1 = arith.constant 1 : index 176 %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32> 177 %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32> 178 %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [%sz, %sz0, %sz1] 179 : tensor<?x?xf32> into tensor<?x?x?xf32> 180 %1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][%sz, 1, %sz0, %sz1][1, 1, 1, 1] 181 : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 182 return %1 : tensor<?x?x?x?xf32> 183} 184 185// ----- 186 187// CHECK-LABEL: func @parallel_insert_of_padding_expand_shape( 188// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32> 189// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32> 190// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index 191// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index 192// CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] 193// CHECK-SAME: [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32> 194func.func @parallel_insert_of_padding_expand_shape( 195 %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index) 196 -> tensor<?x?x?x?xf32> { 197 %c0 = arith.constant 0 : index 198 %c1 = arith.constant 1 : index 199 %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32> 200 %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32> 201 %0 = tensor.expand_shape %t [[0, 1], [2, 3]] output_shape [1, %sz0, 1, %sz1] 202 : tensor<?x?xf32> into tensor<1x?x1x?xf32> 203 %1 = scf.forall (%i, %j) in (%x, %y) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) { 204 scf.forall.in_parallel { 205 tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][1, %sz0, 1, %sz1][1, 1, 1, 1] 206 : tensor<1x?x1x?xf32> into tensor<?x?x?x?xf32> 207 } 208 } 209 return %1 : tensor<?x?x?x?xf32> 210} 211 212// ----- 213 214// CHECK-LABEL: func @parallel_insert_of_non_padding_expand_shape( 215// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32> 216// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32> 217// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index 218// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index 219// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index 220// CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] 221// CHECK-SAME: output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32> 222// CHECK: tensor.parallel_insert_slice %[[expand]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] 223// CHECK-SAME: [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 224func.func @parallel_insert_of_non_padding_expand_shape( 225 %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index) 226 -> tensor<?x?x?x?xf32> { 227 %c0 = arith.constant 0 : index 228 %c1 = arith.constant 1 : index 229 %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32> 230 %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32> 231 %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [%sz, %sz0, %sz1] 232 : tensor<?x?xf32> into tensor<?x?x?xf32> 233 %1 = scf.forall (%i, %j) in (%x, %y) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) { 234 scf.forall.in_parallel { 235 tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][%sz, 1, %sz0, %sz1][1, 1, 1, 1] 236 : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 237 } 238 } 239 return %1 : tensor<?x?x?x?xf32> 240} 241