xref: /llvm-project/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp (revision 129f1001c3b1b5200de43917d53c0efbdf08f11f)
1 //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
10 
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Interfaces/CastInterfaces.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/ScopeExit.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/ErrorHandling.h"
20 
21 #define DEBUG_TYPE "transform-dialect"
22 #define DEBUG_TYPE_FULL "transform-dialect-full"
23 #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
25 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
26 #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Helper functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
35 /// properly dominates `b` and `b` is not inside `a`.
36 static bool happensBefore(Operation *a, Operation *b) {
37   do {
38     if (a->isProperAncestor(b))
39       return false;
40     if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
41       return a->isBeforeInBlock(bAncestor);
42     }
43   } while ((a = a->getParentOp()));
44   return false;
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // TransformState
49 //===----------------------------------------------------------------------===//
50 
51 constexpr const Value transform::TransformState::kTopLevelValue;
52 
53 transform::TransformState::TransformState(
54     Region *region, Operation *payloadRoot,
55     const RaggedArray<MappedValue> &extraMappings,
56     const TransformOptions &options)
57     : topLevel(payloadRoot), options(options) {
58   topLevelMappedValues.reserve(extraMappings.size());
59   for (ArrayRef<MappedValue> mapping : extraMappings)
60     topLevelMappedValues.push_back(mapping);
61   if (region) {
62     RegionScope *scope = new RegionScope(*this, *region);
63     topLevelRegionScope.reset(scope);
64   }
65 }
66 
67 Operation *transform::TransformState::getTopLevel() const { return topLevel; }
68 
69 ArrayRef<Operation *>
70 transform::TransformState::getPayloadOpsView(Value value) const {
71   const TransformOpMapping &operationMapping = getMapping(value).direct;
72   auto iter = operationMapping.find(value);
73   assert(iter != operationMapping.end() &&
74          "cannot find mapping for payload handle (param/value handle "
75          "provided?)");
76   return iter->getSecond();
77 }
78 
79 ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
80   const ParamMapping &mapping = getMapping(value).params;
81   auto iter = mapping.find(value);
82   assert(iter != mapping.end() && "cannot find mapping for param handle "
83                                   "(operation/value handle provided?)");
84   return iter->getSecond();
85 }
86 
87 ArrayRef<Value>
88 transform::TransformState::getPayloadValuesView(Value handleValue) const {
89   const ValueMapping &mapping = getMapping(handleValue).values;
90   auto iter = mapping.find(handleValue);
91   assert(iter != mapping.end() && "cannot find mapping for value handle "
92                                   "(param/operation handle provided?)");
93   return iter->getSecond();
94 }
95 
96 LogicalResult transform::TransformState::getHandlesForPayloadOp(
97     Operation *op, SmallVectorImpl<Value> &handles,
98     bool includeOutOfScope) const {
99   bool found = false;
100   for (const auto &[region, mapping] : llvm::reverse(mappings)) {
101     auto iterator = mapping->reverse.find(op);
102     if (iterator != mapping->reverse.end()) {
103       llvm::append_range(handles, iterator->getSecond());
104       found = true;
105     }
106     // Stop looking when reaching a region that is isolated from above.
107     if (!includeOutOfScope &&
108         region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
109       break;
110   }
111 
112   return success(found);
113 }
114 
115 LogicalResult transform::TransformState::getHandlesForPayloadValue(
116     Value payloadValue, SmallVectorImpl<Value> &handles,
117     bool includeOutOfScope) const {
118   bool found = false;
119   for (const auto &[region, mapping] : llvm::reverse(mappings)) {
120     auto iterator = mapping->reverseValues.find(payloadValue);
121     if (iterator != mapping->reverseValues.end()) {
122       llvm::append_range(handles, iterator->getSecond());
123       found = true;
124     }
125     // Stop looking when reaching a region that is isolated from above.
126     if (!includeOutOfScope &&
127         region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
128       break;
129   }
130 
131   return success(found);
132 }
133 
134 /// Given a list of MappedValues, cast them to the value kind implied by the
135 /// interface of the handle type, and dispatch to one of the callbacks.
136 static DiagnosedSilenceableFailure dispatchMappedValues(
137     Value handle, ArrayRef<transform::MappedValue> values,
138     function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
139     function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
140     function_ref<LogicalResult(ValueRange)> valuesFn) {
141   if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
142     SmallVector<Operation *> operations;
143     operations.reserve(values.size());
144     for (transform::MappedValue value : values) {
145       if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
146         operations.push_back(op);
147         continue;
148       }
149       return emitSilenceableFailure(handle.getLoc())
150              << "wrong kind of value provided for top-level operation handle";
151     }
152     if (failed(operationsFn(operations)))
153       return DiagnosedSilenceableFailure::definiteFailure();
154     return DiagnosedSilenceableFailure::success();
155   }
156 
157   if (llvm::isa<transform::TransformValueHandleTypeInterface>(
158           handle.getType())) {
159     SmallVector<Value> payloadValues;
160     payloadValues.reserve(values.size());
161     for (transform::MappedValue value : values) {
162       if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
163         payloadValues.push_back(v);
164         continue;
165       }
166       return emitSilenceableFailure(handle.getLoc())
167              << "wrong kind of value provided for the top-level value handle";
168     }
169     if (failed(valuesFn(payloadValues)))
170       return DiagnosedSilenceableFailure::definiteFailure();
171     return DiagnosedSilenceableFailure::success();
172   }
173 
174   assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
175          "unsupported kind of block argument");
176   SmallVector<transform::Param> parameters;
177   parameters.reserve(values.size());
178   for (transform::MappedValue value : values) {
179     if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
180       parameters.push_back(attr);
181       continue;
182     }
183     return emitSilenceableFailure(handle.getLoc())
184            << "wrong kind of value provided for top-level parameter";
185   }
186   if (failed(paramsFn(parameters)))
187     return DiagnosedSilenceableFailure::definiteFailure();
188   return DiagnosedSilenceableFailure::success();
189 }
190 
191 LogicalResult
192 transform::TransformState::mapBlockArgument(BlockArgument argument,
193                                             ArrayRef<MappedValue> values) {
194   return dispatchMappedValues(
195              argument, values,
196              [&](ArrayRef<Operation *> operations) {
197                return setPayloadOps(argument, operations);
198              },
199              [&](ArrayRef<Param> params) {
200                return setParams(argument, params);
201              },
202              [&](ValueRange payloadValues) {
203                return setPayloadValues(argument, payloadValues);
204              })
205       .checkAndReport();
206 }
207 
208 LogicalResult transform::TransformState::mapBlockArguments(
209     Block::BlockArgListType arguments,
210     ArrayRef<SmallVector<MappedValue>> mapping) {
211   for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
212     if (failed(mapBlockArgument(argument, values)))
213       return failure();
214   return success();
215 }
216 
217 LogicalResult
218 transform::TransformState::setPayloadOps(Value value,
219                                          ArrayRef<Operation *> targets) {
220   assert(value != kTopLevelValue &&
221          "attempting to reset the transformation root");
222   assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
223          "wrong handle type");
224 
225   for (Operation *target : targets) {
226     if (target)
227       continue;
228     return emitError(value.getLoc())
229            << "attempting to assign a null payload op to this transform value";
230   }
231 
232   auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
233   DiagnosedSilenceableFailure result =
234       iface.checkPayload(value.getLoc(), targets);
235   if (failed(result.checkAndReport()))
236     return failure();
237 
238   // Setting new payload for the value without cleaning it first is a misuse of
239   // the API, assert here.
240   SmallVector<Operation *> storedTargets(targets);
241   Mappings &mappings = getMapping(value);
242   bool inserted =
243       mappings.direct.insert({value, std::move(storedTargets)}).second;
244   assert(inserted && "value is already associated with another list");
245   (void)inserted;
246 
247   for (Operation *op : targets)
248     mappings.reverse[op].push_back(value);
249 
250   return success();
251 }
252 
253 LogicalResult
254 transform::TransformState::setPayloadValues(Value handle,
255                                             ValueRange payloadValues) {
256   assert(handle != nullptr && "attempting to set params for a null value");
257   assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
258          "wrong handle type");
259 
260   for (Value payload : payloadValues) {
261     if (payload)
262       continue;
263     return emitError(handle.getLoc()) << "attempting to assign a null payload "
264                                          "value to this transform handle";
265   }
266 
267   auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
268   SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
269   DiagnosedSilenceableFailure result =
270       iface.checkPayload(handle.getLoc(), payloadValueVector);
271   if (failed(result.checkAndReport()))
272     return failure();
273 
274   Mappings &mappings = getMapping(handle);
275   bool inserted =
276       mappings.values.insert({handle, std::move(payloadValueVector)}).second;
277   assert(
278       inserted &&
279       "value handle is already associated with another list of payload values");
280   (void)inserted;
281 
282   for (Value payload : payloadValues)
283     mappings.reverseValues[payload].push_back(handle);
284 
285   return success();
286 }
287 
288 LogicalResult transform::TransformState::setParams(Value value,
289                                                    ArrayRef<Param> params) {
290   assert(value != nullptr && "attempting to set params for a null value");
291 
292   for (Attribute attr : params) {
293     if (attr)
294       continue;
295     return emitError(value.getLoc())
296            << "attempting to assign a null parameter to this transform value";
297   }
298 
299   auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
300   assert(value &&
301          "cannot associate parameter with a value of non-parameter type");
302   DiagnosedSilenceableFailure result =
303       valueType.checkPayload(value.getLoc(), params);
304   if (failed(result.checkAndReport()))
305     return failure();
306 
307   Mappings &mappings = getMapping(value);
308   bool inserted =
309       mappings.params.insert({value, llvm::to_vector(params)}).second;
310   assert(inserted && "value is already associated with another list of params");
311   (void)inserted;
312   return success();
313 }
314 
315 template <typename Mapping, typename Key, typename Mapped>
316 void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
317   auto it = mapping.find(key);
318   if (it == mapping.end())
319     return;
320 
321   llvm::erase(it->getSecond(), mapped);
322   if (it->getSecond().empty())
323     mapping.erase(it);
324 }
325 
326 void transform::TransformState::forgetMapping(Value opHandle,
327                                               ValueRange origOpFlatResults,
328                                               bool allowOutOfScope) {
329   Mappings &mappings = getMapping(opHandle, allowOutOfScope);
330   for (Operation *op : mappings.direct[opHandle])
331     dropMappingEntry(mappings.reverse, op, opHandle);
332   mappings.direct.erase(opHandle);
333 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
334   // Payload IR is removed from the mapping. This invalidates the respective
335   // iterators.
336   mappings.incrementTimestamp(opHandle);
337 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
338 
339   for (Value opResult : origOpFlatResults) {
340     SmallVector<Value> resultHandles;
341     (void)getHandlesForPayloadValue(opResult, resultHandles);
342     for (Value resultHandle : resultHandles) {
343       Mappings &localMappings = getMapping(resultHandle);
344       dropMappingEntry(localMappings.values, resultHandle, opResult);
345 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
346       // Payload IR is removed from the mapping. This invalidates the respective
347       // iterators.
348       mappings.incrementTimestamp(resultHandle);
349 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
350       dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
351     }
352   }
353 }
354 
355 void transform::TransformState::forgetValueMapping(
356     Value valueHandle, ArrayRef<Operation *> payloadOperations) {
357   Mappings &mappings = getMapping(valueHandle);
358   for (Value payloadValue : mappings.reverseValues[valueHandle])
359     dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
360   mappings.values.erase(valueHandle);
361 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
362   // Payload IR is removed from the mapping. This invalidates the respective
363   // iterators.
364   mappings.incrementTimestamp(valueHandle);
365 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
366 
367   for (Operation *payloadOp : payloadOperations) {
368     SmallVector<Value> opHandles;
369     (void)getHandlesForPayloadOp(payloadOp, opHandles);
370     for (Value opHandle : opHandles) {
371       Mappings &localMappings = getMapping(opHandle);
372       dropMappingEntry(localMappings.direct, opHandle, payloadOp);
373       dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
374 
375 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
376       // Payload IR is removed from the mapping. This invalidates the respective
377       // iterators.
378       localMappings.incrementTimestamp(opHandle);
379 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
380     }
381   }
382 }
383 
384 LogicalResult
385 transform::TransformState::replacePayloadOp(Operation *op,
386                                             Operation *replacement) {
387   // TODO: consider invalidating the handles to nested objects here.
388 
389 #ifndef NDEBUG
390   for (Value opResult : op->getResults()) {
391     SmallVector<Value> valueHandles;
392     (void)getHandlesForPayloadValue(opResult, valueHandles,
393                                     /*includeOutOfScope=*/true);
394     assert(valueHandles.empty() && "expected no mapping to old results");
395   }
396 #endif // NDEBUG
397 
398   // Drop the mapping between the op and all handles that point to it. Fail if
399   // there are no handles.
400   SmallVector<Value> opHandles;
401   if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
402     return failure();
403   for (Value handle : opHandles) {
404     Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
405     dropMappingEntry(mappings.reverse, op, handle);
406   }
407 
408   // Replace the pointed-to object of all handles with the replacement object.
409   // In case a payload op was erased (replacement object is nullptr), a nullptr
410   // is stored in the mapping. These nullptrs are removed after each transform.
411   // Furthermore, nullptrs are not enumerated by payload op iterators. The
412   // relative order of ops is preserved.
413   //
414   // Removing an op from the mapping would be problematic because removing an
415   // element from an array invalidates iterators; merely changing the value of
416   // elements does not.
417   for (Value handle : opHandles) {
418     Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
419     auto it = mappings.direct.find(handle);
420     if (it == mappings.direct.end())
421       continue;
422 
423     SmallVector<Operation *, 2> &association = it->getSecond();
424     // Note that an operation may be associated with the handle more than once.
425     for (Operation *&mapped : association) {
426       if (mapped == op)
427         mapped = replacement;
428     }
429 
430     if (replacement) {
431       mappings.reverse[replacement].push_back(handle);
432     } else {
433       opHandlesToCompact.insert(handle);
434     }
435   }
436 
437   return success();
438 }
439 
440 LogicalResult
441 transform::TransformState::replacePayloadValue(Value value, Value replacement) {
442   SmallVector<Value> valueHandles;
443   if (failed(getHandlesForPayloadValue(value, valueHandles,
444                                        /*includeOutOfScope=*/true)))
445     return failure();
446 
447   for (Value handle : valueHandles) {
448     Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
449     dropMappingEntry(mappings.reverseValues, value, handle);
450 
451     // If replacing with null, that is erasing the mapping, drop the mapping
452     // between the handles and the IR objects
453     if (!replacement) {
454       dropMappingEntry(mappings.values, handle, value);
455 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
456       // Payload IR is removed from the mapping. This invalidates the respective
457       // iterators.
458       mappings.incrementTimestamp(handle);
459 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
460     } else {
461       auto it = mappings.values.find(handle);
462       if (it == mappings.values.end())
463         continue;
464 
465       SmallVector<Value> &association = it->getSecond();
466       for (Value &mapped : association) {
467         if (mapped == value)
468           mapped = replacement;
469       }
470       mappings.reverseValues[replacement].push_back(handle);
471     }
472   }
473 
474   return success();
475 }
476 
477 void transform::TransformState::recordOpHandleInvalidationOne(
478     OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
479     Operation *payloadOp, Value otherHandle, Value throughValue,
480     transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
481   // If the op is associated with invalidated handle, skip the check as it
482   // may be reading invalid IR. This also ensures we report the first
483   // invalidation and not the last one.
484   if (invalidatedHandles.count(otherHandle) ||
485       newlyInvalidated.count(otherHandle))
486     return;
487 
488   FULL_LDBG("--recordOpHandleInvalidationOne\n");
489   DEBUG_WITH_TYPE(
490       DEBUG_TYPE_FULL,
491       llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ",
492                             [](Operation *op) { llvm::dbgs() << *op; });
493       llvm::dbgs() << "\n");
494 
495   Operation *owner = consumingHandle.getOwner();
496   unsigned operandNo = consumingHandle.getOperandNumber();
497   for (Operation *ancestor : potentialAncestors) {
498     // clang-format off
499     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
500       { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
501     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
502       { (DBGS() << "----of payload with name: "
503                 << payloadOp->getName().getIdentifier() << "\n"); });
504     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
505       { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
506     // clang-format on
507     if (!ancestor->isAncestor(payloadOp))
508       continue;
509 
510     // Make sure the error-reporting lambda doesn't capture anything
511     // by-reference because it will go out of scope. Additionally, extract
512     // location from Payload IR ops because the ops themselves may be
513     // deleted before the lambda gets called.
514     Location ancestorLoc = ancestor->getLoc();
515     Location opLoc = payloadOp->getLoc();
516     std::optional<Location> throughValueLoc =
517         throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
518     newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
519                                      otherHandle,
520                                      throughValueLoc](Location currentLoc) {
521       InFlightDiagnostic diag = emitError(currentLoc)
522                                 << "op uses a handle invalidated by a "
523                                    "previously executed transform op";
524       diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
525       diag.attachNote(owner->getLoc())
526           << "invalidated by this transform op that consumes its operand #"
527           << operandNo
528           << " and invalidates all handles to payload IR entities associated "
529              "with this operand and entities nested in them";
530       diag.attachNote(ancestorLoc) << "ancestor payload op";
531       diag.attachNote(opLoc) << "nested payload op";
532       if (throughValueLoc) {
533         diag.attachNote(*throughValueLoc)
534             << "consumed handle points to this payload value";
535       }
536     };
537   }
538 }
539 
540 void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
541     OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
542     Value payloadValue, Value valueHandle,
543     transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
544   // If the op is associated with invalidated handle, skip the check as it
545   // may be reading invalid IR. This also ensures we report the first
546   // invalidation and not the last one.
547   if (invalidatedHandles.count(valueHandle) ||
548       newlyInvalidated.count(valueHandle))
549     return;
550 
551   for (Operation *ancestor : potentialAncestors) {
552     Operation *definingOp;
553     std::optional<unsigned> resultNo;
554     unsigned argumentNo = std::numeric_limits<unsigned>::max();
555     unsigned blockNo = std::numeric_limits<unsigned>::max();
556     unsigned regionNo = std::numeric_limits<unsigned>::max();
557     if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
558       definingOp = opResult.getOwner();
559       resultNo = opResult.getResultNumber();
560     } else {
561       auto arg = llvm::cast<BlockArgument>(payloadValue);
562       definingOp = arg.getParentBlock()->getParentOp();
563       argumentNo = arg.getArgNumber();
564       blockNo = std::distance(arg.getOwner()->getParent()->begin(),
565                               arg.getOwner()->getIterator());
566       regionNo = arg.getOwner()->getParent()->getRegionNumber();
567     }
568     assert(definingOp && "expected the value to be defined by an op as result "
569                          "or block argument");
570     if (!ancestor->isAncestor(definingOp))
571       continue;
572 
573     Operation *owner = opHandle.getOwner();
574     unsigned operandNo = opHandle.getOperandNumber();
575     Location ancestorLoc = ancestor->getLoc();
576     Location opLoc = definingOp->getLoc();
577     Location valueLoc = payloadValue.getLoc();
578     newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
579                                      argumentNo, blockNo, regionNo, ancestorLoc,
580                                      opLoc, valueLoc](Location currentLoc) {
581       InFlightDiagnostic diag = emitError(currentLoc)
582                                 << "op uses a handle invalidated by a "
583                                    "previously executed transform op";
584       diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
585       diag.attachNote(owner->getLoc())
586           << "invalidated by this transform op that consumes its operand #"
587           << operandNo
588           << " and invalidates all handles to payload IR entities "
589              "associated with this operand and entities nested in them";
590       diag.attachNote(ancestorLoc)
591           << "ancestor op associated with the consumed handle";
592       if (resultNo) {
593         diag.attachNote(opLoc)
594             << "op defining the value as result #" << *resultNo;
595       } else {
596         diag.attachNote(opLoc)
597             << "op defining the value as block argument #" << argumentNo
598             << " of block #" << blockNo << " in region #" << regionNo;
599       }
600       diag.attachNote(valueLoc) << "payload value";
601     };
602   }
603 }
604 
605 void transform::TransformState::recordOpHandleInvalidation(
606     OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
607     Value throughValue,
608     transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
609 
610   if (potentialAncestors.empty()) {
611     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
612       (DBGS() << "----recording invalidation for empty handle: " << handle.get()
613               << "\n");
614     });
615 
616     Operation *owner = handle.getOwner();
617     unsigned operandNo = handle.getOperandNumber();
618     newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
619       InFlightDiagnostic diag = emitError(currentLoc)
620                                 << "op uses a handle associated with empty "
621                                    "payload and invalidated by a "
622                                    "previously executed transform op";
623       diag.attachNote(owner->getLoc())
624           << "invalidated by this transform op that consumes its operand #"
625           << operandNo;
626     };
627     return;
628   }
629 
630   // Iterate over the mapping and invalidate aliasing handles. This is quite
631   // expensive and only necessary for error reporting in case of transform
632   // dialect misuse with dangling handles. Iteration over the handles is based
633   // on the assumption that the number of handles is significantly less than the
634   // number of IR objects (operations and values). Alternatively, we could walk
635   // the IR nested in each payload op associated with the given handle and look
636   // for handles associated with each operation and value.
637   for (const auto &[region, mapping] : llvm::reverse(mappings)) {
638     // Go over all op handle mappings and mark as invalidated any handle
639     // pointing to any of the payload ops associated with the given handle or
640     // any op nested in them.
641     for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
642       for (Value otherHandle : otherHandles)
643         recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
644                                       otherHandle, throughValue,
645                                       newlyInvalidated);
646     }
647     // Go over all value handle mappings and mark as invalidated any handle
648     // pointing to any result of the payload op associated with the given handle
649     // or any op nested in them. Similarly invalidate handles to argument of
650     // blocks belonging to any region of any payload op associated with the
651     // given handle or any op nested in them.
652     for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
653       for (Value valueHandle : valueHandles)
654         recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
655                                                    payloadValue, valueHandle,
656                                                    newlyInvalidated);
657     }
658 
659     // Stop lookup when reaching a region that is isolated from above.
660     if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
661       break;
662   }
663 }
664 
665 void transform::TransformState::recordValueHandleInvalidation(
666     OpOperand &valueHandle,
667     transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
668   // Invalidate other handles to the same value.
669   for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
670     SmallVector<Value> otherValueHandles;
671     (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
672     for (Value otherHandle : otherValueHandles) {
673       Operation *owner = valueHandle.getOwner();
674       unsigned operandNo = valueHandle.getOperandNumber();
675       Location valueLoc = payloadValue.getLoc();
676       newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
677                                        valueLoc](Location currentLoc) {
678         InFlightDiagnostic diag = emitError(currentLoc)
679                                   << "op uses a handle invalidated by a "
680                                      "previously executed transform op";
681         diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
682         diag.attachNote(owner->getLoc())
683             << "invalidated by this transform op that consumes its operand #"
684             << operandNo
685             << " and invalidates handles to the same values as associated with "
686                "it";
687         diag.attachNote(valueLoc) << "payload value";
688       };
689     }
690 
691     if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
692       Operation *payloadOp = opResult.getOwner();
693       recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
694                                  newlyInvalidated);
695     } else {
696       auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
697       for (Operation &payloadOp : *arg.getOwner())
698         recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
699                                    newlyInvalidated);
700     }
701   }
702 }
703 
704 /// Checks that the operation does not use invalidated handles as operands.
705 /// Reports errors and returns failure if it does. Otherwise, invalidates the
706 /// handles consumed by the operation as well as any handles pointing to payload
707 /// IR operations nested in the operations associated with the consumed handles.
708 LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
709     transform::TransformOpInterface transform,
710     transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
711   FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
712   auto memoryEffectsIface =
713       cast<MemoryEffectOpInterface>(transform.getOperation());
714   SmallVector<MemoryEffects::EffectInstance> effects;
715   memoryEffectsIface.getEffectsOnResource(
716       transform::TransformMappingResource::get(), effects);
717 
718   for (OpOperand &target : transform->getOpOperands()) {
719     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
720       (DBGS() << "----iterate on handle: " << target.get() << "\n");
721     });
722     // If the operand uses an invalidated handle, report it. If the operation
723     // allows handles to point to repeated payload operations, only report
724     // pre-existing invalidation errors. Otherwise, also report invalidations
725     // caused by the current transform operation affecting its other operands.
726     auto it = invalidatedHandles.find(target.get());
727     auto nit = newlyInvalidated.find(target.get());
728     if (it != invalidatedHandles.end()) {
729       FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
730                 "invalidated -> FAILURE\n");
731       return it->getSecond()(transform->getLoc()), failure();
732     }
733     if (!transform.allowsRepeatedHandleOperands() &&
734         nit != newlyInvalidated.end()) {
735       FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
736                 "invalidated (by this op) -> FAILURE\n");
737       return nit->getSecond()(transform->getLoc()), failure();
738     }
739 
740     // Invalidate handles pointing to the operations nested in the operation
741     // associated with the handle consumed by this operation.
742     auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
743       return isa<MemoryEffects::Free>(effect.getEffect()) &&
744              effect.getValue() == target.get();
745     };
746     if (llvm::any_of(effects, consumesTarget)) {
747       FULL_LDBG("----found consume effect\n");
748       if (llvm::isa<transform::TransformHandleTypeInterface>(
749               target.get().getType())) {
750         FULL_LDBG("----recordOpHandleInvalidation\n");
751         SmallVector<Operation *> payloadOps =
752             llvm::to_vector(getPayloadOps(target.get()));
753         recordOpHandleInvalidation(target, payloadOps, nullptr,
754                                    newlyInvalidated);
755       } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
756                      target.get().getType())) {
757         FULL_LDBG("----recordValueHandleInvalidation\n");
758         recordValueHandleInvalidation(target, newlyInvalidated);
759       } else {
760         FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
761       }
762     } else {
763       FULL_LDBG("----no consume effect -> SKIP\n");
764     }
765   }
766 
767   FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
768   return success();
769 }
770 
771 LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
772     transform::TransformOpInterface transform) {
773   InvalidatedHandleMap newlyInvalidated;
774   LogicalResult checkResult =
775       checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
776   invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
777                             std::make_move_iterator(newlyInvalidated.end()));
778   return checkResult;
779 }
780 
781 template <typename T>
782 DiagnosedSilenceableFailure
783 checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
784                                   transform::TransformOpInterface transform,
785                                   unsigned operandNumber) {
786   DenseSet<T> seen;
787   for (T p : payload) {
788     if (!seen.insert(p).second) {
789       DiagnosedSilenceableFailure diag =
790           transform.emitSilenceableError()
791           << "a handle passed as operand #" << operandNumber
792           << " and consumed by this operation points to a payload "
793              "entity more than once";
794       if constexpr (std::is_pointer_v<T>)
795         diag.attachNote(p->getLoc()) << "repeated target op";
796       else
797         diag.attachNote(p.getLoc()) << "repeated target value";
798       return diag;
799     }
800   }
801   return DiagnosedSilenceableFailure::success();
802 }
803 
804 void transform::TransformState::compactOpHandles() {
805   for (Value handle : opHandlesToCompact) {
806     Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
807 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
808     if (llvm::find(mappings.direct[handle], nullptr) !=
809         mappings.direct[handle].end())
810       // Payload IR is removed from the mapping. This invalidates the respective
811       // iterators.
812       mappings.incrementTimestamp(handle);
813 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
814     llvm::erase(mappings.direct[handle], nullptr);
815   }
816   opHandlesToCompact.clear();
817 }
818 
819 DiagnosedSilenceableFailure
820 transform::TransformState::applyTransform(TransformOpInterface transform) {
821   LLVM_DEBUG({
822     DBGS() << "applying: ";
823     transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
824     llvm::dbgs() << "\n";
825   });
826   DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
827                   DBGS() << "Top-level payload before application:\n"
828                          << *getTopLevel() << "\n");
829   auto printOnFailureRAII = llvm::make_scope_exit([this] {
830     (void)this;
831     LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
832         llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
833   });
834 
835   // Set current transform op.
836   regionStack.back()->currentTransform = transform;
837 
838   // Expensive checks to detect invalid transform IR.
839   if (options.getExpensiveChecksEnabled()) {
840     FULL_LDBG("ExpensiveChecksEnabled\n");
841     if (failed(checkAndRecordHandleInvalidation(transform)))
842       return DiagnosedSilenceableFailure::definiteFailure();
843 
844     for (OpOperand &operand : transform->getOpOperands()) {
845       DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
846         (DBGS() << "iterate on handle: " << operand.get() << "\n");
847       });
848       if (!isHandleConsumed(operand.get(), transform)) {
849         FULL_LDBG("--handle not consumed -> SKIP\n");
850         continue;
851       }
852       if (transform.allowsRepeatedHandleOperands()) {
853         FULL_LDBG("--op allows repeated handles -> SKIP\n");
854         continue;
855       }
856       FULL_LDBG("--handle is consumed\n");
857 
858       Type operandType = operand.get().getType();
859       if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
860         FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
861         DiagnosedSilenceableFailure check =
862             checkRepeatedConsumptionInOperand<Operation *>(
863                 getPayloadOpsView(operand.get()), transform,
864                 operand.getOperandNumber());
865         if (!check.succeeded()) {
866           FULL_LDBG("----FAILED\n");
867           return check;
868         }
869       } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
870         FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
871         DiagnosedSilenceableFailure check =
872             checkRepeatedConsumptionInOperand<Value>(
873                 getPayloadValuesView(operand.get()), transform,
874                 operand.getOperandNumber());
875         if (!check.succeeded()) {
876           FULL_LDBG("----FAILED\n");
877           return check;
878         }
879       } else {
880         FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
881       }
882     }
883   }
884 
885   // Find which operands are consumed.
886   SmallVector<OpOperand *> consumedOperands =
887       transform.getConsumedHandleOpOperands();
888 
889   // Remember the results of the payload ops associated with the consumed
890   // op handles or the ops defining the value handles so we can drop the
891   // association with them later. This must happen here because the
892   // transformation may destroy or mutate them so we cannot traverse the payload
893   // IR after that.
894   SmallVector<Value> origOpFlatResults;
895   SmallVector<Operation *> origAssociatedOps;
896   for (OpOperand *opOperand : consumedOperands) {
897     Value operand = opOperand->get();
898     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
899       for (Operation *payloadOp : getPayloadOps(operand)) {
900         llvm::append_range(origOpFlatResults, payloadOp->getResults());
901       }
902       continue;
903     }
904     if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
905       for (Value payloadValue : getPayloadValuesView(operand)) {
906         if (llvm::isa<OpResult>(payloadValue)) {
907           origAssociatedOps.push_back(payloadValue.getDefiningOp());
908           continue;
909         }
910         llvm::append_range(
911             origAssociatedOps,
912             llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
913                             [](Operation &op) { return &op; }));
914       }
915       continue;
916     }
917     DiagnosedDefiniteFailure diag =
918         emitDefiniteFailure(transform->getLoc())
919         << "unexpectedly consumed a value that is not a handle as operand #"
920         << opOperand->getOperandNumber();
921     diag.attachNote(operand.getLoc())
922         << "value defined here with type " << operand.getType();
923     return diag;
924   }
925 
926   // Prepare rewriter and listener.
927   TrackingListenerConfig config;
928   config.skipHandleFn = [&](Value handle) {
929     // Skip handle if it is dead.
930     auto scopeIt =
931         llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
932           return handle.getParentRegion() == scope->region;
933         });
934     assert(scopeIt != regionStack.rend() &&
935            "could not find region scope for handle");
936     RegionScope *scope = *scopeIt;
937     return llvm::all_of(handle.getUsers(), [&](Operation *user) {
938       return user == scope->currentTransform ||
939              happensBefore(user, scope->currentTransform);
940     });
941   };
942   transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
943                                                             config);
944   transform::TransformRewriter rewriter(transform->getContext(),
945                                         &trackingListener);
946 
947   // Compute the result but do not short-circuit the silenceable failure case as
948   // we still want the handles to propagate properly so the "suppress" mode can
949   // proceed on a best effort basis.
950   transform::TransformResults results(transform->getNumResults());
951   DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
952   compactOpHandles();
953 
954   // Error handling: fail if transform or listener failed.
955   DiagnosedSilenceableFailure trackingFailure =
956       trackingListener.checkAndResetError();
957   if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
958       transform->hasAttr(FindPayloadReplacementOpInterface::
959                              kSilenceTrackingFailuresAttrName)) {
960     // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
961     // do not report failures if the above mentioned attribute is set.
962     if (trackingFailure.isSilenceableFailure())
963       (void)trackingFailure.silence();
964     trackingFailure = DiagnosedSilenceableFailure::success();
965   }
966   if (!trackingFailure.succeeded()) {
967     if (result.succeeded()) {
968       result = std::move(trackingFailure);
969     } else {
970       // Transform op errors have precedence, report those first.
971       if (result.isSilenceableFailure())
972         result.attachNote() << "tracking listener also failed: "
973                             << trackingFailure.getMessage();
974       (void)trackingFailure.silence();
975     }
976   }
977   if (result.isDefiniteFailure())
978     return result;
979 
980   // If a silenceable failure was produced, some results may be unset, set them
981   // to empty lists.
982   if (result.isSilenceableFailure())
983     results.setRemainingToEmpty(transform);
984 
985   // Remove the mapping for the operand if it is consumed by the operation. This
986   // allows us to catch use-after-free with assertions later on.
987   for (OpOperand *opOperand : consumedOperands) {
988     Value operand = opOperand->get();
989     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
990       forgetMapping(operand, origOpFlatResults);
991     } else if (llvm::isa<TransformValueHandleTypeInterface>(
992                    operand.getType())) {
993       forgetValueMapping(operand, origAssociatedOps);
994     }
995   }
996 
997   if (failed(updateStateFromResults(results, transform->getResults())))
998     return DiagnosedSilenceableFailure::definiteFailure();
999 
1000   printOnFailureRAII.release();
1001   DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
1002     DBGS() << "Top-level payload:\n";
1003     getTopLevel()->print(llvm::dbgs());
1004   });
1005   return result;
1006 }
1007 
1008 LogicalResult transform::TransformState::updateStateFromResults(
1009     const TransformResults &results, ResultRange opResults) {
1010   for (OpResult result : opResults) {
1011     if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
1012       assert(results.isParam(result.getResultNumber()) &&
1013              "expected parameters for the parameter-typed result");
1014       if (failed(
1015               setParams(result, results.getParams(result.getResultNumber())))) {
1016         return failure();
1017       }
1018     } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
1019       assert(results.isValue(result.getResultNumber()) &&
1020              "expected values for value-type-result");
1021       if (failed(setPayloadValues(
1022               result, results.getValues(result.getResultNumber())))) {
1023         return failure();
1024       }
1025     } else {
1026       assert(!results.isParam(result.getResultNumber()) &&
1027              "expected payload ops for the non-parameter typed result");
1028       if (failed(
1029               setPayloadOps(result, results.get(result.getResultNumber())))) {
1030         return failure();
1031       }
1032     }
1033   }
1034   return success();
1035 }
1036 
1037 //===----------------------------------------------------------------------===//
1038 // TransformState::Extension
1039 //===----------------------------------------------------------------------===//
1040 
1041 transform::TransformState::Extension::~Extension() = default;
1042 
1043 LogicalResult
1044 transform::TransformState::Extension::replacePayloadOp(Operation *op,
1045                                                        Operation *replacement) {
1046   // TODO: we may need to invalidate handles to operations and values nested in
1047   // the operation being replaced.
1048   return state.replacePayloadOp(op, replacement);
1049 }
1050 
1051 LogicalResult
1052 transform::TransformState::Extension::replacePayloadValue(Value value,
1053                                                           Value replacement) {
1054   return state.replacePayloadValue(value, replacement);
1055 }
1056 
1057 //===----------------------------------------------------------------------===//
1058 // TransformState::RegionScope
1059 //===----------------------------------------------------------------------===//
1060 
1061 transform::TransformState::RegionScope::~RegionScope() {
1062   // Remove handle invalidation notices as handles are going out of scope.
1063   // The same region may be re-entered leading to incorrect invalidation
1064   // errors.
1065   for (Block &block : *region) {
1066     for (Value handle : block.getArguments()) {
1067       state.invalidatedHandles.erase(handle);
1068     }
1069     for (Operation &op : block) {
1070       for (Value handle : op.getResults()) {
1071         state.invalidatedHandles.erase(handle);
1072       }
1073     }
1074   }
1075 
1076 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1077   // Remember pointers to payload ops referenced by the handles going out of
1078   // scope.
1079   SmallVector<Operation *> referencedOps =
1080       llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
1081 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
1082 
1083   state.mappings.erase(region);
1084   state.regionStack.pop_back();
1085 }
1086 
1087 //===----------------------------------------------------------------------===//
1088 // TransformResults
1089 //===----------------------------------------------------------------------===//
1090 
1091 transform::TransformResults::TransformResults(unsigned numSegments) {
1092   operations.appendEmptyRows(numSegments);
1093   params.appendEmptyRows(numSegments);
1094   values.appendEmptyRows(numSegments);
1095 }
1096 
1097 void transform::TransformResults::setParams(
1098     OpResult value, ArrayRef<transform::TransformState::Param> params) {
1099   int64_t position = value.getResultNumber();
1100   assert(position < static_cast<int64_t>(this->params.size()) &&
1101          "setting params for a non-existent handle");
1102   assert(this->params[position].data() == nullptr && "params already set");
1103   assert(operations[position].data() == nullptr &&
1104          "another kind of results already set");
1105   assert(values[position].data() == nullptr &&
1106          "another kind of results already set");
1107   this->params.replace(position, params);
1108 }
1109 
1110 void transform::TransformResults::setMappedValues(
1111     OpResult handle, ArrayRef<MappedValue> values) {
1112   DiagnosedSilenceableFailure diag = dispatchMappedValues(
1113       handle, values,
1114       [&](ArrayRef<Operation *> operations) {
1115         return set(handle, operations), success();
1116       },
1117       [&](ArrayRef<Param> params) {
1118         return setParams(handle, params), success();
1119       },
1120       [&](ValueRange payloadValues) {
1121         return setValues(handle, payloadValues), success();
1122       });
1123 #ifndef NDEBUG
1124   if (!diag.succeeded())
1125     llvm::dbgs() << diag.getStatusString() << "\n";
1126   assert(diag.succeeded() && "incorrect mapping");
1127 #endif // NDEBUG
1128   (void)diag.silence();
1129 }
1130 
1131 void transform::TransformResults::setRemainingToEmpty(
1132     transform::TransformOpInterface transform) {
1133   for (OpResult opResult : transform->getResults()) {
1134     if (!isSet(opResult.getResultNumber()))
1135       setMappedValues(opResult, {});
1136   }
1137 }
1138 
1139 ArrayRef<Operation *>
1140 transform::TransformResults::get(unsigned resultNumber) const {
1141   assert(resultNumber < operations.size() &&
1142          "querying results for a non-existent handle");
1143   assert(operations[resultNumber].data() != nullptr &&
1144          "querying unset results (values or params expected?)");
1145   return operations[resultNumber];
1146 }
1147 
1148 ArrayRef<transform::TransformState::Param>
1149 transform::TransformResults::getParams(unsigned resultNumber) const {
1150   assert(resultNumber < params.size() &&
1151          "querying params for a non-existent handle");
1152   assert(params[resultNumber].data() != nullptr &&
1153          "querying unset params (ops or values expected?)");
1154   return params[resultNumber];
1155 }
1156 
1157 ArrayRef<Value>
1158 transform::TransformResults::getValues(unsigned resultNumber) const {
1159   assert(resultNumber < values.size() &&
1160          "querying values for a non-existent handle");
1161   assert(values[resultNumber].data() != nullptr &&
1162          "querying unset values (ops or params expected?)");
1163   return values[resultNumber];
1164 }
1165 
1166 bool transform::TransformResults::isParam(unsigned resultNumber) const {
1167   assert(resultNumber < params.size() &&
1168          "querying association for a non-existent handle");
1169   return params[resultNumber].data() != nullptr;
1170 }
1171 
1172 bool transform::TransformResults::isValue(unsigned resultNumber) const {
1173   assert(resultNumber < values.size() &&
1174          "querying association for a non-existent handle");
1175   return values[resultNumber].data() != nullptr;
1176 }
1177 
1178 bool transform::TransformResults::isSet(unsigned resultNumber) const {
1179   assert(resultNumber < params.size() &&
1180          "querying association for a non-existent handle");
1181   return params[resultNumber].data() != nullptr ||
1182          operations[resultNumber].data() != nullptr ||
1183          values[resultNumber].data() != nullptr;
1184 }
1185 
1186 //===----------------------------------------------------------------------===//
1187 // TrackingListener
1188 //===----------------------------------------------------------------------===//
1189 
1190 transform::TrackingListener::TrackingListener(TransformState &state,
1191                                               TransformOpInterface op,
1192                                               TrackingListenerConfig config)
1193     : TransformState::Extension(state), transformOp(op), config(config) {
1194   if (op) {
1195     for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
1196       consumedHandles.insert(opOperand->get());
1197     }
1198   }
1199 }
1200 
1201 Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
1202   Operation *defOp = nullptr;
1203   for (Value v : values) {
1204     // Skip empty values.
1205     if (!v)
1206       continue;
1207     if (!defOp) {
1208       defOp = v.getDefiningOp();
1209       continue;
1210     }
1211     if (defOp != v.getDefiningOp())
1212       return nullptr;
1213   }
1214   return defOp;
1215 }
1216 
1217 DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
1218     Operation *&result, Operation *op, ValueRange newValues) const {
1219   assert(op->getNumResults() == newValues.size() &&
1220          "invalid number of replacement values");
1221   SmallVector<Value> values(newValues.begin(), newValues.end());
1222 
1223   DiagnosedSilenceableFailure diag = emitSilenceableFailure(
1224       getTransformOp(), "tracking listener failed to find replacement op "
1225                         "during application of this transform op");
1226 
1227   do {
1228     // If the replacement values belong to different ops, drop the mapping.
1229     Operation *defOp = getCommonDefiningOp(values);
1230     if (!defOp) {
1231       diag.attachNote() << "replacement values belong to different ops";
1232       return diag;
1233     }
1234 
1235     // Skip through ops that implement CastOpInterface.
1236     if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
1237       values.clear();
1238       values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
1239       diag.attachNote(defOp->getLoc())
1240           << "using output of 'CastOpInterface' op";
1241       continue;
1242     }
1243 
1244     // If the defining op has the same name or we do not care about the name of
1245     // op replacements at all, we take it as a replacement.
1246     if (!config.requireMatchingReplacementOpName ||
1247         op->getName() == defOp->getName()) {
1248       result = defOp;
1249       return DiagnosedSilenceableFailure::success();
1250     }
1251 
1252     // Replacing an op with a constant-like equivalent is a common
1253     // canonicalization.
1254     if (defOp->hasTrait<OpTrait::ConstantLike>()) {
1255       result = defOp;
1256       return DiagnosedSilenceableFailure::success();
1257     }
1258 
1259     values.clear();
1260 
1261     // Skip through ops that implement FindPayloadReplacementOpInterface.
1262     if (auto findReplacementOpInterface =
1263             dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
1264       values.assign(findReplacementOpInterface.getNextOperands());
1265       diag.attachNote(defOp->getLoc()) << "using operands provided by "
1266                                           "'FindPayloadReplacementOpInterface'";
1267       continue;
1268     }
1269   } while (!values.empty());
1270 
1271   diag.attachNote() << "ran out of suitable replacement values";
1272   return diag;
1273 }
1274 
1275 void transform::TrackingListener::notifyMatchFailure(
1276     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
1277   LLVM_DEBUG({
1278     Diagnostic diag(loc, DiagnosticSeverity::Remark);
1279     reasonCallback(diag);
1280     DBGS() << "Match Failure : " << diag.str() << "\n";
1281   });
1282 }
1283 
1284 void transform::TrackingListener::notifyOperationErased(Operation *op) {
1285   // Remove mappings for result values.
1286   for (OpResult value : op->getResults())
1287     (void)replacePayloadValue(value, nullptr);
1288   // Remove mapping for op.
1289   (void)replacePayloadOp(op, nullptr);
1290 }
1291 
1292 void transform::TrackingListener::notifyOperationReplaced(
1293     Operation *op, ValueRange newValues) {
1294   assert(op->getNumResults() == newValues.size() &&
1295          "invalid number of replacement values");
1296 
1297   // Replace value handles.
1298   for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
1299     (void)replacePayloadValue(oldValue, newValue);
1300 
1301   // Replace op handle.
1302   SmallVector<Value> opHandles;
1303   if (failed(getTransformState().getHandlesForPayloadOp(
1304           op, opHandles, /*includeOutOfScope=*/true))) {
1305     // Op is not tracked.
1306     return;
1307   }
1308 
1309   // Helper function to check if the current transform op consumes any handle
1310   // that is mapped to `op`.
1311   //
1312   // Note: If a handle was consumed, there shouldn't be any alive users, so it
1313   // is not really necessary to check for consumed handles. However, in case
1314   // there are indeed alive handles that were consumed (which is undefined
1315   // behavior) and a replacement op could not be found, we want to fail with a
1316   // nicer error message: "op uses a handle invalidated..." instead of "could
1317   // not find replacement op". This nicer error is produced later.
1318   auto handleWasConsumed = [&] {
1319     return llvm::any_of(opHandles,
1320                         [&](Value h) { return consumedHandles.contains(h); });
1321   };
1322 
1323   // Check if there are any handles that must be updated.
1324   Value aliveHandle;
1325   if (config.skipHandleFn) {
1326     auto it = llvm::find_if(opHandles,
1327                             [&](Value v) { return !config.skipHandleFn(v); });
1328     if (it != opHandles.end())
1329       aliveHandle = *it;
1330   } else if (!opHandles.empty()) {
1331     aliveHandle = opHandles.front();
1332   }
1333   if (!aliveHandle || handleWasConsumed()) {
1334     // The op is tracked but the corresponding handles are dead or were
1335     // consumed. Drop the op form the mapping.
1336     (void)replacePayloadOp(op, nullptr);
1337     return;
1338   }
1339 
1340   Operation *replacement;
1341   DiagnosedSilenceableFailure diag =
1342       findReplacementOp(replacement, op, newValues);
1343   // If the op is tracked but no replacement op was found, send a
1344   // notification.
1345   if (!diag.succeeded()) {
1346     diag.attachNote(aliveHandle.getLoc())
1347         << "replacement is required because this handle must be updated";
1348     notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
1349     (void)replacePayloadOp(op, nullptr);
1350     return;
1351   }
1352 
1353   (void)replacePayloadOp(op, replacement);
1354 }
1355 
1356 transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
1357   // The state of the ErrorCheckingTrackingListener must be checked and reset
1358   // if there was an error. This is to prevent errors from accidentally being
1359   // missed.
1360   assert(status.succeeded() && "listener state was not checked");
1361 }
1362 
1363 DiagnosedSilenceableFailure
1364 transform::ErrorCheckingTrackingListener::checkAndResetError() {
1365   DiagnosedSilenceableFailure s = std::move(status);
1366   status = DiagnosedSilenceableFailure::success();
1367   errorCounter = 0;
1368   return s;
1369 }
1370 
1371 bool transform::ErrorCheckingTrackingListener::failed() const {
1372   return !status.succeeded();
1373 }
1374 
1375 void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
1376     Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
1377 
1378   // Merge potentially existing diags and store the result in the listener.
1379   SmallVector<Diagnostic> diags;
1380   diag.takeDiagnostics(diags);
1381   if (!status.succeeded())
1382     status.takeDiagnostics(diags);
1383   status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
1384 
1385   // Report more details.
1386   status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
1387   for (auto &&[index, value] : llvm::enumerate(values))
1388     status.attachNote(value.getLoc())
1389         << "[" << errorCounter << "] replacement value " << index;
1390   ++errorCounter;
1391 }
1392 
1393 //===----------------------------------------------------------------------===//
1394 // TransformRewriter
1395 //===----------------------------------------------------------------------===//
1396 
1397 transform::TransformRewriter::TransformRewriter(
1398     MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
1399     : RewriterBase(ctx), listener(listener) {
1400   setListener(listener);
1401 }
1402 
1403 bool transform::TransformRewriter::hasTrackingFailures() const {
1404   return listener->failed();
1405 }
1406 
1407 /// Silence all tracking failures that have been encountered so far.
1408 void transform::TransformRewriter::silenceTrackingFailure() {
1409   if (hasTrackingFailures()) {
1410     DiagnosedSilenceableFailure status = listener->checkAndResetError();
1411     (void)status.silence();
1412   }
1413 }
1414 
1415 LogicalResult transform::TransformRewriter::notifyPayloadOperationReplaced(
1416     Operation *op, Operation *replacement) {
1417   return listener->replacePayloadOp(op, replacement);
1418 }
1419 
1420 //===----------------------------------------------------------------------===//
1421 // Utilities for TransformEachOpTrait.
1422 //===----------------------------------------------------------------------===//
1423 
1424 LogicalResult
1425 transform::detail::checkNestedConsumption(Location loc,
1426                                           ArrayRef<Operation *> targets) {
1427   for (auto &&[position, parent] : llvm::enumerate(targets)) {
1428     for (Operation *child : targets.drop_front(position + 1)) {
1429       if (parent->isAncestor(child)) {
1430         InFlightDiagnostic diag =
1431             emitError(loc)
1432             << "transform operation consumes a handle pointing to an ancestor "
1433                "payload operation before its descendant";
1434         diag.attachNote()
1435             << "the ancestor is likely erased or rewritten before the "
1436                "descendant is accessed, leading to undefined behavior";
1437         diag.attachNote(parent->getLoc()) << "ancestor payload op";
1438         diag.attachNote(child->getLoc()) << "descendant payload op";
1439         return diag;
1440       }
1441     }
1442   }
1443   return success();
1444 }
1445 
1446 LogicalResult
1447 transform::detail::checkApplyToOne(Operation *transformOp,
1448                                    Location payloadOpLoc,
1449                                    const ApplyToEachResultList &partialResult) {
1450   Location transformOpLoc = transformOp->getLoc();
1451   StringRef transformOpName = transformOp->getName().getStringRef();
1452   unsigned expectedNumResults = transformOp->getNumResults();
1453 
1454   // Reuse the emission of the diagnostic note.
1455   auto emitDiag = [&]() {
1456     auto diag = mlir::emitError(transformOpLoc);
1457     diag.attachNote(payloadOpLoc) << "when applied to this op";
1458     return diag;
1459   };
1460 
1461   if (partialResult.size() != expectedNumResults) {
1462     auto diag = emitDiag() << "application of " << transformOpName
1463                            << " expected to produce " << expectedNumResults
1464                            << " results (actually produced "
1465                            << partialResult.size() << ").";
1466     diag.attachNote(transformOpLoc)
1467         << "if you need variadic results, consider a generic `apply` "
1468         << "instead of the specialized `applyToOne`.";
1469     return failure();
1470   }
1471 
1472   // Check that the right kind of value was produced.
1473   for (const auto &[ptr, res] :
1474        llvm::zip(partialResult, transformOp->getResults())) {
1475     if (ptr.isNull())
1476       continue;
1477     if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1478         !isa<Operation *>(ptr)) {
1479       return emitDiag() << "application of " << transformOpName
1480                         << " expected to produce an Operation * for result #"
1481                         << res.getResultNumber();
1482     }
1483     if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1484         !isa<Attribute>(ptr)) {
1485       return emitDiag() << "application of " << transformOpName
1486                         << " expected to produce an Attribute for result #"
1487                         << res.getResultNumber();
1488     }
1489     if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1490         !isa<Value>(ptr)) {
1491       return emitDiag() << "application of " << transformOpName
1492                         << " expected to produce a Value for result #"
1493                         << res.getResultNumber();
1494     }
1495   }
1496   return success();
1497 }
1498 
1499 template <typename T>
1500 static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
1501   return llvm::to_vector(llvm::map_range(
1502       range, [](transform::MappedValue value) { return cast<T>(value); }));
1503 }
1504 
1505 void transform::detail::setApplyToOneResults(
1506     Operation *transformOp, TransformResults &transformResults,
1507     ArrayRef<ApplyToEachResultList> results) {
1508   SmallVector<SmallVector<MappedValue>> transposed;
1509   transposed.resize(transformOp->getNumResults());
1510   for (const ApplyToEachResultList &partialResults : results) {
1511     if (llvm::any_of(partialResults,
1512                      [](MappedValue value) { return value.isNull(); }))
1513       continue;
1514     assert(transformOp->getNumResults() == partialResults.size() &&
1515            "expected as many partial results as op as results");
1516     for (auto [i, value] : llvm::enumerate(partialResults))
1517       transposed[i].push_back(value);
1518   }
1519 
1520   for (OpResult r : transformOp->getResults()) {
1521     unsigned position = r.getResultNumber();
1522     if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
1523       transformResults.setParams(r,
1524                                  castVector<Attribute>(transposed[position]));
1525     } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
1526       transformResults.setValues(r, castVector<Value>(transposed[position]));
1527     } else {
1528       transformResults.set(r, castVector<Operation *>(transposed[position]));
1529     }
1530   }
1531 }
1532 
1533 //===----------------------------------------------------------------------===//
1534 // Utilities for implementing transform ops with regions.
1535 //===----------------------------------------------------------------------===//
1536 
1537 LogicalResult transform::detail::appendValueMappings(
1538     MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
1539     ValueRange values, const transform::TransformState &state, bool flatten) {
1540   assert(mappings.size() == values.size() && "mismatching number of mappings");
1541   for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1542     size_t mappedSize = mapped.size();
1543     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
1544       llvm::append_range(mapped, state.getPayloadOps(operand));
1545     } else if (llvm::isa<TransformValueHandleTypeInterface>(
1546                    operand.getType())) {
1547       llvm::append_range(mapped, state.getPayloadValues(operand));
1548     } else {
1549       assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
1550              "unsupported kind of transform dialect value");
1551       llvm::append_range(mapped, state.getParams(operand));
1552     }
1553 
1554     if (mapped.size() - mappedSize != 1 && !flatten)
1555       return failure();
1556   }
1557   return success();
1558 }
1559 
1560 void transform::detail::prepareValueMappings(
1561     SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
1562     ValueRange values, const transform::TransformState &state) {
1563   mappings.resize(mappings.size() + values.size());
1564   (void)appendValueMappings(
1565       MutableArrayRef<SmallVector<transform::MappedValue>>(mappings).take_back(
1566           values.size()),
1567       values, state);
1568 }
1569 
1570 void transform::detail::forwardTerminatorOperands(
1571     Block *block, transform::TransformState &state,
1572     transform::TransformResults &results) {
1573   for (auto &&[terminatorOperand, result] :
1574        llvm::zip(block->getTerminator()->getOperands(),
1575                  block->getParentOp()->getOpResults())) {
1576     if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
1577       results.set(result, state.getPayloadOps(terminatorOperand));
1578     } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1579                    result.getType())) {
1580       results.setValues(result, state.getPayloadValues(terminatorOperand));
1581     } else {
1582       assert(
1583           llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
1584           "unhandled transform type interface");
1585       results.setParams(result, state.getParams(terminatorOperand));
1586     }
1587   }
1588 }
1589 
1590 transform::TransformState
1591 transform::detail::makeTransformStateForTesting(Region *region,
1592                                                 Operation *payloadRoot) {
1593   return TransformState(region, payloadRoot);
1594 }
1595 
1596 //===----------------------------------------------------------------------===//
1597 // Utilities for PossibleTopLevelTransformOpTrait.
1598 //===----------------------------------------------------------------------===//
1599 
1600 /// Appends to `effects` the memory effect instances on `target` with the same
1601 /// resource and effect as the ones the operation `iface` having on `source`.
1602 static void
1603 remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
1604              OpOperand *target,
1605              SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1606   SmallVector<MemoryEffects::EffectInstance> nestedEffects;
1607   iface.getEffectsOnValue(source, nestedEffects);
1608   for (const auto &effect : nestedEffects)
1609     effects.emplace_back(effect.getEffect(), target, effect.getResource());
1610 }
1611 
1612 /// Appends to `effects` the same effects as the operations of `block` have on
1613 /// block arguments but associated with `operands.`
1614 static void
1615 remapArgumentEffects(Block &block, MutableArrayRef<OpOperand> operands,
1616                      SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1617   for (Operation &op : block) {
1618     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1619     if (!iface)
1620       continue;
1621 
1622     for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
1623       remapEffects(iface, source, &target, effects);
1624     }
1625 
1626     SmallVector<MemoryEffects::EffectInstance> nestedEffects;
1627     iface.getEffectsOnResource(transform::PayloadIRResource::get(),
1628                                nestedEffects);
1629     llvm::append_range(effects, nestedEffects);
1630   }
1631 }
1632 
1633 void transform::detail::getPotentialTopLevelEffects(
1634     Operation *operation, Value root, Block &body,
1635     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1636   transform::onlyReadsHandle(operation->getOpOperands(), effects);
1637   transform::producesHandle(operation->getOpResults(), effects);
1638 
1639   if (!root) {
1640     for (Operation &op : body) {
1641       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
1642       if (!iface)
1643         continue;
1644 
1645       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
1646       iface.getEffects(effects);
1647     }
1648     return;
1649   }
1650 
1651   // Carry over all effects on arguments of the entry block as those on the
1652   // operands, this is the same value just remapped.
1653   remapArgumentEffects(body, operation->getOpOperands(), effects);
1654 }
1655 
1656 LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
1657     TransformState &state, Operation *op, Region &region) {
1658   SmallVector<Operation *> targets;
1659   SmallVector<SmallVector<MappedValue>> extraMappings;
1660   if (op->getNumOperands() != 0) {
1661     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
1662     prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
1663   } else {
1664     if (state.getNumTopLevelMappings() !=
1665         region.front().getNumArguments() - 1) {
1666       return emitError(op->getLoc())
1667              << "operation expects " << region.front().getNumArguments() - 1
1668              << " extra value bindings, but " << state.getNumTopLevelMappings()
1669              << " were provided to the interpreter";
1670     }
1671 
1672     targets.push_back(state.getTopLevel());
1673 
1674     for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
1675       extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
1676   }
1677 
1678   if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
1679     return failure();
1680 
1681   for (BlockArgument argument : region.front().getArguments().drop_front()) {
1682     if (failed(state.mapBlockArgument(
1683             argument, extraMappings[argument.getArgNumber() - 1])))
1684       return failure();
1685   }
1686 
1687   return success();
1688 }
1689 
1690 LogicalResult
1691 transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
1692   // Attaching this trait without the interface is a misuse of the API, but it
1693   // cannot be caught via a static_assert because interface registration is
1694   // dynamic.
1695   assert(isa<TransformOpInterface>(op) &&
1696          "should implement TransformOpInterface to have "
1697          "PossibleTopLevelTransformOpTrait");
1698 
1699   if (op->getNumRegions() < 1)
1700     return op->emitOpError() << "expects at least one region";
1701 
1702   Region *bodyRegion = &op->getRegion(0);
1703   if (!llvm::hasNItems(*bodyRegion, 1))
1704     return op->emitOpError() << "expects a single-block region";
1705 
1706   Block *body = &bodyRegion->front();
1707   if (body->getNumArguments() == 0) {
1708     return op->emitOpError()
1709            << "expects the entry block to have at least one argument";
1710   }
1711   if (!llvm::isa<TransformHandleTypeInterface>(
1712           body->getArgument(0).getType())) {
1713     return op->emitOpError()
1714            << "expects the first entry block argument to be of type "
1715               "implementing TransformHandleTypeInterface";
1716   }
1717   BlockArgument arg = body->getArgument(0);
1718   if (op->getNumOperands() != 0) {
1719     if (arg.getType() != op->getOperand(0).getType()) {
1720       return op->emitOpError()
1721              << "expects the type of the block argument to match "
1722                 "the type of the operand";
1723     }
1724   }
1725   for (BlockArgument arg : body->getArguments().drop_front()) {
1726     if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
1727                   TransformValueHandleTypeInterface>(arg.getType()))
1728       continue;
1729 
1730     InFlightDiagnostic diag =
1731         op->emitOpError()
1732         << "expects trailing entry block arguments to be of type implementing "
1733            "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
1734            "TransformParamTypeInterface";
1735     diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
1736     return diag;
1737   }
1738 
1739   if (auto *parent =
1740           op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
1741     if (op->getNumOperands() != body->getNumArguments()) {
1742       InFlightDiagnostic diag =
1743           op->emitOpError()
1744           << "expects operands to be provided for a nested op";
1745       diag.attachNote(parent->getLoc())
1746           << "nested in another possible top-level op";
1747       return diag;
1748     }
1749   }
1750 
1751   return success();
1752 }
1753 
1754 //===----------------------------------------------------------------------===//
1755 // Utilities for ParamProducedTransformOpTrait.
1756 //===----------------------------------------------------------------------===//
1757 
1758 void transform::detail::getParamProducerTransformOpTraitEffects(
1759     Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1760   producesHandle(op->getResults(), effects);
1761   bool hasPayloadOperands = false;
1762   for (OpOperand &operand : op->getOpOperands()) {
1763     onlyReadsHandle(operand, effects);
1764     if (llvm::isa<TransformHandleTypeInterface,
1765                   TransformValueHandleTypeInterface>(operand.get().getType()))
1766       hasPayloadOperands = true;
1767   }
1768   if (hasPayloadOperands)
1769     onlyReadsPayload(effects);
1770 }
1771 
1772 LogicalResult
1773 transform::detail::verifyParamProducerTransformOpTrait(Operation *op) {
1774   // Interfaces can be attached dynamically, so this cannot be a static
1775   // assert.
1776   if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
1777     llvm::report_fatal_error(
1778         Twine("ParamProducerTransformOpTrait must be attached to an op that "
1779               "implements MemoryEffectsOpInterface, found on ") +
1780         op->getName().getStringRef());
1781   }
1782   for (Value result : op->getResults()) {
1783     if (llvm::isa<TransformParamTypeInterface>(result.getType()))
1784       continue;
1785     return op->emitOpError()
1786            << "ParamProducerTransformOpTrait attached to this op expects "
1787               "result types to implement TransformParamTypeInterface";
1788   }
1789   return success();
1790 }
1791 
1792 //===----------------------------------------------------------------------===//
1793 // Memory effects.
1794 //===----------------------------------------------------------------------===//
1795 
1796 void transform::consumesHandle(
1797     MutableArrayRef<OpOperand> handles,
1798     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1799   for (OpOperand &handle : handles) {
1800     effects.emplace_back(MemoryEffects::Read::get(), &handle,
1801                          TransformMappingResource::get());
1802     effects.emplace_back(MemoryEffects::Free::get(), &handle,
1803                          TransformMappingResource::get());
1804   }
1805 }
1806 
1807 /// Returns `true` if the given list of effects instances contains an instance
1808 /// with the effect type specified as template parameter.
1809 template <typename EffectTy, typename ResourceTy, typename Range>
1810 static bool hasEffect(Range &&effects) {
1811   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1812     return isa<EffectTy>(effect.getEffect()) &&
1813            isa<ResourceTy>(effect.getResource());
1814   });
1815 }
1816 
1817 bool transform::isHandleConsumed(Value handle,
1818                                  transform::TransformOpInterface transform) {
1819   auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1820   SmallVector<MemoryEffects::EffectInstance> effects;
1821   iface.getEffectsOnValue(handle, effects);
1822   return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
1823          ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
1824 }
1825 
1826 void transform::producesHandle(
1827     ResultRange handles,
1828     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1829   for (OpResult handle : handles) {
1830     effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1831                          TransformMappingResource::get());
1832     effects.emplace_back(MemoryEffects::Write::get(), handle,
1833                          TransformMappingResource::get());
1834   }
1835 }
1836 
1837 void transform::producesHandle(
1838     MutableArrayRef<BlockArgument> handles,
1839     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1840   for (BlockArgument handle : handles) {
1841     effects.emplace_back(MemoryEffects::Allocate::get(), handle,
1842                          TransformMappingResource::get());
1843     effects.emplace_back(MemoryEffects::Write::get(), handle,
1844                          TransformMappingResource::get());
1845   }
1846 }
1847 
1848 void transform::onlyReadsHandle(
1849     MutableArrayRef<OpOperand> handles,
1850     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1851   for (OpOperand &handle : handles) {
1852     effects.emplace_back(MemoryEffects::Read::get(), &handle,
1853                          TransformMappingResource::get());
1854   }
1855 }
1856 
1857 void transform::modifiesPayload(
1858     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1859   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1860   effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
1861 }
1862 
1863 void transform::onlyReadsPayload(
1864     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1865   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
1866 }
1867 
1868 bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
1869   auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1870   SmallVector<MemoryEffects::EffectInstance> effects;
1871   iface.getEffects(effects);
1872   return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
1873 }
1874 
1875 bool transform::doesReadPayload(transform::TransformOpInterface transform) {
1876   auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
1877   SmallVector<MemoryEffects::EffectInstance> effects;
1878   iface.getEffects(effects);
1879   return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
1880 }
1881 
1882 void transform::getConsumedBlockArguments(
1883     Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
1884   SmallVector<MemoryEffects::EffectInstance> effects;
1885   for (Operation &nested : block) {
1886     auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
1887     if (!iface)
1888       continue;
1889 
1890     effects.clear();
1891     iface.getEffects(effects);
1892     for (const MemoryEffects::EffectInstance &effect : effects) {
1893       BlockArgument argument =
1894           dyn_cast_or_null<BlockArgument>(effect.getValue());
1895       if (!argument || argument.getOwner() != &block ||
1896           !isa<MemoryEffects::Free>(effect.getEffect()) ||
1897           effect.getResource() != transform::TransformMappingResource::get()) {
1898         continue;
1899       }
1900       consumedArguments.insert(argument.getArgNumber());
1901     }
1902   }
1903 }
1904 
1905 //===----------------------------------------------------------------------===//
1906 // Utilities for TransformOpInterface.
1907 //===----------------------------------------------------------------------===//
1908 
1909 SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands(
1910     TransformOpInterface transformOp) {
1911   SmallVector<OpOperand *> consumedOperands;
1912   consumedOperands.reserve(transformOp->getNumOperands());
1913   auto memEffectInterface =
1914       cast<MemoryEffectOpInterface>(transformOp.getOperation());
1915   SmallVector<MemoryEffects::EffectInstance, 2> effects;
1916   for (OpOperand &target : transformOp->getOpOperands()) {
1917     effects.clear();
1918     memEffectInterface.getEffectsOnValue(target.get(), effects);
1919     if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
1920           return isa<transform::TransformMappingResource>(
1921                      effect.getResource()) &&
1922                  isa<MemoryEffects::Free>(effect.getEffect());
1923         })) {
1924       consumedOperands.push_back(&target);
1925     }
1926   }
1927   return consumedOperands;
1928 }
1929 
1930 LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
1931   auto iface = cast<MemoryEffectOpInterface>(op);
1932   SmallVector<MemoryEffects::EffectInstance> effects;
1933   iface.getEffects(effects);
1934 
1935   auto effectsOn = [&](Value value) {
1936     return llvm::make_filter_range(
1937         effects, [value](const MemoryEffects::EffectInstance &instance) {
1938           return instance.getValue() == value;
1939         });
1940   };
1941 
1942   std::optional<unsigned> firstConsumedOperand;
1943   for (OpOperand &operand : op->getOpOperands()) {
1944     auto range = effectsOn(operand.get());
1945     if (range.empty()) {
1946       InFlightDiagnostic diag =
1947           op->emitError() << "TransformOpInterface requires memory effects "
1948                              "on operands to be specified";
1949       diag.attachNote() << "no effects specified for operand #"
1950                         << operand.getOperandNumber();
1951       return diag;
1952     }
1953     if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
1954       InFlightDiagnostic diag = op->emitError()
1955                                 << "TransformOpInterface did not expect "
1956                                    "'allocate' memory effect on an operand";
1957       diag.attachNote() << "specified for operand #"
1958                         << operand.getOperandNumber();
1959       return diag;
1960     }
1961     if (!firstConsumedOperand &&
1962         ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
1963       firstConsumedOperand = operand.getOperandNumber();
1964     }
1965   }
1966 
1967   if (firstConsumedOperand &&
1968       !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
1969     InFlightDiagnostic diag =
1970         op->emitError()
1971         << "TransformOpInterface expects ops consuming operands to have a "
1972            "'write' effect on the payload resource";
1973     diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
1974     return diag;
1975   }
1976 
1977   for (OpResult result : op->getResults()) {
1978     auto range = effectsOn(result);
1979     if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
1980             range)) {
1981       InFlightDiagnostic diag =
1982           op->emitError() << "TransformOpInterface requires 'allocate' memory "
1983                              "effect to be specified for results";
1984       diag.attachNote() << "no 'allocate' effect specified for result #"
1985                         << result.getResultNumber();
1986       return diag;
1987     }
1988   }
1989 
1990   return success();
1991 }
1992 
1993 //===----------------------------------------------------------------------===//
1994 // Entry point.
1995 //===----------------------------------------------------------------------===//
1996 
1997 LogicalResult transform::applyTransforms(
1998     Operation *payloadRoot, TransformOpInterface transform,
1999     const RaggedArray<MappedValue> &extraMapping,
2000     const TransformOptions &options, bool enforceToplevelTransformOp,
2001     function_ref<void(TransformState &)> stateInitializer,
2002     function_ref<LogicalResult(TransformState &)> stateExporter) {
2003   if (enforceToplevelTransformOp) {
2004     if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
2005         transform->getNumOperands() != 0) {
2006       return transform->emitError()
2007              << "expected transform to start at the top-level transform op";
2008     }
2009   } else if (failed(
2010                  detail::verifyPossibleTopLevelTransformOpTrait(transform))) {
2011     return failure();
2012   }
2013 
2014   TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
2015                        options);
2016   if (stateInitializer)
2017     stateInitializer(state);
2018   if (state.applyTransform(transform).checkAndReport().failed())
2019     return failure();
2020   if (stateExporter)
2021     return stateExporter(state);
2022   return success();
2023 }
2024 
2025 //===----------------------------------------------------------------------===//
2026 // Generated interface implementation.
2027 //===----------------------------------------------------------------------===//
2028 
2029 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
2030 #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
2031