xref: /llvm-project/mlir/test/Dialect/Arith/expand-ops.mlir (revision 30badf96bbaa5ddfd8049442e573fd270a89ddc8)
1// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s
2
3// Test ceil divide with signed integer
4// CHECK-LABEL:       func @ceildivi
5// CHECK-SAME:     ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
6func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
7  %res = arith.ceildivsi %arg0, %arg1 : i32
8  return %res : i32
9
10// CHECK:           [[ONE:%.+]] = arith.constant 1 : i32
11// CHECK:           [[ZERO:%.+]] = arith.constant 0 : i32
12// CHECK:           [[MINONE:%.+]] = arith.constant -1 : i32
13// CHECK:           [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
14// CHECK:           [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : i32
15// CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32
16// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
17// CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32
18// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : i32
19// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : i32
20// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : i32
21// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
22// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
23// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
24// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
25// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
26// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
27// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
28// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
29}
30
31// -----
32
33// Test ceil divide with index type
34// CHECK-LABEL:       func @ceildivi_index
35// CHECK-SAME:     ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
36func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
37  %res = arith.ceildivsi %arg0, %arg1 : index
38  return %res : index
39
40// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
41// CHECK:           [[ZERO:%.+]] = arith.constant 0 : index
42// CHECK:           [[MINONE:%.+]] = arith.constant -1 : index
43// CHECK:           [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
44// CHECK:           [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : index
45// CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
46// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
47// CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
48// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
49// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
50// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
51// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
52// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
53// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
54// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
55// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
56// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
57// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
58// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
59}
60
61// -----
62
63// Test floor divide with signed integer
64// CHECK-LABEL:       func @floordivi
65// CHECK-SAME:     ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
66func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
67  %res = arith.floordivsi %arg0, %arg1 : i32
68  return %res : i32
69// CHECK:   %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : i32
70// CHECK:   %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : i32
71// CHECK:   %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : i32
72// CHECK-DAG:   %[[ZERO:.*]] = arith.constant 0 : i32
73// CHECK:   %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : i32
74// CHECK:   %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : i32
75// CHECK:   %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1
76// CHECK:   %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1
77// CHECK-DAG:   %[[NEG_ONE:.*]] = arith.constant -1 : i32
78// CHECK:   %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : i32
79// CHECK:   %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : i32
80}
81
82// -----
83
84// Test floor divide with index type
85// CHECK-LABEL:       func @floordivi_index
86// CHECK-SAME:     ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
87func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
88  %res = arith.floordivsi %arg0, %arg1 : index
89  return %res : index
90// CHECK:   %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : index
91// CHECK:   %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : index
92// CHECK:   %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : index
93// CHECK-DAG:   %[[ZERO:.*]] = arith.constant 0 : index
94// CHECK:   %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : index
95// CHECK:   %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : index
96// CHECK:   %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1
97// CHECK:   %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1
98// CHECK:   %[[NEG_ONE:.*]] = arith.constant -1 : index
99// CHECK-DAG:   %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : index
100// CHECK:   %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : index
101}
102
103// -----
104
105// Test floor divide with vector
106// CHECK-LABEL:   func.func @floordivi_vec(
107// CHECK-SAME:                             %[[VAL_0:.*]]: vector<4xi32>,
108// CHECK-SAME:                             %[[VAL_1:.*]]: vector<4xi32>) -> vector<4xi32> {
109func.func @floordivi_vec(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>) {
110  %res = arith.floordivsi %arg0, %arg1 : vector<4xi32>
111  return %res : vector<4xi32>
112// CHECK:   %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : vector<4xi32>
113// CHECK:   %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : vector<4xi32>
114// CHECK:   %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : vector<4xi32>
115// CHECK-DAG:   %[[ZERO:.*]] = arith.constant dense<0> : vector<4xi32>
116// CHECK:   %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : vector<4xi32>
117// CHECK:   %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : vector<4xi32>
118// CHECK:   %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : vector<4xi1>
119// CHECK:   %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : vector<4xi1>
120// CHECK-DAG:   %[[NEG_ONE:.*]] = arith.constant dense<-1> : vector<4xi32>
121// CHECK:   %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : vector<4xi32>
122// CHECK:   %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : vector<4xi1>, vector<4xi32>
123}
124
125// -----
126
127// Test ceil divide with unsigned integer
128// CHECK-LABEL:       func @ceildivui
129// CHECK-SAME:     ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
130func.func @ceildivui(%arg0: i32, %arg1: i32) -> (i32) {
131  %res = arith.ceildivui %arg0, %arg1 : i32
132  return %res : i32
133// CHECK:           [[ZERO:%.+]] = arith.constant 0 : i32
134// CHECK:           [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : i32
135// CHECK:           [[ONE:%.+]] = arith.constant 1 : i32
136// CHECK:           [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : i32
137// CHECK:           [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : i32
138// CHECK:           [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : i32
139// CHECK:           [[RES:%.+]] = arith.select [[ISZERO]], [[ZERO]], [[REM]] : i32
140}
141
142// -----
143
144// Test unsigned ceil divide with index
145// CHECK-LABEL:       func @ceildivui_index
146// CHECK-SAME:     ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
147func.func @ceildivui_index(%arg0: index, %arg1: index) -> (index) {
148  %res = arith.ceildivui %arg0, %arg1 : index
149  return %res : index
150// CHECK:           [[ZERO:%.+]] = arith.constant 0 : index
151// CHECK:           [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : index
152// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
153// CHECK:           [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : index
154// CHECK:           [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : index
155// CHECK:           [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : index
156// CHECK:           [[RES:%.+]] = arith.select [[ISZERO]], [[ZERO]], [[REM]] : index
157}
158
159// -----
160
161// CHECK-LABEL: func @maximumf
162func.func @maximumf(%a: f32, %b: f32) -> f32 {
163  %result = arith.maximumf %a, %b : f32
164  return %result : f32
165}
166// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
167// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32
168// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32
169// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32
170// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
171// CHECK-NEXT: return %[[RESULT]] : f32
172
173// -----
174
175// CHECK-LABEL: func @maximumf_vector
176func.func @maximumf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
177  %result = arith.maximumf %a, %b : vector<4xf16>
178  return %result : vector<4xf16>
179}
180// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>)
181// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : vector<4xf16>
182// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]]
183// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : vector<4xf16>
184// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]]
185// CHECK-NEXT: return %[[RESULT]] : vector<4xf16>
186
187// -----
188
189// CHECK-LABEL: func @maxnumf
190func.func @maxnumf(%a: f32, %b: f32) -> f32 {
191  %result = arith.maxnumf %a, %b : f32
192  return %result : f32
193}
194
195// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
196// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ugt, %[[LHS]], %[[RHS]] : f32
197// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32
198// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[LHS]] : f32
199// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
200// CHECK-NEXT: return %[[RESULT]] : f32
201
202// -----
203
204// CHECK-LABEL: func @minimumf
205func.func @minimumf(%a: f32, %b: f32) -> f32 {
206  %result = arith.minimumf %a, %b : f32
207  return %result : f32
208}
209
210// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
211// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32
212// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32
213// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[RHS]], %[[RHS]] : f32
214// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
215// CHECK-NEXT: return %[[RESULT]] : f32
216
217// -----
218
219// CHECK-LABEL: func @minnumf
220func.func @minnumf(%a: f32, %b: f32) -> f32 {
221  %result = arith.minnumf %a, %b : f32
222  return %result : f32
223}
224
225// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
226// CHECK-NEXT: %[[CMP:.*]] = arith.cmpf ult, %[[LHS]], %[[RHS]] : f32
227// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : f32
228// CHECK-NEXT: %[[IS_NAN:.*]] = arith.cmpf uno, %[[LHS]], %[[LHS]] : f32
229// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[IS_NAN]], %[[RHS]], %[[SELECT]] : f32
230// CHECK-NEXT: return %[[RESULT]] : f32
231
232// -----
233
234func.func @truncf_f32(%arg0 : f32) -> bf16 {
235    %0 = arith.truncf %arg0 : f32 to bf16
236    return %0 : bf16
237}
238
239// CHECK-LABEL: @truncf_f32
240// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
241// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : i32
242// CHECK-DAG: %[[C7FC0_i16:.+]] = arith.constant 32704 : i16
243// CHECK-DAG: %[[C7FFF:.+]] = arith.constant 32767 : i32
244// CHECK-DAG: %[[ISNAN:.+]] = arith.cmpf une, %arg0, %arg0 : f32
245// CHECK-DAG: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
246// CHECK-DAG: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C16]] : i32
247// CHECK-DAG: %[[BIT16:.+]] = arith.andi %[[SHRUI]], %[[C1]] : i32
248// CHECK-DAG: %[[ROUNDING_BIAS:.+]] = arith.addi %[[BIT16]], %[[C7FFF]] : i32
249// CHECK-DAG: %[[BIASED:.+]] = arith.addi %[[BITCAST]], %[[ROUNDING_BIAS]] : i32
250// CHECK-DAG: %[[BIASED_SHIFTED:.+]] = arith.shrui %[[BIASED]], %[[C16]] : i32
251// CHECK-DAG: %[[NORMAL_CASE_RESULT_i16:.+]] = arith.trunci %[[BIASED_SHIFTED]] : i32 to i16
252// CHECK-DAG: %[[SELECT:.+]] = arith.select %[[ISNAN]], %[[C7FC0_i16]], %[[NORMAL_CASE_RESULT_i16]] : i16
253// CHECK-DAG: %[[RESULT:.+]] = arith.bitcast %[[SELECT]] : i16 to bf16
254// CHECK: return %[[RESULT]]
255
256// -----
257
258func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
259    %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xbf16>
260    return %0 : vector<4xbf16>
261}
262
263// CHECK-LABEL: @truncf_vector_f32
264// CHECK-NOT: arith.truncf
265
266// -----
267
268func.func @maxsi(%a: i32, %b: i32) -> i32 {
269  %result = arith.maxsi %a, %b : i32
270  return %result : i32
271}
272// CHECK-LABEL: func @maxsi
273// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
274// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32
275// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
276// CHECK-NEXT: return %[[RESULT]] : i32
277
278// -----
279
280func.func @minsi(%a: i32, %b: i32) -> i32 {
281  %result = arith.minsi %a, %b : i32
282  return %result : i32
283}
284// CHECK-LABEL: func @minsi
285// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
286// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32
287// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
288// CHECK-NEXT: return %[[RESULT]] : i32
289
290// -----
291
292func.func @maxui(%a: i32, %b: i32) -> i32 {
293  %result = arith.maxui %a, %b : i32
294  return %result : i32
295}
296// CHECK-LABEL: func @maxui
297// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
298// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32
299// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
300// CHECK-NEXT: return %[[RESULT]] : i32
301
302// -----
303
304func.func @minui(%a: i32, %b: i32) -> i32 {
305  %result = arith.minui %a, %b : i32
306  return %result : i32
307}
308// CHECK-LABEL: func @minui
309// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32
310// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32
311// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32
312// CHECK-NEXT: return %[[RESULT]] : i32
313