1// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-collapsing -split-input-file | FileCheck %s 2 3// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> 4// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)> 5 6// CHECK-LABEL: func @reshape 7// CHECK-SAME: (%[[A:.*]]: tensor<?x16xf32>, %[[B:.*]]: tensor<16xf32>, %[[INIT:.*]]: tensor<?x112x16xf32>, %[[SZ0:.*]]: index) 8// CHECK: %[[C112:.*]] = arith.constant 112 : index 9// CHECK: %[[C0:.*]] = arith.constant 0 : index 10// CHECK: %[[RI:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2]] : tensor<?x112x16xf32> into tensor<?x16xf32> 11// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP2]]], 12// CHECK-SAME: iterator_types = ["parallel", "parallel"]} 13// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<?x16xf32>) 14// CHECK: %[[DIM:.*]] = tensor.dim %[[R]], %[[C0]] : tensor<?x16xf32> 15// CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C112]] : index 16// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[\[}}0, 1], [2]] output_shape [%[[VAL_1]], 112, 16] : tensor<?x16xf32> into tensor<?x112x16xf32> 17// CHECK: return %[[RR]] : tensor<?x112x16xf32> 18func.func @reshape(%A: tensor<?x16xf32>, %B: tensor<16xf32>, %init: tensor<?x112x16xf32>, %sz0: index) -> tensor<?x112x16xf32> { 19 %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [%sz0, 112, 16] 20 : tensor<?x16xf32> into tensor<?x112x16xf32> 21 %2 = linalg.generic {indexing_maps = [ 22 affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, 23 affine_map<(d0, d1, d2) -> (d0, d1, d2)>], 24 iterator_types = ["parallel", "parallel", "parallel"]} 25 ins(%0, %B : tensor<?x112x16xf32>, tensor<16xf32>) 26 outs(%init : tensor<?x112x16xf32>) { 27 ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): 28 %s = arith.subf %arg1, %arg2 : f32 29 linalg.yield %s : f32 30 } -> tensor<?x112x16xf32> 31 return %2 : tensor<?x112x16xf32> 32} 33 34// ----- 35 36// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> 37// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)> 38 39// CHECK-LABEL: func @reshape_multiple 40// CHECK-SAME: (%[[A:.*]]: tensor<12544x16xf32>, %[[B:.*]]: tensor<12544x16xf32>, %[[C:.*]]: tensor<16xf32>) 41// CHECK: %[[I:.*]] = tensor.empty() : tensor<112x112x16xf32> 42// CHECK: %[[RI:.*]] = tensor.collapse_shape %[[I]] {{\[}}[0, 1], [2]] : tensor<112x112x16xf32> into tensor<12544x16xf32> 43// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP2]], #[[$MAP3]], #[[$MAP2]]], 44// CHECK-SAME: iterator_types = ["parallel", "parallel"]} 45// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<12544x16xf32>, tensor<12544x16xf32>, tensor<16xf32>) outs(%[[RI]] : tensor<12544x16xf32>) 46// CHECK: %[[RR:.*]] = tensor.expand_shape %[[R]] {{\[}}[0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32> 47// CHECK: return %[[RR]] : tensor<112x112x16xf32> 48func.func @reshape_multiple(%A: tensor<12544x16xf32>, %B: tensor<12544x16xf32>, 49 %C: tensor<16xf32>) -> tensor<112x112x16xf32> { 50 %0 = tensor.expand_shape %A [[0, 1], [2]] output_shape [112, 112, 16] 51 : tensor<12544x16xf32> into tensor<112x112x16xf32> 52 %1 = tensor.expand_shape %B [[0, 1], [2]] output_shape [112, 112, 16] 53 : tensor<12544x16xf32> into tensor<112x112x16xf32> 54 %2 = tensor.empty() : tensor<112x112x16xf32> 55 %3 = linalg.generic {indexing_maps = [ 56 affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 57 affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 58 affine_map<(d0, d1, d2) -> (d2)>, 59 affine_map<(d0, d1, d2) -> (d0, d1, d2)>], 60 iterator_types = ["parallel", "parallel", "parallel"]} 61 ins(%0, %1, %C : tensor<112x112x16xf32>, tensor<112x112x16xf32>, tensor<16xf32>) 62 outs(%2 : tensor<112x112x16xf32>) { 63 ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): 64 %s = arith.subf %arg1, %arg2 : f32 65 %m = arith.mulf %s, %arg3 : f32 66 linalg.yield %m : f32 67 } -> tensor<112x112x16xf32> 68 return %3 : tensor<112x112x16xf32> 69} 70 71// ----- 72 73// Negative test, since the second source is broadcasted from d1 we cannot merge 74// d0 and d1 dimensions 75// CHECK-LABEL: func @reshape_negative 76// CHECK: tensor.expand_shape {{.*}} {{\[\[}}0, 1], [2]] output_shape [112, 112, 16] : tensor<12544x16xf32> into tensor<112x112x16xf32> 77// CHECK: linalg.generic 78// CHECK: } -> tensor<112x112x16xf32> 79func.func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<112x112x16xf32> { 80 %20 = tensor.expand_shape %A [[0, 1], [2]] output_shape [112, 112, 16] 81 : tensor<12544x16xf32> into tensor<112x112x16xf32> 82 %21 = tensor.empty() : tensor<112x112x16xf32> 83 %22 = linalg.generic {indexing_maps = [ 84 affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d1)>, 85 affine_map<(d0, d1, d2) -> (d0, d1, d2)>], 86 iterator_types = ["parallel", "parallel", "parallel"]} 87 ins(%20, %B : tensor<112x112x16xf32>, tensor<112xf32>) 88 outs(%21 : tensor<112x112x16xf32>) { 89 ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): 90 %s = arith.subf %arg1, %arg2 : f32 91 linalg.yield %s : f32 92 } -> tensor<112x112x16xf32> 93 return %22 : tensor<112x112x16xf32> 94} 95 96// ----- 97 98func.func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>, 99 %arg2 : tensor<5xf32>) -> tensor<2x3x5xf32> { 100 %cst_6 = arith.constant 1.000000e+00 : f32 101 %cst_7 = arith.constant 7.000000e+00 : f32 102 %cst_8 = arith.constant 1.1920929E-7 : f32 103 %25 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 3, 5] 104 : tensor<6x5xi32> into tensor<2x3x5xi32> 105 %26 = tensor.empty() : tensor<2x3x5xf32> 106 %28 = linalg.generic { 107 indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 108 affine_map<(d0, d1, d2) -> (d2)>, 109 affine_map<(d0, d1, d2) -> (d2)>, 110 affine_map<(d0, d1, d2) -> (d0, d1, d2)>], 111 iterator_types = ["parallel", "parallel", "parallel"]} 112 ins(%25, %arg1, %arg2 : tensor<2x3x5xi32>, tensor<5xf32>, tensor<5xf32>) 113 outs(%26 : tensor<2x3x5xf32>) { 114 ^bb0(%arg6: i32, %arg7: f32, %arg8: f32, %arg9: f32): 115 %29 = arith.sitofp %arg6 : i32 to f32 116 %30 = arith.addf %arg7, %cst_8 : f32 117 %31 = arith.divf %cst_7, %30 : f32 118 %32 = arith.divf %cst_6, %31 : f32 119 %33 = arith.mulf %29, %32 : f32 120 %34 = arith.addf %33, %arg8 : f32 121 linalg.yield %34 : f32 122 } -> tensor<2x3x5xf32> 123 return %28 : tensor<2x3x5xf32> 124} 125// CHECK-LABEL: func @type_correctness 126// CHECK: %[[OP:.+]] = linalg.generic 127// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<6x5xi32>, tensor<5xf32>, tensor<5xf32>) 128// CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>) 129// CHECK: tensor.expand_shape %[[OP]] 130// CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32> 131