1// RUN: mlir-opt %s -allow-unregistered-dialect -test-pdl-bytecode-pass -split-input-file | FileCheck %s 2 3// ----- 4 5//===----------------------------------------------------------------------===// 6// 1-layer perceptron with split fwd/bwd operations 7//===----------------------------------------------------------------------===// 8 9module @patterns { 10 // fc_fwd 11 pdl.pattern : benefit(1) { 12 %in_type = pdl.type 13 %out_type = pdl.type 14 %weight_type = pdl.type 15 %rxact = pdl.operand : %in_type 16 %weight = pdl.operand : %weight_type 17 18 %attr0 = pdl.attribute = false 19 %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type) 20 21 pdl.rewrite %op0 { 22 %op1 = pdl.operation "kernel.FcFwd" (%rxact, %weight : !pdl.value, !pdl.value) -> (%out_type : !pdl.type) 23 %val1 = pdl.result 0 of %op1 // txact 24 pdl.replace %op0 with (%val1 : !pdl.value) // tf.MatMul 25 } 26 } 27 28 // fc_bwd 29 pdl.pattern : benefit(4) { 30 %in_type = pdl.type 31 %out_type = pdl.type 32 %weight_type = pdl.type 33 %const_type = pdl.type 34 %rxact = pdl.operand : %in_type 35 %rxdelta = pdl.operand : %out_type 36 %weight = pdl.operand : %weight_type 37 38 %attr0 = pdl.attribute = true 39 %attr1 = pdl.attribute = false 40 %op0 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%weight_type : !pdl.type) 41 %val0 = pdl.result 0 of %op0 42 %op1 = pdl.operation "tf.Const" -> (%const_type : !pdl.type) 43 %val1 = pdl.result 0 of %op1 44 %op2 = pdl.operation "tf.Mul" (%val0, %val1 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type) 45 %val2 = pdl.result 0 of %op2 46 %op3 = pdl.operation "tf.Sub" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type) 47 48 pdl.rewrite %op3 { 49 %op4 = pdl.operation "kernel.FcBwd" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%weight_type : !pdl.type) 50 %val4 = pdl.result 0 of %op4 // weight_out 51 pdl.replace %op3 with (%val4 : !pdl.value) // tf.Sub 52 pdl.erase %op2 // tf.Mul 53 pdl.erase %op1 // tf.Const 54 pdl.erase %op0 // tf.MatMul 55 } 56 } 57 58 // softmax_cross_entropy 59 pdl.pattern : benefit(6) { 60 %in_type = pdl.type 61 %label_type = pdl.type 62 %loss_type = pdl.type 63 %mean_loss_type = pdl.type 64 %mean_const_type = pdl.type 65 %mul_const_type = pdl.type 66 %rxact = pdl.operand : %in_type 67 %rxlabel = pdl.operand : %label_type 68 69 %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type) 70 %val0_0 = pdl.result 0 of %op0 // loss 71 %val0_1 = pdl.result 1 of %op0 // gradient 72 %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type) 73 %val1 = pdl.result 0 of %op1 74 %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type) 75 %val2 = pdl.result 0 of %op2 76 %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type) 77 %val3 = pdl.result 0 of %op3 78 %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type) 79 %val4 = pdl.result 0 of %op4 80 %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type) 81 82 pdl.rewrite { // roots: %op2, %op5 83 %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type) 84 %val6_0 = pdl.result 0 of %op6 // txloss 85 %val6_1 = pdl.result 1 of %op6 // txdelta 86 pdl.replace %op5 with (%val6_1 : !pdl.value) // tf.Mul 87 pdl.erase %op4 // tf.Const 88 pdl.erase %op3 // tf.PreventGradient 89 pdl.replace %op2 with (%val6_0 : !pdl.value) // tf.Mean 90 pdl.erase %op1 // tf.Const 91 pdl.erase %op0 // tf.SparseSoftmaxCrossEntropyWithLogits 92 } 93 } 94} 95 96// CHECK-LABEL: test.mlp_split 97// CHECK: %[[FWD:.*]] = "kernel.FcFwd"(%arg0, %arg2) : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32> 98// CHECK: %[[SM:.*]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FWD]], %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<f32>, tensor<2x10xf32>) 99// CHECK: %[[BWD:.*]] = "kernel.FcBwd"(%arg0, %[[SM]]#1, %arg2) : (tensor<2x20xf32>, tensor<2x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32> 100// CHECK: return %[[SM:.*]]#0, %[[BWD]] : tensor<f32>, tensor<20x10xf32> 101module @ir attributes { test.mlp_split } { 102 func.func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<20x10xf32>) -> (tensor<f32>, tensor<20x10xf32>) { 103 %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> 104 %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32> 105 %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32> 106 %3 = "tf.MatMul"(%arg0, %arg2) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32> 107 %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%3, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>) 108 %4 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32> 109 %5 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32> 110 %6 = "tf.Mul"(%5, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32> 111 %7 = "tf.MatMul"(%arg0, %6) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x10xf32>) -> tensor<20x10xf32> 112 %8 = "tf.Mul"(%7, %1) : (tensor<20x10xf32>, tensor<f32>) -> tensor<20x10xf32> 113 %9 = "tf.Sub"(%arg2, %8) : (tensor<20x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32> 114 return %4, %9 : tensor<f32>, tensor<20x10xf32> 115 } 116} 117 118// ----- 119 120//===----------------------------------------------------------------------===// 121// 2-layer perceptron with fused fwd/bwd operations 122//===----------------------------------------------------------------------===// 123 124module @patterns { 125 126 // gradient descent 127 pdl.pattern : benefit(3) { 128 %const_type = pdl.type 129 %param_type = pdl.type 130 %param = pdl.operand : %param_type 131 %gradient = pdl.operand : %param_type 132 133 %attr0 = pdl.attribute 134 %op0 = pdl.operation "tf.Const" {"value" = %attr0} -> (%const_type : !pdl.type) 135 %val0 = pdl.result 0 of %op0 136 %op1 = pdl.operation "tf.Mul" (%gradient, %val0 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type) 137 %val1 = pdl.result 0 of %op1 138 %op2 = pdl.operation "tf.Sub" (%param, %val1 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type) 139 140 pdl.rewrite %op2 { 141 %op3 = pdl.operation "kernel.GD" (%param, %gradient : !pdl.value, !pdl.value) -> (%param_type : !pdl.type) 142 %val3 = pdl.result 0 of %op3 143 pdl.replace %op2 with (%val3 : !pdl.value) // tf.Sub 144 pdl.erase %op1 // tf.Mul 145 } 146 } 147 148 // first FC 149 pdl.pattern : benefit(8) { 150 %in_type = pdl.type 151 %out_type = pdl.type 152 %weight_type = pdl.type 153 %bias_type = pdl.type 154 %rxact = pdl.operand : %in_type 155 %rxdelta = pdl.operand : %out_type 156 %weight = pdl.operand : %weight_type 157 %bias = pdl.operand : %bias_type 158 159 %attr0 = pdl.attribute = false 160 %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type) 161 %val0 = pdl.result 0 of %op0 162 %op1 = pdl.operation "tf.BiasAdd" (%val0, %bias : !pdl.value, !pdl.value) -> (%out_type : !pdl.type) 163 %val1 = pdl.result 0 of %op1 164 %op2 = pdl.operation "tf.Relu" (%val1 : !pdl.value) -> (%out_type : !pdl.type) 165 %val2 = pdl.result 0 of %op2 166 %op3 = pdl.operation "tf.ReluGrad" (%rxdelta, %val2 : !pdl.value, !pdl.value) -> (%out_type : !pdl.type) 167 %val3 = pdl.result 0 of %op3 168 %attr1 = pdl.attribute = true 169 %op4 = pdl.operation "tf.MatMul" (%rxact, %val3 : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type) 170 %val4 = pdl.result 0 of %op4 171 %op5 = pdl.operation "kernel.GD" (%weight, %val4 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type) 172 %op6 = pdl.operation "tf.BiasAddGrad" (%val3 : !pdl.value) -> (%bias_type : !pdl.type) 173 %val6 = pdl.result 0 of %op6 174 %op7 = pdl.operation "kernel.GD" (%bias, %val6 : !pdl.value, !pdl.value) -> (%bias_type : !pdl.type) 175 176 pdl.rewrite { // roots: %op2, %op5, %op7 177 %op8 = pdl.operation "kernel.FcWithBias" (%rxact, %rxdelta, %weight, %bias : !pdl.value, !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %weight_type, %bias_type : !pdl.type, !pdl.type, !pdl.type) 178 %val8_0 = pdl.result 0 of %op8 // txact 179 %val8_1 = pdl.result 1 of %op8 // weight_out 180 %val8_2 = pdl.result 2 of %op8 // bias_out 181 pdl.replace %op7 with (%val8_2 : !pdl.value) // kernel.GD 182 pdl.erase %op6 // tf.BiasAddGrad 183 pdl.replace %op5 with (%val8_1 : !pdl.value) // kernel.GD 184 pdl.erase %op4 // tf.MatMul 185 pdl.erase %op3 // tf.ReluGrad 186 pdl.replace %op2 with (%val8_0 : !pdl.value) // tf.Relu 187 pdl.erase %op1 // tf.BiasAdd 188 pdl.erase %op0 // tf.MatMul 189 } 190 } 191 192 // second FC 193 pdl.pattern : benefit(4) { 194 %in_type = pdl.type 195 %out_type = pdl.type 196 %weight_type = pdl.type 197 %rxact = pdl.operand : %in_type 198 %rxdelta = pdl.operand : %out_type 199 %weight = pdl.operand : %weight_type 200 201 %attr0 = pdl.attribute = false 202 %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type) 203 %attr1 = pdl.attribute = true 204 %op1 = pdl.operation "tf.MatMul" (%rxdelta, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%in_type : !pdl.type) 205 %op2 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type) 206 %val2 = pdl.result 0 of %op2 207 %op3 = pdl.operation "kernel.GD" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type) 208 209 pdl.rewrite { // roots: %op0, %op1, %op3 210 %op4 = pdl.operation "kernel.Fc" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %in_type, %weight_type : !pdl.type, !pdl.type, !pdl.type) 211 %val4_0 = pdl.result 0 of %op4 // txact 212 %val4_1 = pdl.result 1 of %op4 // txdelta 213 %val4_2 = pdl.result 2 of %op4 // weight_out 214 pdl.replace %op3 with (%val4_2 : !pdl.value) // Sgd 215 pdl.erase %op2 // tf.MatMul 216 pdl.replace %op1 with (%val4_1 : !pdl.value) // tf.MatMul 217 pdl.replace %op0 with (%val4_0 : !pdl.value) // tf.MatMul 218 } 219 } 220 221 // softmax_cross_entropy 222 pdl.pattern : benefit(6) { 223 %in_type = pdl.type 224 %label_type = pdl.type 225 %loss_type = pdl.type 226 %mean_loss_type = pdl.type 227 %mean_const_type = pdl.type 228 %mul_const_type = pdl.type 229 %rxact = pdl.operand : %in_type 230 %rxlabel = pdl.operand : %label_type 231 232 %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type) 233 %val0_0 = pdl.result 0 of %op0 // loss 234 %val0_1 = pdl.result 1 of %op0 // gradient 235 %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type) 236 %val1 = pdl.result 0 of %op1 237 %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type) 238 %val2 = pdl.result 0 of %op2 239 %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type) 240 %val3 = pdl.result 0 of %op3 241 %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type) 242 %val4 = pdl.result 0 of %op4 243 %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type) 244 245 pdl.rewrite { // roots: %op2, %op5 246 %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type) 247 %val6_0 = pdl.result 0 of %op6 // txloss 248 %val6_1 = pdl.result 1 of %op6 // txdelta 249 pdl.replace %op5 with (%val6_1 : !pdl.value) // tf.Mul 250 pdl.erase %op4 // tf.Const 251 pdl.erase %op3 // tf.PreventGradient 252 pdl.replace %op2 with (%val6_0 : !pdl.value) // tf.Mean 253 pdl.erase %op1 // tf.Const 254 pdl.erase %op0 // tf.SparseSoftmaxCrossEntropyWithLogits 255 } 256 } 257} 258 259// CHECK-LABEL: test.mlp_fused 260// CHECK: %[[FC2:.*]]:3 = "kernel.Fc"(%[[FC1:.*]]#0, %[[SM:.*]]#1, %arg4) : (tensor<2x256xf32>, tensor<2x10xf32>, tensor<256x10xf32>) -> (tensor<2x10xf32>, tensor<2x256xf32>, tensor<256x10xf32>) 261// CHECK: %[[SM]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FC2]]#0, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<f32>, tensor<2x10xf32>) 262// CHECK: %[[FC1]]:3 = "kernel.FcWithBias"(%arg0, %[[FC2]]#1, %arg3, %arg2) : (tensor<2x20xf32>, tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>) -> (tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>) 263module @ir attributes { test.mlp_fused } { 264 func.func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<256xf32>, %arg3: tensor<20x256xf32>, %arg4: tensor<256x10xf32>) -> () { // tensor<f32>, tensor<256xf32>, tensor<20x256xf32>, tensor<256x10xf32>) { 265 // The replacement operations fuse forward and backward pass; therefore, the 266 // resulting graph is not a DAG. To address this, we wrap the operations in 267 // a graph region. 268 "test.graph_region"() ({ 269 %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> 270 %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32> 271 %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32> 272 %3 = "tf.MatMul"(%arg0, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x256xf32>) -> tensor<2x256xf32> 273 %4 = "tf.BiasAdd"(%3, %arg2) {data_format = "NHWC"} : (tensor<2x256xf32>, tensor<256xf32>) -> tensor<2x256xf32> 274 %5 = "tf.Relu"(%4) : (tensor<2x256xf32>) -> tensor<2x256xf32> 275 %6 = "tf.MatMul"(%5, %arg4) {transpose_a = false, transpose_b = false} : (tensor<2x256xf32>, tensor<256x10xf32>) -> tensor<2x10xf32> 276 %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%6, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>) 277 %7 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32> 278 %8 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32> 279 %9 = "tf.Mul"(%8, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32> 280 %10 = "tf.MatMul"(%9, %arg4) {transpose_a = false, transpose_b = true} : (tensor<2x10xf32>, tensor<256x10xf32>) -> tensor<2x256xf32> 281 %11 = "tf.MatMul"(%5, %9) {transpose_a = true, transpose_b = false} : (tensor<2x256xf32>, tensor<2x10xf32>) -> tensor<256x10xf32> 282 %12 = "tf.ReluGrad"(%10, %5) : (tensor<2x256xf32>, tensor<2x256xf32>) -> tensor<2x256xf32> 283 %13 = "tf.BiasAddGrad"(%12) {data_format = "NHWC"} : (tensor<2x256xf32>) -> tensor<256xf32> 284 %14 = "tf.MatMul"(%arg0, %12) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x256xf32>) -> tensor<20x256xf32> 285 %15 = "tf.Mul"(%14, %1) : (tensor<20x256xf32>, tensor<f32>) -> tensor<20x256xf32> 286 %16 = "tf.Sub"(%arg3, %15) : (tensor<20x256xf32>, tensor<20x256xf32>) -> tensor<20x256xf32> 287 %17 = "tf.Mul"(%13, %1) : (tensor<256xf32>, tensor<f32>) -> tensor<256xf32> 288 %18 = "tf.Sub"(%arg2, %17) : (tensor<256xf32>, tensor<256xf32>) -> tensor<256xf32> 289 %19 = "tf.Mul"(%11, %1) : (tensor<256x10xf32>, tensor<f32>) -> tensor<256x10xf32> 290 %20 = "tf.Sub"(%arg4, %19) : (tensor<256x10xf32>, tensor<256x10xf32>) -> tensor<256x10xf32> 291 }) : () -> () 292 return 293 } 294} 295