xref: /llvm-project/mlir/test/Integration/Dialect/PDL/CPU/multiroot.mlir (revision d4381b3f93a6e53e1b35232b9a0039b1f5e04c6a)
1// RUN: mlir-opt %s  -allow-unregistered-dialect -test-pdl-bytecode-pass -split-input-file | FileCheck %s
2
3// -----
4
5//===----------------------------------------------------------------------===//
6// 1-layer perceptron with split fwd/bwd operations
7//===----------------------------------------------------------------------===//
8
9module @patterns {
10  // fc_fwd
11  pdl.pattern : benefit(1) {
12    %in_type = pdl.type
13    %out_type = pdl.type
14    %weight_type = pdl.type
15    %rxact = pdl.operand : %in_type
16    %weight = pdl.operand : %weight_type
17
18    %attr0 = pdl.attribute = false
19    %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
20
21    pdl.rewrite %op0 {
22      %op1 = pdl.operation "kernel.FcFwd" (%rxact, %weight : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
23      %val1 = pdl.result 0 of %op1  // txact
24      pdl.replace %op0 with (%val1 : !pdl.value)  // tf.MatMul
25    }
26  }
27
28  // fc_bwd
29  pdl.pattern : benefit(4) {
30    %in_type = pdl.type
31    %out_type = pdl.type
32    %weight_type = pdl.type
33    %const_type = pdl.type
34    %rxact = pdl.operand : %in_type
35    %rxdelta = pdl.operand : %out_type
36    %weight = pdl.operand : %weight_type
37
38    %attr0 = pdl.attribute = true
39    %attr1 = pdl.attribute = false
40    %op0 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%weight_type : !pdl.type)
41    %val0 = pdl.result 0 of %op0
42    %op1 = pdl.operation "tf.Const" -> (%const_type : !pdl.type)
43    %val1 = pdl.result 0 of %op1
44    %op2 = pdl.operation "tf.Mul" (%val0, %val1 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
45    %val2 = pdl.result 0 of %op2
46    %op3 = pdl.operation "tf.Sub" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
47
48    pdl.rewrite %op3 {
49      %op4 = pdl.operation "kernel.FcBwd" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
50      %val4 = pdl.result 0 of %op4  // weight_out
51      pdl.replace %op3 with (%val4 : !pdl.value)  // tf.Sub
52      pdl.erase %op2  // tf.Mul
53      pdl.erase %op1  // tf.Const
54      pdl.erase %op0  // tf.MatMul
55    }
56  }
57
58  // softmax_cross_entropy
59  pdl.pattern : benefit(6) {
60    %in_type = pdl.type
61    %label_type = pdl.type
62    %loss_type = pdl.type
63    %mean_loss_type = pdl.type
64    %mean_const_type = pdl.type
65    %mul_const_type = pdl.type
66    %rxact = pdl.operand : %in_type
67    %rxlabel = pdl.operand : %label_type
68
69    %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type)
70    %val0_0 = pdl.result 0 of %op0  // loss
71    %val0_1 = pdl.result 1 of %op0  // gradient
72    %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type)
73    %val1 = pdl.result 0 of %op1
74    %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type)
75    %val2 = pdl.result 0 of %op2
76    %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type)
77    %val3 = pdl.result 0 of %op3
78    %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type)
79    %val4 = pdl.result 0 of %op4
80    %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type)
81
82    pdl.rewrite {  // roots: %op2, %op5
83      %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type)
84      %val6_0 = pdl.result 0 of %op6  // txloss
85      %val6_1 = pdl.result 1 of %op6  // txdelta
86      pdl.replace %op5 with (%val6_1 : !pdl.value)  // tf.Mul
87      pdl.erase %op4  // tf.Const
88      pdl.erase %op3  // tf.PreventGradient
89      pdl.replace %op2 with (%val6_0 : !pdl.value)  // tf.Mean
90      pdl.erase %op1  // tf.Const
91      pdl.erase %op0  // tf.SparseSoftmaxCrossEntropyWithLogits
92    }
93  }
94}
95
96// CHECK-LABEL: test.mlp_split
97// CHECK: %[[FWD:.*]] = "kernel.FcFwd"(%arg0, %arg2) : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32>
98// CHECK: %[[SM:.*]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FWD]], %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<f32>, tensor<2x10xf32>)
99// CHECK: %[[BWD:.*]] = "kernel.FcBwd"(%arg0, %[[SM]]#1, %arg2) : (tensor<2x20xf32>, tensor<2x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32>
100// CHECK: return %[[SM:.*]]#0, %[[BWD]] : tensor<f32>, tensor<20x10xf32>
101module @ir attributes { test.mlp_split } {
102  func.func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<20x10xf32>) -> (tensor<f32>, tensor<20x10xf32>) {
103    %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
104    %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
105    %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
106    %3 = "tf.MatMul"(%arg0, %arg2) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x10xf32>) -> tensor<2x10xf32>
107    %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%3, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>)
108    %4 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
109    %5 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32>
110    %6 = "tf.Mul"(%5, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32>
111    %7 = "tf.MatMul"(%arg0, %6) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x10xf32>) -> tensor<20x10xf32>
112    %8 = "tf.Mul"(%7, %1) : (tensor<20x10xf32>, tensor<f32>) -> tensor<20x10xf32>
113    %9 = "tf.Sub"(%arg2, %8) : (tensor<20x10xf32>, tensor<20x10xf32>) -> tensor<20x10xf32>
114    return %4, %9 : tensor<f32>, tensor<20x10xf32>
115  }
116}
117
118// -----
119
120//===----------------------------------------------------------------------===//
121// 2-layer perceptron with fused fwd/bwd operations
122//===----------------------------------------------------------------------===//
123
124module @patterns {
125
126  // gradient descent
127  pdl.pattern : benefit(3) {
128    %const_type = pdl.type
129    %param_type = pdl.type
130    %param = pdl.operand : %param_type
131    %gradient = pdl.operand : %param_type
132
133    %attr0 = pdl.attribute
134    %op0 = pdl.operation "tf.Const" {"value" = %attr0} -> (%const_type : !pdl.type)
135    %val0 = pdl.result 0 of %op0
136    %op1 = pdl.operation "tf.Mul" (%gradient, %val0 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
137    %val1 = pdl.result 0 of %op1
138    %op2 = pdl.operation "tf.Sub" (%param, %val1 : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
139
140    pdl.rewrite %op2 {
141      %op3 = pdl.operation "kernel.GD" (%param, %gradient : !pdl.value, !pdl.value) -> (%param_type : !pdl.type)
142      %val3 = pdl.result 0 of %op3
143      pdl.replace %op2 with (%val3 : !pdl.value)  // tf.Sub
144      pdl.erase %op1  // tf.Mul
145    }
146  }
147
148  // first FC
149  pdl.pattern : benefit(8) {
150    %in_type = pdl.type
151    %out_type = pdl.type
152    %weight_type = pdl.type
153    %bias_type = pdl.type
154    %rxact = pdl.operand : %in_type
155    %rxdelta = pdl.operand : %out_type
156    %weight = pdl.operand : %weight_type
157    %bias = pdl.operand : %bias_type
158
159    %attr0 = pdl.attribute = false
160    %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
161    %val0 = pdl.result 0 of %op0
162    %op1 = pdl.operation "tf.BiasAdd" (%val0, %bias : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
163    %val1 = pdl.result 0 of %op1
164    %op2 = pdl.operation "tf.Relu" (%val1 : !pdl.value) -> (%out_type : !pdl.type)
165    %val2 = pdl.result 0 of %op2
166    %op3 = pdl.operation "tf.ReluGrad" (%rxdelta, %val2 : !pdl.value, !pdl.value) -> (%out_type : !pdl.type)
167    %val3 = pdl.result 0 of %op3
168    %attr1 = pdl.attribute = true
169    %op4 = pdl.operation "tf.MatMul" (%rxact, %val3 : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type)
170    %val4 = pdl.result 0 of %op4
171    %op5 = pdl.operation "kernel.GD" (%weight, %val4 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
172    %op6 = pdl.operation "tf.BiasAddGrad" (%val3 : !pdl.value) -> (%bias_type : !pdl.type)
173    %val6 = pdl.result 0 of %op6
174    %op7 = pdl.operation "kernel.GD" (%bias, %val6 : !pdl.value, !pdl.value) -> (%bias_type : !pdl.type)
175
176    pdl.rewrite {  // roots: %op2, %op5, %op7
177      %op8 = pdl.operation "kernel.FcWithBias" (%rxact, %rxdelta, %weight, %bias : !pdl.value, !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %weight_type, %bias_type : !pdl.type, !pdl.type, !pdl.type)
178      %val8_0 = pdl.result 0 of %op8  // txact
179      %val8_1 = pdl.result 1 of %op8  // weight_out
180      %val8_2 = pdl.result 2 of %op8  // bias_out
181      pdl.replace %op7 with (%val8_2 : !pdl.value)  // kernel.GD
182      pdl.erase %op6  // tf.BiasAddGrad
183      pdl.replace %op5 with (%val8_1 : !pdl.value)  // kernel.GD
184      pdl.erase %op4  // tf.MatMul
185      pdl.erase %op3  // tf.ReluGrad
186      pdl.replace %op2 with (%val8_0 : !pdl.value)  // tf.Relu
187      pdl.erase %op1  // tf.BiasAdd
188      pdl.erase %op0  // tf.MatMul
189    }
190  }
191
192  // second FC
193  pdl.pattern : benefit(4) {
194    %in_type = pdl.type
195    %out_type = pdl.type
196    %weight_type = pdl.type
197    %rxact = pdl.operand : %in_type
198    %rxdelta = pdl.operand : %out_type
199    %weight = pdl.operand : %weight_type
200
201    %attr0 = pdl.attribute = false
202    %op0 = pdl.operation "tf.MatMul" (%rxact, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr0} -> (%out_type : !pdl.type)
203    %attr1 = pdl.attribute = true
204    %op1 = pdl.operation "tf.MatMul" (%rxdelta, %weight : !pdl.value, !pdl.value) {"transpose_a" = %attr0, "transpose_b" = %attr1} -> (%in_type : !pdl.type)
205    %op2 = pdl.operation "tf.MatMul" (%rxact, %rxdelta : !pdl.value, !pdl.value) {"transpose_a" = %attr1, "transpose_b" = %attr0} -> (%weight_type : !pdl.type)
206    %val2 = pdl.result 0 of %op2
207    %op3 = pdl.operation "kernel.GD" (%weight, %val2 : !pdl.value, !pdl.value) -> (%weight_type : !pdl.type)
208
209    pdl.rewrite {  // roots: %op0, %op1, %op3
210      %op4 = pdl.operation "kernel.Fc" (%rxact, %rxdelta, %weight : !pdl.value, !pdl.value, !pdl.value) -> (%out_type, %in_type, %weight_type : !pdl.type, !pdl.type, !pdl.type)
211      %val4_0 = pdl.result 0 of %op4  // txact
212      %val4_1 = pdl.result 1 of %op4  // txdelta
213      %val4_2 = pdl.result 2 of %op4  // weight_out
214      pdl.replace %op3 with (%val4_2 : !pdl.value)  // Sgd
215      pdl.erase %op2  // tf.MatMul
216      pdl.replace %op1 with (%val4_1 : !pdl.value)  // tf.MatMul
217      pdl.replace %op0 with (%val4_0 : !pdl.value)  // tf.MatMul
218    }
219  }
220
221  // softmax_cross_entropy
222  pdl.pattern : benefit(6) {
223    %in_type = pdl.type
224    %label_type = pdl.type
225    %loss_type = pdl.type
226    %mean_loss_type = pdl.type
227    %mean_const_type = pdl.type
228    %mul_const_type = pdl.type
229    %rxact = pdl.operand : %in_type
230    %rxlabel = pdl.operand : %label_type
231
232    %op0 = pdl.operation "tf.SparseSoftmaxCrossEntropyWithLogits" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%loss_type, %in_type : !pdl.type, !pdl.type)
233    %val0_0 = pdl.result 0 of %op0  // loss
234    %val0_1 = pdl.result 1 of %op0  // gradient
235    %op1 = pdl.operation "tf.Const" -> (%mean_const_type : !pdl.type)
236    %val1 = pdl.result 0 of %op1
237    %op2 = pdl.operation "tf.Mean" (%val0_0, %val1 : !pdl.value, !pdl.value) -> (%mean_loss_type : !pdl.type)
238    %val2 = pdl.result 0 of %op2
239    %op3 = pdl.operation "tf.PreventGradient" (%val0_1 : !pdl.value) -> (%in_type : !pdl.type)
240    %val3 = pdl.result 0 of %op3
241    %op4 = pdl.operation "tf.Const" -> (%mul_const_type : !pdl.type)
242    %val4 = pdl.result 0 of %op4
243    %op5 = pdl.operation "tf.Mul" (%val3, %val4 : !pdl.value, !pdl.value) -> (%in_type : !pdl.type)
244
245    pdl.rewrite {  // roots: %op2, %op5
246      %op6 = pdl.operation "kernel.SoftmaxCrossEntropy" (%rxact, %rxlabel : !pdl.value, !pdl.value) -> (%mean_loss_type, %in_type : !pdl.type, !pdl.type)
247      %val6_0 = pdl.result 0 of %op6  // txloss
248      %val6_1 = pdl.result 1 of %op6  // txdelta
249      pdl.replace %op5 with (%val6_1 : !pdl.value)  // tf.Mul
250      pdl.erase %op4  // tf.Const
251      pdl.erase %op3  // tf.PreventGradient
252      pdl.replace %op2 with (%val6_0 : !pdl.value)  // tf.Mean
253      pdl.erase %op1  // tf.Const
254      pdl.erase %op0  // tf.SparseSoftmaxCrossEntropyWithLogits
255    }
256  }
257}
258
259// CHECK-LABEL: test.mlp_fused
260// CHECK: %[[FC2:.*]]:3 = "kernel.Fc"(%[[FC1:.*]]#0, %[[SM:.*]]#1, %arg4) : (tensor<2x256xf32>, tensor<2x10xf32>, tensor<256x10xf32>) -> (tensor<2x10xf32>, tensor<2x256xf32>, tensor<256x10xf32>)
261// CHECK: %[[SM]]:2 = "kernel.SoftmaxCrossEntropy"(%[[FC2]]#0, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<f32>, tensor<2x10xf32>)
262// CHECK: %[[FC1]]:3 = "kernel.FcWithBias"(%arg0, %[[FC2]]#1, %arg3, %arg2) : (tensor<2x20xf32>, tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>) -> (tensor<2x256xf32>, tensor<20x256xf32>, tensor<256xf32>)
263module @ir attributes { test.mlp_fused } {
264  func.func @main(%arg0: tensor<2x20xf32>, %arg1: tensor<2xi32>, %arg2: tensor<256xf32>, %arg3: tensor<20x256xf32>, %arg4: tensor<256x10xf32>) -> () { // tensor<f32>, tensor<256xf32>, tensor<20x256xf32>, tensor<256x10xf32>) {
265    // The replacement operations fuse forward and backward pass; therefore, the
266    // resulting graph is not a DAG. To address this, we wrap the operations in
267    // a graph region.
268    "test.graph_region"() ({
269      %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
270      %1 = "tf.Const"() {value = dense<1.000000e-01> : tensor<f32>} : () -> tensor<f32>
271      %2 = "tf.Const"() {value = dense<5.000000e-01> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
272      %3 = "tf.MatMul"(%arg0, %arg3) {transpose_a = false, transpose_b = false} : (tensor<2x20xf32>, tensor<20x256xf32>) -> tensor<2x256xf32>
273      %4 = "tf.BiasAdd"(%3, %arg2) {data_format = "NHWC"} : (tensor<2x256xf32>, tensor<256xf32>) -> tensor<2x256xf32>
274      %5 = "tf.Relu"(%4) : (tensor<2x256xf32>) -> tensor<2x256xf32>
275      %6 = "tf.MatMul"(%5, %arg4) {transpose_a = false, transpose_b = false} : (tensor<2x256xf32>, tensor<256x10xf32>) -> tensor<2x10xf32>
276      %loss, %backprop = "tf.SparseSoftmaxCrossEntropyWithLogits"(%6, %arg1) : (tensor<2x10xf32>, tensor<2xi32>) -> (tensor<2xf32>, tensor<2x10xf32>)
277      %7 = "tf.Mean"(%loss, %0) {keep_dims = false} : (tensor<2xf32>, tensor<1xi32>) -> tensor<f32>
278      %8 = "tf.PreventGradient"(%backprop) : (tensor<2x10xf32>) -> tensor<2x10xf32>
279      %9 = "tf.Mul"(%8, %2) : (tensor<2x10xf32>, tensor<2x1xf32>) -> tensor<2x10xf32>
280      %10 = "tf.MatMul"(%9, %arg4) {transpose_a = false, transpose_b = true} : (tensor<2x10xf32>, tensor<256x10xf32>) -> tensor<2x256xf32>
281      %11 = "tf.MatMul"(%5, %9) {transpose_a = true, transpose_b = false} : (tensor<2x256xf32>, tensor<2x10xf32>) -> tensor<256x10xf32>
282      %12 = "tf.ReluGrad"(%10, %5) : (tensor<2x256xf32>, tensor<2x256xf32>) -> tensor<2x256xf32>
283      %13 = "tf.BiasAddGrad"(%12) {data_format = "NHWC"} : (tensor<2x256xf32>) -> tensor<256xf32>
284      %14 = "tf.MatMul"(%arg0, %12) {transpose_a = true, transpose_b = false} : (tensor<2x20xf32>, tensor<2x256xf32>) -> tensor<20x256xf32>
285      %15 = "tf.Mul"(%14, %1) : (tensor<20x256xf32>, tensor<f32>) -> tensor<20x256xf32>
286      %16 = "tf.Sub"(%arg3, %15) : (tensor<20x256xf32>, tensor<20x256xf32>) -> tensor<20x256xf32>
287      %17 = "tf.Mul"(%13, %1) : (tensor<256xf32>, tensor<f32>) -> tensor<256xf32>
288      %18 = "tf.Sub"(%arg2, %17) : (tensor<256xf32>, tensor<256xf32>) -> tensor<256xf32>
289      %19 = "tf.Mul"(%11, %1) : (tensor<256x10xf32>, tensor<f32>) -> tensor<256x10xf32>
290      %20 = "tf.Sub"(%arg4, %19) : (tensor<256x10xf32>, tensor<256x10xf32>) -> tensor<256x10xf32>
291    }) : () -> ()
292    return
293  }
294}
295