1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include <optional>
14 #include <utility>
15
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17
18 #include "mlir/Dialect/CommonFolders.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
20 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
21 #include "mlir/Dialect/UB/IR/UBOps.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVectorExtras.h"
26
27 using namespace mlir;
28
29 //===----------------------------------------------------------------------===//
30 // Common utility functions
31 //===----------------------------------------------------------------------===//
32
33 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
34 /// or splat vector bool constant.
getScalarOrSplatBoolAttr(Attribute attr)35 static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
36 if (!attr)
37 return std::nullopt;
38
39 if (auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
40 return boolAttr.getValue();
41 if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
42 if (splatAttr.getElementType().isInteger(1))
43 return splatAttr.getSplatValue<bool>();
44 return std::nullopt;
45 }
46
47 // Extracts an element from the given `composite` by following the given
48 // `indices`. Returns a null Attribute if error happens.
extractCompositeElement(Attribute composite,ArrayRef<unsigned> indices)49 static Attribute extractCompositeElement(Attribute composite,
50 ArrayRef<unsigned> indices) {
51 // Check that given composite is a constant.
52 if (!composite)
53 return {};
54 // Return composite itself if we reach the end of the index chain.
55 if (indices.empty())
56 return composite;
57
58 if (auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
59 assert(indices.size() == 1 && "must have exactly one index for a vector");
60 return vector.getValues<Attribute>()[indices[0]];
61 }
62
63 if (auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
64 assert(!indices.empty() && "must have at least one index for an array");
65 return extractCompositeElement(array.getValue()[indices[0]],
66 indices.drop_front());
67 }
68
69 return {};
70 }
71
isDivZeroOrOverflow(const APInt & a,const APInt & b)72 static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
73 bool div0 = b.isZero();
74 bool overflow = a.isMinSignedValue() && b.isAllOnes();
75
76 return div0 || overflow;
77 }
78
79 //===----------------------------------------------------------------------===//
80 // TableGen'erated canonicalizers
81 //===----------------------------------------------------------------------===//
82
83 namespace {
84 #include "SPIRVCanonicalization.inc"
85 } // namespace
86
87 //===----------------------------------------------------------------------===//
88 // spirv.AccessChainOp
89 //===----------------------------------------------------------------------===//
90
91 namespace {
92
93 /// Combines chained `spirv::AccessChainOp` operations into one
94 /// `spirv::AccessChainOp` operation.
95 struct CombineChainedAccessChain final
96 : OpRewritePattern<spirv::AccessChainOp> {
97 using OpRewritePattern::OpRewritePattern;
98
matchAndRewrite__anona5f61ae80211::CombineChainedAccessChain99 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
100 PatternRewriter &rewriter) const override {
101 auto parentAccessChainOp =
102 accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
103
104 if (!parentAccessChainOp) {
105 return failure();
106 }
107
108 // Combine indices.
109 SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
110 llvm::append_range(indices, accessChainOp.getIndices());
111
112 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
113 accessChainOp, parentAccessChainOp.getBasePtr(), indices);
114
115 return success();
116 }
117 };
118 } // namespace
119
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)120 void spirv::AccessChainOp::getCanonicalizationPatterns(
121 RewritePatternSet &results, MLIRContext *context) {
122 results.add<CombineChainedAccessChain>(context);
123 }
124
125 //===----------------------------------------------------------------------===//
126 // spirv.IAddCarry
127 //===----------------------------------------------------------------------===//
128
129 // We are required to use CompositeConstructOp to create a constant struct as
130 // they are not yet implemented as constant, hence we can not do so in a fold.
131 struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
132 using OpRewritePattern::OpRewritePattern;
133
matchAndRewriteIAddCarryFold134 LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
135 PatternRewriter &rewriter) const override {
136 Location loc = op.getLoc();
137 Value lhs = op.getOperand1();
138 Value rhs = op.getOperand2();
139 Type constituentType = lhs.getType();
140
141 // iaddcarry (x, 0) = <0, x>
142 if (matchPattern(rhs, m_Zero())) {
143 Value constituents[2] = {rhs, lhs};
144 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
145 constituents);
146 return success();
147 }
148
149 // According to the SPIR-V spec:
150 //
151 // Result Type must be from OpTypeStruct. The struct must have two
152 // members...
153 //
154 // Member 0 of the result gets the low-order bits (full component width) of
155 // the addition.
156 //
157 // Member 1 of the result gets the high-order (carry) bit of the result of
158 // the addition. That is, it gets the value 1 if the addition overflowed
159 // the component width, and 0 otherwise.
160 Attribute lhsAttr;
161 Attribute rhsAttr;
162 if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
163 !matchPattern(rhs, m_Constant(&rhsAttr)))
164 return failure();
165
166 auto adds = constFoldBinaryOp<IntegerAttr>(
167 {lhsAttr, rhsAttr},
168 [](const APInt &a, const APInt &b) { return a + b; });
169 if (!adds)
170 return failure();
171
172 auto carrys = constFoldBinaryOp<IntegerAttr>(
173 ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
174 APInt zero = APInt::getZero(a.getBitWidth());
175 return a.ult(b) ? (zero + 1) : zero;
176 });
177
178 if (!carrys)
179 return failure();
180
181 Value addsVal =
182 rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
183
184 Value carrysVal =
185 rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
186
187 // Create empty struct
188 Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
189 // Fill in adds at id 0
190 Value intermediate =
191 rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
192 // Fill in carrys at id 1
193 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
194 intermediate, 1);
195 return success();
196 }
197 };
198
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)199 void spirv::IAddCarryOp::getCanonicalizationPatterns(
200 RewritePatternSet &patterns, MLIRContext *context) {
201 patterns.add<IAddCarryFold>(context);
202 }
203
204 //===----------------------------------------------------------------------===//
205 // spirv.[S|U]MulExtended
206 //===----------------------------------------------------------------------===//
207
208 // We are required to use CompositeConstructOp to create a constant struct as
209 // they are not yet implemented as constant, hence we can not do so in a fold.
210 template <typename MulOp, bool IsSigned>
211 struct MulExtendedFold final : OpRewritePattern<MulOp> {
212 using OpRewritePattern<MulOp>::OpRewritePattern;
213
matchAndRewriteMulExtendedFold214 LogicalResult matchAndRewrite(MulOp op,
215 PatternRewriter &rewriter) const override {
216 Location loc = op.getLoc();
217 Value lhs = op.getOperand1();
218 Value rhs = op.getOperand2();
219 Type constituentType = lhs.getType();
220
221 // [su]mulextended (x, 0) = <0, 0>
222 if (matchPattern(rhs, m_Zero())) {
223 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
224 Value constituents[2] = {zero, zero};
225 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
226 constituents);
227 return success();
228 }
229
230 // According to the SPIR-V spec:
231 //
232 // Result Type must be from OpTypeStruct. The struct must have two
233 // members...
234 //
235 // Member 0 of the result gets the low-order bits of the multiplication.
236 //
237 // Member 1 of the result gets the high-order bits of the multiplication.
238 Attribute lhsAttr;
239 Attribute rhsAttr;
240 if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
241 !matchPattern(rhs, m_Constant(&rhsAttr)))
242 return failure();
243
244 auto lowBits = constFoldBinaryOp<IntegerAttr>(
245 {lhsAttr, rhsAttr},
246 [](const APInt &a, const APInt &b) { return a * b; });
247
248 if (!lowBits)
249 return failure();
250
251 auto highBits = constFoldBinaryOp<IntegerAttr>(
252 {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
253 if (IsSigned) {
254 return llvm::APIntOps::mulhs(a, b);
255 } else {
256 return llvm::APIntOps::mulhu(a, b);
257 }
258 });
259
260 if (!highBits)
261 return failure();
262
263 Value lowBitsVal =
264 rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
265
266 Value highBitsVal =
267 rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
268
269 // Create empty struct
270 Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
271 // Fill in lowBits at id 0
272 Value intermediate =
273 rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
274 // Fill in highBits at id 1
275 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
276 intermediate, 1);
277 return success();
278 }
279 };
280
281 using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)282 void spirv::SMulExtendedOp::getCanonicalizationPatterns(
283 RewritePatternSet &patterns, MLIRContext *context) {
284 patterns.add<SMulExtendedOpFold>(context);
285 }
286
287 struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
288 using OpRewritePattern::OpRewritePattern;
289
matchAndRewriteUMulExtendedOpXOne290 LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
291 PatternRewriter &rewriter) const override {
292 Location loc = op.getLoc();
293 Value lhs = op.getOperand1();
294 Value rhs = op.getOperand2();
295 Type constituentType = lhs.getType();
296
297 // umulextended (x, 1) = <x, 0>
298 if (matchPattern(rhs, m_One())) {
299 Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
300 Value constituents[2] = {lhs, zero};
301 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
302 constituents);
303 return success();
304 }
305
306 return failure();
307 }
308 };
309
310 using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)311 void spirv::UMulExtendedOp::getCanonicalizationPatterns(
312 RewritePatternSet &patterns, MLIRContext *context) {
313 patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
314 }
315
316 //===----------------------------------------------------------------------===//
317 // spirv.UMod
318 //===----------------------------------------------------------------------===//
319
320 // Input:
321 // %0 = spirv.UMod %arg0, %const32 : i32
322 // %1 = spirv.UMod %0, %const4 : i32
323 // Output:
324 // %0 = spirv.UMod %arg0, %const32 : i32
325 // %1 = spirv.UMod %arg0, %const4 : i32
326
327 // The transformation is only applied if one divisor is a multiple of the other.
328
329 // TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
330 struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
331 using OpRewritePattern::OpRewritePattern;
332
matchAndRewriteUModSimplification333 LogicalResult matchAndRewrite(spirv::UModOp umodOp,
334 PatternRewriter &rewriter) const override {
335 auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
336 if (!prevUMod)
337 return failure();
338
339 IntegerAttr prevValue;
340 IntegerAttr currValue;
341 if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
342 !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
343 return failure();
344
345 APInt prevConstValue = prevValue.getValue();
346 APInt currConstValue = currValue.getValue();
347
348 // Ensure that one divisor is a multiple of the other. If not, fail the
349 // transformation.
350 if (prevConstValue.urem(currConstValue) != 0 &&
351 currConstValue.urem(prevConstValue) != 0)
352 return failure();
353
354 // The transformation is safe. Replace the existing UMod operation with a
355 // new UMod operation, using the original dividend and the current divisor.
356 rewriter.replaceOpWithNewOp<spirv::UModOp>(
357 umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
358
359 return success();
360 }
361 };
362
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)363 void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
364 MLIRContext *context) {
365 patterns.insert<UModSimplification>(context);
366 }
367
368 //===----------------------------------------------------------------------===//
369 // spirv.BitcastOp
370 //===----------------------------------------------------------------------===//
371
fold(FoldAdaptor)372 OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
373 Value curInput = getOperand();
374 if (getType() == curInput.getType())
375 return curInput;
376
377 // Look through nested bitcasts.
378 if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
379 Value prevInput = prevCast.getOperand();
380 if (prevInput.getType() == getType())
381 return prevInput;
382
383 getOperandMutable().assign(prevInput);
384 return getResult();
385 }
386
387 // TODO(kuhar): Consider constant-folding the operand attribute.
388 return {};
389 }
390
391 //===----------------------------------------------------------------------===//
392 // spirv.CompositeExtractOp
393 //===----------------------------------------------------------------------===//
394
fold(FoldAdaptor adaptor)395 OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
396 Value compositeOp = getComposite();
397
398 while (auto insertOp =
399 compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
400 if (getIndices() == insertOp.getIndices())
401 return insertOp.getObject();
402 compositeOp = insertOp.getComposite();
403 }
404
405 if (auto constructOp =
406 compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
407 auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
408 if (getIndices().size() == 1 &&
409 constructOp.getConstituents().size() == type.getNumElements()) {
410 auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
411 if (i.getValue().getSExtValue() <
412 static_cast<int64_t>(constructOp.getConstituents().size()))
413 return constructOp.getConstituents()[i.getValue().getSExtValue()];
414 }
415 }
416
417 auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
418 return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
419 });
420 return extractCompositeElement(adaptor.getComposite(), indexVector);
421 }
422
423 //===----------------------------------------------------------------------===//
424 // spirv.Constant
425 //===----------------------------------------------------------------------===//
426
fold(FoldAdaptor)427 OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
428 return getValue();
429 }
430
431 //===----------------------------------------------------------------------===//
432 // spirv.IAdd
433 //===----------------------------------------------------------------------===//
434
fold(FoldAdaptor adaptor)435 OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
436 // x + 0 = x
437 if (matchPattern(getOperand2(), m_Zero()))
438 return getOperand1();
439
440 // According to the SPIR-V spec:
441 //
442 // The resulting value will equal the low-order N bits of the correct result
443 // R, where N is the component width and R is computed with enough precision
444 // to avoid overflow and underflow.
445 return constFoldBinaryOp<IntegerAttr>(
446 adaptor.getOperands(),
447 [](APInt a, const APInt &b) { return std::move(a) + b; });
448 }
449
450 //===----------------------------------------------------------------------===//
451 // spirv.IMul
452 //===----------------------------------------------------------------------===//
453
fold(FoldAdaptor adaptor)454 OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
455 // x * 0 == 0
456 if (matchPattern(getOperand2(), m_Zero()))
457 return getOperand2();
458 // x * 1 = x
459 if (matchPattern(getOperand2(), m_One()))
460 return getOperand1();
461
462 // According to the SPIR-V spec:
463 //
464 // The resulting value will equal the low-order N bits of the correct result
465 // R, where N is the component width and R is computed with enough precision
466 // to avoid overflow and underflow.
467 return constFoldBinaryOp<IntegerAttr>(
468 adaptor.getOperands(),
469 [](const APInt &a, const APInt &b) { return a * b; });
470 }
471
472 //===----------------------------------------------------------------------===//
473 // spirv.ISub
474 //===----------------------------------------------------------------------===//
475
fold(FoldAdaptor adaptor)476 OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
477 // x - x = 0
478 if (getOperand1() == getOperand2())
479 return Builder(getContext()).getIntegerAttr(getType(), 0);
480
481 // According to the SPIR-V spec:
482 //
483 // The resulting value will equal the low-order N bits of the correct result
484 // R, where N is the component width and R is computed with enough precision
485 // to avoid overflow and underflow.
486 return constFoldBinaryOp<IntegerAttr>(
487 adaptor.getOperands(),
488 [](APInt a, const APInt &b) { return std::move(a) - b; });
489 }
490
491 //===----------------------------------------------------------------------===//
492 // spirv.SDiv
493 //===----------------------------------------------------------------------===//
494
fold(FoldAdaptor adaptor)495 OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
496 // sdiv (x, 1) = x
497 if (matchPattern(getOperand2(), m_One()))
498 return getOperand1();
499
500 // According to the SPIR-V spec:
501 //
502 // Signed-integer division of Operand 1 divided by Operand 2.
503 // Results are computed per component. Behavior is undefined if Operand 2 is
504 // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
505 // representable value for the operands' type, causing signed overflow.
506 //
507 // So don't fold during undefined behavior.
508 bool div0OrOverflow = false;
509 auto res = constFoldBinaryOp<IntegerAttr>(
510 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
511 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
512 div0OrOverflow = true;
513 return a;
514 }
515 return a.sdiv(b);
516 });
517 return div0OrOverflow ? Attribute() : res;
518 }
519
520 //===----------------------------------------------------------------------===//
521 // spirv.SMod
522 //===----------------------------------------------------------------------===//
523
fold(FoldAdaptor adaptor)524 OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
525 // smod (x, 1) = 0
526 if (matchPattern(getOperand2(), m_One()))
527 return Builder(getContext()).getZeroAttr(getType());
528
529 // According to SPIR-V spec:
530 //
531 // Signed remainder operation for the remainder whose sign matches the sign
532 // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
533 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
534 // value for the operands' type, causing signed overflow. Otherwise, the
535 // result is the remainder r of Operand 1 divided by Operand 2 where if
536 // r ≠ 0, the sign of r is the same as the sign of Operand 2.
537 //
538 // So don't fold during undefined behavior
539 bool div0OrOverflow = false;
540 auto res = constFoldBinaryOp<IntegerAttr>(
541 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
542 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
543 div0OrOverflow = true;
544 return a;
545 }
546 APInt c = a.abs().urem(b.abs());
547 if (c.isZero())
548 return c;
549 if (b.isNegative()) {
550 APInt zero = APInt::getZero(c.getBitWidth());
551 return a.isNegative() ? (zero - c) : (b + c);
552 }
553 return a.isNegative() ? (b - c) : c;
554 });
555 return div0OrOverflow ? Attribute() : res;
556 }
557
558 //===----------------------------------------------------------------------===//
559 // spirv.SRem
560 //===----------------------------------------------------------------------===//
561
fold(FoldAdaptor adaptor)562 OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
563 // x % 1 = 0
564 if (matchPattern(getOperand2(), m_One()))
565 return Builder(getContext()).getZeroAttr(getType());
566
567 // According to SPIR-V spec:
568 //
569 // Signed remainder operation for the remainder whose sign matches the sign
570 // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
571 // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
572 // value for the operands' type, causing signed overflow. Otherwise, the
573 // result is the remainder r of Operand 1 divided by Operand 2 where if
574 // r ≠ 0, the sign of r is the same as the sign of Operand 1.
575
576 // Don't fold if it would do undefined behavior.
577 bool div0OrOverflow = false;
578 auto res = constFoldBinaryOp<IntegerAttr>(
579 adaptor.getOperands(), [&](APInt a, const APInt &b) {
580 if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
581 div0OrOverflow = true;
582 return a;
583 }
584 return a.srem(b);
585 });
586 return div0OrOverflow ? Attribute() : res;
587 }
588
589 //===----------------------------------------------------------------------===//
590 // spirv.UDiv
591 //===----------------------------------------------------------------------===//
592
fold(FoldAdaptor adaptor)593 OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
594 // udiv (x, 1) = x
595 if (matchPattern(getOperand2(), m_One()))
596 return getOperand1();
597
598 // According to the SPIR-V spec:
599 //
600 // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
601 // undefined if Operand 2 is 0.
602 //
603 // So don't fold during undefined behavior.
604 bool div0 = false;
605 auto res = constFoldBinaryOp<IntegerAttr>(
606 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
607 if (div0 || b.isZero()) {
608 div0 = true;
609 return a;
610 }
611 return a.udiv(b);
612 });
613 return div0 ? Attribute() : res;
614 }
615
616 //===----------------------------------------------------------------------===//
617 // spirv.UMod
618 //===----------------------------------------------------------------------===//
619
fold(FoldAdaptor adaptor)620 OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
621 // umod (x, 1) = 0
622 if (matchPattern(getOperand2(), m_One()))
623 return Builder(getContext()).getZeroAttr(getType());
624
625 // According to the SPIR-V spec:
626 //
627 // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
628 // undefined if Operand 2 is 0.
629 //
630 // So don't fold during undefined behavior.
631 bool div0 = false;
632 auto res = constFoldBinaryOp<IntegerAttr>(
633 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
634 if (div0 || b.isZero()) {
635 div0 = true;
636 return a;
637 }
638 return a.urem(b);
639 });
640 return div0 ? Attribute() : res;
641 }
642
643 //===----------------------------------------------------------------------===//
644 // spirv.SNegate
645 //===----------------------------------------------------------------------===//
646
fold(FoldAdaptor adaptor)647 OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
648 // -(-x) = 0 - (0 - x) = x
649 auto op = getOperand();
650 if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
651 return negateOp->getOperand(0);
652
653 // According to the SPIR-V spec:
654 //
655 // Signed-integer subtract of Operand from zero.
656 return constFoldUnaryOp<IntegerAttr>(
657 adaptor.getOperands(), [](const APInt &a) {
658 APInt zero = APInt::getZero(a.getBitWidth());
659 return zero - a;
660 });
661 }
662
663 //===----------------------------------------------------------------------===//
664 // spirv.NotOp
665 //===----------------------------------------------------------------------===//
666
fold(spirv::NotOp::FoldAdaptor adaptor)667 OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
668 // !(!x) = x
669 auto op = getOperand();
670 if (auto notOp = op.getDefiningOp<spirv::NotOp>())
671 return notOp->getOperand(0);
672
673 // According to the SPIR-V spec:
674 //
675 // Complement the bits of Operand.
676 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
677 a.flipAllBits();
678 return a;
679 });
680 }
681
682 //===----------------------------------------------------------------------===//
683 // spirv.LogicalAnd
684 //===----------------------------------------------------------------------===//
685
fold(FoldAdaptor adaptor)686 OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
687 if (std::optional<bool> rhs =
688 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
689 // x && true = x
690 if (*rhs)
691 return getOperand1();
692
693 // x && false = false
694 if (!*rhs)
695 return adaptor.getOperand2();
696 }
697
698 return Attribute();
699 }
700
701 //===----------------------------------------------------------------------===//
702 // spirv.LogicalEqualOp
703 //===----------------------------------------------------------------------===//
704
705 OpFoldResult
fold(spirv::LogicalEqualOp::FoldAdaptor adaptor)706 spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
707 // x == x -> true
708 if (getOperand1() == getOperand2()) {
709 auto trueAttr = BoolAttr::get(getContext(), true);
710 if (isa<IntegerType>(getType()))
711 return trueAttr;
712 if (auto vecTy = dyn_cast<VectorType>(getType()))
713 return SplatElementsAttr::get(vecTy, trueAttr);
714 }
715
716 return constFoldBinaryOp<IntegerAttr>(
717 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
718 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
719 });
720 }
721
722 //===----------------------------------------------------------------------===//
723 // spirv.LogicalNotEqualOp
724 //===----------------------------------------------------------------------===//
725
fold(FoldAdaptor adaptor)726 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
727 if (std::optional<bool> rhs =
728 getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
729 // x != false -> x
730 if (!rhs.value())
731 return getOperand1();
732 }
733
734 // x == x -> false
735 if (getOperand1() == getOperand2()) {
736 auto falseAttr = BoolAttr::get(getContext(), false);
737 if (isa<IntegerType>(getType()))
738 return falseAttr;
739 if (auto vecTy = dyn_cast<VectorType>(getType()))
740 return SplatElementsAttr::get(vecTy, falseAttr);
741 }
742
743 return constFoldBinaryOp<IntegerAttr>(
744 adaptor.getOperands(), [](const APInt &a, const APInt &b) {
745 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
746 });
747 }
748
749 //===----------------------------------------------------------------------===//
750 // spirv.LogicalNot
751 //===----------------------------------------------------------------------===//
752
fold(FoldAdaptor adaptor)753 OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
754 // !(!x) = x
755 auto op = getOperand();
756 if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
757 return notOp->getOperand(0);
758
759 // According to the SPIR-V spec:
760 //
761 // Complement the bits of Operand.
762 return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
763 [](const APInt &a) {
764 APInt zero = APInt::getZero(1);
765 return a == 1 ? zero : (zero + 1);
766 });
767 }
768
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)769 void spirv::LogicalNotOp::getCanonicalizationPatterns(
770 RewritePatternSet &results, MLIRContext *context) {
771 results
772 .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
773 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
774 context);
775 }
776
777 //===----------------------------------------------------------------------===//
778 // spirv.LogicalOr
779 //===----------------------------------------------------------------------===//
780
fold(FoldAdaptor adaptor)781 OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
782 if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
783 if (*rhs) {
784 // x || true = true
785 return adaptor.getOperand2();
786 }
787
788 if (!*rhs) {
789 // x || false = x
790 return getOperand1();
791 }
792 }
793
794 return Attribute();
795 }
796
797 //===----------------------------------------------------------------------===//
798 // spirv.SelectOp
799 //===----------------------------------------------------------------------===//
800
fold(FoldAdaptor adaptor)801 OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
802 // spirv.Select _ x x -> x
803 Value trueVals = getTrueValue();
804 Value falseVals = getFalseValue();
805 if (trueVals == falseVals)
806 return trueVals;
807
808 ArrayRef<Attribute> operands = adaptor.getOperands();
809
810 // spirv.Select true x y -> x
811 // spirv.Select false x y -> y
812 if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
813 return *boolAttr ? trueVals : falseVals;
814
815 // Check that all the operands are constant
816 if (!operands[0] || !operands[1] || !operands[2])
817 return Attribute();
818
819 // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
820 // the scalar case. Hence, we are only required to consider the case of
821 // DenseElementsAttr in foldSelectOp.
822 auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
823 auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
824 auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
825 if (!condAttrs || !trueAttrs || !falseAttrs)
826 return Attribute();
827
828 auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
829 auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
830 falseAttrs.getValues<Attribute>());
831 for (auto [result, cond, falseRes] : iters) {
832 if (!cond.getValue())
833 result = falseRes;
834 }
835
836 auto resultType = trueAttrs.getType();
837 return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
838 }
839
840 //===----------------------------------------------------------------------===//
841 // spirv.IEqualOp
842 //===----------------------------------------------------------------------===//
843
fold(spirv::IEqualOp::FoldAdaptor adaptor)844 OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
845 // x == x -> true
846 if (getOperand1() == getOperand2()) {
847 auto trueAttr = BoolAttr::get(getContext(), true);
848 if (isa<IntegerType>(getType()))
849 return trueAttr;
850 if (auto vecTy = dyn_cast<VectorType>(getType()))
851 return SplatElementsAttr::get(vecTy, trueAttr);
852 }
853
854 return constFoldBinaryOp<IntegerAttr>(
855 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
856 return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
857 });
858 }
859
860 //===----------------------------------------------------------------------===//
861 // spirv.INotEqualOp
862 //===----------------------------------------------------------------------===//
863
fold(spirv::INotEqualOp::FoldAdaptor adaptor)864 OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
865 // x == x -> false
866 if (getOperand1() == getOperand2()) {
867 auto falseAttr = BoolAttr::get(getContext(), false);
868 if (isa<IntegerType>(getType()))
869 return falseAttr;
870 if (auto vecTy = dyn_cast<VectorType>(getType()))
871 return SplatElementsAttr::get(vecTy, falseAttr);
872 }
873
874 return constFoldBinaryOp<IntegerAttr>(
875 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
876 return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
877 });
878 }
879
880 //===----------------------------------------------------------------------===//
881 // spirv.SGreaterThan
882 //===----------------------------------------------------------------------===//
883
884 OpFoldResult
fold(spirv::SGreaterThanOp::FoldAdaptor adaptor)885 spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
886 // x == x -> false
887 if (getOperand1() == getOperand2()) {
888 auto falseAttr = BoolAttr::get(getContext(), false);
889 if (isa<IntegerType>(getType()))
890 return falseAttr;
891 if (auto vecTy = dyn_cast<VectorType>(getType()))
892 return SplatElementsAttr::get(vecTy, falseAttr);
893 }
894
895 return constFoldBinaryOp<IntegerAttr>(
896 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
897 return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
898 });
899 }
900
901 //===----------------------------------------------------------------------===//
902 // spirv.SGreaterThanEqual
903 //===----------------------------------------------------------------------===//
904
fold(spirv::SGreaterThanEqualOp::FoldAdaptor adaptor)905 OpFoldResult spirv::SGreaterThanEqualOp::fold(
906 spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
907 // x == x -> true
908 if (getOperand1() == getOperand2()) {
909 auto trueAttr = BoolAttr::get(getContext(), true);
910 if (isa<IntegerType>(getType()))
911 return trueAttr;
912 if (auto vecTy = dyn_cast<VectorType>(getType()))
913 return SplatElementsAttr::get(vecTy, trueAttr);
914 }
915
916 return constFoldBinaryOp<IntegerAttr>(
917 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
918 return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
919 });
920 }
921
922 //===----------------------------------------------------------------------===//
923 // spirv.UGreaterThan
924 //===----------------------------------------------------------------------===//
925
926 OpFoldResult
fold(spirv::UGreaterThanOp::FoldAdaptor adaptor)927 spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
928 // x == x -> false
929 if (getOperand1() == getOperand2()) {
930 auto falseAttr = BoolAttr::get(getContext(), false);
931 if (isa<IntegerType>(getType()))
932 return falseAttr;
933 if (auto vecTy = dyn_cast<VectorType>(getType()))
934 return SplatElementsAttr::get(vecTy, falseAttr);
935 }
936
937 return constFoldBinaryOp<IntegerAttr>(
938 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
939 return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
940 });
941 }
942
943 //===----------------------------------------------------------------------===//
944 // spirv.UGreaterThanEqual
945 //===----------------------------------------------------------------------===//
946
fold(spirv::UGreaterThanEqualOp::FoldAdaptor adaptor)947 OpFoldResult spirv::UGreaterThanEqualOp::fold(
948 spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
949 // x == x -> true
950 if (getOperand1() == getOperand2()) {
951 auto trueAttr = BoolAttr::get(getContext(), true);
952 if (isa<IntegerType>(getType()))
953 return trueAttr;
954 if (auto vecTy = dyn_cast<VectorType>(getType()))
955 return SplatElementsAttr::get(vecTy, trueAttr);
956 }
957
958 return constFoldBinaryOp<IntegerAttr>(
959 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
960 return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
961 });
962 }
963
964 //===----------------------------------------------------------------------===//
965 // spirv.SLessThan
966 //===----------------------------------------------------------------------===//
967
fold(spirv::SLessThanOp::FoldAdaptor adaptor)968 OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
969 // x == x -> false
970 if (getOperand1() == getOperand2()) {
971 auto falseAttr = BoolAttr::get(getContext(), false);
972 if (isa<IntegerType>(getType()))
973 return falseAttr;
974 if (auto vecTy = dyn_cast<VectorType>(getType()))
975 return SplatElementsAttr::get(vecTy, falseAttr);
976 }
977
978 return constFoldBinaryOp<IntegerAttr>(
979 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
980 return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
981 });
982 }
983
984 //===----------------------------------------------------------------------===//
985 // spirv.SLessThanEqual
986 //===----------------------------------------------------------------------===//
987
988 OpFoldResult
fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor)989 spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
990 // x == x -> true
991 if (getOperand1() == getOperand2()) {
992 auto trueAttr = BoolAttr::get(getContext(), true);
993 if (isa<IntegerType>(getType()))
994 return trueAttr;
995 if (auto vecTy = dyn_cast<VectorType>(getType()))
996 return SplatElementsAttr::get(vecTy, trueAttr);
997 }
998
999 return constFoldBinaryOp<IntegerAttr>(
1000 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1001 return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1002 });
1003 }
1004
1005 //===----------------------------------------------------------------------===//
1006 // spirv.ULessThan
1007 //===----------------------------------------------------------------------===//
1008
fold(spirv::ULessThanOp::FoldAdaptor adaptor)1009 OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1010 // x == x -> false
1011 if (getOperand1() == getOperand2()) {
1012 auto falseAttr = BoolAttr::get(getContext(), false);
1013 if (isa<IntegerType>(getType()))
1014 return falseAttr;
1015 if (auto vecTy = dyn_cast<VectorType>(getType()))
1016 return SplatElementsAttr::get(vecTy, falseAttr);
1017 }
1018
1019 return constFoldBinaryOp<IntegerAttr>(
1020 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1021 return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1022 });
1023 }
1024
1025 //===----------------------------------------------------------------------===//
1026 // spirv.ULessThanEqual
1027 //===----------------------------------------------------------------------===//
1028
1029 OpFoldResult
fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor)1030 spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1031 // x == x -> true
1032 if (getOperand1() == getOperand2()) {
1033 auto trueAttr = BoolAttr::get(getContext(), true);
1034 if (isa<IntegerType>(getType()))
1035 return trueAttr;
1036 if (auto vecTy = dyn_cast<VectorType>(getType()))
1037 return SplatElementsAttr::get(vecTy, trueAttr);
1038 }
1039
1040 return constFoldBinaryOp<IntegerAttr>(
1041 adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1042 return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1043 });
1044 }
1045
1046 //===----------------------------------------------------------------------===//
1047 // spirv.ShiftLeftLogical
1048 //===----------------------------------------------------------------------===//
1049
fold(spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor)1050 OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1051 spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1052 // x << 0 -> x
1053 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1054 return getOperand1();
1055 }
1056
1057 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1058
1059 // Results are computed per component, and within each component, per bit...
1060 //
1061 // The result is undefined if Shift is greater than or equal to the bit width
1062 // of the components of Base.
1063 //
1064 // So we can use the APInt << method, but don't fold if undefined behaviour.
1065 bool shiftToLarge = false;
1066 auto res = constFoldBinaryOp<IntegerAttr>(
1067 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1068 if (shiftToLarge || b.uge(a.getBitWidth())) {
1069 shiftToLarge = true;
1070 return a;
1071 }
1072 return a << b;
1073 });
1074 return shiftToLarge ? Attribute() : res;
1075 }
1076
1077 //===----------------------------------------------------------------------===//
1078 // spirv.ShiftRightArithmetic
1079 //===----------------------------------------------------------------------===//
1080
fold(spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor)1081 OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1082 spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1083 // x >> 0 -> x
1084 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1085 return getOperand1();
1086 }
1087
1088 // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
1089
1090 // Results are computed per component, and within each component, per bit...
1091 //
1092 // The result is undefined if Shift is greater than or equal to the bit width
1093 // of the components of Base.
1094 //
1095 // So we can use the APInt ashr method, but don't fold if undefined behaviour.
1096 bool shiftToLarge = false;
1097 auto res = constFoldBinaryOp<IntegerAttr>(
1098 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1099 if (shiftToLarge || b.uge(a.getBitWidth())) {
1100 shiftToLarge = true;
1101 return a;
1102 }
1103 return a.ashr(b);
1104 });
1105 return shiftToLarge ? Attribute() : res;
1106 }
1107
1108 //===----------------------------------------------------------------------===//
1109 // spirv.ShiftRightLogical
1110 //===----------------------------------------------------------------------===//
1111
fold(spirv::ShiftRightLogicalOp::FoldAdaptor adaptor)1112 OpFoldResult spirv::ShiftRightLogicalOp::fold(
1113 spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1114 // x >> 0 -> x
1115 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1116 return getOperand1();
1117 }
1118
1119 // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1120
1121 // Results are computed per component, and within each component, per bit...
1122 //
1123 // The result is undefined if Shift is greater than or equal to the bit width
1124 // of the components of Base.
1125 //
1126 // So we can use the APInt lshr method, but don't fold if undefined behaviour.
1127 bool shiftToLarge = false;
1128 auto res = constFoldBinaryOp<IntegerAttr>(
1129 adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1130 if (shiftToLarge || b.uge(a.getBitWidth())) {
1131 shiftToLarge = true;
1132 return a;
1133 }
1134 return a.lshr(b);
1135 });
1136 return shiftToLarge ? Attribute() : res;
1137 }
1138
1139 //===----------------------------------------------------------------------===//
1140 // spirv.BitwiseAndOp
1141 //===----------------------------------------------------------------------===//
1142
1143 OpFoldResult
fold(spirv::BitwiseAndOp::FoldAdaptor adaptor)1144 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1145 // x & x -> x
1146 if (getOperand1() == getOperand2()) {
1147 return getOperand1();
1148 }
1149
1150 APInt rhsMask;
1151 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1152 // x & 0 -> 0
1153 if (rhsMask.isZero())
1154 return getOperand2();
1155
1156 // x & <all ones> -> x
1157 if (rhsMask.isAllOnes())
1158 return getOperand1();
1159
1160 // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1161 if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1162 int valueBits =
1163 getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
1164 if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1165 return getOperand1();
1166 }
1167 }
1168
1169 // According to the SPIR-V spec:
1170 //
1171 // Type is a scalar or vector of integer type.
1172 // Results are computed per component, and within each component, per bit.
1173 // So we can use the APInt & method.
1174 return constFoldBinaryOp<IntegerAttr>(
1175 adaptor.getOperands(),
1176 [](const APInt &a, const APInt &b) { return a & b; });
1177 }
1178
1179 //===----------------------------------------------------------------------===//
1180 // spirv.BitwiseOrOp
1181 //===----------------------------------------------------------------------===//
1182
fold(spirv::BitwiseOrOp::FoldAdaptor adaptor)1183 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1184 // x | x -> x
1185 if (getOperand1() == getOperand2()) {
1186 return getOperand1();
1187 }
1188
1189 APInt rhsMask;
1190 if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1191 // x | 0 -> x
1192 if (rhsMask.isZero())
1193 return getOperand1();
1194
1195 // x | <all ones> -> <all ones>
1196 if (rhsMask.isAllOnes())
1197 return getOperand2();
1198 }
1199
1200 // According to the SPIR-V spec:
1201 //
1202 // Type is a scalar or vector of integer type.
1203 // Results are computed per component, and within each component, per bit.
1204 // So we can use the APInt | method.
1205 return constFoldBinaryOp<IntegerAttr>(
1206 adaptor.getOperands(),
1207 [](const APInt &a, const APInt &b) { return a | b; });
1208 }
1209
1210 //===----------------------------------------------------------------------===//
1211 // spirv.BitwiseXorOp
1212 //===----------------------------------------------------------------------===//
1213
1214 OpFoldResult
fold(spirv::BitwiseXorOp::FoldAdaptor adaptor)1215 spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1216 // x ^ 0 -> x
1217 if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1218 return getOperand1();
1219 }
1220
1221 // x ^ x -> 0
1222 if (getOperand1() == getOperand2())
1223 return Builder(getContext()).getZeroAttr(getType());
1224
1225 // According to the SPIR-V spec:
1226 //
1227 // Type is a scalar or vector of integer type.
1228 // Results are computed per component, and within each component, per bit.
1229 // So we can use the APInt ^ method.
1230 return constFoldBinaryOp<IntegerAttr>(
1231 adaptor.getOperands(),
1232 [](const APInt &a, const APInt &b) { return a ^ b; });
1233 }
1234
1235 //===----------------------------------------------------------------------===//
1236 // spirv.mlir.selection
1237 //===----------------------------------------------------------------------===//
1238
1239 namespace {
1240 // Blocks from the given `spirv.mlir.selection` operation must satisfy the
1241 // following layout:
1242 //
1243 // +-----------------------------------------------+
1244 // | header block |
1245 // | spirv.BranchConditionalOp %cond, ^case0, ^case1 |
1246 // +-----------------------------------------------+
1247 // / \
1248 // ...
1249 //
1250 //
1251 // +------------------------+ +------------------------+
1252 // | case #0 | | case #1 |
1253 // | spirv.Store %ptr %value0 | | spirv.Store %ptr %value1 |
1254 // | spirv.Branch ^merge | | spirv.Branch ^merge |
1255 // +------------------------+ +------------------------+
1256 //
1257 //
1258 // ...
1259 // \ /
1260 // v
1261 // +-------------+
1262 // | merge block |
1263 // +-------------+
1264 //
1265 struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
1266 using OpRewritePattern::OpRewritePattern;
1267
matchAndRewrite__anona5f61ae82511::ConvertSelectionOpToSelect1268 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1269 PatternRewriter &rewriter) const override {
1270 Operation *op = selectionOp.getOperation();
1271 Region &body = op->getRegion(0);
1272 // Verifier allows an empty region for `spirv.mlir.selection`.
1273 if (body.empty()) {
1274 return failure();
1275 }
1276
1277 // Check that region consists of 4 blocks:
1278 // header block, `true` block, `false` block and merge block.
1279 if (llvm::range_size(body) != 4) {
1280 return failure();
1281 }
1282
1283 Block *headerBlock = selectionOp.getHeaderBlock();
1284 if (!onlyContainsBranchConditionalOp(headerBlock)) {
1285 return failure();
1286 }
1287
1288 auto brConditionalOp =
1289 cast<spirv::BranchConditionalOp>(headerBlock->front());
1290
1291 Block *trueBlock = brConditionalOp.getSuccessor(0);
1292 Block *falseBlock = brConditionalOp.getSuccessor(1);
1293 Block *mergeBlock = selectionOp.getMergeBlock();
1294
1295 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1296 return failure();
1297
1298 Value trueValue = getSrcValue(trueBlock);
1299 Value falseValue = getSrcValue(falseBlock);
1300 Value ptrValue = getDstPtr(trueBlock);
1301 auto storeOpAttributes =
1302 cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
1303
1304 auto selectOp = rewriter.create<spirv::SelectOp>(
1305 selectionOp.getLoc(), trueValue.getType(),
1306 brConditionalOp.getCondition(), trueValue, falseValue);
1307 rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
1308 selectOp.getResult(), storeOpAttributes);
1309
1310 // `spirv.mlir.selection` is not needed anymore.
1311 rewriter.eraseOp(op);
1312 return success();
1313 }
1314
1315 private:
1316 // Checks that given blocks follow the following rules:
1317 // 1. Each conditional block consists of two operations, the first operation
1318 // is a `spirv.Store` and the last operation is a `spirv.Branch`.
1319 // 2. Each `spirv.Store` uses the same pointer and the same memory attributes.
1320 // 3. A control flow goes into the given merge block from the given
1321 // conditional blocks.
1322 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1323 Block *mergeBlock) const;
1324
onlyContainsBranchConditionalOp__anona5f61ae82511::ConvertSelectionOpToSelect1325 bool onlyContainsBranchConditionalOp(Block *block) const {
1326 return llvm::hasSingleElement(*block) &&
1327 isa<spirv::BranchConditionalOp>(block->front());
1328 }
1329
isSameAttrList__anona5f61ae82511::ConvertSelectionOpToSelect1330 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1331 return lhs->getDiscardableAttrDictionary() ==
1332 rhs->getDiscardableAttrDictionary() &&
1333 lhs.getProperties() == rhs.getProperties();
1334 }
1335
1336 // Returns a source value for the given block.
getSrcValue__anona5f61ae82511::ConvertSelectionOpToSelect1337 Value getSrcValue(Block *block) const {
1338 auto storeOp = cast<spirv::StoreOp>(block->front());
1339 return storeOp.getValue();
1340 }
1341
1342 // Returns a destination value for the given block.
getDstPtr__anona5f61ae82511::ConvertSelectionOpToSelect1343 Value getDstPtr(Block *block) const {
1344 auto storeOp = cast<spirv::StoreOp>(block->front());
1345 return storeOp.getPtr();
1346 }
1347 };
1348
canCanonicalizeSelection(Block * trueBlock,Block * falseBlock,Block * mergeBlock) const1349 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1350 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1351 // Each block must consists of 2 operations.
1352 if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1353 return failure();
1354 }
1355
1356 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
1357 auto trueBrBranchOp =
1358 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
1359 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
1360 auto falseBrBranchOp =
1361 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
1362
1363 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1364 !falseBrBranchOp) {
1365 return failure();
1366 }
1367
1368 // Checks that given type is valid for `spirv.SelectOp`.
1369 // According to SPIR-V spec:
1370 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
1371 // Starting with version 1.4, Result Type can additionally be a composite type
1372 // other than a vector."
1373 bool isScalarOrVector =
1374 llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1375 .isScalarOrVector();
1376
1377 // Check that each `spirv.Store` uses the same pointer, memory access
1378 // attributes and a valid type of the value.
1379 if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1380 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1381 return failure();
1382 }
1383
1384 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1385 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1386 return failure();
1387 }
1388
1389 return success();
1390 }
1391 } // namespace
1392
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1393 void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1394 MLIRContext *context) {
1395 results.add<ConvertSelectionOpToSelect>(context);
1396 }
1397