1// RUN: mlir-opt %s -linalg-fuse-elementwise-ops -split-input-file | FileCheck %s 2 3// CHECK-LABEL: @transpose_fold_2d_fp32 4func.func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { 5 %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> 6 // CHECK: %[[CST:.+]] = arith.constant 7 // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> 8 %1 = linalg.generic { 9 indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], 10 iterator_types = ["parallel", "parallel"] 11 } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { 12 ^bb0(%arg1: f32, %arg2: f32): 13 linalg.yield %arg1 : f32 14 } -> tensor<3x2xf32> 15 // CHECK: return %[[CST]] 16 return %1 : tensor<3x2xf32> 17} 18 19// ----- 20 21// CHECK-LABEL: @transpose_fold_2d_fp64 22func.func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> { 23 %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64> 24 // CHECK: %[[CST:.+]] = arith.constant 25 // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64> 26 %1 = linalg.generic { 27 indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], 28 iterator_types = ["parallel", "parallel"] 29 } ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) { 30 ^bb0(%arg1: f64, %arg2: f64): 31 linalg.yield %arg1 : f64 32 } -> tensor<3x2xf64> 33 // CHECK: return %[[CST]] 34 return %1 : tensor<3x2xf64> 35} 36 37// ----- 38 39// CHECK-LABEL: @transpose_fold_4d_i32 40func.func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> { 41 %input = arith.constant dense<[[ 42 [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], 43 [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] 44 ]]> : tensor<1x2x3x4xi32> 45 // CHECK: %[[CST:.+]] = arith.constant dense<[ 46 // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], 47 // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], 48 // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] 49 // CHECK-SAME{LITERAL}: ]> 50 %1 = linalg.generic { 51 indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>], 52 iterator_types = ["parallel", "parallel", "parallel", "parallel"] 53 } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) { 54 ^bb0(%arg1: i32, %arg2: i32): 55 linalg.yield %arg1 : i32 56 } -> tensor<3x1x4x2xi32> 57 // CHECK: return %[[CST]] 58 return %1 : tensor<3x1x4x2xi32> 59} 60 61// ----- 62 63// CHECK-LABEL: @transpose_fold_4d_i16 64func.func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> { 65 %input = arith.constant dense<[[ 66 [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], 67 [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] 68 ]]> : tensor<1x2x3x4xi16> 69 // CHECK: %[[CST:.+]] = arith.constant dense<[ 70 // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], 71 // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], 72 // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] 73 // CHECK-SAME{LITERAL}: ]> 74 %1 = linalg.generic { 75 indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>], 76 iterator_types = ["parallel", "parallel", "parallel", "parallel"] 77 } ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) { 78 ^bb0(%arg1: i16, %arg2: i16): 79 linalg.yield %arg1 : i16 80 } -> tensor<3x1x4x2xi16> 81 // CHECK: return %[[CST]] 82 return %1 : tensor<3x1x4x2xi16> 83} 84 85// ----- 86 87// CHECK-LABEL: @transpose_nofold_non_cst_input 88func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> { 89 // CHECK: linalg.generic 90 %1 = linalg.generic { 91 indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], 92 iterator_types = ["parallel", "parallel"] 93 } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { 94 ^bb0(%arg1: f32, %arg2: f32): 95 linalg.yield %arg1 : f32 96 } -> tensor<3x2xf32> 97 return %1 : tensor<3x2xf32> 98} 99 100// ----- 101 102// CHECK-LABEL: @transpose_nofold_yield_const 103func.func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { 104 %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> 105 %cst = arith.constant 8.0 : f32 106 // CHECK: linalg.generic 107 %1 = linalg.generic { 108 indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], 109 iterator_types = ["parallel", "parallel"] 110 } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { 111 ^bb0(%arg1: f32, %arg2: f32): 112 linalg.yield %cst : f32 113 } -> tensor<3x2xf32> 114 return %1 : tensor<3x2xf32> 115} 116 117// ----- 118 119// CHECK-LABEL: @transpose_nofold_multi_ops_in_region 120func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { 121 %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> 122 // CHECK: linalg.generic 123 %1 = linalg.generic { 124 indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], 125 iterator_types = ["parallel", "parallel"] 126 } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { 127 ^bb0(%arg1: f32, %arg2: f32): 128 %add = arith.addf %arg1, %arg1 : f32 129 linalg.yield %add : f32 130 } -> tensor<3x2xf32> 131 return %1 : tensor<3x2xf32> 132} 133 134// ----- 135 136// CHECK-LABEL: @named_transpose_fold_2d_fp32 137func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { 138 %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> 139 // CHECK: %[[CST:.+]] = arith.constant 140 // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> 141 %1 = linalg.transpose ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) permutation = [1, 0] 142 // CHECK: return %[[CST]] 143 return %1 : tensor<3x2xf32> 144} 145 146// ----- 147 148 149