1b6eb26fdSRiver Riddle //===- DialectConversion.cpp - MLIR dialect conversion generic pass -------===// 2b6eb26fdSRiver Riddle // 3b6eb26fdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4b6eb26fdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5b6eb26fdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6b6eb26fdSRiver Riddle // 7b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 8b6eb26fdSRiver Riddle 9b6eb26fdSRiver Riddle #include "mlir/Transforms/DialectConversion.h" 106ae7f66fSJacques Pienaar #include "mlir/Config/mlir-config.h" 11b6eb26fdSRiver Riddle #include "mlir/IR/Block.h" 12b6eb26fdSRiver Riddle #include "mlir/IR/Builders.h" 1365fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h" 144d67b278SJeff Niu #include "mlir/IR/IRMapping.h" 15b884f4efSMatthias Springer #include "mlir/IR/Iterators.h" 1634a35a8bSMartin Erhart #include "mlir/Interfaces/FunctionInterfaces.h" 17b6eb26fdSRiver Riddle #include "mlir/Rewrite/PatternApplicator.h" 189c5982efSAlex Zinenko #include "llvm/ADT/ScopeExit.h" 19b6eb26fdSRiver Riddle #include "llvm/ADT/SetVector.h" 20b6eb26fdSRiver Riddle #include "llvm/ADT/SmallPtrSet.h" 21b6eb26fdSRiver Riddle #include "llvm/Support/Debug.h" 22b6eb26fdSRiver Riddle #include "llvm/Support/FormatVariadic.h" 23b6eb26fdSRiver Riddle #include "llvm/Support/SaveAndRestore.h" 24b6eb26fdSRiver Riddle #include "llvm/Support/ScopedPrinter.h" 2505423905SKazu Hirata #include <optional> 26b6eb26fdSRiver Riddle 27b6eb26fdSRiver Riddle using namespace mlir; 28b6eb26fdSRiver Riddle using namespace mlir::detail; 29b6eb26fdSRiver Riddle 30b6eb26fdSRiver Riddle #define DEBUG_TYPE "dialect-conversion" 31b6eb26fdSRiver Riddle 32b6eb26fdSRiver Riddle /// A utility function to log a successful result for the given reason. 33b6eb26fdSRiver Riddle template <typename... Args> 344efb7754SRiver Riddle static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { 35b6eb26fdSRiver Riddle LLVM_DEBUG({ 36b6eb26fdSRiver Riddle os.unindent(); 37b6eb26fdSRiver Riddle os.startLine() << "} -> SUCCESS"; 38b6eb26fdSRiver Riddle if (!fmt.empty()) 39b6eb26fdSRiver Riddle os.getOStream() << " : " 40b6eb26fdSRiver Riddle << llvm::formatv(fmt.data(), std::forward<Args>(args)...); 41b6eb26fdSRiver Riddle os.getOStream() << "\n"; 42b6eb26fdSRiver Riddle }); 43b6eb26fdSRiver Riddle } 44b6eb26fdSRiver Riddle 45b6eb26fdSRiver Riddle /// A utility function to log a failure result for the given reason. 46b6eb26fdSRiver Riddle template <typename... Args> 474efb7754SRiver Riddle static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { 48b6eb26fdSRiver Riddle LLVM_DEBUG({ 49b6eb26fdSRiver Riddle os.unindent(); 50b6eb26fdSRiver Riddle os.startLine() << "} -> FAILURE : " 51b6eb26fdSRiver Riddle << llvm::formatv(fmt.data(), std::forward<Args>(args)...) 52b6eb26fdSRiver Riddle << "\n"; 53b6eb26fdSRiver Riddle }); 54b6eb26fdSRiver Riddle } 55b6eb26fdSRiver Riddle 56b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 57b6eb26fdSRiver Riddle // ConversionValueMapping 58b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 59b6eb26fdSRiver Riddle 60b6eb26fdSRiver Riddle namespace { 614d67b278SJeff Niu /// This class wraps a IRMapping to provide recursive lookup 62b6eb26fdSRiver Riddle /// functionality, i.e. we will traverse if the mapped value also has a mapping. 63b6eb26fdSRiver Riddle struct ConversionValueMapping { 64b6eb26fdSRiver Riddle /// Lookup a mapped value within the map. If a mapping for the provided value 659ca67d7fSAlexander Belyaev /// does not exist then return the provided value. If `desiredType` is 669ca67d7fSAlexander Belyaev /// non-null, returns the most recently mapped value with that type. If an 679ca67d7fSAlexander Belyaev /// operand of that type does not exist, defaults to normal behavior. 689ca67d7fSAlexander Belyaev Value lookupOrDefault(Value from, Type desiredType = nullptr) const; 69b6eb26fdSRiver Riddle 70b6eb26fdSRiver Riddle /// Lookup a mapped value within the map, or return null if a mapping does not 71b6eb26fdSRiver Riddle /// exist. If a mapping exists, this follows the same behavior of 72b6eb26fdSRiver Riddle /// `lookupOrDefault`. 73015192c6SRiver Riddle Value lookupOrNull(Value from, Type desiredType = nullptr) const; 74b6eb26fdSRiver Riddle 75b6eb26fdSRiver Riddle /// Map a value to the one provided. 76015192c6SRiver Riddle void map(Value oldVal, Value newVal) { 77015192c6SRiver Riddle LLVM_DEBUG({ 78015192c6SRiver Riddle for (Value it = newVal; it; it = mapping.lookupOrNull(it)) 79015192c6SRiver Riddle assert(it != oldVal && "inserting cyclic mapping"); 80015192c6SRiver Riddle }); 81015192c6SRiver Riddle mapping.map(oldVal, newVal); 82015192c6SRiver Riddle } 83015192c6SRiver Riddle 84015192c6SRiver Riddle /// Try to map a value to the one provided. Returns false if a transitive 85015192c6SRiver Riddle /// mapping from the new value to the old value already exists, true if the 86015192c6SRiver Riddle /// map was updated. 87015192c6SRiver Riddle bool tryMap(Value oldVal, Value newVal); 88b6eb26fdSRiver Riddle 89b6eb26fdSRiver Riddle /// Drop the last mapping for the given value. 90b6eb26fdSRiver Riddle void erase(Value value) { mapping.erase(value); } 91b6eb26fdSRiver Riddle 925b91060dSAlex Zinenko /// Returns the inverse raw value mapping (without recursive query support). 93015192c6SRiver Riddle DenseMap<Value, SmallVector<Value>> getInverse() const { 94015192c6SRiver Riddle DenseMap<Value, SmallVector<Value>> inverse; 95015192c6SRiver Riddle for (auto &it : mapping.getValueMap()) 96015192c6SRiver Riddle inverse[it.second].push_back(it.first); 97015192c6SRiver Riddle return inverse; 98015192c6SRiver Riddle } 995b91060dSAlex Zinenko 100b6eb26fdSRiver Riddle private: 101b6eb26fdSRiver Riddle /// Current value mappings. 1024d67b278SJeff Niu IRMapping mapping; 103b6eb26fdSRiver Riddle }; 104be0a7e9fSMehdi Amini } // namespace 105b6eb26fdSRiver Riddle 1069ca67d7fSAlexander Belyaev Value ConversionValueMapping::lookupOrDefault(Value from, 1079ca67d7fSAlexander Belyaev Type desiredType) const { 1089ca67d7fSAlexander Belyaev // If there was no desired type, simply find the leaf value. 1099ca67d7fSAlexander Belyaev if (!desiredType) { 110b6eb26fdSRiver Riddle // If this value had a valid mapping, unmap that value as well in the case 111b6eb26fdSRiver Riddle // that it was also replaced. 112b6eb26fdSRiver Riddle while (auto mappedValue = mapping.lookupOrNull(from)) 113b6eb26fdSRiver Riddle from = mappedValue; 114b6eb26fdSRiver Riddle return from; 115b6eb26fdSRiver Riddle } 116b6eb26fdSRiver Riddle 1179ca67d7fSAlexander Belyaev // Otherwise, try to find the deepest value that has the desired type. 1189ca67d7fSAlexander Belyaev Value desiredValue; 119b6eb26fdSRiver Riddle do { 1209ca67d7fSAlexander Belyaev if (from.getType() == desiredType) 1219ca67d7fSAlexander Belyaev desiredValue = from; 122b6eb26fdSRiver Riddle 123b6eb26fdSRiver Riddle Value mappedValue = mapping.lookupOrNull(from); 124b6eb26fdSRiver Riddle if (!mappedValue) 125b6eb26fdSRiver Riddle break; 126b6eb26fdSRiver Riddle from = mappedValue; 127b6eb26fdSRiver Riddle } while (true); 128b6eb26fdSRiver Riddle 129b6eb26fdSRiver Riddle // If the desired value was found use it, otherwise default to the leaf value. 1309ca67d7fSAlexander Belyaev return desiredValue ? desiredValue : from; 131b6eb26fdSRiver Riddle } 132b6eb26fdSRiver Riddle 133015192c6SRiver Riddle Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const { 134015192c6SRiver Riddle Value result = lookupOrDefault(from, desiredType); 135015192c6SRiver Riddle if (result == from || (desiredType && result.getType() != desiredType)) 136015192c6SRiver Riddle return nullptr; 137015192c6SRiver Riddle return result; 138015192c6SRiver Riddle } 139015192c6SRiver Riddle 140015192c6SRiver Riddle bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) { 141015192c6SRiver Riddle for (Value it = newVal; it; it = mapping.lookupOrNull(it)) 142015192c6SRiver Riddle if (it == oldVal) 143015192c6SRiver Riddle return false; 144015192c6SRiver Riddle map(oldVal, newVal); 145015192c6SRiver Riddle return true; 146b6eb26fdSRiver Riddle } 147b6eb26fdSRiver Riddle 148b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 14901b55f16SRiver Riddle // Rewriter and Translation State 15001b55f16SRiver Riddle //===----------------------------------------------------------------------===// 15101b55f16SRiver Riddle namespace { 15201b55f16SRiver Riddle /// This class contains a snapshot of the current conversion rewriter state. 15301b55f16SRiver Riddle /// This is useful when saving and undoing a set of rewrites. 15401b55f16SRiver Riddle struct RewriterState { 155015192c6SRiver Riddle RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, 156015192c6SRiver Riddle unsigned numReplacements, unsigned numArgReplacements, 157e214f004SMatthias Springer unsigned numRewrites, unsigned numIgnoredOperations) 158015192c6SRiver Riddle : numCreatedOps(numCreatedOps), 159015192c6SRiver Riddle numUnresolvedMaterializations(numUnresolvedMaterializations), 160015192c6SRiver Riddle numReplacements(numReplacements), 1618faefe36SMatthias Springer numArgReplacements(numArgReplacements), numRewrites(numRewrites), 162e214f004SMatthias Springer numIgnoredOperations(numIgnoredOperations) {} 16301b55f16SRiver Riddle 16401b55f16SRiver Riddle /// The current number of created operations. 16501b55f16SRiver Riddle unsigned numCreatedOps; 16601b55f16SRiver Riddle 167015192c6SRiver Riddle /// The current number of unresolved materializations. 168015192c6SRiver Riddle unsigned numUnresolvedMaterializations; 169015192c6SRiver Riddle 17001b55f16SRiver Riddle /// The current number of replacements queued. 17101b55f16SRiver Riddle unsigned numReplacements; 17201b55f16SRiver Riddle 17301b55f16SRiver Riddle /// The current number of argument replacements queued. 17401b55f16SRiver Riddle unsigned numArgReplacements; 17501b55f16SRiver Riddle 1768faefe36SMatthias Springer /// The current number of rewrites performed. 1778faefe36SMatthias Springer unsigned numRewrites; 17801b55f16SRiver Riddle 17901b55f16SRiver Riddle /// The current number of ignored operations. 18001b55f16SRiver Riddle unsigned numIgnoredOperations; 18101b55f16SRiver Riddle }; 18201b55f16SRiver Riddle 18301b55f16SRiver Riddle //===----------------------------------------------------------------------===// 18401b55f16SRiver Riddle // OpReplacement 18501b55f16SRiver Riddle 18601b55f16SRiver Riddle /// This class represents one requested operation replacement via 'replaceOp' or 18701b55f16SRiver Riddle /// 'eraseOp`. 18801b55f16SRiver Riddle struct OpReplacement { 189ce254598SMatthias Springer OpReplacement(const TypeConverter *converter = nullptr) 190ce254598SMatthias Springer : converter(converter) {} 19101b55f16SRiver Riddle 19201b55f16SRiver Riddle /// An optional type converter that can be used to materialize conversions 19301b55f16SRiver Riddle /// between the new and old values if necessary. 194ce254598SMatthias Springer const TypeConverter *converter; 19501b55f16SRiver Riddle }; 19601b55f16SRiver Riddle 19701b55f16SRiver Riddle //===----------------------------------------------------------------------===// 198015192c6SRiver Riddle // UnresolvedMaterialization 199015192c6SRiver Riddle 200015192c6SRiver Riddle /// This class represents an unresolved materialization, i.e. a materialization 201015192c6SRiver Riddle /// that was inserted during conversion that needs to be legalized at the end of 202015192c6SRiver Riddle /// the conversion process. 203015192c6SRiver Riddle class UnresolvedMaterialization { 204015192c6SRiver Riddle public: 205015192c6SRiver Riddle /// The type of materialization. 206015192c6SRiver Riddle enum Kind { 207015192c6SRiver Riddle /// This materialization materializes a conversion for an illegal block 208015192c6SRiver Riddle /// argument type, to a legal one. 209015192c6SRiver Riddle Argument, 210015192c6SRiver Riddle 211015192c6SRiver Riddle /// This materialization materializes a conversion from an illegal type to a 212015192c6SRiver Riddle /// legal one. 213015192c6SRiver Riddle Target 214015192c6SRiver Riddle }; 215015192c6SRiver Riddle 216015192c6SRiver Riddle UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr, 217ce254598SMatthias Springer const TypeConverter *converter = nullptr, 218015192c6SRiver Riddle Kind kind = Target, Type origOutputType = nullptr) 219015192c6SRiver Riddle : op(op), converterAndKind(converter, kind), 220015192c6SRiver Riddle origOutputType(origOutputType) {} 221015192c6SRiver Riddle 222015192c6SRiver Riddle /// Return the temporary conversion operation inserted for this 223015192c6SRiver Riddle /// materialization. 224015192c6SRiver Riddle UnrealizedConversionCastOp getOp() const { return op; } 225015192c6SRiver Riddle 226015192c6SRiver Riddle /// Return the type converter of this materialization (which may be null). 227ce254598SMatthias Springer const TypeConverter *getConverter() const { 228ce254598SMatthias Springer return converterAndKind.getPointer(); 229ce254598SMatthias Springer } 230015192c6SRiver Riddle 231015192c6SRiver Riddle /// Return the kind of this materialization. 232015192c6SRiver Riddle Kind getKind() const { return converterAndKind.getInt(); } 233015192c6SRiver Riddle 234015192c6SRiver Riddle /// Set the kind of this materialization. 235015192c6SRiver Riddle void setKind(Kind kind) { converterAndKind.setInt(kind); } 236015192c6SRiver Riddle 237015192c6SRiver Riddle /// Return the original illegal output type of the input values. 238015192c6SRiver Riddle Type getOrigOutputType() const { return origOutputType; } 239015192c6SRiver Riddle 240015192c6SRiver Riddle private: 241015192c6SRiver Riddle /// The unresolved materialization operation created during conversion. 242015192c6SRiver Riddle UnrealizedConversionCastOp op; 243015192c6SRiver Riddle 244015192c6SRiver Riddle /// The corresponding type converter to use when resolving this 245015192c6SRiver Riddle /// materialization, and the kind of this materialization. 246ce254598SMatthias Springer llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind; 247015192c6SRiver Riddle 248015192c6SRiver Riddle /// The original output type. This is only used for argument conversions. 249015192c6SRiver Riddle Type origOutputType; 250015192c6SRiver Riddle }; 251be0a7e9fSMehdi Amini } // namespace 25201b55f16SRiver Riddle 253015192c6SRiver Riddle /// Build an unresolved materialization operation given an output type and set 254015192c6SRiver Riddle /// of input operands. 255015192c6SRiver Riddle static Value buildUnresolvedMaterialization( 256015192c6SRiver Riddle UnresolvedMaterialization::Kind kind, Block *insertBlock, 257015192c6SRiver Riddle Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, 258ce254598SMatthias Springer Type origOutputType, const TypeConverter *converter, 259015192c6SRiver Riddle SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) { 260015192c6SRiver Riddle // Avoid materializing an unnecessary cast. 261015192c6SRiver Riddle if (inputs.size() == 1 && inputs.front().getType() == outputType) 262015192c6SRiver Riddle return inputs.front(); 263015192c6SRiver Riddle 264015192c6SRiver Riddle // Create an unresolved materialization. We use a new OpBuilder to avoid 265015192c6SRiver Riddle // tracking the materialization like we do for other operations. 266015192c6SRiver Riddle OpBuilder builder(insertBlock, insertPt); 267015192c6SRiver Riddle auto convertOp = 268015192c6SRiver Riddle builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs); 269015192c6SRiver Riddle unresolvedMaterializations.emplace_back(convertOp, converter, kind, 270015192c6SRiver Riddle origOutputType); 271015192c6SRiver Riddle return convertOp.getResult(0); 272015192c6SRiver Riddle } 273015192c6SRiver Riddle static Value buildUnresolvedArgumentMaterialization( 274015192c6SRiver Riddle PatternRewriter &rewriter, Location loc, ValueRange inputs, 275ce254598SMatthias Springer Type origOutputType, Type outputType, const TypeConverter *converter, 276015192c6SRiver Riddle SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) { 277015192c6SRiver Riddle return buildUnresolvedMaterialization( 278015192c6SRiver Riddle UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(), 279015192c6SRiver Riddle rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType, 280015192c6SRiver Riddle converter, unresolvedMaterializations); 281015192c6SRiver Riddle } 282015192c6SRiver Riddle static Value buildUnresolvedTargetMaterialization( 283ce254598SMatthias Springer Location loc, Value input, Type outputType, const TypeConverter *converter, 284015192c6SRiver Riddle SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) { 285015192c6SRiver Riddle Block *insertBlock = input.getParentBlock(); 286015192c6SRiver Riddle Block::iterator insertPt = insertBlock->begin(); 2875550c821STres Popp if (OpResult inputRes = dyn_cast<OpResult>(input)) 288015192c6SRiver Riddle insertPt = ++inputRes.getOwner()->getIterator(); 289015192c6SRiver Riddle 290015192c6SRiver Riddle return buildUnresolvedMaterialization( 291015192c6SRiver Riddle UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input, 292015192c6SRiver Riddle outputType, outputType, converter, unresolvedMaterializations); 293015192c6SRiver Riddle } 294015192c6SRiver Riddle 29501b55f16SRiver Riddle //===----------------------------------------------------------------------===// 296b6eb26fdSRiver Riddle // ArgConverter 297b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 298b6eb26fdSRiver Riddle namespace { 299b6eb26fdSRiver Riddle /// This class provides a simple interface for converting the types of block 300b6eb26fdSRiver Riddle /// arguments. This is done by creating a new block that contains the new legal 301b6eb26fdSRiver Riddle /// types and extracting the block that contains the old illegal types to allow 302b6eb26fdSRiver Riddle /// for undoing pending rewrites in the case of failure. 303b6eb26fdSRiver Riddle struct ArgConverter { 304015192c6SRiver Riddle ArgConverter( 305015192c6SRiver Riddle PatternRewriter &rewriter, 306015192c6SRiver Riddle SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) 307015192c6SRiver Riddle : rewriter(rewriter), 308015192c6SRiver Riddle unresolvedMaterializations(unresolvedMaterializations) {} 309b6eb26fdSRiver Riddle 310b6eb26fdSRiver Riddle /// This structure contains the information pertaining to an argument that has 311b6eb26fdSRiver Riddle /// been converted. 312b6eb26fdSRiver Riddle struct ConvertedArgInfo { 313b6eb26fdSRiver Riddle ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, 314b6eb26fdSRiver Riddle Value castValue = nullptr) 315b6eb26fdSRiver Riddle : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} 316b6eb26fdSRiver Riddle 317b6eb26fdSRiver Riddle /// The start index of in the new argument list that contains arguments that 318b6eb26fdSRiver Riddle /// replace the original. 319b6eb26fdSRiver Riddle unsigned newArgIdx; 320b6eb26fdSRiver Riddle 321b6eb26fdSRiver Riddle /// The number of arguments that replaced the original argument. 322b6eb26fdSRiver Riddle unsigned newArgSize; 323b6eb26fdSRiver Riddle 324b6eb26fdSRiver Riddle /// The cast value that was created to cast from the new arguments to the 325b6eb26fdSRiver Riddle /// old. This only used if 'newArgSize' > 1. 326b6eb26fdSRiver Riddle Value castValue; 327b6eb26fdSRiver Riddle }; 328b6eb26fdSRiver Riddle 329b6eb26fdSRiver Riddle /// This structure contains information pertaining to a block that has had its 330b6eb26fdSRiver Riddle /// signature converted. 331b6eb26fdSRiver Riddle struct ConvertedBlockInfo { 332ce254598SMatthias Springer ConvertedBlockInfo(Block *origBlock, const TypeConverter *converter) 333015192c6SRiver Riddle : origBlock(origBlock), converter(converter) {} 334b6eb26fdSRiver Riddle 335b6eb26fdSRiver Riddle /// The original block that was requested to have its signature converted. 336b6eb26fdSRiver Riddle Block *origBlock; 337b6eb26fdSRiver Riddle 338b6eb26fdSRiver Riddle /// The conversion information for each of the arguments. The information is 3394f81805aSKazu Hirata /// std::nullopt if the argument was dropped during conversion. 3400de16fafSRamkumar Ramachandra SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo; 341b6eb26fdSRiver Riddle 342b6eb26fdSRiver Riddle /// The type converter used to convert the arguments. 343ce254598SMatthias Springer const TypeConverter *converter; 344b6eb26fdSRiver Riddle }; 345b6eb26fdSRiver Riddle 346b6eb26fdSRiver Riddle /// Return if the signature of the given block has already been converted. 347b6eb26fdSRiver Riddle bool hasBeenConverted(Block *block) const { 348b6eb26fdSRiver Riddle return conversionInfo.count(block) || convertedBlocks.count(block); 349b6eb26fdSRiver Riddle } 350b6eb26fdSRiver Riddle 351b6eb26fdSRiver Riddle /// Set the type converter to use for the given region. 352ce254598SMatthias Springer void setConverter(Region *region, const TypeConverter *typeConverter) { 353b6eb26fdSRiver Riddle assert(typeConverter && "expected valid type converter"); 354b6eb26fdSRiver Riddle regionToConverter[region] = typeConverter; 355b6eb26fdSRiver Riddle } 356b6eb26fdSRiver Riddle 357b6eb26fdSRiver Riddle /// Return the type converter to use for the given region, or null if there 358b6eb26fdSRiver Riddle /// isn't one. 359ce254598SMatthias Springer const TypeConverter *getConverter(Region *region) { 360b6eb26fdSRiver Riddle return regionToConverter.lookup(region); 361b6eb26fdSRiver Riddle } 362b6eb26fdSRiver Riddle 363b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 364b6eb26fdSRiver Riddle // Rewrite Application 365b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 366b6eb26fdSRiver Riddle 367b6eb26fdSRiver Riddle /// Erase any rewrites registered for the blocks within the given operation 368b6eb26fdSRiver Riddle /// which is about to be removed. This merely drops the rewrites without 369b6eb26fdSRiver Riddle /// undoing them. 370b6eb26fdSRiver Riddle void notifyOpRemoved(Operation *op); 371b6eb26fdSRiver Riddle 372b6eb26fdSRiver Riddle /// Cleanup and undo any generated conversions for the arguments of block. 373b6eb26fdSRiver Riddle /// This method replaces the new block with the original, reverting the IR to 374b6eb26fdSRiver Riddle /// its original state. 375b6eb26fdSRiver Riddle void discardRewrites(Block *block); 376b6eb26fdSRiver Riddle 377b6eb26fdSRiver Riddle /// Fully replace uses of the old arguments with the new. 378b6eb26fdSRiver Riddle void applyRewrites(ConversionValueMapping &mapping); 379b6eb26fdSRiver Riddle 380b6eb26fdSRiver Riddle /// Materialize any necessary conversions for converted arguments that have 381b6eb26fdSRiver Riddle /// live users, using the provided `findLiveUser` to search for a user that 382b6eb26fdSRiver Riddle /// survives the conversion process. 383b6eb26fdSRiver Riddle LogicalResult 384b6eb26fdSRiver Riddle materializeLiveConversions(ConversionValueMapping &mapping, 385b6eb26fdSRiver Riddle OpBuilder &builder, 386b6eb26fdSRiver Riddle function_ref<Operation *(Value)> findLiveUser); 387b6eb26fdSRiver Riddle 388b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 389b6eb26fdSRiver Riddle // Conversion 390b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 391b6eb26fdSRiver Riddle 392b6eb26fdSRiver Riddle /// Attempt to convert the signature of the given block, if successful a new 393b6eb26fdSRiver Riddle /// block is returned containing the new arguments. Returns `block` if it did 394b6eb26fdSRiver Riddle /// not require conversion. 3950409eb28SAlex Zinenko FailureOr<Block *> 396ce254598SMatthias Springer convertSignature(Block *block, const TypeConverter *converter, 3970409eb28SAlex Zinenko ConversionValueMapping &mapping, 3980409eb28SAlex Zinenko SmallVectorImpl<BlockArgument> &argReplacements); 399b6eb26fdSRiver Riddle 400b6eb26fdSRiver Riddle /// Apply the given signature conversion on the given block. The new block 401b6eb26fdSRiver Riddle /// containing the updated signature is returned. If no conversions were 402b6eb26fdSRiver Riddle /// necessary, e.g. if the block has no arguments, `block` is returned. 403b6eb26fdSRiver Riddle /// `converter` is used to generate any necessary cast operations that 404b6eb26fdSRiver Riddle /// translate between the origin argument types and those specified in the 405b6eb26fdSRiver Riddle /// signature conversion. 406b6eb26fdSRiver Riddle Block *applySignatureConversion( 407ce254598SMatthias Springer Block *block, const TypeConverter *converter, 408b6eb26fdSRiver Riddle TypeConverter::SignatureConversion &signatureConversion, 4090409eb28SAlex Zinenko ConversionValueMapping &mapping, 4100409eb28SAlex Zinenko SmallVectorImpl<BlockArgument> &argReplacements); 411b6eb26fdSRiver Riddle 412b6eb26fdSRiver Riddle /// Insert a new conversion into the cache. 413b6eb26fdSRiver Riddle void insertConversion(Block *newBlock, ConvertedBlockInfo &&info); 414b6eb26fdSRiver Riddle 415b6eb26fdSRiver Riddle /// A collection of blocks that have had their arguments converted. This is a 416b6eb26fdSRiver Riddle /// map from the new replacement block, back to the original block. 417b6eb26fdSRiver Riddle llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo; 418b6eb26fdSRiver Riddle 419b6eb26fdSRiver Riddle /// The set of original blocks that were converted. 420b6eb26fdSRiver Riddle DenseSet<Block *> convertedBlocks; 421b6eb26fdSRiver Riddle 422b6eb26fdSRiver Riddle /// A mapping from valid regions, to those containing the original blocks of a 423b6eb26fdSRiver Riddle /// conversion. 424b6eb26fdSRiver Riddle DenseMap<Region *, std::unique_ptr<Region>> regionMapping; 425b6eb26fdSRiver Riddle 426b6eb26fdSRiver Riddle /// A mapping of regions to type converters that should be used when 427b6eb26fdSRiver Riddle /// converting the arguments of blocks within that region. 428ce254598SMatthias Springer DenseMap<Region *, const TypeConverter *> regionToConverter; 429b6eb26fdSRiver Riddle 430b6eb26fdSRiver Riddle /// The pattern rewriter to use when materializing conversions. 431b6eb26fdSRiver Riddle PatternRewriter &rewriter; 432015192c6SRiver Riddle 433015192c6SRiver Riddle /// An ordered set of unresolved materializations during conversion. 434015192c6SRiver Riddle SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations; 435b6eb26fdSRiver Riddle }; 436be0a7e9fSMehdi Amini } // namespace 437b6eb26fdSRiver Riddle 438b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 439b6eb26fdSRiver Riddle // Rewrite Application 440b6eb26fdSRiver Riddle 441b6eb26fdSRiver Riddle void ArgConverter::notifyOpRemoved(Operation *op) { 442b6eb26fdSRiver Riddle if (conversionInfo.empty()) 443b6eb26fdSRiver Riddle return; 444b6eb26fdSRiver Riddle 445b6eb26fdSRiver Riddle for (Region ®ion : op->getRegions()) { 446b6eb26fdSRiver Riddle for (Block &block : region) { 447b6eb26fdSRiver Riddle // Drop any rewrites from within. 448b6eb26fdSRiver Riddle for (Operation &nestedOp : block) 449b6eb26fdSRiver Riddle if (nestedOp.getNumRegions()) 450b6eb26fdSRiver Riddle notifyOpRemoved(&nestedOp); 451b6eb26fdSRiver Riddle 452b6eb26fdSRiver Riddle // Check if this block was converted. 4537dad59f0SMehdi Amini auto *it = conversionInfo.find(&block); 454b6eb26fdSRiver Riddle if (it == conversionInfo.end()) 455b6eb26fdSRiver Riddle continue; 456b6eb26fdSRiver Riddle 457b6eb26fdSRiver Riddle // Drop all uses of the original arguments and delete the original block. 458b6eb26fdSRiver Riddle Block *origBlock = it->second.origBlock; 459b6eb26fdSRiver Riddle for (BlockArgument arg : origBlock->getArguments()) 460b6eb26fdSRiver Riddle arg.dropAllUses(); 461b6eb26fdSRiver Riddle conversionInfo.erase(it); 462b6eb26fdSRiver Riddle } 463b6eb26fdSRiver Riddle } 464b6eb26fdSRiver Riddle } 465b6eb26fdSRiver Riddle 466b6eb26fdSRiver Riddle void ArgConverter::discardRewrites(Block *block) { 4677dad59f0SMehdi Amini auto *it = conversionInfo.find(block); 468b6eb26fdSRiver Riddle if (it == conversionInfo.end()) 469b6eb26fdSRiver Riddle return; 470b6eb26fdSRiver Riddle Block *origBlock = it->second.origBlock; 471b6eb26fdSRiver Riddle 472b6eb26fdSRiver Riddle // Drop all uses of the new block arguments and replace uses of the new block. 473b6eb26fdSRiver Riddle for (int i = block->getNumArguments() - 1; i >= 0; --i) 474b6eb26fdSRiver Riddle block->getArgument(i).dropAllUses(); 475b6eb26fdSRiver Riddle block->replaceAllUsesWith(origBlock); 476b6eb26fdSRiver Riddle 477b6eb26fdSRiver Riddle // Move the operations back the original block and the delete the new block. 478b6eb26fdSRiver Riddle origBlock->getOperations().splice(origBlock->end(), block->getOperations()); 479b6eb26fdSRiver Riddle origBlock->moveBefore(block); 480b6eb26fdSRiver Riddle block->erase(); 481b6eb26fdSRiver Riddle 482b6eb26fdSRiver Riddle convertedBlocks.erase(origBlock); 483b6eb26fdSRiver Riddle conversionInfo.erase(it); 484b6eb26fdSRiver Riddle } 485b6eb26fdSRiver Riddle 486b6eb26fdSRiver Riddle void ArgConverter::applyRewrites(ConversionValueMapping &mapping) { 487b6eb26fdSRiver Riddle for (auto &info : conversionInfo) { 488b6eb26fdSRiver Riddle ConvertedBlockInfo &blockInfo = info.second; 489b6eb26fdSRiver Riddle Block *origBlock = blockInfo.origBlock; 490b6eb26fdSRiver Riddle 491b6eb26fdSRiver Riddle // Process the remapping for each of the original arguments. 492b6eb26fdSRiver Riddle for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { 4930de16fafSRamkumar Ramachandra std::optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i]; 494b6eb26fdSRiver Riddle BlockArgument origArg = origBlock->getArgument(i); 495b6eb26fdSRiver Riddle 496b6eb26fdSRiver Riddle // Handle the case of a 1->0 value mapping. 497b6eb26fdSRiver Riddle if (!argInfo) { 498015192c6SRiver Riddle if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType())) 499b6eb26fdSRiver Riddle origArg.replaceAllUsesWith(newArg); 500b6eb26fdSRiver Riddle continue; 501b6eb26fdSRiver Riddle } 502b6eb26fdSRiver Riddle 503b6eb26fdSRiver Riddle // Otherwise this is a 1->1+ value mapping. 504b6eb26fdSRiver Riddle Value castValue = argInfo->castValue; 505b6eb26fdSRiver Riddle assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); 506b6eb26fdSRiver Riddle 507b6eb26fdSRiver Riddle // If the argument is still used, replace it with the generated cast. 508015192c6SRiver Riddle if (!origArg.use_empty()) { 509015192c6SRiver Riddle origArg.replaceAllUsesWith( 510015192c6SRiver Riddle mapping.lookupOrDefault(castValue, origArg.getType())); 511015192c6SRiver Riddle } 512b6eb26fdSRiver Riddle } 513b6eb26fdSRiver Riddle } 514b6eb26fdSRiver Riddle } 515b6eb26fdSRiver Riddle 516b6eb26fdSRiver Riddle LogicalResult ArgConverter::materializeLiveConversions( 517b6eb26fdSRiver Riddle ConversionValueMapping &mapping, OpBuilder &builder, 518b6eb26fdSRiver Riddle function_ref<Operation *(Value)> findLiveUser) { 519b6eb26fdSRiver Riddle for (auto &info : conversionInfo) { 520b6eb26fdSRiver Riddle Block *newBlock = info.first; 521b6eb26fdSRiver Riddle ConvertedBlockInfo &blockInfo = info.second; 522b6eb26fdSRiver Riddle Block *origBlock = blockInfo.origBlock; 523b6eb26fdSRiver Riddle 524b6eb26fdSRiver Riddle // Process the remapping for each of the original arguments. 525b6eb26fdSRiver Riddle for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { 526b6eb26fdSRiver Riddle // If the type of this argument changed and the argument is still live, we 527b6eb26fdSRiver Riddle // need to materialize a conversion. 528b6eb26fdSRiver Riddle BlockArgument origArg = origBlock->getArgument(i); 529015192c6SRiver Riddle if (mapping.lookupOrNull(origArg, origArg.getType())) 530b6eb26fdSRiver Riddle continue; 531b6eb26fdSRiver Riddle Operation *liveUser = findLiveUser(origArg); 532b6eb26fdSRiver Riddle if (!liveUser) 533b6eb26fdSRiver Riddle continue; 534b6eb26fdSRiver Riddle 535015192c6SRiver Riddle Value replacementValue = mapping.lookupOrDefault(origArg); 536015192c6SRiver Riddle bool isDroppedArg = replacementValue == origArg; 537015192c6SRiver Riddle if (isDroppedArg) 538b6eb26fdSRiver Riddle rewriter.setInsertionPointToStart(newBlock); 539015192c6SRiver Riddle else 540015192c6SRiver Riddle rewriter.setInsertionPointAfterValue(replacementValue); 541015192c6SRiver Riddle Value newArg; 542015192c6SRiver Riddle if (blockInfo.converter) { 543015192c6SRiver Riddle newArg = blockInfo.converter->materializeSourceConversion( 544b6eb26fdSRiver Riddle rewriter, origArg.getLoc(), origArg.getType(), 545015192c6SRiver Riddle isDroppedArg ? ValueRange() : ValueRange(replacementValue)); 546015192c6SRiver Riddle assert((!newArg || newArg.getType() == origArg.getType()) && 547015192c6SRiver Riddle "materialization hook did not provide a value of the expected " 548015192c6SRiver Riddle "type"); 549015192c6SRiver Riddle } 550b6eb26fdSRiver Riddle if (!newArg) { 551b6eb26fdSRiver Riddle InFlightDiagnostic diag = 552b6eb26fdSRiver Riddle emitError(origArg.getLoc()) 553b6eb26fdSRiver Riddle << "failed to materialize conversion for block argument #" << i 554b6eb26fdSRiver Riddle << " that remained live after conversion, type was " 555b6eb26fdSRiver Riddle << origArg.getType(); 556b6eb26fdSRiver Riddle if (!isDroppedArg) 557015192c6SRiver Riddle diag << ", with target type " << replacementValue.getType(); 558b6eb26fdSRiver Riddle diag.attachNote(liveUser->getLoc()) 559b6eb26fdSRiver Riddle << "see existing live user here: " << *liveUser; 560b6eb26fdSRiver Riddle return failure(); 561b6eb26fdSRiver Riddle } 562b6eb26fdSRiver Riddle mapping.map(origArg, newArg); 563b6eb26fdSRiver Riddle } 564b6eb26fdSRiver Riddle } 565b6eb26fdSRiver Riddle return success(); 566b6eb26fdSRiver Riddle } 567b6eb26fdSRiver Riddle 568b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 569b6eb26fdSRiver Riddle // Conversion 570b6eb26fdSRiver Riddle 5710409eb28SAlex Zinenko FailureOr<Block *> ArgConverter::convertSignature( 572ce254598SMatthias Springer Block *block, const TypeConverter *converter, 573ce254598SMatthias Springer ConversionValueMapping &mapping, 5740409eb28SAlex Zinenko SmallVectorImpl<BlockArgument> &argReplacements) { 575b6eb26fdSRiver Riddle // Check if the block was already converted. If the block is detached, 576b6eb26fdSRiver Riddle // conservatively assume it is going to be deleted. 577b6eb26fdSRiver Riddle if (hasBeenConverted(block) || !block->getParent()) 578b6eb26fdSRiver Riddle return block; 579015192c6SRiver Riddle // If a converter wasn't provided, and the block wasn't already converted, 580015192c6SRiver Riddle // there is nothing we can do. 581015192c6SRiver Riddle if (!converter) 582015192c6SRiver Riddle return failure(); 583b6eb26fdSRiver Riddle 584b6eb26fdSRiver Riddle // Try to convert the signature for the block with the provided converter. 585015192c6SRiver Riddle if (auto conversion = converter->convertBlockSignature(block)) 5860409eb28SAlex Zinenko return applySignatureConversion(block, converter, *conversion, mapping, 5870409eb28SAlex Zinenko argReplacements); 588b6eb26fdSRiver Riddle return failure(); 589b6eb26fdSRiver Riddle } 590b6eb26fdSRiver Riddle 591b6eb26fdSRiver Riddle Block *ArgConverter::applySignatureConversion( 592ce254598SMatthias Springer Block *block, const TypeConverter *converter, 593b6eb26fdSRiver Riddle TypeConverter::SignatureConversion &signatureConversion, 5940409eb28SAlex Zinenko ConversionValueMapping &mapping, 5950409eb28SAlex Zinenko SmallVectorImpl<BlockArgument> &argReplacements) { 596b6eb26fdSRiver Riddle // If no arguments are being changed or added, there is nothing to do. 597b6eb26fdSRiver Riddle unsigned origArgCount = block->getNumArguments(); 598b6eb26fdSRiver Riddle auto convertedTypes = signatureConversion.getConvertedTypes(); 599b6eb26fdSRiver Riddle if (origArgCount == 0 && convertedTypes.empty()) 600b6eb26fdSRiver Riddle return block; 601b6eb26fdSRiver Riddle 602b6eb26fdSRiver Riddle // Split the block at the beginning to get a new block to use for the updated 603b6eb26fdSRiver Riddle // signature. 604b6eb26fdSRiver Riddle Block *newBlock = block->splitBlock(block->begin()); 605b6eb26fdSRiver Riddle block->replaceAllUsesWith(newBlock); 606b6eb26fdSRiver Riddle 6070c46a918SNandor Licker // Map all new arguments to the location of the argument they originate from. 608e084679fSRiver Riddle SmallVector<Location> newLocs(convertedTypes.size(), 609e084679fSRiver Riddle rewriter.getUnknownLoc()); 6100c46a918SNandor Licker for (unsigned i = 0; i < origArgCount; ++i) { 6110c46a918SNandor Licker auto inputMap = signatureConversion.getInputMapping(i); 6120c46a918SNandor Licker if (!inputMap || inputMap->replacementValue) 6130c46a918SNandor Licker continue; 6140c46a918SNandor Licker Location origLoc = block->getArgument(i).getLoc(); 6150c46a918SNandor Licker for (unsigned j = 0; j < inputMap->size; ++j) 6160c46a918SNandor Licker newLocs[inputMap->inputNo + j] = origLoc; 6170c46a918SNandor Licker } 6180c46a918SNandor Licker 619e084679fSRiver Riddle SmallVector<Value, 4> newArgRange( 620e084679fSRiver Riddle newBlock->addArguments(convertedTypes, newLocs)); 621b6eb26fdSRiver Riddle ArrayRef<Value> newArgs(newArgRange); 622b6eb26fdSRiver Riddle 623b6eb26fdSRiver Riddle // Remap each of the original arguments as determined by the signature 624b6eb26fdSRiver Riddle // conversion. 625b6eb26fdSRiver Riddle ConvertedBlockInfo info(block, converter); 626b6eb26fdSRiver Riddle info.argInfo.resize(origArgCount); 627b6eb26fdSRiver Riddle 628b6eb26fdSRiver Riddle OpBuilder::InsertionGuard guard(rewriter); 629b6eb26fdSRiver Riddle rewriter.setInsertionPointToStart(newBlock); 630b6eb26fdSRiver Riddle for (unsigned i = 0; i != origArgCount; ++i) { 631b6eb26fdSRiver Riddle auto inputMap = signatureConversion.getInputMapping(i); 632b6eb26fdSRiver Riddle if (!inputMap) 633b6eb26fdSRiver Riddle continue; 634b6eb26fdSRiver Riddle BlockArgument origArg = block->getArgument(i); 635b6eb26fdSRiver Riddle 636b6eb26fdSRiver Riddle // If inputMap->replacementValue is not nullptr, then the argument is 637b6eb26fdSRiver Riddle // dropped and a replacement value is provided to be the remappedValue. 638b6eb26fdSRiver Riddle if (inputMap->replacementValue) { 639b6eb26fdSRiver Riddle assert(inputMap->size == 0 && 640b6eb26fdSRiver Riddle "invalid to provide a replacement value when the argument isn't " 641b6eb26fdSRiver Riddle "dropped"); 642b6eb26fdSRiver Riddle mapping.map(origArg, inputMap->replacementValue); 6430409eb28SAlex Zinenko argReplacements.push_back(origArg); 644b6eb26fdSRiver Riddle continue; 645b6eb26fdSRiver Riddle } 646b6eb26fdSRiver Riddle 647015192c6SRiver Riddle // Otherwise, this is a 1->1+ mapping. 648b6eb26fdSRiver Riddle auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); 649aa6eb2afSKareemErgawy-TomTom Value newArg; 650aa6eb2afSKareemErgawy-TomTom 651aa6eb2afSKareemErgawy-TomTom // If this is a 1->1 mapping and the types of new and replacement arguments 652aa6eb2afSKareemErgawy-TomTom // match (i.e. it's an identity map), then the argument is mapped to its 653aa6eb2afSKareemErgawy-TomTom // original type. 654015192c6SRiver Riddle // FIXME: We simply pass through the replacement argument if there wasn't a 655015192c6SRiver Riddle // converter, which isn't great as it allows implicit type conversions to 656015192c6SRiver Riddle // appear. We should properly restructure this code to handle cases where a 657015192c6SRiver Riddle // converter isn't provided and also to properly handle the case where an 658015192c6SRiver Riddle // argument materialization is actually a temporary source materialization 659015192c6SRiver Riddle // (e.g. in the case of 1->N). 660015192c6SRiver Riddle if (replArgs.size() == 1 && 661015192c6SRiver Riddle (!converter || replArgs[0].getType() == origArg.getType())) { 66201b55f16SRiver Riddle newArg = replArgs.front(); 663015192c6SRiver Riddle } else { 664015192c6SRiver Riddle Type origOutputType = origArg.getType(); 665aa6eb2afSKareemErgawy-TomTom 666015192c6SRiver Riddle // Legalize the argument output type. 667015192c6SRiver Riddle Type outputType = origOutputType; 668015192c6SRiver Riddle if (Type legalOutputType = converter->convertType(outputType)) 669015192c6SRiver Riddle outputType = legalOutputType; 670015192c6SRiver Riddle 671015192c6SRiver Riddle newArg = buildUnresolvedArgumentMaterialization( 672015192c6SRiver Riddle rewriter, origArg.getLoc(), replArgs, origOutputType, outputType, 673015192c6SRiver Riddle converter, unresolvedMaterializations); 674b6eb26fdSRiver Riddle } 675015192c6SRiver Riddle 676b6eb26fdSRiver Riddle mapping.map(origArg, newArg); 6770409eb28SAlex Zinenko argReplacements.push_back(origArg); 678b6eb26fdSRiver Riddle info.argInfo[i] = 679b6eb26fdSRiver Riddle ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); 680b6eb26fdSRiver Riddle } 681b6eb26fdSRiver Riddle 682b6eb26fdSRiver Riddle // Remove the original block from the region and return the new one. 683b6eb26fdSRiver Riddle insertConversion(newBlock, std::move(info)); 684b6eb26fdSRiver Riddle return newBlock; 685b6eb26fdSRiver Riddle } 686b6eb26fdSRiver Riddle 687b6eb26fdSRiver Riddle void ArgConverter::insertConversion(Block *newBlock, 688b6eb26fdSRiver Riddle ConvertedBlockInfo &&info) { 689b6eb26fdSRiver Riddle // Get a region to insert the old block. 690b6eb26fdSRiver Riddle Region *region = newBlock->getParent(); 691b6eb26fdSRiver Riddle std::unique_ptr<Region> &mappedRegion = regionMapping[region]; 692b6eb26fdSRiver Riddle if (!mappedRegion) 693b6eb26fdSRiver Riddle mappedRegion = std::make_unique<Region>(region->getParentOp()); 694b6eb26fdSRiver Riddle 695b6eb26fdSRiver Riddle // Move the original block to the mapped region and emplace the conversion. 696b6eb26fdSRiver Riddle mappedRegion->getBlocks().splice(mappedRegion->end(), region->getBlocks(), 697b6eb26fdSRiver Riddle info.origBlock->getIterator()); 698b6eb26fdSRiver Riddle convertedBlocks.insert(info.origBlock); 699b6eb26fdSRiver Riddle conversionInfo.insert({newBlock, std::move(info)}); 700b6eb26fdSRiver Riddle } 701b6eb26fdSRiver Riddle 702b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 7038faefe36SMatthias Springer // IR rewrites 7048faefe36SMatthias Springer //===----------------------------------------------------------------------===// 7058faefe36SMatthias Springer 7068faefe36SMatthias Springer namespace { 7078faefe36SMatthias Springer /// An IR rewrite that can be committed (upon success) or rolled back (upon 7088faefe36SMatthias Springer /// failure). 7098faefe36SMatthias Springer /// 7108faefe36SMatthias Springer /// The dialect conversion keeps track of IR modifications (requested by the 7118faefe36SMatthias Springer /// user through the rewriter API) in `IRRewrite` objects. Some kind of rewrites 7128faefe36SMatthias Springer /// are directly applied to the IR as the rewriter API is used, some are applied 7138faefe36SMatthias Springer /// partially, and some are delayed until the `IRRewrite` objects are committed. 7148faefe36SMatthias Springer class IRRewrite { 7158faefe36SMatthias Springer public: 7168faefe36SMatthias Springer /// The kind of the rewrite. Rewrites can be undone if the conversion fails. 717e214f004SMatthias Springer /// Enum values are ordered, so that they can be used in `classof`: first all 718e214f004SMatthias Springer /// block rewrites, then all operation rewrites. 7198faefe36SMatthias Springer enum class Kind { 720e214f004SMatthias Springer // Block rewrites 7218faefe36SMatthias Springer CreateBlock, 7228faefe36SMatthias Springer EraseBlock, 7238faefe36SMatthias Springer InlineBlock, 7248faefe36SMatthias Springer MoveBlock, 7258faefe36SMatthias Springer SplitBlock, 7268f4cd2c7SMatthias Springer BlockTypeConversion, 727e214f004SMatthias Springer // Operation rewrites 728e214f004SMatthias Springer MoveOperation, 729e214f004SMatthias Springer ModifyOperation 7308faefe36SMatthias Springer }; 7318faefe36SMatthias Springer 7328faefe36SMatthias Springer virtual ~IRRewrite() = default; 7338faefe36SMatthias Springer 7348faefe36SMatthias Springer /// Roll back the rewrite. 7358faefe36SMatthias Springer virtual void rollback() = 0; 7368faefe36SMatthias Springer 7378faefe36SMatthias Springer /// Commit the rewrite. 7388faefe36SMatthias Springer virtual void commit() {} 7398faefe36SMatthias Springer 7408faefe36SMatthias Springer Kind getKind() const { return kind; } 7418faefe36SMatthias Springer 7428faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { return true; } 7438faefe36SMatthias Springer 7448faefe36SMatthias Springer protected: 7458faefe36SMatthias Springer IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl) 7468faefe36SMatthias Springer : kind(kind), rewriterImpl(rewriterImpl) {} 7478faefe36SMatthias Springer 7488faefe36SMatthias Springer const Kind kind; 7498faefe36SMatthias Springer ConversionPatternRewriterImpl &rewriterImpl; 7508faefe36SMatthias Springer }; 7518faefe36SMatthias Springer 7528faefe36SMatthias Springer /// A block rewrite. 7538faefe36SMatthias Springer class BlockRewrite : public IRRewrite { 7548faefe36SMatthias Springer public: 7558faefe36SMatthias Springer /// Return the block that this rewrite operates on. 7568faefe36SMatthias Springer Block *getBlock() const { return block; } 7578faefe36SMatthias Springer 7588faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 7598faefe36SMatthias Springer return rewrite->getKind() >= Kind::CreateBlock && 7608faefe36SMatthias Springer rewrite->getKind() <= Kind::BlockTypeConversion; 7618faefe36SMatthias Springer } 7628faefe36SMatthias Springer 7638faefe36SMatthias Springer protected: 7648faefe36SMatthias Springer BlockRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, 7658faefe36SMatthias Springer Block *block) 7668faefe36SMatthias Springer : IRRewrite(kind, rewriterImpl), block(block) {} 7678faefe36SMatthias Springer 7688faefe36SMatthias Springer // The block that this rewrite operates on. 7698faefe36SMatthias Springer Block *block; 7708faefe36SMatthias Springer }; 7718faefe36SMatthias Springer 7728faefe36SMatthias Springer /// Creation of a block. Block creations are immediately reflected in the IR. 7738faefe36SMatthias Springer /// There is no extra work to commit the rewrite. During rollback, the newly 7748faefe36SMatthias Springer /// created block is erased. 7758faefe36SMatthias Springer class CreateBlockRewrite : public BlockRewrite { 7768faefe36SMatthias Springer public: 7778faefe36SMatthias Springer CreateBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block) 7788faefe36SMatthias Springer : BlockRewrite(Kind::CreateBlock, rewriterImpl, block) {} 7798faefe36SMatthias Springer 7808faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 7818faefe36SMatthias Springer return rewrite->getKind() == Kind::CreateBlock; 7828faefe36SMatthias Springer } 7838faefe36SMatthias Springer 7848faefe36SMatthias Springer void rollback() override { 7858faefe36SMatthias Springer // Unlink all of the operations within this block, they will be deleted 7868faefe36SMatthias Springer // separately. 7878faefe36SMatthias Springer auto &blockOps = block->getOperations(); 7888faefe36SMatthias Springer while (!blockOps.empty()) 7898faefe36SMatthias Springer blockOps.remove(blockOps.begin()); 7908faefe36SMatthias Springer block->dropAllDefinedValueUses(); 7918faefe36SMatthias Springer block->erase(); 7928faefe36SMatthias Springer } 7938faefe36SMatthias Springer }; 7948faefe36SMatthias Springer 7958faefe36SMatthias Springer /// Erasure of a block. Block erasures are partially reflected in the IR. Erased 7968faefe36SMatthias Springer /// blocks are immediately unlinked, but only erased when the rewrite is 7978faefe36SMatthias Springer /// committed. This makes it easier to rollback a block erasure: the block is 7988faefe36SMatthias Springer /// simply inserted into its original location. 7998faefe36SMatthias Springer class EraseBlockRewrite : public BlockRewrite { 8008faefe36SMatthias Springer public: 8018faefe36SMatthias Springer EraseBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, 8028faefe36SMatthias Springer Region *region, Block *insertBeforeBlock) 8038faefe36SMatthias Springer : BlockRewrite(Kind::EraseBlock, rewriterImpl, block), region(region), 8048faefe36SMatthias Springer insertBeforeBlock(insertBeforeBlock) {} 8058faefe36SMatthias Springer 8068faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 8078faefe36SMatthias Springer return rewrite->getKind() == Kind::EraseBlock; 8088faefe36SMatthias Springer } 8098faefe36SMatthias Springer 8108faefe36SMatthias Springer ~EraseBlockRewrite() override { 8118faefe36SMatthias Springer assert(!block && "rewrite was neither rolled back nor committed"); 8128faefe36SMatthias Springer } 8138faefe36SMatthias Springer 8148faefe36SMatthias Springer void rollback() override { 8158faefe36SMatthias Springer // The block (owned by this rewrite) was not actually erased yet. It was 8168faefe36SMatthias Springer // just unlinked. Put it back into its original position. 8178faefe36SMatthias Springer assert(block && "expected block"); 8188faefe36SMatthias Springer auto &blockList = region->getBlocks(); 8198faefe36SMatthias Springer Region::iterator before = insertBeforeBlock 8208faefe36SMatthias Springer ? Region::iterator(insertBeforeBlock) 8218faefe36SMatthias Springer : blockList.end(); 8228faefe36SMatthias Springer blockList.insert(before, block); 8238faefe36SMatthias Springer block = nullptr; 8248faefe36SMatthias Springer } 8258faefe36SMatthias Springer 8268faefe36SMatthias Springer void commit() override { 8278faefe36SMatthias Springer // Erase the block. 8288faefe36SMatthias Springer assert(block && "expected block"); 8298faefe36SMatthias Springer delete block; 8308faefe36SMatthias Springer block = nullptr; 8318faefe36SMatthias Springer } 8328faefe36SMatthias Springer 8338faefe36SMatthias Springer private: 8348faefe36SMatthias Springer // The region in which this block was previously contained. 8358faefe36SMatthias Springer Region *region; 8368faefe36SMatthias Springer 8378faefe36SMatthias Springer // The original successor of this block before it was unlinked. "nullptr" if 8388faefe36SMatthias Springer // this block was the only block in the region. 8398faefe36SMatthias Springer Block *insertBeforeBlock; 8408faefe36SMatthias Springer }; 8418faefe36SMatthias Springer 8428faefe36SMatthias Springer /// Inlining of a block. This rewrite is immediately reflected in the IR. 8438faefe36SMatthias Springer /// Note: This rewrite represents only the inlining of the operations. The 8448faefe36SMatthias Springer /// erasure of the inlined block is a separate rewrite. 8458faefe36SMatthias Springer class InlineBlockRewrite : public BlockRewrite { 8468faefe36SMatthias Springer public: 8478faefe36SMatthias Springer InlineBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, 8488faefe36SMatthias Springer Block *sourceBlock, Block::iterator before) 8498faefe36SMatthias Springer : BlockRewrite(Kind::InlineBlock, rewriterImpl, block), 8508faefe36SMatthias Springer sourceBlock(sourceBlock), 8518faefe36SMatthias Springer firstInlinedInst(sourceBlock->empty() ? nullptr 8528faefe36SMatthias Springer : &sourceBlock->front()), 8538faefe36SMatthias Springer lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) { 8548faefe36SMatthias Springer } 8558faefe36SMatthias Springer 8568faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 8578faefe36SMatthias Springer return rewrite->getKind() == Kind::InlineBlock; 8588faefe36SMatthias Springer } 8598faefe36SMatthias Springer 8608faefe36SMatthias Springer void rollback() override { 8618faefe36SMatthias Springer // Put the operations from the destination block (owned by the rewrite) 8628faefe36SMatthias Springer // back into the source block. 8638faefe36SMatthias Springer if (firstInlinedInst) { 8648faefe36SMatthias Springer assert(lastInlinedInst && "expected operation"); 8658faefe36SMatthias Springer sourceBlock->getOperations().splice(sourceBlock->begin(), 8668faefe36SMatthias Springer block->getOperations(), 8678faefe36SMatthias Springer Block::iterator(firstInlinedInst), 8688faefe36SMatthias Springer ++Block::iterator(lastInlinedInst)); 8698faefe36SMatthias Springer } 8708faefe36SMatthias Springer } 8718faefe36SMatthias Springer 8728faefe36SMatthias Springer private: 8738faefe36SMatthias Springer // The block that originally contained the operations. 8748faefe36SMatthias Springer Block *sourceBlock; 8758faefe36SMatthias Springer 8768faefe36SMatthias Springer // The first inlined operation. 8778faefe36SMatthias Springer Operation *firstInlinedInst; 8788faefe36SMatthias Springer 8798faefe36SMatthias Springer // The last inlined operation. 8808faefe36SMatthias Springer Operation *lastInlinedInst; 8818faefe36SMatthias Springer }; 8828faefe36SMatthias Springer 8838faefe36SMatthias Springer /// Moving of a block. This rewrite is immediately reflected in the IR. 8848faefe36SMatthias Springer class MoveBlockRewrite : public BlockRewrite { 8858faefe36SMatthias Springer public: 8868faefe36SMatthias Springer MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, 8878faefe36SMatthias Springer Region *region, Block *insertBeforeBlock) 8888faefe36SMatthias Springer : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region), 8898faefe36SMatthias Springer insertBeforeBlock(insertBeforeBlock) {} 8908faefe36SMatthias Springer 8918faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 8928faefe36SMatthias Springer return rewrite->getKind() == Kind::MoveBlock; 8938faefe36SMatthias Springer } 8948faefe36SMatthias Springer 8958faefe36SMatthias Springer void rollback() override { 8968faefe36SMatthias Springer // Move the block back to its original position. 8978faefe36SMatthias Springer Region::iterator before = 8988faefe36SMatthias Springer insertBeforeBlock ? Region::iterator(insertBeforeBlock) : region->end(); 8998faefe36SMatthias Springer region->getBlocks().splice(before, block->getParent()->getBlocks(), block); 9008faefe36SMatthias Springer } 9018faefe36SMatthias Springer 9028faefe36SMatthias Springer private: 9038faefe36SMatthias Springer // The region in which this block was previously contained. 9048faefe36SMatthias Springer Region *region; 9058faefe36SMatthias Springer 9068faefe36SMatthias Springer // The original successor of this block before it was moved. "nullptr" if 9078faefe36SMatthias Springer // this block was the only block in the region. 9088faefe36SMatthias Springer Block *insertBeforeBlock; 9098faefe36SMatthias Springer }; 9108faefe36SMatthias Springer 9118faefe36SMatthias Springer /// Splitting of a block. This rewrite is immediately reflected in the IR. 9128faefe36SMatthias Springer class SplitBlockRewrite : public BlockRewrite { 9138faefe36SMatthias Springer public: 9148faefe36SMatthias Springer SplitBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, 9158faefe36SMatthias Springer Block *originalBlock) 9168faefe36SMatthias Springer : BlockRewrite(Kind::SplitBlock, rewriterImpl, block), 9178faefe36SMatthias Springer originalBlock(originalBlock) {} 9188faefe36SMatthias Springer 9198faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 9208faefe36SMatthias Springer return rewrite->getKind() == Kind::SplitBlock; 9218faefe36SMatthias Springer } 9228faefe36SMatthias Springer 9238faefe36SMatthias Springer void rollback() override { 9248faefe36SMatthias Springer // Merge back the block that was split out. 9258faefe36SMatthias Springer originalBlock->getOperations().splice(originalBlock->end(), 9268faefe36SMatthias Springer block->getOperations()); 9278faefe36SMatthias Springer block->dropAllDefinedValueUses(); 9288faefe36SMatthias Springer block->erase(); 9298faefe36SMatthias Springer } 9308faefe36SMatthias Springer 9318faefe36SMatthias Springer private: 9328faefe36SMatthias Springer // The original block from which this block was split. 9338faefe36SMatthias Springer Block *originalBlock; 9348faefe36SMatthias Springer }; 9358faefe36SMatthias Springer 9368faefe36SMatthias Springer /// Block type conversion. This rewrite is partially reflected in the IR. 9378faefe36SMatthias Springer class BlockTypeConversionRewrite : public BlockRewrite { 9388faefe36SMatthias Springer public: 9398faefe36SMatthias Springer BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, 9408faefe36SMatthias Springer Block *block) 9418faefe36SMatthias Springer : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block) {} 9428faefe36SMatthias Springer 9438faefe36SMatthias Springer static bool classof(const IRRewrite *rewrite) { 9448faefe36SMatthias Springer return rewrite->getKind() == Kind::BlockTypeConversion; 9458faefe36SMatthias Springer } 9468faefe36SMatthias Springer 9478faefe36SMatthias Springer // TODO: Block type conversions are currently committed in 9488faefe36SMatthias Springer // `ArgConverter::applyRewrites`. This should be done in the "commit" method. 9498faefe36SMatthias Springer void rollback() override; 9508faefe36SMatthias Springer }; 9518f4cd2c7SMatthias Springer 9528f4cd2c7SMatthias Springer /// An operation rewrite. 9538f4cd2c7SMatthias Springer class OperationRewrite : public IRRewrite { 9548f4cd2c7SMatthias Springer public: 9558f4cd2c7SMatthias Springer /// Return the operation that this rewrite operates on. 9568f4cd2c7SMatthias Springer Operation *getOperation() const { return op; } 9578f4cd2c7SMatthias Springer 9588f4cd2c7SMatthias Springer static bool classof(const IRRewrite *rewrite) { 9598f4cd2c7SMatthias Springer return rewrite->getKind() >= Kind::MoveOperation && 960e214f004SMatthias Springer rewrite->getKind() <= Kind::ModifyOperation; 9618f4cd2c7SMatthias Springer } 9628f4cd2c7SMatthias Springer 9638f4cd2c7SMatthias Springer protected: 9648f4cd2c7SMatthias Springer OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, 9658f4cd2c7SMatthias Springer Operation *op) 9668f4cd2c7SMatthias Springer : IRRewrite(kind, rewriterImpl), op(op) {} 9678f4cd2c7SMatthias Springer 9688f4cd2c7SMatthias Springer // The operation that this rewrite operates on. 9698f4cd2c7SMatthias Springer Operation *op; 9708f4cd2c7SMatthias Springer }; 9718f4cd2c7SMatthias Springer 9728f4cd2c7SMatthias Springer /// Moving of an operation. This rewrite is immediately reflected in the IR. 9738f4cd2c7SMatthias Springer class MoveOperationRewrite : public OperationRewrite { 9748f4cd2c7SMatthias Springer public: 9758f4cd2c7SMatthias Springer MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, 9768f4cd2c7SMatthias Springer Operation *op, Block *block, Operation *insertBeforeOp) 9778f4cd2c7SMatthias Springer : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block), 9788f4cd2c7SMatthias Springer insertBeforeOp(insertBeforeOp) {} 9798f4cd2c7SMatthias Springer 9808f4cd2c7SMatthias Springer static bool classof(const IRRewrite *rewrite) { 9818f4cd2c7SMatthias Springer return rewrite->getKind() == Kind::MoveOperation; 9828f4cd2c7SMatthias Springer } 9838f4cd2c7SMatthias Springer 9848f4cd2c7SMatthias Springer void rollback() override { 9858f4cd2c7SMatthias Springer // Move the operation back to its original position. 9868f4cd2c7SMatthias Springer Block::iterator before = 9878f4cd2c7SMatthias Springer insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end(); 9888f4cd2c7SMatthias Springer block->getOperations().splice(before, op->getBlock()->getOperations(), op); 9898f4cd2c7SMatthias Springer } 9908f4cd2c7SMatthias Springer 9918f4cd2c7SMatthias Springer private: 9928f4cd2c7SMatthias Springer // The block in which this operation was previously contained. 9938f4cd2c7SMatthias Springer Block *block; 9948f4cd2c7SMatthias Springer 9958f4cd2c7SMatthias Springer // The original successor of this operation before it was moved. "nullptr" if 9968f4cd2c7SMatthias Springer // this operation was the only operation in the region. 9978f4cd2c7SMatthias Springer Operation *insertBeforeOp; 9988f4cd2c7SMatthias Springer }; 999e214f004SMatthias Springer 1000e214f004SMatthias Springer /// In-place modification of an op. This rewrite is immediately reflected in 1001e214f004SMatthias Springer /// the IR. The previous state of the operation is stored in this object. 1002e214f004SMatthias Springer class ModifyOperationRewrite : public OperationRewrite { 1003e214f004SMatthias Springer public: 1004e214f004SMatthias Springer ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, 1005e214f004SMatthias Springer Operation *op) 1006e214f004SMatthias Springer : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op), 1007e214f004SMatthias Springer loc(op->getLoc()), attrs(op->getAttrDictionary()), 1008e214f004SMatthias Springer operands(op->operand_begin(), op->operand_end()), 1009*3a70335bSMatthias Springer successors(op->successor_begin(), op->successor_end()) { 1010*3a70335bSMatthias Springer if (OpaqueProperties prop = op->getPropertiesStorage()) { 1011*3a70335bSMatthias Springer // Make a copy of the properties. 1012*3a70335bSMatthias Springer propertiesStorage = operator new(op->getPropertiesStorageSize()); 1013*3a70335bSMatthias Springer OpaqueProperties propCopy(propertiesStorage); 1014*3a70335bSMatthias Springer op->getName().initOpProperties(propCopy, /*init=*/prop); 1015*3a70335bSMatthias Springer } 1016*3a70335bSMatthias Springer } 1017e214f004SMatthias Springer 1018e214f004SMatthias Springer static bool classof(const IRRewrite *rewrite) { 1019e214f004SMatthias Springer return rewrite->getKind() == Kind::ModifyOperation; 1020e214f004SMatthias Springer } 1021e214f004SMatthias Springer 1022*3a70335bSMatthias Springer ~ModifyOperationRewrite() override { 1023*3a70335bSMatthias Springer assert(!propertiesStorage && 1024*3a70335bSMatthias Springer "rewrite was neither committed nor rolled back"); 1025*3a70335bSMatthias Springer } 1026*3a70335bSMatthias Springer 1027*3a70335bSMatthias Springer void commit() override { 1028*3a70335bSMatthias Springer if (propertiesStorage) { 1029*3a70335bSMatthias Springer OpaqueProperties propCopy(propertiesStorage); 1030*3a70335bSMatthias Springer op->getName().destroyOpProperties(propCopy); 1031*3a70335bSMatthias Springer operator delete(propertiesStorage); 1032*3a70335bSMatthias Springer propertiesStorage = nullptr; 1033*3a70335bSMatthias Springer } 1034*3a70335bSMatthias Springer } 1035*3a70335bSMatthias Springer 1036e214f004SMatthias Springer void rollback() override { 1037e214f004SMatthias Springer op->setLoc(loc); 1038e214f004SMatthias Springer op->setAttrs(attrs); 1039e214f004SMatthias Springer op->setOperands(operands); 1040e214f004SMatthias Springer for (const auto &it : llvm::enumerate(successors)) 1041e214f004SMatthias Springer op->setSuccessor(it.value(), it.index()); 1042*3a70335bSMatthias Springer if (propertiesStorage) { 1043*3a70335bSMatthias Springer OpaqueProperties propCopy(propertiesStorage); 1044*3a70335bSMatthias Springer op->copyProperties(propCopy); 1045*3a70335bSMatthias Springer op->getName().destroyOpProperties(propCopy); 1046*3a70335bSMatthias Springer operator delete(propertiesStorage); 1047*3a70335bSMatthias Springer propertiesStorage = nullptr; 1048*3a70335bSMatthias Springer } 1049e214f004SMatthias Springer } 1050e214f004SMatthias Springer 1051e214f004SMatthias Springer private: 1052e214f004SMatthias Springer LocationAttr loc; 1053e214f004SMatthias Springer DictionaryAttr attrs; 1054e214f004SMatthias Springer SmallVector<Value, 8> operands; 1055e214f004SMatthias Springer SmallVector<Block *, 2> successors; 1056*3a70335bSMatthias Springer void *propertiesStorage = nullptr; 1057e214f004SMatthias Springer }; 10588faefe36SMatthias Springer } // namespace 10598faefe36SMatthias Springer 1060e214f004SMatthias Springer /// Return "true" if there is an operation rewrite that matches the specified 1061e214f004SMatthias Springer /// rewrite type and operation among the given rewrites. 1062e214f004SMatthias Springer template <typename RewriteTy, typename R> 1063e214f004SMatthias Springer static bool hasRewrite(R &&rewrites, Operation *op) { 1064e214f004SMatthias Springer return any_of(std::move(rewrites), [&](auto &rewrite) { 1065e214f004SMatthias Springer auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get()); 1066e214f004SMatthias Springer return rewriteTy && rewriteTy->getOperation() == op; 1067e214f004SMatthias Springer }); 1068e214f004SMatthias Springer } 1069e214f004SMatthias Springer 10708faefe36SMatthias Springer //===----------------------------------------------------------------------===// 1071b6eb26fdSRiver Riddle // ConversionPatternRewriterImpl 1072b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1073b6eb26fdSRiver Riddle namespace mlir { 1074b6eb26fdSRiver Riddle namespace detail { 1075ea2d9383SMatthias Springer struct ConversionPatternRewriterImpl : public RewriterBase::Listener { 1076b8c6b152SChia-hung Duan explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter) 1077b8c6b152SChia-hung Duan : argConverter(rewriter, unresolvedMaterializations), 1078b8c6b152SChia-hung Duan notifyCallback(nullptr) {} 1079b6eb26fdSRiver Riddle 1080b6eb26fdSRiver Riddle /// Cleanup and destroy any generated rewrite operations. This method is 1081b6eb26fdSRiver Riddle /// invoked when the conversion process fails. 1082b6eb26fdSRiver Riddle void discardRewrites(); 1083b6eb26fdSRiver Riddle 1084b6eb26fdSRiver Riddle /// Apply all requested operation rewrites. This method is invoked when the 1085b6eb26fdSRiver Riddle /// conversion process succeeds. 1086b6eb26fdSRiver Riddle void applyRewrites(); 1087b6eb26fdSRiver Riddle 1088b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1089b6eb26fdSRiver Riddle // State Management 1090b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1091b6eb26fdSRiver Riddle 1092b6eb26fdSRiver Riddle /// Return the current state of the rewriter. 1093b6eb26fdSRiver Riddle RewriterState getCurrentState(); 1094b6eb26fdSRiver Riddle 1095b6eb26fdSRiver Riddle /// Reset the state of the rewriter to a previously saved point. 1096b6eb26fdSRiver Riddle void resetState(RewriterState state); 1097b6eb26fdSRiver Riddle 10988faefe36SMatthias Springer /// Append a rewrite. Rewrites are committed upon success and rolled back upon 10998faefe36SMatthias Springer /// failure. 11008faefe36SMatthias Springer template <typename RewriteTy, typename... Args> 11018faefe36SMatthias Springer void appendRewrite(Args &&...args) { 11028faefe36SMatthias Springer rewrites.push_back( 11038faefe36SMatthias Springer std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...)); 11048faefe36SMatthias Springer } 1105b6eb26fdSRiver Riddle 11068faefe36SMatthias Springer /// Undo the rewrites (motions, splits) one by one in reverse order until 11078faefe36SMatthias Springer /// "numRewritesToKeep" rewrites remains. 11088faefe36SMatthias Springer void undoRewrites(unsigned numRewritesToKeep = 0); 1109b6eb26fdSRiver Riddle 1110015192c6SRiver Riddle /// Remap the given values to those with potentially different types. Returns 1111015192c6SRiver Riddle /// success if the values could be remapped, failure otherwise. `valueDiagTag` 1112015192c6SRiver Riddle /// is the tag used when describing a value within a diagnostic, e.g. 1113015192c6SRiver Riddle /// "operand". 11140de16fafSRamkumar Ramachandra LogicalResult remapValues(StringRef valueDiagTag, 11150de16fafSRamkumar Ramachandra std::optional<Location> inputLoc, 1116015192c6SRiver Riddle PatternRewriter &rewriter, ValueRange values, 1117b6eb26fdSRiver Riddle SmallVectorImpl<Value> &remapped); 1118b6eb26fdSRiver Riddle 1119b6eb26fdSRiver Riddle /// Returns true if the given operation is ignored, and does not need to be 1120b6eb26fdSRiver Riddle /// converted. 1121b6eb26fdSRiver Riddle bool isOpIgnored(Operation *op) const; 1122b6eb26fdSRiver Riddle 1123b6eb26fdSRiver Riddle /// Recursively marks the nested operations under 'op' as ignored. This 1124b6eb26fdSRiver Riddle /// removes them from being considered for legalization. 1125b6eb26fdSRiver Riddle void markNestedOpsIgnored(Operation *op); 1126b6eb26fdSRiver Riddle 1127b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1128b6eb26fdSRiver Riddle // Type Conversion 1129b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1130b6eb26fdSRiver Riddle 1131b6eb26fdSRiver Riddle /// Convert the signature of the given block. 1132b6eb26fdSRiver Riddle FailureOr<Block *> convertBlockSignature( 1133ce254598SMatthias Springer Block *block, const TypeConverter *converter, 1134b6eb26fdSRiver Riddle TypeConverter::SignatureConversion *conversion = nullptr); 1135b6eb26fdSRiver Riddle 1136223dcdcfSSean Silva /// Apply a signature conversion on the given region, using `converter` for 1137223dcdcfSSean Silva /// materializations if not null. 1138b6eb26fdSRiver Riddle Block * 1139b6eb26fdSRiver Riddle applySignatureConversion(Region *region, 1140223dcdcfSSean Silva TypeConverter::SignatureConversion &conversion, 1141ce254598SMatthias Springer const TypeConverter *converter); 1142b6eb26fdSRiver Riddle 1143b6eb26fdSRiver Riddle /// Convert the types of block arguments within the given region. 1144b6eb26fdSRiver Riddle FailureOr<Block *> 1145ce254598SMatthias Springer convertRegionTypes(Region *region, const TypeConverter &converter, 1146b6eb26fdSRiver Riddle TypeConverter::SignatureConversion *entryConversion); 1147b6eb26fdSRiver Riddle 11483b021fbdSKareemErgawy-TomTom /// Convert the types of non-entry block arguments within the given region. 1149aa6eb2afSKareemErgawy-TomTom LogicalResult convertNonEntryRegionTypes( 1150ce254598SMatthias Springer Region *region, const TypeConverter &converter, 1151aa6eb2afSKareemErgawy-TomTom ArrayRef<TypeConverter::SignatureConversion> blockConversions = {}); 11523b021fbdSKareemErgawy-TomTom 1153b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1154b6eb26fdSRiver Riddle // Rewriter Notification Hooks 1155b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1156b6eb26fdSRiver Riddle 1157ea2d9383SMatthias Springer //// Notifies that an op was inserted. 1158ea2d9383SMatthias Springer void notifyOperationInserted(Operation *op, 1159ea2d9383SMatthias Springer OpBuilder::InsertPoint previous) override; 1160ea2d9383SMatthias Springer 1161ea2d9383SMatthias Springer /// Notifies that an op is about to be replaced with the given values. 1162b6eb26fdSRiver Riddle void notifyOpReplaced(Operation *op, ValueRange newValues); 1163b6eb26fdSRiver Riddle 1164b6eb26fdSRiver Riddle /// Notifies that a block is about to be erased. 1165b6eb26fdSRiver Riddle void notifyBlockIsBeingErased(Block *block); 1166b6eb26fdSRiver Riddle 1167ea2d9383SMatthias Springer /// Notifies that a block was inserted. 1168ea2d9383SMatthias Springer void notifyBlockInserted(Block *block, Region *previous, 1169ea2d9383SMatthias Springer Region::iterator previousIt) override; 1170b6eb26fdSRiver Riddle 1171b6eb26fdSRiver Riddle /// Notifies that a block was split. 1172b6eb26fdSRiver Riddle void notifySplitBlock(Block *block, Block *continuation); 1173b6eb26fdSRiver Riddle 117442c31d83SMatthias Springer /// Notifies that a block is being inlined into another block. 117542c31d83SMatthias Springer void notifyBlockBeingInlined(Block *block, Block *srcBlock, 117642c31d83SMatthias Springer Block::iterator before); 1177b6eb26fdSRiver Riddle 1178b6eb26fdSRiver Riddle /// Notifies that a pattern match failed for the given reason. 1179ea2d9383SMatthias Springer void 1180ea2d9383SMatthias Springer notifyMatchFailure(Location loc, 1181ea2d9383SMatthias Springer function_ref<void(Diagnostic &)> reasonCallback) override; 1182b6eb26fdSRiver Riddle 1183b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1184b6eb26fdSRiver Riddle // State 1185b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1186b6eb26fdSRiver Riddle 1187b6eb26fdSRiver Riddle // Mapping between replaced values that differ in type. This happens when 1188b6eb26fdSRiver Riddle // replacing a value with one of a different type. 1189b6eb26fdSRiver Riddle ConversionValueMapping mapping; 1190b6eb26fdSRiver Riddle 1191b6eb26fdSRiver Riddle /// Utility used to convert block arguments. 1192b6eb26fdSRiver Riddle ArgConverter argConverter; 1193b6eb26fdSRiver Riddle 1194b6eb26fdSRiver Riddle /// Ordered vector of all of the newly created operations during conversion. 1195015192c6SRiver Riddle SmallVector<Operation *> createdOps; 1196015192c6SRiver Riddle 1197015192c6SRiver Riddle /// Ordered vector of all unresolved type conversion materializations during 1198015192c6SRiver Riddle /// conversion. 1199015192c6SRiver Riddle SmallVector<UnresolvedMaterialization> unresolvedMaterializations; 1200b6eb26fdSRiver Riddle 1201b6eb26fdSRiver Riddle /// Ordered map of requested operation replacements. 1202b6eb26fdSRiver Riddle llvm::MapVector<Operation *, OpReplacement> replacements; 1203b6eb26fdSRiver Riddle 1204b6eb26fdSRiver Riddle /// Ordered vector of any requested block argument replacements. 1205b6eb26fdSRiver Riddle SmallVector<BlockArgument, 4> argReplacements; 1206b6eb26fdSRiver Riddle 1207b6eb26fdSRiver Riddle /// Ordered list of block operations (creations, splits, motions). 12088faefe36SMatthias Springer SmallVector<std::unique_ptr<IRRewrite>> rewrites; 1209b6eb26fdSRiver Riddle 1210b6eb26fdSRiver Riddle /// A set of operations that should no longer be considered for legalization, 1211b6eb26fdSRiver Riddle /// but were not directly replace/erased/etc. by a pattern. These are 1212b6eb26fdSRiver Riddle /// generally child operations of other operations who were 1213b6eb26fdSRiver Riddle /// replaced/erased/etc. This is not meant to be an exhaustive list of all 1214b6eb26fdSRiver Riddle /// operations, but the minimal set that can be used to detect if a given 1215b6eb26fdSRiver Riddle /// operation should be `ignored`. For example, we may add the operations that 1216b6eb26fdSRiver Riddle /// define non-empty regions to the set, but not any of the others. This 1217b6eb26fdSRiver Riddle /// simplifies the amount of memory needed as we can query if the parent 1218b6eb26fdSRiver Riddle /// operation was ignored. 12194efb7754SRiver Riddle SetVector<Operation *> ignoredOps; 1220b6eb26fdSRiver Riddle 1221b6eb26fdSRiver Riddle /// A vector of indices into `replacements` of operations that were replaced 1222b6eb26fdSRiver Riddle /// with values with different result types than the original operation, e.g. 1223b6eb26fdSRiver Riddle /// 1->N conversion of some kind. 1224b6eb26fdSRiver Riddle SmallVector<unsigned, 4> operationsWithChangedResults; 1225b6eb26fdSRiver Riddle 122601b55f16SRiver Riddle /// The current type converter, or nullptr if no type converter is currently 122701b55f16SRiver Riddle /// active. 1228ce254598SMatthias Springer const TypeConverter *currentTypeConverter = nullptr; 1229b6eb26fdSRiver Riddle 1230b8c6b152SChia-hung Duan /// This allows the user to collect the match failure message. 1231b8c6b152SChia-hung Duan function_ref<void(Diagnostic &)> notifyCallback; 1232b8c6b152SChia-hung Duan 1233b6eb26fdSRiver Riddle #ifndef NDEBUG 1234b6eb26fdSRiver Riddle /// A set of operations that have pending updates. This tracking isn't 1235b6eb26fdSRiver Riddle /// strictly necessary, and is thus only active during debug builds for extra 1236b6eb26fdSRiver Riddle /// verification. 1237b6eb26fdSRiver Riddle SmallPtrSet<Operation *, 1> pendingRootUpdates; 1238b6eb26fdSRiver Riddle 1239b6eb26fdSRiver Riddle /// A logger used to emit diagnostics during the conversion process. 1240b6eb26fdSRiver Riddle llvm::ScopedPrinter logger{llvm::dbgs()}; 1241b6eb26fdSRiver Riddle #endif 1242b6eb26fdSRiver Riddle }; 1243be0a7e9fSMehdi Amini } // namespace detail 1244be0a7e9fSMehdi Amini } // namespace mlir 1245b6eb26fdSRiver Riddle 12468faefe36SMatthias Springer void BlockTypeConversionRewrite::rollback() { 12478faefe36SMatthias Springer // Undo the type conversion. 12488faefe36SMatthias Springer rewriterImpl.argConverter.discardRewrites(block); 12498faefe36SMatthias Springer } 12508faefe36SMatthias Springer 1251b6eb26fdSRiver Riddle /// Detach any operations nested in the given operation from their parent 1252b6eb26fdSRiver Riddle /// blocks, and erase the given operation. This can be used when the nested 1253b6eb26fdSRiver Riddle /// operations are scheduled for erasure themselves, so deleting the regions of 1254b6eb26fdSRiver Riddle /// the given operation together with their content would result in double-free. 1255b6eb26fdSRiver Riddle /// This happens, for example, when rolling back op creation in the reverse 1256b6eb26fdSRiver Riddle /// order and if the nested ops were created before the parent op. This function 1257b6eb26fdSRiver Riddle /// does not need to collect nested ops recursively because it is expected to 1258b6eb26fdSRiver Riddle /// also be called for each nested op when it is about to be deleted. 1259b6eb26fdSRiver Riddle static void detachNestedAndErase(Operation *op) { 1260b6eb26fdSRiver Riddle for (Region ®ion : op->getRegions()) { 1261b6eb26fdSRiver Riddle for (Block &block : region.getBlocks()) { 1262b6eb26fdSRiver Riddle while (!block.getOperations().empty()) 1263b6eb26fdSRiver Riddle block.getOperations().remove(block.getOperations().begin()); 1264b6eb26fdSRiver Riddle block.dropAllDefinedValueUses(); 1265b6eb26fdSRiver Riddle } 1266b6eb26fdSRiver Riddle } 12673bd620d4STres Popp op->dropAllUses(); 1268b6eb26fdSRiver Riddle op->erase(); 1269b6eb26fdSRiver Riddle } 1270b6eb26fdSRiver Riddle 1271b6eb26fdSRiver Riddle void ConversionPatternRewriterImpl::discardRewrites() { 12728faefe36SMatthias Springer undoRewrites(); 1273b6eb26fdSRiver Riddle 1274b6eb26fdSRiver Riddle // Remove any newly created ops. 1275015192c6SRiver Riddle for (UnresolvedMaterialization &materialization : unresolvedMaterializations) 1276015192c6SRiver Riddle detachNestedAndErase(materialization.getOp()); 1277b6eb26fdSRiver Riddle for (auto *op : llvm::reverse(createdOps)) 1278b6eb26fdSRiver Riddle detachNestedAndErase(op); 1279b6eb26fdSRiver Riddle } 1280b6eb26fdSRiver Riddle 1281b6eb26fdSRiver Riddle void ConversionPatternRewriterImpl::applyRewrites() { 1282b6eb26fdSRiver Riddle // Apply all of the rewrites replacements requested during conversion. 1283b6eb26fdSRiver Riddle for (auto &repl : replacements) { 1284b6eb26fdSRiver Riddle for (OpResult result : repl.first->getResults()) 1285015192c6SRiver Riddle if (Value newValue = mapping.lookupOrNull(result, result.getType())) 1286b6eb26fdSRiver Riddle result.replaceAllUsesWith(newValue); 1287b6eb26fdSRiver Riddle 1288b6eb26fdSRiver Riddle // If this operation defines any regions, drop any pending argument 1289b6eb26fdSRiver Riddle // rewrites. 1290b6eb26fdSRiver Riddle if (repl.first->getNumRegions()) 1291b6eb26fdSRiver Riddle argConverter.notifyOpRemoved(repl.first); 1292b6eb26fdSRiver Riddle } 1293b6eb26fdSRiver Riddle 1294b6eb26fdSRiver Riddle // Apply all of the requested argument replacements. 1295b6eb26fdSRiver Riddle for (BlockArgument arg : argReplacements) { 1296015192c6SRiver Riddle Value repl = mapping.lookupOrNull(arg, arg.getType()); 1297015192c6SRiver Riddle if (!repl) 1298015192c6SRiver Riddle continue; 1299015192c6SRiver Riddle 13005550c821STres Popp if (isa<BlockArgument>(repl)) { 1301b6eb26fdSRiver Riddle arg.replaceAllUsesWith(repl); 1302b6eb26fdSRiver Riddle continue; 1303b6eb26fdSRiver Riddle } 1304b6eb26fdSRiver Riddle 1305b6eb26fdSRiver Riddle // If the replacement value is an operation, we check to make sure that we 1306b6eb26fdSRiver Riddle // don't replace uses that are within the parent operation of the 1307b6eb26fdSRiver Riddle // replacement value. 13085550c821STres Popp Operation *replOp = cast<OpResult>(repl).getOwner(); 1309b6eb26fdSRiver Riddle Block *replBlock = replOp->getBlock(); 1310b6eb26fdSRiver Riddle arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { 1311b6eb26fdSRiver Riddle Operation *user = operand.getOwner(); 1312b6eb26fdSRiver Riddle return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); 1313b6eb26fdSRiver Riddle }); 1314b6eb26fdSRiver Riddle } 1315b6eb26fdSRiver Riddle 1316015192c6SRiver Riddle // Drop all of the unresolved materialization operations created during 1317015192c6SRiver Riddle // conversion. 1318015192c6SRiver Riddle for (auto &mat : unresolvedMaterializations) { 1319015192c6SRiver Riddle mat.getOp()->dropAllUses(); 1320015192c6SRiver Riddle mat.getOp()->erase(); 1321015192c6SRiver Riddle } 1322015192c6SRiver Riddle 1323b6eb26fdSRiver Riddle // In a second pass, erase all of the replaced operations in reverse. This 1324b6eb26fdSRiver Riddle // allows processing nested operations before their parent region is 1325a360a978SMehdi Amini // destroyed. Because we process in reverse order, producers may be deleted 1326a360a978SMehdi Amini // before their users (a pattern deleting a producer and then the consumer) 1327a360a978SMehdi Amini // so we first drop all uses explicitly. 1328a360a978SMehdi Amini for (auto &repl : llvm::reverse(replacements)) { 1329a360a978SMehdi Amini repl.first->dropAllUses(); 1330b6eb26fdSRiver Riddle repl.first->erase(); 1331a360a978SMehdi Amini } 1332b6eb26fdSRiver Riddle 1333b6eb26fdSRiver Riddle argConverter.applyRewrites(mapping); 1334b6eb26fdSRiver Riddle 13358faefe36SMatthias Springer // Commit all rewrites. 13368faefe36SMatthias Springer for (auto &rewrite : rewrites) 13378faefe36SMatthias Springer rewrite->commit(); 1338b6eb26fdSRiver Riddle } 1339b6eb26fdSRiver Riddle 1340b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1341b6eb26fdSRiver Riddle // State Management 1342b6eb26fdSRiver Riddle 1343b6eb26fdSRiver Riddle RewriterState ConversionPatternRewriterImpl::getCurrentState() { 1344015192c6SRiver Riddle return RewriterState(createdOps.size(), unresolvedMaterializations.size(), 1345015192c6SRiver Riddle replacements.size(), argReplacements.size(), 1346e214f004SMatthias Springer rewrites.size(), ignoredOps.size()); 1347b6eb26fdSRiver Riddle } 1348b6eb26fdSRiver Riddle 1349b6eb26fdSRiver Riddle void ConversionPatternRewriterImpl::resetState(RewriterState state) { 1350b6eb26fdSRiver Riddle // Reset any replaced arguments. 1351b6eb26fdSRiver Riddle for (BlockArgument replacedArg : 1352b6eb26fdSRiver Riddle llvm::drop_begin(argReplacements, state.numArgReplacements)) 1353b6eb26fdSRiver Riddle mapping.erase(replacedArg); 1354b6eb26fdSRiver Riddle argReplacements.resize(state.numArgReplacements); 1355b6eb26fdSRiver Riddle 13568faefe36SMatthias Springer // Undo any rewrites. 13578faefe36SMatthias Springer undoRewrites(state.numRewrites); 1358b6eb26fdSRiver Riddle 1359b6eb26fdSRiver Riddle // Reset any replaced operations and undo any saved mappings. 1360b6eb26fdSRiver Riddle for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) 1361b6eb26fdSRiver Riddle for (auto result : repl.first->getResults()) 1362b6eb26fdSRiver Riddle mapping.erase(result); 1363b6eb26fdSRiver Riddle while (replacements.size() != state.numReplacements) 1364b6eb26fdSRiver Riddle replacements.pop_back(); 1365b6eb26fdSRiver Riddle 1366015192c6SRiver Riddle // Pop all of the newly inserted materializations. 1367015192c6SRiver Riddle while (unresolvedMaterializations.size() != 1368015192c6SRiver Riddle state.numUnresolvedMaterializations) { 1369015192c6SRiver Riddle UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val(); 1370015192c6SRiver Riddle UnrealizedConversionCastOp op = mat.getOp(); 1371015192c6SRiver Riddle 1372015192c6SRiver Riddle // If this was a target materialization, drop the mapping that was inserted. 1373015192c6SRiver Riddle if (mat.getKind() == UnresolvedMaterialization::Target) { 1374015192c6SRiver Riddle for (Value input : op->getOperands()) 1375015192c6SRiver Riddle mapping.erase(input); 1376015192c6SRiver Riddle } 1377015192c6SRiver Riddle detachNestedAndErase(op); 1378015192c6SRiver Riddle } 1379015192c6SRiver Riddle 1380b6eb26fdSRiver Riddle // Pop all of the newly created operations. 1381b6eb26fdSRiver Riddle while (createdOps.size() != state.numCreatedOps) { 1382b6eb26fdSRiver Riddle detachNestedAndErase(createdOps.back()); 1383b6eb26fdSRiver Riddle createdOps.pop_back(); 1384b6eb26fdSRiver Riddle } 1385b6eb26fdSRiver Riddle 1386b6eb26fdSRiver Riddle // Pop all of the recorded ignored operations that are no longer valid. 1387b6eb26fdSRiver Riddle while (ignoredOps.size() != state.numIgnoredOperations) 1388b6eb26fdSRiver Riddle ignoredOps.pop_back(); 1389b6eb26fdSRiver Riddle 1390b6eb26fdSRiver Riddle // Reset operations with changed results. 1391b6eb26fdSRiver Riddle while (!operationsWithChangedResults.empty() && 1392b6eb26fdSRiver Riddle operationsWithChangedResults.back() >= state.numReplacements) 1393b6eb26fdSRiver Riddle operationsWithChangedResults.pop_back(); 1394b6eb26fdSRiver Riddle } 1395b6eb26fdSRiver Riddle 13968faefe36SMatthias Springer void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { 13978faefe36SMatthias Springer for (auto &rewrite : 13988faefe36SMatthias Springer llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) 13998faefe36SMatthias Springer rewrite->rollback(); 14008faefe36SMatthias Springer rewrites.resize(numRewritesToKeep); 1401b6eb26fdSRiver Riddle } 1402b6eb26fdSRiver Riddle 1403b6eb26fdSRiver Riddle LogicalResult ConversionPatternRewriterImpl::remapValues( 14040de16fafSRamkumar Ramachandra StringRef valueDiagTag, std::optional<Location> inputLoc, 1405015192c6SRiver Riddle PatternRewriter &rewriter, ValueRange values, 1406015192c6SRiver Riddle SmallVectorImpl<Value> &remapped) { 1407015192c6SRiver Riddle remapped.reserve(llvm::size(values)); 1408b6eb26fdSRiver Riddle 1409b6eb26fdSRiver Riddle SmallVector<Type, 1> legalTypes; 1410e4853be2SMehdi Amini for (const auto &it : llvm::enumerate(values)) { 1411b6eb26fdSRiver Riddle Value operand = it.value(); 1412b6eb26fdSRiver Riddle Type origType = operand.getType(); 1413b6eb26fdSRiver Riddle 14149ca67d7fSAlexander Belyaev // If a converter was provided, get the desired legal types for this 14159ca67d7fSAlexander Belyaev // operand. 14169ca67d7fSAlexander Belyaev Type desiredType; 1417015192c6SRiver Riddle if (currentTypeConverter) { 1418f8184d4cSAlexander Belyaev // If there is no legal conversion, fail to match this pattern. 14199ca67d7fSAlexander Belyaev legalTypes.clear(); 1420015192c6SRiver Riddle if (failed(currentTypeConverter->convertType(origType, legalTypes))) { 1421015192c6SRiver Riddle Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); 14229a028afdSMatthias Springer notifyMatchFailure(operandLoc, [=](Diagnostic &diag) { 1423015192c6SRiver Riddle diag << "unable to convert type for " << valueDiagTag << " #" 1424015192c6SRiver Riddle << it.index() << ", type was " << origType; 1425b6eb26fdSRiver Riddle }); 14269a028afdSMatthias Springer return failure(); 1427b6eb26fdSRiver Riddle } 1428b6eb26fdSRiver Riddle // TODO: There currently isn't any mechanism to do 1->N type conversion 14299ca67d7fSAlexander Belyaev // via the PatternRewriter replacement API, so for now we just ignore it. 14309ca67d7fSAlexander Belyaev if (legalTypes.size() == 1) 14319ca67d7fSAlexander Belyaev desiredType = legalTypes.front(); 14329ca67d7fSAlexander Belyaev } else { 14339ca67d7fSAlexander Belyaev // TODO: What we should do here is just set `desiredType` to `origType` 14349ca67d7fSAlexander Belyaev // and then handle the necessary type conversions after the conversion 14359ca67d7fSAlexander Belyaev // process has finished. Unfortunately a lot of patterns currently rely on 14369ca67d7fSAlexander Belyaev // receiving the new operands even if the types change, so we keep the 14379ca67d7fSAlexander Belyaev // original behavior here for now until all of the patterns relying on 14389ca67d7fSAlexander Belyaev // this get updated. 1439b6eb26fdSRiver Riddle } 14409ca67d7fSAlexander Belyaev Value newOperand = mapping.lookupOrDefault(operand, desiredType); 14419ca67d7fSAlexander Belyaev 14429ca67d7fSAlexander Belyaev // Handle the case where the conversion was 1->1 and the new operand type 14439ca67d7fSAlexander Belyaev // isn't legal. 14449ca67d7fSAlexander Belyaev Type newOperandType = newOperand.getType(); 1445015192c6SRiver Riddle if (currentTypeConverter && desiredType && newOperandType != desiredType) { 1446015192c6SRiver Riddle Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); 1447015192c6SRiver Riddle Value castValue = buildUnresolvedTargetMaterialization( 1448015192c6SRiver Riddle operandLoc, newOperand, desiredType, currentTypeConverter, 1449015192c6SRiver Riddle unresolvedMaterializations); 1450015192c6SRiver Riddle mapping.map(mapping.lookupOrDefault(newOperand), castValue); 1451015192c6SRiver Riddle newOperand = castValue; 1452b6eb26fdSRiver Riddle } 1453b6eb26fdSRiver Riddle remapped.push_back(newOperand); 1454b6eb26fdSRiver Riddle } 1455b6eb26fdSRiver Riddle return success(); 1456b6eb26fdSRiver Riddle } 1457b6eb26fdSRiver Riddle 1458b6eb26fdSRiver Riddle bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { 1459b6eb26fdSRiver Riddle // Check to see if this operation was replaced or its parent ignored. 1460b6eb26fdSRiver Riddle return replacements.count(op) || ignoredOps.count(op->getParentOp()); 1461b6eb26fdSRiver Riddle } 1462b6eb26fdSRiver Riddle 1463b6eb26fdSRiver Riddle void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { 1464b6eb26fdSRiver Riddle // Walk this operation and collect nested operations that define non-empty 1465b6eb26fdSRiver Riddle // regions. We mark such operations as 'ignored' so that we know we don't have 1466b6eb26fdSRiver Riddle // to convert them, or their nested ops. 1467b6eb26fdSRiver Riddle if (op->getNumRegions() == 0) 1468b6eb26fdSRiver Riddle return; 1469b6eb26fdSRiver Riddle op->walk([&](Operation *op) { 1470b6eb26fdSRiver Riddle if (llvm::any_of(op->getRegions(), 1471b6eb26fdSRiver Riddle [](Region ®ion) { return !region.empty(); })) 1472b6eb26fdSRiver Riddle ignoredOps.insert(op); 1473b6eb26fdSRiver Riddle }); 1474b6eb26fdSRiver Riddle } 1475b6eb26fdSRiver Riddle 1476b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1477b6eb26fdSRiver Riddle // Type Conversion 1478b6eb26fdSRiver Riddle 1479b6eb26fdSRiver Riddle FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature( 1480ce254598SMatthias Springer Block *block, const TypeConverter *converter, 1481b6eb26fdSRiver Riddle TypeConverter::SignatureConversion *conversion) { 1482b6eb26fdSRiver Riddle FailureOr<Block *> result = 14830409eb28SAlex Zinenko conversion ? argConverter.applySignatureConversion( 14840409eb28SAlex Zinenko block, converter, *conversion, mapping, argReplacements) 14850409eb28SAlex Zinenko : argConverter.convertSignature(block, converter, mapping, 14860409eb28SAlex Zinenko argReplacements); 148756272257SStella Laurenzo if (failed(result)) 148856272257SStella Laurenzo return failure(); 14896d5fc1e3SKazu Hirata if (Block *newBlock = *result) { 1490b6eb26fdSRiver Riddle if (newBlock != block) 14918faefe36SMatthias Springer appendRewrite<BlockTypeConversionRewrite>(newBlock); 1492b6eb26fdSRiver Riddle } 1493b6eb26fdSRiver Riddle return result; 1494b6eb26fdSRiver Riddle } 1495b6eb26fdSRiver Riddle 1496b6eb26fdSRiver Riddle Block *ConversionPatternRewriterImpl::applySignatureConversion( 1497223dcdcfSSean Silva Region *region, TypeConverter::SignatureConversion &conversion, 1498ce254598SMatthias Springer const TypeConverter *converter) { 1499015192c6SRiver Riddle if (!region->empty()) 1500015192c6SRiver Riddle return *convertBlockSignature(®ion->front(), converter, &conversion); 1501b6eb26fdSRiver Riddle return nullptr; 1502b6eb26fdSRiver Riddle } 1503b6eb26fdSRiver Riddle 1504b6eb26fdSRiver Riddle FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( 1505ce254598SMatthias Springer Region *region, const TypeConverter &converter, 1506b6eb26fdSRiver Riddle TypeConverter::SignatureConversion *entryConversion) { 1507b6eb26fdSRiver Riddle argConverter.setConverter(region, &converter); 1508b6eb26fdSRiver Riddle if (region->empty()) 1509b6eb26fdSRiver Riddle return nullptr; 1510b6eb26fdSRiver Riddle 15113b021fbdSKareemErgawy-TomTom if (failed(convertNonEntryRegionTypes(region, converter))) 15123b021fbdSKareemErgawy-TomTom return failure(); 15133b021fbdSKareemErgawy-TomTom 1514b6eb26fdSRiver Riddle FailureOr<Block *> newEntry = 1515015192c6SRiver Riddle convertBlockSignature(®ion->front(), &converter, entryConversion); 15163b021fbdSKareemErgawy-TomTom return newEntry; 15173b021fbdSKareemErgawy-TomTom } 15183b021fbdSKareemErgawy-TomTom 15193b021fbdSKareemErgawy-TomTom LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes( 1520ce254598SMatthias Springer Region *region, const TypeConverter &converter, 1521aa6eb2afSKareemErgawy-TomTom ArrayRef<TypeConverter::SignatureConversion> blockConversions) { 15223b021fbdSKareemErgawy-TomTom argConverter.setConverter(region, &converter); 15233b021fbdSKareemErgawy-TomTom if (region->empty()) 15243b021fbdSKareemErgawy-TomTom return success(); 15253b021fbdSKareemErgawy-TomTom 15263b021fbdSKareemErgawy-TomTom // Convert the arguments of each block within the region. 1527aa6eb2afSKareemErgawy-TomTom int blockIdx = 0; 1528aa6eb2afSKareemErgawy-TomTom assert((blockConversions.empty() || 1529aa6eb2afSKareemErgawy-TomTom blockConversions.size() == region->getBlocks().size() - 1) && 1530aa6eb2afSKareemErgawy-TomTom "expected either to provide no SignatureConversions at all or to " 1531aa6eb2afSKareemErgawy-TomTom "provide a SignatureConversion for each non-entry block"); 1532aa6eb2afSKareemErgawy-TomTom 1533aa6eb2afSKareemErgawy-TomTom for (Block &block : 1534aa6eb2afSKareemErgawy-TomTom llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) { 1535aa6eb2afSKareemErgawy-TomTom TypeConverter::SignatureConversion *blockConversion = 1536aa6eb2afSKareemErgawy-TomTom blockConversions.empty() 1537aa6eb2afSKareemErgawy-TomTom ? nullptr 1538aa6eb2afSKareemErgawy-TomTom : const_cast<TypeConverter::SignatureConversion *>( 1539aa6eb2afSKareemErgawy-TomTom &blockConversions[blockIdx++]); 1540aa6eb2afSKareemErgawy-TomTom 1541015192c6SRiver Riddle if (failed(convertBlockSignature(&block, &converter, blockConversion))) 1542b6eb26fdSRiver Riddle return failure(); 1543aa6eb2afSKareemErgawy-TomTom } 15443b021fbdSKareemErgawy-TomTom return success(); 1545b6eb26fdSRiver Riddle } 1546b6eb26fdSRiver Riddle 1547b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1548b6eb26fdSRiver Riddle // Rewriter Notification Hooks 1549b6eb26fdSRiver Riddle 1550ea2d9383SMatthias Springer void ConversionPatternRewriterImpl::notifyOperationInserted( 1551ea2d9383SMatthias Springer Operation *op, OpBuilder::InsertPoint previous) { 1552ea2d9383SMatthias Springer LLVM_DEBUG({ 1553ea2d9383SMatthias Springer logger.startLine() << "** Insert : '" << op->getName() << "'(" << op 1554ea2d9383SMatthias Springer << ")\n"; 1555ea2d9383SMatthias Springer }); 15568f4cd2c7SMatthias Springer if (!previous.isSet()) { 15578f4cd2c7SMatthias Springer // This is a newly created op. 1558ea2d9383SMatthias Springer createdOps.push_back(op); 15598f4cd2c7SMatthias Springer return; 15608f4cd2c7SMatthias Springer } 15618f4cd2c7SMatthias Springer Operation *prevOp = previous.getPoint() == previous.getBlock()->end() 15628f4cd2c7SMatthias Springer ? nullptr 15638f4cd2c7SMatthias Springer : &*previous.getPoint(); 15648f4cd2c7SMatthias Springer appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp); 1565ea2d9383SMatthias Springer } 1566ea2d9383SMatthias Springer 1567b6eb26fdSRiver Riddle void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op, 1568b6eb26fdSRiver Riddle ValueRange newValues) { 1569b6eb26fdSRiver Riddle assert(newValues.size() == op->getNumResults()); 1570b6eb26fdSRiver Riddle assert(!replacements.count(op) && "operation was already replaced"); 1571b6eb26fdSRiver Riddle 1572b6eb26fdSRiver Riddle // Track if any of the results changed, e.g. erased and replaced with null. 1573b6eb26fdSRiver Riddle bool resultChanged = false; 1574b6eb26fdSRiver Riddle 1575b6eb26fdSRiver Riddle // Create mappings for each of the new result values. 15769fa59e76SBenjamin Kramer for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) { 1577b6eb26fdSRiver Riddle if (!newValue) { 1578b6eb26fdSRiver Riddle resultChanged = true; 1579b6eb26fdSRiver Riddle continue; 1580b6eb26fdSRiver Riddle } 1581b6eb26fdSRiver Riddle // Remap, and check for any result type changes. 1582b6eb26fdSRiver Riddle mapping.map(result, newValue); 1583b6eb26fdSRiver Riddle resultChanged |= (newValue.getType() != result.getType()); 1584b6eb26fdSRiver Riddle } 1585b6eb26fdSRiver Riddle if (resultChanged) 1586b6eb26fdSRiver Riddle operationsWithChangedResults.push_back(replacements.size()); 1587b6eb26fdSRiver Riddle 1588b6eb26fdSRiver Riddle // Record the requested operation replacement. 158901b55f16SRiver Riddle replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter))); 1590b6eb26fdSRiver Riddle 1591b6eb26fdSRiver Riddle // Mark this operation as recursively ignored so that we don't need to 1592b6eb26fdSRiver Riddle // convert any nested operations. 1593b6eb26fdSRiver Riddle markNestedOpsIgnored(op); 1594b6eb26fdSRiver Riddle } 1595b6eb26fdSRiver Riddle 1596b6eb26fdSRiver Riddle void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { 1597b6eb26fdSRiver Riddle Region *region = block->getParent(); 15983ed98cb3SMatthias Springer Block *origNextBlock = block->getNextNode(); 15998faefe36SMatthias Springer appendRewrite<EraseBlockRewrite>(block, region, origNextBlock); 1600b6eb26fdSRiver Riddle } 1601b6eb26fdSRiver Riddle 1602ea2d9383SMatthias Springer void ConversionPatternRewriterImpl::notifyBlockInserted( 16033ed98cb3SMatthias Springer Block *block, Region *previous, Region::iterator previousIt) { 16043ed98cb3SMatthias Springer if (!previous) { 16053ed98cb3SMatthias Springer // This is a newly created block. 16068faefe36SMatthias Springer appendRewrite<CreateBlockRewrite>(block); 16073ed98cb3SMatthias Springer return; 16083ed98cb3SMatthias Springer } 16093ed98cb3SMatthias Springer Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt; 16108faefe36SMatthias Springer appendRewrite<MoveBlockRewrite>(block, previous, prevBlock); 1611b6eb26fdSRiver Riddle } 1612b6eb26fdSRiver Riddle 1613b6eb26fdSRiver Riddle void ConversionPatternRewriterImpl::notifySplitBlock(Block *block, 1614b6eb26fdSRiver Riddle Block *continuation) { 16158faefe36SMatthias Springer appendRewrite<SplitBlockRewrite>(continuation, block); 1616b6eb26fdSRiver Riddle } 1617b6eb26fdSRiver Riddle 161842c31d83SMatthias Springer void ConversionPatternRewriterImpl::notifyBlockBeingInlined( 161942c31d83SMatthias Springer Block *block, Block *srcBlock, Block::iterator before) { 16208faefe36SMatthias Springer appendRewrite<InlineBlockRewrite>(block, srcBlock, before); 1621b6eb26fdSRiver Riddle } 1622b6eb26fdSRiver Riddle 16239a028afdSMatthias Springer void ConversionPatternRewriterImpl::notifyMatchFailure( 1624b6eb26fdSRiver Riddle Location loc, function_ref<void(Diagnostic &)> reasonCallback) { 1625b6eb26fdSRiver Riddle LLVM_DEBUG({ 1626b6eb26fdSRiver Riddle Diagnostic diag(loc, DiagnosticSeverity::Remark); 1627b6eb26fdSRiver Riddle reasonCallback(diag); 1628b6eb26fdSRiver Riddle logger.startLine() << "** Failure : " << diag.str() << "\n"; 1629b8c6b152SChia-hung Duan if (notifyCallback) 1630b8c6b152SChia-hung Duan notifyCallback(diag); 1631b6eb26fdSRiver Riddle }); 1632b6eb26fdSRiver Riddle } 1633b6eb26fdSRiver Riddle 1634b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1635b6eb26fdSRiver Riddle // ConversionPatternRewriter 1636b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1637b6eb26fdSRiver Riddle 1638b6eb26fdSRiver Riddle ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) 1639b6eb26fdSRiver Riddle : PatternRewriter(ctx), 1640c6532830SMatthias Springer impl(new detail::ConversionPatternRewriterImpl(*this)) { 1641ea2d9383SMatthias Springer setListener(impl.get()); 1642c6532830SMatthias Springer } 1643c6532830SMatthias Springer 1644e5639b3fSMehdi Amini ConversionPatternRewriter::~ConversionPatternRewriter() = default; 1645b6eb26fdSRiver Riddle 1646c8fb6ee3SRiver Riddle void ConversionPatternRewriter::replaceOpWithIf( 1647c8fb6ee3SRiver Riddle Operation *op, ValueRange newValues, bool *allUsesReplaced, 1648c8fb6ee3SRiver Riddle llvm::unique_function<bool(OpOperand &) const> functor) { 1649c8fb6ee3SRiver Riddle // TODO: To support this we will need to rework a bit of how replacements are 1650c8fb6ee3SRiver Riddle // tracked, given that this isn't guranteed to replace all of the uses of an 1651c8fb6ee3SRiver Riddle // operation. The main change is that now an operation can be replaced 1652c8fb6ee3SRiver Riddle // multiple times, in parts. The current "set" based tracking is mainly useful 1653c8fb6ee3SRiver Riddle // for tracking if a replaced operation should be ignored, i.e. if all of the 1654c8fb6ee3SRiver Riddle // uses will be replaced. 1655c8fb6ee3SRiver Riddle llvm_unreachable( 1656c8fb6ee3SRiver Riddle "replaceOpWithIf is currently not supported by DialectConversion"); 1657c8fb6ee3SRiver Riddle } 1658c8fb6ee3SRiver Riddle 165971d50c89SMatthias Springer void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { 166071d50c89SMatthias Springer assert(op && newOp && "expected non-null op"); 166171d50c89SMatthias Springer replaceOp(op, newOp->getResults()); 166271d50c89SMatthias Springer } 166371d50c89SMatthias Springer 1664b6eb26fdSRiver Riddle void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { 166571d50c89SMatthias Springer assert(op->getNumResults() == newValues.size() && 166671d50c89SMatthias Springer "incorrect # of replacement values"); 1667b6eb26fdSRiver Riddle LLVM_DEBUG({ 1668b6eb26fdSRiver Riddle impl->logger.startLine() 1669b6eb26fdSRiver Riddle << "** Replace : '" << op->getName() << "'(" << op << ")\n"; 1670b6eb26fdSRiver Riddle }); 1671b6eb26fdSRiver Riddle impl->notifyOpReplaced(op, newValues); 1672b6eb26fdSRiver Riddle } 1673b6eb26fdSRiver Riddle 1674b6eb26fdSRiver Riddle void ConversionPatternRewriter::eraseOp(Operation *op) { 1675b6eb26fdSRiver Riddle LLVM_DEBUG({ 1676b6eb26fdSRiver Riddle impl->logger.startLine() 1677b6eb26fdSRiver Riddle << "** Erase : '" << op->getName() << "'(" << op << ")\n"; 1678b6eb26fdSRiver Riddle }); 1679b6eb26fdSRiver Riddle SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr); 1680b6eb26fdSRiver Riddle impl->notifyOpReplaced(op, nullRepls); 1681b6eb26fdSRiver Riddle } 1682b6eb26fdSRiver Riddle 1683b6eb26fdSRiver Riddle void ConversionPatternRewriter::eraseBlock(Block *block) { 1684b6eb26fdSRiver Riddle impl->notifyBlockIsBeingErased(block); 1685b6eb26fdSRiver Riddle 1686b6eb26fdSRiver Riddle // Mark all ops for erasure. 1687b6eb26fdSRiver Riddle for (Operation &op : *block) 1688b6eb26fdSRiver Riddle eraseOp(&op); 1689b6eb26fdSRiver Riddle 16908faefe36SMatthias Springer // Unlink the block from its parent region. The block is kept in the rewrite 16918faefe36SMatthias Springer // object and will be actually destroyed when rewrites are applied. This 1692b6eb26fdSRiver Riddle // allows us to keep the operations in the block live and undo the removal by 1693b6eb26fdSRiver Riddle // re-inserting the block. 1694b6eb26fdSRiver Riddle block->getParent()->getBlocks().remove(block); 1695b6eb26fdSRiver Riddle } 1696b6eb26fdSRiver Riddle 1697b6eb26fdSRiver Riddle Block *ConversionPatternRewriter::applySignatureConversion( 1698223dcdcfSSean Silva Region *region, TypeConverter::SignatureConversion &conversion, 1699ce254598SMatthias Springer const TypeConverter *converter) { 1700223dcdcfSSean Silva return impl->applySignatureConversion(region, conversion, converter); 1701b6eb26fdSRiver Riddle } 1702b6eb26fdSRiver Riddle 1703b6eb26fdSRiver Riddle FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( 1704ce254598SMatthias Springer Region *region, const TypeConverter &converter, 1705b6eb26fdSRiver Riddle TypeConverter::SignatureConversion *entryConversion) { 1706b6eb26fdSRiver Riddle return impl->convertRegionTypes(region, converter, entryConversion); 1707b6eb26fdSRiver Riddle } 1708b6eb26fdSRiver Riddle 17093b021fbdSKareemErgawy-TomTom LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes( 1710ce254598SMatthias Springer Region *region, const TypeConverter &converter, 1711aa6eb2afSKareemErgawy-TomTom ArrayRef<TypeConverter::SignatureConversion> blockConversions) { 1712aa6eb2afSKareemErgawy-TomTom return impl->convertNonEntryRegionTypes(region, converter, blockConversions); 17133b021fbdSKareemErgawy-TomTom } 17143b021fbdSKareemErgawy-TomTom 1715b6eb26fdSRiver Riddle void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, 1716b6eb26fdSRiver Riddle Value to) { 1717b6eb26fdSRiver Riddle LLVM_DEBUG({ 1718b6eb26fdSRiver Riddle Operation *parentOp = from.getOwner()->getParentOp(); 1719b6eb26fdSRiver Riddle impl->logger.startLine() << "** Replace Argument : '" << from 1720b6eb26fdSRiver Riddle << "'(in region of '" << parentOp->getName() 1721b6eb26fdSRiver Riddle << "'(" << from.getOwner()->getParentOp() << ")\n"; 1722b6eb26fdSRiver Riddle }); 1723b6eb26fdSRiver Riddle impl->argReplacements.push_back(from); 1724b6eb26fdSRiver Riddle impl->mapping.map(impl->mapping.lookupOrDefault(from), to); 1725b6eb26fdSRiver Riddle } 1726b6eb26fdSRiver Riddle 1727b6eb26fdSRiver Riddle Value ConversionPatternRewriter::getRemappedValue(Value key) { 1728015192c6SRiver Riddle SmallVector<Value> remappedValues; 17291a36588eSKazu Hirata if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, 1730015192c6SRiver Riddle remappedValues))) 1731015192c6SRiver Riddle return nullptr; 1732015192c6SRiver Riddle return remappedValues.front(); 1733015192c6SRiver Riddle } 1734015192c6SRiver Riddle 1735015192c6SRiver Riddle LogicalResult 1736015192c6SRiver Riddle ConversionPatternRewriter::getRemappedValues(ValueRange keys, 1737015192c6SRiver Riddle SmallVectorImpl<Value> &results) { 1738015192c6SRiver Riddle if (keys.empty()) 1739015192c6SRiver Riddle return success(); 17401a36588eSKazu Hirata return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, 1741015192c6SRiver Riddle results); 1742b6eb26fdSRiver Riddle } 1743b6eb26fdSRiver Riddle 1744b6eb26fdSRiver Riddle Block *ConversionPatternRewriter::splitBlock(Block *block, 1745b6eb26fdSRiver Riddle Block::iterator before) { 1746c2675ba9SMatthias Springer auto *continuation = block->splitBlock(before); 1747b6eb26fdSRiver Riddle impl->notifySplitBlock(block, continuation); 1748b6eb26fdSRiver Riddle return continuation; 1749b6eb26fdSRiver Riddle } 1750b6eb26fdSRiver Riddle 175142c31d83SMatthias Springer void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, 175242c31d83SMatthias Springer Block::iterator before, 1753b6eb26fdSRiver Riddle ValueRange argValues) { 1754b6eb26fdSRiver Riddle assert(argValues.size() == source->getNumArguments() && 1755b6eb26fdSRiver Riddle "incorrect # of argument replacement values"); 175642c31d83SMatthias Springer #ifndef NDEBUG 175742c31d83SMatthias Springer auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); }; 175842c31d83SMatthias Springer #endif // NDEBUG 175942c31d83SMatthias Springer // The source block will be deleted, so it should not have any users (i.e., 176042c31d83SMatthias Springer // there should be no predecessors). 176142c31d83SMatthias Springer assert(llvm::all_of(source->getUsers(), opIgnored) && 176242c31d83SMatthias Springer "expected 'source' to have no predecessors"); 176342c31d83SMatthias Springer 176442c31d83SMatthias Springer impl->notifyBlockBeingInlined(dest, source, before); 1765b6eb26fdSRiver Riddle for (auto it : llvm::zip(source->getArguments(), argValues)) 1766b6eb26fdSRiver Riddle replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); 176742c31d83SMatthias Springer dest->getOperations().splice(before, source->getOperations()); 1768b6eb26fdSRiver Riddle eraseBlock(source); 1769b6eb26fdSRiver Riddle } 1770b6eb26fdSRiver Riddle 17715fcf907bSMatthias Springer void ConversionPatternRewriter::startOpModification(Operation *op) { 1772b6eb26fdSRiver Riddle #ifndef NDEBUG 1773b6eb26fdSRiver Riddle impl->pendingRootUpdates.insert(op); 1774b6eb26fdSRiver Riddle #endif 1775e214f004SMatthias Springer impl->appendRewrite<ModifyOperationRewrite>(op); 1776b6eb26fdSRiver Riddle } 1777b6eb26fdSRiver Riddle 17785fcf907bSMatthias Springer void ConversionPatternRewriter::finalizeOpModification(Operation *op) { 17795fcf907bSMatthias Springer PatternRewriter::finalizeOpModification(op); 1780b6eb26fdSRiver Riddle // There is nothing to do here, we only need to track the operation at the 1781b6eb26fdSRiver Riddle // start of the update. 1782b6eb26fdSRiver Riddle #ifndef NDEBUG 1783b6eb26fdSRiver Riddle assert(impl->pendingRootUpdates.erase(op) && 1784b6eb26fdSRiver Riddle "operation did not have a pending in-place update"); 1785b6eb26fdSRiver Riddle #endif 1786b6eb26fdSRiver Riddle } 1787b6eb26fdSRiver Riddle 17885fcf907bSMatthias Springer void ConversionPatternRewriter::cancelOpModification(Operation *op) { 1789b6eb26fdSRiver Riddle #ifndef NDEBUG 1790b6eb26fdSRiver Riddle assert(impl->pendingRootUpdates.erase(op) && 1791b6eb26fdSRiver Riddle "operation did not have a pending in-place update"); 1792b6eb26fdSRiver Riddle #endif 1793b6eb26fdSRiver Riddle // Erase the last update for this operation. 1794e214f004SMatthias Springer auto it = llvm::find_if( 1795e214f004SMatthias Springer llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) { 1796e214f004SMatthias Springer auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get()); 1797e214f004SMatthias Springer return modifyRewrite && modifyRewrite->getOperation() == op; 1798e214f004SMatthias Springer }); 1799e214f004SMatthias Springer assert(it != impl->rewrites.rend() && "no root update started on op"); 1800e214f004SMatthias Springer (*it)->rollback(); 1801e214f004SMatthias Springer int updateIdx = std::prev(impl->rewrites.rend()) - it; 1802e214f004SMatthias Springer impl->rewrites.erase(impl->rewrites.begin() + updateIdx); 1803b6eb26fdSRiver Riddle } 1804b6eb26fdSRiver Riddle 1805b6eb26fdSRiver Riddle detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { 1806b6eb26fdSRiver Riddle return *impl; 1807b6eb26fdSRiver Riddle } 1808b6eb26fdSRiver Riddle 1809b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1810b6eb26fdSRiver Riddle // ConversionPattern 1811b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1812b6eb26fdSRiver Riddle 1813b6eb26fdSRiver Riddle LogicalResult 1814b6eb26fdSRiver Riddle ConversionPattern::matchAndRewrite(Operation *op, 1815b6eb26fdSRiver Riddle PatternRewriter &rewriter) const { 1816b6eb26fdSRiver Riddle auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter); 1817b6eb26fdSRiver Riddle auto &rewriterImpl = dialectRewriter.getImpl(); 1818b6eb26fdSRiver Riddle 181901b55f16SRiver Riddle // Track the current conversion pattern type converter in the rewriter. 1820abf0c6c0SJan Svoboda llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter, 1821abf0c6c0SJan Svoboda getTypeConverter()); 1822b6eb26fdSRiver Riddle 1823b6eb26fdSRiver Riddle // Remap the operands of the operation. 1824b6eb26fdSRiver Riddle SmallVector<Value, 4> operands; 1825015192c6SRiver Riddle if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, 1826015192c6SRiver Riddle op->getOperands(), operands))) { 1827b6eb26fdSRiver Riddle return failure(); 1828b6eb26fdSRiver Riddle } 1829b6eb26fdSRiver Riddle return matchAndRewrite(op, operands, dialectRewriter); 1830b6eb26fdSRiver Riddle } 1831b6eb26fdSRiver Riddle 1832b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1833b6eb26fdSRiver Riddle // OperationLegalizer 1834b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 1835b6eb26fdSRiver Riddle 1836b6eb26fdSRiver Riddle namespace { 1837b6eb26fdSRiver Riddle /// A set of rewrite patterns that can be used to legalize a given operation. 1838b6eb26fdSRiver Riddle using LegalizationPatterns = SmallVector<const Pattern *, 1>; 1839b6eb26fdSRiver Riddle 1840b6eb26fdSRiver Riddle /// This class defines a recursive operation legalizer. 1841b6eb26fdSRiver Riddle class OperationLegalizer { 1842b6eb26fdSRiver Riddle public: 1843b6eb26fdSRiver Riddle using LegalizationAction = ConversionTarget::LegalizationAction; 1844b6eb26fdSRiver Riddle 1845370a6f09SMehdi Amini OperationLegalizer(const ConversionTarget &targetInfo, 184679d7f618SChris Lattner const FrozenRewritePatternSet &patterns); 1847b6eb26fdSRiver Riddle 1848b6eb26fdSRiver Riddle /// Returns true if the given operation is known to be illegal on the target. 1849b6eb26fdSRiver Riddle bool isIllegal(Operation *op) const; 1850b6eb26fdSRiver Riddle 1851b6eb26fdSRiver Riddle /// Attempt to legalize the given operation. Returns success if the operation 1852b6eb26fdSRiver Riddle /// was legalized, failure otherwise. 1853b6eb26fdSRiver Riddle LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter); 1854b6eb26fdSRiver Riddle 1855b6eb26fdSRiver Riddle /// Returns the conversion target in use by the legalizer. 1856370a6f09SMehdi Amini const ConversionTarget &getTarget() { return target; } 1857b6eb26fdSRiver Riddle 1858b6eb26fdSRiver Riddle private: 1859b6eb26fdSRiver Riddle /// Attempt to legalize the given operation by folding it. 1860b6eb26fdSRiver Riddle LogicalResult legalizeWithFold(Operation *op, 1861b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter); 1862b6eb26fdSRiver Riddle 1863b6eb26fdSRiver Riddle /// Attempt to legalize the given operation by applying a pattern. Returns 1864b6eb26fdSRiver Riddle /// success if the operation was legalized, failure otherwise. 1865b6eb26fdSRiver Riddle LogicalResult legalizeWithPattern(Operation *op, 1866b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter); 1867b6eb26fdSRiver Riddle 1868b6eb26fdSRiver Riddle /// Return true if the given pattern may be applied to the given operation, 1869b6eb26fdSRiver Riddle /// false otherwise. 1870b6eb26fdSRiver Riddle bool canApplyPattern(Operation *op, const Pattern &pattern, 1871b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter); 1872b6eb26fdSRiver Riddle 1873b6eb26fdSRiver Riddle /// Legalize the resultant IR after successfully applying the given pattern. 1874b6eb26fdSRiver Riddle LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, 1875b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, 1876b6eb26fdSRiver Riddle RewriterState &curState); 1877b6eb26fdSRiver Riddle 1878b6eb26fdSRiver Riddle /// Legalizes the actions registered during the execution of a pattern. 18798faefe36SMatthias Springer LogicalResult 18808faefe36SMatthias Springer legalizePatternBlockRewrites(Operation *op, 1881b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, 1882b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &impl, 18838faefe36SMatthias Springer RewriterState &state, RewriterState &newState); 1884b6eb26fdSRiver Riddle LogicalResult legalizePatternCreatedOperations( 1885b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 1886b6eb26fdSRiver Riddle RewriterState &state, RewriterState &newState); 1887b6eb26fdSRiver Riddle LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter, 1888b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &impl, 1889b6eb26fdSRiver Riddle RewriterState &state, 1890b6eb26fdSRiver Riddle RewriterState &newState); 1891b6eb26fdSRiver Riddle 1892b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1893b6eb26fdSRiver Riddle // Cost Model 1894b6eb26fdSRiver Riddle //===--------------------------------------------------------------------===// 1895b6eb26fdSRiver Riddle 1896b6eb26fdSRiver Riddle /// Build an optimistic legalization graph given the provided patterns. This 1897b6eb26fdSRiver Riddle /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with 1898b6eb26fdSRiver Riddle /// patterns for operations that are not directly legal, but may be 1899b6eb26fdSRiver Riddle /// transitively legal for the current target given the provided patterns. 1900b6eb26fdSRiver Riddle void buildLegalizationGraph( 1901b6eb26fdSRiver Riddle LegalizationPatterns &anyOpLegalizerPatterns, 1902b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1903b6eb26fdSRiver Riddle 1904b6eb26fdSRiver Riddle /// Compute the benefit of each node within the computed legalization graph. 1905b6eb26fdSRiver Riddle /// This orders the patterns within 'legalizerPatterns' based upon two 1906b6eb26fdSRiver Riddle /// criteria: 1907b6eb26fdSRiver Riddle /// 1) Prefer patterns that have the lowest legalization depth, i.e. 1908b6eb26fdSRiver Riddle /// represent the more direct mapping to the target. 1909b6eb26fdSRiver Riddle /// 2) When comparing patterns with the same legalization depth, prefer the 1910b6eb26fdSRiver Riddle /// pattern with the highest PatternBenefit. This allows for users to 1911b6eb26fdSRiver Riddle /// prefer specific legalizations over others. 1912b6eb26fdSRiver Riddle void computeLegalizationGraphBenefit( 1913b6eb26fdSRiver Riddle LegalizationPatterns &anyOpLegalizerPatterns, 1914b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1915b6eb26fdSRiver Riddle 1916b6eb26fdSRiver Riddle /// Compute the legalization depth when legalizing an operation of the given 1917b6eb26fdSRiver Riddle /// type. 1918b6eb26fdSRiver Riddle unsigned computeOpLegalizationDepth( 1919b6eb26fdSRiver Riddle OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth, 1920b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1921b6eb26fdSRiver Riddle 1922b6eb26fdSRiver Riddle /// Apply the conversion cost model to the given set of patterns, and return 1923b6eb26fdSRiver Riddle /// the smallest legalization depth of any of the patterns. See 1924b6eb26fdSRiver Riddle /// `computeLegalizationGraphBenefit` for the breakdown of the cost model. 1925b6eb26fdSRiver Riddle unsigned applyCostModelToPatterns( 1926b6eb26fdSRiver Riddle LegalizationPatterns &patterns, 1927b6eb26fdSRiver Riddle DenseMap<OperationName, unsigned> &minOpPatternDepth, 1928b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns); 1929b6eb26fdSRiver Riddle 1930b6eb26fdSRiver Riddle /// The current set of patterns that have been applied. 1931b6eb26fdSRiver Riddle SmallPtrSet<const Pattern *, 8> appliedPatterns; 1932b6eb26fdSRiver Riddle 1933b6eb26fdSRiver Riddle /// The legalization information provided by the target. 1934370a6f09SMehdi Amini const ConversionTarget ⌖ 1935b6eb26fdSRiver Riddle 1936b6eb26fdSRiver Riddle /// The pattern applicator to use for conversions. 1937b6eb26fdSRiver Riddle PatternApplicator applicator; 1938b6eb26fdSRiver Riddle }; 1939b6eb26fdSRiver Riddle } // namespace 1940b6eb26fdSRiver Riddle 1941370a6f09SMehdi Amini OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo, 194279d7f618SChris Lattner const FrozenRewritePatternSet &patterns) 1943b6eb26fdSRiver Riddle : target(targetInfo), applicator(patterns) { 1944b6eb26fdSRiver Riddle // The set of patterns that can be applied to illegal operations to transform 1945b6eb26fdSRiver Riddle // them into legal ones. 1946b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> legalizerPatterns; 1947b6eb26fdSRiver Riddle LegalizationPatterns anyOpLegalizerPatterns; 1948b6eb26fdSRiver Riddle 1949b6eb26fdSRiver Riddle buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns); 1950b6eb26fdSRiver Riddle computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns); 1951b6eb26fdSRiver Riddle } 1952b6eb26fdSRiver Riddle 1953b6eb26fdSRiver Riddle bool OperationLegalizer::isIllegal(Operation *op) const { 19542a3878eaSButygin return target.isIllegal(op); 1955b6eb26fdSRiver Riddle } 1956b6eb26fdSRiver Riddle 1957b6eb26fdSRiver Riddle LogicalResult 1958b6eb26fdSRiver Riddle OperationLegalizer::legalize(Operation *op, 1959b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter) { 1960b6eb26fdSRiver Riddle #ifndef NDEBUG 1961b6eb26fdSRiver Riddle const char *logLineComment = 1962b6eb26fdSRiver Riddle "//===-------------------------------------------===//\n"; 1963b6eb26fdSRiver Riddle 196401b55f16SRiver Riddle auto &logger = rewriter.getImpl().logger; 1965b6eb26fdSRiver Riddle #endif 1966b6eb26fdSRiver Riddle LLVM_DEBUG({ 196701b55f16SRiver Riddle logger.getOStream() << "\n"; 196801b55f16SRiver Riddle logger.startLine() << logLineComment; 196901b55f16SRiver Riddle logger.startLine() << "Legalizing operation : '" << op->getName() << "'(" 197001b55f16SRiver Riddle << op << ") {\n"; 197101b55f16SRiver Riddle logger.indent(); 1972b6eb26fdSRiver Riddle 1973b6eb26fdSRiver Riddle // If the operation has no regions, just print it here. 1974b6eb26fdSRiver Riddle if (op->getNumRegions() == 0) { 197501b55f16SRiver Riddle op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm()); 197601b55f16SRiver Riddle logger.getOStream() << "\n\n"; 1977b6eb26fdSRiver Riddle } 1978b6eb26fdSRiver Riddle }); 1979b6eb26fdSRiver Riddle 1980b6eb26fdSRiver Riddle // Check if this operation is legal on the target. 1981b6eb26fdSRiver Riddle if (auto legalityInfo = target.isLegal(op)) { 1982b6eb26fdSRiver Riddle LLVM_DEBUG({ 1983b6eb26fdSRiver Riddle logSuccess( 198401b55f16SRiver Riddle logger, "operation marked legal by the target{0}", 1985b6eb26fdSRiver Riddle legalityInfo->isRecursivelyLegal 1986b6eb26fdSRiver Riddle ? "; NOTE: operation is recursively legal; skipping internals" 1987b6eb26fdSRiver Riddle : ""); 198801b55f16SRiver Riddle logger.startLine() << logLineComment; 1989b6eb26fdSRiver Riddle }); 1990b6eb26fdSRiver Riddle 1991b6eb26fdSRiver Riddle // If this operation is recursively legal, mark its children as ignored so 1992b6eb26fdSRiver Riddle // that we don't consider them for legalization. 1993b6eb26fdSRiver Riddle if (legalityInfo->isRecursivelyLegal) 1994b6eb26fdSRiver Riddle rewriter.getImpl().markNestedOpsIgnored(op); 1995b6eb26fdSRiver Riddle return success(); 1996b6eb26fdSRiver Riddle } 1997b6eb26fdSRiver Riddle 1998b6eb26fdSRiver Riddle // Check to see if the operation is ignored and doesn't need to be converted. 1999b6eb26fdSRiver Riddle if (rewriter.getImpl().isOpIgnored(op)) { 2000b6eb26fdSRiver Riddle LLVM_DEBUG({ 200101b55f16SRiver Riddle logSuccess(logger, "operation marked 'ignored' during conversion"); 200201b55f16SRiver Riddle logger.startLine() << logLineComment; 2003b6eb26fdSRiver Riddle }); 2004b6eb26fdSRiver Riddle return success(); 2005b6eb26fdSRiver Riddle } 2006b6eb26fdSRiver Riddle 2007b6eb26fdSRiver Riddle // If the operation isn't legal, try to fold it in-place. 2008b6eb26fdSRiver Riddle // TODO: Should we always try to do this, even if the op is 2009b6eb26fdSRiver Riddle // already legal? 2010b6eb26fdSRiver Riddle if (succeeded(legalizeWithFold(op, rewriter))) { 2011b6eb26fdSRiver Riddle LLVM_DEBUG({ 201201b55f16SRiver Riddle logSuccess(logger, "operation was folded"); 201301b55f16SRiver Riddle logger.startLine() << logLineComment; 2014b6eb26fdSRiver Riddle }); 2015b6eb26fdSRiver Riddle return success(); 2016b6eb26fdSRiver Riddle } 2017b6eb26fdSRiver Riddle 2018b6eb26fdSRiver Riddle // Otherwise, we need to apply a legalization pattern to this operation. 2019b6eb26fdSRiver Riddle if (succeeded(legalizeWithPattern(op, rewriter))) { 2020b6eb26fdSRiver Riddle LLVM_DEBUG({ 202101b55f16SRiver Riddle logSuccess(logger, ""); 202201b55f16SRiver Riddle logger.startLine() << logLineComment; 2023b6eb26fdSRiver Riddle }); 2024b6eb26fdSRiver Riddle return success(); 2025b6eb26fdSRiver Riddle } 2026b6eb26fdSRiver Riddle 2027b6eb26fdSRiver Riddle LLVM_DEBUG({ 202801b55f16SRiver Riddle logFailure(logger, "no matched legalization pattern"); 202901b55f16SRiver Riddle logger.startLine() << logLineComment; 2030b6eb26fdSRiver Riddle }); 2031b6eb26fdSRiver Riddle return failure(); 2032b6eb26fdSRiver Riddle } 2033b6eb26fdSRiver Riddle 2034b6eb26fdSRiver Riddle LogicalResult 2035b6eb26fdSRiver Riddle OperationLegalizer::legalizeWithFold(Operation *op, 2036b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter) { 2037b6eb26fdSRiver Riddle auto &rewriterImpl = rewriter.getImpl(); 2038b6eb26fdSRiver Riddle RewriterState curState = rewriterImpl.getCurrentState(); 2039b6eb26fdSRiver Riddle 2040b6eb26fdSRiver Riddle LLVM_DEBUG({ 2041b6eb26fdSRiver Riddle rewriterImpl.logger.startLine() << "* Fold {\n"; 2042b6eb26fdSRiver Riddle rewriterImpl.logger.indent(); 2043b6eb26fdSRiver Riddle }); 2044b6eb26fdSRiver Riddle 2045b6eb26fdSRiver Riddle // Try to fold the operation. 2046b6eb26fdSRiver Riddle SmallVector<Value, 2> replacementValues; 2047b6eb26fdSRiver Riddle rewriter.setInsertionPoint(op); 2048b6eb26fdSRiver Riddle if (failed(rewriter.tryFold(op, replacementValues))) { 2049b6eb26fdSRiver Riddle LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); 2050b6eb26fdSRiver Riddle return failure(); 2051b6eb26fdSRiver Riddle } 2052b6eb26fdSRiver Riddle 2053b6eb26fdSRiver Riddle // Insert a replacement for 'op' with the folded replacement values. 2054b6eb26fdSRiver Riddle rewriter.replaceOp(op, replacementValues); 2055b6eb26fdSRiver Riddle 2056b6eb26fdSRiver Riddle // Recursively legalize any new constant operations. 2057b6eb26fdSRiver Riddle for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size(); 2058b6eb26fdSRiver Riddle i != e; ++i) { 2059b6eb26fdSRiver Riddle Operation *cstOp = rewriterImpl.createdOps[i]; 2060b6eb26fdSRiver Riddle if (failed(legalize(cstOp, rewriter))) { 2061b6eb26fdSRiver Riddle LLVM_DEBUG(logFailure(rewriterImpl.logger, 2062e49c0e48SUday Bondhugula "failed to legalize generated constant '{0}'", 2063b6eb26fdSRiver Riddle cstOp->getName())); 2064b6eb26fdSRiver Riddle rewriterImpl.resetState(curState); 2065b6eb26fdSRiver Riddle return failure(); 2066b6eb26fdSRiver Riddle } 2067b6eb26fdSRiver Riddle } 2068b6eb26fdSRiver Riddle 2069b6eb26fdSRiver Riddle LLVM_DEBUG(logSuccess(rewriterImpl.logger, "")); 2070b6eb26fdSRiver Riddle return success(); 2071b6eb26fdSRiver Riddle } 2072b6eb26fdSRiver Riddle 2073b6eb26fdSRiver Riddle LogicalResult 2074b6eb26fdSRiver Riddle OperationLegalizer::legalizeWithPattern(Operation *op, 2075b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter) { 2076b6eb26fdSRiver Riddle auto &rewriterImpl = rewriter.getImpl(); 2077b6eb26fdSRiver Riddle 2078b6eb26fdSRiver Riddle // Functor that returns if the given pattern may be applied. 2079b6eb26fdSRiver Riddle auto canApply = [&](const Pattern &pattern) { 2080b6eb26fdSRiver Riddle return canApplyPattern(op, pattern, rewriter); 2081b6eb26fdSRiver Riddle }; 2082b6eb26fdSRiver Riddle 2083b6eb26fdSRiver Riddle // Functor that cleans up the rewriter state after a pattern failed to match. 2084b6eb26fdSRiver Riddle RewriterState curState = rewriterImpl.getCurrentState(); 2085b6eb26fdSRiver Riddle auto onFailure = [&](const Pattern &pattern) { 2086e214f004SMatthias Springer assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); 2087b8c6b152SChia-hung Duan LLVM_DEBUG({ 2088b8c6b152SChia-hung Duan logFailure(rewriterImpl.logger, "pattern failed to match"); 2089b8c6b152SChia-hung Duan if (rewriterImpl.notifyCallback) { 2090b8c6b152SChia-hung Duan Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); 2091b8c6b152SChia-hung Duan diag << "Failed to apply pattern \"" << pattern.getDebugName() 2092b8c6b152SChia-hung Duan << "\" on op:\n" 2093b8c6b152SChia-hung Duan << *op; 2094b8c6b152SChia-hung Duan rewriterImpl.notifyCallback(diag); 2095b8c6b152SChia-hung Duan } 2096b8c6b152SChia-hung Duan }); 2097b6eb26fdSRiver Riddle rewriterImpl.resetState(curState); 2098b6eb26fdSRiver Riddle appliedPatterns.erase(&pattern); 2099b6eb26fdSRiver Riddle }; 2100b6eb26fdSRiver Riddle 2101b6eb26fdSRiver Riddle // Functor that performs additional legalization when a pattern is 2102b6eb26fdSRiver Riddle // successfully applied. 2103b6eb26fdSRiver Riddle auto onSuccess = [&](const Pattern &pattern) { 2104e214f004SMatthias Springer assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); 2105b6eb26fdSRiver Riddle auto result = legalizePatternResult(op, pattern, rewriter, curState); 2106b6eb26fdSRiver Riddle appliedPatterns.erase(&pattern); 2107b6eb26fdSRiver Riddle if (failed(result)) 2108b6eb26fdSRiver Riddle rewriterImpl.resetState(curState); 2109b6eb26fdSRiver Riddle return result; 2110b6eb26fdSRiver Riddle }; 2111b6eb26fdSRiver Riddle 2112b6eb26fdSRiver Riddle // Try to match and rewrite a pattern on this operation. 2113b6eb26fdSRiver Riddle return applicator.matchAndRewrite(op, rewriter, canApply, onFailure, 2114b6eb26fdSRiver Riddle onSuccess); 2115b6eb26fdSRiver Riddle } 2116b6eb26fdSRiver Riddle 2117b6eb26fdSRiver Riddle bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, 2118b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter) { 2119b6eb26fdSRiver Riddle LLVM_DEBUG({ 2120b6eb26fdSRiver Riddle auto &os = rewriter.getImpl().logger; 2121b6eb26fdSRiver Riddle os.getOStream() << "\n"; 2122b6eb26fdSRiver Riddle os.startLine() << "* Pattern : '" << op->getName() << " -> ("; 2123015192c6SRiver Riddle llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream()); 2124b6eb26fdSRiver Riddle os.getOStream() << ")' {\n"; 2125b6eb26fdSRiver Riddle os.indent(); 2126b6eb26fdSRiver Riddle }); 2127b6eb26fdSRiver Riddle 2128b6eb26fdSRiver Riddle // Ensure that we don't cycle by not allowing the same pattern to be 2129b6eb26fdSRiver Riddle // applied twice in the same recursion stack if it is not known to be safe. 2130b6eb26fdSRiver Riddle if (!pattern.hasBoundedRewriteRecursion() && 2131b6eb26fdSRiver Riddle !appliedPatterns.insert(&pattern).second) { 2132b6eb26fdSRiver Riddle LLVM_DEBUG( 2133b6eb26fdSRiver Riddle logFailure(rewriter.getImpl().logger, "pattern was already applied")); 2134b6eb26fdSRiver Riddle return false; 2135b6eb26fdSRiver Riddle } 2136b6eb26fdSRiver Riddle return true; 2137b6eb26fdSRiver Riddle } 2138b6eb26fdSRiver Riddle 2139b6eb26fdSRiver Riddle LogicalResult 2140b6eb26fdSRiver Riddle OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, 2141b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, 2142b6eb26fdSRiver Riddle RewriterState &curState) { 2143b6eb26fdSRiver Riddle auto &impl = rewriter.getImpl(); 2144b6eb26fdSRiver Riddle 2145b6eb26fdSRiver Riddle #ifndef NDEBUG 2146b6eb26fdSRiver Riddle assert(impl.pendingRootUpdates.empty() && "dangling root updates"); 2147b6eb26fdSRiver Riddle 2148b6eb26fdSRiver Riddle // Check that the root was either replaced or updated in place. 2149b6eb26fdSRiver Riddle auto replacedRoot = [&] { 2150b6eb26fdSRiver Riddle return llvm::any_of( 2151b6eb26fdSRiver Riddle llvm::drop_begin(impl.replacements, curState.numReplacements), 2152b6eb26fdSRiver Riddle [op](auto &it) { return it.first == op; }); 2153b6eb26fdSRiver Riddle }; 2154b6eb26fdSRiver Riddle auto updatedRootInPlace = [&] { 2155e214f004SMatthias Springer return hasRewrite<ModifyOperationRewrite>( 2156e214f004SMatthias Springer llvm::drop_begin(impl.rewrites, curState.numRewrites), op); 2157b6eb26fdSRiver Riddle }; 2158b6eb26fdSRiver Riddle assert((replacedRoot() || updatedRootInPlace()) && 2159b6eb26fdSRiver Riddle "expected pattern to replace the root operation"); 2160e214f004SMatthias Springer #endif // NDEBUG 2161b6eb26fdSRiver Riddle 2162b6eb26fdSRiver Riddle // Legalize each of the actions registered during application. 2163b6eb26fdSRiver Riddle RewriterState newState = impl.getCurrentState(); 21648faefe36SMatthias Springer if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState, 2165b6eb26fdSRiver Riddle newState)) || 2166b6eb26fdSRiver Riddle failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) || 2167b6eb26fdSRiver Riddle failed(legalizePatternCreatedOperations(rewriter, impl, curState, 2168b6eb26fdSRiver Riddle newState))) { 2169b6eb26fdSRiver Riddle return failure(); 2170b6eb26fdSRiver Riddle } 2171b6eb26fdSRiver Riddle 2172b6eb26fdSRiver Riddle LLVM_DEBUG(logSuccess(impl.logger, "pattern applied successfully")); 2173b6eb26fdSRiver Riddle return success(); 2174b6eb26fdSRiver Riddle } 2175b6eb26fdSRiver Riddle 21768faefe36SMatthias Springer LogicalResult OperationLegalizer::legalizePatternBlockRewrites( 2177b6eb26fdSRiver Riddle Operation *op, ConversionPatternRewriter &rewriter, 2178b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &impl, RewriterState &state, 2179b6eb26fdSRiver Riddle RewriterState &newState) { 2180b6eb26fdSRiver Riddle SmallPtrSet<Operation *, 16> operationsToIgnore; 2181b6eb26fdSRiver Riddle 2182b6eb26fdSRiver Riddle // If the pattern moved or created any blocks, make sure the types of block 2183b6eb26fdSRiver Riddle // arguments get legalized. 21848faefe36SMatthias Springer for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { 21858faefe36SMatthias Springer BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get()); 21868faefe36SMatthias Springer if (!rewrite) 21878faefe36SMatthias Springer continue; 21888faefe36SMatthias Springer Block *block = rewrite->getBlock(); 21898faefe36SMatthias Springer if (isa<BlockTypeConversionRewrite, EraseBlockRewrite>(rewrite)) 2190b6eb26fdSRiver Riddle continue; 2191b6eb26fdSRiver Riddle // Only check blocks outside of the current operation. 21928faefe36SMatthias Springer Operation *parentOp = block->getParentOp(); 21938faefe36SMatthias Springer if (!parentOp || parentOp == op || block->getNumArguments() == 0) 2194b6eb26fdSRiver Riddle continue; 2195b6eb26fdSRiver Riddle 2196b6eb26fdSRiver Riddle // If the region of the block has a type converter, try to convert the block 2197b6eb26fdSRiver Riddle // directly. 21988faefe36SMatthias Springer if (auto *converter = impl.argConverter.getConverter(block->getParent())) { 21998faefe36SMatthias Springer if (failed(impl.convertBlockSignature(block, converter))) { 2200b6eb26fdSRiver Riddle LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " 2201b6eb26fdSRiver Riddle "block")); 2202b6eb26fdSRiver Riddle return failure(); 2203b6eb26fdSRiver Riddle } 2204b6eb26fdSRiver Riddle continue; 2205b6eb26fdSRiver Riddle } 2206b6eb26fdSRiver Riddle 2207b6eb26fdSRiver Riddle // Otherwise, check that this operation isn't one generated by this pattern. 2208b6eb26fdSRiver Riddle // This is because we will attempt to legalize the parent operation, and 2209b6eb26fdSRiver Riddle // blocks in regions created by this pattern will already be legalized later 2210b6eb26fdSRiver Riddle // on. If we haven't built the set yet, build it now. 2211b6eb26fdSRiver Riddle if (operationsToIgnore.empty()) { 2212b6eb26fdSRiver Riddle auto createdOps = ArrayRef<Operation *>(impl.createdOps) 2213b6eb26fdSRiver Riddle .drop_front(state.numCreatedOps); 2214b6eb26fdSRiver Riddle operationsToIgnore.insert(createdOps.begin(), createdOps.end()); 2215b6eb26fdSRiver Riddle } 2216b6eb26fdSRiver Riddle 2217b6eb26fdSRiver Riddle // If this operation should be considered for re-legalization, try it. 2218b6eb26fdSRiver Riddle if (operationsToIgnore.insert(parentOp).second && 2219b6eb26fdSRiver Riddle failed(legalize(parentOp, rewriter))) { 22208faefe36SMatthias Springer LLVM_DEBUG(logFailure(impl.logger, 22218faefe36SMatthias Springer "operation '{0}'({1}) became illegal after rewrite", 2222b6eb26fdSRiver Riddle parentOp->getName(), parentOp)); 2223b6eb26fdSRiver Riddle return failure(); 2224b6eb26fdSRiver Riddle } 2225b6eb26fdSRiver Riddle } 2226b6eb26fdSRiver Riddle return success(); 2227b6eb26fdSRiver Riddle } 222801b55f16SRiver Riddle 2229b6eb26fdSRiver Riddle LogicalResult OperationLegalizer::legalizePatternCreatedOperations( 2230b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 2231b6eb26fdSRiver Riddle RewriterState &state, RewriterState &newState) { 2232b6eb26fdSRiver Riddle for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) { 2233b6eb26fdSRiver Riddle Operation *op = impl.createdOps[i]; 2234b6eb26fdSRiver Riddle if (failed(legalize(op, rewriter))) { 2235b6eb26fdSRiver Riddle LLVM_DEBUG(logFailure(impl.logger, 2236e49c0e48SUday Bondhugula "failed to legalize generated operation '{0}'({1})", 2237b6eb26fdSRiver Riddle op->getName(), op)); 2238b6eb26fdSRiver Riddle return failure(); 2239b6eb26fdSRiver Riddle } 2240b6eb26fdSRiver Riddle } 2241b6eb26fdSRiver Riddle return success(); 2242b6eb26fdSRiver Riddle } 224301b55f16SRiver Riddle 2244b6eb26fdSRiver Riddle LogicalResult OperationLegalizer::legalizePatternRootUpdates( 2245b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, 2246b6eb26fdSRiver Riddle RewriterState &state, RewriterState &newState) { 2247e214f004SMatthias Springer for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) { 2248e214f004SMatthias Springer auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get()); 2249e214f004SMatthias Springer if (!rewrite) 2250e214f004SMatthias Springer continue; 2251e214f004SMatthias Springer Operation *op = rewrite->getOperation(); 2252b6eb26fdSRiver Riddle if (failed(legalize(op, rewriter))) { 2253e49c0e48SUday Bondhugula LLVM_DEBUG(logFailure( 2254e49c0e48SUday Bondhugula impl.logger, "failed to legalize operation updated in-place '{0}'", 2255b6eb26fdSRiver Riddle op->getName())); 2256b6eb26fdSRiver Riddle return failure(); 2257b6eb26fdSRiver Riddle } 2258b6eb26fdSRiver Riddle } 2259b6eb26fdSRiver Riddle return success(); 2260b6eb26fdSRiver Riddle } 2261b6eb26fdSRiver Riddle 2262b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 2263b6eb26fdSRiver Riddle // Cost Model 2264b6eb26fdSRiver Riddle 2265b6eb26fdSRiver Riddle void OperationLegalizer::buildLegalizationGraph( 2266b6eb26fdSRiver Riddle LegalizationPatterns &anyOpLegalizerPatterns, 2267b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2268b6eb26fdSRiver Riddle // A mapping between an operation and a set of operations that can be used to 2269b6eb26fdSRiver Riddle // generate it. 2270b6eb26fdSRiver Riddle DenseMap<OperationName, SmallPtrSet<OperationName, 2>> parentOps; 2271b6eb26fdSRiver Riddle // A mapping between an operation and any currently invalid patterns it has. 2272b6eb26fdSRiver Riddle DenseMap<OperationName, SmallPtrSet<const Pattern *, 2>> invalidPatterns; 2273b6eb26fdSRiver Riddle // A worklist of patterns to consider for legality. 22744efb7754SRiver Riddle SetVector<const Pattern *> patternWorklist; 2275b6eb26fdSRiver Riddle 2276b6eb26fdSRiver Riddle // Build the mapping from operations to the parent ops that may generate them. 2277b6eb26fdSRiver Riddle applicator.walkAllPatterns([&](const Pattern &pattern) { 2278bef481dfSFangrui Song std::optional<OperationName> root = pattern.getRootKind(); 2279b6eb26fdSRiver Riddle 2280b6eb26fdSRiver Riddle // If the pattern has no specific root, we can't analyze the relationship 2281b6eb26fdSRiver Riddle // between the root op and generated operations. Given that, add all such 2282b6eb26fdSRiver Riddle // patterns to the legalization set. 2283b6eb26fdSRiver Riddle if (!root) { 2284b6eb26fdSRiver Riddle anyOpLegalizerPatterns.push_back(&pattern); 2285b6eb26fdSRiver Riddle return; 2286b6eb26fdSRiver Riddle } 2287b6eb26fdSRiver Riddle 2288b6eb26fdSRiver Riddle // Skip operations that are always known to be legal. 2289b6eb26fdSRiver Riddle if (target.getOpAction(*root) == LegalizationAction::Legal) 2290b6eb26fdSRiver Riddle return; 2291b6eb26fdSRiver Riddle 2292b6eb26fdSRiver Riddle // Add this pattern to the invalid set for the root op and record this root 2293b6eb26fdSRiver Riddle // as a parent for any generated operations. 2294b6eb26fdSRiver Riddle invalidPatterns[*root].insert(&pattern); 2295b6eb26fdSRiver Riddle for (auto op : pattern.getGeneratedOps()) 2296b6eb26fdSRiver Riddle parentOps[op].insert(*root); 2297b6eb26fdSRiver Riddle 2298b6eb26fdSRiver Riddle // Add this pattern to the worklist. 2299b6eb26fdSRiver Riddle patternWorklist.insert(&pattern); 2300b6eb26fdSRiver Riddle }); 2301b6eb26fdSRiver Riddle 2302b6eb26fdSRiver Riddle // If there are any patterns that don't have a specific root kind, we can't 2303b6eb26fdSRiver Riddle // make direct assumptions about what operations will never be legalized. 2304b6eb26fdSRiver Riddle // Note: Technically we could, but it would require an analysis that may 2305b6eb26fdSRiver Riddle // recurse into itself. It would be better to perform this kind of filtering 2306b6eb26fdSRiver Riddle // at a higher level than here anyways. 2307b6eb26fdSRiver Riddle if (!anyOpLegalizerPatterns.empty()) { 2308b6eb26fdSRiver Riddle for (const Pattern *pattern : patternWorklist) 2309b6eb26fdSRiver Riddle legalizerPatterns[*pattern->getRootKind()].push_back(pattern); 2310b6eb26fdSRiver Riddle return; 2311b6eb26fdSRiver Riddle } 2312b6eb26fdSRiver Riddle 2313b6eb26fdSRiver Riddle while (!patternWorklist.empty()) { 2314b6eb26fdSRiver Riddle auto *pattern = patternWorklist.pop_back_val(); 2315b6eb26fdSRiver Riddle 2316b6eb26fdSRiver Riddle // Check to see if any of the generated operations are invalid. 2317b6eb26fdSRiver Riddle if (llvm::any_of(pattern->getGeneratedOps(), [&](OperationName op) { 23180de16fafSRamkumar Ramachandra std::optional<LegalizationAction> action = target.getOpAction(op); 2319b6eb26fdSRiver Riddle return !legalizerPatterns.count(op) && 2320b6eb26fdSRiver Riddle (!action || action == LegalizationAction::Illegal); 2321b6eb26fdSRiver Riddle })) 2322b6eb26fdSRiver Riddle continue; 2323b6eb26fdSRiver Riddle 2324b6eb26fdSRiver Riddle // Otherwise, if all of the generated operation are valid, this op is now 2325b6eb26fdSRiver Riddle // legal so add all of the child patterns to the worklist. 2326b6eb26fdSRiver Riddle legalizerPatterns[*pattern->getRootKind()].push_back(pattern); 2327b6eb26fdSRiver Riddle invalidPatterns[*pattern->getRootKind()].erase(pattern); 2328b6eb26fdSRiver Riddle 2329b6eb26fdSRiver Riddle // Add any invalid patterns of the parent operations to see if they have now 2330b6eb26fdSRiver Riddle // become legal. 2331b6eb26fdSRiver Riddle for (auto op : parentOps[*pattern->getRootKind()]) 2332b6eb26fdSRiver Riddle patternWorklist.set_union(invalidPatterns[op]); 2333b6eb26fdSRiver Riddle } 2334b6eb26fdSRiver Riddle } 2335b6eb26fdSRiver Riddle 2336b6eb26fdSRiver Riddle void OperationLegalizer::computeLegalizationGraphBenefit( 2337b6eb26fdSRiver Riddle LegalizationPatterns &anyOpLegalizerPatterns, 2338b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2339b6eb26fdSRiver Riddle // The smallest pattern depth, when legalizing an operation. 2340b6eb26fdSRiver Riddle DenseMap<OperationName, unsigned> minOpPatternDepth; 2341b6eb26fdSRiver Riddle 2342b6eb26fdSRiver Riddle // For each operation that is transitively legal, compute a cost for it. 2343b6eb26fdSRiver Riddle for (auto &opIt : legalizerPatterns) 2344b6eb26fdSRiver Riddle if (!minOpPatternDepth.count(opIt.first)) 2345b6eb26fdSRiver Riddle computeOpLegalizationDepth(opIt.first, minOpPatternDepth, 2346b6eb26fdSRiver Riddle legalizerPatterns); 2347b6eb26fdSRiver Riddle 2348b6eb26fdSRiver Riddle // Apply the cost model to the patterns that can match any operation. Those 2349b6eb26fdSRiver Riddle // with a specific operation type are already resolved when computing the op 2350b6eb26fdSRiver Riddle // legalization depth. 2351b6eb26fdSRiver Riddle if (!anyOpLegalizerPatterns.empty()) 2352b6eb26fdSRiver Riddle applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth, 2353b6eb26fdSRiver Riddle legalizerPatterns); 2354b6eb26fdSRiver Riddle 2355b6eb26fdSRiver Riddle // Apply a cost model to the pattern applicator. We order patterns first by 2356b6eb26fdSRiver Riddle // depth then benefit. `legalizerPatterns` contains per-op patterns by 2357b6eb26fdSRiver Riddle // decreasing benefit. 2358b6eb26fdSRiver Riddle applicator.applyCostModel([&](const Pattern &pattern) { 2359b6eb26fdSRiver Riddle ArrayRef<const Pattern *> orderedPatternList; 2360bef481dfSFangrui Song if (std::optional<OperationName> rootName = pattern.getRootKind()) 2361b6eb26fdSRiver Riddle orderedPatternList = legalizerPatterns[*rootName]; 2362b6eb26fdSRiver Riddle else 2363b6eb26fdSRiver Riddle orderedPatternList = anyOpLegalizerPatterns; 2364b6eb26fdSRiver Riddle 2365b6eb26fdSRiver Riddle // If the pattern is not found, then it was removed and cannot be matched. 23660c29f45aSUday Bondhugula auto *it = llvm::find(orderedPatternList, &pattern); 2367b6eb26fdSRiver Riddle if (it == orderedPatternList.end()) 2368b6eb26fdSRiver Riddle return PatternBenefit::impossibleToMatch(); 2369b6eb26fdSRiver Riddle 2370b6eb26fdSRiver Riddle // Patterns found earlier in the list have higher benefit. 2371b6eb26fdSRiver Riddle return PatternBenefit(std::distance(it, orderedPatternList.end())); 2372b6eb26fdSRiver Riddle }); 2373b6eb26fdSRiver Riddle } 2374b6eb26fdSRiver Riddle 2375b6eb26fdSRiver Riddle unsigned OperationLegalizer::computeOpLegalizationDepth( 2376b6eb26fdSRiver Riddle OperationName op, DenseMap<OperationName, unsigned> &minOpPatternDepth, 2377b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2378b6eb26fdSRiver Riddle // Check for existing depth. 2379b6eb26fdSRiver Riddle auto depthIt = minOpPatternDepth.find(op); 2380b6eb26fdSRiver Riddle if (depthIt != minOpPatternDepth.end()) 2381b6eb26fdSRiver Riddle return depthIt->second; 2382b6eb26fdSRiver Riddle 2383b6eb26fdSRiver Riddle // If a mapping for this operation does not exist, then this operation 2384b6eb26fdSRiver Riddle // is always legal. Return 0 as the depth for a directly legal operation. 2385b6eb26fdSRiver Riddle auto opPatternsIt = legalizerPatterns.find(op); 2386b6eb26fdSRiver Riddle if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty()) 2387b6eb26fdSRiver Riddle return 0u; 2388b6eb26fdSRiver Riddle 2389b6eb26fdSRiver Riddle // Record this initial depth in case we encounter this op again when 2390b6eb26fdSRiver Riddle // recursively computing the depth. 2391b6eb26fdSRiver Riddle minOpPatternDepth.try_emplace(op, std::numeric_limits<unsigned>::max()); 2392b6eb26fdSRiver Riddle 2393b6eb26fdSRiver Riddle // Apply the cost model to the operation patterns, and update the minimum 2394b6eb26fdSRiver Riddle // depth. 2395b6eb26fdSRiver Riddle unsigned minDepth = applyCostModelToPatterns( 2396b6eb26fdSRiver Riddle opPatternsIt->second, minOpPatternDepth, legalizerPatterns); 2397b6eb26fdSRiver Riddle minOpPatternDepth[op] = minDepth; 2398b6eb26fdSRiver Riddle return minDepth; 2399b6eb26fdSRiver Riddle } 2400b6eb26fdSRiver Riddle 2401b6eb26fdSRiver Riddle unsigned OperationLegalizer::applyCostModelToPatterns( 2402b6eb26fdSRiver Riddle LegalizationPatterns &patterns, 2403b6eb26fdSRiver Riddle DenseMap<OperationName, unsigned> &minOpPatternDepth, 2404b6eb26fdSRiver Riddle DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) { 2405b6eb26fdSRiver Riddle unsigned minDepth = std::numeric_limits<unsigned>::max(); 2406b6eb26fdSRiver Riddle 2407b6eb26fdSRiver Riddle // Compute the depth for each pattern within the set. 2408b6eb26fdSRiver Riddle SmallVector<std::pair<const Pattern *, unsigned>, 4> patternsByDepth; 2409b6eb26fdSRiver Riddle patternsByDepth.reserve(patterns.size()); 2410b6eb26fdSRiver Riddle for (const Pattern *pattern : patterns) { 2411015192c6SRiver Riddle unsigned depth = 1; 2412b6eb26fdSRiver Riddle for (auto generatedOp : pattern->getGeneratedOps()) { 2413b6eb26fdSRiver Riddle unsigned generatedOpDepth = computeOpLegalizationDepth( 2414b6eb26fdSRiver Riddle generatedOp, minOpPatternDepth, legalizerPatterns); 2415b6eb26fdSRiver Riddle depth = std::max(depth, generatedOpDepth + 1); 2416b6eb26fdSRiver Riddle } 2417b6eb26fdSRiver Riddle patternsByDepth.emplace_back(pattern, depth); 2418b6eb26fdSRiver Riddle 2419b6eb26fdSRiver Riddle // Update the minimum depth of the pattern list. 2420b6eb26fdSRiver Riddle minDepth = std::min(minDepth, depth); 2421b6eb26fdSRiver Riddle } 2422b6eb26fdSRiver Riddle 2423b6eb26fdSRiver Riddle // If the operation only has one legalization pattern, there is no need to 2424b6eb26fdSRiver Riddle // sort them. 2425b6eb26fdSRiver Riddle if (patternsByDepth.size() == 1) 2426b6eb26fdSRiver Riddle return minDepth; 2427b6eb26fdSRiver Riddle 2428b6eb26fdSRiver Riddle // Sort the patterns by those likely to be the most beneficial. 2429ee3c6de7SXiang Li std::stable_sort(patternsByDepth.begin(), patternsByDepth.end(), 2430ee3c6de7SXiang Li [](const std::pair<const Pattern *, unsigned> &lhs, 2431ee3c6de7SXiang Li const std::pair<const Pattern *, unsigned> &rhs) { 2432b6eb26fdSRiver Riddle // First sort by the smaller pattern legalization 2433b6eb26fdSRiver Riddle // depth. 2434ee3c6de7SXiang Li if (lhs.second != rhs.second) 2435ee3c6de7SXiang Li return lhs.second < rhs.second; 2436b6eb26fdSRiver Riddle 2437b6eb26fdSRiver Riddle // Then sort by the larger pattern benefit. 2438ee3c6de7SXiang Li auto lhsBenefit = lhs.first->getBenefit(); 2439ee3c6de7SXiang Li auto rhsBenefit = rhs.first->getBenefit(); 2440ee3c6de7SXiang Li return lhsBenefit > rhsBenefit; 2441b6eb26fdSRiver Riddle }); 2442b6eb26fdSRiver Riddle 2443b6eb26fdSRiver Riddle // Update the legalization pattern to use the new sorted list. 2444b6eb26fdSRiver Riddle patterns.clear(); 2445b6eb26fdSRiver Riddle for (auto &patternIt : patternsByDepth) 2446b6eb26fdSRiver Riddle patterns.push_back(patternIt.first); 2447b6eb26fdSRiver Riddle return minDepth; 2448b6eb26fdSRiver Riddle } 2449b6eb26fdSRiver Riddle 2450b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 2451b6eb26fdSRiver Riddle // OperationConverter 2452b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 2453b6eb26fdSRiver Riddle namespace { 2454b6eb26fdSRiver Riddle enum OpConversionMode { 245501b55f16SRiver Riddle /// In this mode, the conversion will ignore failed conversions to allow 245601b55f16SRiver Riddle /// illegal operations to co-exist in the IR. 2457b6eb26fdSRiver Riddle Partial, 2458b6eb26fdSRiver Riddle 245901b55f16SRiver Riddle /// In this mode, all operations must be legal for the given target for the 246001b55f16SRiver Riddle /// conversion to succeed. 2461b6eb26fdSRiver Riddle Full, 2462b6eb26fdSRiver Riddle 246301b55f16SRiver Riddle /// In this mode, operations are analyzed for legality. No actual rewrites are 246401b55f16SRiver Riddle /// applied to the operations on success. 2465b6eb26fdSRiver Riddle Analysis, 2466b6eb26fdSRiver Riddle }; 2467b6eb26fdSRiver Riddle 2468b6eb26fdSRiver Riddle // This class converts operations to a given conversion target via a set of 2469b6eb26fdSRiver Riddle // rewrite patterns. The conversion behaves differently depending on the 2470b6eb26fdSRiver Riddle // conversion mode. 2471b6eb26fdSRiver Riddle struct OperationConverter { 2472370a6f09SMehdi Amini explicit OperationConverter(const ConversionTarget &target, 247379d7f618SChris Lattner const FrozenRewritePatternSet &patterns, 2474b6eb26fdSRiver Riddle OpConversionMode mode, 2475b6eb26fdSRiver Riddle DenseSet<Operation *> *trackedOps = nullptr) 2476b6eb26fdSRiver Riddle : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {} 2477b6eb26fdSRiver Riddle 2478b6eb26fdSRiver Riddle /// Converts the given operations to the conversion target. 2479b8c6b152SChia-hung Duan LogicalResult 2480b8c6b152SChia-hung Duan convertOperations(ArrayRef<Operation *> ops, 2481b8c6b152SChia-hung Duan function_ref<void(Diagnostic &)> notifyCallback = nullptr); 2482b6eb26fdSRiver Riddle 2483b6eb26fdSRiver Riddle private: 2484b6eb26fdSRiver Riddle /// Converts an operation with the given rewriter. 2485b6eb26fdSRiver Riddle LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); 2486b6eb26fdSRiver Riddle 2487b6eb26fdSRiver Riddle /// This method is called after the conversion process to legalize any 2488b6eb26fdSRiver Riddle /// remaining artifacts and complete the conversion. 2489b6eb26fdSRiver Riddle LogicalResult finalize(ConversionPatternRewriter &rewriter); 2490b6eb26fdSRiver Riddle 2491b6eb26fdSRiver Riddle /// Legalize the types of converted block arguments. 2492b6eb26fdSRiver Riddle LogicalResult 2493b6eb26fdSRiver Riddle legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter, 2494b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &rewriterImpl); 2495b6eb26fdSRiver Riddle 2496015192c6SRiver Riddle /// Legalize any unresolved type materializations. 2497015192c6SRiver Riddle LogicalResult legalizeUnresolvedMaterializations( 2498015192c6SRiver Riddle ConversionPatternRewriter &rewriter, 2499015192c6SRiver Riddle ConversionPatternRewriterImpl &rewriterImpl, 25000de16fafSRamkumar Ramachandra std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping); 2501015192c6SRiver Riddle 2502b6eb26fdSRiver Riddle /// Legalize an operation result that was marked as "erased". 2503b6eb26fdSRiver Riddle LogicalResult 2504b6eb26fdSRiver Riddle legalizeErasedResult(Operation *op, OpResult result, 2505b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &rewriterImpl); 2506b6eb26fdSRiver Riddle 2507b6eb26fdSRiver Riddle /// Legalize an operation result that was replaced with a value of a different 2508b6eb26fdSRiver Riddle /// type. 2509015192c6SRiver Riddle LogicalResult legalizeChangedResultType( 2510015192c6SRiver Riddle Operation *op, OpResult result, Value newValue, 2511ce254598SMatthias Springer const TypeConverter *replConverter, ConversionPatternRewriter &rewriter, 25125b91060dSAlex Zinenko ConversionPatternRewriterImpl &rewriterImpl, 2513015192c6SRiver Riddle const DenseMap<Value, SmallVector<Value>> &inverseMapping); 2514b6eb26fdSRiver Riddle 2515b6eb26fdSRiver Riddle /// The legalizer to use when converting operations. 2516b6eb26fdSRiver Riddle OperationLegalizer opLegalizer; 2517b6eb26fdSRiver Riddle 2518b6eb26fdSRiver Riddle /// The conversion mode to use when legalizing operations. 2519b6eb26fdSRiver Riddle OpConversionMode mode; 2520b6eb26fdSRiver Riddle 2521b6eb26fdSRiver Riddle /// A set of pre-existing operations. When mode == OpConversionMode::Analysis, 2522b6eb26fdSRiver Riddle /// this is populated with ops found to be legalizable to the target. 2523b6eb26fdSRiver Riddle /// When mode == OpConversionMode::Partial, this is populated with ops found 2524b6eb26fdSRiver Riddle /// *not* to be legalizable to the target. 2525b6eb26fdSRiver Riddle DenseSet<Operation *> *trackedOps; 2526b6eb26fdSRiver Riddle }; 2527be0a7e9fSMehdi Amini } // namespace 2528b6eb26fdSRiver Riddle 2529b6eb26fdSRiver Riddle LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, 2530b6eb26fdSRiver Riddle Operation *op) { 2531b6eb26fdSRiver Riddle // Legalize the given operation. 2532b6eb26fdSRiver Riddle if (failed(opLegalizer.legalize(op, rewriter))) { 2533b6eb26fdSRiver Riddle // Handle the case of a failed conversion for each of the different modes. 2534b6eb26fdSRiver Riddle // Full conversions expect all operations to be converted. 2535b6eb26fdSRiver Riddle if (mode == OpConversionMode::Full) 2536b6eb26fdSRiver Riddle return op->emitError() 2537b6eb26fdSRiver Riddle << "failed to legalize operation '" << op->getName() << "'"; 2538b6eb26fdSRiver Riddle // Partial conversions allow conversions to fail iff the operation was not 2539b6eb26fdSRiver Riddle // explicitly marked as illegal. If the user provided a nonlegalizableOps 2540b6eb26fdSRiver Riddle // set, non-legalizable ops are included. 2541b6eb26fdSRiver Riddle if (mode == OpConversionMode::Partial) { 2542b6eb26fdSRiver Riddle if (opLegalizer.isIllegal(op)) 2543b6eb26fdSRiver Riddle return op->emitError() 2544b6eb26fdSRiver Riddle << "failed to legalize operation '" << op->getName() 2545b6eb26fdSRiver Riddle << "' that was explicitly marked illegal"; 2546b6eb26fdSRiver Riddle if (trackedOps) 2547b6eb26fdSRiver Riddle trackedOps->insert(op); 2548b6eb26fdSRiver Riddle } 2549b6eb26fdSRiver Riddle } else if (mode == OpConversionMode::Analysis) { 2550b6eb26fdSRiver Riddle // Analysis conversions don't fail if any operations fail to legalize, 2551b6eb26fdSRiver Riddle // they are only interested in the operations that were successfully 2552b6eb26fdSRiver Riddle // legalized. 2553b6eb26fdSRiver Riddle trackedOps->insert(op); 2554b6eb26fdSRiver Riddle } 2555b6eb26fdSRiver Riddle return success(); 2556b6eb26fdSRiver Riddle } 2557b6eb26fdSRiver Riddle 2558b8c6b152SChia-hung Duan LogicalResult OperationConverter::convertOperations( 2559b8c6b152SChia-hung Duan ArrayRef<Operation *> ops, 2560b8c6b152SChia-hung Duan function_ref<void(Diagnostic &)> notifyCallback) { 2561b6eb26fdSRiver Riddle if (ops.empty()) 2562b6eb26fdSRiver Riddle return success(); 2563370a6f09SMehdi Amini const ConversionTarget &target = opLegalizer.getTarget(); 2564b6eb26fdSRiver Riddle 2565b6eb26fdSRiver Riddle // Compute the set of operations and blocks to convert. 2566015192c6SRiver Riddle SmallVector<Operation *> toConvert; 2567b6eb26fdSRiver Riddle for (auto *op : ops) { 2568b884f4efSMatthias Springer op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>( 2569b884f4efSMatthias Springer [&](Operation *op) { 2570b884f4efSMatthias Springer toConvert.push_back(op); 2571b884f4efSMatthias Springer // Don't check this operation's children for conversion if the 2572b884f4efSMatthias Springer // operation is recursively legal. 2573b884f4efSMatthias Springer auto legalityInfo = target.isLegal(op); 2574b884f4efSMatthias Springer if (legalityInfo && legalityInfo->isRecursivelyLegal) 2575b884f4efSMatthias Springer return WalkResult::skip(); 2576b884f4efSMatthias Springer return WalkResult::advance(); 2577b884f4efSMatthias Springer }); 2578b6eb26fdSRiver Riddle } 2579b6eb26fdSRiver Riddle 2580b6eb26fdSRiver Riddle // Convert each operation and discard rewrites on failure. 2581b6eb26fdSRiver Riddle ConversionPatternRewriter rewriter(ops.front()->getContext()); 2582b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); 2583b8c6b152SChia-hung Duan rewriterImpl.notifyCallback = notifyCallback; 2584b8c6b152SChia-hung Duan 2585b6eb26fdSRiver Riddle for (auto *op : toConvert) 2586b6eb26fdSRiver Riddle if (failed(convert(rewriter, op))) 2587b6eb26fdSRiver Riddle return rewriterImpl.discardRewrites(), failure(); 2588b6eb26fdSRiver Riddle 2589b6eb26fdSRiver Riddle // Now that all of the operations have been converted, finalize the conversion 2590b6eb26fdSRiver Riddle // process to ensure any lingering conversion artifacts are cleaned up and 2591b6eb26fdSRiver Riddle // legalized. 2592b6eb26fdSRiver Riddle if (failed(finalize(rewriter))) 2593b6eb26fdSRiver Riddle return rewriterImpl.discardRewrites(), failure(); 259401b55f16SRiver Riddle 2595b6eb26fdSRiver Riddle // After a successful conversion, apply rewrites if this is not an analysis 2596b6eb26fdSRiver Riddle // conversion. 259701b55f16SRiver Riddle if (mode == OpConversionMode::Analysis) { 2598b6eb26fdSRiver Riddle rewriterImpl.discardRewrites(); 259901b55f16SRiver Riddle } else { 2600b6eb26fdSRiver Riddle rewriterImpl.applyRewrites(); 2601a360a978SMehdi Amini 2602a360a978SMehdi Amini // It is possible for a later pattern to erase an op that was originally 2603a360a978SMehdi Amini // identified as illegal and added to the trackedOps, remove it now after 2604a360a978SMehdi Amini // replacements have been computed. 2605a360a978SMehdi Amini if (trackedOps) 2606a360a978SMehdi Amini for (auto &repl : rewriterImpl.replacements) 2607a360a978SMehdi Amini trackedOps->erase(repl.first); 2608a360a978SMehdi Amini } 2609b6eb26fdSRiver Riddle return success(); 2610b6eb26fdSRiver Riddle } 2611b6eb26fdSRiver Riddle 2612b6eb26fdSRiver Riddle LogicalResult 2613b6eb26fdSRiver Riddle OperationConverter::finalize(ConversionPatternRewriter &rewriter) { 26140de16fafSRamkumar Ramachandra std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping; 2615b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); 2616015192c6SRiver Riddle if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl, 2617015192c6SRiver Riddle inverseMapping)) || 2618015192c6SRiver Riddle failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) 2619b6eb26fdSRiver Riddle return failure(); 2620b6eb26fdSRiver Riddle 26215b91060dSAlex Zinenko if (rewriterImpl.operationsWithChangedResults.empty()) 26225b91060dSAlex Zinenko return success(); 26235b91060dSAlex Zinenko 2624b6eb26fdSRiver Riddle // Process requested operation replacements. 2625b6eb26fdSRiver Riddle for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size(); 2626b6eb26fdSRiver Riddle i != e; ++i) { 2627b6eb26fdSRiver Riddle unsigned replIdx = rewriterImpl.operationsWithChangedResults[i]; 2628b6eb26fdSRiver Riddle auto &repl = *(rewriterImpl.replacements.begin() + replIdx); 2629b6eb26fdSRiver Riddle for (OpResult result : repl.first->getResults()) { 2630b6eb26fdSRiver Riddle Value newValue = rewriterImpl.mapping.lookupOrNull(result); 2631b6eb26fdSRiver Riddle 2632b6eb26fdSRiver Riddle // If the operation result was replaced with null, all of the uses of this 2633b6eb26fdSRiver Riddle // value should be replaced. 2634b6eb26fdSRiver Riddle if (!newValue) { 2635b6eb26fdSRiver Riddle if (failed(legalizeErasedResult(repl.first, result, rewriterImpl))) 2636b6eb26fdSRiver Riddle return failure(); 2637b6eb26fdSRiver Riddle continue; 2638b6eb26fdSRiver Riddle } 2639b6eb26fdSRiver Riddle 2640b6eb26fdSRiver Riddle // Otherwise, check to see if the type of the result changed. 2641b6eb26fdSRiver Riddle if (result.getType() == newValue.getType()) 2642b6eb26fdSRiver Riddle continue; 2643b6eb26fdSRiver Riddle 26445b91060dSAlex Zinenko // Compute the inverse mapping only if it is really needed. 26455b91060dSAlex Zinenko if (!inverseMapping) 26465b91060dSAlex Zinenko inverseMapping = rewriterImpl.mapping.getInverse(); 26475b91060dSAlex Zinenko 2648b6eb26fdSRiver Riddle // Legalize this result. 2649b6eb26fdSRiver Riddle rewriter.setInsertionPoint(repl.first); 2650b6eb26fdSRiver Riddle if (failed(legalizeChangedResultType(repl.first, result, newValue, 2651b6eb26fdSRiver Riddle repl.second.converter, rewriter, 26525b91060dSAlex Zinenko rewriterImpl, *inverseMapping))) 2653b6eb26fdSRiver Riddle return failure(); 2654b6eb26fdSRiver Riddle 2655b6eb26fdSRiver Riddle // Update the end iterator for this loop in the case it was updated 2656b6eb26fdSRiver Riddle // when legalizing generated conversion operations. 2657b6eb26fdSRiver Riddle e = rewriterImpl.operationsWithChangedResults.size(); 2658b6eb26fdSRiver Riddle } 2659b6eb26fdSRiver Riddle } 2660b6eb26fdSRiver Riddle return success(); 2661b6eb26fdSRiver Riddle } 2662b6eb26fdSRiver Riddle 2663b6eb26fdSRiver Riddle LogicalResult OperationConverter::legalizeConvertedArgumentTypes( 2664b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter, 2665b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &rewriterImpl) { 2666b6eb26fdSRiver Riddle // Functor used to check if all users of a value will be dead after 2667b6eb26fdSRiver Riddle // conversion. 2668b6eb26fdSRiver Riddle auto findLiveUser = [&](Value val) { 2669b6eb26fdSRiver Riddle auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) { 2670b6eb26fdSRiver Riddle return rewriterImpl.isOpIgnored(user); 2671b6eb26fdSRiver Riddle }); 2672b6eb26fdSRiver Riddle return liveUserIt == val.user_end() ? nullptr : *liveUserIt; 2673b6eb26fdSRiver Riddle }; 2674015192c6SRiver Riddle return rewriterImpl.argConverter.materializeLiveConversions( 2675015192c6SRiver Riddle rewriterImpl.mapping, rewriter, findLiveUser); 2676b6eb26fdSRiver Riddle } 2677015192c6SRiver Riddle 2678015192c6SRiver Riddle /// Replace the results of a materialization operation with the given values. 2679015192c6SRiver Riddle static void 2680015192c6SRiver Riddle replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, 2681015192c6SRiver Riddle ResultRange matResults, ValueRange values, 2682015192c6SRiver Riddle DenseMap<Value, SmallVector<Value>> &inverseMapping) { 2683015192c6SRiver Riddle matResults.replaceAllUsesWith(values); 2684015192c6SRiver Riddle 2685015192c6SRiver Riddle // For each of the materialization results, update the inverse mappings to 2686015192c6SRiver Riddle // point to the replacement values. 26879fa59e76SBenjamin Kramer for (auto [matResult, newValue] : llvm::zip(matResults, values)) { 2688015192c6SRiver Riddle auto inverseMapIt = inverseMapping.find(matResult); 2689015192c6SRiver Riddle if (inverseMapIt == inverseMapping.end()) 2690015192c6SRiver Riddle continue; 2691015192c6SRiver Riddle 2692015192c6SRiver Riddle // Update the reverse mapping, or remove the mapping if we couldn't update 2693015192c6SRiver Riddle // it. Not being able to update signals that the mapping would have become 2694015192c6SRiver Riddle // circular (i.e. %foo -> newValue -> %foo), which may occur as values are 2695015192c6SRiver Riddle // propagated through temporary materializations. We simply drop the 2696015192c6SRiver Riddle // mapping, and let the post-conversion replacement logic handle updating 2697015192c6SRiver Riddle // uses. 2698015192c6SRiver Riddle for (Value inverseMapVal : inverseMapIt->second) 2699015192c6SRiver Riddle if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue)) 2700015192c6SRiver Riddle rewriterImpl.mapping.erase(inverseMapVal); 2701015192c6SRiver Riddle } 2702015192c6SRiver Riddle } 2703015192c6SRiver Riddle 2704015192c6SRiver Riddle /// Compute all of the unresolved materializations that will persist beyond the 2705015192c6SRiver Riddle /// conversion process, and require inserting a proper user materialization for. 2706015192c6SRiver Riddle static void computeNecessaryMaterializations( 2707015192c6SRiver Riddle DenseMap<Operation *, UnresolvedMaterialization *> &materializationOps, 2708015192c6SRiver Riddle ConversionPatternRewriter &rewriter, 2709015192c6SRiver Riddle ConversionPatternRewriterImpl &rewriterImpl, 2710015192c6SRiver Riddle DenseMap<Value, SmallVector<Value>> &inverseMapping, 2711015192c6SRiver Riddle SetVector<UnresolvedMaterialization *> &necessaryMaterializations) { 2712015192c6SRiver Riddle auto isLive = [&](Value value) { 2713015192c6SRiver Riddle auto findFn = [&](Operation *user) { 2714015192c6SRiver Riddle auto matIt = materializationOps.find(user); 2715015192c6SRiver Riddle if (matIt != materializationOps.end()) 2716015192c6SRiver Riddle return !necessaryMaterializations.count(matIt->second); 2717015192c6SRiver Riddle return rewriterImpl.isOpIgnored(user); 2718015192c6SRiver Riddle }; 2719d4a53f3bSAlex Zinenko // This value may be replacing another value that has a live user. 2720d4a53f3bSAlex Zinenko for (Value inv : inverseMapping.lookup(value)) 2721d4a53f3bSAlex Zinenko if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end()) 2722d4a53f3bSAlex Zinenko return true; 2723d4a53f3bSAlex Zinenko // Or have live users itself. 2724015192c6SRiver Riddle return llvm::find_if_not(value.getUsers(), findFn) != value.user_end(); 2725015192c6SRiver Riddle }; 2726015192c6SRiver Riddle 2727015192c6SRiver Riddle llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue = 2728015192c6SRiver Riddle [&](Value invalidRoot, Value value, Type type) { 2729015192c6SRiver Riddle // Check to see if the input operation was remapped to a variant of the 2730015192c6SRiver Riddle // output. 2731015192c6SRiver Riddle Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type); 2732015192c6SRiver Riddle if (remappedValue.getType() == type && remappedValue != invalidRoot) 2733015192c6SRiver Riddle return remappedValue; 2734015192c6SRiver Riddle 2735015192c6SRiver Riddle // Check to see if the input is a materialization operation that 2736015192c6SRiver Riddle // provides an inverse conversion. We just check blindly for 2737015192c6SRiver Riddle // UnrealizedConversionCastOp here, but it has no effect on correctness. 2738015192c6SRiver Riddle auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>(); 2739015192c6SRiver Riddle if (inputCastOp && inputCastOp->getNumOperands() == 1) 2740015192c6SRiver Riddle return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0), 2741015192c6SRiver Riddle type); 2742015192c6SRiver Riddle 2743015192c6SRiver Riddle return Value(); 2744015192c6SRiver Riddle }; 2745015192c6SRiver Riddle 2746015192c6SRiver Riddle SetVector<UnresolvedMaterialization *> worklist; 2747015192c6SRiver Riddle for (auto &mat : rewriterImpl.unresolvedMaterializations) { 2748015192c6SRiver Riddle materializationOps.try_emplace(mat.getOp(), &mat); 2749015192c6SRiver Riddle worklist.insert(&mat); 2750015192c6SRiver Riddle } 2751015192c6SRiver Riddle while (!worklist.empty()) { 2752015192c6SRiver Riddle UnresolvedMaterialization *mat = worklist.pop_back_val(); 2753015192c6SRiver Riddle UnrealizedConversionCastOp op = mat->getOp(); 2754015192c6SRiver Riddle 2755015192c6SRiver Riddle // We currently only handle target materializations here. 2756015192c6SRiver Riddle assert(op->getNumResults() == 1 && "unexpected materialization type"); 2757015192c6SRiver Riddle OpResult opResult = op->getOpResult(0); 2758015192c6SRiver Riddle Type outputType = opResult.getType(); 2759015192c6SRiver Riddle Operation::operand_range inputOperands = op.getOperands(); 2760015192c6SRiver Riddle 2761015192c6SRiver Riddle // Try to forward propagate operands for user conversion casts that result 2762015192c6SRiver Riddle // in the input types of the current cast. 2763015192c6SRiver Riddle for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) { 2764015192c6SRiver Riddle auto castOp = dyn_cast<UnrealizedConversionCastOp>(user); 2765015192c6SRiver Riddle if (!castOp) 2766015192c6SRiver Riddle continue; 2767015192c6SRiver Riddle if (castOp->getResultTypes() == inputOperands.getTypes()) { 2768015192c6SRiver Riddle replaceMaterialization(rewriterImpl, opResult, inputOperands, 2769015192c6SRiver Riddle inverseMapping); 2770015192c6SRiver Riddle necessaryMaterializations.remove(materializationOps.lookup(user)); 2771015192c6SRiver Riddle } 2772015192c6SRiver Riddle } 2773015192c6SRiver Riddle 2774015192c6SRiver Riddle // Try to avoid materializing a resolved materialization if possible. 2775015192c6SRiver Riddle // Handle the case of a 1-1 materialization. 2776015192c6SRiver Riddle if (inputOperands.size() == 1) { 2777015192c6SRiver Riddle // Check to see if the input operation was remapped to a variant of the 2778015192c6SRiver Riddle // output. 2779015192c6SRiver Riddle Value remappedValue = 2780015192c6SRiver Riddle lookupRemappedValue(opResult, inputOperands[0], outputType); 2781015192c6SRiver Riddle if (remappedValue && remappedValue != opResult) { 2782015192c6SRiver Riddle replaceMaterialization(rewriterImpl, opResult, remappedValue, 2783015192c6SRiver Riddle inverseMapping); 2784015192c6SRiver Riddle necessaryMaterializations.remove(mat); 2785015192c6SRiver Riddle continue; 2786015192c6SRiver Riddle } 2787015192c6SRiver Riddle } else { 2788015192c6SRiver Riddle // TODO: Avoid materializing other types of conversions here. 2789015192c6SRiver Riddle } 2790015192c6SRiver Riddle 2791015192c6SRiver Riddle // Check to see if this is an argument materialization. 27925550c821STres Popp auto isBlockArg = [](Value v) { return isa<BlockArgument>(v); }; 2793015192c6SRiver Riddle if (llvm::any_of(op->getOperands(), isBlockArg) || 2794015192c6SRiver Riddle llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) { 2795015192c6SRiver Riddle mat->setKind(UnresolvedMaterialization::Argument); 2796015192c6SRiver Riddle } 2797015192c6SRiver Riddle 2798015192c6SRiver Riddle // If the materialization does not have any live users, we don't need to 2799015192c6SRiver Riddle // generate a user materialization for it. 2800015192c6SRiver Riddle // FIXME: For argument materializations, we currently need to check if any 2801015192c6SRiver Riddle // of the inverse mapped values are used because some patterns expect blind 2802015192c6SRiver Riddle // value replacement even if the types differ in some cases. When those 2803015192c6SRiver Riddle // patterns are fixed, we can drop the argument special case here. 2804015192c6SRiver Riddle bool isMaterializationLive = isLive(opResult); 2805015192c6SRiver Riddle if (mat->getKind() == UnresolvedMaterialization::Argument) 2806015192c6SRiver Riddle isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive); 2807015192c6SRiver Riddle if (!isMaterializationLive) 2808015192c6SRiver Riddle continue; 2809015192c6SRiver Riddle if (!necessaryMaterializations.insert(mat)) 2810015192c6SRiver Riddle continue; 2811015192c6SRiver Riddle 2812015192c6SRiver Riddle // Reprocess input materializations to see if they have an updated status. 2813015192c6SRiver Riddle for (Value input : inputOperands) { 2814015192c6SRiver Riddle if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) { 2815015192c6SRiver Riddle if (auto *mat = materializationOps.lookup(parentOp)) 2816015192c6SRiver Riddle worklist.insert(mat); 2817015192c6SRiver Riddle } 2818015192c6SRiver Riddle } 2819015192c6SRiver Riddle } 2820015192c6SRiver Riddle } 2821015192c6SRiver Riddle 2822015192c6SRiver Riddle /// Legalize the given unresolved materialization. Returns success if the 2823015192c6SRiver Riddle /// materialization was legalized, failure otherise. 2824015192c6SRiver Riddle static LogicalResult legalizeUnresolvedMaterialization( 2825015192c6SRiver Riddle UnresolvedMaterialization &mat, 2826015192c6SRiver Riddle DenseMap<Operation *, UnresolvedMaterialization *> &materializationOps, 2827015192c6SRiver Riddle ConversionPatternRewriter &rewriter, 2828015192c6SRiver Riddle ConversionPatternRewriterImpl &rewriterImpl, 2829015192c6SRiver Riddle DenseMap<Value, SmallVector<Value>> &inverseMapping) { 2830015192c6SRiver Riddle auto findLiveUser = [&](auto &&users) { 2831015192c6SRiver Riddle auto liveUserIt = llvm::find_if_not( 2832015192c6SRiver Riddle users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); }); 2833015192c6SRiver Riddle return liveUserIt == users.end() ? nullptr : *liveUserIt; 2834015192c6SRiver Riddle }; 2835015192c6SRiver Riddle 2836015192c6SRiver Riddle llvm::unique_function<Value(Value, Type)> lookupRemappedValue = 2837015192c6SRiver Riddle [&](Value value, Type type) { 2838015192c6SRiver Riddle // Check to see if the input operation was remapped to a variant of the 2839015192c6SRiver Riddle // output. 2840015192c6SRiver Riddle Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type); 2841015192c6SRiver Riddle if (remappedValue.getType() == type) 2842015192c6SRiver Riddle return remappedValue; 2843015192c6SRiver Riddle return Value(); 2844015192c6SRiver Riddle }; 2845015192c6SRiver Riddle 2846015192c6SRiver Riddle UnrealizedConversionCastOp op = mat.getOp(); 2847015192c6SRiver Riddle if (!rewriterImpl.ignoredOps.insert(op)) 2848015192c6SRiver Riddle return success(); 2849015192c6SRiver Riddle 2850015192c6SRiver Riddle // We currently only handle target materializations here. 2851015192c6SRiver Riddle OpResult opResult = op->getOpResult(0); 2852015192c6SRiver Riddle Operation::operand_range inputOperands = op.getOperands(); 2853015192c6SRiver Riddle Type outputType = opResult.getType(); 2854015192c6SRiver Riddle 2855015192c6SRiver Riddle // If any input to this materialization is another materialization, resolve 2856015192c6SRiver Riddle // the input first. 2857015192c6SRiver Riddle for (Value value : op->getOperands()) { 2858015192c6SRiver Riddle auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>(); 2859015192c6SRiver Riddle if (!valueCast) 2860015192c6SRiver Riddle continue; 2861015192c6SRiver Riddle 2862015192c6SRiver Riddle auto matIt = materializationOps.find(valueCast); 2863015192c6SRiver Riddle if (matIt != materializationOps.end()) 2864015192c6SRiver Riddle if (failed(legalizeUnresolvedMaterialization( 2865015192c6SRiver Riddle *matIt->second, materializationOps, rewriter, rewriterImpl, 2866015192c6SRiver Riddle inverseMapping))) 2867015192c6SRiver Riddle return failure(); 2868015192c6SRiver Riddle } 2869015192c6SRiver Riddle 2870015192c6SRiver Riddle // Perform a last ditch attempt to avoid materializing a resolved 2871015192c6SRiver Riddle // materialization if possible. 2872015192c6SRiver Riddle // Handle the case of a 1-1 materialization. 2873015192c6SRiver Riddle if (inputOperands.size() == 1) { 2874015192c6SRiver Riddle // Check to see if the input operation was remapped to a variant of the 2875015192c6SRiver Riddle // output. 2876015192c6SRiver Riddle Value remappedValue = lookupRemappedValue(inputOperands[0], outputType); 2877015192c6SRiver Riddle if (remappedValue && remappedValue != opResult) { 2878015192c6SRiver Riddle replaceMaterialization(rewriterImpl, opResult, remappedValue, 2879015192c6SRiver Riddle inverseMapping); 2880015192c6SRiver Riddle return success(); 2881015192c6SRiver Riddle } 2882015192c6SRiver Riddle } else { 2883015192c6SRiver Riddle // TODO: Avoid materializing other types of conversions here. 2884015192c6SRiver Riddle } 2885015192c6SRiver Riddle 2886015192c6SRiver Riddle // Try to materialize the conversion. 2887ce254598SMatthias Springer if (const TypeConverter *converter = mat.getConverter()) { 2888015192c6SRiver Riddle // FIXME: Determine a suitable insertion location when there are multiple 2889015192c6SRiver Riddle // inputs. 2890015192c6SRiver Riddle if (inputOperands.size() == 1) 2891015192c6SRiver Riddle rewriter.setInsertionPointAfterValue(inputOperands.front()); 2892015192c6SRiver Riddle else 2893015192c6SRiver Riddle rewriter.setInsertionPoint(op); 2894015192c6SRiver Riddle 2895015192c6SRiver Riddle Value newMaterialization; 2896015192c6SRiver Riddle switch (mat.getKind()) { 2897015192c6SRiver Riddle case UnresolvedMaterialization::Argument: 2898015192c6SRiver Riddle // Try to materialize an argument conversion. 2899015192c6SRiver Riddle // FIXME: The current argument materialization hook expects the original 2900015192c6SRiver Riddle // output type, even though it doesn't use that as the actual output type 2901015192c6SRiver Riddle // of the generated IR. The output type is just used as an indicator of 2902015192c6SRiver Riddle // the type of materialization to do. This behavior is really awkward in 2903015192c6SRiver Riddle // that it diverges from the behavior of the other hooks, and can be 2904015192c6SRiver Riddle // easily misunderstood. We should clean up the argument hooks to better 2905015192c6SRiver Riddle // represent the desired invariants we actually care about. 2906015192c6SRiver Riddle newMaterialization = converter->materializeArgumentConversion( 2907015192c6SRiver Riddle rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands); 2908015192c6SRiver Riddle if (newMaterialization) 2909015192c6SRiver Riddle break; 2910015192c6SRiver Riddle 2911015192c6SRiver Riddle // If an argument materialization failed, fallback to trying a target 2912015192c6SRiver Riddle // materialization. 2913fc63c054SFangrui Song [[fallthrough]]; 2914015192c6SRiver Riddle case UnresolvedMaterialization::Target: 2915015192c6SRiver Riddle newMaterialization = converter->materializeTargetConversion( 2916015192c6SRiver Riddle rewriter, op->getLoc(), outputType, inputOperands); 2917015192c6SRiver Riddle break; 2918015192c6SRiver Riddle } 2919015192c6SRiver Riddle if (newMaterialization) { 2920015192c6SRiver Riddle replaceMaterialization(rewriterImpl, opResult, newMaterialization, 2921015192c6SRiver Riddle inverseMapping); 2922015192c6SRiver Riddle return success(); 2923015192c6SRiver Riddle } 2924015192c6SRiver Riddle } 2925015192c6SRiver Riddle 2926015192c6SRiver Riddle InFlightDiagnostic diag = op->emitError() 2927015192c6SRiver Riddle << "failed to legalize unresolved materialization " 2928015192c6SRiver Riddle "from " 2929015192c6SRiver Riddle << inputOperands.getTypes() << " to " << outputType 2930015192c6SRiver Riddle << " that remained live after conversion"; 2931015192c6SRiver Riddle if (Operation *liveUser = findLiveUser(op->getUsers())) { 2932015192c6SRiver Riddle diag.attachNote(liveUser->getLoc()) 2933015192c6SRiver Riddle << "see existing live user here: " << *liveUser; 2934015192c6SRiver Riddle } 2935015192c6SRiver Riddle return failure(); 2936015192c6SRiver Riddle } 2937015192c6SRiver Riddle 2938015192c6SRiver Riddle LogicalResult OperationConverter::legalizeUnresolvedMaterializations( 2939015192c6SRiver Riddle ConversionPatternRewriter &rewriter, 2940015192c6SRiver Riddle ConversionPatternRewriterImpl &rewriterImpl, 29410de16fafSRamkumar Ramachandra std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) { 2942015192c6SRiver Riddle if (rewriterImpl.unresolvedMaterializations.empty()) 2943015192c6SRiver Riddle return success(); 2944015192c6SRiver Riddle inverseMapping = rewriterImpl.mapping.getInverse(); 2945015192c6SRiver Riddle 2946015192c6SRiver Riddle // As an initial step, compute all of the inserted materializations that we 2947015192c6SRiver Riddle // expect to persist beyond the conversion process. 2948015192c6SRiver Riddle DenseMap<Operation *, UnresolvedMaterialization *> materializationOps; 2949015192c6SRiver Riddle SetVector<UnresolvedMaterialization *> necessaryMaterializations; 2950015192c6SRiver Riddle computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl, 2951015192c6SRiver Riddle *inverseMapping, necessaryMaterializations); 2952015192c6SRiver Riddle 2953015192c6SRiver Riddle // Once computed, legalize any necessary materializations. 2954015192c6SRiver Riddle for (auto *mat : necessaryMaterializations) { 2955015192c6SRiver Riddle if (failed(legalizeUnresolvedMaterialization( 2956015192c6SRiver Riddle *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping))) 2957015192c6SRiver Riddle return failure(); 2958b6eb26fdSRiver Riddle } 2959b6eb26fdSRiver Riddle return success(); 2960b6eb26fdSRiver Riddle } 2961b6eb26fdSRiver Riddle 2962b6eb26fdSRiver Riddle LogicalResult OperationConverter::legalizeErasedResult( 2963b6eb26fdSRiver Riddle Operation *op, OpResult result, 2964b6eb26fdSRiver Riddle ConversionPatternRewriterImpl &rewriterImpl) { 2965b6eb26fdSRiver Riddle // If the operation result was replaced with null, all of the uses of this 2966b6eb26fdSRiver Riddle // value should be replaced. 2967b6eb26fdSRiver Riddle auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) { 2968b6eb26fdSRiver Riddle return rewriterImpl.isOpIgnored(user); 2969b6eb26fdSRiver Riddle }); 2970b6eb26fdSRiver Riddle if (liveUserIt != result.user_end()) { 2971b6eb26fdSRiver Riddle InFlightDiagnostic diag = op->emitError("failed to legalize operation '") 2972b6eb26fdSRiver Riddle << op->getName() << "' marked as erased"; 2973b6eb26fdSRiver Riddle diag.attachNote(liveUserIt->getLoc()) 2974b6eb26fdSRiver Riddle << "found live user of result #" << result.getResultNumber() << ": " 2975b6eb26fdSRiver Riddle << *liveUserIt; 2976b6eb26fdSRiver Riddle return failure(); 2977b6eb26fdSRiver Riddle } 2978b6eb26fdSRiver Riddle return success(); 2979b6eb26fdSRiver Riddle } 2980b6eb26fdSRiver Riddle 29815b91060dSAlex Zinenko /// Finds a user of the given value, or of any other value that the given value 29825b91060dSAlex Zinenko /// replaced, that was not replaced in the conversion process. 2983015192c6SRiver Riddle static Operation *findLiveUserOfReplaced( 2984015192c6SRiver Riddle Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, 2985015192c6SRiver Riddle const DenseMap<Value, SmallVector<Value>> &inverseMapping) { 2986015192c6SRiver Riddle SmallVector<Value> worklist(1, initialValue); 2987015192c6SRiver Riddle while (!worklist.empty()) { 2988015192c6SRiver Riddle Value value = worklist.pop_back_val(); 2989015192c6SRiver Riddle 29905b91060dSAlex Zinenko // Walk the users of this value to see if there are any live users that 29915b91060dSAlex Zinenko // weren't replaced during conversion. 29925b91060dSAlex Zinenko auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) { 29935b91060dSAlex Zinenko return rewriterImpl.isOpIgnored(user); 29945b91060dSAlex Zinenko }); 29955b91060dSAlex Zinenko if (liveUserIt != value.user_end()) 29965b91060dSAlex Zinenko return *liveUserIt; 2997015192c6SRiver Riddle auto mapIt = inverseMapping.find(value); 2998015192c6SRiver Riddle if (mapIt != inverseMapping.end()) 2999015192c6SRiver Riddle worklist.append(mapIt->second); 3000015192c6SRiver Riddle } 30015b91060dSAlex Zinenko return nullptr; 30025b91060dSAlex Zinenko } 30035b91060dSAlex Zinenko 3004b6eb26fdSRiver Riddle LogicalResult OperationConverter::legalizeChangedResultType( 3005b6eb26fdSRiver Riddle Operation *op, OpResult result, Value newValue, 3006ce254598SMatthias Springer const TypeConverter *replConverter, ConversionPatternRewriter &rewriter, 30075b91060dSAlex Zinenko ConversionPatternRewriterImpl &rewriterImpl, 3008015192c6SRiver Riddle const DenseMap<Value, SmallVector<Value>> &inverseMapping) { 30095b91060dSAlex Zinenko Operation *liveUser = 30105b91060dSAlex Zinenko findLiveUserOfReplaced(result, rewriterImpl, inverseMapping); 30115b91060dSAlex Zinenko if (!liveUser) 3012b6eb26fdSRiver Riddle return success(); 3013b6eb26fdSRiver Riddle 3014015192c6SRiver Riddle // Functor used to emit a conversion error for a failed materialization. 3015015192c6SRiver Riddle auto emitConversionError = [&] { 3016b6eb26fdSRiver Riddle InFlightDiagnostic diag = op->emitError() 3017b6eb26fdSRiver Riddle << "failed to materialize conversion for result #" 3018b6eb26fdSRiver Riddle << result.getResultNumber() << " of operation '" 3019b6eb26fdSRiver Riddle << op->getName() 3020b6eb26fdSRiver Riddle << "' that remained live after conversion"; 30215b91060dSAlex Zinenko diag.attachNote(liveUser->getLoc()) 30225b91060dSAlex Zinenko << "see existing live user here: " << *liveUser; 3023b6eb26fdSRiver Riddle return failure(); 3024015192c6SRiver Riddle }; 3025b6eb26fdSRiver Riddle 3026015192c6SRiver Riddle // If the replacement has a type converter, attempt to materialize a 3027015192c6SRiver Riddle // conversion back to the original type. 3028015192c6SRiver Riddle if (!replConverter) 3029015192c6SRiver Riddle return emitConversionError(); 3030015192c6SRiver Riddle 3031015192c6SRiver Riddle // Materialize a conversion for this live result value. 3032015192c6SRiver Riddle Type resultType = result.getType(); 3033015192c6SRiver Riddle Value convertedValue = replConverter->materializeSourceConversion( 3034015192c6SRiver Riddle rewriter, op->getLoc(), resultType, newValue); 3035015192c6SRiver Riddle if (!convertedValue) 3036015192c6SRiver Riddle return emitConversionError(); 3037b6eb26fdSRiver Riddle 3038b6eb26fdSRiver Riddle rewriterImpl.mapping.map(result, convertedValue); 3039b6eb26fdSRiver Riddle return success(); 3040b6eb26fdSRiver Riddle } 3041b6eb26fdSRiver Riddle 3042b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 3043b6eb26fdSRiver Riddle // Type Conversion 3044b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 3045b6eb26fdSRiver Riddle 3046b6eb26fdSRiver Riddle void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo, 3047b6eb26fdSRiver Riddle ArrayRef<Type> types) { 3048b6eb26fdSRiver Riddle assert(!types.empty() && "expected valid types"); 3049b6eb26fdSRiver Riddle remapInput(origInputNo, /*newInputNo=*/argTypes.size(), types.size()); 3050b6eb26fdSRiver Riddle addInputs(types); 3051b6eb26fdSRiver Riddle } 3052b6eb26fdSRiver Riddle 3053b6eb26fdSRiver Riddle void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) { 3054b6eb26fdSRiver Riddle assert(!types.empty() && 3055b6eb26fdSRiver Riddle "1->0 type remappings don't need to be added explicitly"); 3056b6eb26fdSRiver Riddle argTypes.append(types.begin(), types.end()); 3057b6eb26fdSRiver Riddle } 3058b6eb26fdSRiver Riddle 3059b6eb26fdSRiver Riddle void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, 3060b6eb26fdSRiver Riddle unsigned newInputNo, 3061b6eb26fdSRiver Riddle unsigned newInputCount) { 3062b6eb26fdSRiver Riddle assert(!remappedInputs[origInputNo] && "input has already been remapped"); 3063b6eb26fdSRiver Riddle assert(newInputCount != 0 && "expected valid input count"); 3064b6eb26fdSRiver Riddle remappedInputs[origInputNo] = 3065b6eb26fdSRiver Riddle InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr}; 3066b6eb26fdSRiver Riddle } 3067b6eb26fdSRiver Riddle 3068b6eb26fdSRiver Riddle void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo, 3069b6eb26fdSRiver Riddle Value replacementValue) { 3070b6eb26fdSRiver Riddle assert(!remappedInputs[origInputNo] && "input has already been remapped"); 3071b6eb26fdSRiver Riddle remappedInputs[origInputNo] = 3072b6eb26fdSRiver Riddle InputMapping{origInputNo, /*size=*/0, replacementValue}; 3073b6eb26fdSRiver Riddle } 3074b6eb26fdSRiver Riddle 3075b6eb26fdSRiver Riddle LogicalResult TypeConverter::convertType(Type t, 30763dd58333SMatthias Springer SmallVectorImpl<Type> &results) const { 3077a8daefedSMehdi Amini { 3078a8daefedSMehdi Amini std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex, 3079a8daefedSMehdi Amini std::defer_lock); 3080a8daefedSMehdi Amini if (t.getContext()->isMultithreadingEnabled()) 3081a8daefedSMehdi Amini cacheReadLock.lock(); 3082b6eb26fdSRiver Riddle auto existingIt = cachedDirectConversions.find(t); 3083b6eb26fdSRiver Riddle if (existingIt != cachedDirectConversions.end()) { 3084b6eb26fdSRiver Riddle if (existingIt->second) 3085b6eb26fdSRiver Riddle results.push_back(existingIt->second); 3086b6eb26fdSRiver Riddle return success(existingIt->second != nullptr); 3087b6eb26fdSRiver Riddle } 3088b6eb26fdSRiver Riddle auto multiIt = cachedMultiConversions.find(t); 3089b6eb26fdSRiver Riddle if (multiIt != cachedMultiConversions.end()) { 3090b6eb26fdSRiver Riddle results.append(multiIt->second.begin(), multiIt->second.end()); 3091b6eb26fdSRiver Riddle return success(); 3092b6eb26fdSRiver Riddle } 3093a8daefedSMehdi Amini } 3094b6eb26fdSRiver Riddle // Walk the added converters in reverse order to apply the most recently 3095b6eb26fdSRiver Riddle // registered first. 3096b6eb26fdSRiver Riddle size_t currentCount = results.size(); 3097dc3dc974SMehdi Amini 3098a8daefedSMehdi Amini std::unique_lock<decltype(cacheMutex)> cacheWriteLock(cacheMutex, 3099a8daefedSMehdi Amini std::defer_lock); 3100a8daefedSMehdi Amini 31013dd58333SMatthias Springer for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { 3102dc3dc974SMehdi Amini if (std::optional<LogicalResult> result = converter(t, results)) { 3103a8daefedSMehdi Amini if (t.getContext()->isMultithreadingEnabled()) 3104a8daefedSMehdi Amini cacheWriteLock.lock(); 3105b6eb26fdSRiver Riddle if (!succeeded(*result)) { 3106b6eb26fdSRiver Riddle cachedDirectConversions.try_emplace(t, nullptr); 3107b6eb26fdSRiver Riddle return failure(); 3108b6eb26fdSRiver Riddle } 3109b6eb26fdSRiver Riddle auto newTypes = ArrayRef<Type>(results).drop_front(currentCount); 3110b6eb26fdSRiver Riddle if (newTypes.size() == 1) 3111b6eb26fdSRiver Riddle cachedDirectConversions.try_emplace(t, newTypes.front()); 3112b6eb26fdSRiver Riddle else 3113b6eb26fdSRiver Riddle cachedMultiConversions.try_emplace(t, llvm::to_vector<2>(newTypes)); 3114b6eb26fdSRiver Riddle return success(); 3115b6eb26fdSRiver Riddle } 3116b6eb26fdSRiver Riddle } 3117b6eb26fdSRiver Riddle return failure(); 3118b6eb26fdSRiver Riddle } 3119b6eb26fdSRiver Riddle 31203dd58333SMatthias Springer Type TypeConverter::convertType(Type t) const { 3121b6eb26fdSRiver Riddle // Use the multi-type result version to convert the type. 3122b6eb26fdSRiver Riddle SmallVector<Type, 1> results; 3123b6eb26fdSRiver Riddle if (failed(convertType(t, results))) 3124b6eb26fdSRiver Riddle return nullptr; 3125b6eb26fdSRiver Riddle 3126b6eb26fdSRiver Riddle // Check to ensure that only one type was produced. 3127b6eb26fdSRiver Riddle return results.size() == 1 ? results.front() : nullptr; 3128b6eb26fdSRiver Riddle } 3129b6eb26fdSRiver Riddle 31303dd58333SMatthias Springer LogicalResult 31313dd58333SMatthias Springer TypeConverter::convertTypes(TypeRange types, 31323dd58333SMatthias Springer SmallVectorImpl<Type> &results) const { 31333dfa8614SRiver Riddle for (Type type : types) 3134b6eb26fdSRiver Riddle if (failed(convertType(type, results))) 3135b6eb26fdSRiver Riddle return failure(); 3136b6eb26fdSRiver Riddle return success(); 3137b6eb26fdSRiver Riddle } 3138b6eb26fdSRiver Riddle 31393dd58333SMatthias Springer bool TypeConverter::isLegal(Type type) const { 31403dd58333SMatthias Springer return convertType(type) == type; 31413dd58333SMatthias Springer } 31423dd58333SMatthias Springer bool TypeConverter::isLegal(Operation *op) const { 3143b6eb26fdSRiver Riddle return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); 3144b6eb26fdSRiver Riddle } 3145b6eb26fdSRiver Riddle 31463dd58333SMatthias Springer bool TypeConverter::isLegal(Region *region) const { 3147b6eb26fdSRiver Riddle return llvm::all_of(*region, [this](Block &block) { 3148b6eb26fdSRiver Riddle return isLegal(block.getArgumentTypes()); 3149b6eb26fdSRiver Riddle }); 3150b6eb26fdSRiver Riddle } 3151b6eb26fdSRiver Riddle 31523dd58333SMatthias Springer bool TypeConverter::isSignatureLegal(FunctionType ty) const { 3153b6eb26fdSRiver Riddle return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults())); 3154b6eb26fdSRiver Riddle } 3155b6eb26fdSRiver Riddle 31563dd58333SMatthias Springer LogicalResult 31573dd58333SMatthias Springer TypeConverter::convertSignatureArg(unsigned inputNo, Type type, 31583dd58333SMatthias Springer SignatureConversion &result) const { 3159b6eb26fdSRiver Riddle // Try to convert the given input type. 3160b6eb26fdSRiver Riddle SmallVector<Type, 1> convertedTypes; 3161b6eb26fdSRiver Riddle if (failed(convertType(type, convertedTypes))) 3162b6eb26fdSRiver Riddle return failure(); 3163b6eb26fdSRiver Riddle 3164b6eb26fdSRiver Riddle // If this argument is being dropped, there is nothing left to do. 3165b6eb26fdSRiver Riddle if (convertedTypes.empty()) 3166b6eb26fdSRiver Riddle return success(); 3167b6eb26fdSRiver Riddle 3168b6eb26fdSRiver Riddle // Otherwise, add the new inputs. 3169b6eb26fdSRiver Riddle result.addInputs(inputNo, convertedTypes); 3170b6eb26fdSRiver Riddle return success(); 3171b6eb26fdSRiver Riddle } 31723dd58333SMatthias Springer LogicalResult 31733dd58333SMatthias Springer TypeConverter::convertSignatureArgs(TypeRange types, 3174b6eb26fdSRiver Riddle SignatureConversion &result, 31753dd58333SMatthias Springer unsigned origInputOffset) const { 3176b6eb26fdSRiver Riddle for (unsigned i = 0, e = types.size(); i != e; ++i) 3177b6eb26fdSRiver Riddle if (failed(convertSignatureArg(origInputOffset + i, types[i], result))) 3178b6eb26fdSRiver Riddle return failure(); 3179b6eb26fdSRiver Riddle return success(); 3180b6eb26fdSRiver Riddle } 3181b6eb26fdSRiver Riddle 3182b6eb26fdSRiver Riddle Value TypeConverter::materializeConversion( 31833dd58333SMatthias Springer ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder, 31843dd58333SMatthias Springer Location loc, Type resultType, ValueRange inputs) const { 31853dd58333SMatthias Springer for (const MaterializationCallbackFn &fn : llvm::reverse(materializations)) 31860de16fafSRamkumar Ramachandra if (std::optional<Value> result = fn(builder, resultType, inputs, loc)) 31876d5fc1e3SKazu Hirata return *result; 3188b6eb26fdSRiver Riddle return nullptr; 3189b6eb26fdSRiver Riddle } 3190b6eb26fdSRiver Riddle 31913dd58333SMatthias Springer std::optional<TypeConverter::SignatureConversion> 31923dd58333SMatthias Springer TypeConverter::convertBlockSignature(Block *block) const { 3193b6eb26fdSRiver Riddle SignatureConversion conversion(block->getNumArguments()); 3194b6eb26fdSRiver Riddle if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion))) 31951a36588eSKazu Hirata return std::nullopt; 3196b6eb26fdSRiver Riddle return conversion; 3197b6eb26fdSRiver Riddle } 3198b6eb26fdSRiver Riddle 319901b55f16SRiver Riddle //===----------------------------------------------------------------------===// 3200499abb24SKrzysztof Drewniak // Type attribute conversion 3201499abb24SKrzysztof Drewniak //===----------------------------------------------------------------------===// 3202499abb24SKrzysztof Drewniak TypeConverter::AttributeConversionResult 3203499abb24SKrzysztof Drewniak TypeConverter::AttributeConversionResult::result(Attribute attr) { 3204499abb24SKrzysztof Drewniak return AttributeConversionResult(attr, resultTag); 3205499abb24SKrzysztof Drewniak } 3206499abb24SKrzysztof Drewniak 3207499abb24SKrzysztof Drewniak TypeConverter::AttributeConversionResult 3208499abb24SKrzysztof Drewniak TypeConverter::AttributeConversionResult::na() { 3209499abb24SKrzysztof Drewniak return AttributeConversionResult(nullptr, naTag); 3210499abb24SKrzysztof Drewniak } 3211499abb24SKrzysztof Drewniak 3212499abb24SKrzysztof Drewniak TypeConverter::AttributeConversionResult 3213499abb24SKrzysztof Drewniak TypeConverter::AttributeConversionResult::abort() { 3214499abb24SKrzysztof Drewniak return AttributeConversionResult(nullptr, abortTag); 3215499abb24SKrzysztof Drewniak } 3216499abb24SKrzysztof Drewniak 3217499abb24SKrzysztof Drewniak bool TypeConverter::AttributeConversionResult::hasResult() const { 3218499abb24SKrzysztof Drewniak return impl.getInt() == resultTag; 3219499abb24SKrzysztof Drewniak } 3220499abb24SKrzysztof Drewniak 3221499abb24SKrzysztof Drewniak bool TypeConverter::AttributeConversionResult::isNa() const { 3222499abb24SKrzysztof Drewniak return impl.getInt() == naTag; 3223499abb24SKrzysztof Drewniak } 3224499abb24SKrzysztof Drewniak 3225499abb24SKrzysztof Drewniak bool TypeConverter::AttributeConversionResult::isAbort() const { 3226499abb24SKrzysztof Drewniak return impl.getInt() == abortTag; 3227499abb24SKrzysztof Drewniak } 3228499abb24SKrzysztof Drewniak 3229499abb24SKrzysztof Drewniak Attribute TypeConverter::AttributeConversionResult::getResult() const { 3230499abb24SKrzysztof Drewniak assert(hasResult() && "Cannot get result from N/A or abort"); 3231499abb24SKrzysztof Drewniak return impl.getPointer(); 3232499abb24SKrzysztof Drewniak } 3233499abb24SKrzysztof Drewniak 32343dd58333SMatthias Springer std::optional<Attribute> 32353dd58333SMatthias Springer TypeConverter::convertTypeAttribute(Type type, Attribute attr) const { 32363dd58333SMatthias Springer for (const TypeAttributeConversionCallbackFn &fn : 3237499abb24SKrzysztof Drewniak llvm::reverse(typeAttributeConversions)) { 3238499abb24SKrzysztof Drewniak AttributeConversionResult res = fn(type, attr); 3239499abb24SKrzysztof Drewniak if (res.hasResult()) 3240499abb24SKrzysztof Drewniak return res.getResult(); 3241499abb24SKrzysztof Drewniak if (res.isAbort()) 3242499abb24SKrzysztof Drewniak return std::nullopt; 3243499abb24SKrzysztof Drewniak } 3244499abb24SKrzysztof Drewniak return std::nullopt; 3245499abb24SKrzysztof Drewniak } 3246499abb24SKrzysztof Drewniak 3247499abb24SKrzysztof Drewniak //===----------------------------------------------------------------------===// 32487ceffae1SRiver Riddle // FunctionOpInterfaceSignatureConversion 324901b55f16SRiver Riddle //===----------------------------------------------------------------------===// 325001b55f16SRiver Riddle 3251ed4749f9SIvan Butygin static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, 3252ce254598SMatthias Springer const TypeConverter &typeConverter, 3253ed4749f9SIvan Butygin ConversionPatternRewriter &rewriter) { 3254e5f8cdd6SKai Sasaki FunctionType type = dyn_cast<FunctionType>(funcOp.getFunctionType()); 3255e5f8cdd6SKai Sasaki if (!type) 3256e5f8cdd6SKai Sasaki return failure(); 3257ed4749f9SIvan Butygin 3258ed4749f9SIvan Butygin // Convert the original function types. 3259ed4749f9SIvan Butygin TypeConverter::SignatureConversion result(type.getNumInputs()); 3260ed4749f9SIvan Butygin SmallVector<Type, 1> newResults; 3261ed4749f9SIvan Butygin if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || 3262ed4749f9SIvan Butygin failed(typeConverter.convertTypes(type.getResults(), newResults)) || 3263ed4749f9SIvan Butygin failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), 3264ed4749f9SIvan Butygin typeConverter, &result))) 3265ed4749f9SIvan Butygin return failure(); 3266ed4749f9SIvan Butygin 3267ed4749f9SIvan Butygin // Update the function signature in-place. 3268ed4749f9SIvan Butygin auto newType = FunctionType::get(rewriter.getContext(), 3269ed4749f9SIvan Butygin result.getConvertedTypes(), newResults); 3270ed4749f9SIvan Butygin 32715fcf907bSMatthias Springer rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); }); 3272ed4749f9SIvan Butygin 3273ed4749f9SIvan Butygin return success(); 3274ed4749f9SIvan Butygin } 3275ed4749f9SIvan Butygin 3276b6eb26fdSRiver Riddle /// Create a default conversion pattern that rewrites the type signature of a 32777ceffae1SRiver Riddle /// FunctionOpInterface op. This only supports ops which use FunctionType to 32787ceffae1SRiver Riddle /// represent their type. 3279b6eb26fdSRiver Riddle namespace { 32807ceffae1SRiver Riddle struct FunctionOpInterfaceSignatureConversion : public ConversionPattern { 32817ceffae1SRiver Riddle FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, 32827ceffae1SRiver Riddle MLIRContext *ctx, 3283ce254598SMatthias Springer const TypeConverter &converter) 328476f3c2f3SRiver Riddle : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {} 3285b6eb26fdSRiver Riddle 3286b6eb26fdSRiver Riddle LogicalResult 3287ed4749f9SIvan Butygin matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/, 3288b6eb26fdSRiver Riddle ConversionPatternRewriter &rewriter) const override { 32897ceffae1SRiver Riddle FunctionOpInterface funcOp = cast<FunctionOpInterface>(op); 3290ed4749f9SIvan Butygin return convertFuncOpTypes(funcOp, *typeConverter, rewriter); 3291ed4749f9SIvan Butygin } 3292ed4749f9SIvan Butygin }; 3293b6eb26fdSRiver Riddle 3294ed4749f9SIvan Butygin struct AnyFunctionOpInterfaceSignatureConversion 3295ed4749f9SIvan Butygin : public OpInterfaceConversionPattern<FunctionOpInterface> { 3296ed4749f9SIvan Butygin using OpInterfaceConversionPattern::OpInterfaceConversionPattern; 3297b6eb26fdSRiver Riddle 3298ed4749f9SIvan Butygin LogicalResult 3299ed4749f9SIvan Butygin matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/, 3300ed4749f9SIvan Butygin ConversionPatternRewriter &rewriter) const override { 3301ed4749f9SIvan Butygin return convertFuncOpTypes(funcOp, *typeConverter, rewriter); 3302b6eb26fdSRiver Riddle } 3303b6eb26fdSRiver Riddle }; 3304be0a7e9fSMehdi Amini } // namespace 3305b6eb26fdSRiver Riddle 330635ef3994SIvan Butygin FailureOr<Operation *> 330735ef3994SIvan Butygin mlir::convertOpResultTypes(Operation *op, ValueRange operands, 330835ef3994SIvan Butygin const TypeConverter &converter, 330935ef3994SIvan Butygin ConversionPatternRewriter &rewriter) { 331035ef3994SIvan Butygin assert(op && "Invalid op"); 331135ef3994SIvan Butygin Location loc = op->getLoc(); 331235ef3994SIvan Butygin if (converter.isLegal(op)) 331335ef3994SIvan Butygin return rewriter.notifyMatchFailure(loc, "op already legal"); 331435ef3994SIvan Butygin 331535ef3994SIvan Butygin OperationState newOp(loc, op->getName()); 331635ef3994SIvan Butygin newOp.addOperands(operands); 331735ef3994SIvan Butygin 331835ef3994SIvan Butygin SmallVector<Type> newResultTypes; 331935ef3994SIvan Butygin if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes))) 332035ef3994SIvan Butygin return rewriter.notifyMatchFailure(loc, "couldn't convert return types"); 332135ef3994SIvan Butygin 332235ef3994SIvan Butygin newOp.addTypes(newResultTypes); 332335ef3994SIvan Butygin newOp.addAttributes(op->getAttrs()); 332435ef3994SIvan Butygin return rewriter.create(newOp); 332535ef3994SIvan Butygin } 332635ef3994SIvan Butygin 33277ceffae1SRiver Riddle void mlir::populateFunctionOpInterfaceTypeConversionPattern( 3328dc4e913bSChris Lattner StringRef functionLikeOpName, RewritePatternSet &patterns, 3329ce254598SMatthias Springer const TypeConverter &converter) { 33307ceffae1SRiver Riddle patterns.add<FunctionOpInterfaceSignatureConversion>( 33313a506b31SChris Lattner functionLikeOpName, patterns.getContext(), converter); 33320a7a1ac7Smikeurbach } 33330a7a1ac7Smikeurbach 3334ed4749f9SIvan Butygin void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern( 3335ce254598SMatthias Springer RewritePatternSet &patterns, const TypeConverter &converter) { 3336ed4749f9SIvan Butygin patterns.add<AnyFunctionOpInterfaceSignatureConversion>( 3337ed4749f9SIvan Butygin converter, patterns.getContext()); 3338ed4749f9SIvan Butygin } 3339ed4749f9SIvan Butygin 3340b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 3341b6eb26fdSRiver Riddle // ConversionTarget 3342b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 3343b6eb26fdSRiver Riddle 3344b6eb26fdSRiver Riddle void ConversionTarget::setOpAction(OperationName op, 3345b6eb26fdSRiver Riddle LegalizationAction action) { 3346c6828e0cSCaitlyn Cano legalOperations[op].action = action; 3347b6eb26fdSRiver Riddle } 3348b6eb26fdSRiver Riddle 3349b6eb26fdSRiver Riddle void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames, 3350b6eb26fdSRiver Riddle LegalizationAction action) { 3351b6eb26fdSRiver Riddle for (StringRef dialect : dialectNames) 3352b6eb26fdSRiver Riddle legalDialects[dialect] = action; 3353b6eb26fdSRiver Riddle } 3354b6eb26fdSRiver Riddle 3355b6eb26fdSRiver Riddle auto ConversionTarget::getOpAction(OperationName op) const 33560de16fafSRamkumar Ramachandra -> std::optional<LegalizationAction> { 33570de16fafSRamkumar Ramachandra std::optional<LegalizationInfo> info = getOpInfo(op); 33580de16fafSRamkumar Ramachandra return info ? info->action : std::optional<LegalizationAction>(); 3359b6eb26fdSRiver Riddle } 3360b6eb26fdSRiver Riddle 3361b6eb26fdSRiver Riddle auto ConversionTarget::isLegal(Operation *op) const 33620de16fafSRamkumar Ramachandra -> std::optional<LegalOpDetails> { 33630de16fafSRamkumar Ramachandra std::optional<LegalizationInfo> info = getOpInfo(op->getName()); 3364b6eb26fdSRiver Riddle if (!info) 33651a36588eSKazu Hirata return std::nullopt; 3366b6eb26fdSRiver Riddle 3367b6eb26fdSRiver Riddle // Returns true if this operation instance is known to be legal. 3368b6eb26fdSRiver Riddle auto isOpLegal = [&] { 33691c9c2c91SBenjamin Kramer // Handle dynamic legality either with the provided legality function. 3370c6828e0cSCaitlyn Cano if (info->action == LegalizationAction::Dynamic) { 33710de16fafSRamkumar Ramachandra std::optional<bool> result = info->legalityFn(op); 3372c6828e0cSCaitlyn Cano if (result) 3373c6828e0cSCaitlyn Cano return *result; 3374c6828e0cSCaitlyn Cano } 3375b6eb26fdSRiver Riddle 3376b6eb26fdSRiver Riddle // Otherwise, the operation is only legal if it was marked 'Legal'. 3377b6eb26fdSRiver Riddle return info->action == LegalizationAction::Legal; 3378b6eb26fdSRiver Riddle }; 3379b6eb26fdSRiver Riddle if (!isOpLegal()) 33801a36588eSKazu Hirata return std::nullopt; 3381b6eb26fdSRiver Riddle 3382b6eb26fdSRiver Riddle // This operation is legal, compute any additional legality information. 3383b6eb26fdSRiver Riddle LegalOpDetails legalityDetails; 3384b6eb26fdSRiver Riddle if (info->isRecursivelyLegal) { 3385b6eb26fdSRiver Riddle auto legalityFnIt = opRecursiveLegalityFns.find(op->getName()); 3386c6828e0cSCaitlyn Cano if (legalityFnIt != opRecursiveLegalityFns.end()) { 3387c6828e0cSCaitlyn Cano legalityDetails.isRecursivelyLegal = 338830c67587SKazu Hirata legalityFnIt->second(op).value_or(true); 3389c6828e0cSCaitlyn Cano } else { 3390b6eb26fdSRiver Riddle legalityDetails.isRecursivelyLegal = true; 3391b6eb26fdSRiver Riddle } 3392c6828e0cSCaitlyn Cano } 3393b6eb26fdSRiver Riddle return legalityDetails; 3394b6eb26fdSRiver Riddle } 3395b6eb26fdSRiver Riddle 33962a3878eaSButygin bool ConversionTarget::isIllegal(Operation *op) const { 33970de16fafSRamkumar Ramachandra std::optional<LegalizationInfo> info = getOpInfo(op->getName()); 33982a3878eaSButygin if (!info) 33992a3878eaSButygin return false; 34002a3878eaSButygin 34012a3878eaSButygin if (info->action == LegalizationAction::Dynamic) { 34020de16fafSRamkumar Ramachandra std::optional<bool> result = info->legalityFn(op); 34032a3878eaSButygin if (!result) 34042a3878eaSButygin return false; 34052a3878eaSButygin 34062a3878eaSButygin return !(*result); 34072a3878eaSButygin } 34082a3878eaSButygin 34092a3878eaSButygin return info->action == LegalizationAction::Illegal; 34102a3878eaSButygin } 34112a3878eaSButygin 3412c6828e0cSCaitlyn Cano static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks( 3413c6828e0cSCaitlyn Cano ConversionTarget::DynamicLegalityCallbackFn oldCallback, 3414c6828e0cSCaitlyn Cano ConversionTarget::DynamicLegalityCallbackFn newCallback) { 3415c6828e0cSCaitlyn Cano if (!oldCallback) 3416c6828e0cSCaitlyn Cano return newCallback; 3417c6828e0cSCaitlyn Cano 3418c6828e0cSCaitlyn Cano auto chain = [oldCl = std::move(oldCallback), newCl = std::move(newCallback)]( 34190de16fafSRamkumar Ramachandra Operation *op) -> std::optional<bool> { 34200de16fafSRamkumar Ramachandra if (std::optional<bool> result = newCl(op)) 3421c6828e0cSCaitlyn Cano return *result; 3422c6828e0cSCaitlyn Cano 3423c6828e0cSCaitlyn Cano return oldCl(op); 3424c6828e0cSCaitlyn Cano }; 3425c6828e0cSCaitlyn Cano return chain; 3426c6828e0cSCaitlyn Cano } 3427c6828e0cSCaitlyn Cano 3428b6eb26fdSRiver Riddle void ConversionTarget::setLegalityCallback( 3429b6eb26fdSRiver Riddle OperationName name, const DynamicLegalityCallbackFn &callback) { 3430b6eb26fdSRiver Riddle assert(callback && "expected valid legality callback"); 34317dad59f0SMehdi Amini auto *infoIt = legalOperations.find(name); 3432b6eb26fdSRiver Riddle assert(infoIt != legalOperations.end() && 3433b6eb26fdSRiver Riddle infoIt->second.action == LegalizationAction::Dynamic && 3434b6eb26fdSRiver Riddle "expected operation to already be marked as dynamically legal"); 3435c6828e0cSCaitlyn Cano infoIt->second.legalityFn = 3436c6828e0cSCaitlyn Cano composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback); 3437b6eb26fdSRiver Riddle } 3438b6eb26fdSRiver Riddle 3439b6eb26fdSRiver Riddle void ConversionTarget::markOpRecursivelyLegal( 3440b6eb26fdSRiver Riddle OperationName name, const DynamicLegalityCallbackFn &callback) { 34417dad59f0SMehdi Amini auto *infoIt = legalOperations.find(name); 3442b6eb26fdSRiver Riddle assert(infoIt != legalOperations.end() && 3443b6eb26fdSRiver Riddle infoIt->second.action != LegalizationAction::Illegal && 3444b6eb26fdSRiver Riddle "expected operation to already be marked as legal"); 3445b6eb26fdSRiver Riddle infoIt->second.isRecursivelyLegal = true; 3446b6eb26fdSRiver Riddle if (callback) 3447c6828e0cSCaitlyn Cano opRecursiveLegalityFns[name] = composeLegalityCallbacks( 3448c6828e0cSCaitlyn Cano std::move(opRecursiveLegalityFns[name]), callback); 3449b6eb26fdSRiver Riddle else 3450b6eb26fdSRiver Riddle opRecursiveLegalityFns.erase(name); 3451b6eb26fdSRiver Riddle } 3452b6eb26fdSRiver Riddle 3453b6eb26fdSRiver Riddle void ConversionTarget::setLegalityCallback( 3454b6eb26fdSRiver Riddle ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) { 3455b6eb26fdSRiver Riddle assert(callback && "expected valid legality callback"); 3456b6eb26fdSRiver Riddle for (StringRef dialect : dialects) 3457c6828e0cSCaitlyn Cano dialectLegalityFns[dialect] = composeLegalityCallbacks( 3458c6828e0cSCaitlyn Cano std::move(dialectLegalityFns[dialect]), callback); 3459b6eb26fdSRiver Riddle } 3460b6eb26fdSRiver Riddle 3461b7a46498SButygin void ConversionTarget::setLegalityCallback( 3462b7a46498SButygin const DynamicLegalityCallbackFn &callback) { 3463b7a46498SButygin assert(callback && "expected valid legality callback"); 3464c6828e0cSCaitlyn Cano unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback); 3465b7a46498SButygin } 3466b7a46498SButygin 3467b6eb26fdSRiver Riddle auto ConversionTarget::getOpInfo(OperationName op) const 34680de16fafSRamkumar Ramachandra -> std::optional<LegalizationInfo> { 3469b6eb26fdSRiver Riddle // Check for info for this specific operation. 34707dad59f0SMehdi Amini const auto *it = legalOperations.find(op); 3471b6eb26fdSRiver Riddle if (it != legalOperations.end()) 3472b6eb26fdSRiver Riddle return it->second; 3473b6eb26fdSRiver Riddle // Check for info for the parent dialect. 3474e6260ad0SRiver Riddle auto dialectIt = legalDialects.find(op.getDialectNamespace()); 3475b6eb26fdSRiver Riddle if (dialectIt != legalDialects.end()) { 3476b7a46498SButygin DynamicLegalityCallbackFn callback; 3477e6260ad0SRiver Riddle auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace()); 3478b6eb26fdSRiver Riddle if (dialectFn != dialectLegalityFns.end()) 3479b6eb26fdSRiver Riddle callback = dialectFn->second; 3480b6eb26fdSRiver Riddle return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false, 3481b6eb26fdSRiver Riddle callback}; 3482b6eb26fdSRiver Riddle } 3483b6eb26fdSRiver Riddle // Otherwise, check if we mark unknown operations as dynamic. 3484b7a46498SButygin if (unknownLegalityFn) 3485b6eb26fdSRiver Riddle return LegalizationInfo{LegalizationAction::Dynamic, 3486b6eb26fdSRiver Riddle /*isRecursivelyLegal=*/false, unknownLegalityFn}; 34871a36588eSKazu Hirata return std::nullopt; 3488b6eb26fdSRiver Riddle } 3489b6eb26fdSRiver Riddle 34906ae7f66fSJacques Pienaar #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 3491b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 34928c66344eSRiver Riddle // PDL Configuration 34938c66344eSRiver Riddle //===----------------------------------------------------------------------===// 34948c66344eSRiver Riddle 34958c66344eSRiver Riddle void PDLConversionConfig::notifyRewriteBegin(PatternRewriter &rewriter) { 34968c66344eSRiver Riddle auto &rewriterImpl = 34978c66344eSRiver Riddle static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); 34988c66344eSRiver Riddle rewriterImpl.currentTypeConverter = getTypeConverter(); 34998c66344eSRiver Riddle } 35008c66344eSRiver Riddle 35018c66344eSRiver Riddle void PDLConversionConfig::notifyRewriteEnd(PatternRewriter &rewriter) { 35028c66344eSRiver Riddle auto &rewriterImpl = 35038c66344eSRiver Riddle static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); 35048c66344eSRiver Riddle rewriterImpl.currentTypeConverter = nullptr; 35058c66344eSRiver Riddle } 35068c66344eSRiver Riddle 35078c66344eSRiver Riddle /// Remap the given value using the rewriter and the type converter in the 35088c66344eSRiver Riddle /// provided config. 35098c66344eSRiver Riddle static FailureOr<SmallVector<Value>> 35108c66344eSRiver Riddle pdllConvertValues(ConversionPatternRewriter &rewriter, ValueRange values) { 35118c66344eSRiver Riddle SmallVector<Value> mappedValues; 35128c66344eSRiver Riddle if (failed(rewriter.getRemappedValues(values, mappedValues))) 35138c66344eSRiver Riddle return failure(); 35148c66344eSRiver Riddle return std::move(mappedValues); 35158c66344eSRiver Riddle } 35168c66344eSRiver Riddle 35178c66344eSRiver Riddle void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) { 35188c66344eSRiver Riddle patterns.getPDLPatterns().registerRewriteFunction( 35198c66344eSRiver Riddle "convertValue", 35208c66344eSRiver Riddle [](PatternRewriter &rewriter, Value value) -> FailureOr<Value> { 35218c66344eSRiver Riddle auto results = pdllConvertValues( 35228c66344eSRiver Riddle static_cast<ConversionPatternRewriter &>(rewriter), value); 35238c66344eSRiver Riddle if (failed(results)) 35248c66344eSRiver Riddle return failure(); 35258c66344eSRiver Riddle return results->front(); 35268c66344eSRiver Riddle }); 35278c66344eSRiver Riddle patterns.getPDLPatterns().registerRewriteFunction( 35288c66344eSRiver Riddle "convertValues", [](PatternRewriter &rewriter, ValueRange values) { 35298c66344eSRiver Riddle return pdllConvertValues( 35308c66344eSRiver Riddle static_cast<ConversionPatternRewriter &>(rewriter), values); 35318c66344eSRiver Riddle }); 35328c66344eSRiver Riddle patterns.getPDLPatterns().registerRewriteFunction( 35338c66344eSRiver Riddle "convertType", 35348c66344eSRiver Riddle [](PatternRewriter &rewriter, Type type) -> FailureOr<Type> { 35358c66344eSRiver Riddle auto &rewriterImpl = 35368c66344eSRiver Riddle static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); 3537ce254598SMatthias Springer if (const TypeConverter *converter = 3538ce254598SMatthias Springer rewriterImpl.currentTypeConverter) { 35398c66344eSRiver Riddle if (Type newType = converter->convertType(type)) 35408c66344eSRiver Riddle return newType; 35418c66344eSRiver Riddle return failure(); 35428c66344eSRiver Riddle } 35438c66344eSRiver Riddle return type; 35448c66344eSRiver Riddle }); 35458c66344eSRiver Riddle patterns.getPDLPatterns().registerRewriteFunction( 35468c66344eSRiver Riddle "convertTypes", 35478c66344eSRiver Riddle [](PatternRewriter &rewriter, 35488c66344eSRiver Riddle TypeRange types) -> FailureOr<SmallVector<Type>> { 35498c66344eSRiver Riddle auto &rewriterImpl = 35508c66344eSRiver Riddle static_cast<ConversionPatternRewriter &>(rewriter).getImpl(); 3551ce254598SMatthias Springer const TypeConverter *converter = rewriterImpl.currentTypeConverter; 35528c66344eSRiver Riddle if (!converter) 35538c66344eSRiver Riddle return SmallVector<Type>(types); 35548c66344eSRiver Riddle 35558c66344eSRiver Riddle SmallVector<Type> remappedTypes; 35568c66344eSRiver Riddle if (failed(converter->convertTypes(types, remappedTypes))) 35578c66344eSRiver Riddle return failure(); 35588c66344eSRiver Riddle return std::move(remappedTypes); 35598c66344eSRiver Riddle }); 35608c66344eSRiver Riddle } 35616ae7f66fSJacques Pienaar #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 35628c66344eSRiver Riddle 35638c66344eSRiver Riddle //===----------------------------------------------------------------------===// 3564b6eb26fdSRiver Riddle // Op Conversion Entry Points 3565b6eb26fdSRiver Riddle //===----------------------------------------------------------------------===// 3566b6eb26fdSRiver Riddle 356701b55f16SRiver Riddle //===----------------------------------------------------------------------===// 356801b55f16SRiver Riddle // Partial Conversion 356901b55f16SRiver Riddle 3570b6eb26fdSRiver Riddle LogicalResult 3571b6eb26fdSRiver Riddle mlir::applyPartialConversion(ArrayRef<Operation *> ops, 3572370a6f09SMehdi Amini const ConversionTarget &target, 357379d7f618SChris Lattner const FrozenRewritePatternSet &patterns, 3574b6eb26fdSRiver Riddle DenseSet<Operation *> *unconvertedOps) { 3575b6eb26fdSRiver Riddle OperationConverter opConverter(target, patterns, OpConversionMode::Partial, 3576b6eb26fdSRiver Riddle unconvertedOps); 3577b6eb26fdSRiver Riddle return opConverter.convertOperations(ops); 3578b6eb26fdSRiver Riddle } 3579b6eb26fdSRiver Riddle LogicalResult 3580370a6f09SMehdi Amini mlir::applyPartialConversion(Operation *op, const ConversionTarget &target, 358179d7f618SChris Lattner const FrozenRewritePatternSet &patterns, 3582b6eb26fdSRiver Riddle DenseSet<Operation *> *unconvertedOps) { 3583984b800aSserge-sans-paille return applyPartialConversion(llvm::ArrayRef(op), target, patterns, 3584b6eb26fdSRiver Riddle unconvertedOps); 3585b6eb26fdSRiver Riddle } 3586b6eb26fdSRiver Riddle 358701b55f16SRiver Riddle //===----------------------------------------------------------------------===// 358801b55f16SRiver Riddle // Full Conversion 358901b55f16SRiver Riddle 3590b6eb26fdSRiver Riddle LogicalResult 3591e214f004SMatthias Springer mlir::applyFullConversion(ArrayRef<Operation *> ops, 3592e214f004SMatthias Springer const ConversionTarget &target, 359379d7f618SChris Lattner const FrozenRewritePatternSet &patterns) { 3594b6eb26fdSRiver Riddle OperationConverter opConverter(target, patterns, OpConversionMode::Full); 3595b6eb26fdSRiver Riddle return opConverter.convertOperations(ops); 3596b6eb26fdSRiver Riddle } 3597b6eb26fdSRiver Riddle LogicalResult 3598370a6f09SMehdi Amini mlir::applyFullConversion(Operation *op, const ConversionTarget &target, 359979d7f618SChris Lattner const FrozenRewritePatternSet &patterns) { 3600984b800aSserge-sans-paille return applyFullConversion(llvm::ArrayRef(op), target, patterns); 3601b6eb26fdSRiver Riddle } 3602b6eb26fdSRiver Riddle 360301b55f16SRiver Riddle //===----------------------------------------------------------------------===// 360401b55f16SRiver Riddle // Analysis Conversion 360501b55f16SRiver Riddle 3606b6eb26fdSRiver Riddle LogicalResult 3607b6eb26fdSRiver Riddle mlir::applyAnalysisConversion(ArrayRef<Operation *> ops, 3608b6eb26fdSRiver Riddle ConversionTarget &target, 360979d7f618SChris Lattner const FrozenRewritePatternSet &patterns, 3610b8c6b152SChia-hung Duan DenseSet<Operation *> &convertedOps, 3611b8c6b152SChia-hung Duan function_ref<void(Diagnostic &)> notifyCallback) { 3612b6eb26fdSRiver Riddle OperationConverter opConverter(target, patterns, OpConversionMode::Analysis, 3613b6eb26fdSRiver Riddle &convertedOps); 3614b8c6b152SChia-hung Duan return opConverter.convertOperations(ops, notifyCallback); 3615b6eb26fdSRiver Riddle } 3616b6eb26fdSRiver Riddle LogicalResult 3617b6eb26fdSRiver Riddle mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target, 361879d7f618SChris Lattner const FrozenRewritePatternSet &patterns, 3619b8c6b152SChia-hung Duan DenseSet<Operation *> &convertedOps, 3620b8c6b152SChia-hung Duan function_ref<void(Diagnostic &)> notifyCallback) { 3621984b800aSserge-sans-paille return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, 3622b8c6b152SChia-hung Duan convertedOps, notifyCallback); 3623b6eb26fdSRiver Riddle } 3624