1// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s 2 3module attributes {transform.with_named_sequence} { 4 transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { 5 %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> 6 transform.apply_patterns to %func_op { 7 transform.apply_patterns.tensor.rewrite_as_constant 8 } : !transform.op<"func.func"> 9 transform.yield 10 } 11} 12 13// CHECK-LABEL: func @tensor_generate_constant( 14// CHECK: %[[cst:.*]] = arith.constant dense<5.000000e+00> : tensor<2x3x5xf32> 15// CHECK: return %[[cst]] 16func.func @tensor_generate_constant() -> tensor<2x3x5xf32> { 17 %cst = arith.constant 5.0 : f32 18 %0 = tensor.generate { 19 ^bb0(%arg0: index, %arg1: index, %arg2: index): 20 tensor.yield %cst : f32 21 } : tensor<2x3x5xf32> 22 return %0 : tensor<2x3x5xf32> 23} 24 25// CHECK-LABEL: func @pad_of_ints( 26// CHECK: %[[cst:.*]] = arith.constant dense<[ 27// CHECK-SAME{LITERAL}: [0, 0, 0, 0], 28// CHECK-SAME{LITERAL}: [0, 6, 7, 0], 29// CHECK-SAME{LITERAL}: [0, 8, 9, 0], 30// CHECK-SAME{LITERAL}: [0, 0, 0, 0] 31// CHECK-SAME{LITERAL}: ]> : tensor<4x4xi32> 32// CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32> 33// CHECK: return %[[cast]] 34func.func @pad_of_ints() -> tensor<?x?xi32> { 35 %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32> 36 %pad_value = arith.constant 0 : i32 37 38 %c1 = arith.constant 1 : index 39 40 %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] { 41 ^bb0(%arg1: index, %arg2: index): 42 tensor.yield %pad_value : i32 43 } : tensor<2x2xi32> to tensor<?x?xi32> 44 45 return %0 : tensor<?x?xi32> 46} 47 48// CHECK-LABEL: func @pad_of_floats( 49// CHECK: %[[cst:.*]] = arith.constant dense<[ 50// CHECK-SAME{LITERAL}: [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00], 51// CHECK-SAME{LITERAL}: [0.000000e+00, 6.000000e+00, 7.000000e+00, 0.000000e+00], 52// CHECK-SAME{LITERAL}: [0.000000e+00, 8.000000e+00, 9.000000e+00, 0.000000e+00], 53// CHECK-SAME{LITERAL}: [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00] 54// CHECK-SAME{LITERAL}: ]> : tensor<4x4xf32> 55// CHECK: return %[[cst]] 56 57func.func @pad_of_floats() -> tensor<4x4xf32> { 58 %init = arith.constant dense<[[6.0, 7.0], [8.0, 9.0]]> : tensor<2x2xf32> 59 %pad_value = arith.constant 0.0 : f32 60 61 %0 = tensor.pad %init low[1, 1] high[1, 1] { 62 ^bb0(%arg1: index, %arg2: index): 63 tensor.yield %pad_value : f32 64 } : tensor<2x2xf32> to tensor<4x4xf32> 65 66 return %0 : tensor<4x4xf32> 67} 68 69// CHECK-LABEL: func @pad_of_ints_no_low_dims( 70// CHECK: %[[cst:.*]] = arith.constant dense<[ 71// CHECK-SAME{LITERAL}: [6, 7, 0], 72// CHECK-SAME{LITERAL}: [8, 9, 0], 73// CHECK-SAME{LITERAL}: [0, 0, 0] 74// CHECK-SAME{LITERAL}: ]> : tensor<3x3xi32> 75// CHECK: return %[[cst]] 76func.func @pad_of_ints_no_low_dims() -> tensor<3x3xi32> { 77 %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32> 78 %pad_value = arith.constant 0 : i32 79 80 %0 = tensor.pad %init low[0, 0] high[1, 1] { 81 ^bb0(%arg1: index, %arg2: index): 82 tensor.yield %pad_value : i32 83 } : tensor<2x2xi32> to tensor<3x3xi32> 84 85 return %0 : tensor<3x3xi32> 86} 87 88// CHECK-LABEL: func @pad_of_ints_no_high_dims( 89// CHECK: %[[cst:.*]] = arith.constant dense<[ 90// CHECK-SAME{LITERAL}: [0, 0, 0], 91// CHECK-SAME{LITERAL}: [0, 6, 7], 92// CHECK-SAME{LITERAL}: [0, 8, 9] 93// CHECK-SAME{LITERAL}: ]> : tensor<3x3xi32> 94// CHECK: return %[[cst]] 95func.func @pad_of_ints_no_high_dims() -> tensor<3x3xi32> { 96 %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32> 97 %pad_value = arith.constant 0 : i32 98 99 %0 = tensor.pad %init low[1, 1] high[0, 0] { 100 ^bb0(%arg1: index, %arg2: index): 101 tensor.yield %pad_value : i32 102 } : tensor<2x2xi32> to tensor<3x3xi32> 103 104 return %0 : tensor<3x3xi32> 105} 106 107// CHECK-LABEL: func @pad_multi_use_do_not_fold( 108// CHECK: %[[pad:.+]] = tensor.pad 109// CHECK: return %[[pad]] 110func.func @pad_multi_use_do_not_fold() -> (tensor<?x?xi32>, tensor<2x2xi32>) { 111 %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32> 112 %pad_value = arith.constant 0 : i32 113 114 %c1 = arith.constant 1 : index 115 116 %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] { 117 ^bb0(%arg1: index, %arg2: index): 118 tensor.yield %pad_value : i32 119 } : tensor<2x2xi32> to tensor<?x?xi32> 120 121 return %0, %init : tensor<?x?xi32>, tensor<2x2xi32> 122} 123 124// ----- 125 126module attributes {transform.with_named_sequence} { 127 transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { 128 %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> 129 transform.apply_patterns to %func_op { 130 transform.apply_patterns.tensor.rewrite_as_constant aggressive 131 } : !transform.op<"func.func"> 132 transform.yield 133 } 134} 135 136// CHECK-LABEL: func @pad_aggressive_fold( 137// CHECK: %[[init:.*]] = arith.constant dense<7> : tensor<2x2xi32> 138// CHECK: %[[cst:.*]] = arith.constant dense<[ 139// CHECK-SAME{LITERAL}: [0, 0, 0, 0], 140// CHECK-SAME{LITERAL}: [0, 7, 7, 0], 141// CHECK-SAME{LITERAL}: [0, 7, 7, 0], 142// CHECK-SAME{LITERAL}: [0, 0, 0, 0] 143// CHECK-SAME{LITERAL}: ]> : tensor<4x4xi32> 144// CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32> 145// CHECK: return %[[cast]] 146func.func @pad_aggressive_fold() -> (tensor<?x?xi32>, tensor<2x2xi32>) { 147 %init = arith.constant dense<7> : tensor<2x2xi32> 148 %pad_value = arith.constant 0 : i32 149 150 %c1 = arith.constant 1 : index 151 152 %0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] { 153 ^bb0(%arg1: index, %arg2: index): 154 tensor.yield %pad_value : i32 155 } : tensor<2x2xi32> to tensor<?x?xi32> 156 157 return %0, %init : tensor<?x?xi32>, tensor<2x2xi32> 158} 159