1fec6c5acSUday Bondhugula //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// 2fec6c5acSUday Bondhugula // 3fec6c5acSUday Bondhugula // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4fec6c5acSUday Bondhugula // See https://llvm.org/LICENSE.txt for license information. 5fec6c5acSUday Bondhugula // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6fec6c5acSUday Bondhugula // 7fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 8fec6c5acSUday Bondhugula 9fec6c5acSUday Bondhugula #include "TestDialect.h" 10e95e94adSJeff Niu #include "TestOps.h" 119c5982efSAlex Zinenko #include "TestTypes.h" 12abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 1323aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 1423aa5a74SRiver Riddle #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 15c0a6318dSMatthias Springer #include "mlir/Dialect/Tensor/IR/Tensor.h" 160f8a6b7dSJakub Kuderski #include "mlir/IR/BuiltinAttributes.h" 172bf423b0SRob Suderman #include "mlir/IR/Matchers.h" 180f8a6b7dSJakub Kuderski #include "mlir/IR/Visitors.h" 19fec6c5acSUday Bondhugula #include "mlir/Pass/Pass.h" 20fec6c5acSUday Bondhugula #include "mlir/Transforms/DialectConversion.h" 2126f93d9fSAlex Zinenko #include "mlir/Transforms/FoldUtils.h" 22b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 230f8a6b7dSJakub Kuderski #include "mlir/Transforms/WalkPatternRewriteDriver.h" 24dc3dc974SMehdi Amini #include "llvm/ADT/ScopeExit.h" 250f8a6b7dSJakub Kuderski #include <cstdint> 269ba37b3bSJacques Pienaar 27fec6c5acSUday Bondhugula using namespace mlir; 287776b19eSStephen Neuendorffer using namespace test; 29fec6c5acSUday Bondhugula 30fec6c5acSUday Bondhugula // Native function for testing NativeCodeCall 31fec6c5acSUday Bondhugula static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { 32fec6c5acSUday Bondhugula return choice.getValue() ? input1 : input2; 33fec6c5acSUday Bondhugula } 34fec6c5acSUday Bondhugula 3529429d1aSJacques Pienaar static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { 3629429d1aSJacques Pienaar rewriter.create<OpI>(loc, input); 37fec6c5acSUday Bondhugula } 38fec6c5acSUday Bondhugula 39fec6c5acSUday Bondhugula static void handleNoResultOp(PatternRewriter &rewriter, 40fec6c5acSUday Bondhugula OpSymbolBindingNoResult op) { 41fec6c5acSUday Bondhugula // Turn the no result op to a one-result op. 426a994233SJacques Pienaar rewriter.create<OpSymbolBindingB>(op.getLoc(), op.getOperand().getType(), 436a994233SJacques Pienaar op.getOperand()); 44fec6c5acSUday Bondhugula } 45fec6c5acSUday Bondhugula 4634b5482bSChia-hung Duan static bool getFirstI32Result(Operation *op, Value &value) { 4734b5482bSChia-hung Duan if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) 4834b5482bSChia-hung Duan return false; 4934b5482bSChia-hung Duan value = op->getResult(0); 5034b5482bSChia-hung Duan return true; 5134b5482bSChia-hung Duan } 5234b5482bSChia-hung Duan 5334b5482bSChia-hung Duan static Value bindNativeCodeCallResult(Value value) { return value; } 5434b5482bSChia-hung Duan 55d7314b3cSChia-hung Duan static SmallVector<Value, 2> bindMultipleNativeCodeCallResult(Value input1, 56d7314b3cSChia-hung Duan Value input2) { 57d7314b3cSChia-hung Duan return SmallVector<Value, 2>({input2, input1}); 58d7314b3cSChia-hung Duan } 59d7314b3cSChia-hung Duan 6001641197SAlexEichenberger // Test that natives calls are only called once during rewrites. 6101641197SAlexEichenberger // OpM_Test will return Pi, increased by 1 for each subsequent calls. 6201641197SAlexEichenberger // This let us check the number of times OpM_Test was called by inspecting 6301641197SAlexEichenberger // the returned value in the MLIR output. 6401641197SAlexEichenberger static int64_t opMIncreasingValue = 314159265; 6502b6fb21SMehdi Amini static Attribute opMTest(PatternRewriter &rewriter, Value val) { 6601641197SAlexEichenberger int64_t i = opMIncreasingValue++; 6701641197SAlexEichenberger return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); 6801641197SAlexEichenberger } 6901641197SAlexEichenberger 70fec6c5acSUday Bondhugula namespace { 71fec6c5acSUday Bondhugula #include "TestPatterns.inc" 72be0a7e9fSMehdi Amini } // namespace 73fec6c5acSUday Bondhugula 74fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 75c484c7ddSChia-hung Duan // Test Reduce Pattern Interface 76c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===// 77c484c7ddSChia-hung Duan 787776b19eSStephen Neuendorffer void test::populateTestReductionPatterns(RewritePatternSet &patterns) { 79c484c7ddSChia-hung Duan populateWithGenerated(patterns); 80c484c7ddSChia-hung Duan } 81c484c7ddSChia-hung Duan 82c484c7ddSChia-hung Duan //===----------------------------------------------------------------------===// 83fec6c5acSUday Bondhugula // Canonicalizer Driver. 84fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 85fec6c5acSUday Bondhugula 86fec6c5acSUday Bondhugula namespace { 8726f93d9fSAlex Zinenko struct FoldingPattern : public RewritePattern { 8826f93d9fSAlex Zinenko public: 8926f93d9fSAlex Zinenko FoldingPattern(MLIRContext *context) 9026f93d9fSAlex Zinenko : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), 9126f93d9fSAlex Zinenko /*benefit=*/1, context) {} 9226f93d9fSAlex Zinenko 9326f93d9fSAlex Zinenko LogicalResult matchAndRewrite(Operation *op, 9426f93d9fSAlex Zinenko PatternRewriter &rewriter) const override { 95506fd672SMatthias Springer // Exercise createOrFold API for a single-result operation that is folded 96506fd672SMatthias Springer // upon construction. The operation being created has an in-place folder, 97506fd672SMatthias Springer // and it should be still present in the output. Furthermore, the folder 98506fd672SMatthias Springer // should not crash when attempting to recover the (unchanged) operation 99506fd672SMatthias Springer // result. 100506fd672SMatthias Springer Value result = rewriter.createOrFold<TestOpInPlaceFold>( 101506fd672SMatthias Springer op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0)); 10226f93d9fSAlex Zinenko assert(result); 10326f93d9fSAlex Zinenko rewriter.replaceOp(op, result); 10426f93d9fSAlex Zinenko return success(); 10526f93d9fSAlex Zinenko } 10626f93d9fSAlex Zinenko }; 10726f93d9fSAlex Zinenko 108e4635e63SRiver Riddle /// This pattern creates a foldable operation at the entry point of the block. 109e4635e63SRiver Riddle /// This tests the situation where the operation folder will need to replace an 110e4635e63SRiver Riddle /// operation with a previously created constant that does not initially 111e4635e63SRiver Riddle /// dominate the operation to replace. 112e4635e63SRiver Riddle struct FolderInsertBeforePreviouslyFoldedConstantPattern 113e4635e63SRiver Riddle : public OpRewritePattern<TestCastOp> { 114e4635e63SRiver Riddle public: 115e4635e63SRiver Riddle using OpRewritePattern<TestCastOp>::OpRewritePattern; 116e4635e63SRiver Riddle 117e4635e63SRiver Riddle LogicalResult matchAndRewrite(TestCastOp op, 118e4635e63SRiver Riddle PatternRewriter &rewriter) const override { 119e4635e63SRiver Riddle if (!op->hasAttr("test_fold_before_previously_folded_op")) 120e4635e63SRiver Riddle return failure(); 121e4635e63SRiver Riddle rewriter.setInsertionPointToStart(op->getBlock()); 122e4635e63SRiver Riddle 123a54f4eaeSMogball auto constOp = rewriter.create<arith::ConstantOp>( 124a54f4eaeSMogball op.getLoc(), rewriter.getBoolAttr(true)); 125e4635e63SRiver Riddle rewriter.replaceOpWithNewOp<TestCastOp>(op, rewriter.getI32Type(), 126e4635e63SRiver Riddle Value(constOp)); 127e4635e63SRiver Riddle return success(); 128e4635e63SRiver Riddle } 129e4635e63SRiver Riddle }; 130e4635e63SRiver Riddle 13143f0d5f9SMehdi Amini /// This pattern matches test.op_commutative2 with the first operand being 13243f0d5f9SMehdi Amini /// another test.op_commutative2 with a constant on the right side and fold it 13343f0d5f9SMehdi Amini /// away by propagating it as its result. This is intend to check that patterns 13443f0d5f9SMehdi Amini /// are applied after the commutative property moves constant to the right. 13543f0d5f9SMehdi Amini struct FolderCommutativeOp2WithConstant 13643f0d5f9SMehdi Amini : public OpRewritePattern<TestCommutative2Op> { 13743f0d5f9SMehdi Amini public: 13843f0d5f9SMehdi Amini using OpRewritePattern<TestCommutative2Op>::OpRewritePattern; 13943f0d5f9SMehdi Amini 14043f0d5f9SMehdi Amini LogicalResult matchAndRewrite(TestCommutative2Op op, 14143f0d5f9SMehdi Amini PatternRewriter &rewriter) const override { 14243f0d5f9SMehdi Amini auto operand = 14343f0d5f9SMehdi Amini dyn_cast_or_null<TestCommutative2Op>(op->getOperand(0).getDefiningOp()); 14443f0d5f9SMehdi Amini if (!operand) 14543f0d5f9SMehdi Amini return failure(); 14643f0d5f9SMehdi Amini Attribute constInput; 14743f0d5f9SMehdi Amini if (!matchPattern(operand->getOperand(1), m_Constant(&constInput))) 14843f0d5f9SMehdi Amini return failure(); 14943f0d5f9SMehdi Amini rewriter.replaceOp(op, operand->getOperand(1)); 15043f0d5f9SMehdi Amini return success(); 15143f0d5f9SMehdi Amini } 15243f0d5f9SMehdi Amini }; 15343f0d5f9SMehdi Amini 1542fea658aSMatthias Springer /// This pattern matches test.any_attr_of_i32_str ops. In case of an integer 1552fea658aSMatthias Springer /// attribute with value smaller than MaxVal, it increments the value by 1. 1562fea658aSMatthias Springer template <int MaxVal> 1572fea658aSMatthias Springer struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> { 1582fea658aSMatthias Springer using OpRewritePattern<AnyAttrOfOp>::OpRewritePattern; 1592fea658aSMatthias Springer 1602fea658aSMatthias Springer LogicalResult matchAndRewrite(AnyAttrOfOp op, 1612fea658aSMatthias Springer PatternRewriter &rewriter) const override { 1625550c821STres Popp auto intAttr = dyn_cast<IntegerAttr>(op.getAttr()); 1632fea658aSMatthias Springer if (!intAttr) 1642fea658aSMatthias Springer return failure(); 1652fea658aSMatthias Springer int64_t val = intAttr.getInt(); 1662fea658aSMatthias Springer if (val >= MaxVal) 1672fea658aSMatthias Springer return failure(); 1685fcf907bSMatthias Springer rewriter.modifyOpInPlace( 1692fea658aSMatthias Springer op, [&]() { op.setAttrAttr(rewriter.getI32IntegerAttr(val + 1)); }); 1702fea658aSMatthias Springer return success(); 1712fea658aSMatthias Springer } 1722fea658aSMatthias Springer }; 1732fea658aSMatthias Springer 174ed9194beSMatthias Springer /// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op". 175ed9194beSMatthias Springer struct MakeOpEligible : public RewritePattern { 176ed9194beSMatthias Springer MakeOpEligible(MLIRContext *context) 177ed9194beSMatthias Springer : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {} 178ed9194beSMatthias Springer 179ed9194beSMatthias Springer LogicalResult matchAndRewrite(Operation *op, 180ed9194beSMatthias Springer PatternRewriter &rewriter) const override { 181ed9194beSMatthias Springer if (op->hasAttr("eligible")) 182ed9194beSMatthias Springer return failure(); 1835fcf907bSMatthias Springer rewriter.modifyOpInPlace( 184ed9194beSMatthias Springer op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); }); 185ed9194beSMatthias Springer return success(); 186ed9194beSMatthias Springer } 187ed9194beSMatthias Springer }; 188ed9194beSMatthias Springer 189ed9194beSMatthias Springer /// This pattern hoists eligible ops out of a "test.one_region_op". 190ed9194beSMatthias Springer struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> { 191ed9194beSMatthias Springer using OpRewritePattern<test::OneRegionOp>::OpRewritePattern; 192ed9194beSMatthias Springer 193ed9194beSMatthias Springer LogicalResult matchAndRewrite(test::OneRegionOp op, 194ed9194beSMatthias Springer PatternRewriter &rewriter) const override { 195ed9194beSMatthias Springer Operation *terminator = op.getRegion().front().getTerminator(); 196ed9194beSMatthias Springer Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp(); 197ed9194beSMatthias Springer if (toBeHoisted->getParentOp() != op) 198ed9194beSMatthias Springer return failure(); 199ed9194beSMatthias Springer if (!toBeHoisted->hasAttr("eligible")) 200ed9194beSMatthias Springer return failure(); 2015cc0f76dSMatthias Springer rewriter.moveOpBefore(toBeHoisted, op); 202ed9194beSMatthias Springer return success(); 203ed9194beSMatthias Springer } 204ed9194beSMatthias Springer }; 205ed9194beSMatthias Springer 206da784a25SMatthias Springer /// This pattern moves "test.move_before_parent_op" before the parent op. 207da784a25SMatthias Springer struct MoveBeforeParentOp : public RewritePattern { 208da784a25SMatthias Springer MoveBeforeParentOp(MLIRContext *context) 209da784a25SMatthias Springer : RewritePattern("test.move_before_parent_op", /*benefit=*/1, context) {} 210da784a25SMatthias Springer 211da784a25SMatthias Springer LogicalResult matchAndRewrite(Operation *op, 212da784a25SMatthias Springer PatternRewriter &rewriter) const override { 213da784a25SMatthias Springer // Do not hoist past functions. 214da784a25SMatthias Springer if (isa<FunctionOpInterface>(op->getParentOp())) 215da784a25SMatthias Springer return failure(); 216da784a25SMatthias Springer rewriter.moveOpBefore(op, op->getParentOp()); 217da784a25SMatthias Springer return success(); 218da784a25SMatthias Springer } 219da784a25SMatthias Springer }; 220da784a25SMatthias Springer 2210f8a6b7dSJakub Kuderski /// This pattern moves "test.move_after_parent_op" after the parent op. 2220f8a6b7dSJakub Kuderski struct MoveAfterParentOp : public RewritePattern { 2230f8a6b7dSJakub Kuderski MoveAfterParentOp(MLIRContext *context) 2240f8a6b7dSJakub Kuderski : RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {} 2250f8a6b7dSJakub Kuderski 2260f8a6b7dSJakub Kuderski LogicalResult matchAndRewrite(Operation *op, 2270f8a6b7dSJakub Kuderski PatternRewriter &rewriter) const override { 2280f8a6b7dSJakub Kuderski // Do not hoist past functions. 2290f8a6b7dSJakub Kuderski if (isa<FunctionOpInterface>(op->getParentOp())) 2300f8a6b7dSJakub Kuderski return failure(); 2310f8a6b7dSJakub Kuderski 2320f8a6b7dSJakub Kuderski int64_t moveForwardBy = 0; 2330f8a6b7dSJakub Kuderski if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance")) 2340f8a6b7dSJakub Kuderski moveForwardBy = advanceBy.getInt(); 2350f8a6b7dSJakub Kuderski 2360f8a6b7dSJakub Kuderski Operation *moveAfter = op->getParentOp(); 2370f8a6b7dSJakub Kuderski for (int64_t i = 0; i < moveForwardBy; ++i) 2380f8a6b7dSJakub Kuderski moveAfter = moveAfter->getNextNode(); 2390f8a6b7dSJakub Kuderski 2400f8a6b7dSJakub Kuderski rewriter.moveOpAfter(op, moveAfter); 2410f8a6b7dSJakub Kuderski return success(); 2420f8a6b7dSJakub Kuderski } 2430f8a6b7dSJakub Kuderski }; 2440f8a6b7dSJakub Kuderski 245c672b342SMatthias Springer /// This pattern inlines blocks that are nested in 246c672b342SMatthias Springer /// "test.inline_blocks_into_parent" into the parent block. 247c672b342SMatthias Springer struct InlineBlocksIntoParent : public RewritePattern { 248c672b342SMatthias Springer InlineBlocksIntoParent(MLIRContext *context) 249c672b342SMatthias Springer : RewritePattern("test.inline_blocks_into_parent", /*benefit=*/1, 250c672b342SMatthias Springer context) {} 251c672b342SMatthias Springer 252c672b342SMatthias Springer LogicalResult matchAndRewrite(Operation *op, 253c672b342SMatthias Springer PatternRewriter &rewriter) const override { 254c672b342SMatthias Springer bool changed = false; 255c672b342SMatthias Springer for (Region &r : op->getRegions()) { 256c672b342SMatthias Springer while (!r.empty()) { 257c672b342SMatthias Springer rewriter.inlineBlockBefore(&r.front(), op); 258c672b342SMatthias Springer changed = true; 259c672b342SMatthias Springer } 260c672b342SMatthias Springer } 261c672b342SMatthias Springer return success(changed); 262c672b342SMatthias Springer } 263c672b342SMatthias Springer }; 264c672b342SMatthias Springer 265c2675ba9SMatthias Springer /// This pattern splits blocks at "test.split_block_here" and replaces the op 266c2675ba9SMatthias Springer /// with a new op (to prevent an infinite loop of block splitting). 267c2675ba9SMatthias Springer struct SplitBlockHere : public RewritePattern { 268c2675ba9SMatthias Springer SplitBlockHere(MLIRContext *context) 269c2675ba9SMatthias Springer : RewritePattern("test.split_block_here", /*benefit=*/1, context) {} 270c2675ba9SMatthias Springer 271c2675ba9SMatthias Springer LogicalResult matchAndRewrite(Operation *op, 272c2675ba9SMatthias Springer PatternRewriter &rewriter) const override { 273c2675ba9SMatthias Springer rewriter.splitBlock(op->getBlock(), op->getIterator()); 274c2675ba9SMatthias Springer Operation *newOp = rewriter.create( 275c2675ba9SMatthias Springer op->getLoc(), 276c2675ba9SMatthias Springer OperationName("test.new_op", op->getContext()).getIdentifier(), 277c2675ba9SMatthias Springer op->getOperands(), op->getResultTypes()); 278c2675ba9SMatthias Springer rewriter.replaceOp(op, newOp); 279c2675ba9SMatthias Springer return success(); 280c2675ba9SMatthias Springer } 281c2675ba9SMatthias Springer }; 282c2675ba9SMatthias Springer 283237a799eSMatthias Springer /// This pattern clones "test.clone_me" ops. 284237a799eSMatthias Springer struct CloneOp : public RewritePattern { 285237a799eSMatthias Springer CloneOp(MLIRContext *context) 286237a799eSMatthias Springer : RewritePattern("test.clone_me", /*benefit=*/1, context) {} 287237a799eSMatthias Springer 288237a799eSMatthias Springer LogicalResult matchAndRewrite(Operation *op, 289237a799eSMatthias Springer PatternRewriter &rewriter) const override { 290237a799eSMatthias Springer // Do not clone already cloned ops to avoid going into an infinite loop. 291237a799eSMatthias Springer if (op->hasAttr("was_cloned")) 292237a799eSMatthias Springer return failure(); 293237a799eSMatthias Springer Operation *cloned = rewriter.clone(*op); 294237a799eSMatthias Springer cloned->setAttr("was_cloned", rewriter.getUnitAttr()); 295237a799eSMatthias Springer return success(); 296237a799eSMatthias Springer } 297237a799eSMatthias Springer }; 298237a799eSMatthias Springer 299b840d296SMatthias Springer /// This pattern clones regions of "test.clone_region_before" ops before the 300b840d296SMatthias Springer /// parent block. 301b840d296SMatthias Springer struct CloneRegionBeforeOp : public RewritePattern { 302b840d296SMatthias Springer CloneRegionBeforeOp(MLIRContext *context) 303b840d296SMatthias Springer : RewritePattern("test.clone_region_before", /*benefit=*/1, context) {} 304b840d296SMatthias Springer 305b840d296SMatthias Springer LogicalResult matchAndRewrite(Operation *op, 306b840d296SMatthias Springer PatternRewriter &rewriter) const override { 307b840d296SMatthias Springer // Do not clone already cloned ops to avoid going into an infinite loop. 308b840d296SMatthias Springer if (op->hasAttr("was_cloned")) 309b840d296SMatthias Springer return failure(); 310b840d296SMatthias Springer for (Region &r : op->getRegions()) 311b840d296SMatthias Springer rewriter.cloneRegionBefore(r, op->getBlock()); 312b840d296SMatthias Springer op->setAttr("was_cloned", rewriter.getUnitAttr()); 313b840d296SMatthias Springer return success(); 314b840d296SMatthias Springer } 315b840d296SMatthias Springer }; 316b840d296SMatthias Springer 3170f8a6b7dSJakub Kuderski /// Replace an operation may introduce the re-visiting of its users. 3180f8a6b7dSJakub Kuderski class ReplaceWithNewOp : public RewritePattern { 3190f8a6b7dSJakub Kuderski public: 3200f8a6b7dSJakub Kuderski ReplaceWithNewOp(MLIRContext *context) 3210f8a6b7dSJakub Kuderski : RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {} 3225e50dd04SRiver Riddle 3230f8a6b7dSJakub Kuderski LogicalResult matchAndRewrite(Operation *op, 3240f8a6b7dSJakub Kuderski PatternRewriter &rewriter) const override { 3250f8a6b7dSJakub Kuderski Operation *newOp; 3260f8a6b7dSJakub Kuderski if (op->hasAttr("create_erase_op")) { 3270f8a6b7dSJakub Kuderski newOp = rewriter.create( 3280f8a6b7dSJakub Kuderski op->getLoc(), 3290f8a6b7dSJakub Kuderski OperationName("test.erase_op", op->getContext()).getIdentifier(), 3300f8a6b7dSJakub Kuderski ValueRange(), TypeRange()); 3310f8a6b7dSJakub Kuderski } else { 3320f8a6b7dSJakub Kuderski newOp = rewriter.create( 3330f8a6b7dSJakub Kuderski op->getLoc(), 3340f8a6b7dSJakub Kuderski OperationName("test.new_op", op->getContext()).getIdentifier(), 3350f8a6b7dSJakub Kuderski op->getOperands(), op->getResultTypes()); 3360f8a6b7dSJakub Kuderski } 3370f8a6b7dSJakub Kuderski // "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp". 3380f8a6b7dSJakub Kuderski // A "notifyOperationReplaced" callback is triggered in either case. 3390f8a6b7dSJakub Kuderski rewriter.replaceAllOpUsesWith(op, newOp->getResults()); 3400f8a6b7dSJakub Kuderski rewriter.eraseOp(op); 3410f8a6b7dSJakub Kuderski return success(); 3420f8a6b7dSJakub Kuderski } 3430f8a6b7dSJakub Kuderski }; 3447814b559Srkayaith 3450f8a6b7dSJakub Kuderski /// Erases the first child block of the matched "test.erase_first_block" 3460f8a6b7dSJakub Kuderski /// operation. 3470f8a6b7dSJakub Kuderski class EraseFirstBlock : public RewritePattern { 3480f8a6b7dSJakub Kuderski public: 3490f8a6b7dSJakub Kuderski EraseFirstBlock(MLIRContext *context) 3500f8a6b7dSJakub Kuderski : RewritePattern("test.erase_first_block", /*benefit=*/1, context) {} 3510f8a6b7dSJakub Kuderski 3520f8a6b7dSJakub Kuderski LogicalResult matchAndRewrite(Operation *op, 3530f8a6b7dSJakub Kuderski PatternRewriter &rewriter) const override { 3540f8a6b7dSJakub Kuderski for (Region &r : op->getRegions()) { 3550f8a6b7dSJakub Kuderski for (Block &b : r.getBlocks()) { 3560f8a6b7dSJakub Kuderski rewriter.eraseBlock(&b); 3570f8a6b7dSJakub Kuderski return success(); 3580f8a6b7dSJakub Kuderski } 3590f8a6b7dSJakub Kuderski } 3600f8a6b7dSJakub Kuderski 3610f8a6b7dSJakub Kuderski return failure(); 3620f8a6b7dSJakub Kuderski } 3630f8a6b7dSJakub Kuderski }; 3640f8a6b7dSJakub Kuderski 3650f8a6b7dSJakub Kuderski struct TestGreedyPatternDriver 3660f8a6b7dSJakub Kuderski : public PassWrapper<TestGreedyPatternDriver, OperationPass<>> { 3670f8a6b7dSJakub Kuderski MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver) 3680f8a6b7dSJakub Kuderski 3690f8a6b7dSJakub Kuderski TestGreedyPatternDriver() = default; 3700f8a6b7dSJakub Kuderski TestGreedyPatternDriver(const TestGreedyPatternDriver &other) 3710f8a6b7dSJakub Kuderski : PassWrapper(other) {} 3720f8a6b7dSJakub Kuderski 3730f8a6b7dSJakub Kuderski StringRef getArgument() const final { return "test-greedy-patterns"; } 374b5e22e6dSMehdi Amini StringRef getDescription() const final { return "Run test dialect patterns"; } 37541574554SRiver Riddle void runOnOperation() override { 376dc4e913bSChris Lattner mlir::RewritePatternSet patterns(&getContext()); 37783e3c6a7SMehdi Amini populateWithGenerated(patterns); 378fec6c5acSUday Bondhugula 379fec6c5acSUday Bondhugula // Verify named pattern is generated with expected name. 380e4635e63SRiver Riddle patterns.add<FoldingPattern, TestNamedPatternRule, 38143f0d5f9SMehdi Amini FolderInsertBeforePreviouslyFoldedConstantPattern, 382ed9194beSMatthias Springer FolderCommutativeOp2WithConstant, HoistEligibleOps, 383ed9194beSMatthias Springer MakeOpEligible>(&getContext()); 384fec6c5acSUday Bondhugula 3852fea658aSMatthias Springer // Additional patterns for testing the GreedyPatternRewriteDriver. 3862fea658aSMatthias Springer patterns.insert<IncrementIntAttribute<3>>(&getContext()); 3872fea658aSMatthias Springer 3887814b559Srkayaith GreedyRewriteConfig config; 3897814b559Srkayaith config.useTopDownTraversal = this->useTopDownTraversal; 3902fea658aSMatthias Springer config.maxIterations = this->maxIterations; 39109dfc571SJacques Pienaar config.fold = this->fold; 39209dfc571SJacques Pienaar config.cseConstants = this->cseConstants; 39309dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); 394fec6c5acSUday Bondhugula } 3957814b559Srkayaith 3967814b559Srkayaith Option<bool> useTopDownTraversal{ 3977814b559Srkayaith *this, "top-down", 3987814b559Srkayaith llvm::cl::desc("Seed the worklist in general top-down order"), 3997814b559Srkayaith llvm::cl::init(GreedyRewriteConfig().useTopDownTraversal)}; 4002fea658aSMatthias Springer Option<int> maxIterations{ 4012fea658aSMatthias Springer *this, "max-iterations", 4022fea658aSMatthias Springer llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), 4032fea658aSMatthias Springer llvm::cl::init(GreedyRewriteConfig().maxIterations)}; 40409dfc571SJacques Pienaar Option<bool> fold{*this, "fold", llvm::cl::desc("Whether to fold"), 40509dfc571SJacques Pienaar llvm::cl::init(GreedyRewriteConfig().fold)}; 40609dfc571SJacques Pienaar Option<bool> cseConstants{*this, "cse-constants", 40709dfc571SJacques Pienaar llvm::cl::desc("Whether to CSE constants"), 40809dfc571SJacques Pienaar llvm::cl::init(GreedyRewriteConfig().cseConstants)}; 409fec6c5acSUday Bondhugula }; 410ba3a9f51SChia-hung Duan 411695a5a6aSMatthias Springer struct DumpNotifications : public RewriterBase::Listener { 412237a799eSMatthias Springer void notifyBlockInserted(Block *block, Region *previous, 413237a799eSMatthias Springer Region::iterator previousIt) override { 41460a20bd6SMatthias Springer llvm::outs() << "notifyBlockInserted"; 41560a20bd6SMatthias Springer if (block->getParentOp()) { 41660a20bd6SMatthias Springer llvm::outs() << " into " << block->getParentOp()->getName() << ": "; 41760a20bd6SMatthias Springer } else { 41860a20bd6SMatthias Springer llvm::outs() << " into unknown op: "; 41960a20bd6SMatthias Springer } 420237a799eSMatthias Springer if (previous == nullptr) { 421237a799eSMatthias Springer llvm::outs() << "was unlinked\n"; 422237a799eSMatthias Springer } else { 423237a799eSMatthias Springer llvm::outs() << "was linked\n"; 424237a799eSMatthias Springer } 425237a799eSMatthias Springer } 426da784a25SMatthias Springer void notifyOperationInserted(Operation *op, 427da784a25SMatthias Springer OpBuilder::InsertPoint previous) override { 428da784a25SMatthias Springer llvm::outs() << "notifyOperationInserted: " << op->getName(); 429da784a25SMatthias Springer if (!previous.isSet()) { 430da784a25SMatthias Springer llvm::outs() << ", was unlinked\n"; 431da784a25SMatthias Springer } else { 43260a20bd6SMatthias Springer if (!previous.getPoint().getNodePtr()) { 43360a20bd6SMatthias Springer llvm::outs() << ", was linked, exact position unknown\n"; 43460a20bd6SMatthias Springer } else if (previous.getPoint() == previous.getBlock()->end()) { 435da784a25SMatthias Springer llvm::outs() << ", was last in block\n"; 436da784a25SMatthias Springer } else { 437da784a25SMatthias Springer llvm::outs() << ", previous = " << previous.getPoint()->getName() 438da784a25SMatthias Springer << "\n"; 439da784a25SMatthias Springer } 440da784a25SMatthias Springer } 441da784a25SMatthias Springer } 44260a20bd6SMatthias Springer void notifyBlockErased(Block *block) override { 44360a20bd6SMatthias Springer llvm::outs() << "notifyBlockErased\n"; 44460a20bd6SMatthias Springer } 445914e6074SMatthias Springer void notifyOperationErased(Operation *op) override { 446914e6074SMatthias Springer llvm::outs() << "notifyOperationErased: " << op->getName() << "\n"; 447695a5a6aSMatthias Springer } 44860a20bd6SMatthias Springer void notifyOperationModified(Operation *op) override { 44960a20bd6SMatthias Springer llvm::outs() << "notifyOperationModified: " << op->getName() << "\n"; 45060a20bd6SMatthias Springer } 45160a20bd6SMatthias Springer void notifyOperationReplaced(Operation *op, ValueRange values) override { 45260a20bd6SMatthias Springer llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n"; 45360a20bd6SMatthias Springer } 454695a5a6aSMatthias Springer }; 455695a5a6aSMatthias Springer 456ba3a9f51SChia-hung Duan struct TestStrictPatternDriver 457ba3a9f51SChia-hung Duan : public PassWrapper<TestStrictPatternDriver, OperationPass<func::FuncOp>> { 458ba3a9f51SChia-hung Duan public: 459ba3a9f51SChia-hung Duan MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStrictPatternDriver) 460ba3a9f51SChia-hung Duan 461ba3a9f51SChia-hung Duan TestStrictPatternDriver() = default; 4626f469d60SMehdi Amini TestStrictPatternDriver(const TestStrictPatternDriver &other) 4636f469d60SMehdi Amini : PassWrapper(other) { 46487e345b1SMatthias Springer strictMode = other.strictMode; 46587e345b1SMatthias Springer } 466ba3a9f51SChia-hung Duan 467ba3a9f51SChia-hung Duan StringRef getArgument() const final { return "test-strict-pattern-driver"; } 468ba3a9f51SChia-hung Duan StringRef getDescription() const final { 46987e345b1SMatthias Springer return "Test strict mode of pattern driver"; 470ba3a9f51SChia-hung Duan } 471ba3a9f51SChia-hung Duan 472ba3a9f51SChia-hung Duan void runOnOperation() override { 473774416bdSMatthias Springer MLIRContext *ctx = &getContext(); 474774416bdSMatthias Springer mlir::RewritePatternSet patterns(ctx); 4754bba8bd3SIngo Müller patterns.add< 4764bba8bd3SIngo Müller // clang-format off 4774bba8bd3SIngo Müller ChangeBlockOp, 478237a799eSMatthias Springer CloneOp, 479b840d296SMatthias Springer CloneRegionBeforeOp, 480c672b342SMatthias Springer EraseOp, 481da784a25SMatthias Springer ImplicitChangeOp, 482c672b342SMatthias Springer InlineBlocksIntoParent, 483c672b342SMatthias Springer InsertSameOp, 484c672b342SMatthias Springer MoveBeforeParentOp, 485c2675ba9SMatthias Springer ReplaceWithNewOp, 486c2675ba9SMatthias Springer SplitBlockHere 4874bba8bd3SIngo Müller // clang-format on 4884bba8bd3SIngo Müller >(ctx); 489ba3a9f51SChia-hung Duan SmallVector<Operation *> ops; 490ba3a9f51SChia-hung Duan getOperation()->walk([&](Operation *op) { 491ba3a9f51SChia-hung Duan StringRef opName = op->getName().getStringRef(); 4924bba8bd3SIngo Müller if (opName == "test.insert_same_op" || opName == "test.change_block_op" || 493da784a25SMatthias Springer opName == "test.replace_with_new_op" || opName == "test.erase_op" || 494c672b342SMatthias Springer opName == "test.move_before_parent_op" || 495c2675ba9SMatthias Springer opName == "test.inline_blocks_into_parent" || 496b840d296SMatthias Springer opName == "test.split_block_here" || opName == "test.clone_me" || 497b840d296SMatthias Springer opName == "test.clone_region_before") { 498ba3a9f51SChia-hung Duan ops.push_back(op); 499ba3a9f51SChia-hung Duan } 500ba3a9f51SChia-hung Duan }); 501ba3a9f51SChia-hung Duan 502695a5a6aSMatthias Springer DumpNotifications dumpNotifications; 5036bdecbcbSMatthias Springer GreedyRewriteConfig config; 504695a5a6aSMatthias Springer config.listener = &dumpNotifications; 50587e345b1SMatthias Springer if (strictMode == "AnyOp") { 5066bdecbcbSMatthias Springer config.strictMode = GreedyRewriteStrictness::AnyOp; 50787e345b1SMatthias Springer } else if (strictMode == "ExistingAndNewOps") { 5086bdecbcbSMatthias Springer config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; 50987e345b1SMatthias Springer } else if (strictMode == "ExistingOps") { 5106bdecbcbSMatthias Springer config.strictMode = GreedyRewriteStrictness::ExistingOps; 51187e345b1SMatthias Springer } else { 51287e345b1SMatthias Springer llvm_unreachable("invalid strictness option"); 51387e345b1SMatthias Springer } 51487e345b1SMatthias Springer 515ba3a9f51SChia-hung Duan // Check if these transformations introduce visiting of operations that 516ba3a9f51SChia-hung Duan // are not in the `ops` set (The new created ops are valid). An invalid 517ba3a9f51SChia-hung Duan // operation will trigger the assertion while processing. 518774416bdSMatthias Springer bool changed = false; 519774416bdSMatthias Springer bool allErased = false; 52009dfc571SJacques Pienaar (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config, 5216bdecbcbSMatthias Springer &changed, &allErased); 522774416bdSMatthias Springer Builder b(ctx); 523774416bdSMatthias Springer getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); 524774416bdSMatthias Springer getOperation()->setAttr("pattern_driver_all_erased", 525774416bdSMatthias Springer b.getBoolAttr(allErased)); 526ba3a9f51SChia-hung Duan } 527ba3a9f51SChia-hung Duan 52887e345b1SMatthias Springer Option<std::string> strictMode{ 52987e345b1SMatthias Springer *this, "strictness", 53087e345b1SMatthias Springer llvm::cl::desc("Can be {AnyOp, ExistingAndNewOps, ExistingOps}"), 53187e345b1SMatthias Springer llvm::cl::init("AnyOp")}; 53287e345b1SMatthias Springer 533ba3a9f51SChia-hung Duan private: 534ba3a9f51SChia-hung Duan // New inserted operation is valid for further transformation. 535ba3a9f51SChia-hung Duan class InsertSameOp : public RewritePattern { 536ba3a9f51SChia-hung Duan public: 537ba3a9f51SChia-hung Duan InsertSameOp(MLIRContext *context) 538ba3a9f51SChia-hung Duan : RewritePattern("test.insert_same_op", /*benefit=*/1, context) {} 539ba3a9f51SChia-hung Duan 540ba3a9f51SChia-hung Duan LogicalResult matchAndRewrite(Operation *op, 541ba3a9f51SChia-hung Duan PatternRewriter &rewriter) const override { 542ba3a9f51SChia-hung Duan if (op->hasAttr("skip")) 543ba3a9f51SChia-hung Duan return failure(); 544ba3a9f51SChia-hung Duan 545ba3a9f51SChia-hung Duan Operation *newOp = 546ba3a9f51SChia-hung Duan rewriter.create(op->getLoc(), op->getName().getIdentifier(), 547ba3a9f51SChia-hung Duan op->getOperands(), op->getResultTypes()); 5485fcf907bSMatthias Springer rewriter.modifyOpInPlace( 549e219e66eSMatthias Springer op, [&]() { op->setAttr("skip", rewriter.getBoolAttr(true)); }); 550ba3a9f51SChia-hung Duan newOp->setAttr("skip", rewriter.getBoolAttr(true)); 551ba3a9f51SChia-hung Duan 552ba3a9f51SChia-hung Duan return success(); 553ba3a9f51SChia-hung Duan } 554ba3a9f51SChia-hung Duan }; 555ba3a9f51SChia-hung Duan 5564bba8bd3SIngo Müller // Remove an operation may introduce the re-visiting of its operands. 557ba3a9f51SChia-hung Duan class EraseOp : public RewritePattern { 558ba3a9f51SChia-hung Duan public: 559ba3a9f51SChia-hung Duan EraseOp(MLIRContext *context) 560ba3a9f51SChia-hung Duan : RewritePattern("test.erase_op", /*benefit=*/1, context) {} 561ba3a9f51SChia-hung Duan LogicalResult matchAndRewrite(Operation *op, 562ba3a9f51SChia-hung Duan PatternRewriter &rewriter) const override { 563ba3a9f51SChia-hung Duan rewriter.eraseOp(op); 564ba3a9f51SChia-hung Duan return success(); 565ba3a9f51SChia-hung Duan } 566ba3a9f51SChia-hung Duan }; 5674bba8bd3SIngo Müller 5684bba8bd3SIngo Müller // The following two patterns test RewriterBase::replaceAllUsesWith. 5694bba8bd3SIngo Müller // 5704bba8bd3SIngo Müller // That function replaces all usages of a Block (or a Value) with another one 5714bba8bd3SIngo Müller // *and tracks these changes in the rewriter.* The GreedyPatternRewriteDriver 5724bba8bd3SIngo Müller // with GreedyRewriteStrictness::AnyOp uses that tracking to construct its 5734bba8bd3SIngo Müller // worklist: when an op is modified, it is added to the worklist. The two 5744bba8bd3SIngo Müller // patterns below make the tracking observable: ChangeBlockOp replaces all 5754bba8bd3SIngo Müller // usages of a block and that pattern is applied because the corresponding ops 5764bba8bd3SIngo Müller // are put on the initial worklist (see above). ImplicitChangeOp does an 5774bba8bd3SIngo Müller // unrelated change but ops of the corresponding type are *not* on the initial 5784bba8bd3SIngo Müller // worklist, so the effect of the second pattern is only visible if the 5794bba8bd3SIngo Müller // tracking and subsequent adding to the worklist actually works. 5804bba8bd3SIngo Müller 5814bba8bd3SIngo Müller // Replace all usages of the first successor with the second successor. 5824bba8bd3SIngo Müller class ChangeBlockOp : public RewritePattern { 5834bba8bd3SIngo Müller public: 5844bba8bd3SIngo Müller ChangeBlockOp(MLIRContext *context) 5854bba8bd3SIngo Müller : RewritePattern("test.change_block_op", /*benefit=*/1, context) {} 5864bba8bd3SIngo Müller LogicalResult matchAndRewrite(Operation *op, 5874bba8bd3SIngo Müller PatternRewriter &rewriter) const override { 5884bba8bd3SIngo Müller if (op->getNumSuccessors() < 2) 5894bba8bd3SIngo Müller return failure(); 5904bba8bd3SIngo Müller Block *firstSuccessor = op->getSuccessor(0); 5914bba8bd3SIngo Müller Block *secondSuccessor = op->getSuccessor(1); 5924bba8bd3SIngo Müller if (firstSuccessor == secondSuccessor) 5934bba8bd3SIngo Müller return failure(); 5944bba8bd3SIngo Müller // This is the function being tested: 5954bba8bd3SIngo Müller rewriter.replaceAllUsesWith(firstSuccessor, secondSuccessor); 5964bba8bd3SIngo Müller // Using the following line instead would make the test fail: 5974bba8bd3SIngo Müller // firstSuccessor->replaceAllUsesWith(secondSuccessor); 5984bba8bd3SIngo Müller return success(); 5994bba8bd3SIngo Müller } 6004bba8bd3SIngo Müller }; 6014bba8bd3SIngo Müller 6024bba8bd3SIngo Müller // Changes the successor to the parent block. 6034bba8bd3SIngo Müller class ImplicitChangeOp : public RewritePattern { 6044bba8bd3SIngo Müller public: 6054bba8bd3SIngo Müller ImplicitChangeOp(MLIRContext *context) 6064bba8bd3SIngo Müller : RewritePattern("test.implicit_change_op", /*benefit=*/1, context) {} 6074bba8bd3SIngo Müller LogicalResult matchAndRewrite(Operation *op, 6084bba8bd3SIngo Müller PatternRewriter &rewriter) const override { 6094bba8bd3SIngo Müller if (op->getNumSuccessors() < 1 || op->getSuccessor(0) == op->getBlock()) 6104bba8bd3SIngo Müller return failure(); 6115fcf907bSMatthias Springer rewriter.modifyOpInPlace(op, 6125fcf907bSMatthias Springer [&]() { op->setSuccessor(op->getBlock(), 0); }); 6134bba8bd3SIngo Müller return success(); 6144bba8bd3SIngo Müller } 6154bba8bd3SIngo Müller }; 616ba3a9f51SChia-hung Duan }; 617ba3a9f51SChia-hung Duan 6180f8a6b7dSJakub Kuderski struct TestWalkPatternDriver final 6190f8a6b7dSJakub Kuderski : PassWrapper<TestWalkPatternDriver, OperationPass<>> { 6200f8a6b7dSJakub Kuderski MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver) 6210f8a6b7dSJakub Kuderski 6220f8a6b7dSJakub Kuderski TestWalkPatternDriver() = default; 6230f8a6b7dSJakub Kuderski TestWalkPatternDriver(const TestWalkPatternDriver &other) 6240f8a6b7dSJakub Kuderski : PassWrapper(other) {} 6250f8a6b7dSJakub Kuderski 6260f8a6b7dSJakub Kuderski StringRef getArgument() const override { 6270f8a6b7dSJakub Kuderski return "test-walk-pattern-rewrite-driver"; 6280f8a6b7dSJakub Kuderski } 6290f8a6b7dSJakub Kuderski StringRef getDescription() const override { 6300f8a6b7dSJakub Kuderski return "Run test walk pattern rewrite driver"; 6310f8a6b7dSJakub Kuderski } 6320f8a6b7dSJakub Kuderski void runOnOperation() override { 6330f8a6b7dSJakub Kuderski mlir::RewritePatternSet patterns(&getContext()); 6340f8a6b7dSJakub Kuderski 6350f8a6b7dSJakub Kuderski // Patterns for testing the WalkPatternRewriteDriver. 6360f8a6b7dSJakub Kuderski patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp, 6370f8a6b7dSJakub Kuderski MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>( 6380f8a6b7dSJakub Kuderski &getContext()); 6390f8a6b7dSJakub Kuderski 6400f8a6b7dSJakub Kuderski DumpNotifications dumpListener; 6410f8a6b7dSJakub Kuderski walkAndApplyPatterns(getOperation(), std::move(patterns), 6420f8a6b7dSJakub Kuderski dumpNotifications ? &dumpListener : nullptr); 6430f8a6b7dSJakub Kuderski } 6440f8a6b7dSJakub Kuderski 6450f8a6b7dSJakub Kuderski Option<bool> dumpNotifications{ 6460f8a6b7dSJakub Kuderski *this, "dump-notifications", 6470f8a6b7dSJakub Kuderski llvm::cl::desc("Print rewrite listener notifications"), 6480f8a6b7dSJakub Kuderski llvm::cl::init(false)}; 6490f8a6b7dSJakub Kuderski }; 6500f8a6b7dSJakub Kuderski 651be0a7e9fSMehdi Amini } // namespace 652fec6c5acSUday Bondhugula 653fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 654fec6c5acSUday Bondhugula // ReturnType Driver. 655fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 656fec6c5acSUday Bondhugula 657fec6c5acSUday Bondhugula namespace { 658fec6c5acSUday Bondhugula // Generate ops for each instance where the type can be successfully inferred. 659fec6c5acSUday Bondhugula template <typename OpTy> 660fec6c5acSUday Bondhugula static void invokeCreateWithInferredReturnType(Operation *op) { 661fec6c5acSUday Bondhugula auto *context = op->getContext(); 66258ceae95SRiver Riddle auto fop = op->getParentOfType<func::FuncOp>(); 663fec6c5acSUday Bondhugula auto location = UnknownLoc::get(context); 664fec6c5acSUday Bondhugula OpBuilder b(op); 665fec6c5acSUday Bondhugula b.setInsertionPointAfter(op); 666fec6c5acSUday Bondhugula 667fec6c5acSUday Bondhugula // Use permutations of 2 args as operands. 668fec6c5acSUday Bondhugula assert(fop.getNumArguments() >= 2); 669fec6c5acSUday Bondhugula for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { 670fec6c5acSUday Bondhugula for (int j = 0; j < e; ++j) { 671fec6c5acSUday Bondhugula std::array<Value, 2> values = {{fop.getArgument(i), fop.getArgument(j)}}; 672fec6c5acSUday Bondhugula SmallVector<Type, 2> inferredReturnTypes; 6735eae715aSJacques Pienaar if (succeeded(OpTy::inferReturnTypes( 674bbe5bf17SMehdi Amini context, std::nullopt, values, op->getDiscardableAttrDictionary(), 6755e118f93SMehdi Amini op->getPropertiesStorage(), op->getRegions(), 6765e118f93SMehdi Amini inferredReturnTypes))) { 677fec6c5acSUday Bondhugula OperationState state(location, OpTy::getOperationName()); 6789db53a18SRiver Riddle // TODO: Expand to regions. 679bb1d976fSAlex Zinenko OpTy::build(b, state, values, op->getAttrs()); 68014ecafd0SChia-hung Duan (void)b.create(state); 681fec6c5acSUday Bondhugula } 682fec6c5acSUday Bondhugula } 683fec6c5acSUday Bondhugula } 684fec6c5acSUday Bondhugula } 685fec6c5acSUday Bondhugula 686fec6c5acSUday Bondhugula static void reifyReturnShape(Operation *op) { 687fec6c5acSUday Bondhugula OpBuilder b(op); 688fec6c5acSUday Bondhugula 689fec6c5acSUday Bondhugula // Use permutations of 2 args as operands. 690fec6c5acSUday Bondhugula auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op); 691fec6c5acSUday Bondhugula SmallVector<Value, 2> shapes; 692851d02f6SWenyi Zhao if (failed(shapedOp.reifyReturnTypeShapes(b, op->getOperands(), shapes)) || 6939b051703SMaheshRavishankar !llvm::hasSingleElement(shapes)) 694fec6c5acSUday Bondhugula return; 69589de9cc8SMehdi Amini for (const auto &it : llvm::enumerate(shapes)) { 696fec6c5acSUday Bondhugula op->emitRemark() << "value " << it.index() << ": " 697fec6c5acSUday Bondhugula << it.value().getDefiningOp(); 698fec6c5acSUday Bondhugula } 6999b051703SMaheshRavishankar } 700fec6c5acSUday Bondhugula 70180aca1eaSRiver Riddle struct TestReturnTypeDriver 70258ceae95SRiver Riddle : public PassWrapper<TestReturnTypeDriver, OperationPass<func::FuncOp>> { 7035e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReturnTypeDriver) 7045e50dd04SRiver Riddle 705e2310704SJulian Gross void getDependentDialects(DialectRegistry ®istry) const override { 706c0a6318dSMatthias Springer registry.insert<tensor::TensorDialect>(); 707e2310704SJulian Gross } 708b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-return-type"; } 709b5e22e6dSMehdi Amini StringRef getDescription() const final { return "Run return type functions"; } 710e2310704SJulian Gross 71141574554SRiver Riddle void runOnOperation() override { 71241574554SRiver Riddle if (getOperation().getName() == "testCreateFunctions") { 713fec6c5acSUday Bondhugula std::vector<Operation *> ops; 714fec6c5acSUday Bondhugula // Collect ops to avoid triggering on inserted ops. 71541574554SRiver Riddle for (auto &op : getOperation().getBody().front()) 716fec6c5acSUday Bondhugula ops.push_back(&op); 717fec6c5acSUday Bondhugula // Generate test patterns for each, but skip terminator. 718984b800aSserge-sans-paille for (auto *op : llvm::ArrayRef(ops).drop_back()) { 719fec6c5acSUday Bondhugula // Test create method of each of the Op classes below. The resultant 720fec6c5acSUday Bondhugula // output would be in reverse order underneath `op` from which 721fec6c5acSUday Bondhugula // the attributes and regions are used. 722fec6c5acSUday Bondhugula invokeCreateWithInferredReturnType<OpWithInferTypeInterfaceOp>(op); 7235267ed05SAmanda Tang invokeCreateWithInferredReturnType<OpWithInferTypeAdaptorInterfaceOp>( 7245267ed05SAmanda Tang op); 725fec6c5acSUday Bondhugula invokeCreateWithInferredReturnType< 726fec6c5acSUday Bondhugula OpWithShapedTypeInferTypeInterfaceOp>(op); 727fec6c5acSUday Bondhugula }; 728fec6c5acSUday Bondhugula return; 729fec6c5acSUday Bondhugula } 73041574554SRiver Riddle if (getOperation().getName() == "testReifyFunctions") { 731fec6c5acSUday Bondhugula std::vector<Operation *> ops; 732fec6c5acSUday Bondhugula // Collect ops to avoid triggering on inserted ops. 73341574554SRiver Riddle for (auto &op : getOperation().getBody().front()) 734fec6c5acSUday Bondhugula if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op)) 735fec6c5acSUday Bondhugula ops.push_back(&op); 736fec6c5acSUday Bondhugula // Generate test patterns for each, but skip terminator. 737fec6c5acSUday Bondhugula for (auto *op : ops) 738fec6c5acSUday Bondhugula reifyReturnShape(op); 739fec6c5acSUday Bondhugula } 740fec6c5acSUday Bondhugula } 741fec6c5acSUday Bondhugula }; 742be0a7e9fSMehdi Amini } // namespace 743fec6c5acSUday Bondhugula 7449ba37b3bSJacques Pienaar namespace { 7459ba37b3bSJacques Pienaar struct TestDerivedAttributeDriver 74658ceae95SRiver Riddle : public PassWrapper<TestDerivedAttributeDriver, 74758ceae95SRiver Riddle OperationPass<func::FuncOp>> { 7485e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDerivedAttributeDriver) 7495e50dd04SRiver Riddle 750b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-derived-attr"; } 751b5e22e6dSMehdi Amini StringRef getDescription() const final { 752b5e22e6dSMehdi Amini return "Run test derived attributes"; 753b5e22e6dSMehdi Amini } 75441574554SRiver Riddle void runOnOperation() override; 7559ba37b3bSJacques Pienaar }; 756be0a7e9fSMehdi Amini } // namespace 7579ba37b3bSJacques Pienaar 75841574554SRiver Riddle void TestDerivedAttributeDriver::runOnOperation() { 75941574554SRiver Riddle getOperation().walk([](DerivedAttributeOpInterface dOp) { 7609ba37b3bSJacques Pienaar auto dAttr = dOp.materializeDerivedAttributes(); 7619ba37b3bSJacques Pienaar if (!dAttr) 7629ba37b3bSJacques Pienaar return; 7639ba37b3bSJacques Pienaar for (auto d : dAttr) 7640c7890c8SRiver Riddle dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); 7659ba37b3bSJacques Pienaar }); 7669ba37b3bSJacques Pienaar } 7679ba37b3bSJacques Pienaar 768fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 769fec6c5acSUday Bondhugula // Legalization Driver. 770fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 771fec6c5acSUday Bondhugula 772fec6c5acSUday Bondhugula namespace { 773fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 774fec6c5acSUday Bondhugula // Region-Block Rewrite Testing 775fec6c5acSUday Bondhugula 776684a5a30SMatthias Springer /// This pattern applies a signature conversion to a block inside a detached 777684a5a30SMatthias Springer /// region. 778684a5a30SMatthias Springer struct TestDetachedSignatureConversion : public ConversionPattern { 779684a5a30SMatthias Springer TestDetachedSignatureConversion(MLIRContext *ctx) 780684a5a30SMatthias Springer : ConversionPattern("test.detached_signature_conversion", /*benefit=*/1, 781684a5a30SMatthias Springer ctx) {} 782684a5a30SMatthias Springer 783684a5a30SMatthias Springer LogicalResult 784684a5a30SMatthias Springer matchAndRewrite(Operation *op, ArrayRef<Value> operands, 785684a5a30SMatthias Springer ConversionPatternRewriter &rewriter) const final { 786684a5a30SMatthias Springer if (op->getNumRegions() != 1) 787684a5a30SMatthias Springer return failure(); 788412e30b2SMatthias Springer OperationState state(op->getLoc(), "test.legal_op", operands, 789684a5a30SMatthias Springer op->getResultTypes(), {}, BlockRange()); 790684a5a30SMatthias Springer Region *newRegion = state.addRegion(); 791684a5a30SMatthias Springer rewriter.inlineRegionBefore(op->getRegion(0), *newRegion, 792684a5a30SMatthias Springer newRegion->begin()); 793684a5a30SMatthias Springer TypeConverter::SignatureConversion result(newRegion->getNumArguments()); 794684a5a30SMatthias Springer for (unsigned i = 0, e = newRegion->getNumArguments(); i < e; ++i) 795684a5a30SMatthias Springer result.addInputs(i, rewriter.getF64Type()); 796684a5a30SMatthias Springer rewriter.applySignatureConversion(&newRegion->front(), result); 797684a5a30SMatthias Springer Operation *newOp = rewriter.create(state); 798684a5a30SMatthias Springer rewriter.replaceOp(op, newOp->getResults()); 799684a5a30SMatthias Springer return success(); 800684a5a30SMatthias Springer } 801684a5a30SMatthias Springer }; 802684a5a30SMatthias Springer 803fec6c5acSUday Bondhugula /// This pattern is a simple pattern that inlines the first region of a given 804fec6c5acSUday Bondhugula /// operation into the parent region. 805fec6c5acSUday Bondhugula struct TestRegionRewriteBlockMovement : public ConversionPattern { 806fec6c5acSUday Bondhugula TestRegionRewriteBlockMovement(MLIRContext *ctx) 807fec6c5acSUday Bondhugula : ConversionPattern("test.region", 1, ctx) {} 808fec6c5acSUday Bondhugula 809fec6c5acSUday Bondhugula LogicalResult 810fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 811fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 812fec6c5acSUday Bondhugula // Inline this region into the parent region. 813fec6c5acSUday Bondhugula auto &parentRegion = *op->getParentRegion(); 814b0750e2dSTres Popp auto &opRegion = op->getRegion(0); 815830b9b07SMehdi Amini if (op->getDiscardableAttr("legalizer.should_clone")) 816b0750e2dSTres Popp rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); 817fec6c5acSUday Bondhugula else 818b0750e2dSTres Popp rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); 819b0750e2dSTres Popp 820830b9b07SMehdi Amini if (op->getDiscardableAttr("legalizer.erase_old_blocks")) { 821b0750e2dSTres Popp while (!opRegion.empty()) 822b0750e2dSTres Popp rewriter.eraseBlock(&opRegion.front()); 823b0750e2dSTres Popp } 824fec6c5acSUday Bondhugula 825fec6c5acSUday Bondhugula // Drop this operation. 826fec6c5acSUday Bondhugula rewriter.eraseOp(op); 827fec6c5acSUday Bondhugula return success(); 828fec6c5acSUday Bondhugula } 829fec6c5acSUday Bondhugula }; 830fec6c5acSUday Bondhugula /// This pattern is a simple pattern that generates a region containing an 831fec6c5acSUday Bondhugula /// illegal operation. 832fec6c5acSUday Bondhugula struct TestRegionRewriteUndo : public RewritePattern { 833fec6c5acSUday Bondhugula TestRegionRewriteUndo(MLIRContext *ctx) 834fec6c5acSUday Bondhugula : RewritePattern("test.region_builder", 1, ctx) {} 835fec6c5acSUday Bondhugula 836fec6c5acSUday Bondhugula LogicalResult matchAndRewrite(Operation *op, 837fec6c5acSUday Bondhugula PatternRewriter &rewriter) const final { 838fec6c5acSUday Bondhugula // Create the region operation with an entry block containing arguments. 839fec6c5acSUday Bondhugula OperationState newRegion(op->getLoc(), "test.region"); 840fec6c5acSUday Bondhugula newRegion.addRegion(); 84114ecafd0SChia-hung Duan auto *regionOp = rewriter.create(newRegion); 842fec6c5acSUday Bondhugula auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); 843e084679fSRiver Riddle entryBlock->addArgument(rewriter.getIntegerType(64), 844e084679fSRiver Riddle rewriter.getUnknownLoc()); 845fec6c5acSUday Bondhugula 846fec6c5acSUday Bondhugula // Add an explicitly illegal operation to ensure the conversion fails. 847fec6c5acSUday Bondhugula rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getIntegerType(32)); 848fec6c5acSUday Bondhugula rewriter.create<TestValidOp>(op->getLoc(), ArrayRef<Value>()); 849fec6c5acSUday Bondhugula 850fec6c5acSUday Bondhugula // Drop this operation. 851fec6c5acSUday Bondhugula rewriter.eraseOp(op); 852fec6c5acSUday Bondhugula return success(); 853fec6c5acSUday Bondhugula } 854fec6c5acSUday Bondhugula }; 855f27f1e8cSAlex Zinenko /// A simple pattern that creates a block at the end of the parent region of the 856f27f1e8cSAlex Zinenko /// matched operation. 857f27f1e8cSAlex Zinenko struct TestCreateBlock : public RewritePattern { 858f27f1e8cSAlex Zinenko TestCreateBlock(MLIRContext *ctx) 859f27f1e8cSAlex Zinenko : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} 860f27f1e8cSAlex Zinenko 861f27f1e8cSAlex Zinenko LogicalResult matchAndRewrite(Operation *op, 862f27f1e8cSAlex Zinenko PatternRewriter &rewriter) const final { 863f27f1e8cSAlex Zinenko Region ®ion = *op->getParentRegion(); 864f27f1e8cSAlex Zinenko Type i32Type = rewriter.getIntegerType(32); 865e084679fSRiver Riddle Location loc = op->getLoc(); 866e084679fSRiver Riddle rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 867e084679fSRiver Riddle rewriter.create<TerminatorOp>(loc); 86871d50c89SMatthias Springer rewriter.eraseOp(op); 869f27f1e8cSAlex Zinenko return success(); 870f27f1e8cSAlex Zinenko } 871f27f1e8cSAlex Zinenko }; 872f27f1e8cSAlex Zinenko 873a23d0559SKazuaki Ishizaki /// A simple pattern that creates a block containing an invalid operation in 874f27f1e8cSAlex Zinenko /// order to trigger the block creation undo mechanism. 875f27f1e8cSAlex Zinenko struct TestCreateIllegalBlock : public RewritePattern { 876f27f1e8cSAlex Zinenko TestCreateIllegalBlock(MLIRContext *ctx) 877f27f1e8cSAlex Zinenko : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} 878f27f1e8cSAlex Zinenko 879f27f1e8cSAlex Zinenko LogicalResult matchAndRewrite(Operation *op, 880f27f1e8cSAlex Zinenko PatternRewriter &rewriter) const final { 881f27f1e8cSAlex Zinenko Region ®ion = *op->getParentRegion(); 882f27f1e8cSAlex Zinenko Type i32Type = rewriter.getIntegerType(32); 883e084679fSRiver Riddle Location loc = op->getLoc(); 884e084679fSRiver Riddle rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}, {loc, loc}); 885f27f1e8cSAlex Zinenko // Create an illegal op to ensure the conversion fails. 886e084679fSRiver Riddle rewriter.create<ILLegalOpF>(loc, i32Type); 887e084679fSRiver Riddle rewriter.create<TerminatorOp>(loc); 88871d50c89SMatthias Springer rewriter.eraseOp(op); 889f27f1e8cSAlex Zinenko return success(); 890f27f1e8cSAlex Zinenko } 891f27f1e8cSAlex Zinenko }; 892fec6c5acSUday Bondhugula 8930816de16SRiver Riddle /// A simple pattern that tests the undo mechanism when replacing the uses of a 8940816de16SRiver Riddle /// block argument. 8950816de16SRiver Riddle struct TestUndoBlockArgReplace : public ConversionPattern { 8960816de16SRiver Riddle TestUndoBlockArgReplace(MLIRContext *ctx) 8970816de16SRiver Riddle : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} 8980816de16SRiver Riddle 8990816de16SRiver Riddle LogicalResult 9000816de16SRiver Riddle matchAndRewrite(Operation *op, ArrayRef<Value> operands, 9010816de16SRiver Riddle ConversionPatternRewriter &rewriter) const final { 9020816de16SRiver Riddle auto illegalOp = 9030816de16SRiver Riddle rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 904e2b71610SRahul Joshi rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), 905a0b21046SRahul Kayaith illegalOp->getResult(0)); 9065fcf907bSMatthias Springer rewriter.modifyOpInPlace(op, [] {}); 9070816de16SRiver Riddle return success(); 9080816de16SRiver Riddle } 9090816de16SRiver Riddle }; 9100816de16SRiver Riddle 9118f4cd2c7SMatthias Springer /// This pattern hoists ops out of a "test.hoist_me" and then fails conversion. 9128f4cd2c7SMatthias Springer /// This is to test the rollback logic. 9138f4cd2c7SMatthias Springer struct TestUndoMoveOpBefore : public ConversionPattern { 9148f4cd2c7SMatthias Springer TestUndoMoveOpBefore(MLIRContext *ctx) 9158f4cd2c7SMatthias Springer : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {} 9168f4cd2c7SMatthias Springer 9178f4cd2c7SMatthias Springer LogicalResult 9188f4cd2c7SMatthias Springer matchAndRewrite(Operation *op, ArrayRef<Value> operands, 9198f4cd2c7SMatthias Springer ConversionPatternRewriter &rewriter) const override { 9208f4cd2c7SMatthias Springer rewriter.moveOpBefore(op, op->getParentOp()); 9218f4cd2c7SMatthias Springer // Replace with an illegal op to ensure the conversion fails. 9228f4cd2c7SMatthias Springer rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type()); 9238f4cd2c7SMatthias Springer return success(); 9248f4cd2c7SMatthias Springer } 9258f4cd2c7SMatthias Springer }; 9268f4cd2c7SMatthias Springer 927df48026bSAlex Zinenko /// A rewrite pattern that tests the undo mechanism when erasing a block. 928df48026bSAlex Zinenko struct TestUndoBlockErase : public ConversionPattern { 929df48026bSAlex Zinenko TestUndoBlockErase(MLIRContext *ctx) 930df48026bSAlex Zinenko : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} 931df48026bSAlex Zinenko 932df48026bSAlex Zinenko LogicalResult 933df48026bSAlex Zinenko matchAndRewrite(Operation *op, ArrayRef<Value> operands, 934df48026bSAlex Zinenko ConversionPatternRewriter &rewriter) const final { 935df48026bSAlex Zinenko Block *secondBlock = &*std::next(op->getRegion(0).begin()); 936df48026bSAlex Zinenko rewriter.setInsertionPointToStart(secondBlock); 937df48026bSAlex Zinenko rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 938df48026bSAlex Zinenko rewriter.eraseBlock(secondBlock); 9395fcf907bSMatthias Springer rewriter.modifyOpInPlace(op, [] {}); 940df48026bSAlex Zinenko return success(); 941df48026bSAlex Zinenko } 942df48026bSAlex Zinenko }; 943df48026bSAlex Zinenko 9443a70335bSMatthias Springer /// A pattern that modifies a property in-place, but keeps the op illegal. 9453a70335bSMatthias Springer struct TestUndoPropertiesModification : public ConversionPattern { 9463a70335bSMatthias Springer TestUndoPropertiesModification(MLIRContext *ctx) 9473a70335bSMatthias Springer : ConversionPattern("test.with_properties", /*benefit=*/1, ctx) {} 9483a70335bSMatthias Springer LogicalResult 9493a70335bSMatthias Springer matchAndRewrite(Operation *op, ArrayRef<Value> operands, 9503a70335bSMatthias Springer ConversionPatternRewriter &rewriter) const final { 9513a70335bSMatthias Springer if (!op->hasAttr("modify_inplace")) 9523a70335bSMatthias Springer return failure(); 9533a70335bSMatthias Springer rewriter.modifyOpInPlace( 9543a70335bSMatthias Springer op, [&]() { cast<TestOpWithProperties>(op).getProperties().setA(42); }); 9553a70335bSMatthias Springer return success(); 9563a70335bSMatthias Springer } 9573a70335bSMatthias Springer }; 9583a70335bSMatthias Springer 959fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 960fec6c5acSUday Bondhugula // Type-Conversion Rewrite Testing 961fec6c5acSUday Bondhugula 962fec6c5acSUday Bondhugula /// This patterns erases a region operation that has had a type conversion. 963fec6c5acSUday Bondhugula struct TestDropOpSignatureConversion : public ConversionPattern { 964ce254598SMatthias Springer TestDropOpSignatureConversion(MLIRContext *ctx, 965ce254598SMatthias Springer const TypeConverter &converter) 96676f3c2f3SRiver Riddle : ConversionPattern(converter, "test.drop_region_op", 1, ctx) {} 967fec6c5acSUday Bondhugula LogicalResult 968fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 969fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const override { 970fec6c5acSUday Bondhugula Region ®ion = op->getRegion(0); 971fec6c5acSUday Bondhugula Block *entry = ®ion.front(); 972fec6c5acSUday Bondhugula 973fec6c5acSUday Bondhugula // Convert the original entry arguments. 974ce254598SMatthias Springer const TypeConverter &converter = *getTypeConverter(); 975fec6c5acSUday Bondhugula TypeConverter::SignatureConversion result(entry->getNumArguments()); 9768d67d187SRiver Riddle if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), 9778d67d187SRiver Riddle result)) || 9788d67d187SRiver Riddle failed(rewriter.convertRegionTypes(®ion, converter, &result))) 979fec6c5acSUday Bondhugula return failure(); 980fec6c5acSUday Bondhugula 981fec6c5acSUday Bondhugula // Convert the region signature and just drop the operation. 982fec6c5acSUday Bondhugula rewriter.eraseOp(op); 983fec6c5acSUday Bondhugula return success(); 984fec6c5acSUday Bondhugula } 985fec6c5acSUday Bondhugula }; 986fec6c5acSUday Bondhugula /// This pattern simply updates the operands of the given operation. 987fec6c5acSUday Bondhugula struct TestPassthroughInvalidOp : public ConversionPattern { 9883cc311abSMatthias Springer TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter) 9893cc311abSMatthias Springer : ConversionPattern(converter, "test.invalid", 1, ctx) {} 990fec6c5acSUday Bondhugula LogicalResult 9919df63b26SMatthias Springer matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 992fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 9939df63b26SMatthias Springer SmallVector<Value> flattened; 9949df63b26SMatthias Springer for (auto it : llvm::enumerate(operands)) { 9959df63b26SMatthias Springer ValueRange range = it.value(); 9969df63b26SMatthias Springer if (range.size() == 1) { 9979df63b26SMatthias Springer flattened.push_back(range.front()); 9989df63b26SMatthias Springer continue; 9999df63b26SMatthias Springer } 10009df63b26SMatthias Springer 10019df63b26SMatthias Springer // This is a 1:N replacement. Insert a test.cast op. (That's what the 10029df63b26SMatthias Springer // argument materialization used to do.) 10039df63b26SMatthias Springer flattened.push_back( 10049df63b26SMatthias Springer rewriter 10059df63b26SMatthias Springer .create<TestCastOp>(op->getLoc(), 10069df63b26SMatthias Springer op->getOperand(it.index()).getType(), range) 10079df63b26SMatthias Springer .getResult()); 10089df63b26SMatthias Springer } 10099df63b26SMatthias Springer rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened, 10101a36588eSKazu Hirata std::nullopt); 1011fec6c5acSUday Bondhugula return success(); 1012fec6c5acSUday Bondhugula } 1013fec6c5acSUday Bondhugula }; 10143815f478SMatthias Springer /// Replace with valid op, but simply drop the operands. This is used in a 10153815f478SMatthias Springer /// regression where we used to generate circular unrealized_conversion_cast 10163815f478SMatthias Springer /// ops. 10173815f478SMatthias Springer struct TestDropAndReplaceInvalidOp : public ConversionPattern { 10183815f478SMatthias Springer TestDropAndReplaceInvalidOp(MLIRContext *ctx, const TypeConverter &converter) 10193815f478SMatthias Springer : ConversionPattern(converter, 10203815f478SMatthias Springer "test.drop_operands_and_replace_with_valid", 1, ctx) { 10213815f478SMatthias Springer } 10223815f478SMatthias Springer LogicalResult 10233815f478SMatthias Springer matchAndRewrite(Operation *op, ArrayRef<Value> operands, 10243815f478SMatthias Springer ConversionPatternRewriter &rewriter) const final { 10253815f478SMatthias Springer rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, ValueRange(), 10263815f478SMatthias Springer std::nullopt); 10273815f478SMatthias Springer return success(); 10283815f478SMatthias Springer } 10293815f478SMatthias Springer }; 1030fec6c5acSUday Bondhugula /// This pattern handles the case of a split return value. 1031fec6c5acSUday Bondhugula struct TestSplitReturnType : public ConversionPattern { 1032fec6c5acSUday Bondhugula TestSplitReturnType(MLIRContext *ctx) 1033fec6c5acSUday Bondhugula : ConversionPattern("test.return", 1, ctx) {} 1034fec6c5acSUday Bondhugula LogicalResult 10359df63b26SMatthias Springer matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 1036fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 1037fec6c5acSUday Bondhugula // Check for a return of F32. 1038fec6c5acSUday Bondhugula if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) 1039fec6c5acSUday Bondhugula return failure(); 10409df63b26SMatthias Springer rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]); 1041fec6c5acSUday Bondhugula return success(); 1042fec6c5acSUday Bondhugula } 1043fec6c5acSUday Bondhugula }; 1044fec6c5acSUday Bondhugula 1045fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 1046fec6c5acSUday Bondhugula // Multi-Level Type-Conversion Rewrite Testing 1047fec6c5acSUday Bondhugula struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { 1048fec6c5acSUday Bondhugula TestChangeProducerTypeI32ToF32(MLIRContext *ctx) 1049fec6c5acSUday Bondhugula : ConversionPattern("test.type_producer", 1, ctx) {} 1050fec6c5acSUday Bondhugula LogicalResult 1051fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1052fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 1053fec6c5acSUday Bondhugula // If the type is I32, change the type to F32. 1054fec6c5acSUday Bondhugula if (!Type(*op->result_type_begin()).isSignlessInteger(32)) 1055fec6c5acSUday Bondhugula return failure(); 1056fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF32Type()); 1057fec6c5acSUday Bondhugula return success(); 1058fec6c5acSUday Bondhugula } 1059fec6c5acSUday Bondhugula }; 1060fec6c5acSUday Bondhugula struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { 1061fec6c5acSUday Bondhugula TestChangeProducerTypeF32ToF64(MLIRContext *ctx) 1062fec6c5acSUday Bondhugula : ConversionPattern("test.type_producer", 1, ctx) {} 1063fec6c5acSUday Bondhugula LogicalResult 1064fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1065fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 1066fec6c5acSUday Bondhugula // If the type is F32, change the type to F64. 1067fec6c5acSUday Bondhugula if (!Type(*op->result_type_begin()).isF32()) 1068fec6c5acSUday Bondhugula return rewriter.notifyMatchFailure(op, "expected single f32 operand"); 1069fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getF64Type()); 1070fec6c5acSUday Bondhugula return success(); 1071fec6c5acSUday Bondhugula } 1072fec6c5acSUday Bondhugula }; 1073fec6c5acSUday Bondhugula struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { 1074fec6c5acSUday Bondhugula TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) 1075fec6c5acSUday Bondhugula : ConversionPattern("test.type_producer", 10, ctx) {} 1076fec6c5acSUday Bondhugula LogicalResult 1077fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1078fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 1079fec6c5acSUday Bondhugula // Always convert to B16, even though it is not a legal type. This tests 1080fec6c5acSUday Bondhugula // that values are unmapped correctly. 1081fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, rewriter.getBF16Type()); 1082fec6c5acSUday Bondhugula return success(); 1083fec6c5acSUday Bondhugula } 1084fec6c5acSUday Bondhugula }; 1085fec6c5acSUday Bondhugula struct TestUpdateConsumerType : public ConversionPattern { 1086fec6c5acSUday Bondhugula TestUpdateConsumerType(MLIRContext *ctx) 1087fec6c5acSUday Bondhugula : ConversionPattern("test.type_consumer", 1, ctx) {} 1088fec6c5acSUday Bondhugula LogicalResult 1089fec6c5acSUday Bondhugula matchAndRewrite(Operation *op, ArrayRef<Value> operands, 1090fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const final { 1091fec6c5acSUday Bondhugula // Verify that the incoming operand has been successfully remapped to F64. 1092fec6c5acSUday Bondhugula if (!operands[0].getType().isF64()) 1093fec6c5acSUday Bondhugula return failure(); 1094fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<TestTypeConsumerOp>(op, operands[0]); 1095fec6c5acSUday Bondhugula return success(); 1096fec6c5acSUday Bondhugula } 1097fec6c5acSUday Bondhugula }; 1098fec6c5acSUday Bondhugula 1099fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 1100fec6c5acSUday Bondhugula // Non-Root Replacement Rewrite Testing 1101fec6c5acSUday Bondhugula /// This pattern generates an invalid operation, but replaces it before the 1102fec6c5acSUday Bondhugula /// pattern is finished. This checks that we don't need to legalize the 1103fec6c5acSUday Bondhugula /// temporary op. 1104fec6c5acSUday Bondhugula struct TestNonRootReplacement : public RewritePattern { 1105fec6c5acSUday Bondhugula TestNonRootReplacement(MLIRContext *ctx) 1106fec6c5acSUday Bondhugula : RewritePattern("test.replace_non_root", 1, ctx) {} 1107fec6c5acSUday Bondhugula 1108fec6c5acSUday Bondhugula LogicalResult matchAndRewrite(Operation *op, 1109fec6c5acSUday Bondhugula PatternRewriter &rewriter) const final { 1110fec6c5acSUday Bondhugula auto resultType = *op->result_type_begin(); 1111fec6c5acSUday Bondhugula auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType); 1112fec6c5acSUday Bondhugula auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType); 1113fec6c5acSUday Bondhugula 111471d50c89SMatthias Springer rewriter.replaceOp(illegalOp, legalOp); 111571d50c89SMatthias Springer rewriter.replaceOp(op, illegalOp); 1116fec6c5acSUday Bondhugula return success(); 1117fec6c5acSUday Bondhugula } 1118fec6c5acSUday Bondhugula }; 1119bd1ccfe6SRiver Riddle 1120bd1ccfe6SRiver Riddle //===----------------------------------------------------------------------===// 1121bd1ccfe6SRiver Riddle // Recursive Rewrite Testing 1122bd1ccfe6SRiver Riddle /// This pattern is applied to the same operation multiple times, but has a 1123bd1ccfe6SRiver Riddle /// bounded recursion. 1124bd1ccfe6SRiver Riddle struct TestBoundedRecursiveRewrite 1125bd1ccfe6SRiver Riddle : public OpRewritePattern<TestRecursiveRewriteOp> { 11262257e4a7SRiver Riddle using OpRewritePattern<TestRecursiveRewriteOp>::OpRewritePattern; 11272257e4a7SRiver Riddle 11282257e4a7SRiver Riddle void initialize() { 1129b99bd771SRiver Riddle // The conversion target handles bounding the recursion of this pattern. 1130b99bd771SRiver Riddle setHasBoundedRewriteRecursion(); 1131b99bd771SRiver Riddle } 1132bd1ccfe6SRiver Riddle 1133bd1ccfe6SRiver Riddle LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, 1134bd1ccfe6SRiver Riddle PatternRewriter &rewriter) const final { 1135bd1ccfe6SRiver Riddle // Decrement the depth of the op in-place. 11365fcf907bSMatthias Springer rewriter.modifyOpInPlace(op, [&] { 11376a994233SJacques Pienaar op->setAttr("depth", rewriter.getI64IntegerAttr(op.getDepth() - 1)); 1138bd1ccfe6SRiver Riddle }); 1139bd1ccfe6SRiver Riddle return success(); 1140bd1ccfe6SRiver Riddle } 1141bd1ccfe6SRiver Riddle }; 11425d5df06aSAlex Zinenko 11435d5df06aSAlex Zinenko struct TestNestedOpCreationUndoRewrite 11445d5df06aSAlex Zinenko : public OpRewritePattern<IllegalOpWithRegionAnchor> { 11455d5df06aSAlex Zinenko using OpRewritePattern<IllegalOpWithRegionAnchor>::OpRewritePattern; 11465d5df06aSAlex Zinenko 11475d5df06aSAlex Zinenko LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, 11485d5df06aSAlex Zinenko PatternRewriter &rewriter) const final { 11495d5df06aSAlex Zinenko // rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 11505d5df06aSAlex Zinenko rewriter.replaceOpWithNewOp<IllegalOpWithRegion>(op); 11515d5df06aSAlex Zinenko return success(); 11525d5df06aSAlex Zinenko }; 11535d5df06aSAlex Zinenko }; 1154a360a978SMehdi Amini 1155a360a978SMehdi Amini // This pattern matches `test.blackhole` and delete this op and its producer. 1156a360a978SMehdi Amini struct TestReplaceEraseOp : public OpRewritePattern<BlackHoleOp> { 1157a360a978SMehdi Amini using OpRewritePattern<BlackHoleOp>::OpRewritePattern; 1158a360a978SMehdi Amini 1159a360a978SMehdi Amini LogicalResult matchAndRewrite(BlackHoleOp op, 1160a360a978SMehdi Amini PatternRewriter &rewriter) const final { 1161a360a978SMehdi Amini Operation *producer = op.getOperand().getDefiningOp(); 1162a360a978SMehdi Amini // Always erase the user before the producer, the framework should handle 1163a360a978SMehdi Amini // this correctly. 1164a360a978SMehdi Amini rewriter.eraseOp(op); 1165a360a978SMehdi Amini rewriter.eraseOp(producer); 1166a360a978SMehdi Amini return success(); 1167a360a978SMehdi Amini }; 1168a360a978SMehdi Amini }; 1169ec03bbe8SVladislav Vinogradov 1170ec03bbe8SVladislav Vinogradov // This pattern replaces explicitly illegal op with explicitly legal op, 1171ec03bbe8SVladislav Vinogradov // but in addition creates unregistered operation. 1172ec03bbe8SVladislav Vinogradov struct TestCreateUnregisteredOp : public OpRewritePattern<ILLegalOpG> { 1173ec03bbe8SVladislav Vinogradov using OpRewritePattern<ILLegalOpG>::OpRewritePattern; 1174ec03bbe8SVladislav Vinogradov 1175ec03bbe8SVladislav Vinogradov LogicalResult matchAndRewrite(ILLegalOpG op, 1176ec03bbe8SVladislav Vinogradov PatternRewriter &rewriter) const final { 1177ec03bbe8SVladislav Vinogradov IntegerAttr attr = rewriter.getI32IntegerAttr(0); 11788e123ca6SRiver Riddle Value val = rewriter.create<arith::ConstantOp>(op->getLoc(), attr); 1179ec03bbe8SVladislav Vinogradov rewriter.replaceOpWithNewOp<LegalOpC>(op, val); 1180ec03bbe8SVladislav Vinogradov return success(); 1181ec03bbe8SVladislav Vinogradov }; 1182ec03bbe8SVladislav Vinogradov }; 11833815f478SMatthias Springer 11843815f478SMatthias Springer class TestEraseOp : public ConversionPattern { 11853815f478SMatthias Springer public: 11863815f478SMatthias Springer TestEraseOp(MLIRContext *ctx) : ConversionPattern("test.erase_op", 1, ctx) {} 11873815f478SMatthias Springer LogicalResult 11883815f478SMatthias Springer matchAndRewrite(Operation *op, ArrayRef<Value> operands, 11893815f478SMatthias Springer ConversionPatternRewriter &rewriter) const final { 11903815f478SMatthias Springer // Erase op without replacements. 11913815f478SMatthias Springer rewriter.eraseOp(op); 11923815f478SMatthias Springer return success(); 11933815f478SMatthias Springer } 11943815f478SMatthias Springer }; 11953815f478SMatthias Springer 11969df63b26SMatthias Springer /// This pattern matches a test.duplicate_block_args op and duplicates all 11979df63b26SMatthias Springer /// block arguments. 11989df63b26SMatthias Springer class TestDuplicateBlockArgs 11999df63b26SMatthias Springer : public OpConversionPattern<DuplicateBlockArgsOp> { 12009df63b26SMatthias Springer using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern; 12019df63b26SMatthias Springer 12029df63b26SMatthias Springer LogicalResult 12039df63b26SMatthias Springer matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor, 12049df63b26SMatthias Springer ConversionPatternRewriter &rewriter) const override { 12059df63b26SMatthias Springer if (op.getIsLegal()) 12069df63b26SMatthias Springer return failure(); 12079df63b26SMatthias Springer rewriter.startOpModification(op); 12089df63b26SMatthias Springer Block *body = &op.getBody().front(); 12099df63b26SMatthias Springer TypeConverter::SignatureConversion result(body->getNumArguments()); 12109df63b26SMatthias Springer for (auto it : llvm::enumerate(body->getArgumentTypes())) 12119df63b26SMatthias Springer result.addInputs(it.index(), {it.value(), it.value()}); 12129df63b26SMatthias Springer rewriter.applySignatureConversion(body, result, getTypeConverter()); 12139df63b26SMatthias Springer op.setIsLegal(true); 12149df63b26SMatthias Springer rewriter.finalizeOpModification(op); 12159df63b26SMatthias Springer return success(); 12169df63b26SMatthias Springer } 12179df63b26SMatthias Springer }; 12189df63b26SMatthias Springer 12199df63b26SMatthias Springer /// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid 12209df63b26SMatthias Springer /// op. The pattern supports 1:N replacements and forwards the replacement 12219df63b26SMatthias Springer /// values of the single operand as test.valid operands. 12229df63b26SMatthias Springer class TestRepetitive1ToNConsumer : public ConversionPattern { 12239df63b26SMatthias Springer public: 12249df63b26SMatthias Springer TestRepetitive1ToNConsumer(MLIRContext *ctx) 12259df63b26SMatthias Springer : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {} 12269df63b26SMatthias Springer LogicalResult 12279df63b26SMatthias Springer matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 12289df63b26SMatthias Springer ConversionPatternRewriter &rewriter) const final { 12299df63b26SMatthias Springer // A single operand is expected. 12309df63b26SMatthias Springer if (op->getNumOperands() != 1) 12319df63b26SMatthias Springer return failure(); 12329df63b26SMatthias Springer rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front()); 12339df63b26SMatthias Springer return success(); 12349df63b26SMatthias Springer } 12359df63b26SMatthias Springer }; 12369df63b26SMatthias Springer 1237412e30b2SMatthias Springer /// A pattern that tests two back-to-back 1 -> 2 op replacements. 1238412e30b2SMatthias Springer class TestMultiple1ToNReplacement : public ConversionPattern { 1239412e30b2SMatthias Springer public: 1240412e30b2SMatthias Springer TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter) 1241412e30b2SMatthias Springer : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1, 1242412e30b2SMatthias Springer ctx) {} 1243412e30b2SMatthias Springer LogicalResult 1244412e30b2SMatthias Springer matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 1245412e30b2SMatthias Springer ConversionPatternRewriter &rewriter) const final { 1246412e30b2SMatthias Springer // Helper function that replaces the given op with a new op of the given 1247412e30b2SMatthias Springer // name and doubles each result (1 -> 2 replacement of each result). 1248412e30b2SMatthias Springer auto replaceWithDoubleResults = [&](Operation *op, StringRef name) { 1249412e30b2SMatthias Springer SmallVector<Type> types; 1250412e30b2SMatthias Springer for (Type t : op->getResultTypes()) { 1251412e30b2SMatthias Springer types.push_back(t); 1252412e30b2SMatthias Springer types.push_back(t); 1253412e30b2SMatthias Springer } 1254412e30b2SMatthias Springer OperationState state(op->getLoc(), name, 1255412e30b2SMatthias Springer /*operands=*/{}, types, op->getAttrs()); 1256412e30b2SMatthias Springer auto *newOp = rewriter.create(state); 1257412e30b2SMatthias Springer SmallVector<ValueRange> repls; 1258412e30b2SMatthias Springer for (size_t i = 0, e = op->getNumResults(); i < e; ++i) 1259412e30b2SMatthias Springer repls.push_back(newOp->getResults().slice(2 * i, 2)); 1260412e30b2SMatthias Springer rewriter.replaceOpWithMultiple(op, repls); 1261412e30b2SMatthias Springer return newOp; 1262412e30b2SMatthias Springer }; 1263412e30b2SMatthias Springer 1264412e30b2SMatthias Springer // Replace test.multiple_1_to_n_replacement with test.step_1. 1265412e30b2SMatthias Springer Operation *repl1 = replaceWithDoubleResults(op, "test.step_1"); 1266412e30b2SMatthias Springer // Now replace test.step_1 with test.legal_op. 1267412e30b2SMatthias Springer replaceWithDoubleResults(repl1, "test.legal_op"); 1268412e30b2SMatthias Springer return success(); 1269412e30b2SMatthias Springer } 1270412e30b2SMatthias Springer }; 1271412e30b2SMatthias Springer 1272fec6c5acSUday Bondhugula } // namespace 1273fec6c5acSUday Bondhugula 1274fec6c5acSUday Bondhugula namespace { 1275fec6c5acSUday Bondhugula struct TestTypeConverter : public TypeConverter { 1276fec6c5acSUday Bondhugula using TypeConverter::TypeConverter; 12775c5dafc5SAlex Zinenko TestTypeConverter() { 12785c5dafc5SAlex Zinenko addConversion(convertType); 12794589dd92SRiver Riddle addSourceMaterialization(materializeCast); 12805c5dafc5SAlex Zinenko } 1281fec6c5acSUday Bondhugula 1282fec6c5acSUday Bondhugula static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { 1283fec6c5acSUday Bondhugula // Drop I16 types. 1284fec6c5acSUday Bondhugula if (t.isSignlessInteger(16)) 1285fec6c5acSUday Bondhugula return success(); 1286fec6c5acSUday Bondhugula 1287fec6c5acSUday Bondhugula // Convert I64 to F64. 1288fec6c5acSUday Bondhugula if (t.isSignlessInteger(64)) { 1289*f023da12SMatthias Springer results.push_back(Float64Type::get(t.getContext())); 1290fec6c5acSUday Bondhugula return success(); 1291fec6c5acSUday Bondhugula } 1292fec6c5acSUday Bondhugula 12935c5dafc5SAlex Zinenko // Convert I42 to I43. 12945c5dafc5SAlex Zinenko if (t.isInteger(42)) { 12951b97cdf8SRiver Riddle results.push_back(IntegerType::get(t.getContext(), 43)); 12965c5dafc5SAlex Zinenko return success(); 12975c5dafc5SAlex Zinenko } 12985c5dafc5SAlex Zinenko 1299fec6c5acSUday Bondhugula // Split F32 into F16,F16. 1300fec6c5acSUday Bondhugula if (t.isF32()) { 1301*f023da12SMatthias Springer results.assign(2, Float16Type::get(t.getContext())); 1302fec6c5acSUday Bondhugula return success(); 1303fec6c5acSUday Bondhugula } 1304fec6c5acSUday Bondhugula 130508e6566dSMatthias Springer // Drop I24 types. 130608e6566dSMatthias Springer if (t.isInteger(24)) { 130708e6566dSMatthias Springer return success(); 130808e6566dSMatthias Springer } 130908e6566dSMatthias Springer 1310fec6c5acSUday Bondhugula // Otherwise, convert the type directly. 1311fec6c5acSUday Bondhugula results.push_back(t); 1312fec6c5acSUday Bondhugula return success(); 1313fec6c5acSUday Bondhugula } 1314fec6c5acSUday Bondhugula 13155c5dafc5SAlex Zinenko /// Hook for materializing a conversion. This is necessary because we generate 13165c5dafc5SAlex Zinenko /// 1->N type mappings. 1317f18c3e4eSMatthias Springer static Value materializeCast(OpBuilder &builder, Type resultType, 13184589dd92SRiver Riddle ValueRange inputs, Location loc) { 13194589dd92SRiver Riddle return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 13205c5dafc5SAlex Zinenko } 1321fec6c5acSUday Bondhugula }; 1322fec6c5acSUday Bondhugula 1323fec6c5acSUday Bondhugula struct TestLegalizePatternDriver 1324c3a8ac79SMatthias Springer : public PassWrapper<TestLegalizePatternDriver, OperationPass<>> { 13255e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLegalizePatternDriver) 13265e50dd04SRiver Riddle 1327b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-legalize-patterns"; } 1328b5e22e6dSMehdi Amini StringRef getDescription() const final { 1329b5e22e6dSMehdi Amini return "Run test dialect legalization patterns"; 1330b5e22e6dSMehdi Amini } 1331fec6c5acSUday Bondhugula /// The mode of conversion to use with the driver. 1332fec6c5acSUday Bondhugula enum class ConversionMode { Analysis, Full, Partial }; 1333fec6c5acSUday Bondhugula 1334fec6c5acSUday Bondhugula TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} 1335fec6c5acSUday Bondhugula 1336ec03bbe8SVladislav Vinogradov void getDependentDialects(DialectRegistry ®istry) const override { 133799fd12f2SMehdi Amini registry.insert<func::FuncDialect, test::TestDialect>(); 1338ec03bbe8SVladislav Vinogradov } 1339ec03bbe8SVladislav Vinogradov 1340722f909fSRiver Riddle void runOnOperation() override { 1341fec6c5acSUday Bondhugula TestTypeConverter converter; 1342dc4e913bSChris Lattner mlir::RewritePatternSet patterns(&getContext()); 13431d909c9aSChris Lattner populateWithGenerated(patterns); 13443cc311abSMatthias Springer patterns 13453cc311abSMatthias Springer .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion, 1346684a5a30SMatthias Springer TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, 13473cc311abSMatthias Springer TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType, 13483cc311abSMatthias Springer TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, 13493cc311abSMatthias Springer TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, 13503cc311abSMatthias Springer TestNonRootReplacement, TestBoundedRecursiveRewrite, 13513cc311abSMatthias Springer TestNestedOpCreationUndoRewrite, TestReplaceEraseOp, 13523cc311abSMatthias Springer TestCreateUnregisteredOp, TestUndoMoveOpBefore, 13539df63b26SMatthias Springer TestUndoPropertiesModification, TestEraseOp, 13549df63b26SMatthias Springer TestRepetitive1ToNConsumer>(&getContext()); 13553cc311abSMatthias Springer patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, 1356412e30b2SMatthias Springer TestPassthroughInvalidOp, TestMultiple1ToNReplacement>( 1357412e30b2SMatthias Springer &getContext(), converter); 13589df63b26SMatthias Springer patterns.add<TestDuplicateBlockArgs>(converter, &getContext()); 1359ed4749f9SIvan Butygin mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1360ed4749f9SIvan Butygin converter); 13613a506b31SChris Lattner mlir::populateCallOpTypeConversionPattern(patterns, converter); 1362fec6c5acSUday Bondhugula 1363fec6c5acSUday Bondhugula // Define the conversion target used for the test. 1364fec6c5acSUday Bondhugula ConversionTarget target(getContext()); 1365973ddb7dSMehdi Amini target.addLegalOp<ModuleOp>(); 1366ec03bbe8SVladislav Vinogradov target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, 13678f4cd2c7SMatthias Springer TerminatorOp, OneRegionOp>(); 1368412e30b2SMatthias Springer target.addLegalOp(OperationName("test.legal_op", &getContext())); 1369fec6c5acSUday Bondhugula target 1370fec6c5acSUday Bondhugula .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); 1371fec6c5acSUday Bondhugula target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { 1372fec6c5acSUday Bondhugula // Don't allow F32 operands. 1373fec6c5acSUday Bondhugula return llvm::none_of(op.getOperandTypes(), 1374fec6c5acSUday Bondhugula [](Type type) { return type.isF32(); }); 1375fec6c5acSUday Bondhugula }); 137658ceae95SRiver Riddle target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 13774a3460a7SRiver Riddle return converter.isSignatureLegal(op.getFunctionType()) && 13788d67d187SRiver Riddle converter.isLegal(&op.getBody()); 13798d67d187SRiver Riddle }); 138023aa5a74SRiver Riddle target.addDynamicallyLegalOp<func::CallOp>( 138123aa5a74SRiver Riddle [&](func::CallOp op) { return converter.isLegal(op); }); 1382fec6c5acSUday Bondhugula 1383a54f4eaeSMogball // TestCreateUnregisteredOp creates `arith.constant` operation, 1384ec03bbe8SVladislav Vinogradov // which was not added to target intentionally to test 1385ec03bbe8SVladislav Vinogradov // correct error code from conversion driver. 1386ec03bbe8SVladislav Vinogradov target.addDynamicallyLegalOp<ILLegalOpG>([](ILLegalOpG) { return false; }); 1387ec03bbe8SVladislav Vinogradov 1388fec6c5acSUday Bondhugula // Expect the type_producer/type_consumer operations to only operate on f64. 1389fec6c5acSUday Bondhugula target.addDynamicallyLegalOp<TestTypeProducerOp>( 1390fec6c5acSUday Bondhugula [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1391fec6c5acSUday Bondhugula target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1392fec6c5acSUday Bondhugula return op.getOperand().getType().isF64(); 1393fec6c5acSUday Bondhugula }); 1394fec6c5acSUday Bondhugula 1395fec6c5acSUday Bondhugula // Check support for marking certain operations as recursively legal. 139658ceae95SRiver Riddle target.markOpRecursivelyLegal<func::FuncOp, ModuleOp>([](Operation *op) { 1397fec6c5acSUday Bondhugula return static_cast<bool>( 1398fec6c5acSUday Bondhugula op->getAttrOfType<UnitAttr>("test.recursively_legal")); 1399fec6c5acSUday Bondhugula }); 1400fec6c5acSUday Bondhugula 1401bd1ccfe6SRiver Riddle // Mark the bound recursion operation as dynamically legal. 1402bd1ccfe6SRiver Riddle target.addDynamicallyLegalOp<TestRecursiveRewriteOp>( 14036a994233SJacques Pienaar [](TestRecursiveRewriteOp op) { return op.getDepth() == 0; }); 1404bd1ccfe6SRiver Riddle 14054513050fSChristian Ulmann // Create a dynamically legal rule that can only be legalized by folding it. 14064513050fSChristian Ulmann target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>( 14074513050fSChristian Ulmann [](TestOpInPlaceSelfFold op) { return op.getFolded(); }); 14084513050fSChristian Ulmann 14099df63b26SMatthias Springer target.addDynamicallyLegalOp<DuplicateBlockArgsOp>( 14109df63b26SMatthias Springer [](DuplicateBlockArgsOp op) { return op.getIsLegal(); }); 14119df63b26SMatthias Springer 1412fec6c5acSUday Bondhugula // Handle a partial conversion. 1413fec6c5acSUday Bondhugula if (mode == ConversionMode::Partial) { 14148de482eaSLucy Fox DenseSet<Operation *> unlegalizedOps; 1415a2821094SMatthias Springer ConversionConfig config; 141660a20bd6SMatthias Springer DumpNotifications dumpNotifications; 141760a20bd6SMatthias Springer config.listener = &dumpNotifications; 1418a2821094SMatthias Springer config.unlegalizedOps = &unlegalizedOps; 1419a2821094SMatthias Springer if (failed(applyPartialConversion(getOperation(), target, 1420a2821094SMatthias Springer std::move(patterns), config))) { 1421ec03bbe8SVladislav Vinogradov getOperation()->emitRemark() << "applyPartialConversion failed"; 1422ec03bbe8SVladislav Vinogradov } 14238de482eaSLucy Fox // Emit remarks for each legalizable operation. 14248de482eaSLucy Fox for (auto *op : unlegalizedOps) 14258de482eaSLucy Fox op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 1426fec6c5acSUday Bondhugula return; 1427fec6c5acSUday Bondhugula } 1428fec6c5acSUday Bondhugula 1429fec6c5acSUday Bondhugula // Handle a full conversion. 1430fec6c5acSUday Bondhugula if (mode == ConversionMode::Full) { 1431fec6c5acSUday Bondhugula // Check support for marking unknown operations as dynamically legal. 1432fec6c5acSUday Bondhugula target.markUnknownOpDynamicallyLegal([](Operation *op) { 1433fec6c5acSUday Bondhugula return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); 1434fec6c5acSUday Bondhugula }); 1435fec6c5acSUday Bondhugula 143660a20bd6SMatthias Springer ConversionConfig config; 143760a20bd6SMatthias Springer DumpNotifications dumpNotifications; 143860a20bd6SMatthias Springer config.listener = &dumpNotifications; 1439ec03bbe8SVladislav Vinogradov if (failed(applyFullConversion(getOperation(), target, 144060a20bd6SMatthias Springer std::move(patterns), config))) { 1441ec03bbe8SVladislav Vinogradov getOperation()->emitRemark() << "applyFullConversion failed"; 1442ec03bbe8SVladislav Vinogradov } 1443fec6c5acSUday Bondhugula return; 1444fec6c5acSUday Bondhugula } 1445fec6c5acSUday Bondhugula 1446fec6c5acSUday Bondhugula // Otherwise, handle an analysis conversion. 1447fec6c5acSUday Bondhugula assert(mode == ConversionMode::Analysis); 1448fec6c5acSUday Bondhugula 1449fec6c5acSUday Bondhugula // Analyze the convertible operations. 1450fec6c5acSUday Bondhugula DenseSet<Operation *> legalizedOps; 1451a2821094SMatthias Springer ConversionConfig config; 1452a2821094SMatthias Springer config.legalizableOps = &legalizedOps; 14533fffffa8SRiver Riddle if (failed(applyAnalysisConversion(getOperation(), target, 1454a2821094SMatthias Springer std::move(patterns), config))) 1455fec6c5acSUday Bondhugula return signalPassFailure(); 1456fec6c5acSUday Bondhugula 1457fec6c5acSUday Bondhugula // Emit remarks for each legalizable operation. 1458fec6c5acSUday Bondhugula for (auto *op : legalizedOps) 1459fec6c5acSUday Bondhugula op->emitRemark() << "op '" << op->getName() << "' is legalizable"; 1460fec6c5acSUday Bondhugula } 1461fec6c5acSUday Bondhugula 1462fec6c5acSUday Bondhugula /// The mode of conversion to use. 1463fec6c5acSUday Bondhugula ConversionMode mode; 1464fec6c5acSUday Bondhugula }; 1465be0a7e9fSMehdi Amini } // namespace 1466fec6c5acSUday Bondhugula 1467fec6c5acSUday Bondhugula static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode> 1468fec6c5acSUday Bondhugula legalizerConversionMode( 1469fec6c5acSUday Bondhugula "test-legalize-mode", 1470fec6c5acSUday Bondhugula llvm::cl::desc("The legalization mode to use with the test driver"), 1471fec6c5acSUday Bondhugula llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), 1472fec6c5acSUday Bondhugula llvm::cl::values( 1473fec6c5acSUday Bondhugula clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, 1474fec6c5acSUday Bondhugula "analysis", "Perform an analysis conversion"), 1475fec6c5acSUday Bondhugula clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", 1476fec6c5acSUday Bondhugula "Perform a full conversion"), 1477fec6c5acSUday Bondhugula clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, 1478fec6c5acSUday Bondhugula "partial", "Perform a partial conversion"))); 1479fec6c5acSUday Bondhugula 1480fec6c5acSUday Bondhugula //===----------------------------------------------------------------------===// 1481fec6c5acSUday Bondhugula // ConversionPatternRewriter::getRemappedValue testing. This method is used 14825aacce3dSKazuaki Ishizaki // to get the remapped value of an original value that was replaced using 1483fec6c5acSUday Bondhugula // ConversionPatternRewriter. 1484fec6c5acSUday Bondhugula namespace { 1485015192c6SRiver Riddle struct TestRemapValueTypeConverter : public TypeConverter { 1486015192c6SRiver Riddle using TypeConverter::TypeConverter; 1487015192c6SRiver Riddle 1488015192c6SRiver Riddle TestRemapValueTypeConverter() { 1489015192c6SRiver Riddle addConversion( 1490015192c6SRiver Riddle [](Float32Type type) { return Float64Type::get(type.getContext()); }); 1491015192c6SRiver Riddle addConversion([](Type type) { return type; }); 1492015192c6SRiver Riddle } 1493015192c6SRiver Riddle }; 1494015192c6SRiver Riddle 1495fec6c5acSUday Bondhugula /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with 1496fec6c5acSUday Bondhugula /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original 1497fec6c5acSUday Bondhugula /// operand twice. 1498fec6c5acSUday Bondhugula /// 1499fec6c5acSUday Bondhugula /// Example: 1500fec6c5acSUday Bondhugula /// %1 = test.one_variadic_out_one_variadic_in1"(%0) 1501fec6c5acSUday Bondhugula /// is replaced with: 1502fec6c5acSUday Bondhugula /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) 1503fec6c5acSUday Bondhugula struct OneVResOneVOperandOp1Converter 1504fec6c5acSUday Bondhugula : public OpConversionPattern<OneVResOneVOperandOp1> { 1505fec6c5acSUday Bondhugula using OpConversionPattern<OneVResOneVOperandOp1>::OpConversionPattern; 1506fec6c5acSUday Bondhugula 1507fec6c5acSUday Bondhugula LogicalResult 1508ef976337SRiver Riddle matchAndRewrite(OneVResOneVOperandOp1 op, OpAdaptor adaptor, 1509fec6c5acSUday Bondhugula ConversionPatternRewriter &rewriter) const override { 1510fec6c5acSUday Bondhugula auto origOps = op.getOperands(); 1511fec6c5acSUday Bondhugula assert(std::distance(origOps.begin(), origOps.end()) == 1 && 1512fec6c5acSUday Bondhugula "One operand expected"); 1513fec6c5acSUday Bondhugula Value origOp = *origOps.begin(); 1514fec6c5acSUday Bondhugula SmallVector<Value, 2> remappedOperands; 1515fec6c5acSUday Bondhugula // Replicate the remapped original operand twice. Note that we don't used 1516fec6c5acSUday Bondhugula // the remapped 'operand' since the goal is testing 'getRemappedValue'. 1517fec6c5acSUday Bondhugula remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1518fec6c5acSUday Bondhugula remappedOperands.push_back(rewriter.getRemappedValue(origOp)); 1519fec6c5acSUday Bondhugula 1520fec6c5acSUday Bondhugula rewriter.replaceOpWithNewOp<OneVResOneVOperandOp1>(op, op.getResultTypes(), 1521fec6c5acSUday Bondhugula remappedOperands); 1522fec6c5acSUday Bondhugula return success(); 1523fec6c5acSUday Bondhugula } 1524fec6c5acSUday Bondhugula }; 1525fec6c5acSUday Bondhugula 1526015192c6SRiver Riddle /// A rewriter pattern that tests that blocks can be merged. 1527015192c6SRiver Riddle struct TestRemapValueInRegion 1528015192c6SRiver Riddle : public OpConversionPattern<TestRemappedValueRegionOp> { 1529015192c6SRiver Riddle using OpConversionPattern<TestRemappedValueRegionOp>::OpConversionPattern; 1530015192c6SRiver Riddle 1531015192c6SRiver Riddle LogicalResult 1532015192c6SRiver Riddle matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, 1533015192c6SRiver Riddle ConversionPatternRewriter &rewriter) const final { 1534015192c6SRiver Riddle Block &block = op.getBody().front(); 1535015192c6SRiver Riddle Operation *terminator = block.getTerminator(); 1536015192c6SRiver Riddle 1537015192c6SRiver Riddle // Merge the block into the parent region. 1538015192c6SRiver Riddle Block *parentBlock = op->getBlock(); 1539015192c6SRiver Riddle Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); 1540015192c6SRiver Riddle rewriter.mergeBlocks(&block, parentBlock, ValueRange()); 1541015192c6SRiver Riddle rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); 1542015192c6SRiver Riddle 1543015192c6SRiver Riddle // Replace the results of this operation with the remapped terminator 1544015192c6SRiver Riddle // values. 1545015192c6SRiver Riddle SmallVector<Value> terminatorOperands; 1546015192c6SRiver Riddle if (failed(rewriter.getRemappedValues(terminator->getOperands(), 1547015192c6SRiver Riddle terminatorOperands))) 1548015192c6SRiver Riddle return failure(); 1549015192c6SRiver Riddle 1550015192c6SRiver Riddle rewriter.eraseOp(terminator); 1551015192c6SRiver Riddle rewriter.replaceOp(op, terminatorOperands); 1552015192c6SRiver Riddle return success(); 1553015192c6SRiver Riddle } 1554015192c6SRiver Riddle }; 1555015192c6SRiver Riddle 155680aca1eaSRiver Riddle struct TestRemappedValue 1557c3a8ac79SMatthias Springer : public mlir::PassWrapper<TestRemappedValue, OperationPass<>> { 15585e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRemappedValue) 15595e50dd04SRiver Riddle 1560b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-remapped-value"; } 1561b5e22e6dSMehdi Amini StringRef getDescription() const final { 1562b5e22e6dSMehdi Amini return "Test public remapped value mechanism in ConversionPatternRewriter"; 1563b5e22e6dSMehdi Amini } 156441574554SRiver Riddle void runOnOperation() override { 1565015192c6SRiver Riddle TestRemapValueTypeConverter typeConverter; 1566015192c6SRiver Riddle 1567dc4e913bSChris Lattner mlir::RewritePatternSet patterns(&getContext()); 1568dc4e913bSChris Lattner patterns.add<OneVResOneVOperandOp1Converter>(&getContext()); 1569015192c6SRiver Riddle patterns.add<TestChangeProducerTypeF32ToF64, TestUpdateConsumerType>( 1570015192c6SRiver Riddle &getContext()); 1571015192c6SRiver Riddle patterns.add<TestRemapValueInRegion>(typeConverter, &getContext()); 1572fec6c5acSUday Bondhugula 1573fec6c5acSUday Bondhugula mlir::ConversionTarget target(getContext()); 157458ceae95SRiver Riddle target.addLegalOp<ModuleOp, func::FuncOp, TestReturnOp>(); 1575015192c6SRiver Riddle 1576015192c6SRiver Riddle // Expect the type_producer/type_consumer operations to only operate on f64. 1577015192c6SRiver Riddle target.addDynamicallyLegalOp<TestTypeProducerOp>( 1578015192c6SRiver Riddle [](TestTypeProducerOp op) { return op.getType().isF64(); }); 1579015192c6SRiver Riddle target.addDynamicallyLegalOp<TestTypeConsumerOp>([](TestTypeConsumerOp op) { 1580015192c6SRiver Riddle return op.getOperand().getType().isF64(); 1581015192c6SRiver Riddle }); 1582015192c6SRiver Riddle 1583fec6c5acSUday Bondhugula // We make OneVResOneVOperandOp1 legal only when it has more that one 1584fec6c5acSUday Bondhugula // operand. This will trigger the conversion that will replace one-operand 1585fec6c5acSUday Bondhugula // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. 1586fec6c5acSUday Bondhugula target.addDynamicallyLegalOp<OneVResOneVOperandOp1>( 1587015192c6SRiver Riddle [](Operation *op) { return op->getNumOperands() > 1; }); 1588fec6c5acSUday Bondhugula 158941574554SRiver Riddle if (failed(mlir::applyFullConversion(getOperation(), target, 15903fffffa8SRiver Riddle std::move(patterns)))) { 1591fec6c5acSUday Bondhugula signalPassFailure(); 1592fec6c5acSUday Bondhugula } 1593fec6c5acSUday Bondhugula } 1594fec6c5acSUday Bondhugula }; 1595be0a7e9fSMehdi Amini } // namespace 1596fec6c5acSUday Bondhugula 159780d7ac3bSRiver Riddle //===----------------------------------------------------------------------===// 159880d7ac3bSRiver Riddle // Test patterns without a specific root operation kind 159980d7ac3bSRiver Riddle //===----------------------------------------------------------------------===// 160080d7ac3bSRiver Riddle 160180d7ac3bSRiver Riddle namespace { 160280d7ac3bSRiver Riddle /// This pattern matches and removes any operation in the test dialect. 160380d7ac3bSRiver Riddle struct RemoveTestDialectOps : public RewritePattern { 160476f3c2f3SRiver Riddle RemoveTestDialectOps(MLIRContext *context) 160576f3c2f3SRiver Riddle : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 160680d7ac3bSRiver Riddle 160780d7ac3bSRiver Riddle LogicalResult matchAndRewrite(Operation *op, 160880d7ac3bSRiver Riddle PatternRewriter &rewriter) const override { 160980d7ac3bSRiver Riddle if (!isa<TestDialect>(op->getDialect())) 161080d7ac3bSRiver Riddle return failure(); 161180d7ac3bSRiver Riddle rewriter.eraseOp(op); 161280d7ac3bSRiver Riddle return success(); 161380d7ac3bSRiver Riddle } 161480d7ac3bSRiver Riddle }; 161580d7ac3bSRiver Riddle 161680d7ac3bSRiver Riddle struct TestUnknownRootOpDriver 1617c3a8ac79SMatthias Springer : public mlir::PassWrapper<TestUnknownRootOpDriver, OperationPass<>> { 16185e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnknownRootOpDriver) 16195e50dd04SRiver Riddle 1620b5e22e6dSMehdi Amini StringRef getArgument() const final { 1621b5e22e6dSMehdi Amini return "test-legalize-unknown-root-patterns"; 1622b5e22e6dSMehdi Amini } 1623b5e22e6dSMehdi Amini StringRef getDescription() const final { 1624b5e22e6dSMehdi Amini return "Test public remapped value mechanism in ConversionPatternRewriter"; 1625b5e22e6dSMehdi Amini } 162641574554SRiver Riddle void runOnOperation() override { 1627dc4e913bSChris Lattner mlir::RewritePatternSet patterns(&getContext()); 162876f3c2f3SRiver Riddle patterns.add<RemoveTestDialectOps>(&getContext()); 162980d7ac3bSRiver Riddle 163080d7ac3bSRiver Riddle mlir::ConversionTarget target(getContext()); 163180d7ac3bSRiver Riddle target.addIllegalDialect<TestDialect>(); 163241574554SRiver Riddle if (failed(applyPartialConversion(getOperation(), target, 163341574554SRiver Riddle std::move(patterns)))) 163480d7ac3bSRiver Riddle signalPassFailure(); 163580d7ac3bSRiver Riddle } 163680d7ac3bSRiver Riddle }; 1637be0a7e9fSMehdi Amini } // namespace 163880d7ac3bSRiver Riddle 16394589dd92SRiver Riddle //===----------------------------------------------------------------------===// 164088bc24a7SMathieu Fehr // Test patterns that uses operations and types defined at runtime 164188bc24a7SMathieu Fehr //===----------------------------------------------------------------------===// 164288bc24a7SMathieu Fehr 164388bc24a7SMathieu Fehr namespace { 164488bc24a7SMathieu Fehr /// This pattern matches dynamic operations 'test.one_operand_two_results' and 164588bc24a7SMathieu Fehr /// replace them with dynamic operations 'test.generic_dynamic_op'. 164688bc24a7SMathieu Fehr struct RewriteDynamicOp : public RewritePattern { 164788bc24a7SMathieu Fehr RewriteDynamicOp(MLIRContext *context) 164888bc24a7SMathieu Fehr : RewritePattern("test.dynamic_one_operand_two_results", /*benefit=*/1, 164988bc24a7SMathieu Fehr context) {} 165088bc24a7SMathieu Fehr 165188bc24a7SMathieu Fehr LogicalResult matchAndRewrite(Operation *op, 165288bc24a7SMathieu Fehr PatternRewriter &rewriter) const override { 165388bc24a7SMathieu Fehr assert(op->getName().getStringRef() == 165488bc24a7SMathieu Fehr "test.dynamic_one_operand_two_results" && 165588bc24a7SMathieu Fehr "rewrite pattern should only match operations with the right name"); 165688bc24a7SMathieu Fehr 165788bc24a7SMathieu Fehr OperationState state(op->getLoc(), "test.dynamic_generic", 165888bc24a7SMathieu Fehr op->getOperands(), op->getResultTypes(), 165988bc24a7SMathieu Fehr op->getAttrs()); 166088bc24a7SMathieu Fehr auto *newOp = rewriter.create(state); 166188bc24a7SMathieu Fehr rewriter.replaceOp(op, newOp->getResults()); 166288bc24a7SMathieu Fehr return success(); 166388bc24a7SMathieu Fehr } 166488bc24a7SMathieu Fehr }; 166588bc24a7SMathieu Fehr 166688bc24a7SMathieu Fehr struct TestRewriteDynamicOpDriver 1667253afd03SMatthias Springer : public PassWrapper<TestRewriteDynamicOpDriver, OperationPass<>> { 166888bc24a7SMathieu Fehr MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRewriteDynamicOpDriver) 166988bc24a7SMathieu Fehr 167088bc24a7SMathieu Fehr void getDependentDialects(DialectRegistry ®istry) const override { 167188bc24a7SMathieu Fehr registry.insert<TestDialect>(); 167288bc24a7SMathieu Fehr } 167388bc24a7SMathieu Fehr StringRef getArgument() const final { return "test-rewrite-dynamic-op"; } 167488bc24a7SMathieu Fehr StringRef getDescription() const final { 167588bc24a7SMathieu Fehr return "Test rewritting on dynamic operations"; 167688bc24a7SMathieu Fehr } 167788bc24a7SMathieu Fehr void runOnOperation() override { 167888bc24a7SMathieu Fehr RewritePatternSet patterns(&getContext()); 167988bc24a7SMathieu Fehr patterns.add<RewriteDynamicOp>(&getContext()); 168088bc24a7SMathieu Fehr 168188bc24a7SMathieu Fehr ConversionTarget target(getContext()); 168288bc24a7SMathieu Fehr target.addIllegalOp( 168388bc24a7SMathieu Fehr OperationName("test.dynamic_one_operand_two_results", &getContext())); 168488bc24a7SMathieu Fehr target.addLegalOp(OperationName("test.dynamic_generic", &getContext())); 168588bc24a7SMathieu Fehr if (failed(applyPartialConversion(getOperation(), target, 168688bc24a7SMathieu Fehr std::move(patterns)))) 168788bc24a7SMathieu Fehr signalPassFailure(); 168888bc24a7SMathieu Fehr } 168988bc24a7SMathieu Fehr }; 169088bc24a7SMathieu Fehr } // end anonymous namespace 169188bc24a7SMathieu Fehr 169288bc24a7SMathieu Fehr //===----------------------------------------------------------------------===// 16934589dd92SRiver Riddle // Test type conversions 16944589dd92SRiver Riddle //===----------------------------------------------------------------------===// 16954589dd92SRiver Riddle 16964589dd92SRiver Riddle namespace { 16974589dd92SRiver Riddle struct TestTypeConversionProducer 16984589dd92SRiver Riddle : public OpConversionPattern<TestTypeProducerOp> { 16994589dd92SRiver Riddle using OpConversionPattern<TestTypeProducerOp>::OpConversionPattern; 17004589dd92SRiver Riddle LogicalResult 1701ef976337SRiver Riddle matchAndRewrite(TestTypeProducerOp op, OpAdaptor adaptor, 17024589dd92SRiver Riddle ConversionPatternRewriter &rewriter) const final { 17034589dd92SRiver Riddle Type resultType = op.getType(); 17049c5982efSAlex Zinenko Type convertedType = getTypeConverter() 17059c5982efSAlex Zinenko ? getTypeConverter()->convertType(resultType) 17069c5982efSAlex Zinenko : resultType; 17075550c821STres Popp if (isa<FloatType>(resultType)) 17084589dd92SRiver Riddle resultType = rewriter.getF64Type(); 17094589dd92SRiver Riddle else if (resultType.isInteger(16)) 17104589dd92SRiver Riddle resultType = rewriter.getIntegerType(64); 17115550c821STres Popp else if (isa<test::TestRecursiveType>(resultType) && 17129c5982efSAlex Zinenko convertedType != resultType) 17139c5982efSAlex Zinenko resultType = convertedType; 17144589dd92SRiver Riddle else 17154589dd92SRiver Riddle return failure(); 17164589dd92SRiver Riddle 17174589dd92SRiver Riddle rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, resultType); 17184589dd92SRiver Riddle return success(); 17194589dd92SRiver Riddle } 17204589dd92SRiver Riddle }; 17214589dd92SRiver Riddle 17220409eb28SAlex Zinenko /// Call signature conversion and then fail the rewrite to trigger the undo 17230409eb28SAlex Zinenko /// mechanism. 17240409eb28SAlex Zinenko struct TestSignatureConversionUndo 17250409eb28SAlex Zinenko : public OpConversionPattern<TestSignatureConversionUndoOp> { 17260409eb28SAlex Zinenko using OpConversionPattern<TestSignatureConversionUndoOp>::OpConversionPattern; 17270409eb28SAlex Zinenko 17280409eb28SAlex Zinenko LogicalResult 1729ef976337SRiver Riddle matchAndRewrite(TestSignatureConversionUndoOp op, OpAdaptor adaptor, 17300409eb28SAlex Zinenko ConversionPatternRewriter &rewriter) const final { 17310409eb28SAlex Zinenko (void)rewriter.convertRegionTypes(&op->getRegion(0), *getTypeConverter()); 17320409eb28SAlex Zinenko return failure(); 17330409eb28SAlex Zinenko } 17340409eb28SAlex Zinenko }; 17350409eb28SAlex Zinenko 17364070f305SRiver Riddle /// Call signature conversion without providing a type converter to handle 17374070f305SRiver Riddle /// materializations. 17384070f305SRiver Riddle struct TestTestSignatureConversionNoConverter 17394070f305SRiver Riddle : public OpConversionPattern<TestSignatureConversionNoConverterOp> { 1740ce254598SMatthias Springer TestTestSignatureConversionNoConverter(const TypeConverter &converter, 17414070f305SRiver Riddle MLIRContext *context) 17424070f305SRiver Riddle : OpConversionPattern<TestSignatureConversionNoConverterOp>(context), 17434070f305SRiver Riddle converter(converter) {} 17444070f305SRiver Riddle 17454070f305SRiver Riddle LogicalResult 17464070f305SRiver Riddle matchAndRewrite(TestSignatureConversionNoConverterOp op, OpAdaptor adaptor, 17474070f305SRiver Riddle ConversionPatternRewriter &rewriter) const final { 17484070f305SRiver Riddle Region ®ion = op->getRegion(0); 17494070f305SRiver Riddle Block *entry = ®ion.front(); 17504070f305SRiver Riddle 17514070f305SRiver Riddle // Convert the original entry arguments. 17524070f305SRiver Riddle TypeConverter::SignatureConversion result(entry->getNumArguments()); 17534070f305SRiver Riddle if (failed( 17544070f305SRiver Riddle converter.convertSignatureArgs(entry->getArgumentTypes(), result))) 17554070f305SRiver Riddle return failure(); 175652050f3fSMatthias Springer rewriter.modifyOpInPlace(op, [&] { 175752050f3fSMatthias Springer rewriter.applySignatureConversion(®ion.front(), result); 175852050f3fSMatthias Springer }); 17594070f305SRiver Riddle return success(); 17604070f305SRiver Riddle } 17614070f305SRiver Riddle 1762ce254598SMatthias Springer const TypeConverter &converter; 17634070f305SRiver Riddle }; 17644070f305SRiver Riddle 17650409eb28SAlex Zinenko /// Just forward the operands to the root op. This is essentially a no-op 17660409eb28SAlex Zinenko /// pattern that is used to trigger target materialization. 17670409eb28SAlex Zinenko struct TestTypeConsumerForward 17680409eb28SAlex Zinenko : public OpConversionPattern<TestTypeConsumerOp> { 17690409eb28SAlex Zinenko using OpConversionPattern<TestTypeConsumerOp>::OpConversionPattern; 17700409eb28SAlex Zinenko 17710409eb28SAlex Zinenko LogicalResult 1772ef976337SRiver Riddle matchAndRewrite(TestTypeConsumerOp op, OpAdaptor adaptor, 17730409eb28SAlex Zinenko ConversionPatternRewriter &rewriter) const final { 17745fcf907bSMatthias Springer rewriter.modifyOpInPlace(op, 1775ef976337SRiver Riddle [&] { op->setOperands(adaptor.getOperands()); }); 17760409eb28SAlex Zinenko return success(); 17770409eb28SAlex Zinenko } 17780409eb28SAlex Zinenko }; 17790409eb28SAlex Zinenko 17805b91060dSAlex Zinenko struct TestTypeConversionAnotherProducer 17815b91060dSAlex Zinenko : public OpRewritePattern<TestAnotherTypeProducerOp> { 17825b91060dSAlex Zinenko using OpRewritePattern<TestAnotherTypeProducerOp>::OpRewritePattern; 17835b91060dSAlex Zinenko 17845b91060dSAlex Zinenko LogicalResult matchAndRewrite(TestAnotherTypeProducerOp op, 17855b91060dSAlex Zinenko PatternRewriter &rewriter) const final { 17865b91060dSAlex Zinenko rewriter.replaceOpWithNewOp<TestTypeProducerOp>(op, op.getType()); 17875b91060dSAlex Zinenko return success(); 17885b91060dSAlex Zinenko } 17895b91060dSAlex Zinenko }; 17905b91060dSAlex Zinenko 17918fc32942SMatthias Springer struct TestReplaceWithLegalOp : public ConversionPattern { 17923cc311abSMatthias Springer TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx) 17933cc311abSMatthias Springer : ConversionPattern(converter, "test.replace_with_legal_op", 17943cc311abSMatthias Springer /*benefit=*/1, ctx) {} 17958fc32942SMatthias Springer LogicalResult 17968fc32942SMatthias Springer matchAndRewrite(Operation *op, ArrayRef<Value> operands, 17978fc32942SMatthias Springer ConversionPatternRewriter &rewriter) const final { 17988fc32942SMatthias Springer rewriter.replaceOpWithNewOp<LegalOpD>(op, operands[0]); 17998fc32942SMatthias Springer return success(); 18008fc32942SMatthias Springer } 18018fc32942SMatthias Springer }; 18028fc32942SMatthias Springer 18034589dd92SRiver Riddle struct TestTypeConversionDriver 1804c3a8ac79SMatthias Springer : public PassWrapper<TestTypeConversionDriver, OperationPass<>> { 18055e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTypeConversionDriver) 18065e50dd04SRiver Riddle 1807f9dc2b70SMehdi Amini void getDependentDialects(DialectRegistry ®istry) const override { 1808f9dc2b70SMehdi Amini registry.insert<TestDialect>(); 1809f9dc2b70SMehdi Amini } 1810b5e22e6dSMehdi Amini StringRef getArgument() const final { 1811b5e22e6dSMehdi Amini return "test-legalize-type-conversion"; 1812b5e22e6dSMehdi Amini } 1813b5e22e6dSMehdi Amini StringRef getDescription() const final { 1814b5e22e6dSMehdi Amini return "Test various type conversion functionalities in DialectConversion"; 1815b5e22e6dSMehdi Amini } 1816f9dc2b70SMehdi Amini 18174589dd92SRiver Riddle void runOnOperation() override { 18184589dd92SRiver Riddle // Initialize the type converter. 1819dc3dc974SMehdi Amini SmallVector<Type, 2> conversionCallStack; 18204589dd92SRiver Riddle TypeConverter converter; 18214589dd92SRiver Riddle 18224589dd92SRiver Riddle /// Add the legal set of type conversions. 18234589dd92SRiver Riddle converter.addConversion([](Type type) -> Type { 18244589dd92SRiver Riddle // Treat F64 as legal. 18254589dd92SRiver Riddle if (type.isF64()) 18264589dd92SRiver Riddle return type; 18274589dd92SRiver Riddle // Allow converting BF16/F16/F32 to F64. 18284589dd92SRiver Riddle if (type.isBF16() || type.isF16() || type.isF32()) 1829*f023da12SMatthias Springer return Float64Type::get(type.getContext()); 18304589dd92SRiver Riddle // Otherwise, the type is illegal. 18314589dd92SRiver Riddle return nullptr; 18324589dd92SRiver Riddle }); 18334589dd92SRiver Riddle converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &) { 18344589dd92SRiver Riddle // Drop all integer types. 18354589dd92SRiver Riddle return success(); 18364589dd92SRiver Riddle }); 18379c5982efSAlex Zinenko converter.addConversion( 18389c5982efSAlex Zinenko // Convert a recursive self-referring type into a non-self-referring 18399c5982efSAlex Zinenko // type named "outer_converted_type" that contains a SimpleAType. 1840dc3dc974SMehdi Amini [&](test::TestRecursiveType type, 1841dc3dc974SMehdi Amini SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { 18429c5982efSAlex Zinenko // If the type is already converted, return it to indicate that it is 18439c5982efSAlex Zinenko // legal. 18449c5982efSAlex Zinenko if (type.getName() == "outer_converted_type") { 18459c5982efSAlex Zinenko results.push_back(type); 18469c5982efSAlex Zinenko return success(); 18479c5982efSAlex Zinenko } 18489c5982efSAlex Zinenko 1849dc3dc974SMehdi Amini conversionCallStack.push_back(type); 1850dc3dc974SMehdi Amini auto popConversionCallStack = llvm::make_scope_exit( 1851dc3dc974SMehdi Amini [&conversionCallStack]() { conversionCallStack.pop_back(); }); 1852dc3dc974SMehdi Amini 18539c5982efSAlex Zinenko // If the type is on the call stack more than once (it is there at 18549c5982efSAlex Zinenko // least once because of the _current_ call, which is always the last 18559c5982efSAlex Zinenko // element on the stack), we've hit the recursive case. Just return 18569c5982efSAlex Zinenko // SimpleAType here to create a non-recursive type as a result. 1857dc3dc974SMehdi Amini if (llvm::is_contained(ArrayRef(conversionCallStack).drop_back(), 1858dc3dc974SMehdi Amini type)) { 18599c5982efSAlex Zinenko results.push_back(test::SimpleAType::get(type.getContext())); 18609c5982efSAlex Zinenko return success(); 18619c5982efSAlex Zinenko } 18629c5982efSAlex Zinenko 18639c5982efSAlex Zinenko // Convert the body recursively. 18649c5982efSAlex Zinenko auto result = test::TestRecursiveType::get(type.getContext(), 18659c5982efSAlex Zinenko "outer_converted_type"); 18669c5982efSAlex Zinenko if (failed(result.setBody(converter.convertType(type.getBody())))) 18679c5982efSAlex Zinenko return failure(); 18689c5982efSAlex Zinenko results.push_back(result); 18699c5982efSAlex Zinenko return success(); 18709c5982efSAlex Zinenko }); 18714589dd92SRiver Riddle 18724589dd92SRiver Riddle /// Add the legal set of type materializations. 18734589dd92SRiver Riddle converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, 18744589dd92SRiver Riddle ValueRange inputs, 18754589dd92SRiver Riddle Location loc) -> Value { 18764589dd92SRiver Riddle // Allow casting from F64 back to F32. 18774589dd92SRiver Riddle if (!resultType.isF16() && inputs.size() == 1 && 18784589dd92SRiver Riddle inputs[0].getType().isF64()) 18794589dd92SRiver Riddle return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 18804589dd92SRiver Riddle // Allow producing an i32 or i64 from nothing. 18814589dd92SRiver Riddle if ((resultType.isInteger(32) || resultType.isInteger(64)) && 18824589dd92SRiver Riddle inputs.empty()) 18834589dd92SRiver Riddle return builder.create<TestTypeProducerOp>(loc, resultType); 18844589dd92SRiver Riddle // Allow producing an i64 from an integer. 18855550c821STres Popp if (isa<IntegerType>(resultType) && inputs.size() == 1 && 18865550c821STres Popp isa<IntegerType>(inputs[0].getType())) 18874589dd92SRiver Riddle return builder.create<TestCastOp>(loc, resultType, inputs).getResult(); 18884589dd92SRiver Riddle // Otherwise, fail. 18894589dd92SRiver Riddle return nullptr; 18904589dd92SRiver Riddle }); 18914589dd92SRiver Riddle 18924589dd92SRiver Riddle // Initialize the conversion target. 18934589dd92SRiver Riddle mlir::ConversionTarget target(getContext()); 18948fc32942SMatthias Springer target.addLegalOp<LegalOpD>(); 18954589dd92SRiver Riddle target.addDynamicallyLegalOp<TestTypeProducerOp>([](TestTypeProducerOp op) { 18965550c821STres Popp auto recursiveType = dyn_cast<test::TestRecursiveType>(op.getType()); 18979c5982efSAlex Zinenko return op.getType().isF64() || op.getType().isInteger(64) || 18989c5982efSAlex Zinenko (recursiveType && 18999c5982efSAlex Zinenko recursiveType.getName() == "outer_converted_type"); 19004589dd92SRiver Riddle }); 190158ceae95SRiver Riddle target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 19024a3460a7SRiver Riddle return converter.isSignatureLegal(op.getFunctionType()) && 19034589dd92SRiver Riddle converter.isLegal(&op.getBody()); 19044589dd92SRiver Riddle }); 19054589dd92SRiver Riddle target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) { 19064589dd92SRiver Riddle // Allow casts from F64 to F32. 19074589dd92SRiver Riddle return (*op.operand_type_begin()).isF64() && op.getType().isF32(); 19084589dd92SRiver Riddle }); 19094070f305SRiver Riddle target.addDynamicallyLegalOp<TestSignatureConversionNoConverterOp>( 19104070f305SRiver Riddle [&](TestSignatureConversionNoConverterOp op) { 19114070f305SRiver Riddle return converter.isLegal(op.getRegion().front().getArgumentTypes()); 19124070f305SRiver Riddle }); 19134589dd92SRiver Riddle 19144589dd92SRiver Riddle // Initialize the set of rewrite patterns. 1915dc4e913bSChris Lattner RewritePatternSet patterns(&getContext()); 19163cc311abSMatthias Springer patterns 19173cc311abSMatthias Springer .add<TestTypeConsumerForward, TestTypeConversionProducer, 19184070f305SRiver Riddle TestSignatureConversionUndo, 19193cc311abSMatthias Springer TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>( 19203cc311abSMatthias Springer converter, &getContext()); 19213cc311abSMatthias Springer patterns.add<TestTypeConversionAnotherProducer>(&getContext()); 1922ed4749f9SIvan Butygin mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, 1923ed4749f9SIvan Butygin converter); 19244589dd92SRiver Riddle 19253fffffa8SRiver Riddle if (failed(applyPartialConversion(getOperation(), target, 19263fffffa8SRiver Riddle std::move(patterns)))) 19274589dd92SRiver Riddle signalPassFailure(); 19284589dd92SRiver Riddle } 19294589dd92SRiver Riddle }; 1930be0a7e9fSMehdi Amini } // namespace 19314589dd92SRiver Riddle 1932c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===// 1933d4a53f3bSAlex Zinenko // Test Target Materialization With No Uses 1934d4a53f3bSAlex Zinenko //===----------------------------------------------------------------------===// 1935d4a53f3bSAlex Zinenko 1936d4a53f3bSAlex Zinenko namespace { 1937d4a53f3bSAlex Zinenko struct ForwardOperandPattern : public OpConversionPattern<TestTypeChangerOp> { 1938d4a53f3bSAlex Zinenko using OpConversionPattern<TestTypeChangerOp>::OpConversionPattern; 1939d4a53f3bSAlex Zinenko 1940d4a53f3bSAlex Zinenko LogicalResult 1941d4a53f3bSAlex Zinenko matchAndRewrite(TestTypeChangerOp op, OpAdaptor adaptor, 1942d4a53f3bSAlex Zinenko ConversionPatternRewriter &rewriter) const final { 1943d4a53f3bSAlex Zinenko rewriter.replaceOp(op, adaptor.getOperands()); 1944d4a53f3bSAlex Zinenko return success(); 1945d4a53f3bSAlex Zinenko } 1946d4a53f3bSAlex Zinenko }; 1947d4a53f3bSAlex Zinenko 1948d4a53f3bSAlex Zinenko struct TestTargetMaterializationWithNoUses 1949c3a8ac79SMatthias Springer : public PassWrapper<TestTargetMaterializationWithNoUses, OperationPass<>> { 19505e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 19515e50dd04SRiver Riddle TestTargetMaterializationWithNoUses) 19525e50dd04SRiver Riddle 1953d4a53f3bSAlex Zinenko StringRef getArgument() const final { 1954d4a53f3bSAlex Zinenko return "test-target-materialization-with-no-uses"; 1955d4a53f3bSAlex Zinenko } 1956d4a53f3bSAlex Zinenko StringRef getDescription() const final { 1957d4a53f3bSAlex Zinenko return "Test a special case of target materialization in DialectConversion"; 1958d4a53f3bSAlex Zinenko } 1959d4a53f3bSAlex Zinenko 1960d4a53f3bSAlex Zinenko void runOnOperation() override { 1961d4a53f3bSAlex Zinenko TypeConverter converter; 1962d4a53f3bSAlex Zinenko converter.addConversion([](Type t) { return t; }); 1963d4a53f3bSAlex Zinenko converter.addConversion([](IntegerType intTy) -> Type { 1964d4a53f3bSAlex Zinenko if (intTy.getWidth() == 16) 1965d4a53f3bSAlex Zinenko return IntegerType::get(intTy.getContext(), 64); 1966d4a53f3bSAlex Zinenko return intTy; 1967d4a53f3bSAlex Zinenko }); 1968d4a53f3bSAlex Zinenko converter.addTargetMaterialization( 1969d4a53f3bSAlex Zinenko [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { 1970d4a53f3bSAlex Zinenko return builder.create<TestCastOp>(loc, type, inputs).getResult(); 1971d4a53f3bSAlex Zinenko }); 1972d4a53f3bSAlex Zinenko 1973d4a53f3bSAlex Zinenko ConversionTarget target(getContext()); 1974d4a53f3bSAlex Zinenko target.addIllegalOp<TestTypeChangerOp>(); 1975d4a53f3bSAlex Zinenko 1976d4a53f3bSAlex Zinenko RewritePatternSet patterns(&getContext()); 1977d4a53f3bSAlex Zinenko patterns.add<ForwardOperandPattern>(converter, &getContext()); 1978d4a53f3bSAlex Zinenko 1979d4a53f3bSAlex Zinenko if (failed(applyPartialConversion(getOperation(), target, 1980d4a53f3bSAlex Zinenko std::move(patterns)))) 1981d4a53f3bSAlex Zinenko signalPassFailure(); 1982d4a53f3bSAlex Zinenko } 1983d4a53f3bSAlex Zinenko }; 1984d4a53f3bSAlex Zinenko } // namespace 1985d4a53f3bSAlex Zinenko 1986d4a53f3bSAlex Zinenko //===----------------------------------------------------------------------===// 1987c8fb6ee3SRiver Riddle // Test Block Merging 1988c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===// 1989c8fb6ee3SRiver Riddle 1990e888886cSMaheshRavishankar namespace { 1991e888886cSMaheshRavishankar /// A rewriter pattern that tests that blocks can be merged. 1992e888886cSMaheshRavishankar struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> { 1993e888886cSMaheshRavishankar using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern; 1994e888886cSMaheshRavishankar 1995e888886cSMaheshRavishankar LogicalResult 1996ef976337SRiver Riddle matchAndRewrite(TestMergeBlocksOp op, OpAdaptor adaptor, 1997e888886cSMaheshRavishankar ConversionPatternRewriter &rewriter) const final { 19986a994233SJacques Pienaar Block &firstBlock = op.getBody().front(); 1999e888886cSMaheshRavishankar Operation *branchOp = firstBlock.getTerminator(); 20006a994233SJacques Pienaar Block *secondBlock = &*(std::next(op.getBody().begin())); 2001e888886cSMaheshRavishankar auto succOperands = branchOp->getOperands(); 2002e888886cSMaheshRavishankar SmallVector<Value, 2> replacements(succOperands); 2003e888886cSMaheshRavishankar rewriter.eraseOp(branchOp); 2004e888886cSMaheshRavishankar rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 20055fcf907bSMatthias Springer rewriter.modifyOpInPlace(op, [] {}); 2006e888886cSMaheshRavishankar return success(); 2007e888886cSMaheshRavishankar } 2008e888886cSMaheshRavishankar }; 2009e888886cSMaheshRavishankar 2010e888886cSMaheshRavishankar /// A rewrite pattern to tests the undo mechanism of blocks being merged. 2011e888886cSMaheshRavishankar struct TestUndoBlocksMerge : public ConversionPattern { 2012e888886cSMaheshRavishankar TestUndoBlocksMerge(MLIRContext *ctx) 2013e888886cSMaheshRavishankar : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} 2014e888886cSMaheshRavishankar LogicalResult 2015e888886cSMaheshRavishankar matchAndRewrite(Operation *op, ArrayRef<Value> operands, 2016e888886cSMaheshRavishankar ConversionPatternRewriter &rewriter) const final { 2017e888886cSMaheshRavishankar Block &firstBlock = op->getRegion(0).front(); 2018e888886cSMaheshRavishankar Operation *branchOp = firstBlock.getTerminator(); 2019e888886cSMaheshRavishankar Block *secondBlock = &*(std::next(op->getRegion(0).begin())); 2020e888886cSMaheshRavishankar rewriter.setInsertionPointToStart(secondBlock); 2021e888886cSMaheshRavishankar rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type()); 2022e888886cSMaheshRavishankar auto succOperands = branchOp->getOperands(); 2023e888886cSMaheshRavishankar SmallVector<Value, 2> replacements(succOperands); 2024e888886cSMaheshRavishankar rewriter.eraseOp(branchOp); 2025e888886cSMaheshRavishankar rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); 20265fcf907bSMatthias Springer rewriter.modifyOpInPlace(op, [] {}); 2027e888886cSMaheshRavishankar return success(); 2028e888886cSMaheshRavishankar } 2029e888886cSMaheshRavishankar }; 2030e888886cSMaheshRavishankar 2031e888886cSMaheshRavishankar /// A rewrite mechanism to inline the body of the op into its parent, when both 2032e888886cSMaheshRavishankar /// ops can have a single block. 2033e888886cSMaheshRavishankar struct TestMergeSingleBlockOps 2034e888886cSMaheshRavishankar : public OpConversionPattern<SingleBlockImplicitTerminatorOp> { 2035e888886cSMaheshRavishankar using OpConversionPattern< 2036e888886cSMaheshRavishankar SingleBlockImplicitTerminatorOp>::OpConversionPattern; 2037e888886cSMaheshRavishankar 2038e888886cSMaheshRavishankar LogicalResult 2039ef976337SRiver Riddle matchAndRewrite(SingleBlockImplicitTerminatorOp op, OpAdaptor adaptor, 2040e888886cSMaheshRavishankar ConversionPatternRewriter &rewriter) const final { 2041e888886cSMaheshRavishankar SingleBlockImplicitTerminatorOp parentOp = 20420bf4a82aSChristian Sigg op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 2043e888886cSMaheshRavishankar if (!parentOp) 2044e888886cSMaheshRavishankar return failure(); 20456a994233SJacques Pienaar Block &innerBlock = op.getRegion().front(); 2046e888886cSMaheshRavishankar TerminatorOp innerTerminator = 2047e888886cSMaheshRavishankar cast<TerminatorOp>(innerBlock.getTerminator()); 204842c31d83SMatthias Springer rewriter.inlineBlockBefore(&innerBlock, op); 2049e888886cSMaheshRavishankar rewriter.eraseOp(innerTerminator); 2050e888886cSMaheshRavishankar rewriter.eraseOp(op); 2051e888886cSMaheshRavishankar return success(); 2052e888886cSMaheshRavishankar } 2053e888886cSMaheshRavishankar }; 2054e888886cSMaheshRavishankar 2055e888886cSMaheshRavishankar struct TestMergeBlocksPatternDriver 2056c3a8ac79SMatthias Springer : public PassWrapper<TestMergeBlocksPatternDriver, OperationPass<>> { 20575e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMergeBlocksPatternDriver) 20585e50dd04SRiver Riddle 2059b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-merge-blocks"; } 2060b5e22e6dSMehdi Amini StringRef getDescription() const final { 2061b5e22e6dSMehdi Amini return "Test Merging operation in ConversionPatternRewriter"; 2062b5e22e6dSMehdi Amini } 2063e888886cSMaheshRavishankar void runOnOperation() override { 2064e888886cSMaheshRavishankar MLIRContext *context = &getContext(); 2065dc4e913bSChris Lattner mlir::RewritePatternSet patterns(context); 2066dc4e913bSChris Lattner patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>( 2067e888886cSMaheshRavishankar context); 2068e888886cSMaheshRavishankar ConversionTarget target(*context); 206958ceae95SRiver Riddle target.addLegalOp<func::FuncOp, ModuleOp, TerminatorOp, TestBranchOp, 2070973ddb7dSMehdi Amini TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>(); 2071e888886cSMaheshRavishankar target.addIllegalOp<ILLegalOpF>(); 2072e888886cSMaheshRavishankar 2073e888886cSMaheshRavishankar /// Expect the op to have a single block after legalization. 2074e888886cSMaheshRavishankar target.addDynamicallyLegalOp<TestMergeBlocksOp>( 2075e888886cSMaheshRavishankar [&](TestMergeBlocksOp op) -> bool { 20766a994233SJacques Pienaar return llvm::hasSingleElement(op.getBody()); 2077e888886cSMaheshRavishankar }); 2078e888886cSMaheshRavishankar 2079e888886cSMaheshRavishankar /// Only allow `test.br` within test.merge_blocks op. 2080e888886cSMaheshRavishankar target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool { 20810bf4a82aSChristian Sigg return op->getParentOfType<TestMergeBlocksOp>(); 2082e888886cSMaheshRavishankar }); 2083e888886cSMaheshRavishankar 2084e888886cSMaheshRavishankar /// Expect that all nested test.SingleBlockImplicitTerminator ops are 2085e888886cSMaheshRavishankar /// inlined. 2086e888886cSMaheshRavishankar target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>( 2087e888886cSMaheshRavishankar [&](SingleBlockImplicitTerminatorOp op) -> bool { 20880bf4a82aSChristian Sigg return !op->getParentOfType<SingleBlockImplicitTerminatorOp>(); 2089e888886cSMaheshRavishankar }); 2090e888886cSMaheshRavishankar 2091e888886cSMaheshRavishankar DenseSet<Operation *> unlegalizedOps; 2092a2821094SMatthias Springer ConversionConfig config; 2093a2821094SMatthias Springer config.unlegalizedOps = &unlegalizedOps; 20943fffffa8SRiver Riddle (void)applyPartialConversion(getOperation(), target, std::move(patterns), 2095a2821094SMatthias Springer config); 2096e888886cSMaheshRavishankar for (auto *op : unlegalizedOps) 2097e888886cSMaheshRavishankar op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; 2098e888886cSMaheshRavishankar } 2099e888886cSMaheshRavishankar }; 2100e888886cSMaheshRavishankar } // namespace 2101e888886cSMaheshRavishankar 21024589dd92SRiver Riddle //===----------------------------------------------------------------------===// 2103c8fb6ee3SRiver Riddle // Test Selective Replacement 2104c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===// 2105c8fb6ee3SRiver Riddle 2106c8fb6ee3SRiver Riddle namespace { 2107c8fb6ee3SRiver Riddle /// A rewrite mechanism to inline the body of the op into its parent, when both 2108c8fb6ee3SRiver Riddle /// ops can have a single block. 2109c8fb6ee3SRiver Riddle struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> { 2110c8fb6ee3SRiver Riddle using OpRewritePattern<TestCastOp>::OpRewritePattern; 2111c8fb6ee3SRiver Riddle 2112c8fb6ee3SRiver Riddle LogicalResult matchAndRewrite(TestCastOp op, 2113c8fb6ee3SRiver Riddle PatternRewriter &rewriter) const final { 2114c8fb6ee3SRiver Riddle if (op.getNumOperands() != 2) 2115c8fb6ee3SRiver Riddle return failure(); 2116c8fb6ee3SRiver Riddle OperandRange operands = op.getOperands(); 2117c8fb6ee3SRiver Riddle 2118c8fb6ee3SRiver Riddle // Replace non-terminator uses with the first operand. 211959a92019SMatthias Springer rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) { 2120fe7c0d90SRiver Riddle return operand.getOwner()->hasTrait<OpTrait::IsTerminator>(); 2121c8fb6ee3SRiver Riddle }); 2122c8fb6ee3SRiver Riddle // Replace everything else with the second operand if the operation isn't 2123c8fb6ee3SRiver Riddle // dead. 2124c8fb6ee3SRiver Riddle rewriter.replaceOp(op, op.getOperand(1)); 2125c8fb6ee3SRiver Riddle return success(); 2126c8fb6ee3SRiver Riddle } 2127c8fb6ee3SRiver Riddle }; 2128c8fb6ee3SRiver Riddle 2129c8fb6ee3SRiver Riddle struct TestSelectiveReplacementPatternDriver 2130c8fb6ee3SRiver Riddle : public PassWrapper<TestSelectiveReplacementPatternDriver, 2131c8fb6ee3SRiver Riddle OperationPass<>> { 21325e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 21335e50dd04SRiver Riddle TestSelectiveReplacementPatternDriver) 21345e50dd04SRiver Riddle 2135b5e22e6dSMehdi Amini StringRef getArgument() const final { 2136b5e22e6dSMehdi Amini return "test-pattern-selective-replacement"; 2137b5e22e6dSMehdi Amini } 2138b5e22e6dSMehdi Amini StringRef getDescription() const final { 2139b5e22e6dSMehdi Amini return "Test selective replacement in the PatternRewriter"; 2140b5e22e6dSMehdi Amini } 2141c8fb6ee3SRiver Riddle void runOnOperation() override { 2142c8fb6ee3SRiver Riddle MLIRContext *context = &getContext(); 2143dc4e913bSChris Lattner mlir::RewritePatternSet patterns(context); 2144dc4e913bSChris Lattner patterns.add<TestSelectiveOpReplacementPattern>(context); 214509dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 2146c8fb6ee3SRiver Riddle } 2147c8fb6ee3SRiver Riddle }; 2148c8fb6ee3SRiver Riddle } // namespace 2149c8fb6ee3SRiver Riddle 2150c8fb6ee3SRiver Riddle //===----------------------------------------------------------------------===// 21514589dd92SRiver Riddle // PassRegistration 21524589dd92SRiver Riddle //===----------------------------------------------------------------------===// 21534589dd92SRiver Riddle 2154fec6c5acSUday Bondhugula namespace mlir { 215572c65b69SAlexander Belyaev namespace test { 2156fec6c5acSUday Bondhugula void registerPatternsTestPass() { 2157b5e22e6dSMehdi Amini PassRegistration<TestReturnTypeDriver>(); 2158fec6c5acSUday Bondhugula 2159b5e22e6dSMehdi Amini PassRegistration<TestDerivedAttributeDriver>(); 21609ba37b3bSJacques Pienaar 21610f8a6b7dSJakub Kuderski PassRegistration<TestGreedyPatternDriver>(); 2162ba3a9f51SChia-hung Duan PassRegistration<TestStrictPatternDriver>(); 21630f8a6b7dSJakub Kuderski PassRegistration<TestWalkPatternDriver>(); 2164fec6c5acSUday Bondhugula 2165b5e22e6dSMehdi Amini PassRegistration<TestLegalizePatternDriver>([] { 2166b5e22e6dSMehdi Amini return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode); 2167fec6c5acSUday Bondhugula }); 2168fec6c5acSUday Bondhugula 2169b5e22e6dSMehdi Amini PassRegistration<TestRemappedValue>(); 217080d7ac3bSRiver Riddle 2171b5e22e6dSMehdi Amini PassRegistration<TestUnknownRootOpDriver>(); 21724589dd92SRiver Riddle 2173b5e22e6dSMehdi Amini PassRegistration<TestTypeConversionDriver>(); 2174d4a53f3bSAlex Zinenko PassRegistration<TestTargetMaterializationWithNoUses>(); 2175e888886cSMaheshRavishankar 217688bc24a7SMathieu Fehr PassRegistration<TestRewriteDynamicOpDriver>(); 217788bc24a7SMathieu Fehr 2178b5e22e6dSMehdi Amini PassRegistration<TestMergeBlocksPatternDriver>(); 2179b5e22e6dSMehdi Amini PassRegistration<TestSelectiveReplacementPatternDriver>(); 2180fec6c5acSUday Bondhugula } 218172c65b69SAlexander Belyaev } // namespace test 2182fec6c5acSUday Bondhugula } // namespace mlir 2183