1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 10 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/Operation.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/Interfaces/CastInterfaces.h" 15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16 #include "llvm/ADT/STLExtras.h" 17 #include "llvm/ADT/ScopeExit.h" 18 #include "llvm/Support/Debug.h" 19 #include "llvm/Support/ErrorHandling.h" 20 21 #define DEBUG_TYPE "transform-dialect" 22 #define DEBUG_TYPE_FULL "transform-dialect-full" 23 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all" 24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 25 #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 26 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X))) 27 28 using namespace mlir; 29 30 //===----------------------------------------------------------------------===// 31 // Helper functions 32 //===----------------------------------------------------------------------===// 33 34 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors 35 /// properly dominates `b` and `b` is not inside `a`. 36 static bool happensBefore(Operation *a, Operation *b) { 37 do { 38 if (a->isProperAncestor(b)) 39 return false; 40 if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) { 41 return a->isBeforeInBlock(bAncestor); 42 } 43 } while ((a = a->getParentOp())); 44 return false; 45 } 46 47 //===----------------------------------------------------------------------===// 48 // TransformState 49 //===----------------------------------------------------------------------===// 50 51 constexpr const Value transform::TransformState::kTopLevelValue; 52 53 transform::TransformState::TransformState( 54 Region *region, Operation *payloadRoot, 55 const RaggedArray<MappedValue> &extraMappings, 56 const TransformOptions &options) 57 : topLevel(payloadRoot), options(options) { 58 topLevelMappedValues.reserve(extraMappings.size()); 59 for (ArrayRef<MappedValue> mapping : extraMappings) 60 topLevelMappedValues.push_back(mapping); 61 if (region) { 62 RegionScope *scope = new RegionScope(*this, *region); 63 topLevelRegionScope.reset(scope); 64 } 65 } 66 67 Operation *transform::TransformState::getTopLevel() const { return topLevel; } 68 69 ArrayRef<Operation *> 70 transform::TransformState::getPayloadOpsView(Value value) const { 71 const TransformOpMapping &operationMapping = getMapping(value).direct; 72 auto iter = operationMapping.find(value); 73 assert(iter != operationMapping.end() && 74 "cannot find mapping for payload handle (param/value handle " 75 "provided?)"); 76 return iter->getSecond(); 77 } 78 79 ArrayRef<Attribute> transform::TransformState::getParams(Value value) const { 80 const ParamMapping &mapping = getMapping(value).params; 81 auto iter = mapping.find(value); 82 assert(iter != mapping.end() && "cannot find mapping for param handle " 83 "(operation/value handle provided?)"); 84 return iter->getSecond(); 85 } 86 87 ArrayRef<Value> 88 transform::TransformState::getPayloadValuesView(Value handleValue) const { 89 const ValueMapping &mapping = getMapping(handleValue).values; 90 auto iter = mapping.find(handleValue); 91 assert(iter != mapping.end() && "cannot find mapping for value handle " 92 "(param/operation handle provided?)"); 93 return iter->getSecond(); 94 } 95 96 LogicalResult transform::TransformState::getHandlesForPayloadOp( 97 Operation *op, SmallVectorImpl<Value> &handles, 98 bool includeOutOfScope) const { 99 bool found = false; 100 for (const auto &[region, mapping] : llvm::reverse(mappings)) { 101 auto iterator = mapping->reverse.find(op); 102 if (iterator != mapping->reverse.end()) { 103 llvm::append_range(handles, iterator->getSecond()); 104 found = true; 105 } 106 // Stop looking when reaching a region that is isolated from above. 107 if (!includeOutOfScope && 108 region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 109 break; 110 } 111 112 return success(found); 113 } 114 115 LogicalResult transform::TransformState::getHandlesForPayloadValue( 116 Value payloadValue, SmallVectorImpl<Value> &handles, 117 bool includeOutOfScope) const { 118 bool found = false; 119 for (const auto &[region, mapping] : llvm::reverse(mappings)) { 120 auto iterator = mapping->reverseValues.find(payloadValue); 121 if (iterator != mapping->reverseValues.end()) { 122 llvm::append_range(handles, iterator->getSecond()); 123 found = true; 124 } 125 // Stop looking when reaching a region that is isolated from above. 126 if (!includeOutOfScope && 127 region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 128 break; 129 } 130 131 return success(found); 132 } 133 134 /// Given a list of MappedValues, cast them to the value kind implied by the 135 /// interface of the handle type, and dispatch to one of the callbacks. 136 static DiagnosedSilenceableFailure dispatchMappedValues( 137 Value handle, ArrayRef<transform::MappedValue> values, 138 function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn, 139 function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn, 140 function_ref<LogicalResult(ValueRange)> valuesFn) { 141 if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) { 142 SmallVector<Operation *> operations; 143 operations.reserve(values.size()); 144 for (transform::MappedValue value : values) { 145 if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) { 146 operations.push_back(op); 147 continue; 148 } 149 return emitSilenceableFailure(handle.getLoc()) 150 << "wrong kind of value provided for top-level operation handle"; 151 } 152 if (failed(operationsFn(operations))) 153 return DiagnosedSilenceableFailure::definiteFailure(); 154 return DiagnosedSilenceableFailure::success(); 155 } 156 157 if (llvm::isa<transform::TransformValueHandleTypeInterface>( 158 handle.getType())) { 159 SmallVector<Value> payloadValues; 160 payloadValues.reserve(values.size()); 161 for (transform::MappedValue value : values) { 162 if (auto v = llvm::dyn_cast_if_present<Value>(value)) { 163 payloadValues.push_back(v); 164 continue; 165 } 166 return emitSilenceableFailure(handle.getLoc()) 167 << "wrong kind of value provided for the top-level value handle"; 168 } 169 if (failed(valuesFn(payloadValues))) 170 return DiagnosedSilenceableFailure::definiteFailure(); 171 return DiagnosedSilenceableFailure::success(); 172 } 173 174 assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) && 175 "unsupported kind of block argument"); 176 SmallVector<transform::Param> parameters; 177 parameters.reserve(values.size()); 178 for (transform::MappedValue value : values) { 179 if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) { 180 parameters.push_back(attr); 181 continue; 182 } 183 return emitSilenceableFailure(handle.getLoc()) 184 << "wrong kind of value provided for top-level parameter"; 185 } 186 if (failed(paramsFn(parameters))) 187 return DiagnosedSilenceableFailure::definiteFailure(); 188 return DiagnosedSilenceableFailure::success(); 189 } 190 191 LogicalResult 192 transform::TransformState::mapBlockArgument(BlockArgument argument, 193 ArrayRef<MappedValue> values) { 194 return dispatchMappedValues( 195 argument, values, 196 [&](ArrayRef<Operation *> operations) { 197 return setPayloadOps(argument, operations); 198 }, 199 [&](ArrayRef<Param> params) { 200 return setParams(argument, params); 201 }, 202 [&](ValueRange payloadValues) { 203 return setPayloadValues(argument, payloadValues); 204 }) 205 .checkAndReport(); 206 } 207 208 LogicalResult transform::TransformState::mapBlockArguments( 209 Block::BlockArgListType arguments, 210 ArrayRef<SmallVector<MappedValue>> mapping) { 211 for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping)) 212 if (failed(mapBlockArgument(argument, values))) 213 return failure(); 214 return success(); 215 } 216 217 LogicalResult 218 transform::TransformState::setPayloadOps(Value value, 219 ArrayRef<Operation *> targets) { 220 assert(value != kTopLevelValue && 221 "attempting to reset the transformation root"); 222 assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) && 223 "wrong handle type"); 224 225 for (Operation *target : targets) { 226 if (target) 227 continue; 228 return emitError(value.getLoc()) 229 << "attempting to assign a null payload op to this transform value"; 230 } 231 232 auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType()); 233 DiagnosedSilenceableFailure result = 234 iface.checkPayload(value.getLoc(), targets); 235 if (failed(result.checkAndReport())) 236 return failure(); 237 238 // Setting new payload for the value without cleaning it first is a misuse of 239 // the API, assert here. 240 SmallVector<Operation *> storedTargets(targets); 241 Mappings &mappings = getMapping(value); 242 bool inserted = 243 mappings.direct.insert({value, std::move(storedTargets)}).second; 244 assert(inserted && "value is already associated with another list"); 245 (void)inserted; 246 247 for (Operation *op : targets) 248 mappings.reverse[op].push_back(value); 249 250 return success(); 251 } 252 253 LogicalResult 254 transform::TransformState::setPayloadValues(Value handle, 255 ValueRange payloadValues) { 256 assert(handle != nullptr && "attempting to set params for a null value"); 257 assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) && 258 "wrong handle type"); 259 260 for (Value payload : payloadValues) { 261 if (payload) 262 continue; 263 return emitError(handle.getLoc()) << "attempting to assign a null payload " 264 "value to this transform handle"; 265 } 266 267 auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType()); 268 SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues); 269 DiagnosedSilenceableFailure result = 270 iface.checkPayload(handle.getLoc(), payloadValueVector); 271 if (failed(result.checkAndReport())) 272 return failure(); 273 274 Mappings &mappings = getMapping(handle); 275 bool inserted = 276 mappings.values.insert({handle, std::move(payloadValueVector)}).second; 277 assert( 278 inserted && 279 "value handle is already associated with another list of payload values"); 280 (void)inserted; 281 282 for (Value payload : payloadValues) 283 mappings.reverseValues[payload].push_back(handle); 284 285 return success(); 286 } 287 288 LogicalResult transform::TransformState::setParams(Value value, 289 ArrayRef<Param> params) { 290 assert(value != nullptr && "attempting to set params for a null value"); 291 292 for (Attribute attr : params) { 293 if (attr) 294 continue; 295 return emitError(value.getLoc()) 296 << "attempting to assign a null parameter to this transform value"; 297 } 298 299 auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType()); 300 assert(value && 301 "cannot associate parameter with a value of non-parameter type"); 302 DiagnosedSilenceableFailure result = 303 valueType.checkPayload(value.getLoc(), params); 304 if (failed(result.checkAndReport())) 305 return failure(); 306 307 Mappings &mappings = getMapping(value); 308 bool inserted = 309 mappings.params.insert({value, llvm::to_vector(params)}).second; 310 assert(inserted && "value is already associated with another list of params"); 311 (void)inserted; 312 return success(); 313 } 314 315 template <typename Mapping, typename Key, typename Mapped> 316 void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) { 317 auto it = mapping.find(key); 318 if (it == mapping.end()) 319 return; 320 321 llvm::erase(it->getSecond(), mapped); 322 if (it->getSecond().empty()) 323 mapping.erase(it); 324 } 325 326 void transform::TransformState::forgetMapping(Value opHandle, 327 ValueRange origOpFlatResults, 328 bool allowOutOfScope) { 329 Mappings &mappings = getMapping(opHandle, allowOutOfScope); 330 for (Operation *op : mappings.direct[opHandle]) 331 dropMappingEntry(mappings.reverse, op, opHandle); 332 mappings.direct.erase(opHandle); 333 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 334 // Payload IR is removed from the mapping. This invalidates the respective 335 // iterators. 336 mappings.incrementTimestamp(opHandle); 337 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 338 339 for (Value opResult : origOpFlatResults) { 340 SmallVector<Value> resultHandles; 341 (void)getHandlesForPayloadValue(opResult, resultHandles); 342 for (Value resultHandle : resultHandles) { 343 Mappings &localMappings = getMapping(resultHandle); 344 dropMappingEntry(localMappings.values, resultHandle, opResult); 345 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 346 // Payload IR is removed from the mapping. This invalidates the respective 347 // iterators. 348 mappings.incrementTimestamp(resultHandle); 349 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 350 dropMappingEntry(localMappings.reverseValues, opResult, resultHandle); 351 } 352 } 353 } 354 355 void transform::TransformState::forgetValueMapping( 356 Value valueHandle, ArrayRef<Operation *> payloadOperations) { 357 Mappings &mappings = getMapping(valueHandle); 358 for (Value payloadValue : mappings.reverseValues[valueHandle]) 359 dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle); 360 mappings.values.erase(valueHandle); 361 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 362 // Payload IR is removed from the mapping. This invalidates the respective 363 // iterators. 364 mappings.incrementTimestamp(valueHandle); 365 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 366 367 for (Operation *payloadOp : payloadOperations) { 368 SmallVector<Value> opHandles; 369 (void)getHandlesForPayloadOp(payloadOp, opHandles); 370 for (Value opHandle : opHandles) { 371 Mappings &localMappings = getMapping(opHandle); 372 dropMappingEntry(localMappings.direct, opHandle, payloadOp); 373 dropMappingEntry(localMappings.reverse, payloadOp, opHandle); 374 375 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 376 // Payload IR is removed from the mapping. This invalidates the respective 377 // iterators. 378 localMappings.incrementTimestamp(opHandle); 379 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 380 } 381 } 382 } 383 384 LogicalResult 385 transform::TransformState::replacePayloadOp(Operation *op, 386 Operation *replacement) { 387 // TODO: consider invalidating the handles to nested objects here. 388 389 #ifndef NDEBUG 390 for (Value opResult : op->getResults()) { 391 SmallVector<Value> valueHandles; 392 (void)getHandlesForPayloadValue(opResult, valueHandles, 393 /*includeOutOfScope=*/true); 394 assert(valueHandles.empty() && "expected no mapping to old results"); 395 } 396 #endif // NDEBUG 397 398 // Drop the mapping between the op and all handles that point to it. Fail if 399 // there are no handles. 400 SmallVector<Value> opHandles; 401 if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true))) 402 return failure(); 403 for (Value handle : opHandles) { 404 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); 405 dropMappingEntry(mappings.reverse, op, handle); 406 } 407 408 // Replace the pointed-to object of all handles with the replacement object. 409 // In case a payload op was erased (replacement object is nullptr), a nullptr 410 // is stored in the mapping. These nullptrs are removed after each transform. 411 // Furthermore, nullptrs are not enumerated by payload op iterators. The 412 // relative order of ops is preserved. 413 // 414 // Removing an op from the mapping would be problematic because removing an 415 // element from an array invalidates iterators; merely changing the value of 416 // elements does not. 417 for (Value handle : opHandles) { 418 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); 419 auto it = mappings.direct.find(handle); 420 if (it == mappings.direct.end()) 421 continue; 422 423 SmallVector<Operation *, 2> &association = it->getSecond(); 424 // Note that an operation may be associated with the handle more than once. 425 for (Operation *&mapped : association) { 426 if (mapped == op) 427 mapped = replacement; 428 } 429 430 if (replacement) { 431 mappings.reverse[replacement].push_back(handle); 432 } else { 433 opHandlesToCompact.insert(handle); 434 } 435 } 436 437 return success(); 438 } 439 440 LogicalResult 441 transform::TransformState::replacePayloadValue(Value value, Value replacement) { 442 SmallVector<Value> valueHandles; 443 if (failed(getHandlesForPayloadValue(value, valueHandles, 444 /*includeOutOfScope=*/true))) 445 return failure(); 446 447 for (Value handle : valueHandles) { 448 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); 449 dropMappingEntry(mappings.reverseValues, value, handle); 450 451 // If replacing with null, that is erasing the mapping, drop the mapping 452 // between the handles and the IR objects 453 if (!replacement) { 454 dropMappingEntry(mappings.values, handle, value); 455 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 456 // Payload IR is removed from the mapping. This invalidates the respective 457 // iterators. 458 mappings.incrementTimestamp(handle); 459 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 460 } else { 461 auto it = mappings.values.find(handle); 462 if (it == mappings.values.end()) 463 continue; 464 465 SmallVector<Value> &association = it->getSecond(); 466 for (Value &mapped : association) { 467 if (mapped == value) 468 mapped = replacement; 469 } 470 mappings.reverseValues[replacement].push_back(handle); 471 } 472 } 473 474 return success(); 475 } 476 477 void transform::TransformState::recordOpHandleInvalidationOne( 478 OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors, 479 Operation *payloadOp, Value otherHandle, Value throughValue, 480 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { 481 // If the op is associated with invalidated handle, skip the check as it 482 // may be reading invalid IR. This also ensures we report the first 483 // invalidation and not the last one. 484 if (invalidatedHandles.count(otherHandle) || 485 newlyInvalidated.count(otherHandle)) 486 return; 487 488 FULL_LDBG("--recordOpHandleInvalidationOne\n"); 489 DEBUG_WITH_TYPE( 490 DEBUG_TYPE_FULL, 491 llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ", 492 [](Operation *op) { llvm::dbgs() << *op; }); 493 llvm::dbgs() << "\n"); 494 495 Operation *owner = consumingHandle.getOwner(); 496 unsigned operandNo = consumingHandle.getOperandNumber(); 497 for (Operation *ancestor : potentialAncestors) { 498 // clang-format off 499 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 500 { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); }); 501 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 502 { (DBGS() << "----of payload with name: " 503 << payloadOp->getName().getIdentifier() << "\n"); }); 504 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 505 { (DBGS() << "----of payload: " << *payloadOp << "\n"); }); 506 // clang-format on 507 if (!ancestor->isAncestor(payloadOp)) 508 continue; 509 510 // Make sure the error-reporting lambda doesn't capture anything 511 // by-reference because it will go out of scope. Additionally, extract 512 // location from Payload IR ops because the ops themselves may be 513 // deleted before the lambda gets called. 514 Location ancestorLoc = ancestor->getLoc(); 515 Location opLoc = payloadOp->getLoc(); 516 std::optional<Location> throughValueLoc = 517 throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt; 518 newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, 519 otherHandle, 520 throughValueLoc](Location currentLoc) { 521 InFlightDiagnostic diag = emitError(currentLoc) 522 << "op uses a handle invalidated by a " 523 "previously executed transform op"; 524 diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops"; 525 diag.attachNote(owner->getLoc()) 526 << "invalidated by this transform op that consumes its operand #" 527 << operandNo 528 << " and invalidates all handles to payload IR entities associated " 529 "with this operand and entities nested in them"; 530 diag.attachNote(ancestorLoc) << "ancestor payload op"; 531 diag.attachNote(opLoc) << "nested payload op"; 532 if (throughValueLoc) { 533 diag.attachNote(*throughValueLoc) 534 << "consumed handle points to this payload value"; 535 } 536 }; 537 } 538 } 539 540 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne( 541 OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors, 542 Value payloadValue, Value valueHandle, 543 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { 544 // If the op is associated with invalidated handle, skip the check as it 545 // may be reading invalid IR. This also ensures we report the first 546 // invalidation and not the last one. 547 if (invalidatedHandles.count(valueHandle) || 548 newlyInvalidated.count(valueHandle)) 549 return; 550 551 for (Operation *ancestor : potentialAncestors) { 552 Operation *definingOp; 553 std::optional<unsigned> resultNo; 554 unsigned argumentNo = std::numeric_limits<unsigned>::max(); 555 unsigned blockNo = std::numeric_limits<unsigned>::max(); 556 unsigned regionNo = std::numeric_limits<unsigned>::max(); 557 if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) { 558 definingOp = opResult.getOwner(); 559 resultNo = opResult.getResultNumber(); 560 } else { 561 auto arg = llvm::cast<BlockArgument>(payloadValue); 562 definingOp = arg.getParentBlock()->getParentOp(); 563 argumentNo = arg.getArgNumber(); 564 blockNo = std::distance(arg.getOwner()->getParent()->begin(), 565 arg.getOwner()->getIterator()); 566 regionNo = arg.getOwner()->getParent()->getRegionNumber(); 567 } 568 assert(definingOp && "expected the value to be defined by an op as result " 569 "or block argument"); 570 if (!ancestor->isAncestor(definingOp)) 571 continue; 572 573 Operation *owner = opHandle.getOwner(); 574 unsigned operandNo = opHandle.getOperandNumber(); 575 Location ancestorLoc = ancestor->getLoc(); 576 Location opLoc = definingOp->getLoc(); 577 Location valueLoc = payloadValue.getLoc(); 578 newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo, 579 argumentNo, blockNo, regionNo, ancestorLoc, 580 opLoc, valueLoc](Location currentLoc) { 581 InFlightDiagnostic diag = emitError(currentLoc) 582 << "op uses a handle invalidated by a " 583 "previously executed transform op"; 584 diag.attachNote(valueHandle.getLoc()) << "invalidated handle"; 585 diag.attachNote(owner->getLoc()) 586 << "invalidated by this transform op that consumes its operand #" 587 << operandNo 588 << " and invalidates all handles to payload IR entities " 589 "associated with this operand and entities nested in them"; 590 diag.attachNote(ancestorLoc) 591 << "ancestor op associated with the consumed handle"; 592 if (resultNo) { 593 diag.attachNote(opLoc) 594 << "op defining the value as result #" << *resultNo; 595 } else { 596 diag.attachNote(opLoc) 597 << "op defining the value as block argument #" << argumentNo 598 << " of block #" << blockNo << " in region #" << regionNo; 599 } 600 diag.attachNote(valueLoc) << "payload value"; 601 }; 602 } 603 } 604 605 void transform::TransformState::recordOpHandleInvalidation( 606 OpOperand &handle, ArrayRef<Operation *> potentialAncestors, 607 Value throughValue, 608 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { 609 610 if (potentialAncestors.empty()) { 611 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { 612 (DBGS() << "----recording invalidation for empty handle: " << handle.get() 613 << "\n"); 614 }); 615 616 Operation *owner = handle.getOwner(); 617 unsigned operandNo = handle.getOperandNumber(); 618 newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) { 619 InFlightDiagnostic diag = emitError(currentLoc) 620 << "op uses a handle associated with empty " 621 "payload and invalidated by a " 622 "previously executed transform op"; 623 diag.attachNote(owner->getLoc()) 624 << "invalidated by this transform op that consumes its operand #" 625 << operandNo; 626 }; 627 return; 628 } 629 630 // Iterate over the mapping and invalidate aliasing handles. This is quite 631 // expensive and only necessary for error reporting in case of transform 632 // dialect misuse with dangling handles. Iteration over the handles is based 633 // on the assumption that the number of handles is significantly less than the 634 // number of IR objects (operations and values). Alternatively, we could walk 635 // the IR nested in each payload op associated with the given handle and look 636 // for handles associated with each operation and value. 637 for (const auto &[region, mapping] : llvm::reverse(mappings)) { 638 // Go over all op handle mappings and mark as invalidated any handle 639 // pointing to any of the payload ops associated with the given handle or 640 // any op nested in them. 641 for (const auto &[payloadOp, otherHandles] : mapping->reverse) { 642 for (Value otherHandle : otherHandles) 643 recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp, 644 otherHandle, throughValue, 645 newlyInvalidated); 646 } 647 // Go over all value handle mappings and mark as invalidated any handle 648 // pointing to any result of the payload op associated with the given handle 649 // or any op nested in them. Similarly invalidate handles to argument of 650 // blocks belonging to any region of any payload op associated with the 651 // given handle or any op nested in them. 652 for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) { 653 for (Value valueHandle : valueHandles) 654 recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors, 655 payloadValue, valueHandle, 656 newlyInvalidated); 657 } 658 659 // Stop lookup when reaching a region that is isolated from above. 660 if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 661 break; 662 } 663 } 664 665 void transform::TransformState::recordValueHandleInvalidation( 666 OpOperand &valueHandle, 667 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { 668 // Invalidate other handles to the same value. 669 for (Value payloadValue : getPayloadValuesView(valueHandle.get())) { 670 SmallVector<Value> otherValueHandles; 671 (void)getHandlesForPayloadValue(payloadValue, otherValueHandles); 672 for (Value otherHandle : otherValueHandles) { 673 Operation *owner = valueHandle.getOwner(); 674 unsigned operandNo = valueHandle.getOperandNumber(); 675 Location valueLoc = payloadValue.getLoc(); 676 newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo, 677 valueLoc](Location currentLoc) { 678 InFlightDiagnostic diag = emitError(currentLoc) 679 << "op uses a handle invalidated by a " 680 "previously executed transform op"; 681 diag.attachNote(otherHandle.getLoc()) << "invalidated handle"; 682 diag.attachNote(owner->getLoc()) 683 << "invalidated by this transform op that consumes its operand #" 684 << operandNo 685 << " and invalidates handles to the same values as associated with " 686 "it"; 687 diag.attachNote(valueLoc) << "payload value"; 688 }; 689 } 690 691 if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) { 692 Operation *payloadOp = opResult.getOwner(); 693 recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue, 694 newlyInvalidated); 695 } else { 696 auto arg = llvm::dyn_cast<BlockArgument>(payloadValue); 697 for (Operation &payloadOp : *arg.getOwner()) 698 recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue, 699 newlyInvalidated); 700 } 701 } 702 } 703 704 /// Checks that the operation does not use invalidated handles as operands. 705 /// Reports errors and returns failure if it does. Otherwise, invalidates the 706 /// handles consumed by the operation as well as any handles pointing to payload 707 /// IR operations nested in the operations associated with the consumed handles. 708 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl( 709 transform::TransformOpInterface transform, 710 transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { 711 FULL_LDBG("--Start checkAndRecordHandleInvalidation\n"); 712 auto memoryEffectsIface = 713 cast<MemoryEffectOpInterface>(transform.getOperation()); 714 SmallVector<MemoryEffects::EffectInstance> effects; 715 memoryEffectsIface.getEffectsOnResource( 716 transform::TransformMappingResource::get(), effects); 717 718 for (OpOperand &target : transform->getOpOperands()) { 719 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { 720 (DBGS() << "----iterate on handle: " << target.get() << "\n"); 721 }); 722 // If the operand uses an invalidated handle, report it. If the operation 723 // allows handles to point to repeated payload operations, only report 724 // pre-existing invalidation errors. Otherwise, also report invalidations 725 // caused by the current transform operation affecting its other operands. 726 auto it = invalidatedHandles.find(target.get()); 727 auto nit = newlyInvalidated.find(target.get()); 728 if (it != invalidatedHandles.end()) { 729 FULL_LDBG("--End checkAndRecordHandleInvalidation, found already " 730 "invalidated -> FAILURE\n"); 731 return it->getSecond()(transform->getLoc()), failure(); 732 } 733 if (!transform.allowsRepeatedHandleOperands() && 734 nit != newlyInvalidated.end()) { 735 FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly " 736 "invalidated (by this op) -> FAILURE\n"); 737 return nit->getSecond()(transform->getLoc()), failure(); 738 } 739 740 // Invalidate handles pointing to the operations nested in the operation 741 // associated with the handle consumed by this operation. 742 auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) { 743 return isa<MemoryEffects::Free>(effect.getEffect()) && 744 effect.getValue() == target.get(); 745 }; 746 if (llvm::any_of(effects, consumesTarget)) { 747 FULL_LDBG("----found consume effect\n"); 748 if (llvm::isa<transform::TransformHandleTypeInterface>( 749 target.get().getType())) { 750 FULL_LDBG("----recordOpHandleInvalidation\n"); 751 SmallVector<Operation *> payloadOps = 752 llvm::to_vector(getPayloadOps(target.get())); 753 recordOpHandleInvalidation(target, payloadOps, nullptr, 754 newlyInvalidated); 755 } else if (llvm::isa<transform::TransformValueHandleTypeInterface>( 756 target.get().getType())) { 757 FULL_LDBG("----recordValueHandleInvalidation\n"); 758 recordValueHandleInvalidation(target, newlyInvalidated); 759 } else { 760 FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); 761 } 762 } else { 763 FULL_LDBG("----no consume effect -> SKIP\n"); 764 } 765 } 766 767 FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n"); 768 return success(); 769 } 770 771 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( 772 transform::TransformOpInterface transform) { 773 InvalidatedHandleMap newlyInvalidated; 774 LogicalResult checkResult = 775 checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated); 776 invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()), 777 std::make_move_iterator(newlyInvalidated.end())); 778 return checkResult; 779 } 780 781 template <typename T> 782 DiagnosedSilenceableFailure 783 checkRepeatedConsumptionInOperand(ArrayRef<T> payload, 784 transform::TransformOpInterface transform, 785 unsigned operandNumber) { 786 DenseSet<T> seen; 787 for (T p : payload) { 788 if (!seen.insert(p).second) { 789 DiagnosedSilenceableFailure diag = 790 transform.emitSilenceableError() 791 << "a handle passed as operand #" << operandNumber 792 << " and consumed by this operation points to a payload " 793 "entity more than once"; 794 if constexpr (std::is_pointer_v<T>) 795 diag.attachNote(p->getLoc()) << "repeated target op"; 796 else 797 diag.attachNote(p.getLoc()) << "repeated target value"; 798 return diag; 799 } 800 } 801 return DiagnosedSilenceableFailure::success(); 802 } 803 804 void transform::TransformState::compactOpHandles() { 805 for (Value handle : opHandlesToCompact) { 806 Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); 807 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 808 if (llvm::find(mappings.direct[handle], nullptr) != 809 mappings.direct[handle].end()) 810 // Payload IR is removed from the mapping. This invalidates the respective 811 // iterators. 812 mappings.incrementTimestamp(handle); 813 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 814 llvm::erase(mappings.direct[handle], nullptr); 815 } 816 opHandlesToCompact.clear(); 817 } 818 819 DiagnosedSilenceableFailure 820 transform::TransformState::applyTransform(TransformOpInterface transform) { 821 LLVM_DEBUG({ 822 DBGS() << "applying: "; 823 transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions()); 824 llvm::dbgs() << "\n"; 825 }); 826 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 827 DBGS() << "Top-level payload before application:\n" 828 << *getTopLevel() << "\n"); 829 auto printOnFailureRAII = llvm::make_scope_exit([this] { 830 (void)this; 831 LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print( 832 llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm());); 833 }); 834 835 // Set current transform op. 836 regionStack.back()->currentTransform = transform; 837 838 // Expensive checks to detect invalid transform IR. 839 if (options.getExpensiveChecksEnabled()) { 840 FULL_LDBG("ExpensiveChecksEnabled\n"); 841 if (failed(checkAndRecordHandleInvalidation(transform))) 842 return DiagnosedSilenceableFailure::definiteFailure(); 843 844 for (OpOperand &operand : transform->getOpOperands()) { 845 DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { 846 (DBGS() << "iterate on handle: " << operand.get() << "\n"); 847 }); 848 if (!isHandleConsumed(operand.get(), transform)) { 849 FULL_LDBG("--handle not consumed -> SKIP\n"); 850 continue; 851 } 852 if (transform.allowsRepeatedHandleOperands()) { 853 FULL_LDBG("--op allows repeated handles -> SKIP\n"); 854 continue; 855 } 856 FULL_LDBG("--handle is consumed\n"); 857 858 Type operandType = operand.get().getType(); 859 if (llvm::isa<TransformHandleTypeInterface>(operandType)) { 860 FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n"); 861 DiagnosedSilenceableFailure check = 862 checkRepeatedConsumptionInOperand<Operation *>( 863 getPayloadOpsView(operand.get()), transform, 864 operand.getOperandNumber()); 865 if (!check.succeeded()) { 866 FULL_LDBG("----FAILED\n"); 867 return check; 868 } 869 } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) { 870 FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n"); 871 DiagnosedSilenceableFailure check = 872 checkRepeatedConsumptionInOperand<Value>( 873 getPayloadValuesView(operand.get()), transform, 874 operand.getOperandNumber()); 875 if (!check.succeeded()) { 876 FULL_LDBG("----FAILED\n"); 877 return check; 878 } 879 } else { 880 FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); 881 } 882 } 883 } 884 885 // Find which operands are consumed. 886 SmallVector<OpOperand *> consumedOperands = 887 transform.getConsumedHandleOpOperands(); 888 889 // Remember the results of the payload ops associated with the consumed 890 // op handles or the ops defining the value handles so we can drop the 891 // association with them later. This must happen here because the 892 // transformation may destroy or mutate them so we cannot traverse the payload 893 // IR after that. 894 SmallVector<Value> origOpFlatResults; 895 SmallVector<Operation *> origAssociatedOps; 896 for (OpOperand *opOperand : consumedOperands) { 897 Value operand = opOperand->get(); 898 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) { 899 for (Operation *payloadOp : getPayloadOps(operand)) { 900 llvm::append_range(origOpFlatResults, payloadOp->getResults()); 901 } 902 continue; 903 } 904 if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) { 905 for (Value payloadValue : getPayloadValuesView(operand)) { 906 if (llvm::isa<OpResult>(payloadValue)) { 907 origAssociatedOps.push_back(payloadValue.getDefiningOp()); 908 continue; 909 } 910 llvm::append_range( 911 origAssociatedOps, 912 llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(), 913 [](Operation &op) { return &op; })); 914 } 915 continue; 916 } 917 DiagnosedDefiniteFailure diag = 918 emitDefiniteFailure(transform->getLoc()) 919 << "unexpectedly consumed a value that is not a handle as operand #" 920 << opOperand->getOperandNumber(); 921 diag.attachNote(operand.getLoc()) 922 << "value defined here with type " << operand.getType(); 923 return diag; 924 } 925 926 // Prepare rewriter and listener. 927 TrackingListenerConfig config; 928 config.skipHandleFn = [&](Value handle) { 929 // Skip handle if it is dead. 930 auto scopeIt = 931 llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) { 932 return handle.getParentRegion() == scope->region; 933 }); 934 assert(scopeIt != regionStack.rend() && 935 "could not find region scope for handle"); 936 RegionScope *scope = *scopeIt; 937 return llvm::all_of(handle.getUsers(), [&](Operation *user) { 938 return user == scope->currentTransform || 939 happensBefore(user, scope->currentTransform); 940 }); 941 }; 942 transform::ErrorCheckingTrackingListener trackingListener(*this, transform, 943 config); 944 transform::TransformRewriter rewriter(transform->getContext(), 945 &trackingListener); 946 947 // Compute the result but do not short-circuit the silenceable failure case as 948 // we still want the handles to propagate properly so the "suppress" mode can 949 // proceed on a best effort basis. 950 transform::TransformResults results(transform->getNumResults()); 951 DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this)); 952 compactOpHandles(); 953 954 // Error handling: fail if transform or listener failed. 955 DiagnosedSilenceableFailure trackingFailure = 956 trackingListener.checkAndResetError(); 957 if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() || 958 transform->hasAttr(FindPayloadReplacementOpInterface:: 959 kSilenceTrackingFailuresAttrName)) { 960 // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also 961 // do not report failures if the above mentioned attribute is set. 962 if (trackingFailure.isSilenceableFailure()) 963 (void)trackingFailure.silence(); 964 trackingFailure = DiagnosedSilenceableFailure::success(); 965 } 966 if (!trackingFailure.succeeded()) { 967 if (result.succeeded()) { 968 result = std::move(trackingFailure); 969 } else { 970 // Transform op errors have precedence, report those first. 971 if (result.isSilenceableFailure()) 972 result.attachNote() << "tracking listener also failed: " 973 << trackingFailure.getMessage(); 974 (void)trackingFailure.silence(); 975 } 976 } 977 if (result.isDefiniteFailure()) 978 return result; 979 980 // If a silenceable failure was produced, some results may be unset, set them 981 // to empty lists. 982 if (result.isSilenceableFailure()) 983 results.setRemainingToEmpty(transform); 984 985 // Remove the mapping for the operand if it is consumed by the operation. This 986 // allows us to catch use-after-free with assertions later on. 987 for (OpOperand *opOperand : consumedOperands) { 988 Value operand = opOperand->get(); 989 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) { 990 forgetMapping(operand, origOpFlatResults); 991 } else if (llvm::isa<TransformValueHandleTypeInterface>( 992 operand.getType())) { 993 forgetValueMapping(operand, origAssociatedOps); 994 } 995 } 996 997 if (failed(updateStateFromResults(results, transform->getResults()))) 998 return DiagnosedSilenceableFailure::definiteFailure(); 999 1000 printOnFailureRAII.release(); 1001 DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, { 1002 DBGS() << "Top-level payload:\n"; 1003 getTopLevel()->print(llvm::dbgs()); 1004 }); 1005 return result; 1006 } 1007 1008 LogicalResult transform::TransformState::updateStateFromResults( 1009 const TransformResults &results, ResultRange opResults) { 1010 for (OpResult result : opResults) { 1011 if (llvm::isa<TransformParamTypeInterface>(result.getType())) { 1012 assert(results.isParam(result.getResultNumber()) && 1013 "expected parameters for the parameter-typed result"); 1014 if (failed( 1015 setParams(result, results.getParams(result.getResultNumber())))) { 1016 return failure(); 1017 } 1018 } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) { 1019 assert(results.isValue(result.getResultNumber()) && 1020 "expected values for value-type-result"); 1021 if (failed(setPayloadValues( 1022 result, results.getValues(result.getResultNumber())))) { 1023 return failure(); 1024 } 1025 } else { 1026 assert(!results.isParam(result.getResultNumber()) && 1027 "expected payload ops for the non-parameter typed result"); 1028 if (failed( 1029 setPayloadOps(result, results.get(result.getResultNumber())))) { 1030 return failure(); 1031 } 1032 } 1033 } 1034 return success(); 1035 } 1036 1037 //===----------------------------------------------------------------------===// 1038 // TransformState::Extension 1039 //===----------------------------------------------------------------------===// 1040 1041 transform::TransformState::Extension::~Extension() = default; 1042 1043 LogicalResult 1044 transform::TransformState::Extension::replacePayloadOp(Operation *op, 1045 Operation *replacement) { 1046 // TODO: we may need to invalidate handles to operations and values nested in 1047 // the operation being replaced. 1048 return state.replacePayloadOp(op, replacement); 1049 } 1050 1051 LogicalResult 1052 transform::TransformState::Extension::replacePayloadValue(Value value, 1053 Value replacement) { 1054 return state.replacePayloadValue(value, replacement); 1055 } 1056 1057 //===----------------------------------------------------------------------===// 1058 // TransformState::RegionScope 1059 //===----------------------------------------------------------------------===// 1060 1061 transform::TransformState::RegionScope::~RegionScope() { 1062 // Remove handle invalidation notices as handles are going out of scope. 1063 // The same region may be re-entered leading to incorrect invalidation 1064 // errors. 1065 for (Block &block : *region) { 1066 for (Value handle : block.getArguments()) { 1067 state.invalidatedHandles.erase(handle); 1068 } 1069 for (Operation &op : block) { 1070 for (Value handle : op.getResults()) { 1071 state.invalidatedHandles.erase(handle); 1072 } 1073 } 1074 } 1075 1076 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 1077 // Remember pointers to payload ops referenced by the handles going out of 1078 // scope. 1079 SmallVector<Operation *> referencedOps = 1080 llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse)); 1081 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 1082 1083 state.mappings.erase(region); 1084 state.regionStack.pop_back(); 1085 } 1086 1087 //===----------------------------------------------------------------------===// 1088 // TransformResults 1089 //===----------------------------------------------------------------------===// 1090 1091 transform::TransformResults::TransformResults(unsigned numSegments) { 1092 operations.appendEmptyRows(numSegments); 1093 params.appendEmptyRows(numSegments); 1094 values.appendEmptyRows(numSegments); 1095 } 1096 1097 void transform::TransformResults::setParams( 1098 OpResult value, ArrayRef<transform::TransformState::Param> params) { 1099 int64_t position = value.getResultNumber(); 1100 assert(position < static_cast<int64_t>(this->params.size()) && 1101 "setting params for a non-existent handle"); 1102 assert(this->params[position].data() == nullptr && "params already set"); 1103 assert(operations[position].data() == nullptr && 1104 "another kind of results already set"); 1105 assert(values[position].data() == nullptr && 1106 "another kind of results already set"); 1107 this->params.replace(position, params); 1108 } 1109 1110 void transform::TransformResults::setMappedValues( 1111 OpResult handle, ArrayRef<MappedValue> values) { 1112 DiagnosedSilenceableFailure diag = dispatchMappedValues( 1113 handle, values, 1114 [&](ArrayRef<Operation *> operations) { 1115 return set(handle, operations), success(); 1116 }, 1117 [&](ArrayRef<Param> params) { 1118 return setParams(handle, params), success(); 1119 }, 1120 [&](ValueRange payloadValues) { 1121 return setValues(handle, payloadValues), success(); 1122 }); 1123 #ifndef NDEBUG 1124 if (!diag.succeeded()) 1125 llvm::dbgs() << diag.getStatusString() << "\n"; 1126 assert(diag.succeeded() && "incorrect mapping"); 1127 #endif // NDEBUG 1128 (void)diag.silence(); 1129 } 1130 1131 void transform::TransformResults::setRemainingToEmpty( 1132 transform::TransformOpInterface transform) { 1133 for (OpResult opResult : transform->getResults()) { 1134 if (!isSet(opResult.getResultNumber())) 1135 setMappedValues(opResult, {}); 1136 } 1137 } 1138 1139 ArrayRef<Operation *> 1140 transform::TransformResults::get(unsigned resultNumber) const { 1141 assert(resultNumber < operations.size() && 1142 "querying results for a non-existent handle"); 1143 assert(operations[resultNumber].data() != nullptr && 1144 "querying unset results (values or params expected?)"); 1145 return operations[resultNumber]; 1146 } 1147 1148 ArrayRef<transform::TransformState::Param> 1149 transform::TransformResults::getParams(unsigned resultNumber) const { 1150 assert(resultNumber < params.size() && 1151 "querying params for a non-existent handle"); 1152 assert(params[resultNumber].data() != nullptr && 1153 "querying unset params (ops or values expected?)"); 1154 return params[resultNumber]; 1155 } 1156 1157 ArrayRef<Value> 1158 transform::TransformResults::getValues(unsigned resultNumber) const { 1159 assert(resultNumber < values.size() && 1160 "querying values for a non-existent handle"); 1161 assert(values[resultNumber].data() != nullptr && 1162 "querying unset values (ops or params expected?)"); 1163 return values[resultNumber]; 1164 } 1165 1166 bool transform::TransformResults::isParam(unsigned resultNumber) const { 1167 assert(resultNumber < params.size() && 1168 "querying association for a non-existent handle"); 1169 return params[resultNumber].data() != nullptr; 1170 } 1171 1172 bool transform::TransformResults::isValue(unsigned resultNumber) const { 1173 assert(resultNumber < values.size() && 1174 "querying association for a non-existent handle"); 1175 return values[resultNumber].data() != nullptr; 1176 } 1177 1178 bool transform::TransformResults::isSet(unsigned resultNumber) const { 1179 assert(resultNumber < params.size() && 1180 "querying association for a non-existent handle"); 1181 return params[resultNumber].data() != nullptr || 1182 operations[resultNumber].data() != nullptr || 1183 values[resultNumber].data() != nullptr; 1184 } 1185 1186 //===----------------------------------------------------------------------===// 1187 // TrackingListener 1188 //===----------------------------------------------------------------------===// 1189 1190 transform::TrackingListener::TrackingListener(TransformState &state, 1191 TransformOpInterface op, 1192 TrackingListenerConfig config) 1193 : TransformState::Extension(state), transformOp(op), config(config) { 1194 if (op) { 1195 for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) { 1196 consumedHandles.insert(opOperand->get()); 1197 } 1198 } 1199 } 1200 1201 Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) { 1202 Operation *defOp = nullptr; 1203 for (Value v : values) { 1204 // Skip empty values. 1205 if (!v) 1206 continue; 1207 if (!defOp) { 1208 defOp = v.getDefiningOp(); 1209 continue; 1210 } 1211 if (defOp != v.getDefiningOp()) 1212 return nullptr; 1213 } 1214 return defOp; 1215 } 1216 1217 DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp( 1218 Operation *&result, Operation *op, ValueRange newValues) const { 1219 assert(op->getNumResults() == newValues.size() && 1220 "invalid number of replacement values"); 1221 SmallVector<Value> values(newValues.begin(), newValues.end()); 1222 1223 DiagnosedSilenceableFailure diag = emitSilenceableFailure( 1224 getTransformOp(), "tracking listener failed to find replacement op " 1225 "during application of this transform op"); 1226 1227 do { 1228 // If the replacement values belong to different ops, drop the mapping. 1229 Operation *defOp = getCommonDefiningOp(values); 1230 if (!defOp) { 1231 diag.attachNote() << "replacement values belong to different ops"; 1232 return diag; 1233 } 1234 1235 // Skip through ops that implement CastOpInterface. 1236 if (config.skipCastOps && isa<CastOpInterface>(defOp)) { 1237 values.clear(); 1238 values.assign(defOp->getOperands().begin(), defOp->getOperands().end()); 1239 diag.attachNote(defOp->getLoc()) 1240 << "using output of 'CastOpInterface' op"; 1241 continue; 1242 } 1243 1244 // If the defining op has the same name or we do not care about the name of 1245 // op replacements at all, we take it as a replacement. 1246 if (!config.requireMatchingReplacementOpName || 1247 op->getName() == defOp->getName()) { 1248 result = defOp; 1249 return DiagnosedSilenceableFailure::success(); 1250 } 1251 1252 // Replacing an op with a constant-like equivalent is a common 1253 // canonicalization. 1254 if (defOp->hasTrait<OpTrait::ConstantLike>()) { 1255 result = defOp; 1256 return DiagnosedSilenceableFailure::success(); 1257 } 1258 1259 values.clear(); 1260 1261 // Skip through ops that implement FindPayloadReplacementOpInterface. 1262 if (auto findReplacementOpInterface = 1263 dyn_cast<FindPayloadReplacementOpInterface>(defOp)) { 1264 values.assign(findReplacementOpInterface.getNextOperands()); 1265 diag.attachNote(defOp->getLoc()) << "using operands provided by " 1266 "'FindPayloadReplacementOpInterface'"; 1267 continue; 1268 } 1269 } while (!values.empty()); 1270 1271 diag.attachNote() << "ran out of suitable replacement values"; 1272 return diag; 1273 } 1274 1275 void transform::TrackingListener::notifyMatchFailure( 1276 Location loc, function_ref<void(Diagnostic &)> reasonCallback) { 1277 LLVM_DEBUG({ 1278 Diagnostic diag(loc, DiagnosticSeverity::Remark); 1279 reasonCallback(diag); 1280 DBGS() << "Match Failure : " << diag.str() << "\n"; 1281 }); 1282 } 1283 1284 void transform::TrackingListener::notifyOperationErased(Operation *op) { 1285 // Remove mappings for result values. 1286 for (OpResult value : op->getResults()) 1287 (void)replacePayloadValue(value, nullptr); 1288 // Remove mapping for op. 1289 (void)replacePayloadOp(op, nullptr); 1290 } 1291 1292 void transform::TrackingListener::notifyOperationReplaced( 1293 Operation *op, ValueRange newValues) { 1294 assert(op->getNumResults() == newValues.size() && 1295 "invalid number of replacement values"); 1296 1297 // Replace value handles. 1298 for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) 1299 (void)replacePayloadValue(oldValue, newValue); 1300 1301 // Replace op handle. 1302 SmallVector<Value> opHandles; 1303 if (failed(getTransformState().getHandlesForPayloadOp( 1304 op, opHandles, /*includeOutOfScope=*/true))) { 1305 // Op is not tracked. 1306 return; 1307 } 1308 1309 // Helper function to check if the current transform op consumes any handle 1310 // that is mapped to `op`. 1311 // 1312 // Note: If a handle was consumed, there shouldn't be any alive users, so it 1313 // is not really necessary to check for consumed handles. However, in case 1314 // there are indeed alive handles that were consumed (which is undefined 1315 // behavior) and a replacement op could not be found, we want to fail with a 1316 // nicer error message: "op uses a handle invalidated..." instead of "could 1317 // not find replacement op". This nicer error is produced later. 1318 auto handleWasConsumed = [&] { 1319 return llvm::any_of(opHandles, 1320 [&](Value h) { return consumedHandles.contains(h); }); 1321 }; 1322 1323 // Check if there are any handles that must be updated. 1324 Value aliveHandle; 1325 if (config.skipHandleFn) { 1326 auto it = llvm::find_if(opHandles, 1327 [&](Value v) { return !config.skipHandleFn(v); }); 1328 if (it != opHandles.end()) 1329 aliveHandle = *it; 1330 } else if (!opHandles.empty()) { 1331 aliveHandle = opHandles.front(); 1332 } 1333 if (!aliveHandle || handleWasConsumed()) { 1334 // The op is tracked but the corresponding handles are dead or were 1335 // consumed. Drop the op form the mapping. 1336 (void)replacePayloadOp(op, nullptr); 1337 return; 1338 } 1339 1340 Operation *replacement; 1341 DiagnosedSilenceableFailure diag = 1342 findReplacementOp(replacement, op, newValues); 1343 // If the op is tracked but no replacement op was found, send a 1344 // notification. 1345 if (!diag.succeeded()) { 1346 diag.attachNote(aliveHandle.getLoc()) 1347 << "replacement is required because this handle must be updated"; 1348 notifyPayloadReplacementNotFound(op, newValues, std::move(diag)); 1349 (void)replacePayloadOp(op, nullptr); 1350 return; 1351 } 1352 1353 (void)replacePayloadOp(op, replacement); 1354 } 1355 1356 transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() { 1357 // The state of the ErrorCheckingTrackingListener must be checked and reset 1358 // if there was an error. This is to prevent errors from accidentally being 1359 // missed. 1360 assert(status.succeeded() && "listener state was not checked"); 1361 } 1362 1363 DiagnosedSilenceableFailure 1364 transform::ErrorCheckingTrackingListener::checkAndResetError() { 1365 DiagnosedSilenceableFailure s = std::move(status); 1366 status = DiagnosedSilenceableFailure::success(); 1367 errorCounter = 0; 1368 return s; 1369 } 1370 1371 bool transform::ErrorCheckingTrackingListener::failed() const { 1372 return !status.succeeded(); 1373 } 1374 1375 void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( 1376 Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) { 1377 1378 // Merge potentially existing diags and store the result in the listener. 1379 SmallVector<Diagnostic> diags; 1380 diag.takeDiagnostics(diags); 1381 if (!status.succeeded()) 1382 status.takeDiagnostics(diags); 1383 status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags)); 1384 1385 // Report more details. 1386 status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op"; 1387 for (auto &&[index, value] : llvm::enumerate(values)) 1388 status.attachNote(value.getLoc()) 1389 << "[" << errorCounter << "] replacement value " << index; 1390 ++errorCounter; 1391 } 1392 1393 //===----------------------------------------------------------------------===// 1394 // TransformRewriter 1395 //===----------------------------------------------------------------------===// 1396 1397 transform::TransformRewriter::TransformRewriter( 1398 MLIRContext *ctx, ErrorCheckingTrackingListener *listener) 1399 : RewriterBase(ctx), listener(listener) { 1400 setListener(listener); 1401 } 1402 1403 bool transform::TransformRewriter::hasTrackingFailures() const { 1404 return listener->failed(); 1405 } 1406 1407 /// Silence all tracking failures that have been encountered so far. 1408 void transform::TransformRewriter::silenceTrackingFailure() { 1409 if (hasTrackingFailures()) { 1410 DiagnosedSilenceableFailure status = listener->checkAndResetError(); 1411 (void)status.silence(); 1412 } 1413 } 1414 1415 LogicalResult transform::TransformRewriter::notifyPayloadOperationReplaced( 1416 Operation *op, Operation *replacement) { 1417 return listener->replacePayloadOp(op, replacement); 1418 } 1419 1420 //===----------------------------------------------------------------------===// 1421 // Utilities for TransformEachOpTrait. 1422 //===----------------------------------------------------------------------===// 1423 1424 LogicalResult 1425 transform::detail::checkNestedConsumption(Location loc, 1426 ArrayRef<Operation *> targets) { 1427 for (auto &&[position, parent] : llvm::enumerate(targets)) { 1428 for (Operation *child : targets.drop_front(position + 1)) { 1429 if (parent->isAncestor(child)) { 1430 InFlightDiagnostic diag = 1431 emitError(loc) 1432 << "transform operation consumes a handle pointing to an ancestor " 1433 "payload operation before its descendant"; 1434 diag.attachNote() 1435 << "the ancestor is likely erased or rewritten before the " 1436 "descendant is accessed, leading to undefined behavior"; 1437 diag.attachNote(parent->getLoc()) << "ancestor payload op"; 1438 diag.attachNote(child->getLoc()) << "descendant payload op"; 1439 return diag; 1440 } 1441 } 1442 } 1443 return success(); 1444 } 1445 1446 LogicalResult 1447 transform::detail::checkApplyToOne(Operation *transformOp, 1448 Location payloadOpLoc, 1449 const ApplyToEachResultList &partialResult) { 1450 Location transformOpLoc = transformOp->getLoc(); 1451 StringRef transformOpName = transformOp->getName().getStringRef(); 1452 unsigned expectedNumResults = transformOp->getNumResults(); 1453 1454 // Reuse the emission of the diagnostic note. 1455 auto emitDiag = [&]() { 1456 auto diag = mlir::emitError(transformOpLoc); 1457 diag.attachNote(payloadOpLoc) << "when applied to this op"; 1458 return diag; 1459 }; 1460 1461 if (partialResult.size() != expectedNumResults) { 1462 auto diag = emitDiag() << "application of " << transformOpName 1463 << " expected to produce " << expectedNumResults 1464 << " results (actually produced " 1465 << partialResult.size() << ")."; 1466 diag.attachNote(transformOpLoc) 1467 << "if you need variadic results, consider a generic `apply` " 1468 << "instead of the specialized `applyToOne`."; 1469 return failure(); 1470 } 1471 1472 // Check that the right kind of value was produced. 1473 for (const auto &[ptr, res] : 1474 llvm::zip(partialResult, transformOp->getResults())) { 1475 if (ptr.isNull()) 1476 continue; 1477 if (llvm::isa<TransformHandleTypeInterface>(res.getType()) && 1478 !isa<Operation *>(ptr)) { 1479 return emitDiag() << "application of " << transformOpName 1480 << " expected to produce an Operation * for result #" 1481 << res.getResultNumber(); 1482 } 1483 if (llvm::isa<TransformParamTypeInterface>(res.getType()) && 1484 !isa<Attribute>(ptr)) { 1485 return emitDiag() << "application of " << transformOpName 1486 << " expected to produce an Attribute for result #" 1487 << res.getResultNumber(); 1488 } 1489 if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) && 1490 !isa<Value>(ptr)) { 1491 return emitDiag() << "application of " << transformOpName 1492 << " expected to produce a Value for result #" 1493 << res.getResultNumber(); 1494 } 1495 } 1496 return success(); 1497 } 1498 1499 template <typename T> 1500 static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) { 1501 return llvm::to_vector(llvm::map_range( 1502 range, [](transform::MappedValue value) { return cast<T>(value); })); 1503 } 1504 1505 void transform::detail::setApplyToOneResults( 1506 Operation *transformOp, TransformResults &transformResults, 1507 ArrayRef<ApplyToEachResultList> results) { 1508 SmallVector<SmallVector<MappedValue>> transposed; 1509 transposed.resize(transformOp->getNumResults()); 1510 for (const ApplyToEachResultList &partialResults : results) { 1511 if (llvm::any_of(partialResults, 1512 [](MappedValue value) { return value.isNull(); })) 1513 continue; 1514 assert(transformOp->getNumResults() == partialResults.size() && 1515 "expected as many partial results as op as results"); 1516 for (auto [i, value] : llvm::enumerate(partialResults)) 1517 transposed[i].push_back(value); 1518 } 1519 1520 for (OpResult r : transformOp->getResults()) { 1521 unsigned position = r.getResultNumber(); 1522 if (llvm::isa<TransformParamTypeInterface>(r.getType())) { 1523 transformResults.setParams(r, 1524 castVector<Attribute>(transposed[position])); 1525 } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) { 1526 transformResults.setValues(r, castVector<Value>(transposed[position])); 1527 } else { 1528 transformResults.set(r, castVector<Operation *>(transposed[position])); 1529 } 1530 } 1531 } 1532 1533 //===----------------------------------------------------------------------===// 1534 // Utilities for implementing transform ops with regions. 1535 //===----------------------------------------------------------------------===// 1536 1537 LogicalResult transform::detail::appendValueMappings( 1538 MutableArrayRef<SmallVector<transform::MappedValue>> mappings, 1539 ValueRange values, const transform::TransformState &state, bool flatten) { 1540 assert(mappings.size() == values.size() && "mismatching number of mappings"); 1541 for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) { 1542 size_t mappedSize = mapped.size(); 1543 if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) { 1544 llvm::append_range(mapped, state.getPayloadOps(operand)); 1545 } else if (llvm::isa<TransformValueHandleTypeInterface>( 1546 operand.getType())) { 1547 llvm::append_range(mapped, state.getPayloadValues(operand)); 1548 } else { 1549 assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) && 1550 "unsupported kind of transform dialect value"); 1551 llvm::append_range(mapped, state.getParams(operand)); 1552 } 1553 1554 if (mapped.size() - mappedSize != 1 && !flatten) 1555 return failure(); 1556 } 1557 return success(); 1558 } 1559 1560 void transform::detail::prepareValueMappings( 1561 SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings, 1562 ValueRange values, const transform::TransformState &state) { 1563 mappings.resize(mappings.size() + values.size()); 1564 (void)appendValueMappings( 1565 MutableArrayRef<SmallVector<transform::MappedValue>>(mappings).take_back( 1566 values.size()), 1567 values, state); 1568 } 1569 1570 void transform::detail::forwardTerminatorOperands( 1571 Block *block, transform::TransformState &state, 1572 transform::TransformResults &results) { 1573 for (auto &&[terminatorOperand, result] : 1574 llvm::zip(block->getTerminator()->getOperands(), 1575 block->getParentOp()->getOpResults())) { 1576 if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) { 1577 results.set(result, state.getPayloadOps(terminatorOperand)); 1578 } else if (llvm::isa<transform::TransformValueHandleTypeInterface>( 1579 result.getType())) { 1580 results.setValues(result, state.getPayloadValues(terminatorOperand)); 1581 } else { 1582 assert( 1583 llvm::isa<transform::TransformParamTypeInterface>(result.getType()) && 1584 "unhandled transform type interface"); 1585 results.setParams(result, state.getParams(terminatorOperand)); 1586 } 1587 } 1588 } 1589 1590 transform::TransformState 1591 transform::detail::makeTransformStateForTesting(Region *region, 1592 Operation *payloadRoot) { 1593 return TransformState(region, payloadRoot); 1594 } 1595 1596 //===----------------------------------------------------------------------===// 1597 // Utilities for PossibleTopLevelTransformOpTrait. 1598 //===----------------------------------------------------------------------===// 1599 1600 /// Appends to `effects` the memory effect instances on `target` with the same 1601 /// resource and effect as the ones the operation `iface` having on `source`. 1602 static void 1603 remapEffects(MemoryEffectOpInterface iface, BlockArgument source, 1604 OpOperand *target, 1605 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1606 SmallVector<MemoryEffects::EffectInstance> nestedEffects; 1607 iface.getEffectsOnValue(source, nestedEffects); 1608 for (const auto &effect : nestedEffects) 1609 effects.emplace_back(effect.getEffect(), target, effect.getResource()); 1610 } 1611 1612 /// Appends to `effects` the same effects as the operations of `block` have on 1613 /// block arguments but associated with `operands.` 1614 static void 1615 remapArgumentEffects(Block &block, MutableArrayRef<OpOperand> operands, 1616 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1617 for (Operation &op : block) { 1618 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 1619 if (!iface) 1620 continue; 1621 1622 for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) { 1623 remapEffects(iface, source, &target, effects); 1624 } 1625 1626 SmallVector<MemoryEffects::EffectInstance> nestedEffects; 1627 iface.getEffectsOnResource(transform::PayloadIRResource::get(), 1628 nestedEffects); 1629 llvm::append_range(effects, nestedEffects); 1630 } 1631 } 1632 1633 void transform::detail::getPotentialTopLevelEffects( 1634 Operation *operation, Value root, Block &body, 1635 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1636 transform::onlyReadsHandle(operation->getOpOperands(), effects); 1637 transform::producesHandle(operation->getOpResults(), effects); 1638 1639 if (!root) { 1640 for (Operation &op : body) { 1641 auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 1642 if (!iface) 1643 continue; 1644 1645 SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 1646 iface.getEffects(effects); 1647 } 1648 return; 1649 } 1650 1651 // Carry over all effects on arguments of the entry block as those on the 1652 // operands, this is the same value just remapped. 1653 remapArgumentEffects(body, operation->getOpOperands(), effects); 1654 } 1655 1656 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( 1657 TransformState &state, Operation *op, Region ®ion) { 1658 SmallVector<Operation *> targets; 1659 SmallVector<SmallVector<MappedValue>> extraMappings; 1660 if (op->getNumOperands() != 0) { 1661 llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); 1662 prepareValueMappings(extraMappings, op->getOperands().drop_front(), state); 1663 } else { 1664 if (state.getNumTopLevelMappings() != 1665 region.front().getNumArguments() - 1) { 1666 return emitError(op->getLoc()) 1667 << "operation expects " << region.front().getNumArguments() - 1 1668 << " extra value bindings, but " << state.getNumTopLevelMappings() 1669 << " were provided to the interpreter"; 1670 } 1671 1672 targets.push_back(state.getTopLevel()); 1673 1674 for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i) 1675 extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i))); 1676 } 1677 1678 if (failed(state.mapBlockArguments(region.front().getArgument(0), targets))) 1679 return failure(); 1680 1681 for (BlockArgument argument : region.front().getArguments().drop_front()) { 1682 if (failed(state.mapBlockArgument( 1683 argument, extraMappings[argument.getArgNumber() - 1]))) 1684 return failure(); 1685 } 1686 1687 return success(); 1688 } 1689 1690 LogicalResult 1691 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { 1692 // Attaching this trait without the interface is a misuse of the API, but it 1693 // cannot be caught via a static_assert because interface registration is 1694 // dynamic. 1695 assert(isa<TransformOpInterface>(op) && 1696 "should implement TransformOpInterface to have " 1697 "PossibleTopLevelTransformOpTrait"); 1698 1699 if (op->getNumRegions() < 1) 1700 return op->emitOpError() << "expects at least one region"; 1701 1702 Region *bodyRegion = &op->getRegion(0); 1703 if (!llvm::hasNItems(*bodyRegion, 1)) 1704 return op->emitOpError() << "expects a single-block region"; 1705 1706 Block *body = &bodyRegion->front(); 1707 if (body->getNumArguments() == 0) { 1708 return op->emitOpError() 1709 << "expects the entry block to have at least one argument"; 1710 } 1711 if (!llvm::isa<TransformHandleTypeInterface>( 1712 body->getArgument(0).getType())) { 1713 return op->emitOpError() 1714 << "expects the first entry block argument to be of type " 1715 "implementing TransformHandleTypeInterface"; 1716 } 1717 BlockArgument arg = body->getArgument(0); 1718 if (op->getNumOperands() != 0) { 1719 if (arg.getType() != op->getOperand(0).getType()) { 1720 return op->emitOpError() 1721 << "expects the type of the block argument to match " 1722 "the type of the operand"; 1723 } 1724 } 1725 for (BlockArgument arg : body->getArguments().drop_front()) { 1726 if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface, 1727 TransformValueHandleTypeInterface>(arg.getType())) 1728 continue; 1729 1730 InFlightDiagnostic diag = 1731 op->emitOpError() 1732 << "expects trailing entry block arguments to be of type implementing " 1733 "TransformHandleTypeInterface, TransformValueHandleTypeInterface or " 1734 "TransformParamTypeInterface"; 1735 diag.attachNote() << "argument #" << arg.getArgNumber() << " does not"; 1736 return diag; 1737 } 1738 1739 if (auto *parent = 1740 op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) { 1741 if (op->getNumOperands() != body->getNumArguments()) { 1742 InFlightDiagnostic diag = 1743 op->emitOpError() 1744 << "expects operands to be provided for a nested op"; 1745 diag.attachNote(parent->getLoc()) 1746 << "nested in another possible top-level op"; 1747 return diag; 1748 } 1749 } 1750 1751 return success(); 1752 } 1753 1754 //===----------------------------------------------------------------------===// 1755 // Utilities for ParamProducedTransformOpTrait. 1756 //===----------------------------------------------------------------------===// 1757 1758 void transform::detail::getParamProducerTransformOpTraitEffects( 1759 Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1760 producesHandle(op->getResults(), effects); 1761 bool hasPayloadOperands = false; 1762 for (OpOperand &operand : op->getOpOperands()) { 1763 onlyReadsHandle(operand, effects); 1764 if (llvm::isa<TransformHandleTypeInterface, 1765 TransformValueHandleTypeInterface>(operand.get().getType())) 1766 hasPayloadOperands = true; 1767 } 1768 if (hasPayloadOperands) 1769 onlyReadsPayload(effects); 1770 } 1771 1772 LogicalResult 1773 transform::detail::verifyParamProducerTransformOpTrait(Operation *op) { 1774 // Interfaces can be attached dynamically, so this cannot be a static 1775 // assert. 1776 if (!op->getName().getInterface<MemoryEffectOpInterface>()) { 1777 llvm::report_fatal_error( 1778 Twine("ParamProducerTransformOpTrait must be attached to an op that " 1779 "implements MemoryEffectsOpInterface, found on ") + 1780 op->getName().getStringRef()); 1781 } 1782 for (Value result : op->getResults()) { 1783 if (llvm::isa<TransformParamTypeInterface>(result.getType())) 1784 continue; 1785 return op->emitOpError() 1786 << "ParamProducerTransformOpTrait attached to this op expects " 1787 "result types to implement TransformParamTypeInterface"; 1788 } 1789 return success(); 1790 } 1791 1792 //===----------------------------------------------------------------------===// 1793 // Memory effects. 1794 //===----------------------------------------------------------------------===// 1795 1796 void transform::consumesHandle( 1797 MutableArrayRef<OpOperand> handles, 1798 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1799 for (OpOperand &handle : handles) { 1800 effects.emplace_back(MemoryEffects::Read::get(), &handle, 1801 TransformMappingResource::get()); 1802 effects.emplace_back(MemoryEffects::Free::get(), &handle, 1803 TransformMappingResource::get()); 1804 } 1805 } 1806 1807 /// Returns `true` if the given list of effects instances contains an instance 1808 /// with the effect type specified as template parameter. 1809 template <typename EffectTy, typename ResourceTy, typename Range> 1810 static bool hasEffect(Range &&effects) { 1811 return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 1812 return isa<EffectTy>(effect.getEffect()) && 1813 isa<ResourceTy>(effect.getResource()); 1814 }); 1815 } 1816 1817 bool transform::isHandleConsumed(Value handle, 1818 transform::TransformOpInterface transform) { 1819 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation()); 1820 SmallVector<MemoryEffects::EffectInstance> effects; 1821 iface.getEffectsOnValue(handle, effects); 1822 return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) && 1823 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects); 1824 } 1825 1826 void transform::producesHandle( 1827 ResultRange handles, 1828 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1829 for (OpResult handle : handles) { 1830 effects.emplace_back(MemoryEffects::Allocate::get(), handle, 1831 TransformMappingResource::get()); 1832 effects.emplace_back(MemoryEffects::Write::get(), handle, 1833 TransformMappingResource::get()); 1834 } 1835 } 1836 1837 void transform::producesHandle( 1838 MutableArrayRef<BlockArgument> handles, 1839 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1840 for (BlockArgument handle : handles) { 1841 effects.emplace_back(MemoryEffects::Allocate::get(), handle, 1842 TransformMappingResource::get()); 1843 effects.emplace_back(MemoryEffects::Write::get(), handle, 1844 TransformMappingResource::get()); 1845 } 1846 } 1847 1848 void transform::onlyReadsHandle( 1849 MutableArrayRef<OpOperand> handles, 1850 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1851 for (OpOperand &handle : handles) { 1852 effects.emplace_back(MemoryEffects::Read::get(), &handle, 1853 TransformMappingResource::get()); 1854 } 1855 } 1856 1857 void transform::modifiesPayload( 1858 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1859 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 1860 effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); 1861 } 1862 1863 void transform::onlyReadsPayload( 1864 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 1865 effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 1866 } 1867 1868 bool transform::doesModifyPayload(transform::TransformOpInterface transform) { 1869 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation()); 1870 SmallVector<MemoryEffects::EffectInstance> effects; 1871 iface.getEffects(effects); 1872 return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects); 1873 } 1874 1875 bool transform::doesReadPayload(transform::TransformOpInterface transform) { 1876 auto iface = cast<MemoryEffectOpInterface>(transform.getOperation()); 1877 SmallVector<MemoryEffects::EffectInstance> effects; 1878 iface.getEffects(effects); 1879 return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects); 1880 } 1881 1882 void transform::getConsumedBlockArguments( 1883 Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) { 1884 SmallVector<MemoryEffects::EffectInstance> effects; 1885 for (Operation &nested : block) { 1886 auto iface = dyn_cast<MemoryEffectOpInterface>(nested); 1887 if (!iface) 1888 continue; 1889 1890 effects.clear(); 1891 iface.getEffects(effects); 1892 for (const MemoryEffects::EffectInstance &effect : effects) { 1893 BlockArgument argument = 1894 dyn_cast_or_null<BlockArgument>(effect.getValue()); 1895 if (!argument || argument.getOwner() != &block || 1896 !isa<MemoryEffects::Free>(effect.getEffect()) || 1897 effect.getResource() != transform::TransformMappingResource::get()) { 1898 continue; 1899 } 1900 consumedArguments.insert(argument.getArgNumber()); 1901 } 1902 } 1903 } 1904 1905 //===----------------------------------------------------------------------===// 1906 // Utilities for TransformOpInterface. 1907 //===----------------------------------------------------------------------===// 1908 1909 SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands( 1910 TransformOpInterface transformOp) { 1911 SmallVector<OpOperand *> consumedOperands; 1912 consumedOperands.reserve(transformOp->getNumOperands()); 1913 auto memEffectInterface = 1914 cast<MemoryEffectOpInterface>(transformOp.getOperation()); 1915 SmallVector<MemoryEffects::EffectInstance, 2> effects; 1916 for (OpOperand &target : transformOp->getOpOperands()) { 1917 effects.clear(); 1918 memEffectInterface.getEffectsOnValue(target.get(), effects); 1919 if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 1920 return isa<transform::TransformMappingResource>( 1921 effect.getResource()) && 1922 isa<MemoryEffects::Free>(effect.getEffect()); 1923 })) { 1924 consumedOperands.push_back(&target); 1925 } 1926 } 1927 return consumedOperands; 1928 } 1929 1930 LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) { 1931 auto iface = cast<MemoryEffectOpInterface>(op); 1932 SmallVector<MemoryEffects::EffectInstance> effects; 1933 iface.getEffects(effects); 1934 1935 auto effectsOn = [&](Value value) { 1936 return llvm::make_filter_range( 1937 effects, [value](const MemoryEffects::EffectInstance &instance) { 1938 return instance.getValue() == value; 1939 }); 1940 }; 1941 1942 std::optional<unsigned> firstConsumedOperand; 1943 for (OpOperand &operand : op->getOpOperands()) { 1944 auto range = effectsOn(operand.get()); 1945 if (range.empty()) { 1946 InFlightDiagnostic diag = 1947 op->emitError() << "TransformOpInterface requires memory effects " 1948 "on operands to be specified"; 1949 diag.attachNote() << "no effects specified for operand #" 1950 << operand.getOperandNumber(); 1951 return diag; 1952 } 1953 if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) { 1954 InFlightDiagnostic diag = op->emitError() 1955 << "TransformOpInterface did not expect " 1956 "'allocate' memory effect on an operand"; 1957 diag.attachNote() << "specified for operand #" 1958 << operand.getOperandNumber(); 1959 return diag; 1960 } 1961 if (!firstConsumedOperand && 1962 ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) { 1963 firstConsumedOperand = operand.getOperandNumber(); 1964 } 1965 } 1966 1967 if (firstConsumedOperand && 1968 !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) { 1969 InFlightDiagnostic diag = 1970 op->emitError() 1971 << "TransformOpInterface expects ops consuming operands to have a " 1972 "'write' effect on the payload resource"; 1973 diag.attachNote() << "consumes operand #" << *firstConsumedOperand; 1974 return diag; 1975 } 1976 1977 for (OpResult result : op->getResults()) { 1978 auto range = effectsOn(result); 1979 if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>( 1980 range)) { 1981 InFlightDiagnostic diag = 1982 op->emitError() << "TransformOpInterface requires 'allocate' memory " 1983 "effect to be specified for results"; 1984 diag.attachNote() << "no 'allocate' effect specified for result #" 1985 << result.getResultNumber(); 1986 return diag; 1987 } 1988 } 1989 1990 return success(); 1991 } 1992 1993 //===----------------------------------------------------------------------===// 1994 // Entry point. 1995 //===----------------------------------------------------------------------===// 1996 1997 LogicalResult transform::applyTransforms( 1998 Operation *payloadRoot, TransformOpInterface transform, 1999 const RaggedArray<MappedValue> &extraMapping, 2000 const TransformOptions &options, bool enforceToplevelTransformOp, 2001 function_ref<void(TransformState &)> stateInitializer, 2002 function_ref<LogicalResult(TransformState &)> stateExporter) { 2003 if (enforceToplevelTransformOp) { 2004 if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() || 2005 transform->getNumOperands() != 0) { 2006 return transform->emitError() 2007 << "expected transform to start at the top-level transform op"; 2008 } 2009 } else if (failed( 2010 detail::verifyPossibleTopLevelTransformOpTrait(transform))) { 2011 return failure(); 2012 } 2013 2014 TransformState state(transform->getParentRegion(), payloadRoot, extraMapping, 2015 options); 2016 if (stateInitializer) 2017 stateInitializer(state); 2018 if (state.applyTransform(transform).checkAndReport().failed()) 2019 return failure(); 2020 if (stateExporter) 2021 return stateExporter(state); 2022 return success(); 2023 } 2024 2025 //===----------------------------------------------------------------------===// 2026 // Generated interface implementation. 2027 //===----------------------------------------------------------------------===// 2028 2029 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc" 2030 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc" 2031