1b6eb26fdSRiver Riddle //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===// 2b6eb26fdSRiver Riddle // 3b6eb26fdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b6eb26fdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5b6eb26fdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b6eb26fdSRiver Riddle // 7b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 8b6eb26fdSRiver Riddle 9b6eb26fdSRiver Riddle #include "mlir/Transforms/DialectConversion.h" 106ae7f66fSJacques Pienaar #include "mlir/Config/mlir-config.h" 11b6eb26fdSRiver Riddle #include "mlir/IR/Block.h" 12b6eb26fdSRiver Riddle #include "mlir/IR/Builders.h" 1365fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h" 143ace6851SMatthias Springer #include "mlir/IR/Dominance.h" 154d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 16b884f4efSMatthias Springer #include "mlir/IR/Iterators.h" 1734a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h" 18b6eb26fdSRiver Riddle #include "mlir/Rewrite/PatternApplicator.h" 199c5982efSAlex Zinenko #include "llvm/ADT/ScopeExit.h" 20b6eb26fdSRiver Riddle #include "llvm/ADT/SetVector.h" 21b6eb26fdSRiver Riddle #include "llvm/ADT/SmallPtrSet.h" 22b6eb26fdSRiver Riddle #include "llvm/Support/Debug.h" 23b6eb26fdSRiver Riddle #include "llvm/Support/FormatVariadic.h" 24b6eb26fdSRiver Riddle #include "llvm/Support/SaveAndRestore.h" 25b6eb26fdSRiver Riddle #include "llvm/Support/ScopedPrinter.h" 2605423905SKazu Hirata #include <optional> 27b6eb26fdSRiver Riddle 28b6eb26fdSRiver Riddle using namespace mlir; 29b6eb26fdSRiver Riddle using namespace mlir::detail; 30b6eb26fdSRiver Riddle 31b6eb26fdSRiver Riddle #define DEBUG_TYPE "dialect-conversion" 32b6eb26fdSRiver Riddle 33b6eb26fdSRiver Riddle /// A utility function to log a successful result for the given reason. 34b6eb26fdSRiver Riddle template <typename... Args> 354efb7754SRiver Riddle static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { 36b6eb26fdSRiver Riddle LLVM_DEBUG({ 37b6eb26fdSRiver Riddle os.unindent(); 38b6eb26fdSRiver Riddle os.startLine() << "} -> SUCCESS"; 39b6eb26fdSRiver Riddle if (!fmt.empty()) 40b6eb26fdSRiver Riddle os.getOStream() << " : " 41b6eb26fdSRiver Riddle << llvm::formatv(fmt.data(), std::forward<Args>(args)...); 42b6eb26fdSRiver Riddle os.getOStream() << "\n"; 43b6eb26fdSRiver Riddle }); 44b6eb26fdSRiver Riddle } 45b6eb26fdSRiver Riddle 46b6eb26fdSRiver Riddle /// A utility function to log a failure result for the given reason. 47b6eb26fdSRiver Riddle template <typename... Args> 484efb7754SRiver Riddle static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { 49b6eb26fdSRiver Riddle LLVM_DEBUG({ 50b6eb26fdSRiver Riddle os.unindent(); 51b6eb26fdSRiver Riddle os.startLine() << "} -> FAILURE : " 52b6eb26fdSRiver Riddle << llvm::formatv(fmt.data(), std::forward<Args>(args)...) 53b6eb26fdSRiver Riddle << "\n"; 54b6eb26fdSRiver Riddle }); 55b6eb26fdSRiver Riddle } 56b6eb26fdSRiver Riddle 572fc71e4eSMatthias Springer /// Helper function that computes an insertion point where the given value is 582fc71e4eSMatthias Springer /// defined and can be used without a dominance violation. 592fc71e4eSMatthias Springer static OpBuilder::InsertPoint computeInsertPoint(Value value) { 602fc71e4eSMatthias Springer Block *insertBlock = value.getParentBlock(); 612fc71e4eSMatthias Springer Block::iterator insertPt = insertBlock->begin(); 622fc71e4eSMatthias Springer if (OpResult inputRes = dyn_cast<OpResult>(value)) 632fc71e4eSMatthias Springer insertPt = ++inputRes.getOwner()->getIterator(); 642fc71e4eSMatthias Springer return OpBuilder::InsertPoint(insertBlock, insertPt); 652fc71e4eSMatthias Springer } 662fc71e4eSMatthias Springer 673ace6851SMatthias Springer /// Helper function that computes an insertion point where the given values are 683ace6851SMatthias Springer /// defined and can be used without a dominance violation. 693ace6851SMatthias Springer static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) { 703ace6851SMatthias Springer assert(!vals.empty() && "expected at least one value"); 7195c5c5d4SMatthias Springer DominanceInfo domInfo; 723ace6851SMatthias Springer OpBuilder::InsertPoint pt = computeInsertPoint(vals.front()); 7395c5c5d4SMatthias Springer for (Value v : vals.drop_front()) { 7495c5c5d4SMatthias Springer // Choose the "later" insertion point. 7595c5c5d4SMatthias Springer OpBuilder::InsertPoint nextPt = computeInsertPoint(v); 7695c5c5d4SMatthias Springer if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(), 7795c5c5d4SMatthias Springer nextPt.getPoint())) { 7895c5c5d4SMatthias Springer // pt is before nextPt => choose nextPt. 7995c5c5d4SMatthias Springer pt = nextPt; 8095c5c5d4SMatthias Springer } else { 8195c5c5d4SMatthias Springer #ifndef NDEBUG 8295c5c5d4SMatthias Springer // nextPt should be before pt => choose pt. 8395c5c5d4SMatthias Springer // If pt, nextPt are no dominance relationship, then there is no valid 8495c5c5d4SMatthias Springer // insertion point at which all given values are defined. 8595c5c5d4SMatthias Springer bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(), 8695c5c5d4SMatthias Springer pt.getBlock(), pt.getPoint()); 8795c5c5d4SMatthias Springer assert(dom && "unable to find valid insertion point"); 8895c5c5d4SMatthias Springer #endif // NDEBUG 8995c5c5d4SMatthias Springer } 9095c5c5d4SMatthias Springer } 913ace6851SMatthias Springer return pt; 923ace6851SMatthias Springer } 933ace6851SMatthias Springer 94b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 95b6eb26fdSRiver Riddle // ConversionValueMapping 96b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 97b6eb26fdSRiver Riddle 983ace6851SMatthias Springer /// A vector of SSA values, optimized for the most common case of a single 993ace6851SMatthias Springer /// value. 1003ace6851SMatthias Springer using ValueVector = SmallVector<Value, 1>; 1013ace6851SMatthias Springer 102b6eb26fdSRiver Riddle namespace { 1033ace6851SMatthias Springer 1043ace6851SMatthias Springer /// Helper class to make it possible to use `ValueVector` as a key in DenseMap. 1053ace6851SMatthias Springer struct ValueVectorMapInfo { 106afef716eSMatthias Springer static ValueVector getEmptyKey() { return ValueVector{Value()}; } 107afef716eSMatthias Springer static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; } 1083ace6851SMatthias Springer static ::llvm::hash_code getHashValue(const ValueVector &val) { 1093ace6851SMatthias Springer return ::llvm::hash_combine_range(val.begin(), val.end()); 1103ace6851SMatthias Springer } 1113ace6851SMatthias Springer static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) { 1123ace6851SMatthias Springer return LHS == RHS; 1133ace6851SMatthias Springer } 1143ace6851SMatthias Springer }; 1153ace6851SMatthias Springer 1164d67b278SJeff Niu /// This class wraps a IRMapping to provide recursive lookup 117b6eb26fdSRiver Riddle /// functionality, i.e. we will traverse if the mapped value also has a mapping. 118b6eb26fdSRiver Riddle struct ConversionValueMapping { 1193761b675SMatthias Springer /// Return "true" if an SSA value is mapped to the given value. May return 1203761b675SMatthias Springer /// false positives. 1213761b675SMatthias Springer bool isMappedTo(Value value) const { return mappedTo.contains(value); } 1223761b675SMatthias Springer 1233ace6851SMatthias Springer /// Lookup the most recently mapped values with the desired types in the 124fcde4f65SMatthias Springer /// mapping. 125fcde4f65SMatthias Springer /// 126fcde4f65SMatthias Springer /// Special cases: 1273ace6851SMatthias Springer /// - If the desired type range is empty, simply return the most recently 1283ace6851SMatthias Springer /// mapped values. 1293ace6851SMatthias Springer /// - If there is no mapping to the desired types, also return the most 1303ace6851SMatthias Springer /// recently mapped values. 1313ace6851SMatthias Springer /// - If there is no mapping for the given values at all, return the given 132fcde4f65SMatthias Springer /// value. 1333ace6851SMatthias Springer ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; 134b6eb26fdSRiver Riddle 1353ace6851SMatthias Springer /// Lookup the given value within the map, or return an empty vector if the 1363ace6851SMatthias Springer /// value is not mapped. If it is mapped, this follows the same behavior 1373ace6851SMatthias Springer /// as `lookupOrDefault`. 1383ace6851SMatthias Springer ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; 139b6eb26fdSRiver Riddle 1403ace6851SMatthias Springer template <typename T> 1413ace6851SMatthias Springer struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {}; 1423ace6851SMatthias Springer 1433ace6851SMatthias Springer /// Map a value vector to the one provided. 1443ace6851SMatthias Springer template <typename OldVal, typename NewVal> 1453ace6851SMatthias Springer std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value> 1463ace6851SMatthias Springer map(OldVal &&oldVal, NewVal &&newVal) { 147015192c6SRiver Riddle LLVM_DEBUG({ 1483ace6851SMatthias Springer ValueVector next(newVal); 1493ace6851SMatthias Springer while (true) { 1503ace6851SMatthias Springer assert(next != oldVal && "inserting cyclic mapping"); 1513ace6851SMatthias Springer auto it = mapping.find(next); 1523ace6851SMatthias Springer if (it == mapping.end()) 1533ace6851SMatthias Springer break; 1543ace6851SMatthias Springer next = it->second; 1553ace6851SMatthias Springer } 156015192c6SRiver Riddle }); 1573ace6851SMatthias Springer for (Value v : newVal) 1583ace6851SMatthias Springer mappedTo.insert(v); 1593ace6851SMatthias Springer 1603ace6851SMatthias Springer mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal); 161015192c6SRiver Riddle } 162015192c6SRiver Riddle 1633ace6851SMatthias Springer /// Map a value vector or single value to the one provided. 1643ace6851SMatthias Springer template <typename OldVal, typename NewVal> 1653ace6851SMatthias Springer std::enable_if_t<!IsValueVector<OldVal>::value || 1663ace6851SMatthias Springer !IsValueVector<NewVal>::value> 1673ace6851SMatthias Springer map(OldVal &&oldVal, NewVal &&newVal) { 1683ace6851SMatthias Springer if constexpr (IsValueVector<OldVal>{}) { 1693ace6851SMatthias Springer map(std::forward<OldVal>(oldVal), ValueVector{newVal}); 1703ace6851SMatthias Springer } else if constexpr (IsValueVector<NewVal>{}) { 1713ace6851SMatthias Springer map(ValueVector{oldVal}, std::forward<NewVal>(newVal)); 1723ace6851SMatthias Springer } else { 1733ace6851SMatthias Springer map(ValueVector{oldVal}, ValueVector{newVal}); 1743ace6851SMatthias Springer } 1753ace6851SMatthias Springer } 1763ace6851SMatthias Springer 1773ace6851SMatthias Springer /// Drop the last mapping for the given values. 1783ace6851SMatthias Springer void erase(const ValueVector &value) { mapping.erase(value); } 179b6eb26fdSRiver Riddle 180b6eb26fdSRiver Riddle private: 181b6eb26fdSRiver Riddle /// Current value mappings. 1823ace6851SMatthias Springer DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping; 1833761b675SMatthias Springer 1843761b675SMatthias Springer /// All SSA values that are mapped to. May contain false positives. 1853761b675SMatthias Springer DenseSet<Value> mappedTo; 186b6eb26fdSRiver Riddle }; 187be0a7e9fSMehdi Amini } // namespace 188b6eb26fdSRiver Riddle 1893ace6851SMatthias Springer ValueVector 1903ace6851SMatthias Springer ConversionValueMapping::lookupOrDefault(Value from, 1913ace6851SMatthias Springer TypeRange desiredTypes) const { 1923ace6851SMatthias Springer // Try to find the deepest values that have the desired types. If there is no 1933ace6851SMatthias Springer // such mapping, simply return the deepest values. 1943ace6851SMatthias Springer ValueVector desiredValue; 1953ace6851SMatthias Springer ValueVector current{from}; 196b6eb26fdSRiver Riddle do { 1973ace6851SMatthias Springer // Store the current value if the types match. 198faa30be1SMatthias Springer if (TypeRange(ValueRange(current)) == desiredTypes) 1993ace6851SMatthias Springer desiredValue = current; 200b6eb26fdSRiver Riddle 2013ace6851SMatthias Springer // If possible, Replace each value with (one or multiple) mapped values. 2023ace6851SMatthias Springer ValueVector next; 2033ace6851SMatthias Springer for (Value v : current) { 2043ace6851SMatthias Springer auto it = mapping.find({v}); 2053ace6851SMatthias Springer if (it != mapping.end()) { 2063ace6851SMatthias Springer llvm::append_range(next, it->second); 2073ace6851SMatthias Springer } else { 2083ace6851SMatthias Springer next.push_back(v); 2093ace6851SMatthias Springer } 2103ace6851SMatthias Springer } 2113ace6851SMatthias Springer if (next != current) { 2123ace6851SMatthias Springer // If at least one value was replaced, continue the lookup from there. 2133ace6851SMatthias Springer current = std::move(next); 2143ace6851SMatthias Springer continue; 215b6eb26fdSRiver Riddle } 216b6eb26fdSRiver Riddle 2173ace6851SMatthias Springer // Otherwise: Check if there is a mapping for the entire vector. Such 2183ace6851SMatthias Springer // mappings are materializations. (N:M mapping are not supported for value 2193ace6851SMatthias Springer // replacements.) 2203ace6851SMatthias Springer // 2213ace6851SMatthias Springer // Note: From a correctness point of view, materializations do not have to 2223ace6851SMatthias Springer // be stored (and looked up) in the mapping. But for performance reasons, 2233ace6851SMatthias Springer // we choose to reuse existing IR (when possible) instead of creating it 2243ace6851SMatthias Springer // multiple times. 2253ace6851SMatthias Springer auto it = mapping.find(current); 2263ace6851SMatthias Springer if (it == mapping.end()) { 2273ace6851SMatthias Springer // No mapping found: The lookup stops here. 2283ace6851SMatthias Springer break; 2293ace6851SMatthias Springer } 2303ace6851SMatthias Springer current = it->second; 2313ace6851SMatthias Springer } while (true); 2323ace6851SMatthias Springer 2333ace6851SMatthias Springer // If the desired values were found use them, otherwise default to the leaf 2343ace6851SMatthias Springer // values. 2353ace6851SMatthias Springer // Note: If `desiredTypes` is empty, this function always returns `current`. 2363ace6851SMatthias Springer return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); 2373ace6851SMatthias Springer } 2383ace6851SMatthias Springer 2393ace6851SMatthias Springer ValueVector ConversionValueMapping::lookupOrNull(Value from, 2403ace6851SMatthias Springer TypeRange desiredTypes) const { 2413ace6851SMatthias Springer ValueVector result = lookupOrDefault(from, desiredTypes); 2423ace6851SMatthias Springer if (result == ValueVector{from} || 243faa30be1SMatthias Springer (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) 2443ace6851SMatthias Springer return {}; 245015192c6SRiver Riddle return result; 246015192c6SRiver Riddle } 247015192c6SRiver Riddle 248b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 24901b55f16SRiver Riddle // Rewriter and Translation State 25001b55f16SRiver Riddle //===----------------------------------------------------------------------===// 25101b55f16SRiver Riddle namespace { 25201b55f16SRiver Riddle /// This class contains a snapshot of the current conversion rewriter state. 25301b55f16SRiver Riddle /// This is useful when saving and undoing a set of rewrites. 25401b55f16SRiver Riddle struct RewriterState { 25559ff4d13SMatthias Springer RewriterState(unsigned numRewrites, unsigned numIgnoredOperations, 256310a2788SMatthias Springer unsigned numReplacedOps) 25759ff4d13SMatthias Springer : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations), 258310a2788SMatthias Springer numReplacedOps(numReplacedOps) {} 25901b55f16SRiver Riddle 2608faefe36SMatthias Springer /// The current number of rewrites performed. 2618faefe36SMatthias Springer unsigned numRewrites; 26201b55f16SRiver Riddle 26301b55f16SRiver Riddle /// The current number of ignored operations. 26401b55f16SRiver Riddle unsigned numIgnoredOperations; 26555558cd0SMatthias Springer 2667b66b5d6SMatthias Springer /// The current number of replaced ops that are scheduled for erasure. 2677b66b5d6SMatthias Springer unsigned numReplacedOps; 26801b55f16SRiver Riddle }; 26901b55f16SRiver Riddle 27001b55f16SRiver Riddle //===----------------------------------------------------------------------===// 2718faefe36SMatthias Springer // IR rewrites 2728faefe36SMatthias Springer //===----------------------------------------------------------------------===// 2738faefe36SMatthias Springer 2748faefe36SMatthias Springer /// An IR rewrite that can be committed (upon success) or rolled back (upon 2758faefe36SMatthias Springer /// failure). 2768faefe36SMatthias Springer /// 2778faefe36SMatthias Springer /// The dialect conversion keeps track of IR modifications (requested by the 2788faefe36SMatthias Springer /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites 2798faefe36SMatthias Springer /// are directly applied to the IR as the rewriter API is used, some are applied 2808faefe36SMatthias Springer /// partially, and some are delayed until the `IRRewrite` objects are committed. 2818faefe36SMatthias Springer class IRRewrite { 2828faefe36SMatthias Springer public: 2838faefe36SMatthias Springer /// The kind of the rewrite. Rewrites can be undone if the conversion fails. 284e214f004SMatthias Springer /// Enum values are ordered, so that they can be used in `classof`: first all 285e214f004SMatthias Springer /// block rewrites, then all operation rewrites. 2868faefe36SMatthias Springer enum class Kind { 287e214f004SMatthias Springer // Block rewrites 2888faefe36SMatthias Springer CreateBlock, 2898faefe36SMatthias Springer EraseBlock, 2908faefe36SMatthias Springer InlineBlock, 2918faefe36SMatthias Springer MoveBlock, 2928f4cd2c7SMatthias Springer BlockTypeConversion, 293d68d2951SMatthias Springer ReplaceBlockArg, 294e214f004SMatthias Springer // Operation rewrites 295e214f004SMatthias Springer MoveOperation, 296d68d2951SMatthias Springer ModifyOperation, 2979ca70d72SMatthias Springer ReplaceOperation, 29859ff4d13SMatthias Springer CreateOperation, 29959ff4d13SMatthias Springer UnresolvedMaterialization 3008faefe36SMatthias Springer }; 3018faefe36SMatthias Springer 3028faefe36SMatthias Springer virtual ~IRRewrite() = default; 3038faefe36SMatthias Springer 304d68d2951SMatthias Springer /// Roll back the rewrite. Operations may be erased during rollback. 3058faefe36SMatthias Springer virtual void rollback() = 0; 3068faefe36SMatthias Springer 30760a20bd6SMatthias Springer /// Commit the rewrite. At this point, it is certain that the dialect 30860a20bd6SMatthias Springer /// conversion will succeed. All IR modifications, except for operation/block 30960a20bd6SMatthias Springer /// erasure, must be performed through the given rewriter. 31060a20bd6SMatthias Springer /// 31160a20bd6SMatthias Springer /// Instead of erasing operations/blocks, they should merely be unlinked 31260a20bd6SMatthias Springer /// commit phase and finally be erased during the cleanup phase. This is 31360a20bd6SMatthias Springer /// because internal dialect conversion state (such as `mapping`) may still 31460a20bd6SMatthias Springer /// be using them. 31560a20bd6SMatthias Springer /// 31660a20bd6SMatthias Springer /// Any IR modification that was already performed before the commit phase 31760a20bd6SMatthias Springer /// (e.g., insertion of an op) must be communicated to the listener that may 31860a20bd6SMatthias Springer /// be attached to the given rewriter. 31960a20bd6SMatthias Springer virtual void commit(RewriterBase &rewriter) {} 3208faefe36SMatthias Springer 3219606655fSMatthias Springer /// Cleanup operations/blocks. Cleanup is called after commit. 32260a20bd6SMatthias Springer virtual void cleanup(RewriterBase &rewriter) {} 323d68d2951SMatthias Springer 3248faefe36SMatthias Springer Kind getKind() const { return kind; } 3258faefe36SMatthias Springer 3268faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { return true; } 3278faefe36SMatthias Springer 3288faefe36SMatthias Springer protected: 3298faefe36SMatthias Springer IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl) 3308faefe36SMatthias Springer : kind(kind), rewriterImpl(rewriterImpl) {} 3318faefe36SMatthias Springer 332a2821094SMatthias Springer const ConversionConfig &getConfig() const; 333a2821094SMatthias Springer 3348faefe36SMatthias Springer const Kind kind; 3358faefe36SMatthias Springer ConversionPatternRewriterImpl &rewriterImpl; 3368faefe36SMatthias Springer }; 3378faefe36SMatthias Springer 3388faefe36SMatthias Springer /// A block rewrite. 3398faefe36SMatthias Springer class BlockRewrite : public IRRewrite { 3408faefe36SMatthias Springer public: 3418faefe36SMatthias Springer /// Return the block that this rewrite operates on. 3428faefe36SMatthias Springer Block *getBlock() const { return block; } 3438faefe36SMatthias Springer 3448faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 3458faefe36SMatthias Springer return rewrite->getKind() >= Kind::CreateBlock && 346d68d2951SMatthias Springer rewrite->getKind() <= Kind::ReplaceBlockArg; 3478faefe36SMatthias Springer } 3488faefe36SMatthias Springer 3498faefe36SMatthias Springer protected: 3508faefe36SMatthias Springer BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, 3518faefe36SMatthias Springer Block *block) 3528faefe36SMatthias Springer : IRRewrite(kind, rewriterImpl), block(block) {} 3538faefe36SMatthias Springer 3548faefe36SMatthias Springer // The block that this rewrite operates on. 3558faefe36SMatthias Springer Block *block; 3568faefe36SMatthias Springer }; 3578faefe36SMatthias Springer 3588faefe36SMatthias Springer /// Creation of a block. Block creations are immediately reflected in the IR. 3598faefe36SMatthias Springer /// There is no extra work to commit the rewrite. During rollback, the newly 3608faefe36SMatthias Springer /// created block is erased. 3618faefe36SMatthias Springer class CreateBlockRewrite : public BlockRewrite { 3628faefe36SMatthias Springer public: 3638faefe36SMatthias Springer CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block) 3648faefe36SMatthias Springer : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {} 3658faefe36SMatthias Springer 3668faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 3678faefe36SMatthias Springer return rewrite->getKind() == Kind::CreateBlock; 3688faefe36SMatthias Springer } 3698faefe36SMatthias Springer 37060a20bd6SMatthias Springer void commit(RewriterBase &rewriter) override { 37160a20bd6SMatthias Springer // The block was already created and inserted. Just inform the listener. 37260a20bd6SMatthias Springer if (auto *listener = rewriter.getListener()) 37360a20bd6SMatthias Springer listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{}); 37460a20bd6SMatthias Springer } 37560a20bd6SMatthias Springer 3768faefe36SMatthias Springer void rollback() override { 3778faefe36SMatthias Springer // Unlink all of the operations within this block, they will be deleted 3788faefe36SMatthias Springer // separately. 3798faefe36SMatthias Springer auto &blockOps = block->getOperations(); 3808faefe36SMatthias Springer while (!blockOps.empty()) 3818faefe36SMatthias Springer blockOps.remove(blockOps.begin()); 38223941019SMehdi Amini block->dropAllUses(); 38323941019SMehdi Amini if (block->getParent()) 384310a2788SMatthias Springer block->erase(); 38523941019SMehdi Amini else 3869ca70d72SMatthias Springer delete block; 3878faefe36SMatthias Springer } 3888faefe36SMatthias Springer }; 3898faefe36SMatthias Springer 3908faefe36SMatthias Springer /// Erasure of a block. Block erasures are partially reflected in the IR. Erased 3919606655fSMatthias Springer /// blocks are immediately unlinked, but only erased during cleanup. This makes 3929606655fSMatthias Springer /// it easier to rollback a block erasure: the block is simply inserted into its 3939606655fSMatthias Springer /// original location. 3948faefe36SMatthias Springer class EraseBlockRewrite : public BlockRewrite { 3958faefe36SMatthias Springer public: 39636d384b4SMatthias Springer EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block) 39736d384b4SMatthias Springer : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), 39836d384b4SMatthias Springer region(block->getParent()), insertBeforeBlock(block->getNextNode()) {} 3998faefe36SMatthias Springer 4008faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 4018faefe36SMatthias Springer return rewrite->getKind() == Kind::EraseBlock; 4028faefe36SMatthias Springer } 4038faefe36SMatthias Springer 4048faefe36SMatthias Springer ~EraseBlockRewrite() override { 4059606655fSMatthias Springer assert(!block && 4069606655fSMatthias Springer "rewrite was neither rolled back nor committed/cleaned up"); 4078faefe36SMatthias Springer } 4088faefe36SMatthias Springer 4098faefe36SMatthias Springer void rollback() override { 4108faefe36SMatthias Springer // The block (owned by this rewrite) was not actually erased yet. It was 4118faefe36SMatthias Springer // just unlinked. Put it back into its original position. 4128faefe36SMatthias Springer assert(block && "expected block"); 4138faefe36SMatthias Springer auto &blockList = region->getBlocks(); 4148faefe36SMatthias Springer Region::iterator before = insertBeforeBlock 4158faefe36SMatthias Springer ? Region::iterator(insertBeforeBlock) 4168faefe36SMatthias Springer : blockList.end(); 4178faefe36SMatthias Springer blockList.insert(before, block); 4188faefe36SMatthias Springer block = nullptr; 4198faefe36SMatthias Springer } 4208faefe36SMatthias Springer 42160a20bd6SMatthias Springer void commit(RewriterBase &rewriter) override { 4228faefe36SMatthias Springer // Erase the block. 4238faefe36SMatthias Springer assert(block && "expected block"); 424d68d2951SMatthias Springer assert(block->empty() && "expected empty block"); 42560a20bd6SMatthias Springer 42660a20bd6SMatthias Springer // Notify the listener that the block is about to be erased. 42760a20bd6SMatthias Springer if (auto *listener = 42860a20bd6SMatthias Springer dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener())) 42960a20bd6SMatthias Springer listener->notifyBlockErased(block); 43060a20bd6SMatthias Springer } 43160a20bd6SMatthias Springer 43260a20bd6SMatthias Springer void cleanup(RewriterBase &rewriter) override { 43360a20bd6SMatthias Springer // Erase the block. 434d68d2951SMatthias Springer block->dropAllDefinedValueUses(); 4358faefe36SMatthias Springer delete block; 4368faefe36SMatthias Springer block = nullptr; 4378faefe36SMatthias Springer } 4388faefe36SMatthias Springer 4398faefe36SMatthias Springer private: 4408faefe36SMatthias Springer // The region in which this block was previously contained. 4418faefe36SMatthias Springer Region *region; 4428faefe36SMatthias Springer 4438faefe36SMatthias Springer // The original successor of this block before it was unlinked. "nullptr" if 4448faefe36SMatthias Springer // this block was the only block in the region. 4458faefe36SMatthias Springer Block *insertBeforeBlock; 4468faefe36SMatthias Springer }; 4478faefe36SMatthias Springer 4488faefe36SMatthias Springer /// Inlining of a block. This rewrite is immediately reflected in the IR. 4498faefe36SMatthias Springer /// Note: This rewrite represents only the inlining of the operations. The 4508faefe36SMatthias Springer /// erasure of the inlined block is a separate rewrite. 4518faefe36SMatthias Springer class InlineBlockRewrite : public BlockRewrite { 4528faefe36SMatthias Springer public: 4538faefe36SMatthias Springer InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, 4548faefe36SMatthias Springer Block *sourceBlock, Block::iterator before) 4558faefe36SMatthias Springer : BlockRewrite(Kind::InlineBlock, rewriterImpl, block), 4568faefe36SMatthias Springer sourceBlock(sourceBlock), 4578faefe36SMatthias Springer firstInlinedInst(sourceBlock->empty() ? nullptr 4588faefe36SMatthias Springer : &sourceBlock->front()), 4598faefe36SMatthias Springer lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) { 46060a20bd6SMatthias Springer // If a listener is attached to the dialect conversion, ops must be moved 46160a20bd6SMatthias Springer // one-by-one. When they are moved in bulk, notifications cannot be sent 46260a20bd6SMatthias Springer // because the ops that used to be in the source block at the time of the 46360a20bd6SMatthias Springer // inlining (before the "commit" phase) are unknown at the time when 46460a20bd6SMatthias Springer // notifications are sent (which is during the "commit" phase). 46560a20bd6SMatthias Springer assert(!getConfig().listener && 46660a20bd6SMatthias Springer "InlineBlockRewrite not supported if listener is attached"); 4678faefe36SMatthias Springer } 4688faefe36SMatthias Springer 4698faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 4708faefe36SMatthias Springer return rewrite->getKind() == Kind::InlineBlock; 4718faefe36SMatthias Springer } 4728faefe36SMatthias Springer 4738faefe36SMatthias Springer void rollback() override { 4748faefe36SMatthias Springer // Put the operations from the destination block (owned by the rewrite) 4758faefe36SMatthias Springer // back into the source block. 4768faefe36SMatthias Springer if (firstInlinedInst) { 4778faefe36SMatthias Springer assert(lastInlinedInst && "expected operation"); 4788faefe36SMatthias Springer sourceBlock->getOperations().splice(sourceBlock->begin(), 4798faefe36SMatthias Springer block->getOperations(), 4808faefe36SMatthias Springer Block::iterator(firstInlinedInst), 4818faefe36SMatthias Springer ++Block::iterator(lastInlinedInst)); 4828faefe36SMatthias Springer } 4838faefe36SMatthias Springer } 4848faefe36SMatthias Springer 4858faefe36SMatthias Springer private: 4868faefe36SMatthias Springer // The block that originally contained the operations. 4878faefe36SMatthias Springer Block *sourceBlock; 4888faefe36SMatthias Springer 4898faefe36SMatthias Springer // The first inlined operation. 4908faefe36SMatthias Springer Operation *firstInlinedInst; 4918faefe36SMatthias Springer 4928faefe36SMatthias Springer // The last inlined operation. 4938faefe36SMatthias Springer Operation *lastInlinedInst; 4948faefe36SMatthias Springer }; 4958faefe36SMatthias Springer 4968faefe36SMatthias Springer /// Moving of a block. This rewrite is immediately reflected in the IR. 4978faefe36SMatthias Springer class MoveBlockRewrite : public BlockRewrite { 4988faefe36SMatthias Springer public: 4998faefe36SMatthias Springer MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, 5008faefe36SMatthias Springer Region *region, Block *insertBeforeBlock) 5018faefe36SMatthias Springer : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region), 5028faefe36SMatthias Springer insertBeforeBlock(insertBeforeBlock) {} 5038faefe36SMatthias Springer 5048faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 5058faefe36SMatthias Springer return rewrite->getKind() == Kind::MoveBlock; 5068faefe36SMatthias Springer } 5078faefe36SMatthias Springer 50860a20bd6SMatthias Springer void commit(RewriterBase &rewriter) override { 50960a20bd6SMatthias Springer // The block was already moved. Just inform the listener. 51060a20bd6SMatthias Springer if (auto *listener = rewriter.getListener()) { 51160a20bd6SMatthias Springer // Note: `previousIt` cannot be passed because this is a delayed 51260a20bd6SMatthias Springer // notification and iterators into past IR state cannot be represented. 51360a20bd6SMatthias Springer listener->notifyBlockInserted(block, /*previous=*/region, 51460a20bd6SMatthias Springer /*previousIt=*/{}); 51560a20bd6SMatthias Springer } 51660a20bd6SMatthias Springer } 51760a20bd6SMatthias Springer 5188faefe36SMatthias Springer void rollback() override { 5198faefe36SMatthias Springer // Move the block back to its original position. 5208faefe36SMatthias Springer Region::iterator before = 5218faefe36SMatthias Springer insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end(); 5228faefe36SMatthias Springer region->getBlocks().splice(before, block->getParent()->getBlocks(), block); 5238faefe36SMatthias Springer } 5248faefe36SMatthias Springer 5258faefe36SMatthias Springer private: 5268faefe36SMatthias Springer // The region in which this block was previously contained. 5278faefe36SMatthias Springer Region *region; 5288faefe36SMatthias Springer 5298faefe36SMatthias Springer // The original successor of this block before it was moved. "nullptr" if 5308faefe36SMatthias Springer // this block was the only block in the region. 5318faefe36SMatthias Springer Block *insertBeforeBlock; 5328faefe36SMatthias Springer }; 5338faefe36SMatthias Springer 5348faefe36SMatthias Springer /// Block type conversion. This rewrite is partially reflected in the IR. 5358faefe36SMatthias Springer class BlockTypeConversionRewrite : public BlockRewrite { 5368faefe36SMatthias Springer public: 537bbd4af5dSMatthias Springer BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, 538345ca6a6SMatthias Springer Block *origBlock, Block *newBlock) 539345ca6a6SMatthias Springer : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, origBlock), 540345ca6a6SMatthias Springer newBlock(newBlock) {} 5418faefe36SMatthias Springer 5428faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 5438faefe36SMatthias Springer return rewrite->getKind() == Kind::BlockTypeConversion; 5448faefe36SMatthias Springer } 5458faefe36SMatthias Springer 546345ca6a6SMatthias Springer Block *getOrigBlock() const { return block; } 547345ca6a6SMatthias Springer 548345ca6a6SMatthias Springer Block *getNewBlock() const { return newBlock; } 5492fc71e4eSMatthias Springer 55060a20bd6SMatthias Springer void commit(RewriterBase &rewriter) override; 55155558cd0SMatthias Springer 5528faefe36SMatthias Springer void rollback() override; 55355558cd0SMatthias Springer 55455558cd0SMatthias Springer private: 555345ca6a6SMatthias Springer /// The new block that was created as part of this signature conversion. 556345ca6a6SMatthias Springer Block *newBlock; 5578faefe36SMatthias Springer }; 5588f4cd2c7SMatthias Springer 559d68d2951SMatthias Springer /// Replacing a block argument. This rewrite is not immediately reflected in the 560d68d2951SMatthias Springer /// IR. An internal IR mapping is updated, but the actual replacement is delayed 561d68d2951SMatthias Springer /// until the rewrite is committed. 562d68d2951SMatthias Springer class ReplaceBlockArgRewrite : public BlockRewrite { 563d68d2951SMatthias Springer public: 564d68d2951SMatthias Springer ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl, 5653761b675SMatthias Springer Block *block, BlockArgument arg, 5663761b675SMatthias Springer const TypeConverter *converter) 5673761b675SMatthias Springer : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg), 5683761b675SMatthias Springer converter(converter) {} 569d68d2951SMatthias Springer 570d68d2951SMatthias Springer static bool classof(const IRRewrite *rewrite) { 571d68d2951SMatthias Springer return rewrite->getKind() == Kind::ReplaceBlockArg; 572d68d2951SMatthias Springer } 573d68d2951SMatthias Springer 57460a20bd6SMatthias Springer void commit(RewriterBase &rewriter) override; 575d68d2951SMatthias Springer 576d68d2951SMatthias Springer void rollback() override; 577d68d2951SMatthias Springer 578d68d2951SMatthias Springer private: 579d68d2951SMatthias Springer BlockArgument arg; 5803761b675SMatthias Springer 5813761b675SMatthias Springer /// The current type converter when the block argument was replaced. 5823761b675SMatthias Springer const TypeConverter *converter; 583d68d2951SMatthias Springer }; 584d68d2951SMatthias Springer 5858f4cd2c7SMatthias Springer /// An operation rewrite. 5868f4cd2c7SMatthias Springer class OperationRewrite : public IRRewrite { 5878f4cd2c7SMatthias Springer public: 5888f4cd2c7SMatthias Springer /// Return the operation that this rewrite operates on. 5898f4cd2c7SMatthias Springer Operation *getOperation() const { return op; } 5908f4cd2c7SMatthias Springer 5918f4cd2c7SMatthias Springer static bool classof(const IRRewrite *rewrite) { 5928f4cd2c7SMatthias Springer return rewrite->getKind() >= Kind::MoveOperation && 59359ff4d13SMatthias Springer rewrite->getKind() <= Kind::UnresolvedMaterialization; 5948f4cd2c7SMatthias Springer } 5958f4cd2c7SMatthias Springer 5968f4cd2c7SMatthias Springer protected: 5978f4cd2c7SMatthias Springer OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, 5988f4cd2c7SMatthias Springer Operation *op) 5998f4cd2c7SMatthias Springer : IRRewrite(kind, rewriterImpl), op(op) {} 6008f4cd2c7SMatthias Springer 6018f4cd2c7SMatthias Springer // The operation that this rewrite operates on. 6028f4cd2c7SMatthias Springer Operation *op; 6038f4cd2c7SMatthias Springer }; 6048f4cd2c7SMatthias Springer 6058f4cd2c7SMatthias Springer /// Moving of an operation. This rewrite is immediately reflected in the IR. 6068f4cd2c7SMatthias Springer class MoveOperationRewrite : public OperationRewrite { 6078f4cd2c7SMatthias Springer public: 6088f4cd2c7SMatthias Springer MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, 6098f4cd2c7SMatthias Springer Operation *op, Block *block, Operation *insertBeforeOp) 6108f4cd2c7SMatthias Springer : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block), 6118f4cd2c7SMatthias Springer insertBeforeOp(insertBeforeOp) {} 6128f4cd2c7SMatthias Springer 6138f4cd2c7SMatthias Springer static bool classof(const IRRewrite *rewrite) { 6148f4cd2c7SMatthias Springer return rewrite->getKind() == Kind::MoveOperation; 6158f4cd2c7SMatthias Springer } 6168f4cd2c7SMatthias Springer 61760a20bd6SMatthias Springer void commit(RewriterBase &rewriter) override { 61860a20bd6SMatthias Springer // The operation was already moved. Just inform the listener. 61960a20bd6SMatthias Springer if (auto *listener = rewriter.getListener()) { 62060a20bd6SMatthias Springer // Note: `previousIt` cannot be passed because this is a delayed 62160a20bd6SMatthias Springer // notification and iterators into past IR state cannot be represented. 62260a20bd6SMatthias Springer listener->notifyOperationInserted( 62360a20bd6SMatthias Springer op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block, 62460a20bd6SMatthias Springer /*insertPt=*/{})); 62560a20bd6SMatthias Springer } 62660a20bd6SMatthias Springer } 62760a20bd6SMatthias Springer 6288f4cd2c7SMatthias Springer void rollback() override { 6298f4cd2c7SMatthias Springer // Move the operation back to its original position. 6308f4cd2c7SMatthias Springer Block::iterator before = 6318f4cd2c7SMatthias Springer insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end(); 6328f4cd2c7SMatthias Springer block->getOperations().splice(before, op->getBlock()->getOperations(), op); 6338f4cd2c7SMatthias Springer } 6348f4cd2c7SMatthias Springer 6358f4cd2c7SMatthias Springer private: 6368f4cd2c7SMatthias Springer // The block in which this operation was previously contained. 6378f4cd2c7SMatthias Springer Block *block; 6388f4cd2c7SMatthias Springer 63955558cd0SMatthias Springer // The original successor of this operation before it was moved. "nullptr" 64055558cd0SMatthias Springer // if this operation was the only operation in the region. 6418f4cd2c7SMatthias Springer Operation *insertBeforeOp; 6428f4cd2c7SMatthias Springer }; 643e214f004SMatthias Springer 644e214f004SMatthias Springer /// In-place modification of an op. This rewrite is immediately reflected in 645e214f004SMatthias Springer /// the IR. The previous state of the operation is stored in this object. 646e214f004SMatthias Springer class ModifyOperationRewrite : public OperationRewrite { 647e214f004SMatthias Springer public: 648e214f004SMatthias Springer ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, 649e214f004SMatthias Springer Operation *op) 650e214f004SMatthias Springer : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op), 6513f732c41SMatthias Springer name(op->getName()), loc(op->getLoc()), attrs(op->getAttrDictionary()), 652e214f004SMatthias Springer operands(op->operand_begin(), op->operand_end()), 6533a70335bSMatthias Springer successors(op->successor_begin(), op->successor_end()) { 6543a70335bSMatthias Springer if (OpaqueProperties prop = op->getPropertiesStorage()) { 6553a70335bSMatthias Springer // Make a copy of the properties. 6563a70335bSMatthias Springer propertiesStorage = operator new(op->getPropertiesStorageSize()); 6573a70335bSMatthias Springer OpaqueProperties propCopy(propertiesStorage); 6583f732c41SMatthias Springer name.initOpProperties(propCopy, /*init=*/prop); 6593a70335bSMatthias Springer } 6603a70335bSMatthias Springer } 661e214f004SMatthias Springer 662e214f004SMatthias Springer static bool classof(const IRRewrite *rewrite) { 663e214f004SMatthias Springer return rewrite->getKind() == Kind::ModifyOperation; 664e214f004SMatthias Springer } 665e214f004SMatthias Springer 6663a70335bSMatthias Springer ~ModifyOperationRewrite() override { 6673a70335bSMatthias Springer assert(!propertiesStorage && 6683a70335bSMatthias Springer "rewrite was neither committed nor rolled back"); 6693a70335bSMatthias Springer } 6703a70335bSMatthias Springer 67160a20bd6SMatthias Springer void commit(RewriterBase &rewriter) override { 67260a20bd6SMatthias Springer // Notify the listener that the operation was modified in-place. 67360a20bd6SMatthias Springer if (auto *listener = 67460a20bd6SMatthias Springer dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener())) 67560a20bd6SMatthias Springer listener->notifyOperationModified(op); 67660a20bd6SMatthias Springer 6773a70335bSMatthias Springer if (propertiesStorage) { 6783a70335bSMatthias Springer OpaqueProperties propCopy(propertiesStorage); 6793f732c41SMatthias Springer // Note: The operation may have been erased in the mean time, so 6803f732c41SMatthias Springer // OperationName must be stored in this object. 6813f732c41SMatthias Springer name.destroyOpProperties(propCopy); 6823a70335bSMatthias Springer operator delete(propertiesStorage); 6833a70335bSMatthias Springer propertiesStorage = nullptr; 6843a70335bSMatthias Springer } 6853a70335bSMatthias Springer } 6863a70335bSMatthias Springer 687e214f004SMatthias Springer void rollback() override { 688e214f004SMatthias Springer op->setLoc(loc); 689e214f004SMatthias Springer op->setAttrs(attrs); 690e214f004SMatthias Springer op->setOperands(operands); 691e214f004SMatthias Springer for (const auto &it : llvm::enumerate(successors)) 692e214f004SMatthias Springer op->setSuccessor(it.value(), it.index()); 6933a70335bSMatthias Springer if (propertiesStorage) { 6943a70335bSMatthias Springer OpaqueProperties propCopy(propertiesStorage); 6953a70335bSMatthias Springer op->copyProperties(propCopy); 6963f732c41SMatthias Springer name.destroyOpProperties(propCopy); 6973a70335bSMatthias Springer operator delete(propertiesStorage); 6983a70335bSMatthias Springer propertiesStorage = nullptr; 6993a70335bSMatthias Springer } 700e214f004SMatthias Springer } 701e214f004SMatthias Springer 702e214f004SMatthias Springer private: 7033f732c41SMatthias Springer OperationName name; 704e214f004SMatthias Springer LocationAttr loc; 705e214f004SMatthias Springer DictionaryAttr attrs; 706e214f004SMatthias Springer SmallVector<Value, 8> operands; 707e214f004SMatthias Springer SmallVector<Block *, 2> successors; 7083a70335bSMatthias Springer void *propertiesStorage = nullptr; 709e214f004SMatthias Springer }; 710d68d2951SMatthias Springer 711d68d2951SMatthias Springer /// Replacing an operation. Erasing an operation is treated as a special case 712d68d2951SMatthias Springer /// with "null" replacements. This rewrite is not immediately reflected in the 713d68d2951SMatthias Springer /// IR. An internal IR mapping is updated, but values are not replaced and the 714d68d2951SMatthias Springer /// original op is not erased until the rewrite is committed. 715d68d2951SMatthias Springer class ReplaceOperationRewrite : public OperationRewrite { 716d68d2951SMatthias Springer public: 717d68d2951SMatthias Springer ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, 7186093c26aSMatthias Springer Operation *op, const TypeConverter *converter) 719d68d2951SMatthias Springer : OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op), 7206093c26aSMatthias Springer converter(converter) {} 721d68d2951SMatthias Springer 722d68d2951SMatthias Springer static bool classof(const IRRewrite *rewrite) { 723d68d2951SMatthias Springer return rewrite->getKind() == Kind::ReplaceOperation; 724d68d2951SMatthias Springer } 725d68d2951SMatthias Springer 72660a20bd6SMatthias Springer void commit(RewriterBase &rewriter) override; 727d68d2951SMatthias Springer 728d68d2951SMatthias Springer void rollback() override; 729d68d2951SMatthias Springer 73060a20bd6SMatthias Springer void cleanup(RewriterBase &rewriter) override; 731d68d2951SMatthias Springer 732a622b21fSMatthias Springer private: 733d68d2951SMatthias Springer /// An optional type converter that can be used to materialize conversions 734d68d2951SMatthias Springer /// between the new and old values if necessary. 735d68d2951SMatthias Springer const TypeConverter *converter; 736d68d2951SMatthias Springer }; 7379ca70d72SMatthias Springer 7389ca70d72SMatthias Springer class CreateOperationRewrite : public OperationRewrite { 7399ca70d72SMatthias Springer public: 7409ca70d72SMatthias Springer CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, 7419ca70d72SMatthias Springer Operation *op) 7429ca70d72SMatthias Springer : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {} 7439ca70d72SMatthias Springer 7449ca70d72SMatthias Springer static bool classof(const IRRewrite *rewrite) { 7459ca70d72SMatthias Springer return rewrite->getKind() == Kind::CreateOperation; 7469ca70d72SMatthias Springer } 7479ca70d72SMatthias Springer 74860a20bd6SMatthias Springer void commit(RewriterBase &rewriter) override { 74960a20bd6SMatthias Springer // The operation was already created and inserted. Just inform the listener. 75060a20bd6SMatthias Springer if (auto *listener = rewriter.getListener()) 75160a20bd6SMatthias Springer listener->notifyOperationInserted(op, /*previous=*/{}); 75260a20bd6SMatthias Springer } 75360a20bd6SMatthias Springer 7549ca70d72SMatthias Springer void rollback() override; 7559ca70d72SMatthias Springer }; 75659ff4d13SMatthias Springer 75759ff4d13SMatthias Springer /// The type of materialization. 75859ff4d13SMatthias Springer enum MaterializationKind { 75959ff4d13SMatthias Springer /// This materialization materializes a conversion from an illegal type to a 76059ff4d13SMatthias Springer /// legal one. 761bbd4af5dSMatthias Springer Target, 762bbd4af5dSMatthias Springer 763bbd4af5dSMatthias Springer /// This materialization materializes a conversion from a legal type back to 764bbd4af5dSMatthias Springer /// an illegal one. 765bbd4af5dSMatthias Springer Source 76659ff4d13SMatthias Springer }; 76759ff4d13SMatthias Springer 76859ff4d13SMatthias Springer /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast" 76959ff4d13SMatthias Springer /// op. Unresolved materializations are erased at the end of the dialect 77059ff4d13SMatthias Springer /// conversion. 77159ff4d13SMatthias Springer class UnresolvedMaterializationRewrite : public OperationRewrite { 77259ff4d13SMatthias Springer public: 7730d906a42SMatthias Springer UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl, 7740d906a42SMatthias Springer UnrealizedConversionCastOp op, 7750d906a42SMatthias Springer const TypeConverter *converter, 776f2d500c6SMatthias Springer MaterializationKind kind, Type originalType, 7773ace6851SMatthias Springer ValueVector mappedValues); 77859ff4d13SMatthias Springer 77959ff4d13SMatthias Springer static bool classof(const IRRewrite *rewrite) { 78059ff4d13SMatthias Springer return rewrite->getKind() == Kind::UnresolvedMaterialization; 78159ff4d13SMatthias Springer } 78259ff4d13SMatthias Springer 7833815f478SMatthias Springer void rollback() override; 7843815f478SMatthias Springer 78559ff4d13SMatthias Springer UnrealizedConversionCastOp getOperation() const { 78659ff4d13SMatthias Springer return cast<UnrealizedConversionCastOp>(op); 78759ff4d13SMatthias Springer } 78859ff4d13SMatthias Springer 78959ff4d13SMatthias Springer /// Return the type converter of this materialization (which may be null). 79059ff4d13SMatthias Springer const TypeConverter *getConverter() const { 79159ff4d13SMatthias Springer return converterAndKind.getPointer(); 79259ff4d13SMatthias Springer } 79359ff4d13SMatthias Springer 79459ff4d13SMatthias Springer /// Return the kind of this materialization. 79559ff4d13SMatthias Springer MaterializationKind getMaterializationKind() const { 79659ff4d13SMatthias Springer return converterAndKind.getInt(); 79759ff4d13SMatthias Springer } 79859ff4d13SMatthias Springer 7990d906a42SMatthias Springer /// Return the original type of the SSA value. 8000d906a42SMatthias Springer Type getOriginalType() const { return originalType; } 8010d906a42SMatthias Springer 80259ff4d13SMatthias Springer private: 80359ff4d13SMatthias Springer /// The corresponding type converter to use when resolving this 80459ff4d13SMatthias Springer /// materialization, and the kind of this materialization. 805bbd4af5dSMatthias Springer llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind> 80659ff4d13SMatthias Springer converterAndKind; 8070d906a42SMatthias Springer 8080d906a42SMatthias Springer /// The original type of the SSA value. Only used for target 8090d906a42SMatthias Springer /// materializations. 8100d906a42SMatthias Springer Type originalType; 811f2d500c6SMatthias Springer 8123ace6851SMatthias Springer /// The values in the conversion value mapping that are being replaced by the 813f2d500c6SMatthias Springer /// results of this unresolved materialization. 8143ace6851SMatthias Springer ValueVector mappedValues; 81559ff4d13SMatthias Springer }; 8168faefe36SMatthias Springer } // namespace 8178faefe36SMatthias Springer 81879f41434SMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 819e214f004SMatthias Springer /// Return "true" if there is an operation rewrite that matches the specified 820e214f004SMatthias Springer /// rewrite type and operation among the given rewrites. 821e214f004SMatthias Springer template <typename RewriteTy, typename R> 822e214f004SMatthias Springer static bool hasRewrite(R &&rewrites, Operation *op) { 823962534c4SAdrian Kuegel return any_of(std::forward<R>(rewrites), [&](auto &rewrite) { 824e214f004SMatthias Springer auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get()); 825e214f004SMatthias Springer return rewriteTy && rewriteTy->getOperation() == op; 826e214f004SMatthias Springer }); 827e214f004SMatthias Springer } 828e214f004SMatthias Springer 829345ca6a6SMatthias Springer /// Return "true" if there is a block rewrite that matches the specified 830345ca6a6SMatthias Springer /// rewrite type and block among the given rewrites. 831345ca6a6SMatthias Springer template <typename RewriteTy, typename R> 832345ca6a6SMatthias Springer static bool hasRewrite(R &&rewrites, Block *block) { 833345ca6a6SMatthias Springer return any_of(std::forward<R>(rewrites), [&](auto &rewrite) { 834345ca6a6SMatthias Springer auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get()); 835345ca6a6SMatthias Springer return rewriteTy && rewriteTy->getBlock() == block; 836345ca6a6SMatthias Springer }); 837345ca6a6SMatthias Springer } 83879f41434SMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 839345ca6a6SMatthias Springer 8408faefe36SMatthias Springer //===----------------------------------------------------------------------===// 841b6eb26fdSRiver Riddle // ConversionPatternRewriterImpl 842b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 843b6eb26fdSRiver Riddle namespace mlir { 844b6eb26fdSRiver Riddle namespace detail { 845ea2d9383SMatthias Springer struct ConversionPatternRewriterImpl : public RewriterBase::Listener { 846a2821094SMatthias Springer explicit ConversionPatternRewriterImpl(MLIRContext *ctx, 847a2821094SMatthias Springer const ConversionConfig &config) 8483815f478SMatthias Springer : context(ctx), eraseRewriter(ctx), config(config) {} 849b6eb26fdSRiver Riddle 850b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 851b6eb26fdSRiver Riddle // State Management 852b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 853b6eb26fdSRiver Riddle 854b6eb26fdSRiver Riddle /// Return the current state of the rewriter. 855b6eb26fdSRiver Riddle RewriterState getCurrentState(); 856b6eb26fdSRiver Riddle 85759ff4d13SMatthias Springer /// Apply all requested operation rewrites. This method is invoked when the 85859ff4d13SMatthias Springer /// conversion process succeeds. 85959ff4d13SMatthias Springer void applyRewrites(); 86059ff4d13SMatthias Springer 861b6eb26fdSRiver Riddle /// Reset the state of the rewriter to a previously saved point. 862b6eb26fdSRiver Riddle void resetState(RewriterState state); 863b6eb26fdSRiver Riddle 8648faefe36SMatthias Springer /// Append a rewrite. Rewrites are committed upon success and rolled back upon 8658faefe36SMatthias Springer /// failure. 8668faefe36SMatthias Springer template <typename RewriteTy, typename... Args> 8678faefe36SMatthias Springer void appendRewrite(Args &&...args) { 8688faefe36SMatthias Springer rewrites.push_back( 8698faefe36SMatthias Springer std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...)); 8708faefe36SMatthias Springer } 871b6eb26fdSRiver Riddle 8728faefe36SMatthias Springer /// Undo the rewrites (motions, splits) one by one in reverse order until 8738faefe36SMatthias Springer /// "numRewritesToKeep" rewrites remains. 8748faefe36SMatthias Springer void undoRewrites(unsigned numRewritesToKeep = 0); 875b6eb26fdSRiver Riddle 876015192c6SRiver Riddle /// Remap the given values to those with potentially different types. Returns 877015192c6SRiver Riddle /// success if the values could be remapped, failure otherwise. `valueDiagTag` 878015192c6SRiver Riddle /// is the tag used when describing a value within a diagnostic, e.g. 879015192c6SRiver Riddle /// "operand". 8800de16fafSRamkumar Ramachandra LogicalResult remapValues(StringRef valueDiagTag, 8810de16fafSRamkumar Ramachandra std::optional<Location> inputLoc, 882015192c6SRiver Riddle PatternRewriter &rewriter, ValueRange values, 8833ace6851SMatthias Springer SmallVector<ValueVector> &remapped); 884b6eb26fdSRiver Riddle 8856008cd40SMatthias Springer /// Return "true" if the given operation is ignored, and does not need to be 886b6eb26fdSRiver Riddle /// converted. 887b6eb26fdSRiver Riddle bool isOpIgnored(Operation *op) const; 888b6eb26fdSRiver Riddle 8896008cd40SMatthias Springer /// Return "true" if the given operation was replaced or erased. 8906008cd40SMatthias Springer bool wasOpReplaced(Operation *op) const; 891b6eb26fdSRiver Riddle 892b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 893b6eb26fdSRiver Riddle // Type Conversion 894b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 895b6eb26fdSRiver Riddle 896b6eb26fdSRiver Riddle /// Convert the types of block arguments within the given region. 897b6eb26fdSRiver Riddle FailureOr<Block *> 898aaf5c818SMatthias Springer convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, 899aaf5c818SMatthias Springer const TypeConverter &converter, 900b6eb26fdSRiver Riddle TypeConverter::SignatureConversion *entryConversion); 901b6eb26fdSRiver Riddle 90255558cd0SMatthias Springer /// Apply the given signature conversion on the given block. The new block 90355558cd0SMatthias Springer /// containing the updated signature is returned. If no conversions were 90455558cd0SMatthias Springer /// necessary, e.g. if the block has no arguments, `block` is returned. 90555558cd0SMatthias Springer /// `converter` is used to generate any necessary cast operations that 90655558cd0SMatthias Springer /// translate between the origin argument types and those specified in the 90755558cd0SMatthias Springer /// signature conversion. 90855558cd0SMatthias Springer Block *applySignatureConversion( 909aaf5c818SMatthias Springer ConversionPatternRewriter &rewriter, Block *block, 910aaf5c818SMatthias Springer const TypeConverter *converter, 91155558cd0SMatthias Springer TypeConverter::SignatureConversion &signatureConversion); 9123b021fbdSKareemErgawy-TomTom 913b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 91459ff4d13SMatthias Springer // Materializations 91559ff4d13SMatthias Springer //===--------------------------------------------------------------------===// 9163815f478SMatthias Springer 9179df63b26SMatthias Springer /// Build an unresolved materialization operation given a range of output 9189df63b26SMatthias Springer /// types and a list of input operands. Returns the inputs if they their 9199df63b26SMatthias Springer /// types match the output types. 9209df63b26SMatthias Springer /// 9219df63b26SMatthias Springer /// If a cast op was built, it can optionally be returned with the `castOp` 9229df63b26SMatthias Springer /// output argument. 923f2d500c6SMatthias Springer /// 9243ace6851SMatthias Springer /// If `valuesToMap` is set to a non-null Value, then that value is mapped to 9259df63b26SMatthias Springer /// the results of the unresolved materialization in the conversion value 926f2d500c6SMatthias Springer /// mapping. 9279df63b26SMatthias Springer ValueRange buildUnresolvedMaterialization( 9289df63b26SMatthias Springer MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, 9293ace6851SMatthias Springer ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, 9309df63b26SMatthias Springer Type originalType, const TypeConverter *converter, 9319df63b26SMatthias Springer UnrealizedConversionCastOp *castOp = nullptr); 932aed43562SMatthias Springer 9333761b675SMatthias Springer /// Find a replacement value for the given SSA value in the conversion value 9343761b675SMatthias Springer /// mapping. The replacement value must have the same type as the given SSA 9353761b675SMatthias Springer /// value. If there is no replacement value with the correct type, find the 9363761b675SMatthias Springer /// latest replacement value (regardless of the type) and build a source 9373761b675SMatthias Springer /// materialization. 9383761b675SMatthias Springer Value findOrBuildReplacementValue(Value value, 9393761b675SMatthias Springer const TypeConverter *converter); 9403761b675SMatthias Springer 94159ff4d13SMatthias Springer //===--------------------------------------------------------------------===// 942b6eb26fdSRiver Riddle // Rewriter Notification Hooks 943b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 944b6eb26fdSRiver Riddle 945ea2d9383SMatthias Springer //// Notifies that an op was inserted. 946ea2d9383SMatthias Springer void notifyOperationInserted(Operation *op, 947ea2d9383SMatthias Springer OpBuilder::InsertPoint previous) override; 948ea2d9383SMatthias Springer 949ea2d9383SMatthias Springer /// Notifies that an op is about to be replaced with the given values. 9509df63b26SMatthias Springer void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues); 951b6eb26fdSRiver Riddle 952b6eb26fdSRiver Riddle /// Notifies that a block is about to be erased. 953b6eb26fdSRiver Riddle void notifyBlockIsBeingErased(Block *block); 954b6eb26fdSRiver Riddle 955ea2d9383SMatthias Springer /// Notifies that a block was inserted. 956ea2d9383SMatthias Springer void notifyBlockInserted(Block *block, Region *previous, 957ea2d9383SMatthias Springer Region::iterator previousIt) override; 958b6eb26fdSRiver Riddle 95942c31d83SMatthias Springer /// Notifies that a block is being inlined into another block. 96042c31d83SMatthias Springer void notifyBlockBeingInlined(Block *block, Block *srcBlock, 96142c31d83SMatthias Springer Block::iterator before); 962b6eb26fdSRiver Riddle 963b6eb26fdSRiver Riddle /// Notifies that a pattern match failed for the given reason. 964ea2d9383SMatthias Springer void 965ea2d9383SMatthias Springer notifyMatchFailure(Location loc, 966ea2d9383SMatthias Springer function_ref<void(Diagnostic &)> reasonCallback) override; 967b6eb26fdSRiver Riddle 968b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 96955558cd0SMatthias Springer // IR Erasure 97055558cd0SMatthias Springer //===--------------------------------------------------------------------===// 97155558cd0SMatthias Springer 97255558cd0SMatthias Springer /// A rewriter that keeps track of erased ops and blocks. It ensures that no 97355558cd0SMatthias Springer /// operation or block is erased multiple times. This rewriter assumes that 97455558cd0SMatthias Springer /// no new IR is created between calls to `eraseOp`/`eraseBlock`. 97555558cd0SMatthias Springer struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener { 97655558cd0SMatthias Springer public: 97755558cd0SMatthias Springer SingleEraseRewriter(MLIRContext *context) 97855558cd0SMatthias Springer : RewriterBase(context, /*listener=*/this) {} 97955558cd0SMatthias Springer 98055558cd0SMatthias Springer /// Erase the given op (unless it was already erased). 98155558cd0SMatthias Springer void eraseOp(Operation *op) override { 9823815f478SMatthias Springer if (wasErased(op)) 98355558cd0SMatthias Springer return; 98455558cd0SMatthias Springer op->dropAllUses(); 98555558cd0SMatthias Springer RewriterBase::eraseOp(op); 98655558cd0SMatthias Springer } 98755558cd0SMatthias Springer 98855558cd0SMatthias Springer /// Erase the given block (unless it was already erased). 98955558cd0SMatthias Springer void eraseBlock(Block *block) override { 9903815f478SMatthias Springer if (wasErased(block)) 99155558cd0SMatthias Springer return; 992d68d2951SMatthias Springer assert(block->empty() && "expected empty block"); 99355558cd0SMatthias Springer block->dropAllDefinedValueUses(); 99455558cd0SMatthias Springer RewriterBase::eraseBlock(block); 99555558cd0SMatthias Springer } 99655558cd0SMatthias Springer 9973815f478SMatthias Springer bool wasErased(void *ptr) const { return erased.contains(ptr); } 9983815f478SMatthias Springer 99955558cd0SMatthias Springer void notifyOperationErased(Operation *op) override { erased.insert(op); } 100060a20bd6SMatthias Springer 100155558cd0SMatthias Springer void notifyBlockErased(Block *block) override { erased.insert(block); } 100255558cd0SMatthias Springer 10033815f478SMatthias Springer private: 100455558cd0SMatthias Springer /// Pointers to all erased operations and blocks. 1005310a2788SMatthias Springer DenseSet<void *> erased; 100655558cd0SMatthias Springer }; 100755558cd0SMatthias Springer 100855558cd0SMatthias Springer //===--------------------------------------------------------------------===// 1009b6eb26fdSRiver Riddle // State 1010b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1011b6eb26fdSRiver Riddle 101260a20bd6SMatthias Springer /// MLIR context. 101360a20bd6SMatthias Springer MLIRContext *context; 101455558cd0SMatthias Springer 10153815f478SMatthias Springer /// A rewriter that keeps track of ops/block that were already erased and 10163815f478SMatthias Springer /// skips duplicate op/block erasures. This rewriter is used during the 10173815f478SMatthias Springer /// "cleanup" phase. 10183815f478SMatthias Springer SingleEraseRewriter eraseRewriter; 10193815f478SMatthias Springer 1020b6eb26fdSRiver Riddle // Mapping between replaced values that differ in type. This happens when 1021b6eb26fdSRiver Riddle // replacing a value with one of a different type. 1022b6eb26fdSRiver Riddle ConversionValueMapping mapping; 1023b6eb26fdSRiver Riddle 1024b6eb26fdSRiver Riddle /// Ordered list of block operations (creations, splits, motions). 10258faefe36SMatthias Springer SmallVector<std::unique_ptr<IRRewrite>> rewrites; 1026b6eb26fdSRiver Riddle 10276008cd40SMatthias Springer /// A set of operations that should no longer be considered for legalization. 10286008cd40SMatthias Springer /// E.g., ops that are recursively legal. Ops that were replaced/erased are 10296008cd40SMatthias Springer /// tracked separately. 10304efb7754SRiver Riddle SetVector<Operation *> ignoredOps; 1031b6eb26fdSRiver Riddle 10326008cd40SMatthias Springer /// A set of operations that were replaced/erased. Such ops are not erased 10336008cd40SMatthias Springer /// immediately but only when the dialect conversion succeeds. In the mean 10346008cd40SMatthias Springer /// time, they should no longer be considered for legalization and any attempt 10356008cd40SMatthias Springer /// to modify/access them is invalid rewriter API usage. 10367b66b5d6SMatthias Springer SetVector<Operation *> replacedOps; 10377b66b5d6SMatthias Springer 1038d588e49aSMatthias Springer /// A mapping of all unresolved materializations (UnrealizedConversionCastOp) 1039d588e49aSMatthias Springer /// to the corresponding rewrite objects. 1040d588e49aSMatthias Springer DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *> 1041d588e49aSMatthias Springer unresolvedMaterializations; 10426093c26aSMatthias Springer 104301b55f16SRiver Riddle /// The current type converter, or nullptr if no type converter is currently 104401b55f16SRiver Riddle /// active. 1045ce254598SMatthias Springer const TypeConverter *currentTypeConverter = nullptr; 1046b6eb26fdSRiver Riddle 1047b49f155cSMatthias Springer /// A mapping of regions to type converters that should be used when 1048b49f155cSMatthias Springer /// converting the arguments of blocks within that region. 1049b49f155cSMatthias Springer DenseMap<Region *, const TypeConverter *> regionToConverter; 1050b49f155cSMatthias Springer 1051a2821094SMatthias Springer /// Dialect conversion configuration. 1052a2821094SMatthias Springer const ConversionConfig &config; 1053d68d2951SMatthias Springer 1054b6eb26fdSRiver Riddle #ifndef NDEBUG 1055b6eb26fdSRiver Riddle /// A set of operations that have pending updates. This tracking isn't 1056b6eb26fdSRiver Riddle /// strictly necessary, and is thus only active during debug builds for extra 1057b6eb26fdSRiver Riddle /// verification. 1058b6eb26fdSRiver Riddle SmallPtrSet<Operation *, 1> pendingRootUpdates; 1059b6eb26fdSRiver Riddle 1060b6eb26fdSRiver Riddle /// A logger used to emit diagnostics during the conversion process. 1061b6eb26fdSRiver Riddle llvm::ScopedPrinter logger{llvm::dbgs()}; 1062b6eb26fdSRiver Riddle #endif 1063b6eb26fdSRiver Riddle }; 1064be0a7e9fSMehdi Amini } // namespace detail 1065be0a7e9fSMehdi Amini } // namespace mlir 1066b6eb26fdSRiver Riddle 1067a2821094SMatthias Springer const ConversionConfig &IRRewrite::getConfig() const { 1068a2821094SMatthias Springer return rewriterImpl.config; 1069a2821094SMatthias Springer } 1070a2821094SMatthias Springer 107160a20bd6SMatthias Springer void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { 107260a20bd6SMatthias Springer // Inform the listener about all IR modifications that have already taken 107360a20bd6SMatthias Springer // place: References to the original block have been replaced with the new 107460a20bd6SMatthias Springer // block. 10752a306845SMatthias Springer if (auto *listener = 10762a306845SMatthias Springer dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener())) 1077345ca6a6SMatthias Springer for (Operation *op : getNewBlock()->getUsers()) 107860a20bd6SMatthias Springer listener->notifyOperationModified(op); 10799606655fSMatthias Springer } 108055558cd0SMatthias Springer 108155558cd0SMatthias Springer void BlockTypeConversionRewrite::rollback() { 1082345ca6a6SMatthias Springer getNewBlock()->replaceAllUsesWith(getOrigBlock()); 108355558cd0SMatthias Springer } 108455558cd0SMatthias Springer 108560a20bd6SMatthias Springer void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { 10863761b675SMatthias Springer Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); 1087d68d2951SMatthias Springer if (!repl) 1088d68d2951SMatthias Springer return; 1089d68d2951SMatthias Springer 1090d68d2951SMatthias Springer if (isa<BlockArgument>(repl)) { 109160a20bd6SMatthias Springer rewriter.replaceAllUsesWith(arg, repl); 1092d68d2951SMatthias Springer return; 1093d68d2951SMatthias Springer } 1094d68d2951SMatthias Springer 1095d68d2951SMatthias Springer // If the replacement value is an operation, we check to make sure that we 1096d68d2951SMatthias Springer // don't replace uses that are within the parent operation of the 1097d68d2951SMatthias Springer // replacement value. 1098d68d2951SMatthias Springer Operation *replOp = cast<OpResult>(repl).getOwner(); 1099d68d2951SMatthias Springer Block *replBlock = replOp->getBlock(); 110060a20bd6SMatthias Springer rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) { 1101d68d2951SMatthias Springer Operation *user = operand.getOwner(); 1102d68d2951SMatthias Springer return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); 1103d68d2951SMatthias Springer }); 1104d68d2951SMatthias Springer } 1105d68d2951SMatthias Springer 11063ace6851SMatthias Springer void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } 1107d68d2951SMatthias Springer 110860a20bd6SMatthias Springer void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { 11092a306845SMatthias Springer auto *listener = 11102a306845SMatthias Springer dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()); 111160a20bd6SMatthias Springer 111260a20bd6SMatthias Springer // Compute replacement values. 111360a20bd6SMatthias Springer SmallVector<Value> replacements = 111460a20bd6SMatthias Springer llvm::map_to_vector(op->getResults(), [&](OpResult result) { 11153761b675SMatthias Springer return rewriterImpl.findOrBuildReplacementValue(result, converter); 111660a20bd6SMatthias Springer }); 111760a20bd6SMatthias Springer 111860a20bd6SMatthias Springer // Notify the listener that the operation is about to be replaced. 111960a20bd6SMatthias Springer if (listener) 112060a20bd6SMatthias Springer listener->notifyOperationReplaced(op, replacements); 112160a20bd6SMatthias Springer 112260a20bd6SMatthias Springer // Replace all uses with the new values. 112360a20bd6SMatthias Springer for (auto [result, newValue] : 112460a20bd6SMatthias Springer llvm::zip_equal(op->getResults(), replacements)) 112560a20bd6SMatthias Springer if (newValue) 112660a20bd6SMatthias Springer rewriter.replaceAllUsesWith(result, newValue); 112760a20bd6SMatthias Springer 112860a20bd6SMatthias Springer // The original op will be erased, so remove it from the set of unlegalized 112960a20bd6SMatthias Springer // ops. 1130a2821094SMatthias Springer if (getConfig().unlegalizedOps) 1131a2821094SMatthias Springer getConfig().unlegalizedOps->erase(op); 113260a20bd6SMatthias Springer 113360a20bd6SMatthias Springer // Notify the listener that the operation (and its nested operations) was 113460a20bd6SMatthias Springer // erased. 113560a20bd6SMatthias Springer if (listener) { 113660a20bd6SMatthias Springer op->walk<WalkOrder::PostOrder>( 113760a20bd6SMatthias Springer [&](Operation *op) { listener->notifyOperationErased(op); }); 113860a20bd6SMatthias Springer } 113960a20bd6SMatthias Springer 1140d68d2951SMatthias Springer // Do not erase the operation yet. It may still be referenced in `mapping`. 114160a20bd6SMatthias Springer // Just unlink it for now and erase it during cleanup. 1142d68d2951SMatthias Springer op->getBlock()->getOperations().remove(op); 1143d68d2951SMatthias Springer } 1144d68d2951SMatthias Springer 1145d68d2951SMatthias Springer void ReplaceOperationRewrite::rollback() { 1146d68d2951SMatthias Springer for (auto result : op->getResults()) 11473ace6851SMatthias Springer rewriterImpl.mapping.erase({result}); 1148d68d2951SMatthias Springer } 1149d68d2951SMatthias Springer 115060a20bd6SMatthias Springer void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) { 115160a20bd6SMatthias Springer rewriter.eraseOp(op); 115260a20bd6SMatthias Springer } 1153d68d2951SMatthias Springer 11549ca70d72SMatthias Springer void CreateOperationRewrite::rollback() { 11559ca70d72SMatthias Springer for (Region ®ion : op->getRegions()) { 11569ca70d72SMatthias Springer while (!region.getBlocks().empty()) 11579ca70d72SMatthias Springer region.getBlocks().remove(region.getBlocks().begin()); 11589ca70d72SMatthias Springer } 11599ca70d72SMatthias Springer op->dropAllUses(); 1160310a2788SMatthias Springer op->erase(); 11619ca70d72SMatthias Springer } 11629ca70d72SMatthias Springer 1163d588e49aSMatthias Springer UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( 1164d588e49aSMatthias Springer ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, 1165f2d500c6SMatthias Springer const TypeConverter *converter, MaterializationKind kind, Type originalType, 11663ace6851SMatthias Springer ValueVector mappedValues) 1167d588e49aSMatthias Springer : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), 1168f2d500c6SMatthias Springer converterAndKind(converter, kind), originalType(originalType), 11693ace6851SMatthias Springer mappedValues(std::move(mappedValues)) { 1170d5746d73SFrank Schlimbach assert((!originalType || kind == MaterializationKind::Target) && 11710d906a42SMatthias Springer "original type is valid only for target materializations"); 1172d588e49aSMatthias Springer rewriterImpl.unresolvedMaterializations[op] = this; 1173d588e49aSMatthias Springer } 1174d588e49aSMatthias Springer 117559ff4d13SMatthias Springer void UnresolvedMaterializationRewrite::rollback() { 11763ace6851SMatthias Springer if (!mappedValues.empty()) 11773ace6851SMatthias Springer rewriterImpl.mapping.erase(mappedValues); 1178d588e49aSMatthias Springer rewriterImpl.unresolvedMaterializations.erase(getOperation()); 1179310a2788SMatthias Springer op->erase(); 1180b6eb26fdSRiver Riddle } 1181b6eb26fdSRiver Riddle 1182b6eb26fdSRiver Riddle void ConversionPatternRewriterImpl::applyRewrites() { 1183d68d2951SMatthias Springer // Commit all rewrites. 118460a20bd6SMatthias Springer IRRewriter rewriter(context, config.listener); 11853761b675SMatthias Springer // Note: New rewrites may be added during the "commit" phase and the 11863761b675SMatthias Springer // `rewrites` vector may reallocate. 11873761b675SMatthias Springer for (size_t i = 0; i < rewrites.size(); ++i) 11883761b675SMatthias Springer rewrites[i]->commit(rewriter); 118960a20bd6SMatthias Springer 119060a20bd6SMatthias Springer // Clean up all rewrites. 1191d68d2951SMatthias Springer for (auto &rewrite : rewrites) 119260a20bd6SMatthias Springer rewrite->cleanup(eraseRewriter); 1193b6eb26fdSRiver Riddle } 1194b6eb26fdSRiver Riddle 1195b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1196b6eb26fdSRiver Riddle // State Management 1197b6eb26fdSRiver Riddle 1198b6eb26fdSRiver Riddle RewriterState ConversionPatternRewriterImpl::getCurrentState() { 1199310a2788SMatthias Springer return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size()); 1200b6eb26fdSRiver Riddle } 1201b6eb26fdSRiver Riddle 1202b6eb26fdSRiver Riddle void ConversionPatternRewriterImpl::resetState(RewriterState state) { 12038faefe36SMatthias Springer // Undo any rewrites. 12048faefe36SMatthias Springer undoRewrites(state.numRewrites); 1205b6eb26fdSRiver Riddle 1206b6eb26fdSRiver Riddle // Pop all of the recorded ignored operations that are no longer valid. 1207b6eb26fdSRiver Riddle while (ignoredOps.size() != state.numIgnoredOperations) 1208b6eb26fdSRiver Riddle ignoredOps.pop_back(); 1209b6eb26fdSRiver Riddle 12107b66b5d6SMatthias Springer while (replacedOps.size() != state.numReplacedOps) 12117b66b5d6SMatthias Springer replacedOps.pop_back(); 1212b6eb26fdSRiver Riddle } 1213b6eb26fdSRiver Riddle 12148faefe36SMatthias Springer void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { 12158faefe36SMatthias Springer for (auto &rewrite : 12168faefe36SMatthias Springer llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) 12178faefe36SMatthias Springer rewrite->rollback(); 12188faefe36SMatthias Springer rewrites.resize(numRewritesToKeep); 1219b6eb26fdSRiver Riddle } 1220b6eb26fdSRiver Riddle 1221b6eb26fdSRiver Riddle LogicalResult ConversionPatternRewriterImpl::remapValues( 12220de16fafSRamkumar Ramachandra StringRef valueDiagTag, std::optional<Location> inputLoc, 1223015192c6SRiver Riddle PatternRewriter &rewriter, ValueRange values, 12243ace6851SMatthias Springer SmallVector<ValueVector> &remapped) { 1225015192c6SRiver Riddle remapped.reserve(llvm::size(values)); 1226b6eb26fdSRiver Riddle 1227e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(values)) { 1228b6eb26fdSRiver Riddle Value operand = it.value(); 1229b6eb26fdSRiver Riddle Type origType = operand.getType(); 1230015192c6SRiver Riddle Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); 1231023f7c93SMatthias Springer 1232023f7c93SMatthias Springer if (!currentTypeConverter) { 1233023f7c93SMatthias Springer // The current pattern does not have a type converter. I.e., it does not 1234023f7c93SMatthias Springer // distinguish between legal and illegal types. For each operand, simply 12353ace6851SMatthias Springer // pass through the most recently mapped values. 12363ace6851SMatthias Springer remapped.push_back(mapping.lookupOrDefault(operand)); 1237023f7c93SMatthias Springer continue; 1238023f7c93SMatthias Springer } 1239023f7c93SMatthias Springer 1240023f7c93SMatthias Springer // If there is no legal conversion, fail to match this pattern. 1241023f7c93SMatthias Springer SmallVector<Type, 1> legalTypes; 1242023f7c93SMatthias Springer if (failed(currentTypeConverter->convertType(origType, legalTypes))) { 12439a028afdSMatthias Springer notifyMatchFailure(operandLoc, [=](Diagnostic &diag) { 1244015192c6SRiver Riddle diag << "unable to convert type for " << valueDiagTag << " #" 1245015192c6SRiver Riddle << it.index() << ", type was " << origType; 1246b6eb26fdSRiver Riddle }); 12479a028afdSMatthias Springer return failure(); 1248b6eb26fdSRiver Riddle } 12499df63b26SMatthias Springer // If a type is converted to 0 types, there is nothing to do. 12509df63b26SMatthias Springer if (legalTypes.empty()) { 12519df63b26SMatthias Springer remapped.push_back({}); 12529df63b26SMatthias Springer continue; 12539df63b26SMatthias Springer } 12549df63b26SMatthias Springer 12553ace6851SMatthias Springer ValueVector repl = mapping.lookupOrDefault(operand, legalTypes); 1256faa30be1SMatthias Springer if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) { 12573ace6851SMatthias Springer // Mapped values have the correct type or there is an existing 12583ace6851SMatthias Springer // materialization. Or the operand is not mapped at all and has the 12593ace6851SMatthias Springer // correct type. 12603ace6851SMatthias Springer remapped.push_back(std::move(repl)); 12619df63b26SMatthias Springer continue; 12629df63b26SMatthias Springer } 12639df63b26SMatthias Springer 12643ace6851SMatthias Springer // Create a materialization for the most recently mapped values. 12653ace6851SMatthias Springer repl = mapping.lookupOrDefault(operand); 12663ace6851SMatthias Springer ValueRange castValues = buildUnresolvedMaterialization( 12679df63b26SMatthias Springer MaterializationKind::Target, computeInsertPoint(repl), operandLoc, 12683ace6851SMatthias Springer /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, 12693ace6851SMatthias Springer /*originalType=*/origType, currentTypeConverter); 12703ace6851SMatthias Springer remapped.push_back(castValues); 1271b6eb26fdSRiver Riddle } 1272b6eb26fdSRiver Riddle return success(); 1273b6eb26fdSRiver Riddle } 1274b6eb26fdSRiver Riddle 1275b6eb26fdSRiver Riddle bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { 12766008cd40SMatthias Springer // Check to see if this operation is ignored or was replaced. 12776008cd40SMatthias Springer return replacedOps.count(op) || ignoredOps.count(op); 1278b6eb26fdSRiver Riddle } 1279b6eb26fdSRiver Riddle 12806008cd40SMatthias Springer bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { 12816008cd40SMatthias Springer // Check to see if this operation was replaced. 12826008cd40SMatthias Springer return replacedOps.count(op); 1283b6eb26fdSRiver Riddle } 1284b6eb26fdSRiver Riddle 1285b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1286b6eb26fdSRiver Riddle // Type Conversion 1287b6eb26fdSRiver Riddle 1288b6eb26fdSRiver Riddle FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( 1289aaf5c818SMatthias Springer ConversionPatternRewriter &rewriter, Region *region, 1290aaf5c818SMatthias Springer const TypeConverter &converter, 1291b6eb26fdSRiver Riddle TypeConverter::SignatureConversion *entryConversion) { 1292b49f155cSMatthias Springer regionToConverter[region] = &converter; 1293b6eb26fdSRiver Riddle if (region->empty()) 1294b6eb26fdSRiver Riddle return nullptr; 1295b6eb26fdSRiver Riddle 129652050f3fSMatthias Springer // Convert the arguments of each non-entry block within the region. 1297aa6eb2afSKareemErgawy-TomTom for (Block &block : 1298aa6eb2afSKareemErgawy-TomTom llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) { 129952050f3fSMatthias Springer // Compute the signature for the block with the provided converter. 130052050f3fSMatthias Springer std::optional<TypeConverter::SignatureConversion> conversion = 130152050f3fSMatthias Springer converter.convertBlockSignature(&block); 130252050f3fSMatthias Springer if (!conversion) 1303b6eb26fdSRiver Riddle return failure(); 130452050f3fSMatthias Springer // Convert the block with the computed signature. 130552050f3fSMatthias Springer applySignatureConversion(rewriter, &block, &converter, *conversion); 1306aa6eb2afSKareemErgawy-TomTom } 130752050f3fSMatthias Springer 130852050f3fSMatthias Springer // Convert the entry block. If an entry signature conversion was provided, 130952050f3fSMatthias Springer // use that one. Otherwise, compute the signature with the type converter. 131052050f3fSMatthias Springer if (entryConversion) 131152050f3fSMatthias Springer return applySignatureConversion(rewriter, ®ion->front(), &converter, 131252050f3fSMatthias Springer *entryConversion); 131352050f3fSMatthias Springer std::optional<TypeConverter::SignatureConversion> conversion = 131452050f3fSMatthias Springer converter.convertBlockSignature(®ion->front()); 131552050f3fSMatthias Springer if (!conversion) 131652050f3fSMatthias Springer return failure(); 131752050f3fSMatthias Springer return applySignatureConversion(rewriter, ®ion->front(), &converter, 131852050f3fSMatthias Springer *conversion); 1319b6eb26fdSRiver Riddle } 1320b6eb26fdSRiver Riddle 132155558cd0SMatthias Springer Block *ConversionPatternRewriterImpl::applySignatureConversion( 1322aaf5c818SMatthias Springer ConversionPatternRewriter &rewriter, Block *block, 1323aaf5c818SMatthias Springer const TypeConverter *converter, 132455558cd0SMatthias Springer TypeConverter::SignatureConversion &signatureConversion) { 132579f41434SMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 1326345ca6a6SMatthias Springer // A block cannot be converted multiple times. 132779f41434SMatthias Springer if (hasRewrite<BlockTypeConversionRewrite>(rewrites, block)) 132879f41434SMatthias Springer llvm::report_fatal_error("block was already converted"); 132979f41434SMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 133079f41434SMatthias Springer 1331ddaf040eSMatthias Springer OpBuilder::InsertionGuard g(rewriter); 13327bb08ee8SMatthias Springer 133355558cd0SMatthias Springer // If no arguments are being changed or added, there is nothing to do. 133455558cd0SMatthias Springer unsigned origArgCount = block->getNumArguments(); 133555558cd0SMatthias Springer auto convertedTypes = signatureConversion.getConvertedTypes(); 133655558cd0SMatthias Springer if (llvm::equal(block->getArgumentTypes(), convertedTypes)) 133755558cd0SMatthias Springer return block; 133855558cd0SMatthias Springer 1339ddaf040eSMatthias Springer // Compute the locations of all block arguments in the new block. 134055558cd0SMatthias Springer SmallVector<Location> newLocs(convertedTypes.size(), 1341ddaf040eSMatthias Springer rewriter.getUnknownLoc()); 134255558cd0SMatthias Springer for (unsigned i = 0; i < origArgCount; ++i) { 134355558cd0SMatthias Springer auto inputMap = signatureConversion.getInputMapping(i); 134455558cd0SMatthias Springer if (!inputMap || inputMap->replacementValue) 134555558cd0SMatthias Springer continue; 134655558cd0SMatthias Springer Location origLoc = block->getArgument(i).getLoc(); 134755558cd0SMatthias Springer for (unsigned j = 0; j < inputMap->size; ++j) 134855558cd0SMatthias Springer newLocs[inputMap->inputNo + j] = origLoc; 134955558cd0SMatthias Springer } 135055558cd0SMatthias Springer 1351ddaf040eSMatthias Springer // Insert a new block with the converted block argument types and move all ops 1352ddaf040eSMatthias Springer // from the old block to the new block. 1353ddaf040eSMatthias Springer Block *newBlock = 1354ddaf040eSMatthias Springer rewriter.createBlock(block->getParent(), std::next(block->getIterator()), 1355ddaf040eSMatthias Springer convertedTypes, newLocs); 135660a20bd6SMatthias Springer 135760a20bd6SMatthias Springer // If a listener is attached to the dialect conversion, ops cannot be moved 135860a20bd6SMatthias Springer // to the destination block in bulk ("fast path"). This is because at the time 135960a20bd6SMatthias Springer // the notifications are sent, it is unknown which ops were moved. Instead, 136060a20bd6SMatthias Springer // ops should be moved one-by-one ("slow path"), so that a separate 136160a20bd6SMatthias Springer // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is 136260a20bd6SMatthias Springer // a bit more efficient, so we try to do that when possible. 136360a20bd6SMatthias Springer bool fastPath = !config.listener; 136460a20bd6SMatthias Springer if (fastPath) { 1365ddaf040eSMatthias Springer appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end()); 1366ddaf040eSMatthias Springer newBlock->getOperations().splice(newBlock->end(), block->getOperations()); 136760a20bd6SMatthias Springer } else { 136860a20bd6SMatthias Springer while (!block->empty()) 136960a20bd6SMatthias Springer rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end()); 137060a20bd6SMatthias Springer } 1371ddaf040eSMatthias Springer 1372ddaf040eSMatthias Springer // Replace all uses of the old block with the new block. 1373ddaf040eSMatthias Springer block->replaceAllUsesWith(newBlock); 137455558cd0SMatthias Springer 13754d46b460SBenjamin Kramer for (unsigned i = 0; i != origArgCount; ++i) { 13764d46b460SBenjamin Kramer BlockArgument origArg = block->getArgument(i); 1377bbd4af5dSMatthias Springer Type origArgType = origArg.getType(); 13784d46b460SBenjamin Kramer 1379bbd4af5dSMatthias Springer std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap = 1380bbd4af5dSMatthias Springer signatureConversion.getInputMapping(i); 1381bbd4af5dSMatthias Springer if (!inputMap) { 1382bbd4af5dSMatthias Springer // This block argument was dropped and no replacement value was provided. 1383bbd4af5dSMatthias Springer // Materialize a replacement value "out of thin air". 1384f2d500c6SMatthias Springer buildUnresolvedMaterialization( 13852fc71e4eSMatthias Springer MaterializationKind::Source, 13862fc71e4eSMatthias Springer OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), 13873ace6851SMatthias Springer /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(), 13880d906a42SMatthias Springer /*outputType=*/origArgType, /*originalType=*/Type(), converter); 13893761b675SMatthias Springer appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); 13904d46b460SBenjamin Kramer continue; 13914d46b460SBenjamin Kramer } 13924d46b460SBenjamin Kramer 1393bbd4af5dSMatthias Springer if (Value repl = inputMap->replacementValue) { 1394bbd4af5dSMatthias Springer // This block argument was dropped and a replacement value was provided. 1395bbd4af5dSMatthias Springer assert(inputMap->size == 0 && 1396bbd4af5dSMatthias Springer "invalid to provide a replacement value when the argument isn't " 1397bbd4af5dSMatthias Springer "dropped"); 1398bbd4af5dSMatthias Springer mapping.map(origArg, repl); 13993761b675SMatthias Springer appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); 1400bbd4af5dSMatthias Springer continue; 1401bbd4af5dSMatthias Springer } 1402bbd4af5dSMatthias Springer 14033ace6851SMatthias Springer // This is a 1->1+ mapping. 14044d46b460SBenjamin Kramer auto replArgs = 14054d46b460SBenjamin Kramer newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); 14063ace6851SMatthias Springer ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs); 14073ace6851SMatthias Springer mapping.map(origArg, std::move(replArgVals)); 14083761b675SMatthias Springer appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); 14092aa96fcfSMatthias Springer } 1410f1e0657dSMatthias Springer 1411345ca6a6SMatthias Springer appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock); 1412aaf5c818SMatthias Springer 1413aaf5c818SMatthias Springer // Erase the old block. (It is just unlinked for now and will be erased during 1414aaf5c818SMatthias Springer // cleanup.) 1415aaf5c818SMatthias Springer rewriter.eraseBlock(block); 1416aaf5c818SMatthias Springer 141755558cd0SMatthias Springer return newBlock; 141855558cd0SMatthias Springer } 141955558cd0SMatthias Springer 1420b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 142159ff4d13SMatthias Springer // Materializations 142259ff4d13SMatthias Springer //===----------------------------------------------------------------------===// 142359ff4d13SMatthias Springer 142459ff4d13SMatthias Springer /// Build an unresolved materialization operation given an output type and set 142559ff4d13SMatthias Springer /// of input operands. 14269df63b26SMatthias Springer ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( 14272fc71e4eSMatthias Springer MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, 14283ace6851SMatthias Springer ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, 14299df63b26SMatthias Springer Type originalType, const TypeConverter *converter, 14309df63b26SMatthias Springer UnrealizedConversionCastOp *castOp) { 1431d5746d73SFrank Schlimbach assert((!originalType || kind == MaterializationKind::Target) && 14320d906a42SMatthias Springer "original type is valid only for target materializations"); 1433486f83faSMatthias Springer assert(TypeRange(inputs) != outputTypes && 1434486f83faSMatthias Springer "materialization is not necessary"); 143559ff4d13SMatthias Springer 143659ff4d13SMatthias Springer // Create an unresolved materialization. We use a new OpBuilder to avoid 143759ff4d13SMatthias Springer // tracking the materialization like we do for other operations. 14389df63b26SMatthias Springer OpBuilder builder(outputTypes.front().getContext()); 14392fc71e4eSMatthias Springer builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); 144059ff4d13SMatthias Springer auto convertOp = 14419df63b26SMatthias Springer builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs); 14423ace6851SMatthias Springer if (!valuesToMap.empty()) 14433ace6851SMatthias Springer mapping.map(valuesToMap, convertOp.getResults()); 14449df63b26SMatthias Springer if (castOp) 14459df63b26SMatthias Springer *castOp = convertOp; 14463ace6851SMatthias Springer appendRewrite<UnresolvedMaterializationRewrite>( 14473ace6851SMatthias Springer convertOp, converter, kind, originalType, std::move(valuesToMap)); 14489df63b26SMatthias Springer return convertOp.getResults(); 144959ff4d13SMatthias Springer } 145059ff4d13SMatthias Springer 14513761b675SMatthias Springer Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( 14523761b675SMatthias Springer Value value, const TypeConverter *converter) { 1453486f83faSMatthias Springer // Try to find a replacement value with the same type in the conversion value 1454486f83faSMatthias Springer // mapping. This includes cached materializations. We try to reuse those 1455486f83faSMatthias Springer // instead of generating duplicate IR. 14563ace6851SMatthias Springer ValueVector repl = mapping.lookupOrNull(value, value.getType()); 14573ace6851SMatthias Springer if (!repl.empty()) 14583ace6851SMatthias Springer return repl.front(); 14593761b675SMatthias Springer 14603761b675SMatthias Springer // Check if the value is dead. No replacement value is needed in that case. 14613761b675SMatthias Springer // This is an approximate check that may have false negatives but does not 14623761b675SMatthias Springer // require computing and traversing an inverse mapping. (We may end up 14633761b675SMatthias Springer // building source materializations that are never used and that fold away.) 14643761b675SMatthias Springer if (llvm::all_of(value.getUsers(), 14653761b675SMatthias Springer [&](Operation *op) { return replacedOps.contains(op); }) && 14663761b675SMatthias Springer !mapping.isMappedTo(value)) 14673761b675SMatthias Springer return Value(); 14683761b675SMatthias Springer 14693761b675SMatthias Springer // No replacement value was found. Get the latest replacement value 14703761b675SMatthias Springer // (regardless of the type) and build a source materialization to the 14713761b675SMatthias Springer // original type. 14723761b675SMatthias Springer repl = mapping.lookupOrNull(value); 14733ace6851SMatthias Springer if (repl.empty()) { 14743761b675SMatthias Springer // No replacement value is registered in the mapping. This means that the 14753761b675SMatthias Springer // value is dropped and no longer needed. (If the value were still needed, 14763761b675SMatthias Springer // a source materialization producing a replacement value "out of thin air" 14773761b675SMatthias Springer // would have already been created during `replaceOp` or 14783761b675SMatthias Springer // `applySignatureConversion`.) 14793761b675SMatthias Springer return Value(); 14803761b675SMatthias Springer } 14813ace6851SMatthias Springer 14823ace6851SMatthias Springer // Note: `computeInsertPoint` computes the "earliest" insertion point at 14833ace6851SMatthias Springer // which all values in `repl` are defined. It is important to emit the 14843ace6851SMatthias Springer // materialization at that location because the same materialization may be 14853ace6851SMatthias Springer // reused in a different context. (That's because materializations are cached 14863ace6851SMatthias Springer // in the conversion value mapping.) The insertion point of the 14873ace6851SMatthias Springer // materialization must be valid for all future users that may be created 14883ace6851SMatthias Springer // later in the conversion process. 14893ace6851SMatthias Springer Value castValue = 14903ace6851SMatthias Springer buildUnresolvedMaterialization(MaterializationKind::Source, 14913ace6851SMatthias Springer computeInsertPoint(repl), value.getLoc(), 1492*5f7568a3SMatthias Springer /*valuesToMap=*/repl, /*inputs=*/repl, 14933ace6851SMatthias Springer /*outputType=*/value.getType(), 14943ace6851SMatthias Springer /*originalType=*/Type(), converter) 14953ace6851SMatthias Springer .front(); 14963761b675SMatthias Springer return castValue; 14973761b675SMatthias Springer } 14983761b675SMatthias Springer 149959ff4d13SMatthias Springer //===----------------------------------------------------------------------===// 1500b6eb26fdSRiver Riddle // Rewriter Notification Hooks 1501b6eb26fdSRiver Riddle 1502ea2d9383SMatthias Springer void ConversionPatternRewriterImpl::notifyOperationInserted( 1503ea2d9383SMatthias Springer Operation *op, OpBuilder::InsertPoint previous) { 1504ea2d9383SMatthias Springer LLVM_DEBUG({ 1505ea2d9383SMatthias Springer logger.startLine() << "** Insert : '" << op->getName() << "'(" << op 1506ea2d9383SMatthias Springer << ")\n"; 1507ea2d9383SMatthias Springer }); 15086008cd40SMatthias Springer assert(!wasOpReplaced(op->getParentOp()) && 15096008cd40SMatthias Springer "attempting to insert into a block within a replaced/erased op"); 15106008cd40SMatthias Springer 15118f4cd2c7SMatthias Springer if (!previous.isSet()) { 15128f4cd2c7SMatthias Springer // This is a newly created op. 15139ca70d72SMatthias Springer appendRewrite<CreateOperationRewrite>(op); 15148f4cd2c7SMatthias Springer return; 15158f4cd2c7SMatthias Springer } 15168f4cd2c7SMatthias Springer Operation *prevOp = previous.getPoint() == previous.getBlock()->end() 15178f4cd2c7SMatthias Springer ? nullptr 15188f4cd2c7SMatthias Springer : &*previous.getPoint(); 15198f4cd2c7SMatthias Springer appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp); 1520ea2d9383SMatthias Springer } 1521ea2d9383SMatthias Springer 1522aed43562SMatthias Springer void ConversionPatternRewriterImpl::notifyOpReplaced( 15239df63b26SMatthias Springer Operation *op, ArrayRef<ValueRange> newValues) { 1524b6eb26fdSRiver Riddle assert(newValues.size() == op->getNumResults()); 15256008cd40SMatthias Springer assert(!ignoredOps.contains(op) && "operation was already replaced"); 1526b6eb26fdSRiver Riddle 1527c0cba25cSMatthias Springer // Check if replaced op is an unresolved materialization, i.e., an 1528c0cba25cSMatthias Springer // unrealized_conversion_cast op that was created by the conversion driver. 1529c0cba25cSMatthias Springer bool isUnresolvedMaterialization = false; 1530c0cba25cSMatthias Springer if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) 1531c0cba25cSMatthias Springer if (unresolvedMaterializations.contains(castOp)) 1532c0cba25cSMatthias Springer isUnresolvedMaterialization = true; 1533c0cba25cSMatthias Springer 1534b6eb26fdSRiver Riddle // Create mappings for each of the new result values. 15359df63b26SMatthias Springer for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) { 1536aed43562SMatthias Springer if (repl.empty()) { 15376093c26aSMatthias Springer // This result was dropped and no replacement value was provided. 1538c0cba25cSMatthias Springer if (isUnresolvedMaterialization) { 15396093c26aSMatthias Springer // Do not create another materializations if we are erasing a 15406093c26aSMatthias Springer // materialization. 1541b6eb26fdSRiver Riddle continue; 1542b6eb26fdSRiver Riddle } 15436093c26aSMatthias Springer 15446093c26aSMatthias Springer // Materialize a replacement value "out of thin air". 15459df63b26SMatthias Springer buildUnresolvedMaterialization( 15466093c26aSMatthias Springer MaterializationKind::Source, computeInsertPoint(result), 15473ace6851SMatthias Springer result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(), 15480d906a42SMatthias Springer /*outputType=*/result.getType(), /*originalType=*/Type(), 15490d906a42SMatthias Springer currentTypeConverter); 15509df63b26SMatthias Springer continue; 1551c0cba25cSMatthias Springer } else { 1552c0cba25cSMatthias Springer // Make sure that the user does not mess with unresolved materializations 1553c0cba25cSMatthias Springer // that were inserted by the conversion driver. We keep track of these 1554c0cba25cSMatthias Springer // ops in internal data structures. Erasing them must be allowed because 1555c0cba25cSMatthias Springer // this can happen when the user is erasing an entire block (including 1556c0cba25cSMatthias Springer // its body). But replacing them with another value should be forbidden 1557c0cba25cSMatthias Springer // to avoid problems with the `mapping`. 1558c0cba25cSMatthias Springer assert(!isUnresolvedMaterialization && 1559c0cba25cSMatthias Springer "attempting to replace an unresolved materialization"); 1560b6eb26fdSRiver Riddle } 1561b6eb26fdSRiver Riddle 1562c0cba25cSMatthias Springer // Remap result to replacement value. 1563aed43562SMatthias Springer if (repl.empty()) 1564aed43562SMatthias Springer continue; 15653ace6851SMatthias Springer mapping.map(result, repl); 15666093c26aSMatthias Springer } 15676093c26aSMatthias Springer 15686093c26aSMatthias Springer appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter); 15696008cd40SMatthias Springer // Mark this operation and all nested ops as replaced. 15706008cd40SMatthias Springer op->walk([&](Operation *op) { replacedOps.insert(op); }); 1571b6eb26fdSRiver Riddle } 1572b6eb26fdSRiver Riddle 1573b6eb26fdSRiver Riddle void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { 157436d384b4SMatthias Springer appendRewrite<EraseBlockRewrite>(block); 1575b6eb26fdSRiver Riddle } 1576b6eb26fdSRiver Riddle 1577ea2d9383SMatthias Springer void ConversionPatternRewriterImpl::notifyBlockInserted( 15783ed98cb3SMatthias Springer Block *block, Region *previous, Region::iterator previousIt) { 15796008cd40SMatthias Springer assert(!wasOpReplaced(block->getParentOp()) && 15806008cd40SMatthias Springer "attempting to insert into a region within a replaced/erased op"); 158165a8e3a4SMehdi Amini LLVM_DEBUG( 158265a8e3a4SMehdi Amini { 158365a8e3a4SMehdi Amini Operation *parent = block->getParentOp(); 158465a8e3a4SMehdi Amini if (parent) { 158565a8e3a4SMehdi Amini logger.startLine() << "** Insert Block into : '" << parent->getName() 158665a8e3a4SMehdi Amini << "'(" << parent << ")\n"; 158765a8e3a4SMehdi Amini } else { 158865a8e3a4SMehdi Amini logger.startLine() 158965a8e3a4SMehdi Amini << "** Insert Block into detached Region (nullptr parent op)'"; 159065a8e3a4SMehdi Amini } 15919606655fSMatthias Springer }); 15926008cd40SMatthias Springer 15933ed98cb3SMatthias Springer if (!previous) { 15943ed98cb3SMatthias Springer // This is a newly created block. 15958faefe36SMatthias Springer appendRewrite<CreateBlockRewrite>(block); 15963ed98cb3SMatthias Springer return; 15973ed98cb3SMatthias Springer } 15983ed98cb3SMatthias Springer Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt; 15998faefe36SMatthias Springer appendRewrite<MoveBlockRewrite>(block, previous, prevBlock); 1600b6eb26fdSRiver Riddle } 1601b6eb26fdSRiver Riddle 160242c31d83SMatthias Springer void ConversionPatternRewriterImpl::notifyBlockBeingInlined( 160342c31d83SMatthias Springer Block *block, Block *srcBlock, Block::iterator before) { 16048faefe36SMatthias Springer appendRewrite<InlineBlockRewrite>(block, srcBlock, before); 1605b6eb26fdSRiver Riddle } 1606b6eb26fdSRiver Riddle 16079a028afdSMatthias Springer void ConversionPatternRewriterImpl::notifyMatchFailure( 1608b6eb26fdSRiver Riddle Location loc, function_ref<void(Diagnostic &)> reasonCallback) { 1609b6eb26fdSRiver Riddle LLVM_DEBUG({ 1610b6eb26fdSRiver Riddle Diagnostic diag(loc, DiagnosticSeverity::Remark); 1611b6eb26fdSRiver Riddle reasonCallback(diag); 1612b6eb26fdSRiver Riddle logger.startLine() << "** Failure : " << diag.str() << "\n"; 1613a2821094SMatthias Springer if (config.notifyCallback) 1614a2821094SMatthias Springer config.notifyCallback(diag); 1615b6eb26fdSRiver Riddle }); 1616b6eb26fdSRiver Riddle } 1617b6eb26fdSRiver Riddle 1618b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1619b6eb26fdSRiver Riddle // ConversionPatternRewriter 1620b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1621b6eb26fdSRiver Riddle 1622a2821094SMatthias Springer ConversionPatternRewriter::ConversionPatternRewriter( 1623a2821094SMatthias Springer MLIRContext *ctx, const ConversionConfig &config) 1624b6eb26fdSRiver Riddle : PatternRewriter(ctx), 1625a2821094SMatthias Springer impl(new detail::ConversionPatternRewriterImpl(ctx, config)) { 1626ea2d9383SMatthias Springer setListener(impl.get()); 1627c6532830SMatthias Springer } 1628c6532830SMatthias Springer 1629e5639b3fSMehdi Amini ConversionPatternRewriter::~ConversionPatternRewriter() = default; 1630b6eb26fdSRiver Riddle 163171d50c89SMatthias Springer void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { 163271d50c89SMatthias Springer assert(op && newOp && "expected non-null op"); 163371d50c89SMatthias Springer replaceOp(op, newOp->getResults()); 163471d50c89SMatthias Springer } 163571d50c89SMatthias Springer 1636b6eb26fdSRiver Riddle void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { 163771d50c89SMatthias Springer assert(op->getNumResults() == newValues.size() && 163871d50c89SMatthias Springer "incorrect # of replacement values"); 1639b6eb26fdSRiver Riddle LLVM_DEBUG({ 1640b6eb26fdSRiver Riddle impl->logger.startLine() 1641b6eb26fdSRiver Riddle << "** Replace : '" << op->getName() << "'(" << op << ")\n"; 1642b6eb26fdSRiver Riddle }); 16439df63b26SMatthias Springer SmallVector<ValueRange> newVals; 16443ace6851SMatthias Springer for (size_t i = 0; i < newValues.size(); ++i) { 16453ace6851SMatthias Springer if (newValues[i]) { 16469df63b26SMatthias Springer newVals.push_back(newValues.slice(i, 1)); 16473ace6851SMatthias Springer } else { 16483ace6851SMatthias Springer newVals.push_back(ValueRange()); 16493ace6851SMatthias Springer } 16503ace6851SMatthias Springer } 1651aed43562SMatthias Springer impl->notifyOpReplaced(op, newVals); 1652aed43562SMatthias Springer } 1653aed43562SMatthias Springer 1654aed43562SMatthias Springer void ConversionPatternRewriter::replaceOpWithMultiple( 1655aed43562SMatthias Springer Operation *op, ArrayRef<ValueRange> newValues) { 1656aed43562SMatthias Springer assert(op->getNumResults() == newValues.size() && 1657aed43562SMatthias Springer "incorrect # of replacement values"); 1658aed43562SMatthias Springer LLVM_DEBUG({ 1659aed43562SMatthias Springer impl->logger.startLine() 1660aed43562SMatthias Springer << "** Replace : '" << op->getName() << "'(" << op << ")\n"; 1661aed43562SMatthias Springer }); 16629df63b26SMatthias Springer impl->notifyOpReplaced(op, newValues); 1663b6eb26fdSRiver Riddle } 1664b6eb26fdSRiver Riddle 1665b6eb26fdSRiver Riddle void ConversionPatternRewriter::eraseOp(Operation *op) { 1666b6eb26fdSRiver Riddle LLVM_DEBUG({ 1667b6eb26fdSRiver Riddle impl->logger.startLine() 1668b6eb26fdSRiver Riddle << "** Erase : '" << op->getName() << "'(" << op << ")\n"; 1669b6eb26fdSRiver Riddle }); 16709df63b26SMatthias Springer SmallVector<ValueRange> nullRepls(op->getNumResults(), {}); 1671b6eb26fdSRiver Riddle impl->notifyOpReplaced(op, nullRepls); 1672b6eb26fdSRiver Riddle } 1673b6eb26fdSRiver Riddle 1674b6eb26fdSRiver Riddle void ConversionPatternRewriter::eraseBlock(Block *block) { 16756008cd40SMatthias Springer assert(!impl->wasOpReplaced(block->getParentOp()) && 16766008cd40SMatthias Springer "attempting to erase a block within a replaced/erased op"); 16776008cd40SMatthias Springer 1678b6eb26fdSRiver Riddle // Mark all ops for erasure. 1679b6eb26fdSRiver Riddle for (Operation &op : *block) 1680b6eb26fdSRiver Riddle eraseOp(&op); 1681b6eb26fdSRiver Riddle 16828faefe36SMatthias Springer // Unlink the block from its parent region. The block is kept in the rewrite 16838faefe36SMatthias Springer // object and will be actually destroyed when rewrites are applied. This 1684b6eb26fdSRiver Riddle // allows us to keep the operations in the block live and undo the removal by 1685b6eb26fdSRiver Riddle // re-inserting the block. 1686d68d2951SMatthias Springer impl->notifyBlockIsBeingErased(block); 1687b6eb26fdSRiver Riddle block->getParent()->getBlocks().remove(block); 1688b6eb26fdSRiver Riddle } 1689b6eb26fdSRiver Riddle 1690b6eb26fdSRiver Riddle Block *ConversionPatternRewriter::applySignatureConversion( 169152050f3fSMatthias Springer Block *block, TypeConverter::SignatureConversion &conversion, 1692ce254598SMatthias Springer const TypeConverter *converter) { 169352050f3fSMatthias Springer assert(!impl->wasOpReplaced(block->getParentOp()) && 16946008cd40SMatthias Springer "attempting to apply a signature conversion to a block within a " 16956008cd40SMatthias Springer "replaced/erased op"); 169652050f3fSMatthias Springer return impl->applySignatureConversion(*this, block, converter, conversion); 1697b6eb26fdSRiver Riddle } 1698b6eb26fdSRiver Riddle 1699b6eb26fdSRiver Riddle FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( 1700ce254598SMatthias Springer Region *region, const TypeConverter &converter, 1701b6eb26fdSRiver Riddle TypeConverter::SignatureConversion *entryConversion) { 17026008cd40SMatthias Springer assert(!impl->wasOpReplaced(region->getParentOp()) && 17036008cd40SMatthias Springer "attempting to apply a signature conversion to a block within a " 17046008cd40SMatthias Springer "replaced/erased op"); 1705aaf5c818SMatthias Springer return impl->convertRegionTypes(*this, region, converter, entryConversion); 1706b6eb26fdSRiver Riddle } 1707b6eb26fdSRiver Riddle 1708b6eb26fdSRiver Riddle void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, 1709b6eb26fdSRiver Riddle Value to) { 1710b6eb26fdSRiver Riddle LLVM_DEBUG({ 1711b6eb26fdSRiver Riddle Operation *parentOp = from.getOwner()->getParentOp(); 1712b6eb26fdSRiver Riddle impl->logger.startLine() << "** Replace Argument : '" << from 1713b6eb26fdSRiver Riddle << "'(in region of '" << parentOp->getName() 1714b6eb26fdSRiver Riddle << "'(" << from.getOwner()->getParentOp() << ")\n"; 1715b6eb26fdSRiver Riddle }); 17163761b675SMatthias Springer impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, 17173761b675SMatthias Springer impl->currentTypeConverter); 1718b6eb26fdSRiver Riddle impl->mapping.map(impl->mapping.lookupOrDefault(from), to); 1719b6eb26fdSRiver Riddle } 1720b6eb26fdSRiver Riddle 1721b6eb26fdSRiver Riddle Value ConversionPatternRewriter::getRemappedValue(Value key) { 17223ace6851SMatthias Springer SmallVector<ValueVector> remappedValues; 17231a36588eSKazu Hirata if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, 1724015192c6SRiver Riddle remappedValues))) 1725015192c6SRiver Riddle return nullptr; 17269df63b26SMatthias Springer assert(remappedValues.front().size() == 1 && "1:N conversion not supported"); 17279df63b26SMatthias Springer return remappedValues.front().front(); 1728015192c6SRiver Riddle } 1729015192c6SRiver Riddle 1730015192c6SRiver Riddle LogicalResult 1731015192c6SRiver Riddle ConversionPatternRewriter::getRemappedValues(ValueRange keys, 1732015192c6SRiver Riddle SmallVectorImpl<Value> &results) { 1733015192c6SRiver Riddle if (keys.empty()) 1734015192c6SRiver Riddle return success(); 17353ace6851SMatthias Springer SmallVector<ValueVector> remapped; 17369df63b26SMatthias Springer if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, 17379df63b26SMatthias Springer remapped))) 17389df63b26SMatthias Springer return failure(); 17399df63b26SMatthias Springer for (const auto &values : remapped) { 17409df63b26SMatthias Springer assert(values.size() == 1 && "1:N conversion not supported"); 17419df63b26SMatthias Springer results.push_back(values.front()); 17429df63b26SMatthias Springer } 17439df63b26SMatthias Springer return success(); 1744b6eb26fdSRiver Riddle } 1745b6eb26fdSRiver Riddle 174642c31d83SMatthias Springer void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, 174742c31d83SMatthias Springer Block::iterator before, 1748b6eb26fdSRiver Riddle ValueRange argValues) { 17496008cd40SMatthias Springer #ifndef NDEBUG 1750b6eb26fdSRiver Riddle assert(argValues.size() == source->getNumArguments() && 1751b6eb26fdSRiver Riddle "incorrect # of argument replacement values"); 17526008cd40SMatthias Springer assert(!impl->wasOpReplaced(source->getParentOp()) && 17536008cd40SMatthias Springer "attempting to inline a block from a replaced/erased op"); 17546008cd40SMatthias Springer assert(!impl->wasOpReplaced(dest->getParentOp()) && 17556008cd40SMatthias Springer "attempting to inline a block into a replaced/erased op"); 175642c31d83SMatthias Springer auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); }; 175742c31d83SMatthias Springer // The source block will be deleted, so it should not have any users (i.e., 175842c31d83SMatthias Springer // there should be no predecessors). 175942c31d83SMatthias Springer assert(llvm::all_of(source->getUsers(), opIgnored) && 176042c31d83SMatthias Springer "expected 'source' to have no predecessors"); 17616008cd40SMatthias Springer #endif // NDEBUG 176242c31d83SMatthias Springer 176360a20bd6SMatthias Springer // If a listener is attached to the dialect conversion, ops cannot be moved 176460a20bd6SMatthias Springer // to the destination block in bulk ("fast path"). This is because at the time 176560a20bd6SMatthias Springer // the notifications are sent, it is unknown which ops were moved. Instead, 176660a20bd6SMatthias Springer // ops should be moved one-by-one ("slow path"), so that a separate 176760a20bd6SMatthias Springer // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is 176860a20bd6SMatthias Springer // a bit more efficient, so we try to do that when possible. 176960a20bd6SMatthias Springer bool fastPath = !impl->config.listener; 177060a20bd6SMatthias Springer 177160a20bd6SMatthias Springer if (fastPath) 177242c31d83SMatthias Springer impl->notifyBlockBeingInlined(dest, source, before); 177360a20bd6SMatthias Springer 177460a20bd6SMatthias Springer // Replace all uses of block arguments. 1775b6eb26fdSRiver Riddle for (auto it : llvm::zip(source->getArguments(), argValues)) 1776b6eb26fdSRiver Riddle replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); 177760a20bd6SMatthias Springer 177860a20bd6SMatthias Springer if (fastPath) { 177960a20bd6SMatthias Springer // Move all ops at once. 178042c31d83SMatthias Springer dest->getOperations().splice(before, source->getOperations()); 178160a20bd6SMatthias Springer } else { 178260a20bd6SMatthias Springer // Move op by op. 178360a20bd6SMatthias Springer while (!source->empty()) 178460a20bd6SMatthias Springer moveOpBefore(&source->front(), dest, before); 178560a20bd6SMatthias Springer } 178660a20bd6SMatthias Springer 178760a20bd6SMatthias Springer // Erase the source block. 1788b6eb26fdSRiver Riddle eraseBlock(source); 1789b6eb26fdSRiver Riddle } 1790b6eb26fdSRiver Riddle 17915fcf907bSMatthias Springer void ConversionPatternRewriter::startOpModification(Operation *op) { 17926008cd40SMatthias Springer assert(!impl->wasOpReplaced(op) && 17936008cd40SMatthias Springer "attempting to modify a replaced/erased op"); 1794b6eb26fdSRiver Riddle #ifndef NDEBUG 1795b6eb26fdSRiver Riddle impl->pendingRootUpdates.insert(op); 1796b6eb26fdSRiver Riddle #endif 1797e214f004SMatthias Springer impl->appendRewrite<ModifyOperationRewrite>(op); 1798b6eb26fdSRiver Riddle } 1799b6eb26fdSRiver Riddle 18005fcf907bSMatthias Springer void ConversionPatternRewriter::finalizeOpModification(Operation *op) { 18016008cd40SMatthias Springer assert(!impl->wasOpReplaced(op) && 18026008cd40SMatthias Springer "attempting to modify a replaced/erased op"); 18035fcf907bSMatthias Springer PatternRewriter::finalizeOpModification(op); 1804b6eb26fdSRiver Riddle // There is nothing to do here, we only need to track the operation at the 1805b6eb26fdSRiver Riddle // start of the update. 1806b6eb26fdSRiver Riddle #ifndef NDEBUG 1807b6eb26fdSRiver Riddle assert(impl->pendingRootUpdates.erase(op) && 1808b6eb26fdSRiver Riddle "operation did not have a pending in-place update"); 1809b6eb26fdSRiver Riddle #endif 1810b6eb26fdSRiver Riddle } 1811b6eb26fdSRiver Riddle 18125fcf907bSMatthias Springer void ConversionPatternRewriter::cancelOpModification(Operation *op) { 1813b6eb26fdSRiver Riddle #ifndef NDEBUG 1814b6eb26fdSRiver Riddle assert(impl->pendingRootUpdates.erase(op) && 1815b6eb26fdSRiver Riddle "operation did not have a pending in-place update"); 1816b6eb26fdSRiver Riddle #endif 1817b6eb26fdSRiver Riddle // Erase the last update for this operation. 1818e214f004SMatthias Springer auto it = llvm::find_if( 1819e214f004SMatthias Springer llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) { 1820e214f004SMatthias Springer auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get()); 1821e214f004SMatthias Springer return modifyRewrite && modifyRewrite->getOperation() == op; 1822e214f004SMatthias Springer }); 1823e214f004SMatthias Springer assert(it != impl->rewrites.rend() && "no root update started on op"); 1824e214f004SMatthias Springer (*it)->rollback(); 1825e214f004SMatthias Springer int updateIdx = std::prev(impl->rewrites.rend()) - it; 1826e214f004SMatthias Springer impl->rewrites.erase(impl->rewrites.begin() + updateIdx); 1827b6eb26fdSRiver Riddle } 1828b6eb26fdSRiver Riddle 1829b6eb26fdSRiver Riddle detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { 1830b6eb26fdSRiver Riddle return *impl; 1831b6eb26fdSRiver Riddle } 1832b6eb26fdSRiver Riddle 1833b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1834b6eb26fdSRiver Riddle // ConversionPattern 1835b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1836b6eb26fdSRiver Riddle 18379df63b26SMatthias Springer SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands( 18389df63b26SMatthias Springer ArrayRef<ValueRange> operands) const { 18399df63b26SMatthias Springer SmallVector<Value> oneToOneOperands; 18409df63b26SMatthias Springer oneToOneOperands.reserve(operands.size()); 18419df63b26SMatthias Springer for (ValueRange operand : operands) { 18429df63b26SMatthias Springer if (operand.size() != 1) 18439df63b26SMatthias Springer llvm::report_fatal_error("pattern '" + getDebugName() + 18449df63b26SMatthias Springer "' does not support 1:N conversion"); 18459df63b26SMatthias Springer oneToOneOperands.push_back(operand.front()); 18469df63b26SMatthias Springer } 18479df63b26SMatthias Springer return oneToOneOperands; 18489df63b26SMatthias Springer } 18499df63b26SMatthias Springer 1850b6eb26fdSRiver Riddle LogicalResult 1851b6eb26fdSRiver Riddle ConversionPattern::matchAndRewrite(Operation *op, 1852b6eb26fdSRiver Riddle PatternRewriter &rewriter) const { 1853b6eb26fdSRiver Riddle auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter); 1854b6eb26fdSRiver Riddle auto &rewriterImpl = dialectRewriter.getImpl(); 1855b6eb26fdSRiver Riddle 185601b55f16SRiver Riddle // Track the current conversion pattern type converter in the rewriter. 1857abf0c6c0SJan Svoboda llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter, 1858abf0c6c0SJan Svoboda getTypeConverter()); 1859b6eb26fdSRiver Riddle 1860b6eb26fdSRiver Riddle // Remap the operands of the operation. 18613ace6851SMatthias Springer SmallVector<ValueVector> remapped; 1862015192c6SRiver Riddle if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, 18639df63b26SMatthias Springer op->getOperands(), remapped))) { 1864b6eb26fdSRiver Riddle return failure(); 1865b6eb26fdSRiver Riddle } 18669df63b26SMatthias Springer SmallVector<ValueRange> remappedAsRange = 18679df63b26SMatthias Springer llvm::to_vector_of<ValueRange>(remapped); 18689df63b26SMatthias Springer return matchAndRewrite(op, remappedAsRange, dialectRewriter); 1869b6eb26fdSRiver Riddle } 1870b6eb26fdSRiver Riddle 1871b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1872b6eb26fdSRiver Riddle // OperationLegalizer 1873b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1874b6eb26fdSRiver Riddle 1875b6eb26fdSRiver Riddle namespace { 1876b6eb26fdSRiver Riddle /// A set of rewrite patterns that can be used to legalize a given operation. 1877b6eb26fdSRiver Riddle using LegalizationPatterns = SmallVector<const Pattern *, 1>; 1878b6eb26fdSRiver Riddle 1879b6eb26fdSRiver Riddle /// This class defines a recursive operation legalizer. 1880b6eb26fdSRiver Riddle class OperationLegalizer { 1881b6eb26fdSRiver Riddle public: 1882b6eb26fdSRiver Riddle using LegalizationAction = ConversionTarget::LegalizationAction; 1883b6eb26fdSRiver Riddle 1884370a6f09SMehdi Amini OperationLegalizer(const ConversionTarget &targetInfo, 18859b6bd709SMatthias Springer const FrozenRewritePatternSet &patterns, 18869b6bd709SMatthias Springer const ConversionConfig &config); 1887b6eb26fdSRiver Riddle 1888b6eb26fdSRiver Riddle /// Returns true if the given operation is known to be illegal on the target. 1889b6eb26fdSRiver Riddle bool isIllegal(Operation *op) const; 1890b6eb26fdSRiver Riddle 1891b6eb26fdSRiver Riddle /// Attempt to legalize the given operation. Returns success if the operation 1892b6eb26fdSRiver Riddle /// was legalized, failure otherwise. 1893b6eb26fdSRiver Riddle LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter); 1894b6eb26fdSRiver Riddle 1895b6eb26fdSRiver Riddle /// Returns the conversion target in use by the legalizer. 1896370a6f09SMehdi Amini const ConversionTarget &getTarget() { return target; } 1897b6eb26fdSRiver Riddle 1898b6eb26fdSRiver Riddle private: 1899b6eb26fdSRiver Riddle /// Attempt to legalize the given operation by folding it. 1900b6eb26fdSRiver Riddle LogicalResult legalizeWithFold(Operation *op, 1901b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter); 1902b6eb26fdSRiver Riddle 1903b6eb26fdSRiver Riddle /// Attempt to legalize the given operation by applying a pattern. Returns 1904b6eb26fdSRiver Riddle /// success if the operation was legalized, failure otherwise. 1905b6eb26fdSRiver Riddle LogicalResult legalizeWithPattern(Operation *op, 1906b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter); 1907b6eb26fdSRiver Riddle 1908b6eb26fdSRiver Riddle /// Return true if the given pattern may be applied to the given operation, 1909b6eb26fdSRiver Riddle /// false otherwise. 1910b6eb26fdSRiver Riddle bool canApplyPattern(Operation *op, const Pattern &pattern, 1911b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter); 1912b6eb26fdSRiver Riddle 1913b6eb26fdSRiver Riddle /// Legalize the resultant IR after successfully applying the given pattern. 1914b6eb26fdSRiver Riddle LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, 1915b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, 1916b6eb26fdSRiver Riddle RewriterState &curState); 1917b6eb26fdSRiver Riddle 1918b6eb26fdSRiver Riddle /// Legalizes the actions registered during the execution of a pattern. 19198faefe36SMatthias Springer LogicalResult 19208faefe36SMatthias Springer legalizePatternBlockRewrites(Operation *op, 1921b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, 1922b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &impl, 19238faefe36SMatthias Springer RewriterState &state, RewriterState &newState); 1924b6eb26fdSRiver Riddle LogicalResult legalizePatternCreatedOperations( 1925b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 1926b6eb26fdSRiver Riddle RewriterState &state, RewriterState &newState); 1927b6eb26fdSRiver Riddle LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter, 1928b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &impl, 1929b6eb26fdSRiver Riddle RewriterState &state, 1930b6eb26fdSRiver Riddle RewriterState &newState); 1931b6eb26fdSRiver Riddle 1932b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1933b6eb26fdSRiver Riddle // Cost Model 1934b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1935b6eb26fdSRiver Riddle 1936b6eb26fdSRiver Riddle /// Build an optimistic legalization graph given the provided patterns. This 1937b6eb26fdSRiver Riddle /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with 1938b6eb26fdSRiver Riddle /// patterns for operations that are not directly legal, but may be 1939b6eb26fdSRiver Riddle /// transitively legal for the current target given the provided patterns. 1940b6eb26fdSRiver Riddle void buildLegalizationGraph( 1941b6eb26fdSRiver Riddle LegalizationPatterns &anyOpLegalizerPatterns, 1942b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1943b6eb26fdSRiver Riddle 1944b6eb26fdSRiver Riddle /// Compute the benefit of each node within the computed legalization graph. 1945b6eb26fdSRiver Riddle /// This orders the patterns within 'legalizerPatterns' based upon two 1946b6eb26fdSRiver Riddle /// criteria: 1947b6eb26fdSRiver Riddle /// 1) Prefer patterns that have the lowest legalization depth, i.e. 1948b6eb26fdSRiver Riddle /// represent the more direct mapping to the target. 1949b6eb26fdSRiver Riddle /// 2) When comparing patterns with the same legalization depth, prefer the 1950b6eb26fdSRiver Riddle /// pattern with the highest PatternBenefit. This allows for users to 1951b6eb26fdSRiver Riddle /// prefer specific legalizations over others. 1952b6eb26fdSRiver Riddle void computeLegalizationGraphBenefit( 1953b6eb26fdSRiver Riddle LegalizationPatterns &anyOpLegalizerPatterns, 1954b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1955b6eb26fdSRiver Riddle 1956b6eb26fdSRiver Riddle /// Compute the legalization depth when legalizing an operation of the given 1957b6eb26fdSRiver Riddle /// type. 1958b6eb26fdSRiver Riddle unsigned computeOpLegalizationDepth( 1959b6eb26fdSRiver Riddle OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth, 1960b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1961b6eb26fdSRiver Riddle 1962b6eb26fdSRiver Riddle /// Apply the conversion cost model to the given set of patterns, and return 1963b6eb26fdSRiver Riddle /// the smallest legalization depth of any of the patterns. See 1964b6eb26fdSRiver Riddle /// `computeLegalizationGraphBenefit` for the breakdown of the cost model. 1965b6eb26fdSRiver Riddle unsigned applyCostModelToPatterns( 1966b6eb26fdSRiver Riddle LegalizationPatterns &patterns, 1967b6eb26fdSRiver Riddle DenseMap<OperationName, unsigned> &minOpPatternDepth, 1968b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1969b6eb26fdSRiver Riddle 1970b6eb26fdSRiver Riddle /// The current set of patterns that have been applied. 1971b6eb26fdSRiver Riddle SmallPtrSet<const Pattern *, 8> appliedPatterns; 1972b6eb26fdSRiver Riddle 1973b6eb26fdSRiver Riddle /// The legalization information provided by the target. 1974370a6f09SMehdi Amini const ConversionTarget ⌖ 1975b6eb26fdSRiver Riddle 1976b6eb26fdSRiver Riddle /// The pattern applicator to use for conversions. 1977b6eb26fdSRiver Riddle PatternApplicator applicator; 19789b6bd709SMatthias Springer 19799b6bd709SMatthias Springer /// Dialect conversion configuration. 19809b6bd709SMatthias Springer const ConversionConfig &config; 1981b6eb26fdSRiver Riddle }; 1982b6eb26fdSRiver Riddle } // namespace 1983b6eb26fdSRiver Riddle 1984370a6f09SMehdi Amini OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo, 19859b6bd709SMatthias Springer const FrozenRewritePatternSet &patterns, 19869b6bd709SMatthias Springer const ConversionConfig &config) 19879b6bd709SMatthias Springer : target(targetInfo), applicator(patterns), config(config) { 1988b6eb26fdSRiver Riddle // The set of patterns that can be applied to illegal operations to transform 1989b6eb26fdSRiver Riddle // them into legal ones. 1990b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> legalizerPatterns; 1991b6eb26fdSRiver Riddle LegalizationPatterns anyOpLegalizerPatterns; 1992b6eb26fdSRiver Riddle 1993b6eb26fdSRiver Riddle buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns); 1994b6eb26fdSRiver Riddle computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns); 1995b6eb26fdSRiver Riddle } 1996b6eb26fdSRiver Riddle 1997b6eb26fdSRiver Riddle bool OperationLegalizer::isIllegal(Operation *op) const { 19982a3878eaSButygin return target.isIllegal(op); 1999b6eb26fdSRiver Riddle } 2000b6eb26fdSRiver Riddle 2001b6eb26fdSRiver Riddle LogicalResult 2002b6eb26fdSRiver Riddle OperationLegalizer::legalize(Operation *op, 2003b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter) { 2004b6eb26fdSRiver Riddle #ifndef NDEBUG 2005b6eb26fdSRiver Riddle const char *logLineComment = 2006b6eb26fdSRiver Riddle "//===-------------------------------------------===//\n"; 2007b6eb26fdSRiver Riddle 200801b55f16SRiver Riddle auto &logger = rewriter.getImpl().logger; 2009b6eb26fdSRiver Riddle #endif 2010b6eb26fdSRiver Riddle LLVM_DEBUG({ 201101b55f16SRiver Riddle logger.getOStream() << "\n"; 201201b55f16SRiver Riddle logger.startLine() << logLineComment; 201301b55f16SRiver Riddle logger.startLine() << "Legalizing operation : '" << op->getName() << "'(" 201401b55f16SRiver Riddle << op << ") {\n"; 201501b55f16SRiver Riddle logger.indent(); 2016b6eb26fdSRiver Riddle 2017b6eb26fdSRiver Riddle // If the operation has no regions, just print it here. 2018b6eb26fdSRiver Riddle if (op->getNumRegions() == 0) { 201901b55f16SRiver Riddle op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm()); 202001b55f16SRiver Riddle logger.getOStream() << "\n\n"; 2021b6eb26fdSRiver Riddle } 2022b6eb26fdSRiver Riddle }); 2023b6eb26fdSRiver Riddle 2024b6eb26fdSRiver Riddle // Check if this operation is legal on the target. 2025b6eb26fdSRiver Riddle if (auto legalityInfo = target.isLegal(op)) { 2026b6eb26fdSRiver Riddle LLVM_DEBUG({ 2027b6eb26fdSRiver Riddle logSuccess( 202801b55f16SRiver Riddle logger, "operation marked legal by the target{0}", 2029b6eb26fdSRiver Riddle legalityInfo->isRecursivelyLegal 2030b6eb26fdSRiver Riddle ? "; NOTE: operation is recursively legal; skipping internals" 2031b6eb26fdSRiver Riddle : ""); 203201b55f16SRiver Riddle logger.startLine() << logLineComment; 2033b6eb26fdSRiver Riddle }); 2034b6eb26fdSRiver Riddle 2035b6eb26fdSRiver Riddle // If this operation is recursively legal, mark its children as ignored so 2036b6eb26fdSRiver Riddle // that we don't consider them for legalization. 20376008cd40SMatthias Springer if (legalityInfo->isRecursivelyLegal) { 20386008cd40SMatthias Springer op->walk([&](Operation *nested) { 20396008cd40SMatthias Springer if (op != nested) 20406008cd40SMatthias Springer rewriter.getImpl().ignoredOps.insert(nested); 20416008cd40SMatthias Springer }); 20426008cd40SMatthias Springer } 20436008cd40SMatthias Springer 2044b6eb26fdSRiver Riddle return success(); 2045b6eb26fdSRiver Riddle } 2046b6eb26fdSRiver Riddle 2047b6eb26fdSRiver Riddle // Check to see if the operation is ignored and doesn't need to be converted. 2048b6eb26fdSRiver Riddle if (rewriter.getImpl().isOpIgnored(op)) { 2049b6eb26fdSRiver Riddle LLVM_DEBUG({ 205001b55f16SRiver Riddle logSuccess(logger, "operation marked 'ignored' during conversion"); 205101b55f16SRiver Riddle logger.startLine() << logLineComment; 2052b6eb26fdSRiver Riddle }); 2053b6eb26fdSRiver Riddle return success(); 2054b6eb26fdSRiver Riddle } 2055b6eb26fdSRiver Riddle 2056b6eb26fdSRiver Riddle // If the operation isn't legal, try to fold it in-place. 2057b6eb26fdSRiver Riddle // TODO: Should we always try to do this, even if the op is 2058b6eb26fdSRiver Riddle // already legal? 2059b6eb26fdSRiver Riddle if (succeeded(legalizeWithFold(op, rewriter))) { 2060b6eb26fdSRiver Riddle LLVM_DEBUG({ 206101b55f16SRiver Riddle logSuccess(logger, "operation was folded"); 206201b55f16SRiver Riddle logger.startLine() << logLineComment; 2063b6eb26fdSRiver Riddle }); 2064b6eb26fdSRiver Riddle return success(); 2065b6eb26fdSRiver Riddle } 2066b6eb26fdSRiver Riddle 2067b6eb26fdSRiver Riddle // Otherwise, we need to apply a legalization pattern to this operation. 2068b6eb26fdSRiver Riddle if (succeeded(legalizeWithPattern(op, rewriter))) { 2069b6eb26fdSRiver Riddle LLVM_DEBUG({ 207001b55f16SRiver Riddle logSuccess(logger, ""); 207101b55f16SRiver Riddle logger.startLine() << logLineComment; 2072b6eb26fdSRiver Riddle }); 2073b6eb26fdSRiver Riddle return success(); 2074b6eb26fdSRiver Riddle } 2075b6eb26fdSRiver Riddle 2076b6eb26fdSRiver Riddle LLVM_DEBUG({ 207701b55f16SRiver Riddle logFailure(logger, "no matched legalization pattern"); 207801b55f16SRiver Riddle logger.startLine() << logLineComment; 2079b6eb26fdSRiver Riddle }); 2080b6eb26fdSRiver Riddle return failure(); 2081b6eb26fdSRiver Riddle } 2082b6eb26fdSRiver Riddle 2083b6eb26fdSRiver Riddle LogicalResult 2084b6eb26fdSRiver Riddle OperationLegalizer::legalizeWithFold(Operation *op, 2085b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter) { 2086b6eb26fdSRiver Riddle auto &rewriterImpl = rewriter.getImpl(); 2087b6eb26fdSRiver Riddle RewriterState curState = rewriterImpl.getCurrentState(); 2088b6eb26fdSRiver Riddle 2089b6eb26fdSRiver Riddle LLVM_DEBUG({ 2090b6eb26fdSRiver Riddle rewriterImpl.logger.startLine() << "* Fold {\n"; 2091b6eb26fdSRiver Riddle rewriterImpl.logger.indent(); 2092b6eb26fdSRiver Riddle }); 2093b6eb26fdSRiver Riddle 2094b6eb26fdSRiver Riddle // Try to fold the operation. 2095b6eb26fdSRiver Riddle SmallVector<Value, 2> replacementValues; 2096b6eb26fdSRiver Riddle rewriter.setInsertionPoint(op); 2097b6eb26fdSRiver Riddle if (failed(rewriter.tryFold(op, replacementValues))) { 2098b6eb26fdSRiver Riddle LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); 2099b6eb26fdSRiver Riddle return failure(); 2100b6eb26fdSRiver Riddle } 21014513050fSChristian Ulmann // An empty list of replacement values indicates that the fold was in-place. 21024513050fSChristian Ulmann // As the operation changed, a new legalization needs to be attempted. 21034513050fSChristian Ulmann if (replacementValues.empty()) 21044513050fSChristian Ulmann return legalize(op, rewriter); 2105b6eb26fdSRiver Riddle 2106b6eb26fdSRiver Riddle // Insert a replacement for 'op' with the folded replacement values. 2107b6eb26fdSRiver Riddle rewriter.replaceOp(op, replacementValues); 2108b6eb26fdSRiver Riddle 2109b6eb26fdSRiver Riddle // Recursively legalize any new constant operations. 21109ca70d72SMatthias Springer for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size(); 2111b6eb26fdSRiver Riddle i != e; ++i) { 21129ca70d72SMatthias Springer auto *createOp = 21139ca70d72SMatthias Springer dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get()); 21149ca70d72SMatthias Springer if (!createOp) 21159ca70d72SMatthias Springer continue; 21169ca70d72SMatthias Springer if (failed(legalize(createOp->getOperation(), rewriter))) { 2117b6eb26fdSRiver Riddle LLVM_DEBUG(logFailure(rewriterImpl.logger, 2118e49c0e48SUday Bondhugula "failed to legalize generated constant '{0}'", 21199ca70d72SMatthias Springer createOp->getOperation()->getName())); 2120b6eb26fdSRiver Riddle rewriterImpl.resetState(curState); 2121b6eb26fdSRiver Riddle return failure(); 2122b6eb26fdSRiver Riddle } 2123b6eb26fdSRiver Riddle } 2124b6eb26fdSRiver Riddle 2125b6eb26fdSRiver Riddle LLVM_DEBUG(logSuccess(rewriterImpl.logger, "")); 2126b6eb26fdSRiver Riddle return success(); 2127b6eb26fdSRiver Riddle } 2128b6eb26fdSRiver Riddle 2129b6eb26fdSRiver Riddle LogicalResult 2130b6eb26fdSRiver Riddle OperationLegalizer::legalizeWithPattern(Operation *op, 2131b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter) { 2132b6eb26fdSRiver Riddle auto &rewriterImpl = rewriter.getImpl(); 2133b6eb26fdSRiver Riddle 2134b6eb26fdSRiver Riddle // Functor that returns if the given pattern may be applied. 2135b6eb26fdSRiver Riddle auto canApply = [&](const Pattern &pattern) { 21369b6bd709SMatthias Springer bool canApply = canApplyPattern(op, pattern, rewriter); 21379b6bd709SMatthias Springer if (canApply && config.listener) 21389b6bd709SMatthias Springer config.listener->notifyPatternBegin(pattern, op); 21399b6bd709SMatthias Springer return canApply; 2140b6eb26fdSRiver Riddle }; 2141b6eb26fdSRiver Riddle 2142b6eb26fdSRiver Riddle // Functor that cleans up the rewriter state after a pattern failed to match. 2143b6eb26fdSRiver Riddle RewriterState curState = rewriterImpl.getCurrentState(); 2144b6eb26fdSRiver Riddle auto onFailure = [&](const Pattern &pattern) { 2145e214f004SMatthias Springer assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); 2146b8c6b152SChia-hung Duan LLVM_DEBUG({ 2147b8c6b152SChia-hung Duan logFailure(rewriterImpl.logger, "pattern failed to match"); 2148a2821094SMatthias Springer if (rewriterImpl.config.notifyCallback) { 2149b8c6b152SChia-hung Duan Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); 2150b8c6b152SChia-hung Duan diag << "Failed to apply pattern \"" << pattern.getDebugName() 2151b8c6b152SChia-hung Duan << "\" on op:\n" 2152b8c6b152SChia-hung Duan << *op; 2153a2821094SMatthias Springer rewriterImpl.config.notifyCallback(diag); 2154b8c6b152SChia-hung Duan } 2155b8c6b152SChia-hung Duan }); 21569b6bd709SMatthias Springer if (config.listener) 21579b6bd709SMatthias Springer config.listener->notifyPatternEnd(pattern, failure()); 2158b6eb26fdSRiver Riddle rewriterImpl.resetState(curState); 2159b6eb26fdSRiver Riddle appliedPatterns.erase(&pattern); 2160b6eb26fdSRiver Riddle }; 2161b6eb26fdSRiver Riddle 2162b6eb26fdSRiver Riddle // Functor that performs additional legalization when a pattern is 2163b6eb26fdSRiver Riddle // successfully applied. 2164b6eb26fdSRiver Riddle auto onSuccess = [&](const Pattern &pattern) { 2165e214f004SMatthias Springer assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); 2166b6eb26fdSRiver Riddle auto result = legalizePatternResult(op, pattern, rewriter, curState); 2167b6eb26fdSRiver Riddle appliedPatterns.erase(&pattern); 2168b6eb26fdSRiver Riddle if (failed(result)) 2169b6eb26fdSRiver Riddle rewriterImpl.resetState(curState); 21709b6bd709SMatthias Springer if (config.listener) 21719b6bd709SMatthias Springer config.listener->notifyPatternEnd(pattern, result); 2172b6eb26fdSRiver Riddle return result; 2173b6eb26fdSRiver Riddle }; 2174b6eb26fdSRiver Riddle 2175b6eb26fdSRiver Riddle // Try to match and rewrite a pattern on this operation. 2176b6eb26fdSRiver Riddle return applicator.matchAndRewrite(op, rewriter, canApply, onFailure, 2177b6eb26fdSRiver Riddle onSuccess); 2178b6eb26fdSRiver Riddle } 2179b6eb26fdSRiver Riddle 2180b6eb26fdSRiver Riddle bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, 2181b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter) { 2182b6eb26fdSRiver Riddle LLVM_DEBUG({ 2183b6eb26fdSRiver Riddle auto &os = rewriter.getImpl().logger; 2184b6eb26fdSRiver Riddle os.getOStream() << "\n"; 2185b6eb26fdSRiver Riddle os.startLine() << "* Pattern : '" << op->getName() << " -> ("; 2186015192c6SRiver Riddle llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream()); 2187b6eb26fdSRiver Riddle os.getOStream() << ")' {\n"; 2188b6eb26fdSRiver Riddle os.indent(); 2189b6eb26fdSRiver Riddle }); 2190b6eb26fdSRiver Riddle 2191b6eb26fdSRiver Riddle // Ensure that we don't cycle by not allowing the same pattern to be 2192b6eb26fdSRiver Riddle // applied twice in the same recursion stack if it is not known to be safe. 2193b6eb26fdSRiver Riddle if (!pattern.hasBoundedRewriteRecursion() && 2194b6eb26fdSRiver Riddle !appliedPatterns.insert(&pattern).second) { 2195b6eb26fdSRiver Riddle LLVM_DEBUG( 2196b6eb26fdSRiver Riddle logFailure(rewriter.getImpl().logger, "pattern was already applied")); 2197b6eb26fdSRiver Riddle return false; 2198b6eb26fdSRiver Riddle } 2199b6eb26fdSRiver Riddle return true; 2200b6eb26fdSRiver Riddle } 2201b6eb26fdSRiver Riddle 2202b6eb26fdSRiver Riddle LogicalResult 2203b6eb26fdSRiver Riddle OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, 2204b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, 2205b6eb26fdSRiver Riddle RewriterState &curState) { 2206b6eb26fdSRiver Riddle auto &impl = rewriter.getImpl(); 2207b6eb26fdSRiver Riddle assert(impl.pendingRootUpdates.empty() && "dangling root updates"); 220879f41434SMatthias Springer 220979f41434SMatthias Springer #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 2210b6eb26fdSRiver Riddle // Check that the root was either replaced or updated in place. 2211d68d2951SMatthias Springer auto newRewrites = llvm::drop_begin(impl.rewrites, curState.numRewrites); 2212b6eb26fdSRiver Riddle auto replacedRoot = [&] { 2213d68d2951SMatthias Springer return hasRewrite<ReplaceOperationRewrite>(newRewrites, op); 2214b6eb26fdSRiver Riddle }; 2215b6eb26fdSRiver Riddle auto updatedRootInPlace = [&] { 2216d68d2951SMatthias Springer return hasRewrite<ModifyOperationRewrite>(newRewrites, op); 2217b6eb26fdSRiver Riddle }; 221879f41434SMatthias Springer if (!replacedRoot() && !updatedRootInPlace()) 221979f41434SMatthias Springer llvm::report_fatal_error("expected pattern to replace the root operation"); 222079f41434SMatthias Springer #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS 2221b6eb26fdSRiver Riddle 2222b6eb26fdSRiver Riddle // Legalize each of the actions registered during application. 2223b6eb26fdSRiver Riddle RewriterState newState = impl.getCurrentState(); 22248faefe36SMatthias Springer if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState, 2225b6eb26fdSRiver Riddle newState)) || 2226b6eb26fdSRiver Riddle failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) || 2227b6eb26fdSRiver Riddle failed(legalizePatternCreatedOperations(rewriter, impl, curState, 2228b6eb26fdSRiver Riddle newState))) { 2229b6eb26fdSRiver Riddle return failure(); 2230b6eb26fdSRiver Riddle } 2231b6eb26fdSRiver Riddle 2232b6eb26fdSRiver Riddle LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully")); 2233b6eb26fdSRiver Riddle return success(); 2234b6eb26fdSRiver Riddle } 2235b6eb26fdSRiver Riddle 22368faefe36SMatthias Springer LogicalResult OperationLegalizer::legalizePatternBlockRewrites( 2237b6eb26fdSRiver Riddle Operation *op, ConversionPatternRewriter &rewriter, 2238b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &impl, RewriterState &state, 2239b6eb26fdSRiver Riddle RewriterState &newState) { 2240b6eb26fdSRiver Riddle SmallPtrSet<Operation *, 16> operationsToIgnore; 2241b6eb26fdSRiver Riddle 2242b6eb26fdSRiver Riddle // If the pattern moved or created any blocks, make sure the types of block 2243b6eb26fdSRiver Riddle // arguments get legalized. 22448faefe36SMatthias Springer for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { 22458faefe36SMatthias Springer BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get()); 22468faefe36SMatthias Springer if (!rewrite) 22478faefe36SMatthias Springer continue; 22488faefe36SMatthias Springer Block *block = rewrite->getBlock(); 2249d68d2951SMatthias Springer if (isa<BlockTypeConversionRewrite, EraseBlockRewrite, 2250d68d2951SMatthias Springer ReplaceBlockArgRewrite>(rewrite)) 2251b6eb26fdSRiver Riddle continue; 2252b6eb26fdSRiver Riddle // Only check blocks outside of the current operation. 22538faefe36SMatthias Springer Operation *parentOp = block->getParentOp(); 22548faefe36SMatthias Springer if (!parentOp || parentOp == op || block->getNumArguments() == 0) 2255b6eb26fdSRiver Riddle continue; 2256b6eb26fdSRiver Riddle 2257b6eb26fdSRiver Riddle // If the region of the block has a type converter, try to convert the block 2258b6eb26fdSRiver Riddle // directly. 2259b49f155cSMatthias Springer if (auto *converter = impl.regionToConverter.lookup(block->getParent())) { 226052050f3fSMatthias Springer std::optional<TypeConverter::SignatureConversion> conversion = 226152050f3fSMatthias Springer converter->convertBlockSignature(block); 226252050f3fSMatthias Springer if (!conversion) { 2263b6eb26fdSRiver Riddle LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " 2264b6eb26fdSRiver Riddle "block")); 2265b6eb26fdSRiver Riddle return failure(); 2266b6eb26fdSRiver Riddle } 226752050f3fSMatthias Springer impl.applySignatureConversion(rewriter, block, converter, *conversion); 2268b6eb26fdSRiver Riddle continue; 2269b6eb26fdSRiver Riddle } 2270b6eb26fdSRiver Riddle 2271b6eb26fdSRiver Riddle // Otherwise, check that this operation isn't one generated by this pattern. 2272b6eb26fdSRiver Riddle // This is because we will attempt to legalize the parent operation, and 2273b6eb26fdSRiver Riddle // blocks in regions created by this pattern will already be legalized later 2274b6eb26fdSRiver Riddle // on. If we haven't built the set yet, build it now. 2275b6eb26fdSRiver Riddle if (operationsToIgnore.empty()) { 22769ca70d72SMatthias Springer for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e; 22779ca70d72SMatthias Springer ++i) { 22789ca70d72SMatthias Springer auto *createOp = 22799ca70d72SMatthias Springer dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get()); 22809ca70d72SMatthias Springer if (!createOp) 22819ca70d72SMatthias Springer continue; 22829ca70d72SMatthias Springer operationsToIgnore.insert(createOp->getOperation()); 22839ca70d72SMatthias Springer } 2284b6eb26fdSRiver Riddle } 2285b6eb26fdSRiver Riddle 2286b6eb26fdSRiver Riddle // If this operation should be considered for re-legalization, try it. 2287b6eb26fdSRiver Riddle if (operationsToIgnore.insert(parentOp).second && 2288b6eb26fdSRiver Riddle failed(legalize(parentOp, rewriter))) { 22898faefe36SMatthias Springer LLVM_DEBUG(logFailure(impl.logger, 22908faefe36SMatthias Springer "operation '{0}'({1}) became illegal after rewrite", 2291b6eb26fdSRiver Riddle parentOp->getName(), parentOp)); 2292b6eb26fdSRiver Riddle return failure(); 2293b6eb26fdSRiver Riddle } 2294b6eb26fdSRiver Riddle } 2295b6eb26fdSRiver Riddle return success(); 2296b6eb26fdSRiver Riddle } 229701b55f16SRiver Riddle 2298b6eb26fdSRiver Riddle LogicalResult OperationLegalizer::legalizePatternCreatedOperations( 2299b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 2300b6eb26fdSRiver Riddle RewriterState &state, RewriterState &newState) { 23019ca70d72SMatthias Springer for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { 23029ca70d72SMatthias Springer auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get()); 23039ca70d72SMatthias Springer if (!createOp) 23049ca70d72SMatthias Springer continue; 23059ca70d72SMatthias Springer Operation *op = createOp->getOperation(); 2306b6eb26fdSRiver Riddle if (failed(legalize(op, rewriter))) { 2307b6eb26fdSRiver Riddle LLVM_DEBUG(logFailure(impl.logger, 2308e49c0e48SUday Bondhugula "failed to legalize generated operation '{0}'({1})", 2309b6eb26fdSRiver Riddle op->getName(), op)); 2310b6eb26fdSRiver Riddle return failure(); 2311b6eb26fdSRiver Riddle } 2312b6eb26fdSRiver Riddle } 2313b6eb26fdSRiver Riddle return success(); 2314b6eb26fdSRiver Riddle } 231501b55f16SRiver Riddle 2316b6eb26fdSRiver Riddle LogicalResult OperationLegalizer::legalizePatternRootUpdates( 2317b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 2318b6eb26fdSRiver Riddle RewriterState &state, RewriterState &newState) { 2319e214f004SMatthias Springer for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { 2320e214f004SMatthias Springer auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get()); 2321e214f004SMatthias Springer if (!rewrite) 2322e214f004SMatthias Springer continue; 2323e214f004SMatthias Springer Operation *op = rewrite->getOperation(); 2324b6eb26fdSRiver Riddle if (failed(legalize(op, rewriter))) { 2325e49c0e48SUday Bondhugula LLVM_DEBUG(logFailure( 2326e49c0e48SUday Bondhugula impl.logger, "failed to legalize operation updated in-place '{0}'", 2327b6eb26fdSRiver Riddle op->getName())); 2328b6eb26fdSRiver Riddle return failure(); 2329b6eb26fdSRiver Riddle } 2330b6eb26fdSRiver Riddle } 2331b6eb26fdSRiver Riddle return success(); 2332b6eb26fdSRiver Riddle } 2333b6eb26fdSRiver Riddle 2334b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 2335b6eb26fdSRiver Riddle // Cost Model 2336b6eb26fdSRiver Riddle 2337b6eb26fdSRiver Riddle void OperationLegalizer::buildLegalizationGraph( 2338b6eb26fdSRiver Riddle LegalizationPatterns &anyOpLegalizerPatterns, 2339b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2340b6eb26fdSRiver Riddle // A mapping between an operation and a set of operations that can be used to 2341b6eb26fdSRiver Riddle // generate it. 2342b6eb26fdSRiver Riddle DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps; 2343b6eb26fdSRiver Riddle // A mapping between an operation and any currently invalid patterns it has. 2344b6eb26fdSRiver Riddle DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns; 2345b6eb26fdSRiver Riddle // A worklist of patterns to consider for legality. 23464efb7754SRiver Riddle SetVector<const Pattern *> patternWorklist; 2347b6eb26fdSRiver Riddle 2348b6eb26fdSRiver Riddle // Build the mapping from operations to the parent ops that may generate them. 2349b6eb26fdSRiver Riddle applicator.walkAllPatterns([&](const Pattern &pattern) { 2350bef481dfSFangrui Song std::optional<OperationName> root = pattern.getRootKind(); 2351b6eb26fdSRiver Riddle 2352b6eb26fdSRiver Riddle // If the pattern has no specific root, we can't analyze the relationship 2353b6eb26fdSRiver Riddle // between the root op and generated operations. Given that, add all such 2354b6eb26fdSRiver Riddle // patterns to the legalization set. 2355b6eb26fdSRiver Riddle if (!root) { 2356b6eb26fdSRiver Riddle anyOpLegalizerPatterns.push_back(&pattern); 2357b6eb26fdSRiver Riddle return; 2358b6eb26fdSRiver Riddle } 2359b6eb26fdSRiver Riddle 2360b6eb26fdSRiver Riddle // Skip operations that are always known to be legal. 2361b6eb26fdSRiver Riddle if (target.getOpAction(*root) == LegalizationAction::Legal) 2362b6eb26fdSRiver Riddle return; 2363b6eb26fdSRiver Riddle 2364b6eb26fdSRiver Riddle // Add this pattern to the invalid set for the root op and record this root 2365b6eb26fdSRiver Riddle // as a parent for any generated operations. 2366b6eb26fdSRiver Riddle invalidPatterns[*root].insert(&pattern); 2367b6eb26fdSRiver Riddle for (auto op : pattern.getGeneratedOps()) 2368b6eb26fdSRiver Riddle parentOps[op].insert(*root); 2369b6eb26fdSRiver Riddle 2370b6eb26fdSRiver Riddle // Add this pattern to the worklist. 2371b6eb26fdSRiver Riddle patternWorklist.insert(&pattern); 2372b6eb26fdSRiver Riddle }); 2373b6eb26fdSRiver Riddle 2374b6eb26fdSRiver Riddle // If there are any patterns that don't have a specific root kind, we can't 2375b6eb26fdSRiver Riddle // make direct assumptions about what operations will never be legalized. 2376b6eb26fdSRiver Riddle // Note: Technically we could, but it would require an analysis that may 2377b6eb26fdSRiver Riddle // recurse into itself. It would be better to perform this kind of filtering 2378b6eb26fdSRiver Riddle // at a higher level than here anyways. 2379b6eb26fdSRiver Riddle if (!anyOpLegalizerPatterns.empty()) { 2380b6eb26fdSRiver Riddle for (const Pattern *pattern : patternWorklist) 2381b6eb26fdSRiver Riddle legalizerPatterns[*pattern->getRootKind()].push_back(pattern); 2382b6eb26fdSRiver Riddle return; 2383b6eb26fdSRiver Riddle } 2384b6eb26fdSRiver Riddle 2385b6eb26fdSRiver Riddle while (!patternWorklist.empty()) { 2386b6eb26fdSRiver Riddle auto *pattern = patternWorklist.pop_back_val(); 2387b6eb26fdSRiver Riddle 2388b6eb26fdSRiver Riddle // Check to see if any of the generated operations are invalid. 2389b6eb26fdSRiver Riddle if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) { 23900de16fafSRamkumar Ramachandra std::optional<LegalizationAction> action = target.getOpAction(op); 2391b6eb26fdSRiver Riddle return !legalizerPatterns.count(op) && 2392b6eb26fdSRiver Riddle (!action || action == LegalizationAction::Illegal); 2393b6eb26fdSRiver Riddle })) 2394b6eb26fdSRiver Riddle continue; 2395b6eb26fdSRiver Riddle 2396b6eb26fdSRiver Riddle // Otherwise, if all of the generated operation are valid, this op is now 2397b6eb26fdSRiver Riddle // legal so add all of the child patterns to the worklist. 2398b6eb26fdSRiver Riddle legalizerPatterns[*pattern->getRootKind()].push_back(pattern); 2399b6eb26fdSRiver Riddle invalidPatterns[*pattern->getRootKind()].erase(pattern); 2400b6eb26fdSRiver Riddle 2401b6eb26fdSRiver Riddle // Add any invalid patterns of the parent operations to see if they have now 2402b6eb26fdSRiver Riddle // become legal. 2403b6eb26fdSRiver Riddle for (auto op : parentOps[*pattern->getRootKind()]) 2404b6eb26fdSRiver Riddle patternWorklist.set_union(invalidPatterns[op]); 2405b6eb26fdSRiver Riddle } 2406b6eb26fdSRiver Riddle } 2407b6eb26fdSRiver Riddle 2408b6eb26fdSRiver Riddle void OperationLegalizer::computeLegalizationGraphBenefit( 2409b6eb26fdSRiver Riddle LegalizationPatterns &anyOpLegalizerPatterns, 2410b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2411b6eb26fdSRiver Riddle // The smallest pattern depth, when legalizing an operation. 2412b6eb26fdSRiver Riddle DenseMap<OperationName, unsigned> minOpPatternDepth; 2413b6eb26fdSRiver Riddle 2414b6eb26fdSRiver Riddle // For each operation that is transitively legal, compute a cost for it. 2415b6eb26fdSRiver Riddle for (auto &opIt : legalizerPatterns) 2416b6eb26fdSRiver Riddle if (!minOpPatternDepth.count(opIt.first)) 2417b6eb26fdSRiver Riddle computeOpLegalizationDepth(opIt.first, minOpPatternDepth, 2418b6eb26fdSRiver Riddle legalizerPatterns); 2419b6eb26fdSRiver Riddle 2420b6eb26fdSRiver Riddle // Apply the cost model to the patterns that can match any operation. Those 2421b6eb26fdSRiver Riddle // with a specific operation type are already resolved when computing the op 2422b6eb26fdSRiver Riddle // legalization depth. 2423b6eb26fdSRiver Riddle if (!anyOpLegalizerPatterns.empty()) 2424b6eb26fdSRiver Riddle applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth, 2425b6eb26fdSRiver Riddle legalizerPatterns); 2426b6eb26fdSRiver Riddle 2427b6eb26fdSRiver Riddle // Apply a cost model to the pattern applicator. We order patterns first by 2428b6eb26fdSRiver Riddle // depth then benefit. `legalizerPatterns` contains per-op patterns by 2429b6eb26fdSRiver Riddle // decreasing benefit. 2430b6eb26fdSRiver Riddle applicator.applyCostModel([&](const Pattern &pattern) { 2431b6eb26fdSRiver Riddle ArrayRef<const Pattern *> orderedPatternList; 2432bef481dfSFangrui Song if (std::optional<OperationName> rootName = pattern.getRootKind()) 2433b6eb26fdSRiver Riddle orderedPatternList = legalizerPatterns[*rootName]; 2434b6eb26fdSRiver Riddle else 2435b6eb26fdSRiver Riddle orderedPatternList = anyOpLegalizerPatterns; 2436b6eb26fdSRiver Riddle 2437b6eb26fdSRiver Riddle // If the pattern is not found, then it was removed and cannot be matched. 24380c29f45aSUday Bondhugula auto *it = llvm::find(orderedPatternList, &pattern); 2439b6eb26fdSRiver Riddle if (it == orderedPatternList.end()) 2440b6eb26fdSRiver Riddle return PatternBenefit::impossibleToMatch(); 2441b6eb26fdSRiver Riddle 2442b6eb26fdSRiver Riddle // Patterns found earlier in the list have higher benefit. 2443b6eb26fdSRiver Riddle return PatternBenefit(std::distance(it, orderedPatternList.end())); 2444b6eb26fdSRiver Riddle }); 2445b6eb26fdSRiver Riddle } 2446b6eb26fdSRiver Riddle 2447b6eb26fdSRiver Riddle unsigned OperationLegalizer::computeOpLegalizationDepth( 2448b6eb26fdSRiver Riddle OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth, 2449b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2450b6eb26fdSRiver Riddle // Check for existing depth. 2451b6eb26fdSRiver Riddle auto depthIt = minOpPatternDepth.find(op); 2452b6eb26fdSRiver Riddle if (depthIt != minOpPatternDepth.end()) 2453b6eb26fdSRiver Riddle return depthIt->second; 2454b6eb26fdSRiver Riddle 2455b6eb26fdSRiver Riddle // If a mapping for this operation does not exist, then this operation 2456b6eb26fdSRiver Riddle // is always legal. Return 0 as the depth for a directly legal operation. 2457b6eb26fdSRiver Riddle auto opPatternsIt = legalizerPatterns.find(op); 2458b6eb26fdSRiver Riddle if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty()) 2459b6eb26fdSRiver Riddle return 0u; 2460b6eb26fdSRiver Riddle 2461b6eb26fdSRiver Riddle // Record this initial depth in case we encounter this op again when 2462b6eb26fdSRiver Riddle // recursively computing the depth. 2463b6eb26fdSRiver Riddle minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max()); 2464b6eb26fdSRiver Riddle 2465b6eb26fdSRiver Riddle // Apply the cost model to the operation patterns, and update the minimum 2466b6eb26fdSRiver Riddle // depth. 2467b6eb26fdSRiver Riddle unsigned minDepth = applyCostModelToPatterns( 2468b6eb26fdSRiver Riddle opPatternsIt->second, minOpPatternDepth, legalizerPatterns); 2469b6eb26fdSRiver Riddle minOpPatternDepth[op] = minDepth; 2470b6eb26fdSRiver Riddle return minDepth; 2471b6eb26fdSRiver Riddle } 2472b6eb26fdSRiver Riddle 2473b6eb26fdSRiver Riddle unsigned OperationLegalizer::applyCostModelToPatterns( 2474b6eb26fdSRiver Riddle LegalizationPatterns &patterns, 2475b6eb26fdSRiver Riddle DenseMap<OperationName, unsigned> &minOpPatternDepth, 2476b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2477b6eb26fdSRiver Riddle unsigned minDepth = std::numeric_limits<unsigned>::max(); 2478b6eb26fdSRiver Riddle 2479b6eb26fdSRiver Riddle // Compute the depth for each pattern within the set. 2480b6eb26fdSRiver Riddle SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth; 2481b6eb26fdSRiver Riddle patternsByDepth.reserve(patterns.size()); 2482b6eb26fdSRiver Riddle for (const Pattern *pattern : patterns) { 2483015192c6SRiver Riddle unsigned depth = 1; 2484b6eb26fdSRiver Riddle for (auto generatedOp : pattern->getGeneratedOps()) { 2485b6eb26fdSRiver Riddle unsigned generatedOpDepth = computeOpLegalizationDepth( 2486b6eb26fdSRiver Riddle generatedOp, minOpPatternDepth, legalizerPatterns); 2487b6eb26fdSRiver Riddle depth = std::max(depth, generatedOpDepth + 1); 2488b6eb26fdSRiver Riddle } 2489b6eb26fdSRiver Riddle patternsByDepth.emplace_back(pattern, depth); 2490b6eb26fdSRiver Riddle 2491b6eb26fdSRiver Riddle // Update the minimum depth of the pattern list. 2492b6eb26fdSRiver Riddle minDepth = std::min(minDepth, depth); 2493b6eb26fdSRiver Riddle } 2494b6eb26fdSRiver Riddle 2495b6eb26fdSRiver Riddle // If the operation only has one legalization pattern, there is no need to 2496b6eb26fdSRiver Riddle // sort them. 2497b6eb26fdSRiver Riddle if (patternsByDepth.size() == 1) 2498b6eb26fdSRiver Riddle return minDepth; 2499b6eb26fdSRiver Riddle 2500b6eb26fdSRiver Riddle // Sort the patterns by those likely to be the most beneficial. 2501ee3c6de7SXiang Li std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(), 2502ee3c6de7SXiang Li [](const std::pair<const Pattern *, unsigned> &lhs, 2503ee3c6de7SXiang Li const std::pair<const Pattern *, unsigned> &rhs) { 2504b6eb26fdSRiver Riddle // First sort by the smaller pattern legalization 2505b6eb26fdSRiver Riddle // depth. 2506ee3c6de7SXiang Li if (lhs.second != rhs.second) 2507ee3c6de7SXiang Li return lhs.second < rhs.second; 2508b6eb26fdSRiver Riddle 2509b6eb26fdSRiver Riddle // Then sort by the larger pattern benefit. 2510ee3c6de7SXiang Li auto lhsBenefit = lhs.first->getBenefit(); 2511ee3c6de7SXiang Li auto rhsBenefit = rhs.first->getBenefit(); 2512ee3c6de7SXiang Li return lhsBenefit > rhsBenefit; 2513b6eb26fdSRiver Riddle }); 2514b6eb26fdSRiver Riddle 2515b6eb26fdSRiver Riddle // Update the legalization pattern to use the new sorted list. 2516b6eb26fdSRiver Riddle patterns.clear(); 2517b6eb26fdSRiver Riddle for (auto &patternIt : patternsByDepth) 2518b6eb26fdSRiver Riddle patterns.push_back(patternIt.first); 2519b6eb26fdSRiver Riddle return minDepth; 2520b6eb26fdSRiver Riddle } 2521b6eb26fdSRiver Riddle 2522b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 2523b6eb26fdSRiver Riddle // OperationConverter 2524b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 2525b6eb26fdSRiver Riddle namespace { 2526b6eb26fdSRiver Riddle enum OpConversionMode { 252701b55f16SRiver Riddle /// In this mode, the conversion will ignore failed conversions to allow 252801b55f16SRiver Riddle /// illegal operations to co-exist in the IR. 2529b6eb26fdSRiver Riddle Partial, 2530b6eb26fdSRiver Riddle 253101b55f16SRiver Riddle /// In this mode, all operations must be legal for the given target for the 253201b55f16SRiver Riddle /// conversion to succeed. 2533b6eb26fdSRiver Riddle Full, 2534b6eb26fdSRiver Riddle 253501b55f16SRiver Riddle /// In this mode, operations are analyzed for legality. No actual rewrites are 253601b55f16SRiver Riddle /// applied to the operations on success. 2537b6eb26fdSRiver Riddle Analysis, 2538b6eb26fdSRiver Riddle }; 2539a622b21fSMatthias Springer } // namespace 2540b6eb26fdSRiver Riddle 2541a622b21fSMatthias Springer namespace mlir { 2542b6eb26fdSRiver Riddle // This class converts operations to a given conversion target via a set of 2543b6eb26fdSRiver Riddle // rewrite patterns. The conversion behaves differently depending on the 2544b6eb26fdSRiver Riddle // conversion mode. 2545b6eb26fdSRiver Riddle struct OperationConverter { 2546370a6f09SMehdi Amini explicit OperationConverter(const ConversionTarget &target, 254779d7f618SChris Lattner const FrozenRewritePatternSet &patterns, 2548a2821094SMatthias Springer const ConversionConfig &config, 2549a2821094SMatthias Springer OpConversionMode mode) 25509b6bd709SMatthias Springer : config(config), opLegalizer(target, patterns, this->config), 25519b6bd709SMatthias Springer mode(mode) {} 2552b6eb26fdSRiver Riddle 2553b6eb26fdSRiver Riddle /// Converts the given operations to the conversion target. 2554a2821094SMatthias Springer LogicalResult convertOperations(ArrayRef<Operation *> ops); 2555b6eb26fdSRiver Riddle 2556b6eb26fdSRiver Riddle private: 2557b6eb26fdSRiver Riddle /// Converts an operation with the given rewriter. 2558b6eb26fdSRiver Riddle LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); 2559b6eb26fdSRiver Riddle 2560a2821094SMatthias Springer /// Dialect conversion configuration. 2561a2821094SMatthias Springer ConversionConfig config; 2562a2821094SMatthias Springer 25639b6bd709SMatthias Springer /// The legalizer to use when converting operations. 25649b6bd709SMatthias Springer OperationLegalizer opLegalizer; 25659b6bd709SMatthias Springer 2566b6eb26fdSRiver Riddle /// The conversion mode to use when legalizing operations. 2567b6eb26fdSRiver Riddle OpConversionMode mode; 2568b6eb26fdSRiver Riddle }; 2569a622b21fSMatthias Springer } // namespace mlir 2570b6eb26fdSRiver Riddle 2571b6eb26fdSRiver Riddle LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, 2572b6eb26fdSRiver Riddle Operation *op) { 2573b6eb26fdSRiver Riddle // Legalize the given operation. 2574b6eb26fdSRiver Riddle if (failed(opLegalizer.legalize(op, rewriter))) { 2575b6eb26fdSRiver Riddle // Handle the case of a failed conversion for each of the different modes. 2576b6eb26fdSRiver Riddle // Full conversions expect all operations to be converted. 2577b6eb26fdSRiver Riddle if (mode == OpConversionMode::Full) 2578b6eb26fdSRiver Riddle return op->emitError() 2579b6eb26fdSRiver Riddle << "failed to legalize operation '" << op->getName() << "'"; 2580b6eb26fdSRiver Riddle // Partial conversions allow conversions to fail iff the operation was not 2581a2821094SMatthias Springer // explicitly marked as illegal. If the user provided a `unlegalizedOps` 2582a2821094SMatthias Springer // set, non-legalizable ops are added to that set. 2583b6eb26fdSRiver Riddle if (mode == OpConversionMode::Partial) { 2584b6eb26fdSRiver Riddle if (opLegalizer.isIllegal(op)) 2585b6eb26fdSRiver Riddle return op->emitError() 2586b6eb26fdSRiver Riddle << "failed to legalize operation '" << op->getName() 2587b6eb26fdSRiver Riddle << "' that was explicitly marked illegal"; 2588a2821094SMatthias Springer if (config.unlegalizedOps) 2589a2821094SMatthias Springer config.unlegalizedOps->insert(op); 2590b6eb26fdSRiver Riddle } 2591b6eb26fdSRiver Riddle } else if (mode == OpConversionMode::Analysis) { 2592b6eb26fdSRiver Riddle // Analysis conversions don't fail if any operations fail to legalize, 2593b6eb26fdSRiver Riddle // they are only interested in the operations that were successfully 2594b6eb26fdSRiver Riddle // legalized. 2595a2821094SMatthias Springer if (config.legalizableOps) 2596a2821094SMatthias Springer config.legalizableOps->insert(op); 2597b6eb26fdSRiver Riddle } 2598b6eb26fdSRiver Riddle return success(); 2599b6eb26fdSRiver Riddle } 2600b6eb26fdSRiver Riddle 26013815f478SMatthias Springer static LogicalResult 26023815f478SMatthias Springer legalizeUnresolvedMaterialization(RewriterBase &rewriter, 26033815f478SMatthias Springer UnresolvedMaterializationRewrite *rewrite) { 26043815f478SMatthias Springer UnrealizedConversionCastOp op = rewrite->getOperation(); 26053815f478SMatthias Springer assert(!op.use_empty() && 26063815f478SMatthias Springer "expected that dead materializations have already been DCE'd"); 26073815f478SMatthias Springer Operation::operand_range inputOperands = op.getOperands(); 26083815f478SMatthias Springer 26093815f478SMatthias Springer // Try to materialize the conversion. 26103815f478SMatthias Springer if (const TypeConverter *converter = rewrite->getConverter()) { 26113815f478SMatthias Springer rewriter.setInsertionPoint(op); 26129df63b26SMatthias Springer SmallVector<Value> newMaterialization; 26133815f478SMatthias Springer switch (rewrite->getMaterializationKind()) { 26143815f478SMatthias Springer case MaterializationKind::Target: 26153815f478SMatthias Springer newMaterialization = converter->materializeTargetConversion( 26169df63b26SMatthias Springer rewriter, op->getLoc(), op.getResultTypes(), inputOperands, 26170d906a42SMatthias Springer rewrite->getOriginalType()); 26183815f478SMatthias Springer break; 26193815f478SMatthias Springer case MaterializationKind::Source: 26209df63b26SMatthias Springer assert(op->getNumResults() == 1 && "expected single result"); 26219df63b26SMatthias Springer Value sourceMat = converter->materializeSourceConversion( 26229df63b26SMatthias Springer rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands); 26239df63b26SMatthias Springer if (sourceMat) 26249df63b26SMatthias Springer newMaterialization.push_back(sourceMat); 26253815f478SMatthias Springer break; 26263815f478SMatthias Springer } 26279df63b26SMatthias Springer if (!newMaterialization.empty()) { 262858389b22SMatthias Springer #ifndef NDEBUG 262958389b22SMatthias Springer ValueRange newMaterializationRange(newMaterialization); 263058389b22SMatthias Springer assert(TypeRange(newMaterializationRange) == op.getResultTypes() && 26313815f478SMatthias Springer "materialization callback produced value of incorrect type"); 263258389b22SMatthias Springer #endif // NDEBUG 26333815f478SMatthias Springer rewriter.replaceOp(op, newMaterialization); 26343815f478SMatthias Springer return success(); 26353815f478SMatthias Springer } 26363815f478SMatthias Springer } 26373815f478SMatthias Springer 26389df63b26SMatthias Springer InFlightDiagnostic diag = op->emitError() 26399df63b26SMatthias Springer << "failed to legalize unresolved materialization " 26403815f478SMatthias Springer "from (" 26419df63b26SMatthias Springer << inputOperands.getTypes() << ") to (" 26429df63b26SMatthias Springer << op.getResultTypes() 2643ea050ab1SMatthias Springer << ") that remained live after conversion"; 26443815f478SMatthias Springer diag.attachNote(op->getUsers().begin()->getLoc()) 26453815f478SMatthias Springer << "see existing live user here: " << *op->getUsers().begin(); 26463815f478SMatthias Springer return failure(); 26473815f478SMatthias Springer } 26483815f478SMatthias Springer 2649a2821094SMatthias Springer LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { 2650b6eb26fdSRiver Riddle if (ops.empty()) 2651b6eb26fdSRiver Riddle return success(); 2652370a6f09SMehdi Amini const ConversionTarget &target = opLegalizer.getTarget(); 2653b6eb26fdSRiver Riddle 2654b6eb26fdSRiver Riddle // Compute the set of operations and blocks to convert. 2655015192c6SRiver Riddle SmallVector<Operation *> toConvert; 2656b6eb26fdSRiver Riddle for (auto *op : ops) { 2657b884f4efSMatthias Springer op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>( 2658b884f4efSMatthias Springer [&](Operation *op) { 2659b884f4efSMatthias Springer toConvert.push_back(op); 2660b884f4efSMatthias Springer // Don't check this operation's children for conversion if the 2661b884f4efSMatthias Springer // operation is recursively legal. 2662b884f4efSMatthias Springer auto legalityInfo = target.isLegal(op); 2663b884f4efSMatthias Springer if (legalityInfo && legalityInfo->isRecursivelyLegal) 2664b884f4efSMatthias Springer return WalkResult::skip(); 2665b884f4efSMatthias Springer return WalkResult::advance(); 2666b884f4efSMatthias Springer }); 2667b6eb26fdSRiver Riddle } 2668b6eb26fdSRiver Riddle 2669b6eb26fdSRiver Riddle // Convert each operation and discard rewrites on failure. 2670a2821094SMatthias Springer ConversionPatternRewriter rewriter(ops.front()->getContext(), config); 2671b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); 2672b8c6b152SChia-hung Duan 2673b6eb26fdSRiver Riddle for (auto *op : toConvert) 2674b6eb26fdSRiver Riddle if (failed(convert(rewriter, op))) 267559ff4d13SMatthias Springer return rewriterImpl.undoRewrites(), failure(); 2676b6eb26fdSRiver Riddle 26775030deadSMatthias Springer // After a successful conversion, apply rewrites. 2678b6eb26fdSRiver Riddle rewriterImpl.applyRewrites(); 26793815f478SMatthias Springer 26803815f478SMatthias Springer // Gather all unresolved materializations. 26813815f478SMatthias Springer SmallVector<UnrealizedConversionCastOp> allCastOps; 2682d588e49aSMatthias Springer const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *> 2683d588e49aSMatthias Springer &materializations = rewriterImpl.unresolvedMaterializations; 2684d588e49aSMatthias Springer for (auto it : materializations) { 2685d588e49aSMatthias Springer if (rewriterImpl.eraseRewriter.wasErased(it.first)) 26863815f478SMatthias Springer continue; 2687d588e49aSMatthias Springer allCastOps.push_back(it.first); 26883815f478SMatthias Springer } 26893815f478SMatthias Springer 26903815f478SMatthias Springer // Reconcile all UnrealizedConversionCastOps that were inserted by the 26913815f478SMatthias Springer // dialect conversion frameworks. (Not the one that were inserted by 26923815f478SMatthias Springer // patterns.) 26936093c26aSMatthias Springer SmallVector<UnrealizedConversionCastOp> remainingCastOps; 26946093c26aSMatthias Springer reconcileUnrealizedCasts(allCastOps, &remainingCastOps); 26953815f478SMatthias Springer 26963815f478SMatthias Springer // Try to legalize all unresolved materializations. 26973815f478SMatthias Springer if (config.buildMaterializations) { 26983815f478SMatthias Springer IRRewriter rewriter(rewriterImpl.context, config.listener); 26996093c26aSMatthias Springer for (UnrealizedConversionCastOp castOp : remainingCastOps) { 2700d588e49aSMatthias Springer auto it = materializations.find(castOp); 2701d588e49aSMatthias Springer assert(it != materializations.end() && "inconsistent state"); 27023815f478SMatthias Springer if (failed(legalizeUnresolvedMaterialization(rewriter, it->second))) 27033815f478SMatthias Springer return failure(); 27043815f478SMatthias Springer } 27053815f478SMatthias Springer } 27063815f478SMatthias Springer 2707b6eb26fdSRiver Riddle return success(); 2708b6eb26fdSRiver Riddle } 2709b6eb26fdSRiver Riddle 2710b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 2711a9f62244SMatthias Springer // Reconcile Unrealized Casts 2712a9f62244SMatthias Springer //===----------------------------------------------------------------------===// 2713a9f62244SMatthias Springer 2714a9f62244SMatthias Springer void mlir::reconcileUnrealizedCasts( 2715a9f62244SMatthias Springer ArrayRef<UnrealizedConversionCastOp> castOps, 2716a9f62244SMatthias Springer SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) { 2717a9f62244SMatthias Springer SetVector<UnrealizedConversionCastOp> worklist(castOps.begin(), 2718a9f62244SMatthias Springer castOps.end()); 2719a9f62244SMatthias Springer // This set is maintained only if `remainingCastOps` is provided. 2720a9f62244SMatthias Springer DenseSet<Operation *> erasedOps; 2721a9f62244SMatthias Springer 2722a9f62244SMatthias Springer // Helper function that adds all operands to the worklist that are an 2723a9f62244SMatthias Springer // unrealized_conversion_cast op result. 2724a9f62244SMatthias Springer auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) { 2725a9f62244SMatthias Springer for (Value v : castOp.getInputs()) 2726a9f62244SMatthias Springer if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>()) 2727a9f62244SMatthias Springer worklist.insert(inputCastOp); 2728a9f62244SMatthias Springer }; 2729a9f62244SMatthias Springer 2730a9f62244SMatthias Springer // Helper function that return the unrealized_conversion_cast op that 2731a9f62244SMatthias Springer // defines all inputs of the given op (in the same order). Return "nullptr" 2732a9f62244SMatthias Springer // if there is no such op. 2733a9f62244SMatthias Springer auto getInputCast = 2734a9f62244SMatthias Springer [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp { 2735a9f62244SMatthias Springer if (castOp.getInputs().empty()) 2736a9f62244SMatthias Springer return {}; 2737a9f62244SMatthias Springer auto inputCastOp = 2738a9f62244SMatthias Springer castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>(); 2739a9f62244SMatthias Springer if (!inputCastOp) 2740a9f62244SMatthias Springer return {}; 2741a9f62244SMatthias Springer if (inputCastOp.getOutputs() != castOp.getInputs()) 2742a9f62244SMatthias Springer return {}; 2743a9f62244SMatthias Springer return inputCastOp; 2744a9f62244SMatthias Springer }; 2745a9f62244SMatthias Springer 2746a9f62244SMatthias Springer // Process ops in the worklist bottom-to-top. 2747a9f62244SMatthias Springer while (!worklist.empty()) { 2748a9f62244SMatthias Springer UnrealizedConversionCastOp castOp = worklist.pop_back_val(); 2749a9f62244SMatthias Springer if (castOp->use_empty()) { 2750a9f62244SMatthias Springer // DCE: If the op has no users, erase it. Add the operands to the 2751a9f62244SMatthias Springer // worklist to find additional DCE opportunities. 2752a9f62244SMatthias Springer enqueueOperands(castOp); 2753a9f62244SMatthias Springer if (remainingCastOps) 2754a9f62244SMatthias Springer erasedOps.insert(castOp.getOperation()); 2755a9f62244SMatthias Springer castOp->erase(); 2756a9f62244SMatthias Springer continue; 2757a9f62244SMatthias Springer } 2758a9f62244SMatthias Springer 2759a9f62244SMatthias Springer // Traverse the chain of input cast ops to see if an op with the same 2760a9f62244SMatthias Springer // input types can be found. 2761a9f62244SMatthias Springer UnrealizedConversionCastOp nextCast = castOp; 2762a9f62244SMatthias Springer while (nextCast) { 2763a9f62244SMatthias Springer if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) { 2764a9f62244SMatthias Springer // Found a cast where the input types match the output types of the 2765a9f62244SMatthias Springer // matched op. We can directly use those inputs and the matched op can 2766a9f62244SMatthias Springer // be removed. 2767a9f62244SMatthias Springer enqueueOperands(castOp); 2768a9f62244SMatthias Springer castOp.replaceAllUsesWith(nextCast.getInputs()); 2769a9f62244SMatthias Springer if (remainingCastOps) 2770a9f62244SMatthias Springer erasedOps.insert(castOp.getOperation()); 2771a9f62244SMatthias Springer castOp->erase(); 2772a9f62244SMatthias Springer break; 2773a9f62244SMatthias Springer } 2774a9f62244SMatthias Springer nextCast = getInputCast(nextCast); 2775a9f62244SMatthias Springer } 2776a9f62244SMatthias Springer } 2777a9f62244SMatthias Springer 2778a9f62244SMatthias Springer if (remainingCastOps) 2779a9f62244SMatthias Springer for (UnrealizedConversionCastOp op : castOps) 2780a9f62244SMatthias Springer if (!erasedOps.contains(op.getOperation())) 2781a9f62244SMatthias Springer remainingCastOps->push_back(op); 2782a9f62244SMatthias Springer } 2783a9f62244SMatthias Springer 2784a9f62244SMatthias Springer //===----------------------------------------------------------------------===// 2785b6eb26fdSRiver Riddle // Type Conversion 2786b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 2787b6eb26fdSRiver Riddle 2788b6eb26fdSRiver Riddle void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo, 2789b6eb26fdSRiver Riddle ArrayRef<Type> types) { 2790b6eb26fdSRiver Riddle assert(!types.empty() && "expected valid types"); 2791b6eb26fdSRiver Riddle remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size()); 2792b6eb26fdSRiver Riddle addInputs(types); 2793b6eb26fdSRiver Riddle } 2794b6eb26fdSRiver Riddle 2795b6eb26fdSRiver Riddle void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) { 2796b6eb26fdSRiver Riddle assert(!types.empty() && 2797b6eb26fdSRiver Riddle "1->0 type remappings don't need to be added explicitly"); 2798b6eb26fdSRiver Riddle argTypes.append(types.begin(), types.end()); 2799b6eb26fdSRiver Riddle } 2800b6eb26fdSRiver Riddle 2801b6eb26fdSRiver Riddle void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, 2802b6eb26fdSRiver Riddle unsigned newInputNo, 2803b6eb26fdSRiver Riddle unsigned newInputCount) { 2804b6eb26fdSRiver Riddle assert(!remappedInputs[origInputNo] && "input has already been remapped"); 2805b6eb26fdSRiver Riddle assert(newInputCount != 0 && "expected valid input count"); 2806b6eb26fdSRiver Riddle remappedInputs[origInputNo] = 2807b6eb26fdSRiver Riddle InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr}; 2808b6eb26fdSRiver Riddle } 2809b6eb26fdSRiver Riddle 2810b6eb26fdSRiver Riddle void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, 2811b6eb26fdSRiver Riddle Value replacementValue) { 2812b6eb26fdSRiver Riddle assert(!remappedInputs[origInputNo] && "input has already been remapped"); 2813b6eb26fdSRiver Riddle remappedInputs[origInputNo] = 2814b6eb26fdSRiver Riddle InputMapping{origInputNo, /*size=*/0, replacementValue}; 2815b6eb26fdSRiver Riddle } 2816b6eb26fdSRiver Riddle 2817b6eb26fdSRiver Riddle LogicalResult TypeConverter::convertType(Type t, 28183dd58333SMatthias Springer SmallVectorImpl<Type> &results) const { 28193cc311abSMatthias Springer assert(t && "expected non-null type"); 28203cc311abSMatthias Springer 2821a8daefedSMehdi Amini { 2822a8daefedSMehdi Amini std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex, 2823a8daefedSMehdi Amini std::defer_lock); 2824a8daefedSMehdi Amini if (t.getContext()->isMultithreadingEnabled()) 2825a8daefedSMehdi Amini cacheReadLock.lock(); 2826b6eb26fdSRiver Riddle auto existingIt = cachedDirectConversions.find(t); 2827b6eb26fdSRiver Riddle if (existingIt != cachedDirectConversions.end()) { 2828b6eb26fdSRiver Riddle if (existingIt->second) 2829b6eb26fdSRiver Riddle results.push_back(existingIt->second); 2830b6eb26fdSRiver Riddle return success(existingIt->second != nullptr); 2831b6eb26fdSRiver Riddle } 2832b6eb26fdSRiver Riddle auto multiIt = cachedMultiConversions.find(t); 2833b6eb26fdSRiver Riddle if (multiIt != cachedMultiConversions.end()) { 2834b6eb26fdSRiver Riddle results.append(multiIt->second.begin(), multiIt->second.end()); 2835b6eb26fdSRiver Riddle return success(); 2836b6eb26fdSRiver Riddle } 2837a8daefedSMehdi Amini } 2838b6eb26fdSRiver Riddle // Walk the added converters in reverse order to apply the most recently 2839b6eb26fdSRiver Riddle // registered first. 2840b6eb26fdSRiver Riddle size_t currentCount = results.size(); 2841dc3dc974SMehdi Amini 2842a8daefedSMehdi Amini std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex, 2843a8daefedSMehdi Amini std::defer_lock); 2844a8daefedSMehdi Amini 28453dd58333SMatthias Springer for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { 2846dc3dc974SMehdi Amini if (std::optional<LogicalResult> result = converter(t, results)) { 2847a8daefedSMehdi Amini if (t.getContext()->isMultithreadingEnabled()) 2848a8daefedSMehdi Amini cacheWriteLock.lock(); 2849b6eb26fdSRiver Riddle if (!succeeded(*result)) { 2850b6eb26fdSRiver Riddle cachedDirectConversions.try_emplace(t, nullptr); 2851b6eb26fdSRiver Riddle return failure(); 2852b6eb26fdSRiver Riddle } 2853b6eb26fdSRiver Riddle auto newTypes = ArrayRef<Type>(results).drop_front(currentCount); 2854b6eb26fdSRiver Riddle if (newTypes.size() == 1) 2855b6eb26fdSRiver Riddle cachedDirectConversions.try_emplace(t, newTypes.front()); 2856b6eb26fdSRiver Riddle else 2857b6eb26fdSRiver Riddle cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); 2858b6eb26fdSRiver Riddle return success(); 2859b6eb26fdSRiver Riddle } 2860b6eb26fdSRiver Riddle } 2861b6eb26fdSRiver Riddle return failure(); 2862b6eb26fdSRiver Riddle } 2863b6eb26fdSRiver Riddle 28643dd58333SMatthias Springer Type TypeConverter::convertType(Type t) const { 2865b6eb26fdSRiver Riddle // Use the multi-type result version to convert the type. 2866b6eb26fdSRiver Riddle SmallVector<Type, 1> results; 2867b6eb26fdSRiver Riddle if (failed(convertType(t, results))) 2868b6eb26fdSRiver Riddle return nullptr; 2869b6eb26fdSRiver Riddle 2870b6eb26fdSRiver Riddle // Check to ensure that only one type was produced. 2871b6eb26fdSRiver Riddle return results.size() == 1 ? results.front() : nullptr; 2872b6eb26fdSRiver Riddle } 2873b6eb26fdSRiver Riddle 28743dd58333SMatthias Springer LogicalResult 28753dd58333SMatthias Springer TypeConverter::convertTypes(TypeRange types, 28763dd58333SMatthias Springer SmallVectorImpl<Type> &results) const { 28773dfa8614SRiver Riddle for (Type type : types) 2878b6eb26fdSRiver Riddle if (failed(convertType(type, results))) 2879b6eb26fdSRiver Riddle return failure(); 2880b6eb26fdSRiver Riddle return success(); 2881b6eb26fdSRiver Riddle } 2882b6eb26fdSRiver Riddle 28833dd58333SMatthias Springer bool TypeConverter::isLegal(Type type) const { 28843dd58333SMatthias Springer return convertType(type) == type; 28853dd58333SMatthias Springer } 28863dd58333SMatthias Springer bool TypeConverter::isLegal(Operation *op) const { 2887b6eb26fdSRiver Riddle return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); 2888b6eb26fdSRiver Riddle } 2889b6eb26fdSRiver Riddle 28903dd58333SMatthias Springer bool TypeConverter::isLegal(Region *region) const { 2891b6eb26fdSRiver Riddle return llvm::all_of(*region, [this](Block &block) { 2892b6eb26fdSRiver Riddle return isLegal(block.getArgumentTypes()); 2893b6eb26fdSRiver Riddle }); 2894b6eb26fdSRiver Riddle } 2895b6eb26fdSRiver Riddle 28963dd58333SMatthias Springer bool TypeConverter::isSignatureLegal(FunctionType ty) const { 2897b6eb26fdSRiver Riddle return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults())); 2898b6eb26fdSRiver Riddle } 2899b6eb26fdSRiver Riddle 29003dd58333SMatthias Springer LogicalResult 29013dd58333SMatthias Springer TypeConverter::convertSignatureArg(unsigned inputNo, Type type, 29023dd58333SMatthias Springer SignatureConversion &result) const { 2903b6eb26fdSRiver Riddle // Try to convert the given input type. 2904b6eb26fdSRiver Riddle SmallVector<Type, 1> convertedTypes; 2905b6eb26fdSRiver Riddle if (failed(convertType(type, convertedTypes))) 2906b6eb26fdSRiver Riddle return failure(); 2907b6eb26fdSRiver Riddle 2908b6eb26fdSRiver Riddle // If this argument is being dropped, there is nothing left to do. 2909b6eb26fdSRiver Riddle if (convertedTypes.empty()) 2910b6eb26fdSRiver Riddle return success(); 2911b6eb26fdSRiver Riddle 2912b6eb26fdSRiver Riddle // Otherwise, add the new inputs. 2913b6eb26fdSRiver Riddle result.addInputs(inputNo, convertedTypes); 2914b6eb26fdSRiver Riddle return success(); 2915b6eb26fdSRiver Riddle } 29163dd58333SMatthias Springer LogicalResult 29173dd58333SMatthias Springer TypeConverter::convertSignatureArgs(TypeRange types, 2918b6eb26fdSRiver Riddle SignatureConversion &result, 29193dd58333SMatthias Springer unsigned origInputOffset) const { 2920b6eb26fdSRiver Riddle for (unsigned i = 0, e = types.size(); i != e; ++i) 2921b6eb26fdSRiver Riddle if (failed(convertSignatureArg(origInputOffset + i, types[i], result))) 2922b6eb26fdSRiver Riddle return failure(); 2923b6eb26fdSRiver Riddle return success(); 2924b6eb26fdSRiver Riddle } 2925b6eb26fdSRiver Riddle 29260d906a42SMatthias Springer Value TypeConverter::materializeArgumentConversion(OpBuilder &builder, 29270d906a42SMatthias Springer Location loc, 29280d906a42SMatthias Springer Type resultType, 29290d906a42SMatthias Springer ValueRange inputs) const { 29300d906a42SMatthias Springer for (const MaterializationCallbackFn &fn : 29310d906a42SMatthias Springer llvm::reverse(argumentMaterializations)) 2932f18c3e4eSMatthias Springer if (Value result = fn(builder, resultType, inputs, loc)) 2933f18c3e4eSMatthias Springer return result; 2934b6eb26fdSRiver Riddle return nullptr; 2935b6eb26fdSRiver Riddle } 2936b6eb26fdSRiver Riddle 29370d906a42SMatthias Springer Value TypeConverter::materializeSourceConversion(OpBuilder &builder, 29380d906a42SMatthias Springer Location loc, Type resultType, 29390d906a42SMatthias Springer ValueRange inputs) const { 29400d906a42SMatthias Springer for (const MaterializationCallbackFn &fn : 29410d906a42SMatthias Springer llvm::reverse(sourceMaterializations)) 2942f18c3e4eSMatthias Springer if (Value result = fn(builder, resultType, inputs, loc)) 2943f18c3e4eSMatthias Springer return result; 29440d906a42SMatthias Springer return nullptr; 29450d906a42SMatthias Springer } 29460d906a42SMatthias Springer 29470d906a42SMatthias Springer Value TypeConverter::materializeTargetConversion(OpBuilder &builder, 29480d906a42SMatthias Springer Location loc, Type resultType, 29490d906a42SMatthias Springer ValueRange inputs, 29500d906a42SMatthias Springer Type originalType) const { 29518c4bc1e7SMatthias Springer SmallVector<Value> result = materializeTargetConversion( 29528c4bc1e7SMatthias Springer builder, loc, TypeRange(resultType), inputs, originalType); 29538c4bc1e7SMatthias Springer if (result.empty()) 29540d906a42SMatthias Springer return nullptr; 29558c4bc1e7SMatthias Springer assert(result.size() == 1 && "expected single result"); 29568c4bc1e7SMatthias Springer return result.front(); 29578c4bc1e7SMatthias Springer } 29588c4bc1e7SMatthias Springer 29598c4bc1e7SMatthias Springer SmallVector<Value> TypeConverter::materializeTargetConversion( 29608c4bc1e7SMatthias Springer OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs, 29618c4bc1e7SMatthias Springer Type originalType) const { 29628c4bc1e7SMatthias Springer for (const TargetMaterializationCallbackFn &fn : 29638c4bc1e7SMatthias Springer llvm::reverse(targetMaterializations)) { 29648c4bc1e7SMatthias Springer SmallVector<Value> result = 29658c4bc1e7SMatthias Springer fn(builder, resultTypes, inputs, loc, originalType); 29668c4bc1e7SMatthias Springer if (result.empty()) 29678c4bc1e7SMatthias Springer continue; 2968a8ef0b33SMatthias Springer assert(TypeRange(ValueRange(result)) == resultTypes && 29698c4bc1e7SMatthias Springer "callback produced incorrect number of values or values with " 29708c4bc1e7SMatthias Springer "incorrect types"); 29718c4bc1e7SMatthias Springer return result; 29728c4bc1e7SMatthias Springer } 29738c4bc1e7SMatthias Springer return {}; 29740d906a42SMatthias Springer } 29750d906a42SMatthias Springer 29763dd58333SMatthias Springer std::optional<TypeConverter::SignatureConversion> 29773dd58333SMatthias Springer TypeConverter::convertBlockSignature(Block *block) const { 2978b6eb26fdSRiver Riddle SignatureConversion conversion(block->getNumArguments()); 2979b6eb26fdSRiver Riddle if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion))) 29801a36588eSKazu Hirata return std::nullopt; 2981b6eb26fdSRiver Riddle return conversion; 2982b6eb26fdSRiver Riddle } 2983b6eb26fdSRiver Riddle 298401b55f16SRiver Riddle //===----------------------------------------------------------------------===// 2985499abb24SKrzysztof Drewniak // Type attribute conversion 2986499abb24SKrzysztof Drewniak //===----------------------------------------------------------------------===// 2987499abb24SKrzysztof Drewniak TypeConverter::AttributeConversionResult 2988499abb24SKrzysztof Drewniak TypeConverter::AttributeConversionResult::result(Attribute attr) { 2989499abb24SKrzysztof Drewniak return AttributeConversionResult(attr, resultTag); 2990499abb24SKrzysztof Drewniak } 2991499abb24SKrzysztof Drewniak 2992499abb24SKrzysztof Drewniak TypeConverter::AttributeConversionResult 2993499abb24SKrzysztof Drewniak TypeConverter::AttributeConversionResult::na() { 2994499abb24SKrzysztof Drewniak return AttributeConversionResult(nullptr, naTag); 2995499abb24SKrzysztof Drewniak } 2996499abb24SKrzysztof Drewniak 2997499abb24SKrzysztof Drewniak TypeConverter::AttributeConversionResult 2998499abb24SKrzysztof Drewniak TypeConverter::AttributeConversionResult::abort() { 2999499abb24SKrzysztof Drewniak return AttributeConversionResult(nullptr, abortTag); 3000499abb24SKrzysztof Drewniak } 3001499abb24SKrzysztof Drewniak 3002499abb24SKrzysztof Drewniak bool TypeConverter::AttributeConversionResult::hasResult() const { 3003499abb24SKrzysztof Drewniak return impl.getInt() == resultTag; 3004499abb24SKrzysztof Drewniak } 3005499abb24SKrzysztof Drewniak 3006499abb24SKrzysztof Drewniak bool TypeConverter::AttributeConversionResult::isNa() const { 3007499abb24SKrzysztof Drewniak return impl.getInt() == naTag; 3008499abb24SKrzysztof Drewniak } 3009499abb24SKrzysztof Drewniak 3010499abb24SKrzysztof Drewniak bool TypeConverter::AttributeConversionResult::isAbort() const { 3011499abb24SKrzysztof Drewniak return impl.getInt() == abortTag; 3012499abb24SKrzysztof Drewniak } 3013499abb24SKrzysztof Drewniak 3014499abb24SKrzysztof Drewniak Attribute TypeConverter::AttributeConversionResult::getResult() const { 3015499abb24SKrzysztof Drewniak assert(hasResult() && "Cannot get result from N/A or abort"); 3016499abb24SKrzysztof Drewniak return impl.getPointer(); 3017499abb24SKrzysztof Drewniak } 3018499abb24SKrzysztof Drewniak 30193dd58333SMatthias Springer std::optional<Attribute> 30203dd58333SMatthias Springer TypeConverter::convertTypeAttribute(Type type, Attribute attr) const { 30213dd58333SMatthias Springer for (const TypeAttributeConversionCallbackFn &fn : 3022499abb24SKrzysztof Drewniak llvm::reverse(typeAttributeConversions)) { 3023499abb24SKrzysztof Drewniak AttributeConversionResult res = fn(type, attr); 3024499abb24SKrzysztof Drewniak if (res.hasResult()) 3025499abb24SKrzysztof Drewniak return res.getResult(); 3026499abb24SKrzysztof Drewniak if (res.isAbort()) 3027499abb24SKrzysztof Drewniak return std::nullopt; 3028499abb24SKrzysztof Drewniak } 3029499abb24SKrzysztof Drewniak return std::nullopt; 3030499abb24SKrzysztof Drewniak } 3031499abb24SKrzysztof Drewniak 3032499abb24SKrzysztof Drewniak //===----------------------------------------------------------------------===// 30337ceffae1SRiver Riddle // FunctionOpInterfaceSignatureConversion 303401b55f16SRiver Riddle //===----------------------------------------------------------------------===// 303501b55f16SRiver Riddle 3036ed4749f9SIvan Butygin static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, 3037ce254598SMatthias Springer const TypeConverter &typeConverter, 3038ed4749f9SIvan Butygin ConversionPatternRewriter &rewriter) { 3039e5f8cdd6SKai Sasaki FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType()); 3040e5f8cdd6SKai Sasaki if (!type) 3041e5f8cdd6SKai Sasaki return failure(); 3042ed4749f9SIvan Butygin 3043ed4749f9SIvan Butygin // Convert the original function types. 3044ed4749f9SIvan Butygin TypeConverter::SignatureConversion result(type.getNumInputs()); 3045ed4749f9SIvan Butygin SmallVector<Type, 1> newResults; 3046ed4749f9SIvan Butygin if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || 3047ed4749f9SIvan Butygin failed(typeConverter.convertTypes(type.getResults(), newResults)) || 3048ed4749f9SIvan Butygin failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), 3049ed4749f9SIvan Butygin typeConverter, &result))) 3050ed4749f9SIvan Butygin return failure(); 3051ed4749f9SIvan Butygin 3052ed4749f9SIvan Butygin // Update the function signature in-place. 3053ed4749f9SIvan Butygin auto newType = FunctionType::get(rewriter.getContext(), 3054ed4749f9SIvan Butygin result.getConvertedTypes(), newResults); 3055ed4749f9SIvan Butygin 30565fcf907bSMatthias Springer rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); }); 3057ed4749f9SIvan Butygin 3058ed4749f9SIvan Butygin return success(); 3059ed4749f9SIvan Butygin } 3060ed4749f9SIvan Butygin 3061b6eb26fdSRiver Riddle /// Create a default conversion pattern that rewrites the type signature of a 30627ceffae1SRiver Riddle /// FunctionOpInterface op. This only supports ops which use FunctionType to 30637ceffae1SRiver Riddle /// represent their type. 3064b6eb26fdSRiver Riddle namespace { 30657ceffae1SRiver Riddle struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { 30667ceffae1SRiver Riddle FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, 30677ceffae1SRiver Riddle MLIRContext *ctx, 3068ce254598SMatthias Springer const TypeConverter &converter) 306976f3c2f3SRiver Riddle : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} 3070b6eb26fdSRiver Riddle 3071b6eb26fdSRiver Riddle LogicalResult 3072ed4749f9SIvan Butygin matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/, 3073b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter) const override { 30747ceffae1SRiver Riddle FunctionOpInterface funcOp = cast<FunctionOpInterface>(op); 3075ed4749f9SIvan Butygin return convertFuncOpTypes(funcOp, *typeConverter, rewriter); 3076ed4749f9SIvan Butygin } 3077ed4749f9SIvan Butygin }; 3078b6eb26fdSRiver Riddle 3079ed4749f9SIvan Butygin struct AnyFunctionOpInterfaceSignatureConversion 3080ed4749f9SIvan Butygin : public OpInterfaceConversionPattern<FunctionOpInterface> { 3081ed4749f9SIvan Butygin using OpInterfaceConversionPattern::OpInterfaceConversionPattern; 3082b6eb26fdSRiver Riddle 3083ed4749f9SIvan Butygin LogicalResult 3084ed4749f9SIvan Butygin matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/, 3085ed4749f9SIvan Butygin ConversionPatternRewriter &rewriter) const override { 3086ed4749f9SIvan Butygin return convertFuncOpTypes(funcOp, *typeConverter, rewriter); 3087b6eb26fdSRiver Riddle } 3088b6eb26fdSRiver Riddle }; 3089be0a7e9fSMehdi Amini } // namespace 3090b6eb26fdSRiver Riddle 309135ef3994SIvan Butygin FailureOr<Operation *> 309235ef3994SIvan Butygin mlir::convertOpResultTypes(Operation *op, ValueRange operands, 309335ef3994SIvan Butygin const TypeConverter &converter, 309435ef3994SIvan Butygin ConversionPatternRewriter &rewriter) { 309535ef3994SIvan Butygin assert(op && "Invalid op"); 309635ef3994SIvan Butygin Location loc = op->getLoc(); 309735ef3994SIvan Butygin if (converter.isLegal(op)) 309835ef3994SIvan Butygin return rewriter.notifyMatchFailure(loc, "op already legal"); 309935ef3994SIvan Butygin 310035ef3994SIvan Butygin OperationState newOp(loc, op->getName()); 310135ef3994SIvan Butygin newOp.addOperands(operands); 310235ef3994SIvan Butygin 310335ef3994SIvan Butygin SmallVector<Type> newResultTypes; 310435ef3994SIvan Butygin if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes))) 310535ef3994SIvan Butygin return rewriter.notifyMatchFailure(loc, "couldn't convert return types"); 310635ef3994SIvan Butygin 310735ef3994SIvan Butygin newOp.addTypes(newResultTypes); 310835ef3994SIvan Butygin newOp.addAttributes(op->getAttrs()); 310935ef3994SIvan Butygin return rewriter.create(newOp); 311035ef3994SIvan Butygin } 311135ef3994SIvan Butygin 31127ceffae1SRiver Riddle void mlir::populateFunctionOpInterfaceTypeConversionPattern( 3113dc4e913bSChris Lattner StringRef functionLikeOpName, RewritePatternSet &patterns, 3114ce254598SMatthias Springer const TypeConverter &converter) { 31157ceffae1SRiver Riddle patterns.add<FunctionOpInterfaceSignatureConversion>( 31163a506b31SChris Lattner functionLikeOpName, patterns.getContext(), converter); 31170a7a1ac7Smikeurbach } 31180a7a1ac7Smikeurbach 3119ed4749f9SIvan Butygin void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern( 3120ce254598SMatthias Springer RewritePatternSet &patterns, const TypeConverter &converter) { 3121ed4749f9SIvan Butygin patterns.add<AnyFunctionOpInterfaceSignatureConversion>( 3122ed4749f9SIvan Butygin converter, patterns.getContext()); 3123ed4749f9SIvan Butygin } 3124ed4749f9SIvan Butygin 3125b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 3126b6eb26fdSRiver Riddle // ConversionTarget 3127b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 3128b6eb26fdSRiver Riddle 3129b6eb26fdSRiver Riddle void ConversionTarget::setOpAction(OperationName op, 3130b6eb26fdSRiver Riddle LegalizationAction action) { 3131c6828e0cSCaitlyn Cano legalOperations[op].action = action; 3132b6eb26fdSRiver Riddle } 3133b6eb26fdSRiver Riddle 3134b6eb26fdSRiver Riddle void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames, 3135b6eb26fdSRiver Riddle LegalizationAction action) { 3136b6eb26fdSRiver Riddle for (StringRef dialect : dialectNames) 3137b6eb26fdSRiver Riddle legalDialects[dialect] = action; 3138b6eb26fdSRiver Riddle } 3139b6eb26fdSRiver Riddle 3140b6eb26fdSRiver Riddle auto ConversionTarget::getOpAction(OperationName op) const 31410de16fafSRamkumar Ramachandra -> std::optional<LegalizationAction> { 31420de16fafSRamkumar Ramachandra std::optional<LegalizationInfo> info = getOpInfo(op); 31430de16fafSRamkumar Ramachandra return info ? info->action : std::optional<LegalizationAction>(); 3144b6eb26fdSRiver Riddle } 3145b6eb26fdSRiver Riddle 3146b6eb26fdSRiver Riddle auto ConversionTarget::isLegal(Operation *op) const 31470de16fafSRamkumar Ramachandra -> std::optional<LegalOpDetails> { 31480de16fafSRamkumar Ramachandra std::optional<LegalizationInfo> info = getOpInfo(op->getName()); 3149b6eb26fdSRiver Riddle if (!info) 31501a36588eSKazu Hirata return std::nullopt; 3151b6eb26fdSRiver Riddle 3152b6eb26fdSRiver Riddle // Returns true if this operation instance is known to be legal. 3153b6eb26fdSRiver Riddle auto isOpLegal = [&] { 31541c9c2c91SBenjamin Kramer // Handle dynamic legality either with the provided legality function. 3155c6828e0cSCaitlyn Cano if (info->action == LegalizationAction::Dynamic) { 31560de16fafSRamkumar Ramachandra std::optional<bool> result = info->legalityFn(op); 3157c6828e0cSCaitlyn Cano if (result) 3158c6828e0cSCaitlyn Cano return *result; 3159c6828e0cSCaitlyn Cano } 3160b6eb26fdSRiver Riddle 3161b6eb26fdSRiver Riddle // Otherwise, the operation is only legal if it was marked 'Legal'. 3162b6eb26fdSRiver Riddle return info->action == LegalizationAction::Legal; 3163b6eb26fdSRiver Riddle }; 3164b6eb26fdSRiver Riddle if (!isOpLegal()) 31651a36588eSKazu Hirata return std::nullopt; 3166b6eb26fdSRiver Riddle 3167b6eb26fdSRiver Riddle // This operation is legal, compute any additional legality information. 3168b6eb26fdSRiver Riddle LegalOpDetails legalityDetails; 3169b6eb26fdSRiver Riddle if (info->isRecursivelyLegal) { 3170b6eb26fdSRiver Riddle auto legalityFnIt = opRecursiveLegalityFns.find(op->getName()); 3171c6828e0cSCaitlyn Cano if (legalityFnIt != opRecursiveLegalityFns.end()) { 3172c6828e0cSCaitlyn Cano legalityDetails.isRecursivelyLegal = 317330c67587SKazu Hirata legalityFnIt->second(op).value_or(true); 3174c6828e0cSCaitlyn Cano } else { 3175b6eb26fdSRiver Riddle legalityDetails.isRecursivelyLegal = true; 3176b6eb26fdSRiver Riddle } 3177c6828e0cSCaitlyn Cano } 3178b6eb26fdSRiver Riddle return legalityDetails; 3179b6eb26fdSRiver Riddle } 3180b6eb26fdSRiver Riddle 31812a3878eaSButygin bool ConversionTarget::isIllegal(Operation *op) const { 31820de16fafSRamkumar Ramachandra std::optional<LegalizationInfo> info = getOpInfo(op->getName()); 31832a3878eaSButygin if (!info) 31842a3878eaSButygin return false; 31852a3878eaSButygin 31862a3878eaSButygin if (info->action == LegalizationAction::Dynamic) { 31870de16fafSRamkumar Ramachandra std::optional<bool> result = info->legalityFn(op); 31882a3878eaSButygin if (!result) 31892a3878eaSButygin return false; 31902a3878eaSButygin 31912a3878eaSButygin return !(*result); 31922a3878eaSButygin } 31932a3878eaSButygin 31942a3878eaSButygin return info->action == LegalizationAction::Illegal; 31952a3878eaSButygin } 31962a3878eaSButygin 3197c6828e0cSCaitlyn Cano static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks( 3198c6828e0cSCaitlyn Cano ConversionTarget::DynamicLegalityCallbackFn oldCallback, 3199c6828e0cSCaitlyn Cano ConversionTarget::DynamicLegalityCallbackFn newCallback) { 3200c6828e0cSCaitlyn Cano if (!oldCallback) 3201c6828e0cSCaitlyn Cano return newCallback; 3202c6828e0cSCaitlyn Cano 3203c6828e0cSCaitlyn Cano auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)]( 32040de16fafSRamkumar Ramachandra Operation *op) -> std::optional<bool> { 32050de16fafSRamkumar Ramachandra if (std::optional<bool> result = newCl(op)) 3206c6828e0cSCaitlyn Cano return *result; 3207c6828e0cSCaitlyn Cano 3208c6828e0cSCaitlyn Cano return oldCl(op); 3209c6828e0cSCaitlyn Cano }; 3210c6828e0cSCaitlyn Cano return chain; 3211c6828e0cSCaitlyn Cano } 3212c6828e0cSCaitlyn Cano 3213b6eb26fdSRiver Riddle void ConversionTarget::setLegalityCallback( 3214b6eb26fdSRiver Riddle OperationName name, const DynamicLegalityCallbackFn &callback) { 3215b6eb26fdSRiver Riddle assert(callback && "expected valid legality callback"); 32167dad59f0SMehdi Amini auto *infoIt = legalOperations.find(name); 3217b6eb26fdSRiver Riddle assert(infoIt != legalOperations.end() && 3218b6eb26fdSRiver Riddle infoIt->second.action == LegalizationAction::Dynamic && 3219b6eb26fdSRiver Riddle "expected operation to already be marked as dynamically legal"); 3220c6828e0cSCaitlyn Cano infoIt->second.legalityFn = 3221c6828e0cSCaitlyn Cano composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback); 3222b6eb26fdSRiver Riddle } 3223b6eb26fdSRiver Riddle 3224b6eb26fdSRiver Riddle void ConversionTarget::markOpRecursivelyLegal( 3225b6eb26fdSRiver Riddle OperationName name, const DynamicLegalityCallbackFn &callback) { 32267dad59f0SMehdi Amini auto *infoIt = legalOperations.find(name); 3227b6eb26fdSRiver Riddle assert(infoIt != legalOperations.end() && 3228b6eb26fdSRiver Riddle infoIt->second.action != LegalizationAction::Illegal && 3229b6eb26fdSRiver Riddle "expected operation to already be marked as legal"); 3230b6eb26fdSRiver Riddle infoIt->second.isRecursivelyLegal = true; 3231b6eb26fdSRiver Riddle if (callback) 3232c6828e0cSCaitlyn Cano opRecursiveLegalityFns[name] = composeLegalityCallbacks( 3233c6828e0cSCaitlyn Cano std::move(opRecursiveLegalityFns[name]), callback); 3234b6eb26fdSRiver Riddle else 3235b6eb26fdSRiver Riddle opRecursiveLegalityFns.erase(name); 3236b6eb26fdSRiver Riddle } 3237b6eb26fdSRiver Riddle 3238b6eb26fdSRiver Riddle void ConversionTarget::setLegalityCallback( 3239b6eb26fdSRiver Riddle ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) { 3240b6eb26fdSRiver Riddle assert(callback && "expected valid legality callback"); 3241b6eb26fdSRiver Riddle for (StringRef dialect : dialects) 3242c6828e0cSCaitlyn Cano dialectLegalityFns[dialect] = composeLegalityCallbacks( 3243c6828e0cSCaitlyn Cano std::move(dialectLegalityFns[dialect]), callback); 3244b6eb26fdSRiver Riddle } 3245b6eb26fdSRiver Riddle 3246b7a46498SButygin void ConversionTarget::setLegalityCallback( 3247b7a46498SButygin const DynamicLegalityCallbackFn &callback) { 3248b7a46498SButygin assert(callback && "expected valid legality callback"); 3249c6828e0cSCaitlyn Cano unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback); 3250b7a46498SButygin } 3251b7a46498SButygin 3252b6eb26fdSRiver Riddle auto ConversionTarget::getOpInfo(OperationName op) const 32530de16fafSRamkumar Ramachandra -> std::optional<LegalizationInfo> { 3254b6eb26fdSRiver Riddle // Check for info for this specific operation. 32557dad59f0SMehdi Amini const auto *it = legalOperations.find(op); 3256b6eb26fdSRiver Riddle if (it != legalOperations.end()) 3257b6eb26fdSRiver Riddle return it->second; 3258b6eb26fdSRiver Riddle // Check for info for the parent dialect. 3259e6260ad0SRiver Riddle auto dialectIt = legalDialects.find(op.getDialectNamespace()); 3260b6eb26fdSRiver Riddle if (dialectIt != legalDialects.end()) { 3261b7a46498SButygin DynamicLegalityCallbackFn callback; 3262e6260ad0SRiver Riddle auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace()); 3263b6eb26fdSRiver Riddle if (dialectFn != dialectLegalityFns.end()) 3264b6eb26fdSRiver Riddle callback = dialectFn->second; 3265b6eb26fdSRiver Riddle return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false, 3266b6eb26fdSRiver Riddle callback}; 3267b6eb26fdSRiver Riddle } 3268b6eb26fdSRiver Riddle // Otherwise, check if we mark unknown operations as dynamic. 3269b7a46498SButygin if (unknownLegalityFn) 3270b6eb26fdSRiver Riddle return LegalizationInfo{LegalizationAction::Dynamic, 3271b6eb26fdSRiver Riddle /*isRecursivelyLegal=*/false, unknownLegalityFn}; 32721a36588eSKazu Hirata return std::nullopt; 3273b6eb26fdSRiver Riddle } 3274b6eb26fdSRiver Riddle 32756ae7f66fSJacques Pienaar #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 3276b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 32778c66344eSRiver Riddle // PDL Configuration 32788c66344eSRiver Riddle //===----------------------------------------------------------------------===// 32798c66344eSRiver Riddle 32808c66344eSRiver Riddle void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) { 32818c66344eSRiver Riddle auto &rewriterImpl = 32828c66344eSRiver Riddle static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); 32838c66344eSRiver Riddle rewriterImpl.currentTypeConverter = getTypeConverter(); 32848c66344eSRiver Riddle } 32858c66344eSRiver Riddle 32868c66344eSRiver Riddle void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) { 32878c66344eSRiver Riddle auto &rewriterImpl = 32888c66344eSRiver Riddle static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); 32898c66344eSRiver Riddle rewriterImpl.currentTypeConverter = nullptr; 32908c66344eSRiver Riddle } 32918c66344eSRiver Riddle 32928c66344eSRiver Riddle /// Remap the given value using the rewriter and the type converter in the 32938c66344eSRiver Riddle /// provided config. 32948c66344eSRiver Riddle static FailureOr<SmallVector<Value>> 32958c66344eSRiver Riddle pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) { 32968c66344eSRiver Riddle SmallVector<Value> mappedValues; 32978c66344eSRiver Riddle if (failed(rewriter.getRemappedValues(values, mappedValues))) 32988c66344eSRiver Riddle return failure(); 32998c66344eSRiver Riddle return std::move(mappedValues); 33008c66344eSRiver Riddle } 33018c66344eSRiver Riddle 33028c66344eSRiver Riddle void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { 33038c66344eSRiver Riddle patterns.getPDLPatterns().registerRewriteFunction( 33048c66344eSRiver Riddle "convertValue", 33058c66344eSRiver Riddle [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> { 33068c66344eSRiver Riddle auto results = pdllConvertValues( 33078c66344eSRiver Riddle static_cast<ConversionPatternRewriter &>(rewriter), value); 33088c66344eSRiver Riddle if (failed(results)) 33098c66344eSRiver Riddle return failure(); 33108c66344eSRiver Riddle return results->front(); 33118c66344eSRiver Riddle }); 33128c66344eSRiver Riddle patterns.getPDLPatterns().registerRewriteFunction( 33138c66344eSRiver Riddle "convertValues", [](PatternRewriter &rewriter, ValueRange values) { 33148c66344eSRiver Riddle return pdllConvertValues( 33158c66344eSRiver Riddle static_cast<ConversionPatternRewriter &>(rewriter), values); 33168c66344eSRiver Riddle }); 33178c66344eSRiver Riddle patterns.getPDLPatterns().registerRewriteFunction( 33188c66344eSRiver Riddle "convertType", 33198c66344eSRiver Riddle [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> { 33208c66344eSRiver Riddle auto &rewriterImpl = 33218c66344eSRiver Riddle static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); 3322ce254598SMatthias Springer if (const TypeConverter *converter = 3323ce254598SMatthias Springer rewriterImpl.currentTypeConverter) { 33248c66344eSRiver Riddle if (Type newType = converter->convertType(type)) 33258c66344eSRiver Riddle return newType; 33268c66344eSRiver Riddle return failure(); 33278c66344eSRiver Riddle } 33288c66344eSRiver Riddle return type; 33298c66344eSRiver Riddle }); 33308c66344eSRiver Riddle patterns.getPDLPatterns().registerRewriteFunction( 33318c66344eSRiver Riddle "convertTypes", 33328c66344eSRiver Riddle [](PatternRewriter &rewriter, 33338c66344eSRiver Riddle TypeRange types) -> FailureOr<SmallVector<Type>> { 33348c66344eSRiver Riddle auto &rewriterImpl = 33358c66344eSRiver Riddle static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); 3336ce254598SMatthias Springer const TypeConverter *converter = rewriterImpl.currentTypeConverter; 33378c66344eSRiver Riddle if (!converter) 33388c66344eSRiver Riddle return SmallVector<Type>(types); 33398c66344eSRiver Riddle 33408c66344eSRiver Riddle SmallVector<Type> remappedTypes; 33418c66344eSRiver Riddle if (failed(converter->convertTypes(types, remappedTypes))) 33428c66344eSRiver Riddle return failure(); 33438c66344eSRiver Riddle return std::move(remappedTypes); 33448c66344eSRiver Riddle }); 33458c66344eSRiver Riddle } 33466ae7f66fSJacques Pienaar #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 33478c66344eSRiver Riddle 33488c66344eSRiver Riddle //===----------------------------------------------------------------------===// 3349b6eb26fdSRiver Riddle // Op Conversion Entry Points 3350b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 3351b6eb26fdSRiver Riddle 335201b55f16SRiver Riddle //===----------------------------------------------------------------------===// 335301b55f16SRiver Riddle // Partial Conversion 335401b55f16SRiver Riddle 3355a2821094SMatthias Springer LogicalResult mlir::applyPartialConversion( 3356a2821094SMatthias Springer ArrayRef<Operation *> ops, const ConversionTarget &target, 3357a2821094SMatthias Springer const FrozenRewritePatternSet &patterns, ConversionConfig config) { 3358a2821094SMatthias Springer OperationConverter opConverter(target, patterns, config, 3359a2821094SMatthias Springer OpConversionMode::Partial); 3360b6eb26fdSRiver Riddle return opConverter.convertOperations(ops); 3361b6eb26fdSRiver Riddle } 3362b6eb26fdSRiver Riddle LogicalResult 3363370a6f09SMehdi Amini mlir::applyPartialConversion(Operation *op, const ConversionTarget &target, 336479d7f618SChris Lattner const FrozenRewritePatternSet &patterns, 3365a2821094SMatthias Springer ConversionConfig config) { 3366a2821094SMatthias Springer return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config); 3367b6eb26fdSRiver Riddle } 3368b6eb26fdSRiver Riddle 336901b55f16SRiver Riddle //===----------------------------------------------------------------------===// 337001b55f16SRiver Riddle // Full Conversion 337101b55f16SRiver Riddle 3372a2821094SMatthias Springer LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops, 3373e214f004SMatthias Springer const ConversionTarget &target, 3374a2821094SMatthias Springer const FrozenRewritePatternSet &patterns, 3375a2821094SMatthias Springer ConversionConfig config) { 3376a2821094SMatthias Springer OperationConverter opConverter(target, patterns, config, 3377a2821094SMatthias Springer OpConversionMode::Full); 3378b6eb26fdSRiver Riddle return opConverter.convertOperations(ops); 3379b6eb26fdSRiver Riddle } 3380a2821094SMatthias Springer LogicalResult mlir::applyFullConversion(Operation *op, 3381a2821094SMatthias Springer const ConversionTarget &target, 3382a2821094SMatthias Springer const FrozenRewritePatternSet &patterns, 3383a2821094SMatthias Springer ConversionConfig config) { 3384a2821094SMatthias Springer return applyFullConversion(llvm::ArrayRef(op), target, patterns, config); 3385b6eb26fdSRiver Riddle } 3386b6eb26fdSRiver Riddle 338701b55f16SRiver Riddle //===----------------------------------------------------------------------===// 338801b55f16SRiver Riddle // Analysis Conversion 338901b55f16SRiver Riddle 33905030deadSMatthias Springer /// Find a common IsolatedFromAbove ancestor of the given ops. If at least one 33915030deadSMatthias Springer /// op is a top-level module op (which is expected to be isolated from above), 33925030deadSMatthias Springer /// return that op. 33935030deadSMatthias Springer static Operation *findCommonAncestor(ArrayRef<Operation *> ops) { 33945030deadSMatthias Springer // Check if there is a top-level operation within `ops`. If so, return that 33955030deadSMatthias Springer // op. 33965030deadSMatthias Springer for (Operation *op : ops) { 33975030deadSMatthias Springer if (!op->getParentOp()) { 33985030deadSMatthias Springer #ifndef NDEBUG 33995030deadSMatthias Springer assert(op->hasTrait<OpTrait::IsIsolatedFromAbove>() && 34005030deadSMatthias Springer "expected top-level op to be isolated from above"); 34015030deadSMatthias Springer for (Operation *other : ops) 34025030deadSMatthias Springer assert(op->isAncestor(other) && 34035030deadSMatthias Springer "expected ops to have a common ancestor"); 34045030deadSMatthias Springer #endif // NDEBUG 34055030deadSMatthias Springer return op; 34065030deadSMatthias Springer } 34075030deadSMatthias Springer } 34085030deadSMatthias Springer 34095030deadSMatthias Springer // No top-level op. Find a common ancestor. 34105030deadSMatthias Springer Operation *commonAncestor = 34115030deadSMatthias Springer ops.front()->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(); 34125030deadSMatthias Springer for (Operation *op : ops.drop_front()) { 34135030deadSMatthias Springer while (!commonAncestor->isProperAncestor(op)) { 34145030deadSMatthias Springer commonAncestor = 34155030deadSMatthias Springer commonAncestor->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(); 34165030deadSMatthias Springer assert(commonAncestor && 34175030deadSMatthias Springer "expected to find a common isolated from above ancestor"); 34185030deadSMatthias Springer } 34195030deadSMatthias Springer } 34205030deadSMatthias Springer 34215030deadSMatthias Springer return commonAncestor; 34225030deadSMatthias Springer } 34235030deadSMatthias Springer 3424a2821094SMatthias Springer LogicalResult mlir::applyAnalysisConversion( 3425a2821094SMatthias Springer ArrayRef<Operation *> ops, ConversionTarget &target, 3426a2821094SMatthias Springer const FrozenRewritePatternSet &patterns, ConversionConfig config) { 34275030deadSMatthias Springer #ifndef NDEBUG 34285030deadSMatthias Springer if (config.legalizableOps) 34295030deadSMatthias Springer assert(config.legalizableOps->empty() && "expected empty set"); 34305030deadSMatthias Springer #endif // NDEBUG 34315030deadSMatthias Springer 34325030deadSMatthias Springer // Clone closted common ancestor that is isolated from above. 34335030deadSMatthias Springer Operation *commonAncestor = findCommonAncestor(ops); 34345030deadSMatthias Springer IRMapping mapping; 34355030deadSMatthias Springer Operation *clonedAncestor = commonAncestor->clone(mapping); 34365030deadSMatthias Springer // Compute inverse IR mapping. 34375030deadSMatthias Springer DenseMap<Operation *, Operation *> inverseOperationMap; 34385030deadSMatthias Springer for (auto &it : mapping.getOperationMap()) 34395030deadSMatthias Springer inverseOperationMap[it.second] = it.first; 34405030deadSMatthias Springer 34415030deadSMatthias Springer // Convert the cloned operations. The original IR will remain unchanged. 34425030deadSMatthias Springer SmallVector<Operation *> opsToConvert = llvm::map_to_vector( 34435030deadSMatthias Springer ops, [&](Operation *op) { return mapping.lookup(op); }); 3444a2821094SMatthias Springer OperationConverter opConverter(target, patterns, config, 3445a2821094SMatthias Springer OpConversionMode::Analysis); 34465030deadSMatthias Springer LogicalResult status = opConverter.convertOperations(opsToConvert); 34475030deadSMatthias Springer 34485030deadSMatthias Springer // Remap `legalizableOps`, so that they point to the original ops and not the 34495030deadSMatthias Springer // cloned ops. 34505030deadSMatthias Springer if (config.legalizableOps) { 34515030deadSMatthias Springer DenseSet<Operation *> originalLegalizableOps; 34525030deadSMatthias Springer for (Operation *op : *config.legalizableOps) 34535030deadSMatthias Springer originalLegalizableOps.insert(inverseOperationMap[op]); 34545030deadSMatthias Springer *config.legalizableOps = std::move(originalLegalizableOps); 3455b6eb26fdSRiver Riddle } 34565030deadSMatthias Springer 34575030deadSMatthias Springer // Erase the cloned IR. 34585030deadSMatthias Springer clonedAncestor->erase(); 34595030deadSMatthias Springer return status; 34605030deadSMatthias Springer } 34615030deadSMatthias Springer 3462b6eb26fdSRiver Riddle LogicalResult 3463b6eb26fdSRiver Riddle mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, 346479d7f618SChris Lattner const FrozenRewritePatternSet &patterns, 3465a2821094SMatthias Springer ConversionConfig config) { 3466a2821094SMatthias Springer return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config); 3467b6eb26fdSRiver Riddle } 3468