1 //===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares a generic pass for converting between MLIR dialects. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_ 14 #define MLIR_TRANSFORMS_DIALECTCONVERSION_H_ 15 16 #include "mlir/Config/mlir-config.h" 17 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 18 #include "llvm/ADT/MapVector.h" 19 #include "llvm/ADT/StringMap.h" 20 #include <type_traits> 21 22 namespace mlir { 23 24 // Forward declarations. 25 class Attribute; 26 class Block; 27 struct ConversionConfig; 28 class ConversionPatternRewriter; 29 class MLIRContext; 30 class Operation; 31 struct OperationConverter; 32 class Type; 33 class Value; 34 35 //===----------------------------------------------------------------------===// 36 // Type Conversion 37 //===----------------------------------------------------------------------===// 38 39 /// Type conversion class. Specific conversions and materializations can be 40 /// registered using addConversion and addMaterialization, respectively. 41 class TypeConverter { 42 public: 43 virtual ~TypeConverter() = default; 44 TypeConverter() = default; 45 // Copy the registered conversions, but not the caches 46 TypeConverter(const TypeConverter &other) 47 : conversions(other.conversions), 48 argumentMaterializations(other.argumentMaterializations), 49 sourceMaterializations(other.sourceMaterializations), 50 targetMaterializations(other.targetMaterializations), 51 typeAttributeConversions(other.typeAttributeConversions) {} 52 TypeConverter &operator=(const TypeConverter &other) { 53 conversions = other.conversions; 54 argumentMaterializations = other.argumentMaterializations; 55 sourceMaterializations = other.sourceMaterializations; 56 targetMaterializations = other.targetMaterializations; 57 typeAttributeConversions = other.typeAttributeConversions; 58 return *this; 59 } 60 61 /// This class provides all of the information necessary to convert a type 62 /// signature. 63 class SignatureConversion { 64 public: 65 SignatureConversion(unsigned numOrigInputs) 66 : remappedInputs(numOrigInputs) {} 67 68 /// This struct represents a range of new types or a single value that 69 /// remaps an existing signature input. 70 struct InputMapping { 71 size_t inputNo, size; 72 Value replacementValue; 73 }; 74 75 /// Return the argument types for the new signature. 76 ArrayRef<Type> getConvertedTypes() const { return argTypes; } 77 78 /// Get the input mapping for the given argument. 79 std::optional<InputMapping> getInputMapping(unsigned input) const { 80 return remappedInputs[input]; 81 } 82 83 //===------------------------------------------------------------------===// 84 // Conversion Hooks 85 //===------------------------------------------------------------------===// 86 87 /// Remap an input of the original signature with a new set of types. The 88 /// new types are appended to the new signature conversion. 89 void addInputs(unsigned origInputNo, ArrayRef<Type> types); 90 91 /// Append new input types to the signature conversion, this should only be 92 /// used if the new types are not intended to remap an existing input. 93 void addInputs(ArrayRef<Type> types); 94 95 /// Remap an input of the original signature to another `replacement` 96 /// value. This drops the original argument. 97 void remapInput(unsigned origInputNo, Value replacement); 98 99 private: 100 /// Remap an input of the original signature with a range of types in the 101 /// new signature. 102 void remapInput(unsigned origInputNo, unsigned newInputNo, 103 unsigned newInputCount = 1); 104 105 /// The remapping information for each of the original arguments. 106 SmallVector<std::optional<InputMapping>, 4> remappedInputs; 107 108 /// The set of new argument types. 109 SmallVector<Type, 4> argTypes; 110 }; 111 112 /// The general result of a type attribute conversion callback, allowing 113 /// for early termination. The default constructor creates the na case. 114 class AttributeConversionResult { 115 public: 116 constexpr AttributeConversionResult() : impl() {} 117 AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {} 118 119 static AttributeConversionResult result(Attribute attr); 120 static AttributeConversionResult na(); 121 static AttributeConversionResult abort(); 122 123 bool hasResult() const; 124 bool isNa() const; 125 bool isAbort() const; 126 127 Attribute getResult() const; 128 129 private: 130 AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {} 131 132 llvm::PointerIntPair<Attribute, 2> impl; 133 // Note that na is 0 so that we can use PointerIntPair's default 134 // constructor. 135 static constexpr unsigned naTag = 0; 136 static constexpr unsigned resultTag = 1; 137 static constexpr unsigned abortTag = 2; 138 }; 139 140 /// Register a conversion function. A conversion function must be convertible 141 /// to any of the following forms (where `T` is a class derived from `Type`): 142 /// 143 /// * std::optional<Type>(T) 144 /// - This form represents a 1-1 type conversion. It should return nullptr 145 /// or `std::nullopt` to signify failure. If `std::nullopt` is returned, 146 /// the converter is allowed to try another conversion function to 147 /// perform the conversion. 148 /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &) 149 /// - This form represents a 1-N type conversion. It should return 150 /// `failure` or `std::nullopt` to signify a failed conversion. If the 151 /// new set of types is empty, the type is removed and any usages of the 152 /// existing value are expected to be removed during conversion. If 153 /// `std::nullopt` is returned, the converter is allowed to try another 154 /// conversion function to perform the conversion. 155 /// 156 /// Note: When attempting to convert a type, e.g. via 'convertType', the 157 /// mostly recently added conversions will be invoked first. 158 template <typename FnT, typename T = typename llvm::function_traits< 159 std::decay_t<FnT>>::template arg_t<0>> 160 void addConversion(FnT &&callback) { 161 registerConversion(wrapCallback<T>(std::forward<FnT>(callback))); 162 } 163 164 /// All of the following materializations require function objects that are 165 /// convertible to the following form: 166 /// `Value(OpBuilder &, T, ValueRange, Location)`, 167 /// where `T` is any subclass of `Type`. This function is responsible for 168 /// creating an operation, using the OpBuilder and Location provided, that 169 /// "casts" a range of values into a single value of the given type `T`. It 170 /// must return a Value of the type `T` on success and `nullptr` if 171 /// it failed but other materialization should be attempted. Materialization 172 /// functions must be provided when a type conversion may persist after the 173 /// conversion has finished. 174 /// 175 /// Note: Target materializations may optionally accept an additional Type 176 /// parameter, which is the original type of the SSA value. Furthermore, `T` 177 /// can be a TypeRange; in that case, the function must return a 178 /// SmallVector<Value>. 179 180 /// This method registers a materialization that will be called when 181 /// converting (potentially multiple) block arguments that were the result of 182 /// a signature conversion of a single block argument, to a single SSA value 183 /// with the old block argument type. 184 /// 185 /// Note: Argument materializations are used only with the 1:N dialect 186 /// conversion driver. The 1:N dialect conversion driver will be removed soon 187 /// and so will be argument materializations. 188 template <typename FnT, typename T = typename llvm::function_traits< 189 std::decay_t<FnT>>::template arg_t<1>> 190 void addArgumentMaterialization(FnT &&callback) { 191 argumentMaterializations.emplace_back( 192 wrapMaterialization<T>(std::forward<FnT>(callback))); 193 } 194 195 /// This method registers a materialization that will be called when 196 /// converting a replacement value back to its original source type. 197 /// This is used when some uses of the original value persist beyond the main 198 /// conversion. 199 template <typename FnT, typename T = typename llvm::function_traits< 200 std::decay_t<FnT>>::template arg_t<1>> 201 void addSourceMaterialization(FnT &&callback) { 202 sourceMaterializations.emplace_back( 203 wrapMaterialization<T>(std::forward<FnT>(callback))); 204 } 205 206 /// This method registers a materialization that will be called when 207 /// converting a value to a target type according to a pattern's type 208 /// converter. 209 /// 210 /// Note: Target materializations can optionally inspect the "original" 211 /// type. This type may be different from the type of the input value. 212 /// For example, let's assume that a conversion pattern "P1" replaced an SSA 213 /// value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion 214 /// pattern "P2" matches an op that has "v1" as an operand. Let's furthermore 215 /// assume that "P2" determines that the converted target type of "t1" is 216 /// "t3", which may be different from "t2". In this example, the target 217 /// materialization will be invoked with: outputType = "t3", inputs = "v2", 218 /// originalType = "t1". Note that the original type "t1" cannot be recovered 219 /// from just "t3" and "v2"; that's why the originalType parameter exists. 220 /// 221 /// Note: During a 1:N conversion, the result types can be a TypeRange. In 222 /// that case the materialization produces a SmallVector<Value>. 223 template <typename FnT, typename T = typename llvm::function_traits< 224 std::decay_t<FnT>>::template arg_t<1>> 225 void addTargetMaterialization(FnT &&callback) { 226 targetMaterializations.emplace_back( 227 wrapTargetMaterialization<T>(std::forward<FnT>(callback))); 228 } 229 230 /// Register a conversion function for attributes within types. Type 231 /// converters may call this function in order to allow hoking into the 232 /// translation of attributes that exist within types. For example, a type 233 /// converter for the `memref` type could use these conversions to convert 234 /// memory spaces or layouts in an extensible way. 235 /// 236 /// The conversion functions take a non-null Type or subclass of Type and a 237 /// non-null Attribute (or subclass of Attribute), and returns a 238 /// `AttributeConversionResult`. This result can either contan an `Attribute`, 239 /// which may be `nullptr`, representing the conversion's success, 240 /// `AttributeConversionResult::na()` (the default empty value), indicating 241 /// that the conversion function did not apply and that further conversion 242 /// functions should be checked, or `AttributeConversionResult::abort()` 243 /// indicating that the conversion process should be aborted. 244 /// 245 /// Registered conversion functions are callled in the reverse of the order in 246 /// which they were registered. 247 template < 248 typename FnT, 249 typename T = 250 typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>, 251 typename A = 252 typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>> 253 void addTypeAttributeConversion(FnT &&callback) { 254 registerTypeAttributeConversion( 255 wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback))); 256 } 257 258 /// Convert the given type. This function should return failure if no valid 259 /// conversion exists, success otherwise. If the new set of types is empty, 260 /// the type is removed and any usages of the existing value are expected to 261 /// be removed during conversion. 262 LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const; 263 264 /// This hook simplifies defining 1-1 type conversions. This function returns 265 /// the type to convert to on success, and a null type on failure. 266 Type convertType(Type t) const; 267 268 /// Attempts a 1-1 type conversion, expecting the result type to be 269 /// `TargetType`. Returns the converted type cast to `TargetType` on success, 270 /// and a null type on conversion or cast failure. 271 template <typename TargetType> 272 TargetType convertType(Type t) const { 273 return dyn_cast_or_null<TargetType>(convertType(t)); 274 } 275 276 /// Convert the given set of types, filling 'results' as necessary. This 277 /// returns failure if the conversion of any of the types fails, success 278 /// otherwise. 279 LogicalResult convertTypes(TypeRange types, 280 SmallVectorImpl<Type> &results) const; 281 282 /// Return true if the given type is legal for this type converter, i.e. the 283 /// type converts to itself. 284 bool isLegal(Type type) const; 285 286 /// Return true if all of the given types are legal for this type converter. 287 template <typename RangeT> 288 std::enable_if_t<!std::is_convertible<RangeT, Type>::value && 289 !std::is_convertible<RangeT, Operation *>::value, 290 bool> 291 isLegal(RangeT &&range) const { 292 return llvm::all_of(range, [this](Type type) { return isLegal(type); }); 293 } 294 /// Return true if the given operation has legal operand and result types. 295 bool isLegal(Operation *op) const; 296 297 /// Return true if the types of block arguments within the region are legal. 298 bool isLegal(Region *region) const; 299 300 /// Return true if the inputs and outputs of the given function type are 301 /// legal. 302 bool isSignatureLegal(FunctionType ty) const; 303 304 /// This method allows for converting a specific argument of a signature. It 305 /// takes as inputs the original argument input number, type. 306 /// On success, it populates 'result' with any new mappings. 307 LogicalResult convertSignatureArg(unsigned inputNo, Type type, 308 SignatureConversion &result) const; 309 LogicalResult convertSignatureArgs(TypeRange types, 310 SignatureConversion &result, 311 unsigned origInputOffset = 0) const; 312 313 /// This function converts the type signature of the given block, by invoking 314 /// 'convertSignatureArg' for each argument. This function should return a 315 /// valid conversion for the signature on success, std::nullopt otherwise. 316 std::optional<SignatureConversion> convertBlockSignature(Block *block) const; 317 318 /// Materialize a conversion from a set of types into one result type by 319 /// generating a cast sequence of some kind. See the respective 320 /// `add*Materialization` for more information on the context for these 321 /// methods. 322 Value materializeArgumentConversion(OpBuilder &builder, Location loc, 323 Type resultType, ValueRange inputs) const; 324 Value materializeSourceConversion(OpBuilder &builder, Location loc, 325 Type resultType, ValueRange inputs) const; 326 Value materializeTargetConversion(OpBuilder &builder, Location loc, 327 Type resultType, ValueRange inputs, 328 Type originalType = {}) const; 329 SmallVector<Value> materializeTargetConversion(OpBuilder &builder, 330 Location loc, 331 TypeRange resultType, 332 ValueRange inputs, 333 Type originalType = {}) const; 334 335 /// Convert an attribute present `attr` from within the type `type` using 336 /// the registered conversion functions. If no applicable conversion has been 337 /// registered, return std::nullopt. Note that the empty attribute/`nullptr` 338 /// is a valid return value for this function. 339 std::optional<Attribute> convertTypeAttribute(Type type, 340 Attribute attr) const; 341 342 private: 343 /// The signature of the callback used to convert a type. If the new set of 344 /// types is empty, the type is removed and any usages of the existing value 345 /// are expected to be removed during conversion. 346 using ConversionCallbackFn = std::function<std::optional<LogicalResult>( 347 Type, SmallVectorImpl<Type> &)>; 348 349 /// The signature of the callback used to materialize a source/argument 350 /// conversion. 351 /// 352 /// Arguments: builder, result type, inputs, location 353 using MaterializationCallbackFn = 354 std::function<Value(OpBuilder &, Type, ValueRange, Location)>; 355 356 /// The signature of the callback used to materialize a target conversion. 357 /// 358 /// Arguments: builder, result types, inputs, location, original type 359 using TargetMaterializationCallbackFn = std::function<SmallVector<Value>( 360 OpBuilder &, TypeRange, ValueRange, Location, Type)>; 361 362 /// The signature of the callback used to convert a type attribute. 363 using TypeAttributeConversionCallbackFn = 364 std::function<AttributeConversionResult(Type, Attribute)>; 365 366 /// Generate a wrapper for the given callback. This allows for accepting 367 /// different callback forms, that all compose into a single version. 368 /// With callback of form: `std::optional<Type>(T)` 369 template <typename T, typename FnT> 370 std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn> 371 wrapCallback(FnT &&callback) const { 372 return wrapCallback<T>([callback = std::forward<FnT>(callback)]( 373 T type, SmallVectorImpl<Type> &results) { 374 if (std::optional<Type> resultOpt = callback(type)) { 375 bool wasSuccess = static_cast<bool>(*resultOpt); 376 if (wasSuccess) 377 results.push_back(*resultOpt); 378 return std::optional<LogicalResult>(success(wasSuccess)); 379 } 380 return std::optional<LogicalResult>(); 381 }); 382 } 383 /// With callback of form: `std::optional<LogicalResult>( 384 /// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`. 385 template <typename T, typename FnT> 386 std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>, 387 ConversionCallbackFn> 388 wrapCallback(FnT &&callback) const { 389 return [callback = std::forward<FnT>(callback)]( 390 Type type, 391 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { 392 T derivedType = dyn_cast<T>(type); 393 if (!derivedType) 394 return std::nullopt; 395 return callback(derivedType, results); 396 }; 397 } 398 399 /// Register a type conversion. 400 void registerConversion(ConversionCallbackFn callback) { 401 conversions.emplace_back(std::move(callback)); 402 cachedDirectConversions.clear(); 403 cachedMultiConversions.clear(); 404 } 405 406 /// Generate a wrapper for the given argument/source materialization 407 /// callback. The callback may take any subclass of `Type` and the 408 /// wrapper will check for the target type to be of the expected class 409 /// before calling the callback. 410 template <typename T, typename FnT> 411 MaterializationCallbackFn wrapMaterialization(FnT &&callback) const { 412 return [callback = std::forward<FnT>(callback)]( 413 OpBuilder &builder, Type resultType, ValueRange inputs, 414 Location loc) -> Value { 415 if (T derivedType = dyn_cast<T>(resultType)) 416 return callback(builder, derivedType, inputs, loc); 417 return Value(); 418 }; 419 } 420 421 /// Generate a wrapper for the given target materialization callback. 422 /// The callback may take any subclass of `Type` and the wrapper will check 423 /// for the target type to be of the expected class before calling the 424 /// callback. 425 /// 426 /// With callback of form: 427 /// - Value(OpBuilder &, T, ValueRange, Location, Type) 428 /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type) 429 template <typename T, typename FnT> 430 std::enable_if_t< 431 std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>, 432 TargetMaterializationCallbackFn> 433 wrapTargetMaterialization(FnT &&callback) const { 434 return [callback = std::forward<FnT>(callback)]( 435 OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, 436 Location loc, Type originalType) -> SmallVector<Value> { 437 SmallVector<Value> result; 438 if constexpr (std::is_same<T, TypeRange>::value) { 439 // This is a 1:N target materialization. Return the produces values 440 // directly. 441 result = callback(builder, resultTypes, inputs, loc, originalType); 442 } else if constexpr (std::is_assignable<Type, T>::value) { 443 // This is a 1:1 target materialization. Invoke the callback only if a 444 // single SSA value is requested. 445 if (resultTypes.size() == 1) { 446 // Invoke the callback only if the type class of the callback matches 447 // the requested result type. 448 if (T derivedType = dyn_cast<T>(resultTypes.front())) { 449 // 1:1 materializations produce single values, but we store 1:N 450 // target materialization functions in the type converter. Wrap the 451 // result value in a SmallVector<Value>. 452 Value val = 453 callback(builder, derivedType, inputs, loc, originalType); 454 if (val) 455 result.push_back(val); 456 } 457 } 458 } else { 459 static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange"); 460 } 461 return result; 462 }; 463 } 464 /// With callback of form: 465 /// - Value(OpBuilder &, T, ValueRange, Location) 466 /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location) 467 template <typename T, typename FnT> 468 std::enable_if_t< 469 std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>, 470 TargetMaterializationCallbackFn> 471 wrapTargetMaterialization(FnT &&callback) const { 472 return wrapTargetMaterialization<T>( 473 [callback = std::forward<FnT>(callback)]( 474 OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc, 475 Type originalType) { 476 return callback(builder, resultTypes, inputs, loc); 477 }); 478 } 479 480 /// Generate a wrapper for the given memory space conversion callback. The 481 /// callback may take any subclass of `Attribute` and the wrapper will check 482 /// for the target attribute to be of the expected class before calling the 483 /// callback. 484 template <typename T, typename A, typename FnT> 485 TypeAttributeConversionCallbackFn 486 wrapTypeAttributeConversion(FnT &&callback) const { 487 return [callback = std::forward<FnT>(callback)]( 488 Type type, Attribute attr) -> AttributeConversionResult { 489 if (T derivedType = dyn_cast<T>(type)) { 490 if (A derivedAttr = dyn_cast_or_null<A>(attr)) 491 return callback(derivedType, derivedAttr); 492 } 493 return AttributeConversionResult::na(); 494 }; 495 } 496 497 /// Register a memory space conversion, clearing caches. 498 void 499 registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) { 500 typeAttributeConversions.emplace_back(std::move(callback)); 501 // Clear type conversions in case a memory space is lingering inside. 502 cachedDirectConversions.clear(); 503 cachedMultiConversions.clear(); 504 } 505 506 /// The set of registered conversion functions. 507 SmallVector<ConversionCallbackFn, 4> conversions; 508 509 /// The list of registered materialization functions. 510 SmallVector<MaterializationCallbackFn, 2> argumentMaterializations; 511 SmallVector<MaterializationCallbackFn, 2> sourceMaterializations; 512 SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations; 513 514 /// The list of registered type attribute conversion functions. 515 SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions; 516 517 /// A set of cached conversions to avoid recomputing in the common case. 518 /// Direct 1-1 conversions are the most common, so this cache stores the 519 /// successful 1-1 conversions as well as all failed conversions. 520 mutable DenseMap<Type, Type> cachedDirectConversions; 521 /// This cache stores the successful 1->N conversions, where N != 1. 522 mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions; 523 /// A mutex used for cache access 524 mutable llvm::sys::SmartRWMutex<true> cacheMutex; 525 }; 526 527 //===----------------------------------------------------------------------===// 528 // Conversion Patterns 529 //===----------------------------------------------------------------------===// 530 531 /// Base class for the conversion patterns. This pattern class enables type 532 /// conversions, and other uses specific to the conversion framework. As such, 533 /// patterns of this type can only be used with the 'apply*' methods below. 534 class ConversionPattern : public RewritePattern { 535 public: 536 /// Hook for derived classes to implement rewriting. `op` is the (first) 537 /// operation matched by the pattern, `operands` is a list of the rewritten 538 /// operand values that are passed to `op`, `rewriter` can be used to emit the 539 /// new operations. This function should not fail. If some specific cases of 540 /// the operation are not supported, these cases should not be matched. 541 virtual void rewrite(Operation *op, ArrayRef<Value> operands, 542 ConversionPatternRewriter &rewriter) const { 543 llvm_unreachable("unimplemented rewrite"); 544 } 545 virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands, 546 ConversionPatternRewriter &rewriter) const { 547 rewrite(op, getOneToOneAdaptorOperands(operands), rewriter); 548 } 549 550 /// Hook for derived classes to implement combined matching and rewriting. 551 /// This overload supports only 1:1 replacements. The 1:N overload is called 552 /// by the driver. By default, it calls this 1:1 overload or reports a fatal 553 /// error if 1:N replacements were found. 554 virtual LogicalResult 555 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 556 ConversionPatternRewriter &rewriter) const { 557 if (failed(match(op))) 558 return failure(); 559 rewrite(op, operands, rewriter); 560 return success(); 561 } 562 563 /// Hook for derived classes to implement combined matching and rewriting. 564 /// This overload supports 1:N replacements. 565 virtual LogicalResult 566 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 567 ConversionPatternRewriter &rewriter) const { 568 return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); 569 } 570 571 /// Attempt to match and rewrite the IR root at the specified operation. 572 LogicalResult matchAndRewrite(Operation *op, 573 PatternRewriter &rewriter) const final; 574 575 /// Return the type converter held by this pattern, or nullptr if the pattern 576 /// does not require type conversion. 577 const TypeConverter *getTypeConverter() const { return typeConverter; } 578 579 template <typename ConverterTy> 580 std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value, 581 const ConverterTy *> 582 getTypeConverter() const { 583 return static_cast<const ConverterTy *>(typeConverter); 584 } 585 586 protected: 587 /// See `RewritePattern::RewritePattern` for information on the other 588 /// available constructors. 589 using RewritePattern::RewritePattern; 590 /// Construct a conversion pattern with the given converter, and forward the 591 /// remaining arguments to RewritePattern. 592 template <typename... Args> 593 ConversionPattern(const TypeConverter &typeConverter, Args &&...args) 594 : RewritePattern(std::forward<Args>(args)...), 595 typeConverter(&typeConverter) {} 596 597 /// Given an array of value ranges, which are the inputs to a 1:N adaptor, 598 /// try to extract the single value of each range to construct a the inputs 599 /// for a 1:1 adaptor. 600 /// 601 /// This function produces a fatal error if at least one range has 0 or 602 /// more than 1 value: "pattern 'name' does not support 1:N conversion" 603 SmallVector<Value> 604 getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const; 605 606 protected: 607 /// An optional type converter for use by this pattern. 608 const TypeConverter *typeConverter = nullptr; 609 610 private: 611 using RewritePattern::rewrite; 612 }; 613 614 /// OpConversionPattern is a wrapper around ConversionPattern that allows for 615 /// matching and rewriting against an instance of a derived operation class as 616 /// opposed to a raw Operation. 617 template <typename SourceOp> 618 class OpConversionPattern : public ConversionPattern { 619 public: 620 using OpAdaptor = typename SourceOp::Adaptor; 621 using OneToNOpAdaptor = 622 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>; 623 624 OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) 625 : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} 626 OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context, 627 PatternBenefit benefit = 1) 628 : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit, 629 context) {} 630 631 /// Wrappers around the ConversionPattern methods that pass the derived op 632 /// type. 633 LogicalResult match(Operation *op) const final { 634 return match(cast<SourceOp>(op)); 635 } 636 void rewrite(Operation *op, ArrayRef<Value> operands, 637 ConversionPatternRewriter &rewriter) const final { 638 auto sourceOp = cast<SourceOp>(op); 639 rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); 640 } 641 void rewrite(Operation *op, ArrayRef<ValueRange> operands, 642 ConversionPatternRewriter &rewriter) const final { 643 auto sourceOp = cast<SourceOp>(op); 644 rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); 645 } 646 LogicalResult 647 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 648 ConversionPatternRewriter &rewriter) const final { 649 auto sourceOp = cast<SourceOp>(op); 650 return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); 651 } 652 LogicalResult 653 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 654 ConversionPatternRewriter &rewriter) const final { 655 auto sourceOp = cast<SourceOp>(op); 656 return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), 657 rewriter); 658 } 659 660 /// Rewrite and Match methods that operate on the SourceOp type. These must be 661 /// overridden by the derived pattern class. 662 virtual LogicalResult match(SourceOp op) const { 663 llvm_unreachable("must override match or matchAndRewrite"); 664 } 665 virtual void rewrite(SourceOp op, OpAdaptor adaptor, 666 ConversionPatternRewriter &rewriter) const { 667 llvm_unreachable("must override matchAndRewrite or a rewrite method"); 668 } 669 virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, 670 ConversionPatternRewriter &rewriter) const { 671 SmallVector<Value> oneToOneOperands = 672 getOneToOneAdaptorOperands(adaptor.getOperands()); 673 rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); 674 } 675 virtual LogicalResult 676 matchAndRewrite(SourceOp op, OpAdaptor adaptor, 677 ConversionPatternRewriter &rewriter) const { 678 if (failed(match(op))) 679 return failure(); 680 rewrite(op, adaptor, rewriter); 681 return success(); 682 } 683 virtual LogicalResult 684 matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, 685 ConversionPatternRewriter &rewriter) const { 686 SmallVector<Value> oneToOneOperands = 687 getOneToOneAdaptorOperands(adaptor.getOperands()); 688 return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); 689 } 690 691 private: 692 using ConversionPattern::matchAndRewrite; 693 }; 694 695 /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that 696 /// allows for matching and rewriting against an instance of an OpInterface 697 /// class as opposed to a raw Operation. 698 template <typename SourceOp> 699 class OpInterfaceConversionPattern : public ConversionPattern { 700 public: 701 OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) 702 : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(), 703 SourceOp::getInterfaceID(), benefit, context) {} 704 OpInterfaceConversionPattern(const TypeConverter &typeConverter, 705 MLIRContext *context, PatternBenefit benefit = 1) 706 : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(), 707 SourceOp::getInterfaceID(), benefit, context) {} 708 709 /// Wrappers around the ConversionPattern methods that pass the derived op 710 /// type. 711 void rewrite(Operation *op, ArrayRef<Value> operands, 712 ConversionPatternRewriter &rewriter) const final { 713 rewrite(cast<SourceOp>(op), operands, rewriter); 714 } 715 void rewrite(Operation *op, ArrayRef<ValueRange> operands, 716 ConversionPatternRewriter &rewriter) const final { 717 rewrite(cast<SourceOp>(op), operands, rewriter); 718 } 719 LogicalResult 720 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 721 ConversionPatternRewriter &rewriter) const final { 722 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); 723 } 724 LogicalResult 725 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, 726 ConversionPatternRewriter &rewriter) const final { 727 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); 728 } 729 730 /// Rewrite and Match methods that operate on the SourceOp type. These must be 731 /// overridden by the derived pattern class. 732 virtual void rewrite(SourceOp op, ArrayRef<Value> operands, 733 ConversionPatternRewriter &rewriter) const { 734 llvm_unreachable("must override matchAndRewrite or a rewrite method"); 735 } 736 virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands, 737 ConversionPatternRewriter &rewriter) const { 738 rewrite(op, getOneToOneAdaptorOperands(operands), rewriter); 739 } 740 virtual LogicalResult 741 matchAndRewrite(SourceOp op, ArrayRef<Value> operands, 742 ConversionPatternRewriter &rewriter) const { 743 if (failed(match(op))) 744 return failure(); 745 rewrite(op, operands, rewriter); 746 return success(); 747 } 748 virtual LogicalResult 749 matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands, 750 ConversionPatternRewriter &rewriter) const { 751 return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); 752 } 753 754 private: 755 using ConversionPattern::matchAndRewrite; 756 }; 757 758 /// OpTraitConversionPattern is a wrapper around ConversionPattern that allows 759 /// for matching and rewriting against instances of an operation that possess a 760 /// given trait. 761 template <template <typename> class TraitType> 762 class OpTraitConversionPattern : public ConversionPattern { 763 public: 764 OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) 765 : ConversionPattern(Pattern::MatchTraitOpTypeTag(), 766 TypeID::get<TraitType>(), benefit, context) {} 767 OpTraitConversionPattern(const TypeConverter &typeConverter, 768 MLIRContext *context, PatternBenefit benefit = 1) 769 : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(), 770 TypeID::get<TraitType>(), benefit, context) {} 771 }; 772 773 /// Generic utility to convert op result types according to type converter 774 /// without knowing exact op type. 775 /// Clones existing op with new result types and returns it. 776 FailureOr<Operation *> 777 convertOpResultTypes(Operation *op, ValueRange operands, 778 const TypeConverter &converter, 779 ConversionPatternRewriter &rewriter); 780 781 /// Add a pattern to the given pattern list to convert the signature of a 782 /// FunctionOpInterface op with the given type converter. This only supports 783 /// ops which use FunctionType to represent their type. 784 void populateFunctionOpInterfaceTypeConversionPattern( 785 StringRef functionLikeOpName, RewritePatternSet &patterns, 786 const TypeConverter &converter); 787 788 template <typename FuncOpT> 789 void populateFunctionOpInterfaceTypeConversionPattern( 790 RewritePatternSet &patterns, const TypeConverter &converter) { 791 populateFunctionOpInterfaceTypeConversionPattern(FuncOpT::getOperationName(), 792 patterns, converter); 793 } 794 795 void populateAnyFunctionOpInterfaceTypeConversionPattern( 796 RewritePatternSet &patterns, const TypeConverter &converter); 797 798 //===----------------------------------------------------------------------===// 799 // Conversion PatternRewriter 800 //===----------------------------------------------------------------------===// 801 802 namespace detail { 803 struct ConversionPatternRewriterImpl; 804 } // namespace detail 805 806 /// This class implements a pattern rewriter for use with ConversionPatterns. It 807 /// extends the base PatternRewriter and provides special conversion specific 808 /// hooks. 809 class ConversionPatternRewriter final : public PatternRewriter { 810 public: 811 ~ConversionPatternRewriter() override; 812 813 /// Apply a signature conversion to given block. This replaces the block with 814 /// a new block containing the updated signature. The operations of the given 815 /// block are inlined into the newly-created block, which is returned. 816 /// 817 /// If no block argument types are changing, the original block will be 818 /// left in place and returned. 819 /// 820 /// A signature converison must be provided. (Type converters can construct 821 /// a signature conversion with `convertBlockSignature`.) 822 /// 823 /// Optionally, a type converter can be provided to build materializations. 824 /// Note: If no type converter was provided or the type converter does not 825 /// specify any suitable argument/target materialization rules, the dialect 826 /// conversion may fail to legalize unresolved materializations. 827 Block * 828 applySignatureConversion(Block *block, 829 TypeConverter::SignatureConversion &conversion, 830 const TypeConverter *converter = nullptr); 831 832 /// Apply a signature conversion to each block in the given region. This 833 /// replaces each block with a new block containing the updated signature. If 834 /// an updated signature would match the current signature, the respective 835 /// block is left in place as is. (See `applySignatureConversion` for 836 /// details.) The new entry block of the region is returned. 837 /// 838 /// SignatureConversions are computed with the specified type converter. 839 /// This function returns "failure" if the type converter failed to compute 840 /// a SignatureConversion for at least one block. 841 /// 842 /// Optionally, a special SignatureConversion can be specified for the entry 843 /// block. This is because the types of the entry block arguments are often 844 /// tied semantically to the operation. 845 FailureOr<Block *> convertRegionTypes( 846 Region *region, const TypeConverter &converter, 847 TypeConverter::SignatureConversion *entryConversion = nullptr); 848 849 /// Replace all the uses of the block argument `from` with value `to`. 850 void replaceUsesOfBlockArgument(BlockArgument from, Value to); 851 852 /// Return the converted value of 'key' with a type defined by the type 853 /// converter of the currently executing pattern. Return nullptr in the case 854 /// of failure, the remapped value otherwise. 855 Value getRemappedValue(Value key); 856 857 /// Return the converted values that replace 'keys' with types defined by the 858 /// type converter of the currently executing pattern. Returns failure if the 859 /// remap failed, success otherwise. 860 LogicalResult getRemappedValues(ValueRange keys, 861 SmallVectorImpl<Value> &results); 862 863 //===--------------------------------------------------------------------===// 864 // PatternRewriter Hooks 865 //===--------------------------------------------------------------------===// 866 867 /// Indicate that the conversion rewriter can recover from rewrite failure. 868 /// Recovery is supported via rollback, allowing for continued processing of 869 /// patterns even if a failure is encountered during the rewrite step. 870 bool canRecoverFromRewriteFailure() const override { return true; } 871 872 /// Replace the given operation with the new values. The number of op results 873 /// and replacement values must match. The types may differ: the dialect 874 /// conversion driver will reconcile any surviving type mismatches at the end 875 /// of the conversion process with source materializations. The given 876 /// operation is erased. 877 void replaceOp(Operation *op, ValueRange newValues) override; 878 879 /// Replace the given operation with the results of the new op. The number of 880 /// op results must match. The types may differ: the dialect conversion 881 /// driver will reconcile any surviving type mismatches at the end of the 882 /// conversion process with source materializations. The original operation 883 /// is erased. 884 void replaceOp(Operation *op, Operation *newOp) override; 885 886 /// Replace the given operation with the new value ranges. The number of op 887 /// results and value ranges must match. The given operation is erased. 888 void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues); 889 890 /// PatternRewriter hook for erasing a dead operation. The uses of this 891 /// operation *must* be made dead by the end of the conversion process, 892 /// otherwise an assert will be issued. 893 void eraseOp(Operation *op) override; 894 895 /// PatternRewriter hook for erase all operations in a block. This is not yet 896 /// implemented for dialect conversion. 897 void eraseBlock(Block *block) override; 898 899 /// PatternRewriter hook for inlining the ops of a block into another block. 900 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, 901 ValueRange argValues = std::nullopt) override; 902 using PatternRewriter::inlineBlockBefore; 903 904 /// PatternRewriter hook for updating the given operation in-place. 905 /// Note: These methods only track updates to the given operation itself, 906 /// and not nested regions. Updates to regions will still require notification 907 /// through other more specific hooks above. 908 void startOpModification(Operation *op) override; 909 910 /// PatternRewriter hook for updating the given operation in-place. 911 void finalizeOpModification(Operation *op) override; 912 913 /// PatternRewriter hook for updating the given operation in-place. 914 void cancelOpModification(Operation *op) override; 915 916 /// Return a reference to the internal implementation. 917 detail::ConversionPatternRewriterImpl &getImpl(); 918 919 private: 920 // Allow OperationConverter to construct new rewriters. 921 friend struct OperationConverter; 922 923 /// Conversion pattern rewriters must not be used outside of dialect 924 /// conversions. They apply some IR rewrites in a delayed fashion and could 925 /// bring the IR into an inconsistent state when used standalone. 926 explicit ConversionPatternRewriter(MLIRContext *ctx, 927 const ConversionConfig &config); 928 929 // Hide unsupported pattern rewriter API. 930 using OpBuilder::setListener; 931 932 std::unique_ptr<detail::ConversionPatternRewriterImpl> impl; 933 }; 934 935 //===----------------------------------------------------------------------===// 936 // ConversionTarget 937 //===----------------------------------------------------------------------===// 938 939 /// This class describes a specific conversion target. 940 class ConversionTarget { 941 public: 942 /// This enumeration corresponds to the specific action to take when 943 /// considering an operation legal for this conversion target. 944 enum class LegalizationAction { 945 /// The target supports this operation. 946 Legal, 947 948 /// This operation has dynamic legalization constraints that must be checked 949 /// by the target. 950 Dynamic, 951 952 /// The target explicitly does not support this operation. 953 Illegal, 954 }; 955 956 /// A structure containing additional information describing a specific legal 957 /// operation instance. 958 struct LegalOpDetails { 959 /// A flag that indicates if this operation is 'recursively' legal. This 960 /// means that if an operation is legal, either statically or dynamically, 961 /// all of the operations nested within are also considered legal. 962 bool isRecursivelyLegal = false; 963 }; 964 965 /// The signature of the callback used to determine if an operation is 966 /// dynamically legal on the target. 967 using DynamicLegalityCallbackFn = 968 std::function<std::optional<bool>(Operation *)>; 969 970 ConversionTarget(MLIRContext &ctx) : ctx(ctx) {} 971 virtual ~ConversionTarget() = default; 972 973 //===--------------------------------------------------------------------===// 974 // Legality Registration 975 //===--------------------------------------------------------------------===// 976 977 /// Register a legality action for the given operation. 978 void setOpAction(OperationName op, LegalizationAction action); 979 template <typename OpT> 980 void setOpAction(LegalizationAction action) { 981 setOpAction(OperationName(OpT::getOperationName(), &ctx), action); 982 } 983 984 /// Register the given operations as legal. 985 void addLegalOp(OperationName op) { 986 setOpAction(op, LegalizationAction::Legal); 987 } 988 template <typename OpT> 989 void addLegalOp() { 990 addLegalOp(OperationName(OpT::getOperationName(), &ctx)); 991 } 992 template <typename OpT, typename OpT2, typename... OpTs> 993 void addLegalOp() { 994 addLegalOp<OpT>(); 995 addLegalOp<OpT2, OpTs...>(); 996 } 997 998 /// Register the given operation as dynamically legal and set the dynamic 999 /// legalization callback to the one provided. 1000 void addDynamicallyLegalOp(OperationName op, 1001 const DynamicLegalityCallbackFn &callback) { 1002 setOpAction(op, LegalizationAction::Dynamic); 1003 setLegalityCallback(op, callback); 1004 } 1005 template <typename OpT> 1006 void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) { 1007 addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx), 1008 callback); 1009 } 1010 template <typename OpT, typename OpT2, typename... OpTs> 1011 void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) { 1012 addDynamicallyLegalOp<OpT>(callback); 1013 addDynamicallyLegalOp<OpT2, OpTs...>(callback); 1014 } 1015 template <typename OpT, class Callable> 1016 std::enable_if_t<!std::is_invocable_v<Callable, Operation *>> 1017 addDynamicallyLegalOp(Callable &&callback) { 1018 addDynamicallyLegalOp<OpT>( 1019 [=](Operation *op) { return callback(cast<OpT>(op)); }); 1020 } 1021 1022 /// Register the given operation as illegal, i.e. this operation is known to 1023 /// not be supported by this target. 1024 void addIllegalOp(OperationName op) { 1025 setOpAction(op, LegalizationAction::Illegal); 1026 } 1027 template <typename OpT> 1028 void addIllegalOp() { 1029 addIllegalOp(OperationName(OpT::getOperationName(), &ctx)); 1030 } 1031 template <typename OpT, typename OpT2, typename... OpTs> 1032 void addIllegalOp() { 1033 addIllegalOp<OpT>(); 1034 addIllegalOp<OpT2, OpTs...>(); 1035 } 1036 1037 /// Mark an operation, that *must* have either been set as `Legal` or 1038 /// `DynamicallyLegal`, as being recursively legal. This means that in 1039 /// addition to the operation itself, all of the operations nested within are 1040 /// also considered legal. An optional dynamic legality callback may be 1041 /// provided to mark subsets of legal instances as recursively legal. 1042 void markOpRecursivelyLegal(OperationName name, 1043 const DynamicLegalityCallbackFn &callback); 1044 template <typename OpT> 1045 void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) { 1046 OperationName opName(OpT::getOperationName(), &ctx); 1047 markOpRecursivelyLegal(opName, callback); 1048 } 1049 template <typename OpT, typename OpT2, typename... OpTs> 1050 void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) { 1051 markOpRecursivelyLegal<OpT>(callback); 1052 markOpRecursivelyLegal<OpT2, OpTs...>(callback); 1053 } 1054 template <typename OpT, class Callable> 1055 std::enable_if_t<!std::is_invocable_v<Callable, Operation *>> 1056 markOpRecursivelyLegal(Callable &&callback) { 1057 markOpRecursivelyLegal<OpT>( 1058 [=](Operation *op) { return callback(cast<OpT>(op)); }); 1059 } 1060 1061 /// Register a legality action for the given dialects. 1062 void setDialectAction(ArrayRef<StringRef> dialectNames, 1063 LegalizationAction action); 1064 1065 /// Register the operations of the given dialects as legal. 1066 template <typename... Names> 1067 void addLegalDialect(StringRef name, Names... names) { 1068 SmallVector<StringRef, 2> dialectNames({name, names...}); 1069 setDialectAction(dialectNames, LegalizationAction::Legal); 1070 } 1071 template <typename... Args> 1072 void addLegalDialect() { 1073 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...}); 1074 setDialectAction(dialectNames, LegalizationAction::Legal); 1075 } 1076 1077 /// Register the operations of the given dialects as dynamically legal, i.e. 1078 /// requiring custom handling by the callback. 1079 template <typename... Names> 1080 void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, 1081 StringRef name, Names... names) { 1082 SmallVector<StringRef, 2> dialectNames({name, names...}); 1083 setDialectAction(dialectNames, LegalizationAction::Dynamic); 1084 setLegalityCallback(dialectNames, callback); 1085 } 1086 template <typename... Args> 1087 void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) { 1088 addDynamicallyLegalDialect(std::move(callback), 1089 Args::getDialectNamespace()...); 1090 } 1091 1092 /// Register unknown operations as dynamically legal. For operations(and 1093 /// dialects) that do not have a set legalization action, treat them as 1094 /// dynamically legal and invoke the given callback. 1095 void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) { 1096 setLegalityCallback(fn); 1097 } 1098 1099 /// Register the operations of the given dialects as illegal, i.e. 1100 /// operations of this dialect are not supported by the target. 1101 template <typename... Names> 1102 void addIllegalDialect(StringRef name, Names... names) { 1103 SmallVector<StringRef, 2> dialectNames({name, names...}); 1104 setDialectAction(dialectNames, LegalizationAction::Illegal); 1105 } 1106 template <typename... Args> 1107 void addIllegalDialect() { 1108 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...}); 1109 setDialectAction(dialectNames, LegalizationAction::Illegal); 1110 } 1111 1112 //===--------------------------------------------------------------------===// 1113 // Legality Querying 1114 //===--------------------------------------------------------------------===// 1115 1116 /// Get the legality action for the given operation. 1117 std::optional<LegalizationAction> getOpAction(OperationName op) const; 1118 1119 /// If the given operation instance is legal on this target, a structure 1120 /// containing legality information is returned. If the operation is not 1121 /// legal, std::nullopt is returned. Also returns std::nullopt if operation 1122 /// legality wasn't registered by user or dynamic legality callbacks returned 1123 /// None. 1124 /// 1125 /// Note: Legality is actually a 4-state: Legal(recursive=true), 1126 /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated 1127 /// either as Legal or Illegal depending on context. 1128 std::optional<LegalOpDetails> isLegal(Operation *op) const; 1129 1130 /// Returns true is operation instance is illegal on this target. Returns 1131 /// false if operation is legal, operation legality wasn't registered by user 1132 /// or dynamic legality callbacks returned None. 1133 bool isIllegal(Operation *op) const; 1134 1135 private: 1136 /// Set the dynamic legality callback for the given operation. 1137 void setLegalityCallback(OperationName name, 1138 const DynamicLegalityCallbackFn &callback); 1139 1140 /// Set the dynamic legality callback for the given dialects. 1141 void setLegalityCallback(ArrayRef<StringRef> dialects, 1142 const DynamicLegalityCallbackFn &callback); 1143 1144 /// Set the dynamic legality callback for the unknown ops. 1145 void setLegalityCallback(const DynamicLegalityCallbackFn &callback); 1146 1147 /// The set of information that configures the legalization of an operation. 1148 struct LegalizationInfo { 1149 /// The legality action this operation was given. 1150 LegalizationAction action = LegalizationAction::Illegal; 1151 1152 /// If some legal instances of this operation may also be recursively legal. 1153 bool isRecursivelyLegal = false; 1154 1155 /// The legality callback if this operation is dynamically legal. 1156 DynamicLegalityCallbackFn legalityFn; 1157 }; 1158 1159 /// Get the legalization information for the given operation. 1160 std::optional<LegalizationInfo> getOpInfo(OperationName op) const; 1161 1162 /// A deterministic mapping of operation name and its respective legality 1163 /// information. 1164 llvm::MapVector<OperationName, LegalizationInfo> legalOperations; 1165 1166 /// A set of legality callbacks for given operation names that are used to 1167 /// check if an operation instance is recursively legal. 1168 DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns; 1169 1170 /// A deterministic mapping of dialect name to the specific legality action to 1171 /// take. 1172 llvm::StringMap<LegalizationAction> legalDialects; 1173 1174 /// A set of dynamic legality callbacks for given dialect names. 1175 llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns; 1176 1177 /// An optional legality callback for unknown operations. 1178 DynamicLegalityCallbackFn unknownLegalityFn; 1179 1180 /// The current context this target applies to. 1181 MLIRContext &ctx; 1182 }; 1183 1184 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 1185 //===----------------------------------------------------------------------===// 1186 // PDL Configuration 1187 //===----------------------------------------------------------------------===// 1188 1189 /// A PDL configuration that is used to supported dialect conversion 1190 /// functionality. 1191 class PDLConversionConfig final 1192 : public PDLPatternConfigBase<PDLConversionConfig> { 1193 public: 1194 PDLConversionConfig(const TypeConverter *converter) : converter(converter) {} 1195 ~PDLConversionConfig() final = default; 1196 1197 /// Return the type converter used by this configuration, which may be nullptr 1198 /// if no type conversions are expected. 1199 const TypeConverter *getTypeConverter() const { return converter; } 1200 1201 /// Hooks that are invoked at the beginning and end of a rewrite of a matched 1202 /// pattern. 1203 void notifyRewriteBegin(PatternRewriter &rewriter) final; 1204 void notifyRewriteEnd(PatternRewriter &rewriter) final; 1205 1206 private: 1207 /// An optional type converter to use for the pattern. 1208 const TypeConverter *converter; 1209 }; 1210 1211 /// Register the dialect conversion PDL functions with the given pattern set. 1212 void registerConversionPDLFunctions(RewritePatternSet &patterns); 1213 1214 #else 1215 1216 // Stubs for when PDL in rewriting is not enabled. 1217 1218 inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {} 1219 1220 class PDLConversionConfig final { 1221 public: 1222 PDLConversionConfig(const TypeConverter * /*converter*/) {} 1223 }; 1224 1225 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 1226 1227 //===----------------------------------------------------------------------===// 1228 // ConversionConfig 1229 //===----------------------------------------------------------------------===// 1230 1231 /// Dialect conversion configuration. 1232 struct ConversionConfig { 1233 /// An optional callback used to notify about match failure diagnostics during 1234 /// the conversion. Diagnostics reported to this callback may only be 1235 /// available in debug mode. 1236 function_ref<void(Diagnostic &)> notifyCallback = nullptr; 1237 1238 /// Partial conversion only. All operations that are found not to be 1239 /// legalizable are placed in this set. (Note that if there is an op 1240 /// explicitly marked as illegal, the conversion terminates and the set will 1241 /// not necessarily be complete.) 1242 DenseSet<Operation *> *unlegalizedOps = nullptr; 1243 1244 /// Analysis conversion only. All operations that are found to be legalizable 1245 /// are placed in this set. Note that no actual rewrites are applied to the 1246 /// IR during an analysis conversion and only pre-existing operations are 1247 /// added to the set. 1248 DenseSet<Operation *> *legalizableOps = nullptr; 1249 1250 /// An optional listener that is notified about all IR modifications in case 1251 /// dialect conversion succeeds. If the dialect conversion fails and no IR 1252 /// modifications are visible (i.e., they were all rolled back), or if the 1253 /// dialect conversion is an "analysis conversion", no notifications are 1254 /// sent (apart from `notifyPatternBegin`/notifyPatternEnd`). 1255 /// 1256 /// Note: Notifications are sent in a delayed fashion, when the dialect 1257 /// conversion is guaranteed to succeed. At that point, some IR modifications 1258 /// may already have been materialized. Consequently, operations/blocks that 1259 /// are passed to listener callbacks should not be accessed. (Ops/blocks are 1260 /// guaranteed to be valid pointers and accessing op names is allowed. But 1261 /// there are no guarantees about the state of ops/blocks at the time that a 1262 /// callback is triggered.) 1263 /// 1264 /// Example: Consider a dialect conversion a new op ("test.foo") is created 1265 /// and inserted, and later moved to another block. (Moving ops also triggers 1266 /// "notifyOperationInserted".) 1267 /// 1268 /// (1) notifyOperationInserted: "test.foo" (into block "b1") 1269 /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2") 1270 /// 1271 /// When querying "op->getBlock()" during the first "notifyOperationInserted", 1272 /// "b2" would be returned because "moving an op" is a kind of rewrite that is 1273 /// immediately performed by the dialect conversion (and rolled back upon 1274 /// failure). 1275 // 1276 // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted" 1277 // callback, the previous region/block is provided to the callback, but not 1278 // the iterator pointing to the exact location within the region/block. That 1279 // is because these notifications are sent with a delay (after the IR has 1280 // already been modified) and iterators into past IR state cannot be 1281 // represented at the moment. 1282 RewriterBase::Listener *listener = nullptr; 1283 1284 /// If set to "true", the dialect conversion attempts to build source/target 1285 /// materializations through the type converter API in lieu of 1286 /// "builtin.unrealized_conversion_cast ops". The conversion process fails if 1287 /// at least one materialization could not be built. 1288 /// 1289 /// If set to "false", the dialect conversion does not build any custom 1290 /// materializations and instead inserts "builtin.unrealized_conversion_cast" 1291 /// ops to ensure that the resulting IR is valid. 1292 bool buildMaterializations = true; 1293 }; 1294 1295 //===----------------------------------------------------------------------===// 1296 // Reconcile Unrealized Casts 1297 //===----------------------------------------------------------------------===// 1298 1299 /// Try to reconcile all given UnrealizedConversionCastOps and store the 1300 /// left-over ops in `remainingCastOps` (if provided). 1301 /// 1302 /// This function processes cast ops in a worklist-driven fashion. For each 1303 /// cast op, if the chain of input casts eventually reaches a cast op where the 1304 /// input types match the output types of the matched op, replace the matched 1305 /// op with the inputs. 1306 /// 1307 /// Example: 1308 /// %1 = unrealized_conversion_cast %0 : !A to !B 1309 /// %2 = unrealized_conversion_cast %1 : !B to !C 1310 /// %3 = unrealized_conversion_cast %2 : !C to !A 1311 /// 1312 /// In the above example, %0 can be used instead of %3 and all cast ops are 1313 /// folded away. 1314 void reconcileUnrealizedCasts( 1315 ArrayRef<UnrealizedConversionCastOp> castOps, 1316 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr); 1317 1318 //===----------------------------------------------------------------------===// 1319 // Op Conversion Entry Points 1320 //===----------------------------------------------------------------------===// 1321 1322 /// Below we define several entry points for operation conversion. It is 1323 /// important to note that the patterns provided to the conversion framework may 1324 /// have additional constraints. See the `PatternRewriter Hooks` section of the 1325 /// ConversionPatternRewriter, to see what additional constraints are imposed on 1326 /// the use of the PatternRewriter. 1327 1328 /// Apply a partial conversion on the given operations and all nested 1329 /// operations. This method converts as many operations to the target as 1330 /// possible, ignoring operations that failed to legalize. This method only 1331 /// returns failure if there ops explicitly marked as illegal. 1332 LogicalResult 1333 applyPartialConversion(ArrayRef<Operation *> ops, 1334 const ConversionTarget &target, 1335 const FrozenRewritePatternSet &patterns, 1336 ConversionConfig config = ConversionConfig()); 1337 LogicalResult 1338 applyPartialConversion(Operation *op, const ConversionTarget &target, 1339 const FrozenRewritePatternSet &patterns, 1340 ConversionConfig config = ConversionConfig()); 1341 1342 /// Apply a complete conversion on the given operations, and all nested 1343 /// operations. This method returns failure if the conversion of any operation 1344 /// fails, or if there are unreachable blocks in any of the regions nested 1345 /// within 'ops'. 1346 LogicalResult applyFullConversion(ArrayRef<Operation *> ops, 1347 const ConversionTarget &target, 1348 const FrozenRewritePatternSet &patterns, 1349 ConversionConfig config = ConversionConfig()); 1350 LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target, 1351 const FrozenRewritePatternSet &patterns, 1352 ConversionConfig config = ConversionConfig()); 1353 1354 /// Apply an analysis conversion on the given operations, and all nested 1355 /// operations. This method analyzes which operations would be successfully 1356 /// converted to the target if a conversion was applied. All operations that 1357 /// were found to be legalizable to the given 'target' are placed within the 1358 /// provided 'config.legalizableOps' set; note that no actual rewrites are 1359 /// applied to the operations on success. This method only returns failure if 1360 /// there are unreachable blocks in any of the regions nested within 'ops'. 1361 LogicalResult 1362 applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target, 1363 const FrozenRewritePatternSet &patterns, 1364 ConversionConfig config = ConversionConfig()); 1365 LogicalResult 1366 applyAnalysisConversion(Operation *op, ConversionTarget &target, 1367 const FrozenRewritePatternSet &patterns, 1368 ConversionConfig config = ConversionConfig()); 1369 } // namespace mlir 1370 1371 #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ 1372