xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp (revision 4aba595f092e8e05e92656b23944ce6619465a78)
1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <optional>
14 #include <utility>
15 
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 
18 #include "mlir/Dialect/CommonFolders.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
20 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
21 #include "mlir/Dialect/UB/IR/UBOps.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallVectorExtras.h"
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // Common utility functions
31 //===----------------------------------------------------------------------===//
32 
33 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
34 /// or splat vector bool constant.
getScalarOrSplatBoolAttr(Attribute attr)35 static std::optional<bool> getScalarOrSplatBoolAttr(Attribute attr) {
36   if (!attr)
37     return std::nullopt;
38 
39   if (auto boolAttr = llvm::dyn_cast<BoolAttr>(attr))
40     return boolAttr.getValue();
41   if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(attr))
42     if (splatAttr.getElementType().isInteger(1))
43       return splatAttr.getSplatValue<bool>();
44   return std::nullopt;
45 }
46 
47 // Extracts an element from the given `composite` by following the given
48 // `indices`. Returns a null Attribute if error happens.
extractCompositeElement(Attribute composite,ArrayRef<unsigned> indices)49 static Attribute extractCompositeElement(Attribute composite,
50                                          ArrayRef<unsigned> indices) {
51   // Check that given composite is a constant.
52   if (!composite)
53     return {};
54   // Return composite itself if we reach the end of the index chain.
55   if (indices.empty())
56     return composite;
57 
58   if (auto vector = llvm::dyn_cast<ElementsAttr>(composite)) {
59     assert(indices.size() == 1 && "must have exactly one index for a vector");
60     return vector.getValues<Attribute>()[indices[0]];
61   }
62 
63   if (auto array = llvm::dyn_cast<ArrayAttr>(composite)) {
64     assert(!indices.empty() && "must have at least one index for an array");
65     return extractCompositeElement(array.getValue()[indices[0]],
66                                    indices.drop_front());
67   }
68 
69   return {};
70 }
71 
isDivZeroOrOverflow(const APInt & a,const APInt & b)72 static bool isDivZeroOrOverflow(const APInt &a, const APInt &b) {
73   bool div0 = b.isZero();
74   bool overflow = a.isMinSignedValue() && b.isAllOnes();
75 
76   return div0 || overflow;
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // TableGen'erated canonicalizers
81 //===----------------------------------------------------------------------===//
82 
83 namespace {
84 #include "SPIRVCanonicalization.inc"
85 } // namespace
86 
87 //===----------------------------------------------------------------------===//
88 // spirv.AccessChainOp
89 //===----------------------------------------------------------------------===//
90 
91 namespace {
92 
93 /// Combines chained `spirv::AccessChainOp` operations into one
94 /// `spirv::AccessChainOp` operation.
95 struct CombineChainedAccessChain final
96     : OpRewritePattern<spirv::AccessChainOp> {
97   using OpRewritePattern::OpRewritePattern;
98 
matchAndRewrite__anona5f61ae80211::CombineChainedAccessChain99   LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
100                                 PatternRewriter &rewriter) const override {
101     auto parentAccessChainOp =
102         accessChainOp.getBasePtr().getDefiningOp<spirv::AccessChainOp>();
103 
104     if (!parentAccessChainOp) {
105       return failure();
106     }
107 
108     // Combine indices.
109     SmallVector<Value, 4> indices(parentAccessChainOp.getIndices());
110     llvm::append_range(indices, accessChainOp.getIndices());
111 
112     rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
113         accessChainOp, parentAccessChainOp.getBasePtr(), indices);
114 
115     return success();
116   }
117 };
118 } // namespace
119 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)120 void spirv::AccessChainOp::getCanonicalizationPatterns(
121     RewritePatternSet &results, MLIRContext *context) {
122   results.add<CombineChainedAccessChain>(context);
123 }
124 
125 //===----------------------------------------------------------------------===//
126 // spirv.IAddCarry
127 //===----------------------------------------------------------------------===//
128 
129 // We are required to use CompositeConstructOp to create a constant struct as
130 // they are not yet implemented as constant, hence we can not do so in a fold.
131 struct IAddCarryFold final : OpRewritePattern<spirv::IAddCarryOp> {
132   using OpRewritePattern::OpRewritePattern;
133 
matchAndRewriteIAddCarryFold134   LogicalResult matchAndRewrite(spirv::IAddCarryOp op,
135                                 PatternRewriter &rewriter) const override {
136     Location loc = op.getLoc();
137     Value lhs = op.getOperand1();
138     Value rhs = op.getOperand2();
139     Type constituentType = lhs.getType();
140 
141     // iaddcarry (x, 0) = <0, x>
142     if (matchPattern(rhs, m_Zero())) {
143       Value constituents[2] = {rhs, lhs};
144       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
145                                                                constituents);
146       return success();
147     }
148 
149     // According to the SPIR-V spec:
150     //
151     //  Result Type must be from OpTypeStruct.  The struct must have two
152     //  members...
153     //
154     //  Member 0 of the result gets the low-order bits (full component width) of
155     //  the addition.
156     //
157     //  Member 1 of the result gets the high-order (carry) bit of the result of
158     //  the addition. That is, it gets the value 1 if the addition overflowed
159     //  the component width, and 0 otherwise.
160     Attribute lhsAttr;
161     Attribute rhsAttr;
162     if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
163         !matchPattern(rhs, m_Constant(&rhsAttr)))
164       return failure();
165 
166     auto adds = constFoldBinaryOp<IntegerAttr>(
167         {lhsAttr, rhsAttr},
168         [](const APInt &a, const APInt &b) { return a + b; });
169     if (!adds)
170       return failure();
171 
172     auto carrys = constFoldBinaryOp<IntegerAttr>(
173         ArrayRef{adds, lhsAttr}, [](const APInt &a, const APInt &b) {
174           APInt zero = APInt::getZero(a.getBitWidth());
175           return a.ult(b) ? (zero + 1) : zero;
176         });
177 
178     if (!carrys)
179       return failure();
180 
181     Value addsVal =
182         rewriter.create<spirv::ConstantOp>(loc, constituentType, adds);
183 
184     Value carrysVal =
185         rewriter.create<spirv::ConstantOp>(loc, constituentType, carrys);
186 
187     // Create empty struct
188     Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
189     // Fill in adds at id 0
190     Value intermediate =
191         rewriter.create<spirv::CompositeInsertOp>(loc, addsVal, undef, 0);
192     // Fill in carrys at id 1
193     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, carrysVal,
194                                                           intermediate, 1);
195     return success();
196   }
197 };
198 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)199 void spirv::IAddCarryOp::getCanonicalizationPatterns(
200     RewritePatternSet &patterns, MLIRContext *context) {
201   patterns.add<IAddCarryFold>(context);
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // spirv.[S|U]MulExtended
206 //===----------------------------------------------------------------------===//
207 
208 // We are required to use CompositeConstructOp to create a constant struct as
209 // they are not yet implemented as constant, hence we can not do so in a fold.
210 template <typename MulOp, bool IsSigned>
211 struct MulExtendedFold final : OpRewritePattern<MulOp> {
212   using OpRewritePattern<MulOp>::OpRewritePattern;
213 
matchAndRewriteMulExtendedFold214   LogicalResult matchAndRewrite(MulOp op,
215                                 PatternRewriter &rewriter) const override {
216     Location loc = op.getLoc();
217     Value lhs = op.getOperand1();
218     Value rhs = op.getOperand2();
219     Type constituentType = lhs.getType();
220 
221     // [su]mulextended (x, 0) = <0, 0>
222     if (matchPattern(rhs, m_Zero())) {
223       Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
224       Value constituents[2] = {zero, zero};
225       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
226                                                                constituents);
227       return success();
228     }
229 
230     // According to the SPIR-V spec:
231     //
232     // Result Type must be from OpTypeStruct.  The struct must have two
233     // members...
234     //
235     // Member 0 of the result gets the low-order bits of the multiplication.
236     //
237     // Member 1 of the result gets the high-order bits of the multiplication.
238     Attribute lhsAttr;
239     Attribute rhsAttr;
240     if (!matchPattern(lhs, m_Constant(&lhsAttr)) ||
241         !matchPattern(rhs, m_Constant(&rhsAttr)))
242       return failure();
243 
244     auto lowBits = constFoldBinaryOp<IntegerAttr>(
245         {lhsAttr, rhsAttr},
246         [](const APInt &a, const APInt &b) { return a * b; });
247 
248     if (!lowBits)
249       return failure();
250 
251     auto highBits = constFoldBinaryOp<IntegerAttr>(
252         {lhsAttr, rhsAttr}, [](const APInt &a, const APInt &b) {
253           if (IsSigned) {
254             return llvm::APIntOps::mulhs(a, b);
255           } else {
256             return llvm::APIntOps::mulhu(a, b);
257           }
258         });
259 
260     if (!highBits)
261       return failure();
262 
263     Value lowBitsVal =
264         rewriter.create<spirv::ConstantOp>(loc, constituentType, lowBits);
265 
266     Value highBitsVal =
267         rewriter.create<spirv::ConstantOp>(loc, constituentType, highBits);
268 
269     // Create empty struct
270     Value undef = rewriter.create<spirv::UndefOp>(loc, op.getType());
271     // Fill in lowBits at id 0
272     Value intermediate =
273         rewriter.create<spirv::CompositeInsertOp>(loc, lowBitsVal, undef, 0);
274     // Fill in highBits at id 1
275     rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(op, highBitsVal,
276                                                           intermediate, 1);
277     return success();
278   }
279 };
280 
281 using SMulExtendedOpFold = MulExtendedFold<spirv::SMulExtendedOp, true>;
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)282 void spirv::SMulExtendedOp::getCanonicalizationPatterns(
283     RewritePatternSet &patterns, MLIRContext *context) {
284   patterns.add<SMulExtendedOpFold>(context);
285 }
286 
287 struct UMulExtendedOpXOne final : OpRewritePattern<spirv::UMulExtendedOp> {
288   using OpRewritePattern::OpRewritePattern;
289 
matchAndRewriteUMulExtendedOpXOne290   LogicalResult matchAndRewrite(spirv::UMulExtendedOp op,
291                                 PatternRewriter &rewriter) const override {
292     Location loc = op.getLoc();
293     Value lhs = op.getOperand1();
294     Value rhs = op.getOperand2();
295     Type constituentType = lhs.getType();
296 
297     // umulextended (x, 1) = <x, 0>
298     if (matchPattern(rhs, m_One())) {
299       Value zero = spirv::ConstantOp::getZero(constituentType, loc, rewriter);
300       Value constituents[2] = {lhs, zero};
301       rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, op.getType(),
302                                                                constituents);
303       return success();
304     }
305 
306     return failure();
307   }
308 };
309 
310 using UMulExtendedOpFold = MulExtendedFold<spirv::UMulExtendedOp, false>;
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)311 void spirv::UMulExtendedOp::getCanonicalizationPatterns(
312     RewritePatternSet &patterns, MLIRContext *context) {
313   patterns.add<UMulExtendedOpFold, UMulExtendedOpXOne>(context);
314 }
315 
316 //===----------------------------------------------------------------------===//
317 // spirv.UMod
318 //===----------------------------------------------------------------------===//
319 
320 // Input:
321 //    %0 = spirv.UMod %arg0, %const32 : i32
322 //    %1 = spirv.UMod %0, %const4 : i32
323 // Output:
324 //    %0 = spirv.UMod %arg0, %const32 : i32
325 //    %1 = spirv.UMod %arg0, %const4 : i32
326 
327 // The transformation is only applied if one divisor is a multiple of the other.
328 
329 // TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
330 struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
331   using OpRewritePattern::OpRewritePattern;
332 
matchAndRewriteUModSimplification333   LogicalResult matchAndRewrite(spirv::UModOp umodOp,
334                                 PatternRewriter &rewriter) const override {
335     auto prevUMod = umodOp.getOperand(0).getDefiningOp<spirv::UModOp>();
336     if (!prevUMod)
337       return failure();
338 
339     IntegerAttr prevValue;
340     IntegerAttr currValue;
341     if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
342         !matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
343       return failure();
344 
345     APInt prevConstValue = prevValue.getValue();
346     APInt currConstValue = currValue.getValue();
347 
348     // Ensure that one divisor is a multiple of the other. If not, fail the
349     // transformation.
350     if (prevConstValue.urem(currConstValue) != 0 &&
351         currConstValue.urem(prevConstValue) != 0)
352       return failure();
353 
354     // The transformation is safe. Replace the existing UMod operation with a
355     // new UMod operation, using the original dividend and the current divisor.
356     rewriter.replaceOpWithNewOp<spirv::UModOp>(
357         umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1));
358 
359     return success();
360   }
361 };
362 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)363 void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
364                                                 MLIRContext *context) {
365   patterns.insert<UModSimplification>(context);
366 }
367 
368 //===----------------------------------------------------------------------===//
369 // spirv.BitcastOp
370 //===----------------------------------------------------------------------===//
371 
fold(FoldAdaptor)372 OpFoldResult spirv::BitcastOp::fold(FoldAdaptor /*adaptor*/) {
373   Value curInput = getOperand();
374   if (getType() == curInput.getType())
375     return curInput;
376 
377   // Look through nested bitcasts.
378   if (auto prevCast = curInput.getDefiningOp<spirv::BitcastOp>()) {
379     Value prevInput = prevCast.getOperand();
380     if (prevInput.getType() == getType())
381       return prevInput;
382 
383     getOperandMutable().assign(prevInput);
384     return getResult();
385   }
386 
387   // TODO(kuhar): Consider constant-folding the operand attribute.
388   return {};
389 }
390 
391 //===----------------------------------------------------------------------===//
392 // spirv.CompositeExtractOp
393 //===----------------------------------------------------------------------===//
394 
fold(FoldAdaptor adaptor)395 OpFoldResult spirv::CompositeExtractOp::fold(FoldAdaptor adaptor) {
396   Value compositeOp = getComposite();
397 
398   while (auto insertOp =
399              compositeOp.getDefiningOp<spirv::CompositeInsertOp>()) {
400     if (getIndices() == insertOp.getIndices())
401       return insertOp.getObject();
402     compositeOp = insertOp.getComposite();
403   }
404 
405   if (auto constructOp =
406           compositeOp.getDefiningOp<spirv::CompositeConstructOp>()) {
407     auto type = llvm::cast<spirv::CompositeType>(constructOp.getType());
408     if (getIndices().size() == 1 &&
409         constructOp.getConstituents().size() == type.getNumElements()) {
410       auto i = llvm::cast<IntegerAttr>(*getIndices().begin());
411       if (i.getValue().getSExtValue() <
412           static_cast<int64_t>(constructOp.getConstituents().size()))
413         return constructOp.getConstituents()[i.getValue().getSExtValue()];
414     }
415   }
416 
417   auto indexVector = llvm::map_to_vector(getIndices(), [](Attribute attr) {
418     return static_cast<unsigned>(llvm::cast<IntegerAttr>(attr).getInt());
419   });
420   return extractCompositeElement(adaptor.getComposite(), indexVector);
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // spirv.Constant
425 //===----------------------------------------------------------------------===//
426 
fold(FoldAdaptor)427 OpFoldResult spirv::ConstantOp::fold(FoldAdaptor /*adaptor*/) {
428   return getValue();
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // spirv.IAdd
433 //===----------------------------------------------------------------------===//
434 
fold(FoldAdaptor adaptor)435 OpFoldResult spirv::IAddOp::fold(FoldAdaptor adaptor) {
436   // x + 0 = x
437   if (matchPattern(getOperand2(), m_Zero()))
438     return getOperand1();
439 
440   // According to the SPIR-V spec:
441   //
442   // The resulting value will equal the low-order N bits of the correct result
443   // R, where N is the component width and R is computed with enough precision
444   // to avoid overflow and underflow.
445   return constFoldBinaryOp<IntegerAttr>(
446       adaptor.getOperands(),
447       [](APInt a, const APInt &b) { return std::move(a) + b; });
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // spirv.IMul
452 //===----------------------------------------------------------------------===//
453 
fold(FoldAdaptor adaptor)454 OpFoldResult spirv::IMulOp::fold(FoldAdaptor adaptor) {
455   // x * 0 == 0
456   if (matchPattern(getOperand2(), m_Zero()))
457     return getOperand2();
458   // x * 1 = x
459   if (matchPattern(getOperand2(), m_One()))
460     return getOperand1();
461 
462   // According to the SPIR-V spec:
463   //
464   // The resulting value will equal the low-order N bits of the correct result
465   // R, where N is the component width and R is computed with enough precision
466   // to avoid overflow and underflow.
467   return constFoldBinaryOp<IntegerAttr>(
468       adaptor.getOperands(),
469       [](const APInt &a, const APInt &b) { return a * b; });
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // spirv.ISub
474 //===----------------------------------------------------------------------===//
475 
fold(FoldAdaptor adaptor)476 OpFoldResult spirv::ISubOp::fold(FoldAdaptor adaptor) {
477   // x - x = 0
478   if (getOperand1() == getOperand2())
479     return Builder(getContext()).getIntegerAttr(getType(), 0);
480 
481   // According to the SPIR-V spec:
482   //
483   // The resulting value will equal the low-order N bits of the correct result
484   // R, where N is the component width and R is computed with enough precision
485   // to avoid overflow and underflow.
486   return constFoldBinaryOp<IntegerAttr>(
487       adaptor.getOperands(),
488       [](APInt a, const APInt &b) { return std::move(a) - b; });
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // spirv.SDiv
493 //===----------------------------------------------------------------------===//
494 
fold(FoldAdaptor adaptor)495 OpFoldResult spirv::SDivOp::fold(FoldAdaptor adaptor) {
496   // sdiv (x, 1) = x
497   if (matchPattern(getOperand2(), m_One()))
498     return getOperand1();
499 
500   // According to the SPIR-V spec:
501   //
502   // Signed-integer division of Operand 1 divided by Operand 2.
503   // Results are computed per component. Behavior is undefined if Operand 2 is
504   // 0. Behavior is undefined if Operand 2 is -1 and Operand 1 is the minimum
505   // representable value for the operands' type, causing signed overflow.
506   //
507   // So don't fold during undefined behavior.
508   bool div0OrOverflow = false;
509   auto res = constFoldBinaryOp<IntegerAttr>(
510       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
511         if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
512           div0OrOverflow = true;
513           return a;
514         }
515         return a.sdiv(b);
516       });
517   return div0OrOverflow ? Attribute() : res;
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // spirv.SMod
522 //===----------------------------------------------------------------------===//
523 
fold(FoldAdaptor adaptor)524 OpFoldResult spirv::SModOp::fold(FoldAdaptor adaptor) {
525   // smod (x, 1) = 0
526   if (matchPattern(getOperand2(), m_One()))
527     return Builder(getContext()).getZeroAttr(getType());
528 
529   // According to SPIR-V spec:
530   //
531   // Signed remainder operation for the remainder whose sign matches the sign
532   // of Operand 2. Behavior is undefined if Operand 2 is 0. Behavior is
533   // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
534   // value for the operands' type, causing signed overflow. Otherwise, the
535   // result is the remainder r of Operand 1 divided by Operand 2 where if
536   // r ≠ 0, the sign of r is the same as the sign of Operand 2.
537   //
538   // So don't fold during undefined behavior
539   bool div0OrOverflow = false;
540   auto res = constFoldBinaryOp<IntegerAttr>(
541       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
542         if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
543           div0OrOverflow = true;
544           return a;
545         }
546         APInt c = a.abs().urem(b.abs());
547         if (c.isZero())
548           return c;
549         if (b.isNegative()) {
550           APInt zero = APInt::getZero(c.getBitWidth());
551           return a.isNegative() ? (zero - c) : (b + c);
552         }
553         return a.isNegative() ? (b - c) : c;
554       });
555   return div0OrOverflow ? Attribute() : res;
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // spirv.SRem
560 //===----------------------------------------------------------------------===//
561 
fold(FoldAdaptor adaptor)562 OpFoldResult spirv::SRemOp::fold(FoldAdaptor adaptor) {
563   // x % 1 = 0
564   if (matchPattern(getOperand2(), m_One()))
565     return Builder(getContext()).getZeroAttr(getType());
566 
567   // According to SPIR-V spec:
568   //
569   // Signed remainder operation for the remainder whose sign matches the sign
570   // of Operand 1. Behavior is undefined if Operand 2 is 0. Behavior is
571   // undefined if Operand 2 is -1 and Operand 1 is the minimum representable
572   // value for the operands' type, causing signed overflow. Otherwise, the
573   // result is the remainder r of Operand 1 divided by Operand 2 where if
574   // r ≠ 0, the sign of r is the same as the sign of Operand 1.
575 
576   // Don't fold if it would do undefined behavior.
577   bool div0OrOverflow = false;
578   auto res = constFoldBinaryOp<IntegerAttr>(
579       adaptor.getOperands(), [&](APInt a, const APInt &b) {
580         if (div0OrOverflow || isDivZeroOrOverflow(a, b)) {
581           div0OrOverflow = true;
582           return a;
583         }
584         return a.srem(b);
585       });
586   return div0OrOverflow ? Attribute() : res;
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // spirv.UDiv
591 //===----------------------------------------------------------------------===//
592 
fold(FoldAdaptor adaptor)593 OpFoldResult spirv::UDivOp::fold(FoldAdaptor adaptor) {
594   // udiv (x, 1) = x
595   if (matchPattern(getOperand2(), m_One()))
596     return getOperand1();
597 
598   // According to the SPIR-V spec:
599   //
600   // Unsigned-integer division of Operand 1 divided by Operand 2. Behavior is
601   // undefined if Operand 2 is 0.
602   //
603   // So don't fold during undefined behavior.
604   bool div0 = false;
605   auto res = constFoldBinaryOp<IntegerAttr>(
606       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
607         if (div0 || b.isZero()) {
608           div0 = true;
609           return a;
610         }
611         return a.udiv(b);
612       });
613   return div0 ? Attribute() : res;
614 }
615 
616 //===----------------------------------------------------------------------===//
617 // spirv.UMod
618 //===----------------------------------------------------------------------===//
619 
fold(FoldAdaptor adaptor)620 OpFoldResult spirv::UModOp::fold(FoldAdaptor adaptor) {
621   // umod (x, 1) = 0
622   if (matchPattern(getOperand2(), m_One()))
623     return Builder(getContext()).getZeroAttr(getType());
624 
625   // According to the SPIR-V spec:
626   //
627   // Unsigned modulo operation of Operand 1 modulo Operand 2. Behavior is
628   // undefined if Operand 2 is 0.
629   //
630   // So don't fold during undefined behavior.
631   bool div0 = false;
632   auto res = constFoldBinaryOp<IntegerAttr>(
633       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
634         if (div0 || b.isZero()) {
635           div0 = true;
636           return a;
637         }
638         return a.urem(b);
639       });
640   return div0 ? Attribute() : res;
641 }
642 
643 //===----------------------------------------------------------------------===//
644 // spirv.SNegate
645 //===----------------------------------------------------------------------===//
646 
fold(FoldAdaptor adaptor)647 OpFoldResult spirv::SNegateOp::fold(FoldAdaptor adaptor) {
648   // -(-x) = 0 - (0 - x) = x
649   auto op = getOperand();
650   if (auto negateOp = op.getDefiningOp<spirv::SNegateOp>())
651     return negateOp->getOperand(0);
652 
653   // According to the SPIR-V spec:
654   //
655   // Signed-integer subtract of Operand from zero.
656   return constFoldUnaryOp<IntegerAttr>(
657       adaptor.getOperands(), [](const APInt &a) {
658         APInt zero = APInt::getZero(a.getBitWidth());
659         return zero - a;
660       });
661 }
662 
663 //===----------------------------------------------------------------------===//
664 // spirv.NotOp
665 //===----------------------------------------------------------------------===//
666 
fold(spirv::NotOp::FoldAdaptor adaptor)667 OpFoldResult spirv::NotOp::fold(spirv::NotOp::FoldAdaptor adaptor) {
668   // !(!x) = x
669   auto op = getOperand();
670   if (auto notOp = op.getDefiningOp<spirv::NotOp>())
671     return notOp->getOperand(0);
672 
673   // According to the SPIR-V spec:
674   //
675   // Complement the bits of Operand.
676   return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(), [&](APInt a) {
677     a.flipAllBits();
678     return a;
679   });
680 }
681 
682 //===----------------------------------------------------------------------===//
683 // spirv.LogicalAnd
684 //===----------------------------------------------------------------------===//
685 
fold(FoldAdaptor adaptor)686 OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
687   if (std::optional<bool> rhs =
688           getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
689     // x && true = x
690     if (*rhs)
691       return getOperand1();
692 
693     // x && false = false
694     if (!*rhs)
695       return adaptor.getOperand2();
696   }
697 
698   return Attribute();
699 }
700 
701 //===----------------------------------------------------------------------===//
702 // spirv.LogicalEqualOp
703 //===----------------------------------------------------------------------===//
704 
705 OpFoldResult
fold(spirv::LogicalEqualOp::FoldAdaptor adaptor)706 spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
707   // x == x -> true
708   if (getOperand1() == getOperand2()) {
709     auto trueAttr = BoolAttr::get(getContext(), true);
710     if (isa<IntegerType>(getType()))
711       return trueAttr;
712     if (auto vecTy = dyn_cast<VectorType>(getType()))
713       return SplatElementsAttr::get(vecTy, trueAttr);
714   }
715 
716   return constFoldBinaryOp<IntegerAttr>(
717       adaptor.getOperands(), [](const APInt &a, const APInt &b) {
718         return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
719       });
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // spirv.LogicalNotEqualOp
724 //===----------------------------------------------------------------------===//
725 
fold(FoldAdaptor adaptor)726 OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
727   if (std::optional<bool> rhs =
728           getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
729     // x != false -> x
730     if (!rhs.value())
731       return getOperand1();
732   }
733 
734   // x == x -> false
735   if (getOperand1() == getOperand2()) {
736     auto falseAttr = BoolAttr::get(getContext(), false);
737     if (isa<IntegerType>(getType()))
738       return falseAttr;
739     if (auto vecTy = dyn_cast<VectorType>(getType()))
740       return SplatElementsAttr::get(vecTy, falseAttr);
741   }
742 
743   return constFoldBinaryOp<IntegerAttr>(
744       adaptor.getOperands(), [](const APInt &a, const APInt &b) {
745         return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
746       });
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // spirv.LogicalNot
751 //===----------------------------------------------------------------------===//
752 
fold(FoldAdaptor adaptor)753 OpFoldResult spirv::LogicalNotOp::fold(FoldAdaptor adaptor) {
754   // !(!x) = x
755   auto op = getOperand();
756   if (auto notOp = op.getDefiningOp<spirv::LogicalNotOp>())
757     return notOp->getOperand(0);
758 
759   // According to the SPIR-V spec:
760   //
761   // Complement the bits of Operand.
762   return constFoldUnaryOp<IntegerAttr>(adaptor.getOperands(),
763                                        [](const APInt &a) {
764                                          APInt zero = APInt::getZero(1);
765                                          return a == 1 ? zero : (zero + 1);
766                                        });
767 }
768 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)769 void spirv::LogicalNotOp::getCanonicalizationPatterns(
770     RewritePatternSet &results, MLIRContext *context) {
771   results
772       .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
773            ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
774           context);
775 }
776 
777 //===----------------------------------------------------------------------===//
778 // spirv.LogicalOr
779 //===----------------------------------------------------------------------===//
780 
fold(FoldAdaptor adaptor)781 OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
782   if (auto rhs = getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
783     if (*rhs) {
784       // x || true = true
785       return adaptor.getOperand2();
786     }
787 
788     if (!*rhs) {
789       // x || false = x
790       return getOperand1();
791     }
792   }
793 
794   return Attribute();
795 }
796 
797 //===----------------------------------------------------------------------===//
798 // spirv.SelectOp
799 //===----------------------------------------------------------------------===//
800 
fold(FoldAdaptor adaptor)801 OpFoldResult spirv::SelectOp::fold(FoldAdaptor adaptor) {
802   // spirv.Select _ x x -> x
803   Value trueVals = getTrueValue();
804   Value falseVals = getFalseValue();
805   if (trueVals == falseVals)
806     return trueVals;
807 
808   ArrayRef<Attribute> operands = adaptor.getOperands();
809 
810   // spirv.Select true  x y -> x
811   // spirv.Select false x y -> y
812   if (auto boolAttr = getScalarOrSplatBoolAttr(operands[0]))
813     return *boolAttr ? trueVals : falseVals;
814 
815   // Check that all the operands are constant
816   if (!operands[0] || !operands[1] || !operands[2])
817     return Attribute();
818 
819   // Note: getScalarOrSplatBoolAttr will always return a boolAttr if we are in
820   // the scalar case. Hence, we are only required to consider the case of
821   // DenseElementsAttr in foldSelectOp.
822   auto condAttrs = dyn_cast<DenseElementsAttr>(operands[0]);
823   auto trueAttrs = dyn_cast<DenseElementsAttr>(operands[1]);
824   auto falseAttrs = dyn_cast<DenseElementsAttr>(operands[2]);
825   if (!condAttrs || !trueAttrs || !falseAttrs)
826     return Attribute();
827 
828   auto elementResults = llvm::to_vector<4>(trueAttrs.getValues<Attribute>());
829   auto iters = llvm::zip_equal(elementResults, condAttrs.getValues<BoolAttr>(),
830                                falseAttrs.getValues<Attribute>());
831   for (auto [result, cond, falseRes] : iters) {
832     if (!cond.getValue())
833       result = falseRes;
834   }
835 
836   auto resultType = trueAttrs.getType();
837   return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
838 }
839 
840 //===----------------------------------------------------------------------===//
841 // spirv.IEqualOp
842 //===----------------------------------------------------------------------===//
843 
fold(spirv::IEqualOp::FoldAdaptor adaptor)844 OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
845   // x == x -> true
846   if (getOperand1() == getOperand2()) {
847     auto trueAttr = BoolAttr::get(getContext(), true);
848     if (isa<IntegerType>(getType()))
849       return trueAttr;
850     if (auto vecTy = dyn_cast<VectorType>(getType()))
851       return SplatElementsAttr::get(vecTy, trueAttr);
852   }
853 
854   return constFoldBinaryOp<IntegerAttr>(
855       adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
856         return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
857       });
858 }
859 
860 //===----------------------------------------------------------------------===//
861 // spirv.INotEqualOp
862 //===----------------------------------------------------------------------===//
863 
fold(spirv::INotEqualOp::FoldAdaptor adaptor)864 OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
865   // x == x -> false
866   if (getOperand1() == getOperand2()) {
867     auto falseAttr = BoolAttr::get(getContext(), false);
868     if (isa<IntegerType>(getType()))
869       return falseAttr;
870     if (auto vecTy = dyn_cast<VectorType>(getType()))
871       return SplatElementsAttr::get(vecTy, falseAttr);
872   }
873 
874   return constFoldBinaryOp<IntegerAttr>(
875       adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
876         return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
877       });
878 }
879 
880 //===----------------------------------------------------------------------===//
881 // spirv.SGreaterThan
882 //===----------------------------------------------------------------------===//
883 
884 OpFoldResult
fold(spirv::SGreaterThanOp::FoldAdaptor adaptor)885 spirv::SGreaterThanOp::fold(spirv::SGreaterThanOp::FoldAdaptor adaptor) {
886   // x == x -> false
887   if (getOperand1() == getOperand2()) {
888     auto falseAttr = BoolAttr::get(getContext(), false);
889     if (isa<IntegerType>(getType()))
890       return falseAttr;
891     if (auto vecTy = dyn_cast<VectorType>(getType()))
892       return SplatElementsAttr::get(vecTy, falseAttr);
893   }
894 
895   return constFoldBinaryOp<IntegerAttr>(
896       adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
897         return a.sgt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
898       });
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // spirv.SGreaterThanEqual
903 //===----------------------------------------------------------------------===//
904 
fold(spirv::SGreaterThanEqualOp::FoldAdaptor adaptor)905 OpFoldResult spirv::SGreaterThanEqualOp::fold(
906     spirv::SGreaterThanEqualOp::FoldAdaptor adaptor) {
907   // x == x -> true
908   if (getOperand1() == getOperand2()) {
909     auto trueAttr = BoolAttr::get(getContext(), true);
910     if (isa<IntegerType>(getType()))
911       return trueAttr;
912     if (auto vecTy = dyn_cast<VectorType>(getType()))
913       return SplatElementsAttr::get(vecTy, trueAttr);
914   }
915 
916   return constFoldBinaryOp<IntegerAttr>(
917       adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
918         return a.sge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
919       });
920 }
921 
922 //===----------------------------------------------------------------------===//
923 // spirv.UGreaterThan
924 //===----------------------------------------------------------------------===//
925 
926 OpFoldResult
fold(spirv::UGreaterThanOp::FoldAdaptor adaptor)927 spirv::UGreaterThanOp::fold(spirv::UGreaterThanOp::FoldAdaptor adaptor) {
928   // x == x -> false
929   if (getOperand1() == getOperand2()) {
930     auto falseAttr = BoolAttr::get(getContext(), false);
931     if (isa<IntegerType>(getType()))
932       return falseAttr;
933     if (auto vecTy = dyn_cast<VectorType>(getType()))
934       return SplatElementsAttr::get(vecTy, falseAttr);
935   }
936 
937   return constFoldBinaryOp<IntegerAttr>(
938       adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
939         return a.ugt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
940       });
941 }
942 
943 //===----------------------------------------------------------------------===//
944 // spirv.UGreaterThanEqual
945 //===----------------------------------------------------------------------===//
946 
fold(spirv::UGreaterThanEqualOp::FoldAdaptor adaptor)947 OpFoldResult spirv::UGreaterThanEqualOp::fold(
948     spirv::UGreaterThanEqualOp::FoldAdaptor adaptor) {
949   // x == x -> true
950   if (getOperand1() == getOperand2()) {
951     auto trueAttr = BoolAttr::get(getContext(), true);
952     if (isa<IntegerType>(getType()))
953       return trueAttr;
954     if (auto vecTy = dyn_cast<VectorType>(getType()))
955       return SplatElementsAttr::get(vecTy, trueAttr);
956   }
957 
958   return constFoldBinaryOp<IntegerAttr>(
959       adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
960         return a.uge(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
961       });
962 }
963 
964 //===----------------------------------------------------------------------===//
965 // spirv.SLessThan
966 //===----------------------------------------------------------------------===//
967 
fold(spirv::SLessThanOp::FoldAdaptor adaptor)968 OpFoldResult spirv::SLessThanOp::fold(spirv::SLessThanOp::FoldAdaptor adaptor) {
969   // x == x -> false
970   if (getOperand1() == getOperand2()) {
971     auto falseAttr = BoolAttr::get(getContext(), false);
972     if (isa<IntegerType>(getType()))
973       return falseAttr;
974     if (auto vecTy = dyn_cast<VectorType>(getType()))
975       return SplatElementsAttr::get(vecTy, falseAttr);
976   }
977 
978   return constFoldBinaryOp<IntegerAttr>(
979       adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
980         return a.slt(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
981       });
982 }
983 
984 //===----------------------------------------------------------------------===//
985 // spirv.SLessThanEqual
986 //===----------------------------------------------------------------------===//
987 
988 OpFoldResult
fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor)989 spirv::SLessThanEqualOp::fold(spirv::SLessThanEqualOp::FoldAdaptor adaptor) {
990   // x == x -> true
991   if (getOperand1() == getOperand2()) {
992     auto trueAttr = BoolAttr::get(getContext(), true);
993     if (isa<IntegerType>(getType()))
994       return trueAttr;
995     if (auto vecTy = dyn_cast<VectorType>(getType()))
996       return SplatElementsAttr::get(vecTy, trueAttr);
997   }
998 
999   return constFoldBinaryOp<IntegerAttr>(
1000       adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1001         return a.sle(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1002       });
1003 }
1004 
1005 //===----------------------------------------------------------------------===//
1006 // spirv.ULessThan
1007 //===----------------------------------------------------------------------===//
1008 
fold(spirv::ULessThanOp::FoldAdaptor adaptor)1009 OpFoldResult spirv::ULessThanOp::fold(spirv::ULessThanOp::FoldAdaptor adaptor) {
1010   // x == x -> false
1011   if (getOperand1() == getOperand2()) {
1012     auto falseAttr = BoolAttr::get(getContext(), false);
1013     if (isa<IntegerType>(getType()))
1014       return falseAttr;
1015     if (auto vecTy = dyn_cast<VectorType>(getType()))
1016       return SplatElementsAttr::get(vecTy, falseAttr);
1017   }
1018 
1019   return constFoldBinaryOp<IntegerAttr>(
1020       adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1021         return a.ult(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1022       });
1023 }
1024 
1025 //===----------------------------------------------------------------------===//
1026 // spirv.ULessThanEqual
1027 //===----------------------------------------------------------------------===//
1028 
1029 OpFoldResult
fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor)1030 spirv::ULessThanEqualOp::fold(spirv::ULessThanEqualOp::FoldAdaptor adaptor) {
1031   // x == x -> true
1032   if (getOperand1() == getOperand2()) {
1033     auto trueAttr = BoolAttr::get(getContext(), true);
1034     if (isa<IntegerType>(getType()))
1035       return trueAttr;
1036     if (auto vecTy = dyn_cast<VectorType>(getType()))
1037       return SplatElementsAttr::get(vecTy, trueAttr);
1038   }
1039 
1040   return constFoldBinaryOp<IntegerAttr>(
1041       adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
1042         return a.ule(b) ? APInt::getAllOnes(1) : APInt::getZero(1);
1043       });
1044 }
1045 
1046 //===----------------------------------------------------------------------===//
1047 // spirv.ShiftLeftLogical
1048 //===----------------------------------------------------------------------===//
1049 
fold(spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor)1050 OpFoldResult spirv::ShiftLeftLogicalOp::fold(
1051     spirv::ShiftLeftLogicalOp::FoldAdaptor adaptor) {
1052   // x << 0 -> x
1053   if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1054     return getOperand1();
1055   }
1056 
1057   // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1058 
1059   // Results are computed per component, and within each component, per bit...
1060   //
1061   // The result is undefined if Shift is greater than or equal to the bit width
1062   // of the components of Base.
1063   //
1064   // So we can use the APInt << method, but don't fold if undefined behaviour.
1065   bool shiftToLarge = false;
1066   auto res = constFoldBinaryOp<IntegerAttr>(
1067       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1068         if (shiftToLarge || b.uge(a.getBitWidth())) {
1069           shiftToLarge = true;
1070           return a;
1071         }
1072         return a << b;
1073       });
1074   return shiftToLarge ? Attribute() : res;
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // spirv.ShiftRightArithmetic
1079 //===----------------------------------------------------------------------===//
1080 
fold(spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor)1081 OpFoldResult spirv::ShiftRightArithmeticOp::fold(
1082     spirv::ShiftRightArithmeticOp::FoldAdaptor adaptor) {
1083   // x >> 0 -> x
1084   if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1085     return getOperand1();
1086   }
1087 
1088   // Unfortunately due to below undefined behaviour can't fold 0, -1 for Base.
1089 
1090   // Results are computed per component, and within each component, per bit...
1091   //
1092   // The result is undefined if Shift is greater than or equal to the bit width
1093   // of the components of Base.
1094   //
1095   // So we can use the APInt ashr method, but don't fold if undefined behaviour.
1096   bool shiftToLarge = false;
1097   auto res = constFoldBinaryOp<IntegerAttr>(
1098       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1099         if (shiftToLarge || b.uge(a.getBitWidth())) {
1100           shiftToLarge = true;
1101           return a;
1102         }
1103         return a.ashr(b);
1104       });
1105   return shiftToLarge ? Attribute() : res;
1106 }
1107 
1108 //===----------------------------------------------------------------------===//
1109 // spirv.ShiftRightLogical
1110 //===----------------------------------------------------------------------===//
1111 
fold(spirv::ShiftRightLogicalOp::FoldAdaptor adaptor)1112 OpFoldResult spirv::ShiftRightLogicalOp::fold(
1113     spirv::ShiftRightLogicalOp::FoldAdaptor adaptor) {
1114   // x >> 0 -> x
1115   if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1116     return getOperand1();
1117   }
1118 
1119   // Unfortunately due to below undefined behaviour can't fold 0 for Base.
1120 
1121   // Results are computed per component, and within each component, per bit...
1122   //
1123   // The result is undefined if Shift is greater than or equal to the bit width
1124   // of the components of Base.
1125   //
1126   // So we can use the APInt lshr method, but don't fold if undefined behaviour.
1127   bool shiftToLarge = false;
1128   auto res = constFoldBinaryOp<IntegerAttr>(
1129       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
1130         if (shiftToLarge || b.uge(a.getBitWidth())) {
1131           shiftToLarge = true;
1132           return a;
1133         }
1134         return a.lshr(b);
1135       });
1136   return shiftToLarge ? Attribute() : res;
1137 }
1138 
1139 //===----------------------------------------------------------------------===//
1140 // spirv.BitwiseAndOp
1141 //===----------------------------------------------------------------------===//
1142 
1143 OpFoldResult
fold(spirv::BitwiseAndOp::FoldAdaptor adaptor)1144 spirv::BitwiseAndOp::fold(spirv::BitwiseAndOp::FoldAdaptor adaptor) {
1145   // x & x -> x
1146   if (getOperand1() == getOperand2()) {
1147     return getOperand1();
1148   }
1149 
1150   APInt rhsMask;
1151   if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1152     // x & 0 -> 0
1153     if (rhsMask.isZero())
1154       return getOperand2();
1155 
1156     // x & <all ones> -> x
1157     if (rhsMask.isAllOnes())
1158       return getOperand1();
1159 
1160     // (UConvert x : iN to iK) & <mask with N low bits set> -> UConvert x
1161     if (auto zext = getOperand1().getDefiningOp<spirv::UConvertOp>()) {
1162       int valueBits =
1163           getElementTypeOrSelf(zext.getOperand()).getIntOrFloatBitWidth();
1164       if (rhsMask.zextOrTrunc(valueBits).isAllOnes())
1165         return getOperand1();
1166     }
1167   }
1168 
1169   // According to the SPIR-V spec:
1170   //
1171   // Type is a scalar or vector of integer type.
1172   // Results are computed per component, and within each component, per bit.
1173   // So we can use the APInt & method.
1174   return constFoldBinaryOp<IntegerAttr>(
1175       adaptor.getOperands(),
1176       [](const APInt &a, const APInt &b) { return a & b; });
1177 }
1178 
1179 //===----------------------------------------------------------------------===//
1180 // spirv.BitwiseOrOp
1181 //===----------------------------------------------------------------------===//
1182 
fold(spirv::BitwiseOrOp::FoldAdaptor adaptor)1183 OpFoldResult spirv::BitwiseOrOp::fold(spirv::BitwiseOrOp::FoldAdaptor adaptor) {
1184   // x | x -> x
1185   if (getOperand1() == getOperand2()) {
1186     return getOperand1();
1187   }
1188 
1189   APInt rhsMask;
1190   if (matchPattern(adaptor.getOperand2(), m_ConstantInt(&rhsMask))) {
1191     // x | 0 -> x
1192     if (rhsMask.isZero())
1193       return getOperand1();
1194 
1195     // x | <all ones> -> <all ones>
1196     if (rhsMask.isAllOnes())
1197       return getOperand2();
1198   }
1199 
1200   // According to the SPIR-V spec:
1201   //
1202   // Type is a scalar or vector of integer type.
1203   // Results are computed per component, and within each component, per bit.
1204   // So we can use the APInt | method.
1205   return constFoldBinaryOp<IntegerAttr>(
1206       adaptor.getOperands(),
1207       [](const APInt &a, const APInt &b) { return a | b; });
1208 }
1209 
1210 //===----------------------------------------------------------------------===//
1211 // spirv.BitwiseXorOp
1212 //===----------------------------------------------------------------------===//
1213 
1214 OpFoldResult
fold(spirv::BitwiseXorOp::FoldAdaptor adaptor)1215 spirv::BitwiseXorOp::fold(spirv::BitwiseXorOp::FoldAdaptor adaptor) {
1216   // x ^ 0 -> x
1217   if (matchPattern(adaptor.getOperand2(), m_Zero())) {
1218     return getOperand1();
1219   }
1220 
1221   // x ^ x -> 0
1222   if (getOperand1() == getOperand2())
1223     return Builder(getContext()).getZeroAttr(getType());
1224 
1225   // According to the SPIR-V spec:
1226   //
1227   // Type is a scalar or vector of integer type.
1228   // Results are computed per component, and within each component, per bit.
1229   // So we can use the APInt ^ method.
1230   return constFoldBinaryOp<IntegerAttr>(
1231       adaptor.getOperands(),
1232       [](const APInt &a, const APInt &b) { return a ^ b; });
1233 }
1234 
1235 //===----------------------------------------------------------------------===//
1236 // spirv.mlir.selection
1237 //===----------------------------------------------------------------------===//
1238 
1239 namespace {
1240 // Blocks from the given `spirv.mlir.selection` operation must satisfy the
1241 // following layout:
1242 //
1243 //       +-----------------------------------------------+
1244 //       | header block                                  |
1245 //       | spirv.BranchConditionalOp %cond, ^case0, ^case1 |
1246 //       +-----------------------------------------------+
1247 //                            /   \
1248 //                             ...
1249 //
1250 //
1251 //   +------------------------+    +------------------------+
1252 //   | case #0                |    | case #1                |
1253 //   | spirv.Store %ptr %value0 |    | spirv.Store %ptr %value1 |
1254 //   | spirv.Branch ^merge      |    | spirv.Branch ^merge      |
1255 //   +------------------------+    +------------------------+
1256 //
1257 //
1258 //                             ...
1259 //                            \   /
1260 //                              v
1261 //                       +-------------+
1262 //                       | merge block |
1263 //                       +-------------+
1264 //
1265 struct ConvertSelectionOpToSelect final : OpRewritePattern<spirv::SelectionOp> {
1266   using OpRewritePattern::OpRewritePattern;
1267 
matchAndRewrite__anona5f61ae82511::ConvertSelectionOpToSelect1268   LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
1269                                 PatternRewriter &rewriter) const override {
1270     Operation *op = selectionOp.getOperation();
1271     Region &body = op->getRegion(0);
1272     // Verifier allows an empty region for `spirv.mlir.selection`.
1273     if (body.empty()) {
1274       return failure();
1275     }
1276 
1277     // Check that region consists of 4 blocks:
1278     // header block, `true` block, `false` block and merge block.
1279     if (llvm::range_size(body) != 4) {
1280       return failure();
1281     }
1282 
1283     Block *headerBlock = selectionOp.getHeaderBlock();
1284     if (!onlyContainsBranchConditionalOp(headerBlock)) {
1285       return failure();
1286     }
1287 
1288     auto brConditionalOp =
1289         cast<spirv::BranchConditionalOp>(headerBlock->front());
1290 
1291     Block *trueBlock = brConditionalOp.getSuccessor(0);
1292     Block *falseBlock = brConditionalOp.getSuccessor(1);
1293     Block *mergeBlock = selectionOp.getMergeBlock();
1294 
1295     if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
1296       return failure();
1297 
1298     Value trueValue = getSrcValue(trueBlock);
1299     Value falseValue = getSrcValue(falseBlock);
1300     Value ptrValue = getDstPtr(trueBlock);
1301     auto storeOpAttributes =
1302         cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
1303 
1304     auto selectOp = rewriter.create<spirv::SelectOp>(
1305         selectionOp.getLoc(), trueValue.getType(),
1306         brConditionalOp.getCondition(), trueValue, falseValue);
1307     rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
1308                                     selectOp.getResult(), storeOpAttributes);
1309 
1310     // `spirv.mlir.selection` is not needed anymore.
1311     rewriter.eraseOp(op);
1312     return success();
1313   }
1314 
1315 private:
1316   // Checks that given blocks follow the following rules:
1317   // 1. Each conditional block consists of two operations, the first operation
1318   //    is a `spirv.Store` and the last operation is a `spirv.Branch`.
1319   // 2. Each `spirv.Store` uses the same pointer and the same memory attributes.
1320   // 3. A control flow goes into the given merge block from the given
1321   //    conditional blocks.
1322   LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
1323                                          Block *mergeBlock) const;
1324 
onlyContainsBranchConditionalOp__anona5f61ae82511::ConvertSelectionOpToSelect1325   bool onlyContainsBranchConditionalOp(Block *block) const {
1326     return llvm::hasSingleElement(*block) &&
1327            isa<spirv::BranchConditionalOp>(block->front());
1328   }
1329 
isSameAttrList__anona5f61ae82511::ConvertSelectionOpToSelect1330   bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
1331     return lhs->getDiscardableAttrDictionary() ==
1332                rhs->getDiscardableAttrDictionary() &&
1333            lhs.getProperties() == rhs.getProperties();
1334   }
1335 
1336   // Returns a source value for the given block.
getSrcValue__anona5f61ae82511::ConvertSelectionOpToSelect1337   Value getSrcValue(Block *block) const {
1338     auto storeOp = cast<spirv::StoreOp>(block->front());
1339     return storeOp.getValue();
1340   }
1341 
1342   // Returns a destination value for the given block.
getDstPtr__anona5f61ae82511::ConvertSelectionOpToSelect1343   Value getDstPtr(Block *block) const {
1344     auto storeOp = cast<spirv::StoreOp>(block->front());
1345     return storeOp.getPtr();
1346   }
1347 };
1348 
canCanonicalizeSelection(Block * trueBlock,Block * falseBlock,Block * mergeBlock) const1349 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
1350     Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
1351   // Each block must consists of 2 operations.
1352   if (llvm::range_size(*trueBlock) != 2 || llvm::range_size(*falseBlock) != 2) {
1353     return failure();
1354   }
1355 
1356   auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
1357   auto trueBrBranchOp =
1358       dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
1359   auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
1360   auto falseBrBranchOp =
1361       dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
1362 
1363   if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
1364       !falseBrBranchOp) {
1365     return failure();
1366   }
1367 
1368   // Checks that given type is valid for `spirv.SelectOp`.
1369   // According to SPIR-V spec:
1370   // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
1371   // Starting with version 1.4, Result Type can additionally be a composite type
1372   // other than a vector."
1373   bool isScalarOrVector =
1374       llvm::cast<spirv::SPIRVType>(trueBrStoreOp.getValue().getType())
1375           .isScalarOrVector();
1376 
1377   // Check that each `spirv.Store` uses the same pointer, memory access
1378   // attributes and a valid type of the value.
1379   if ((trueBrStoreOp.getPtr() != falseBrStoreOp.getPtr()) ||
1380       !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
1381     return failure();
1382   }
1383 
1384   if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
1385       (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
1386     return failure();
1387   }
1388 
1389   return success();
1390 }
1391 } // namespace
1392 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1393 void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1394                                                      MLIRContext *context) {
1395   results.add<ConvertSelectionOpToSelect>(context);
1396 }
1397