xref: /llvm-project/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp (revision 129f1001c3b1b5200de43917d53c0efbdf08f11f)
15a9bdd85SOleksandr "Alex" Zinenko //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===//
25a9bdd85SOleksandr "Alex" Zinenko //
35a9bdd85SOleksandr "Alex" Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45a9bdd85SOleksandr "Alex" Zinenko // See https://llvm.org/LICENSE.txt for license information.
55a9bdd85SOleksandr "Alex" Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65a9bdd85SOleksandr "Alex" Zinenko //
75a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
85a9bdd85SOleksandr "Alex" Zinenko 
95a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
105a9bdd85SOleksandr "Alex" Zinenko 
115a9bdd85SOleksandr "Alex" Zinenko #include "mlir/IR/Diagnostics.h"
125a9bdd85SOleksandr "Alex" Zinenko #include "mlir/IR/Operation.h"
135a9bdd85SOleksandr "Alex" Zinenko #include "mlir/IR/PatternMatch.h"
145a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Interfaces/CastInterfaces.h"
155a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
165a9bdd85SOleksandr "Alex" Zinenko #include "llvm/ADT/STLExtras.h"
175a9bdd85SOleksandr "Alex" Zinenko #include "llvm/ADT/ScopeExit.h"
185a9bdd85SOleksandr "Alex" Zinenko #include "llvm/Support/Debug.h"
195a9bdd85SOleksandr "Alex" Zinenko #include "llvm/Support/ErrorHandling.h"
205a9bdd85SOleksandr "Alex" Zinenko 
215a9bdd85SOleksandr "Alex" Zinenko #define DEBUG_TYPE "transform-dialect"
225a9bdd85SOleksandr "Alex" Zinenko #define DEBUG_TYPE_FULL "transform-dialect-full"
235a9bdd85SOleksandr "Alex" Zinenko #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
245a9bdd85SOleksandr "Alex" Zinenko #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
255a9bdd85SOleksandr "Alex" Zinenko #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
265a9bdd85SOleksandr "Alex" Zinenko #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X)))
275a9bdd85SOleksandr "Alex" Zinenko 
285a9bdd85SOleksandr "Alex" Zinenko using namespace mlir;
295a9bdd85SOleksandr "Alex" Zinenko 
305a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
315a9bdd85SOleksandr "Alex" Zinenko // Helper functions
325a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
335a9bdd85SOleksandr "Alex" Zinenko 
345a9bdd85SOleksandr "Alex" Zinenko /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
355a9bdd85SOleksandr "Alex" Zinenko /// properly dominates `b` and `b` is not inside `a`.
365a9bdd85SOleksandr "Alex" Zinenko static bool happensBefore(Operation *a, Operation *b) {
375a9bdd85SOleksandr "Alex" Zinenko   do {
385a9bdd85SOleksandr "Alex" Zinenko     if (a->isProperAncestor(b))
395a9bdd85SOleksandr "Alex" Zinenko       return false;
405a9bdd85SOleksandr "Alex" Zinenko     if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) {
415a9bdd85SOleksandr "Alex" Zinenko       return a->isBeforeInBlock(bAncestor);
425a9bdd85SOleksandr "Alex" Zinenko     }
435a9bdd85SOleksandr "Alex" Zinenko   } while ((a = a->getParentOp()));
445a9bdd85SOleksandr "Alex" Zinenko   return false;
455a9bdd85SOleksandr "Alex" Zinenko }
465a9bdd85SOleksandr "Alex" Zinenko 
475a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
485a9bdd85SOleksandr "Alex" Zinenko // TransformState
495a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
505a9bdd85SOleksandr "Alex" Zinenko 
515a9bdd85SOleksandr "Alex" Zinenko constexpr const Value transform::TransformState::kTopLevelValue;
525a9bdd85SOleksandr "Alex" Zinenko 
535a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::TransformState(
545a9bdd85SOleksandr "Alex" Zinenko     Region *region, Operation *payloadRoot,
555a9bdd85SOleksandr "Alex" Zinenko     const RaggedArray<MappedValue> &extraMappings,
565a9bdd85SOleksandr "Alex" Zinenko     const TransformOptions &options)
575a9bdd85SOleksandr "Alex" Zinenko     : topLevel(payloadRoot), options(options) {
585a9bdd85SOleksandr "Alex" Zinenko   topLevelMappedValues.reserve(extraMappings.size());
595a9bdd85SOleksandr "Alex" Zinenko   for (ArrayRef<MappedValue> mapping : extraMappings)
605a9bdd85SOleksandr "Alex" Zinenko     topLevelMappedValues.push_back(mapping);
615a9bdd85SOleksandr "Alex" Zinenko   if (region) {
625a9bdd85SOleksandr "Alex" Zinenko     RegionScope *scope = new RegionScope(*this, *region);
635a9bdd85SOleksandr "Alex" Zinenko     topLevelRegionScope.reset(scope);
645a9bdd85SOleksandr "Alex" Zinenko   }
655a9bdd85SOleksandr "Alex" Zinenko }
665a9bdd85SOleksandr "Alex" Zinenko 
675a9bdd85SOleksandr "Alex" Zinenko Operation *transform::TransformState::getTopLevel() const { return topLevel; }
685a9bdd85SOleksandr "Alex" Zinenko 
695a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Operation *>
705a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::getPayloadOpsView(Value value) const {
715a9bdd85SOleksandr "Alex" Zinenko   const TransformOpMapping &operationMapping = getMapping(value).direct;
725a9bdd85SOleksandr "Alex" Zinenko   auto iter = operationMapping.find(value);
735a9bdd85SOleksandr "Alex" Zinenko   assert(iter != operationMapping.end() &&
745a9bdd85SOleksandr "Alex" Zinenko          "cannot find mapping for payload handle (param/value handle "
755a9bdd85SOleksandr "Alex" Zinenko          "provided?)");
765a9bdd85SOleksandr "Alex" Zinenko   return iter->getSecond();
775a9bdd85SOleksandr "Alex" Zinenko }
785a9bdd85SOleksandr "Alex" Zinenko 
795a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Attribute> transform::TransformState::getParams(Value value) const {
805a9bdd85SOleksandr "Alex" Zinenko   const ParamMapping &mapping = getMapping(value).params;
815a9bdd85SOleksandr "Alex" Zinenko   auto iter = mapping.find(value);
825a9bdd85SOleksandr "Alex" Zinenko   assert(iter != mapping.end() && "cannot find mapping for param handle "
835a9bdd85SOleksandr "Alex" Zinenko                                   "(operation/value handle provided?)");
845a9bdd85SOleksandr "Alex" Zinenko   return iter->getSecond();
855a9bdd85SOleksandr "Alex" Zinenko }
865a9bdd85SOleksandr "Alex" Zinenko 
875a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Value>
885a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::getPayloadValuesView(Value handleValue) const {
895a9bdd85SOleksandr "Alex" Zinenko   const ValueMapping &mapping = getMapping(handleValue).values;
905a9bdd85SOleksandr "Alex" Zinenko   auto iter = mapping.find(handleValue);
915a9bdd85SOleksandr "Alex" Zinenko   assert(iter != mapping.end() && "cannot find mapping for value handle "
925a9bdd85SOleksandr "Alex" Zinenko                                   "(param/operation handle provided?)");
935a9bdd85SOleksandr "Alex" Zinenko   return iter->getSecond();
945a9bdd85SOleksandr "Alex" Zinenko }
955a9bdd85SOleksandr "Alex" Zinenko 
965a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::getHandlesForPayloadOp(
975a9bdd85SOleksandr "Alex" Zinenko     Operation *op, SmallVectorImpl<Value> &handles,
985a9bdd85SOleksandr "Alex" Zinenko     bool includeOutOfScope) const {
995a9bdd85SOleksandr "Alex" Zinenko   bool found = false;
1005a9bdd85SOleksandr "Alex" Zinenko   for (const auto &[region, mapping] : llvm::reverse(mappings)) {
1015a9bdd85SOleksandr "Alex" Zinenko     auto iterator = mapping->reverse.find(op);
1025a9bdd85SOleksandr "Alex" Zinenko     if (iterator != mapping->reverse.end()) {
1035a9bdd85SOleksandr "Alex" Zinenko       llvm::append_range(handles, iterator->getSecond());
1045a9bdd85SOleksandr "Alex" Zinenko       found = true;
1055a9bdd85SOleksandr "Alex" Zinenko     }
1065a9bdd85SOleksandr "Alex" Zinenko     // Stop looking when reaching a region that is isolated from above.
1075a9bdd85SOleksandr "Alex" Zinenko     if (!includeOutOfScope &&
1085a9bdd85SOleksandr "Alex" Zinenko         region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
1095a9bdd85SOleksandr "Alex" Zinenko       break;
1105a9bdd85SOleksandr "Alex" Zinenko   }
1115a9bdd85SOleksandr "Alex" Zinenko 
1125a9bdd85SOleksandr "Alex" Zinenko   return success(found);
1135a9bdd85SOleksandr "Alex" Zinenko }
1145a9bdd85SOleksandr "Alex" Zinenko 
1155a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::getHandlesForPayloadValue(
1165a9bdd85SOleksandr "Alex" Zinenko     Value payloadValue, SmallVectorImpl<Value> &handles,
1175a9bdd85SOleksandr "Alex" Zinenko     bool includeOutOfScope) const {
1185a9bdd85SOleksandr "Alex" Zinenko   bool found = false;
1195a9bdd85SOleksandr "Alex" Zinenko   for (const auto &[region, mapping] : llvm::reverse(mappings)) {
1205a9bdd85SOleksandr "Alex" Zinenko     auto iterator = mapping->reverseValues.find(payloadValue);
1215a9bdd85SOleksandr "Alex" Zinenko     if (iterator != mapping->reverseValues.end()) {
1225a9bdd85SOleksandr "Alex" Zinenko       llvm::append_range(handles, iterator->getSecond());
1235a9bdd85SOleksandr "Alex" Zinenko       found = true;
1245a9bdd85SOleksandr "Alex" Zinenko     }
1255a9bdd85SOleksandr "Alex" Zinenko     // Stop looking when reaching a region that is isolated from above.
1265a9bdd85SOleksandr "Alex" Zinenko     if (!includeOutOfScope &&
1275a9bdd85SOleksandr "Alex" Zinenko         region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
1285a9bdd85SOleksandr "Alex" Zinenko       break;
1295a9bdd85SOleksandr "Alex" Zinenko   }
1305a9bdd85SOleksandr "Alex" Zinenko 
1315a9bdd85SOleksandr "Alex" Zinenko   return success(found);
1325a9bdd85SOleksandr "Alex" Zinenko }
1335a9bdd85SOleksandr "Alex" Zinenko 
1345a9bdd85SOleksandr "Alex" Zinenko /// Given a list of MappedValues, cast them to the value kind implied by the
1355a9bdd85SOleksandr "Alex" Zinenko /// interface of the handle type, and dispatch to one of the callbacks.
1365a9bdd85SOleksandr "Alex" Zinenko static DiagnosedSilenceableFailure dispatchMappedValues(
1375a9bdd85SOleksandr "Alex" Zinenko     Value handle, ArrayRef<transform::MappedValue> values,
1385a9bdd85SOleksandr "Alex" Zinenko     function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn,
1395a9bdd85SOleksandr "Alex" Zinenko     function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn,
1405a9bdd85SOleksandr "Alex" Zinenko     function_ref<LogicalResult(ValueRange)> valuesFn) {
1415a9bdd85SOleksandr "Alex" Zinenko   if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) {
1425a9bdd85SOleksandr "Alex" Zinenko     SmallVector<Operation *> operations;
1435a9bdd85SOleksandr "Alex" Zinenko     operations.reserve(values.size());
1445a9bdd85SOleksandr "Alex" Zinenko     for (transform::MappedValue value : values) {
1455a9bdd85SOleksandr "Alex" Zinenko       if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) {
1465a9bdd85SOleksandr "Alex" Zinenko         operations.push_back(op);
1475a9bdd85SOleksandr "Alex" Zinenko         continue;
1485a9bdd85SOleksandr "Alex" Zinenko       }
1495a9bdd85SOleksandr "Alex" Zinenko       return emitSilenceableFailure(handle.getLoc())
1505a9bdd85SOleksandr "Alex" Zinenko              << "wrong kind of value provided for top-level operation handle";
1515a9bdd85SOleksandr "Alex" Zinenko     }
1525a9bdd85SOleksandr "Alex" Zinenko     if (failed(operationsFn(operations)))
1535a9bdd85SOleksandr "Alex" Zinenko       return DiagnosedSilenceableFailure::definiteFailure();
1545a9bdd85SOleksandr "Alex" Zinenko     return DiagnosedSilenceableFailure::success();
1555a9bdd85SOleksandr "Alex" Zinenko   }
1565a9bdd85SOleksandr "Alex" Zinenko 
1575a9bdd85SOleksandr "Alex" Zinenko   if (llvm::isa<transform::TransformValueHandleTypeInterface>(
1585a9bdd85SOleksandr "Alex" Zinenko           handle.getType())) {
1595a9bdd85SOleksandr "Alex" Zinenko     SmallVector<Value> payloadValues;
1605a9bdd85SOleksandr "Alex" Zinenko     payloadValues.reserve(values.size());
1615a9bdd85SOleksandr "Alex" Zinenko     for (transform::MappedValue value : values) {
1625a9bdd85SOleksandr "Alex" Zinenko       if (auto v = llvm::dyn_cast_if_present<Value>(value)) {
1635a9bdd85SOleksandr "Alex" Zinenko         payloadValues.push_back(v);
1645a9bdd85SOleksandr "Alex" Zinenko         continue;
1655a9bdd85SOleksandr "Alex" Zinenko       }
1665a9bdd85SOleksandr "Alex" Zinenko       return emitSilenceableFailure(handle.getLoc())
1675a9bdd85SOleksandr "Alex" Zinenko              << "wrong kind of value provided for the top-level value handle";
1685a9bdd85SOleksandr "Alex" Zinenko     }
1695a9bdd85SOleksandr "Alex" Zinenko     if (failed(valuesFn(payloadValues)))
1705a9bdd85SOleksandr "Alex" Zinenko       return DiagnosedSilenceableFailure::definiteFailure();
1715a9bdd85SOleksandr "Alex" Zinenko     return DiagnosedSilenceableFailure::success();
1725a9bdd85SOleksandr "Alex" Zinenko   }
1735a9bdd85SOleksandr "Alex" Zinenko 
1745a9bdd85SOleksandr "Alex" Zinenko   assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) &&
1755a9bdd85SOleksandr "Alex" Zinenko          "unsupported kind of block argument");
1765a9bdd85SOleksandr "Alex" Zinenko   SmallVector<transform::Param> parameters;
1775a9bdd85SOleksandr "Alex" Zinenko   parameters.reserve(values.size());
1785a9bdd85SOleksandr "Alex" Zinenko   for (transform::MappedValue value : values) {
1795a9bdd85SOleksandr "Alex" Zinenko     if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) {
1805a9bdd85SOleksandr "Alex" Zinenko       parameters.push_back(attr);
1815a9bdd85SOleksandr "Alex" Zinenko       continue;
1825a9bdd85SOleksandr "Alex" Zinenko     }
1835a9bdd85SOleksandr "Alex" Zinenko     return emitSilenceableFailure(handle.getLoc())
1845a9bdd85SOleksandr "Alex" Zinenko            << "wrong kind of value provided for top-level parameter";
1855a9bdd85SOleksandr "Alex" Zinenko   }
1865a9bdd85SOleksandr "Alex" Zinenko   if (failed(paramsFn(parameters)))
1875a9bdd85SOleksandr "Alex" Zinenko     return DiagnosedSilenceableFailure::definiteFailure();
1885a9bdd85SOleksandr "Alex" Zinenko   return DiagnosedSilenceableFailure::success();
1895a9bdd85SOleksandr "Alex" Zinenko }
1905a9bdd85SOleksandr "Alex" Zinenko 
1915a9bdd85SOleksandr "Alex" Zinenko LogicalResult
1925a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::mapBlockArgument(BlockArgument argument,
1935a9bdd85SOleksandr "Alex" Zinenko                                             ArrayRef<MappedValue> values) {
1945a9bdd85SOleksandr "Alex" Zinenko   return dispatchMappedValues(
1955a9bdd85SOleksandr "Alex" Zinenko              argument, values,
1965a9bdd85SOleksandr "Alex" Zinenko              [&](ArrayRef<Operation *> operations) {
1975a9bdd85SOleksandr "Alex" Zinenko                return setPayloadOps(argument, operations);
1985a9bdd85SOleksandr "Alex" Zinenko              },
1995a9bdd85SOleksandr "Alex" Zinenko              [&](ArrayRef<Param> params) {
2005a9bdd85SOleksandr "Alex" Zinenko                return setParams(argument, params);
2015a9bdd85SOleksandr "Alex" Zinenko              },
2025a9bdd85SOleksandr "Alex" Zinenko              [&](ValueRange payloadValues) {
2035a9bdd85SOleksandr "Alex" Zinenko                return setPayloadValues(argument, payloadValues);
2045a9bdd85SOleksandr "Alex" Zinenko              })
2055a9bdd85SOleksandr "Alex" Zinenko       .checkAndReport();
2065a9bdd85SOleksandr "Alex" Zinenko }
2075a9bdd85SOleksandr "Alex" Zinenko 
208e4b04b39SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::mapBlockArguments(
209e4b04b39SOleksandr "Alex" Zinenko     Block::BlockArgListType arguments,
210e4b04b39SOleksandr "Alex" Zinenko     ArrayRef<SmallVector<MappedValue>> mapping) {
211e4b04b39SOleksandr "Alex" Zinenko   for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping))
212e4b04b39SOleksandr "Alex" Zinenko     if (failed(mapBlockArgument(argument, values)))
213e4b04b39SOleksandr "Alex" Zinenko       return failure();
214e4b04b39SOleksandr "Alex" Zinenko   return success();
215e4b04b39SOleksandr "Alex" Zinenko }
216e4b04b39SOleksandr "Alex" Zinenko 
2175a9bdd85SOleksandr "Alex" Zinenko LogicalResult
2185a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::setPayloadOps(Value value,
2195a9bdd85SOleksandr "Alex" Zinenko                                          ArrayRef<Operation *> targets) {
2205a9bdd85SOleksandr "Alex" Zinenko   assert(value != kTopLevelValue &&
2215a9bdd85SOleksandr "Alex" Zinenko          "attempting to reset the transformation root");
2225a9bdd85SOleksandr "Alex" Zinenko   assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) &&
2235a9bdd85SOleksandr "Alex" Zinenko          "wrong handle type");
2245a9bdd85SOleksandr "Alex" Zinenko 
2255a9bdd85SOleksandr "Alex" Zinenko   for (Operation *target : targets) {
2265a9bdd85SOleksandr "Alex" Zinenko     if (target)
2275a9bdd85SOleksandr "Alex" Zinenko       continue;
2285a9bdd85SOleksandr "Alex" Zinenko     return emitError(value.getLoc())
2295a9bdd85SOleksandr "Alex" Zinenko            << "attempting to assign a null payload op to this transform value";
2305a9bdd85SOleksandr "Alex" Zinenko   }
2315a9bdd85SOleksandr "Alex" Zinenko 
2325a9bdd85SOleksandr "Alex" Zinenko   auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType());
2335a9bdd85SOleksandr "Alex" Zinenko   DiagnosedSilenceableFailure result =
2345a9bdd85SOleksandr "Alex" Zinenko       iface.checkPayload(value.getLoc(), targets);
2355a9bdd85SOleksandr "Alex" Zinenko   if (failed(result.checkAndReport()))
2365a9bdd85SOleksandr "Alex" Zinenko     return failure();
2375a9bdd85SOleksandr "Alex" Zinenko 
2385a9bdd85SOleksandr "Alex" Zinenko   // Setting new payload for the value without cleaning it first is a misuse of
2395a9bdd85SOleksandr "Alex" Zinenko   // the API, assert here.
2405262865aSKazu Hirata   SmallVector<Operation *> storedTargets(targets);
2415a9bdd85SOleksandr "Alex" Zinenko   Mappings &mappings = getMapping(value);
2425a9bdd85SOleksandr "Alex" Zinenko   bool inserted =
2435a9bdd85SOleksandr "Alex" Zinenko       mappings.direct.insert({value, std::move(storedTargets)}).second;
2445a9bdd85SOleksandr "Alex" Zinenko   assert(inserted && "value is already associated with another list");
2455a9bdd85SOleksandr "Alex" Zinenko   (void)inserted;
2465a9bdd85SOleksandr "Alex" Zinenko 
2475a9bdd85SOleksandr "Alex" Zinenko   for (Operation *op : targets)
2485a9bdd85SOleksandr "Alex" Zinenko     mappings.reverse[op].push_back(value);
2495a9bdd85SOleksandr "Alex" Zinenko 
2505a9bdd85SOleksandr "Alex" Zinenko   return success();
2515a9bdd85SOleksandr "Alex" Zinenko }
2525a9bdd85SOleksandr "Alex" Zinenko 
2535a9bdd85SOleksandr "Alex" Zinenko LogicalResult
2545a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::setPayloadValues(Value handle,
2555a9bdd85SOleksandr "Alex" Zinenko                                             ValueRange payloadValues) {
2565a9bdd85SOleksandr "Alex" Zinenko   assert(handle != nullptr && "attempting to set params for a null value");
2575a9bdd85SOleksandr "Alex" Zinenko   assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) &&
2585a9bdd85SOleksandr "Alex" Zinenko          "wrong handle type");
2595a9bdd85SOleksandr "Alex" Zinenko 
2605a9bdd85SOleksandr "Alex" Zinenko   for (Value payload : payloadValues) {
2615a9bdd85SOleksandr "Alex" Zinenko     if (payload)
2625a9bdd85SOleksandr "Alex" Zinenko       continue;
2635a9bdd85SOleksandr "Alex" Zinenko     return emitError(handle.getLoc()) << "attempting to assign a null payload "
2645a9bdd85SOleksandr "Alex" Zinenko                                          "value to this transform handle";
2655a9bdd85SOleksandr "Alex" Zinenko   }
2665a9bdd85SOleksandr "Alex" Zinenko 
2675a9bdd85SOleksandr "Alex" Zinenko   auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType());
2685a9bdd85SOleksandr "Alex" Zinenko   SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues);
2695a9bdd85SOleksandr "Alex" Zinenko   DiagnosedSilenceableFailure result =
2705a9bdd85SOleksandr "Alex" Zinenko       iface.checkPayload(handle.getLoc(), payloadValueVector);
2715a9bdd85SOleksandr "Alex" Zinenko   if (failed(result.checkAndReport()))
2725a9bdd85SOleksandr "Alex" Zinenko     return failure();
2735a9bdd85SOleksandr "Alex" Zinenko 
2745a9bdd85SOleksandr "Alex" Zinenko   Mappings &mappings = getMapping(handle);
2755a9bdd85SOleksandr "Alex" Zinenko   bool inserted =
2765a9bdd85SOleksandr "Alex" Zinenko       mappings.values.insert({handle, std::move(payloadValueVector)}).second;
2775a9bdd85SOleksandr "Alex" Zinenko   assert(
2785a9bdd85SOleksandr "Alex" Zinenko       inserted &&
2795a9bdd85SOleksandr "Alex" Zinenko       "value handle is already associated with another list of payload values");
2805a9bdd85SOleksandr "Alex" Zinenko   (void)inserted;
2815a9bdd85SOleksandr "Alex" Zinenko 
2825a9bdd85SOleksandr "Alex" Zinenko   for (Value payload : payloadValues)
2835a9bdd85SOleksandr "Alex" Zinenko     mappings.reverseValues[payload].push_back(handle);
2845a9bdd85SOleksandr "Alex" Zinenko 
2855a9bdd85SOleksandr "Alex" Zinenko   return success();
2865a9bdd85SOleksandr "Alex" Zinenko }
2875a9bdd85SOleksandr "Alex" Zinenko 
2885a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::setParams(Value value,
2895a9bdd85SOleksandr "Alex" Zinenko                                                    ArrayRef<Param> params) {
2905a9bdd85SOleksandr "Alex" Zinenko   assert(value != nullptr && "attempting to set params for a null value");
2915a9bdd85SOleksandr "Alex" Zinenko 
2925a9bdd85SOleksandr "Alex" Zinenko   for (Attribute attr : params) {
2935a9bdd85SOleksandr "Alex" Zinenko     if (attr)
2945a9bdd85SOleksandr "Alex" Zinenko       continue;
2955a9bdd85SOleksandr "Alex" Zinenko     return emitError(value.getLoc())
2965a9bdd85SOleksandr "Alex" Zinenko            << "attempting to assign a null parameter to this transform value";
2975a9bdd85SOleksandr "Alex" Zinenko   }
2985a9bdd85SOleksandr "Alex" Zinenko 
2995a9bdd85SOleksandr "Alex" Zinenko   auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType());
3005a9bdd85SOleksandr "Alex" Zinenko   assert(value &&
3015a9bdd85SOleksandr "Alex" Zinenko          "cannot associate parameter with a value of non-parameter type");
3025a9bdd85SOleksandr "Alex" Zinenko   DiagnosedSilenceableFailure result =
3035a9bdd85SOleksandr "Alex" Zinenko       valueType.checkPayload(value.getLoc(), params);
3045a9bdd85SOleksandr "Alex" Zinenko   if (failed(result.checkAndReport()))
3055a9bdd85SOleksandr "Alex" Zinenko     return failure();
3065a9bdd85SOleksandr "Alex" Zinenko 
3075a9bdd85SOleksandr "Alex" Zinenko   Mappings &mappings = getMapping(value);
3085a9bdd85SOleksandr "Alex" Zinenko   bool inserted =
3095a9bdd85SOleksandr "Alex" Zinenko       mappings.params.insert({value, llvm::to_vector(params)}).second;
3105a9bdd85SOleksandr "Alex" Zinenko   assert(inserted && "value is already associated with another list of params");
3115a9bdd85SOleksandr "Alex" Zinenko   (void)inserted;
3125a9bdd85SOleksandr "Alex" Zinenko   return success();
3135a9bdd85SOleksandr "Alex" Zinenko }
3145a9bdd85SOleksandr "Alex" Zinenko 
3155a9bdd85SOleksandr "Alex" Zinenko template <typename Mapping, typename Key, typename Mapped>
3165a9bdd85SOleksandr "Alex" Zinenko void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) {
3175a9bdd85SOleksandr "Alex" Zinenko   auto it = mapping.find(key);
3185a9bdd85SOleksandr "Alex" Zinenko   if (it == mapping.end())
3195a9bdd85SOleksandr "Alex" Zinenko     return;
3205a9bdd85SOleksandr "Alex" Zinenko 
3215a9bdd85SOleksandr "Alex" Zinenko   llvm::erase(it->getSecond(), mapped);
3225a9bdd85SOleksandr "Alex" Zinenko   if (it->getSecond().empty())
3235a9bdd85SOleksandr "Alex" Zinenko     mapping.erase(it);
3245a9bdd85SOleksandr "Alex" Zinenko }
3255a9bdd85SOleksandr "Alex" Zinenko 
3265a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::forgetMapping(Value opHandle,
3275a9bdd85SOleksandr "Alex" Zinenko                                               ValueRange origOpFlatResults,
3285a9bdd85SOleksandr "Alex" Zinenko                                               bool allowOutOfScope) {
3295a9bdd85SOleksandr "Alex" Zinenko   Mappings &mappings = getMapping(opHandle, allowOutOfScope);
3305a9bdd85SOleksandr "Alex" Zinenko   for (Operation *op : mappings.direct[opHandle])
3315a9bdd85SOleksandr "Alex" Zinenko     dropMappingEntry(mappings.reverse, op, opHandle);
3325a9bdd85SOleksandr "Alex" Zinenko   mappings.direct.erase(opHandle);
3336c7a3f80SMehdi Amini #if LLVM_ENABLE_ABI_BREAKING_CHECKS
3345a9bdd85SOleksandr "Alex" Zinenko   // Payload IR is removed from the mapping. This invalidates the respective
3355a9bdd85SOleksandr "Alex" Zinenko   // iterators.
3365a9bdd85SOleksandr "Alex" Zinenko   mappings.incrementTimestamp(opHandle);
3375a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
3385a9bdd85SOleksandr "Alex" Zinenko 
3395a9bdd85SOleksandr "Alex" Zinenko   for (Value opResult : origOpFlatResults) {
3405a9bdd85SOleksandr "Alex" Zinenko     SmallVector<Value> resultHandles;
3415a9bdd85SOleksandr "Alex" Zinenko     (void)getHandlesForPayloadValue(opResult, resultHandles);
3425a9bdd85SOleksandr "Alex" Zinenko     for (Value resultHandle : resultHandles) {
3435a9bdd85SOleksandr "Alex" Zinenko       Mappings &localMappings = getMapping(resultHandle);
3445a9bdd85SOleksandr "Alex" Zinenko       dropMappingEntry(localMappings.values, resultHandle, opResult);
3456c7a3f80SMehdi Amini #if LLVM_ENABLE_ABI_BREAKING_CHECKS
3465a9bdd85SOleksandr "Alex" Zinenko       // Payload IR is removed from the mapping. This invalidates the respective
3475a9bdd85SOleksandr "Alex" Zinenko       // iterators.
3485a9bdd85SOleksandr "Alex" Zinenko       mappings.incrementTimestamp(resultHandle);
3495a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
3505a9bdd85SOleksandr "Alex" Zinenko       dropMappingEntry(localMappings.reverseValues, opResult, resultHandle);
3515a9bdd85SOleksandr "Alex" Zinenko     }
3525a9bdd85SOleksandr "Alex" Zinenko   }
3535a9bdd85SOleksandr "Alex" Zinenko }
3545a9bdd85SOleksandr "Alex" Zinenko 
3555a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::forgetValueMapping(
3565a9bdd85SOleksandr "Alex" Zinenko     Value valueHandle, ArrayRef<Operation *> payloadOperations) {
3575a9bdd85SOleksandr "Alex" Zinenko   Mappings &mappings = getMapping(valueHandle);
3585a9bdd85SOleksandr "Alex" Zinenko   for (Value payloadValue : mappings.reverseValues[valueHandle])
3595a9bdd85SOleksandr "Alex" Zinenko     dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle);
3605a9bdd85SOleksandr "Alex" Zinenko   mappings.values.erase(valueHandle);
3616c7a3f80SMehdi Amini #if LLVM_ENABLE_ABI_BREAKING_CHECKS
3625a9bdd85SOleksandr "Alex" Zinenko   // Payload IR is removed from the mapping. This invalidates the respective
3635a9bdd85SOleksandr "Alex" Zinenko   // iterators.
3645a9bdd85SOleksandr "Alex" Zinenko   mappings.incrementTimestamp(valueHandle);
3655a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
3665a9bdd85SOleksandr "Alex" Zinenko 
3675a9bdd85SOleksandr "Alex" Zinenko   for (Operation *payloadOp : payloadOperations) {
3685a9bdd85SOleksandr "Alex" Zinenko     SmallVector<Value> opHandles;
3695a9bdd85SOleksandr "Alex" Zinenko     (void)getHandlesForPayloadOp(payloadOp, opHandles);
3705a9bdd85SOleksandr "Alex" Zinenko     for (Value opHandle : opHandles) {
3715a9bdd85SOleksandr "Alex" Zinenko       Mappings &localMappings = getMapping(opHandle);
3725a9bdd85SOleksandr "Alex" Zinenko       dropMappingEntry(localMappings.direct, opHandle, payloadOp);
3735a9bdd85SOleksandr "Alex" Zinenko       dropMappingEntry(localMappings.reverse, payloadOp, opHandle);
3745a9bdd85SOleksandr "Alex" Zinenko 
3756c7a3f80SMehdi Amini #if LLVM_ENABLE_ABI_BREAKING_CHECKS
3765a9bdd85SOleksandr "Alex" Zinenko       // Payload IR is removed from the mapping. This invalidates the respective
3775a9bdd85SOleksandr "Alex" Zinenko       // iterators.
3785a9bdd85SOleksandr "Alex" Zinenko       localMappings.incrementTimestamp(opHandle);
3795a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
3805a9bdd85SOleksandr "Alex" Zinenko     }
3815a9bdd85SOleksandr "Alex" Zinenko   }
3825a9bdd85SOleksandr "Alex" Zinenko }
3835a9bdd85SOleksandr "Alex" Zinenko 
3845a9bdd85SOleksandr "Alex" Zinenko LogicalResult
3855a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::replacePayloadOp(Operation *op,
3865a9bdd85SOleksandr "Alex" Zinenko                                             Operation *replacement) {
3875a9bdd85SOleksandr "Alex" Zinenko   // TODO: consider invalidating the handles to nested objects here.
3885a9bdd85SOleksandr "Alex" Zinenko 
3895a9bdd85SOleksandr "Alex" Zinenko #ifndef NDEBUG
3905a9bdd85SOleksandr "Alex" Zinenko   for (Value opResult : op->getResults()) {
3915a9bdd85SOleksandr "Alex" Zinenko     SmallVector<Value> valueHandles;
3925a9bdd85SOleksandr "Alex" Zinenko     (void)getHandlesForPayloadValue(opResult, valueHandles,
3935a9bdd85SOleksandr "Alex" Zinenko                                     /*includeOutOfScope=*/true);
3945a9bdd85SOleksandr "Alex" Zinenko     assert(valueHandles.empty() && "expected no mapping to old results");
3955a9bdd85SOleksandr "Alex" Zinenko   }
3965a9bdd85SOleksandr "Alex" Zinenko #endif // NDEBUG
3975a9bdd85SOleksandr "Alex" Zinenko 
3985a9bdd85SOleksandr "Alex" Zinenko   // Drop the mapping between the op and all handles that point to it. Fail if
3995a9bdd85SOleksandr "Alex" Zinenko   // there are no handles.
4005a9bdd85SOleksandr "Alex" Zinenko   SmallVector<Value> opHandles;
4015a9bdd85SOleksandr "Alex" Zinenko   if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true)))
4025a9bdd85SOleksandr "Alex" Zinenko     return failure();
4035a9bdd85SOleksandr "Alex" Zinenko   for (Value handle : opHandles) {
4045a9bdd85SOleksandr "Alex" Zinenko     Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
4055a9bdd85SOleksandr "Alex" Zinenko     dropMappingEntry(mappings.reverse, op, handle);
4065a9bdd85SOleksandr "Alex" Zinenko   }
4075a9bdd85SOleksandr "Alex" Zinenko 
4085a9bdd85SOleksandr "Alex" Zinenko   // Replace the pointed-to object of all handles with the replacement object.
4095a9bdd85SOleksandr "Alex" Zinenko   // In case a payload op was erased (replacement object is nullptr), a nullptr
4105a9bdd85SOleksandr "Alex" Zinenko   // is stored in the mapping. These nullptrs are removed after each transform.
4115a9bdd85SOleksandr "Alex" Zinenko   // Furthermore, nullptrs are not enumerated by payload op iterators. The
4125a9bdd85SOleksandr "Alex" Zinenko   // relative order of ops is preserved.
4135a9bdd85SOleksandr "Alex" Zinenko   //
4145a9bdd85SOleksandr "Alex" Zinenko   // Removing an op from the mapping would be problematic because removing an
4155a9bdd85SOleksandr "Alex" Zinenko   // element from an array invalidates iterators; merely changing the value of
4165a9bdd85SOleksandr "Alex" Zinenko   // elements does not.
4175a9bdd85SOleksandr "Alex" Zinenko   for (Value handle : opHandles) {
4185a9bdd85SOleksandr "Alex" Zinenko     Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
4195a9bdd85SOleksandr "Alex" Zinenko     auto it = mappings.direct.find(handle);
4205a9bdd85SOleksandr "Alex" Zinenko     if (it == mappings.direct.end())
4215a9bdd85SOleksandr "Alex" Zinenko       continue;
4225a9bdd85SOleksandr "Alex" Zinenko 
4235a9bdd85SOleksandr "Alex" Zinenko     SmallVector<Operation *, 2> &association = it->getSecond();
4245a9bdd85SOleksandr "Alex" Zinenko     // Note that an operation may be associated with the handle more than once.
4255a9bdd85SOleksandr "Alex" Zinenko     for (Operation *&mapped : association) {
4265a9bdd85SOleksandr "Alex" Zinenko       if (mapped == op)
4275a9bdd85SOleksandr "Alex" Zinenko         mapped = replacement;
4285a9bdd85SOleksandr "Alex" Zinenko     }
4295a9bdd85SOleksandr "Alex" Zinenko 
4305a9bdd85SOleksandr "Alex" Zinenko     if (replacement) {
4315a9bdd85SOleksandr "Alex" Zinenko       mappings.reverse[replacement].push_back(handle);
4325a9bdd85SOleksandr "Alex" Zinenko     } else {
4335a9bdd85SOleksandr "Alex" Zinenko       opHandlesToCompact.insert(handle);
4345a9bdd85SOleksandr "Alex" Zinenko     }
4355a9bdd85SOleksandr "Alex" Zinenko   }
4365a9bdd85SOleksandr "Alex" Zinenko 
4375a9bdd85SOleksandr "Alex" Zinenko   return success();
4385a9bdd85SOleksandr "Alex" Zinenko }
4395a9bdd85SOleksandr "Alex" Zinenko 
4405a9bdd85SOleksandr "Alex" Zinenko LogicalResult
4415a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::replacePayloadValue(Value value, Value replacement) {
4425a9bdd85SOleksandr "Alex" Zinenko   SmallVector<Value> valueHandles;
4435a9bdd85SOleksandr "Alex" Zinenko   if (failed(getHandlesForPayloadValue(value, valueHandles,
4445a9bdd85SOleksandr "Alex" Zinenko                                        /*includeOutOfScope=*/true)))
4455a9bdd85SOleksandr "Alex" Zinenko     return failure();
4465a9bdd85SOleksandr "Alex" Zinenko 
4475a9bdd85SOleksandr "Alex" Zinenko   for (Value handle : valueHandles) {
4485a9bdd85SOleksandr "Alex" Zinenko     Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
4495a9bdd85SOleksandr "Alex" Zinenko     dropMappingEntry(mappings.reverseValues, value, handle);
4505a9bdd85SOleksandr "Alex" Zinenko 
4515a9bdd85SOleksandr "Alex" Zinenko     // If replacing with null, that is erasing the mapping, drop the mapping
4525a9bdd85SOleksandr "Alex" Zinenko     // between the handles and the IR objects
4535a9bdd85SOleksandr "Alex" Zinenko     if (!replacement) {
4545a9bdd85SOleksandr "Alex" Zinenko       dropMappingEntry(mappings.values, handle, value);
4556c7a3f80SMehdi Amini #if LLVM_ENABLE_ABI_BREAKING_CHECKS
4565a9bdd85SOleksandr "Alex" Zinenko       // Payload IR is removed from the mapping. This invalidates the respective
4575a9bdd85SOleksandr "Alex" Zinenko       // iterators.
4585a9bdd85SOleksandr "Alex" Zinenko       mappings.incrementTimestamp(handle);
4595a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
4605a9bdd85SOleksandr "Alex" Zinenko     } else {
4615a9bdd85SOleksandr "Alex" Zinenko       auto it = mappings.values.find(handle);
4625a9bdd85SOleksandr "Alex" Zinenko       if (it == mappings.values.end())
4635a9bdd85SOleksandr "Alex" Zinenko         continue;
4645a9bdd85SOleksandr "Alex" Zinenko 
4655a9bdd85SOleksandr "Alex" Zinenko       SmallVector<Value> &association = it->getSecond();
4665a9bdd85SOleksandr "Alex" Zinenko       for (Value &mapped : association) {
4675a9bdd85SOleksandr "Alex" Zinenko         if (mapped == value)
4685a9bdd85SOleksandr "Alex" Zinenko           mapped = replacement;
4695a9bdd85SOleksandr "Alex" Zinenko       }
4705a9bdd85SOleksandr "Alex" Zinenko       mappings.reverseValues[replacement].push_back(handle);
4715a9bdd85SOleksandr "Alex" Zinenko     }
4725a9bdd85SOleksandr "Alex" Zinenko   }
4735a9bdd85SOleksandr "Alex" Zinenko 
4745a9bdd85SOleksandr "Alex" Zinenko   return success();
4755a9bdd85SOleksandr "Alex" Zinenko }
4765a9bdd85SOleksandr "Alex" Zinenko 
4775a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::recordOpHandleInvalidationOne(
4785a9bdd85SOleksandr "Alex" Zinenko     OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors,
4795a9bdd85SOleksandr "Alex" Zinenko     Operation *payloadOp, Value otherHandle, Value throughValue,
4805a9bdd85SOleksandr "Alex" Zinenko     transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
4815a9bdd85SOleksandr "Alex" Zinenko   // If the op is associated with invalidated handle, skip the check as it
4825a9bdd85SOleksandr "Alex" Zinenko   // may be reading invalid IR. This also ensures we report the first
4835a9bdd85SOleksandr "Alex" Zinenko   // invalidation and not the last one.
4845a9bdd85SOleksandr "Alex" Zinenko   if (invalidatedHandles.count(otherHandle) ||
4855a9bdd85SOleksandr "Alex" Zinenko       newlyInvalidated.count(otherHandle))
4865a9bdd85SOleksandr "Alex" Zinenko     return;
4875a9bdd85SOleksandr "Alex" Zinenko 
4885a9bdd85SOleksandr "Alex" Zinenko   FULL_LDBG("--recordOpHandleInvalidationOne\n");
4895a9bdd85SOleksandr "Alex" Zinenko   DEBUG_WITH_TYPE(
4905a9bdd85SOleksandr "Alex" Zinenko       DEBUG_TYPE_FULL,
4915a9bdd85SOleksandr "Alex" Zinenko       llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ",
4925a9bdd85SOleksandr "Alex" Zinenko                             [](Operation *op) { llvm::dbgs() << *op; });
4935a9bdd85SOleksandr "Alex" Zinenko       llvm::dbgs() << "\n");
4945a9bdd85SOleksandr "Alex" Zinenko 
4955a9bdd85SOleksandr "Alex" Zinenko   Operation *owner = consumingHandle.getOwner();
4965a9bdd85SOleksandr "Alex" Zinenko   unsigned operandNo = consumingHandle.getOperandNumber();
4975a9bdd85SOleksandr "Alex" Zinenko   for (Operation *ancestor : potentialAncestors) {
4985a9bdd85SOleksandr "Alex" Zinenko     // clang-format off
4995a9bdd85SOleksandr "Alex" Zinenko     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
5005a9bdd85SOleksandr "Alex" Zinenko       { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
5015a9bdd85SOleksandr "Alex" Zinenko     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
5025a9bdd85SOleksandr "Alex" Zinenko       { (DBGS() << "----of payload with name: "
5035a9bdd85SOleksandr "Alex" Zinenko                 << payloadOp->getName().getIdentifier() << "\n"); });
5045a9bdd85SOleksandr "Alex" Zinenko     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
5055a9bdd85SOleksandr "Alex" Zinenko       { (DBGS() << "----of payload: " << *payloadOp << "\n"); });
5065a9bdd85SOleksandr "Alex" Zinenko     // clang-format on
5075a9bdd85SOleksandr "Alex" Zinenko     if (!ancestor->isAncestor(payloadOp))
5085a9bdd85SOleksandr "Alex" Zinenko       continue;
5095a9bdd85SOleksandr "Alex" Zinenko 
5105a9bdd85SOleksandr "Alex" Zinenko     // Make sure the error-reporting lambda doesn't capture anything
5115a9bdd85SOleksandr "Alex" Zinenko     // by-reference because it will go out of scope. Additionally, extract
5125a9bdd85SOleksandr "Alex" Zinenko     // location from Payload IR ops because the ops themselves may be
5135a9bdd85SOleksandr "Alex" Zinenko     // deleted before the lambda gets called.
5145a9bdd85SOleksandr "Alex" Zinenko     Location ancestorLoc = ancestor->getLoc();
5155a9bdd85SOleksandr "Alex" Zinenko     Location opLoc = payloadOp->getLoc();
5165a9bdd85SOleksandr "Alex" Zinenko     std::optional<Location> throughValueLoc =
5175a9bdd85SOleksandr "Alex" Zinenko         throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt;
5185a9bdd85SOleksandr "Alex" Zinenko     newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo,
5195a9bdd85SOleksandr "Alex" Zinenko                                      otherHandle,
5205a9bdd85SOleksandr "Alex" Zinenko                                      throughValueLoc](Location currentLoc) {
5215a9bdd85SOleksandr "Alex" Zinenko       InFlightDiagnostic diag = emitError(currentLoc)
5225a9bdd85SOleksandr "Alex" Zinenko                                 << "op uses a handle invalidated by a "
5235a9bdd85SOleksandr "Alex" Zinenko                                    "previously executed transform op";
5245a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops";
5255a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote(owner->getLoc())
5265a9bdd85SOleksandr "Alex" Zinenko           << "invalidated by this transform op that consumes its operand #"
5275a9bdd85SOleksandr "Alex" Zinenko           << operandNo
5285a9bdd85SOleksandr "Alex" Zinenko           << " and invalidates all handles to payload IR entities associated "
5295a9bdd85SOleksandr "Alex" Zinenko              "with this operand and entities nested in them";
5305a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote(ancestorLoc) << "ancestor payload op";
5315a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote(opLoc) << "nested payload op";
5325a9bdd85SOleksandr "Alex" Zinenko       if (throughValueLoc) {
5335a9bdd85SOleksandr "Alex" Zinenko         diag.attachNote(*throughValueLoc)
5345a9bdd85SOleksandr "Alex" Zinenko             << "consumed handle points to this payload value";
5355a9bdd85SOleksandr "Alex" Zinenko       }
5365a9bdd85SOleksandr "Alex" Zinenko     };
5375a9bdd85SOleksandr "Alex" Zinenko   }
5385a9bdd85SOleksandr "Alex" Zinenko }
5395a9bdd85SOleksandr "Alex" Zinenko 
5405a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::recordValueHandleInvalidationByOpHandleOne(
5415a9bdd85SOleksandr "Alex" Zinenko     OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors,
5425a9bdd85SOleksandr "Alex" Zinenko     Value payloadValue, Value valueHandle,
5435a9bdd85SOleksandr "Alex" Zinenko     transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
5445a9bdd85SOleksandr "Alex" Zinenko   // If the op is associated with invalidated handle, skip the check as it
5455a9bdd85SOleksandr "Alex" Zinenko   // may be reading invalid IR. This also ensures we report the first
5465a9bdd85SOleksandr "Alex" Zinenko   // invalidation and not the last one.
5475a9bdd85SOleksandr "Alex" Zinenko   if (invalidatedHandles.count(valueHandle) ||
5485a9bdd85SOleksandr "Alex" Zinenko       newlyInvalidated.count(valueHandle))
5495a9bdd85SOleksandr "Alex" Zinenko     return;
5505a9bdd85SOleksandr "Alex" Zinenko 
5515a9bdd85SOleksandr "Alex" Zinenko   for (Operation *ancestor : potentialAncestors) {
5525a9bdd85SOleksandr "Alex" Zinenko     Operation *definingOp;
5535a9bdd85SOleksandr "Alex" Zinenko     std::optional<unsigned> resultNo;
5545a9bdd85SOleksandr "Alex" Zinenko     unsigned argumentNo = std::numeric_limits<unsigned>::max();
5555a9bdd85SOleksandr "Alex" Zinenko     unsigned blockNo = std::numeric_limits<unsigned>::max();
5565a9bdd85SOleksandr "Alex" Zinenko     unsigned regionNo = std::numeric_limits<unsigned>::max();
5575a9bdd85SOleksandr "Alex" Zinenko     if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
5585a9bdd85SOleksandr "Alex" Zinenko       definingOp = opResult.getOwner();
5595a9bdd85SOleksandr "Alex" Zinenko       resultNo = opResult.getResultNumber();
5605a9bdd85SOleksandr "Alex" Zinenko     } else {
5615a9bdd85SOleksandr "Alex" Zinenko       auto arg = llvm::cast<BlockArgument>(payloadValue);
5625a9bdd85SOleksandr "Alex" Zinenko       definingOp = arg.getParentBlock()->getParentOp();
5635a9bdd85SOleksandr "Alex" Zinenko       argumentNo = arg.getArgNumber();
5645a9bdd85SOleksandr "Alex" Zinenko       blockNo = std::distance(arg.getOwner()->getParent()->begin(),
5655a9bdd85SOleksandr "Alex" Zinenko                               arg.getOwner()->getIterator());
5665a9bdd85SOleksandr "Alex" Zinenko       regionNo = arg.getOwner()->getParent()->getRegionNumber();
5675a9bdd85SOleksandr "Alex" Zinenko     }
5685a9bdd85SOleksandr "Alex" Zinenko     assert(definingOp && "expected the value to be defined by an op as result "
5695a9bdd85SOleksandr "Alex" Zinenko                          "or block argument");
5705a9bdd85SOleksandr "Alex" Zinenko     if (!ancestor->isAncestor(definingOp))
5715a9bdd85SOleksandr "Alex" Zinenko       continue;
5725a9bdd85SOleksandr "Alex" Zinenko 
5735a9bdd85SOleksandr "Alex" Zinenko     Operation *owner = opHandle.getOwner();
5745a9bdd85SOleksandr "Alex" Zinenko     unsigned operandNo = opHandle.getOperandNumber();
5755a9bdd85SOleksandr "Alex" Zinenko     Location ancestorLoc = ancestor->getLoc();
5765a9bdd85SOleksandr "Alex" Zinenko     Location opLoc = definingOp->getLoc();
5775a9bdd85SOleksandr "Alex" Zinenko     Location valueLoc = payloadValue.getLoc();
5785a9bdd85SOleksandr "Alex" Zinenko     newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo,
5795a9bdd85SOleksandr "Alex" Zinenko                                      argumentNo, blockNo, regionNo, ancestorLoc,
5805a9bdd85SOleksandr "Alex" Zinenko                                      opLoc, valueLoc](Location currentLoc) {
5815a9bdd85SOleksandr "Alex" Zinenko       InFlightDiagnostic diag = emitError(currentLoc)
5825a9bdd85SOleksandr "Alex" Zinenko                                 << "op uses a handle invalidated by a "
5835a9bdd85SOleksandr "Alex" Zinenko                                    "previously executed transform op";
5845a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote(valueHandle.getLoc()) << "invalidated handle";
5855a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote(owner->getLoc())
5865a9bdd85SOleksandr "Alex" Zinenko           << "invalidated by this transform op that consumes its operand #"
5875a9bdd85SOleksandr "Alex" Zinenko           << operandNo
5885a9bdd85SOleksandr "Alex" Zinenko           << " and invalidates all handles to payload IR entities "
5895a9bdd85SOleksandr "Alex" Zinenko              "associated with this operand and entities nested in them";
5905a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote(ancestorLoc)
5915a9bdd85SOleksandr "Alex" Zinenko           << "ancestor op associated with the consumed handle";
5925a9bdd85SOleksandr "Alex" Zinenko       if (resultNo) {
5935a9bdd85SOleksandr "Alex" Zinenko         diag.attachNote(opLoc)
5945a9bdd85SOleksandr "Alex" Zinenko             << "op defining the value as result #" << *resultNo;
5955a9bdd85SOleksandr "Alex" Zinenko       } else {
5965a9bdd85SOleksandr "Alex" Zinenko         diag.attachNote(opLoc)
5975a9bdd85SOleksandr "Alex" Zinenko             << "op defining the value as block argument #" << argumentNo
5985a9bdd85SOleksandr "Alex" Zinenko             << " of block #" << blockNo << " in region #" << regionNo;
5995a9bdd85SOleksandr "Alex" Zinenko       }
6005a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote(valueLoc) << "payload value";
6015a9bdd85SOleksandr "Alex" Zinenko     };
6025a9bdd85SOleksandr "Alex" Zinenko   }
6035a9bdd85SOleksandr "Alex" Zinenko }
6045a9bdd85SOleksandr "Alex" Zinenko 
6055a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::recordOpHandleInvalidation(
6065a9bdd85SOleksandr "Alex" Zinenko     OpOperand &handle, ArrayRef<Operation *> potentialAncestors,
6075a9bdd85SOleksandr "Alex" Zinenko     Value throughValue,
6085a9bdd85SOleksandr "Alex" Zinenko     transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
6095a9bdd85SOleksandr "Alex" Zinenko 
6105a9bdd85SOleksandr "Alex" Zinenko   if (potentialAncestors.empty()) {
6115a9bdd85SOleksandr "Alex" Zinenko     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
6125a9bdd85SOleksandr "Alex" Zinenko       (DBGS() << "----recording invalidation for empty handle: " << handle.get()
6135a9bdd85SOleksandr "Alex" Zinenko               << "\n");
6145a9bdd85SOleksandr "Alex" Zinenko     });
6155a9bdd85SOleksandr "Alex" Zinenko 
6165a9bdd85SOleksandr "Alex" Zinenko     Operation *owner = handle.getOwner();
6175a9bdd85SOleksandr "Alex" Zinenko     unsigned operandNo = handle.getOperandNumber();
6185a9bdd85SOleksandr "Alex" Zinenko     newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) {
6195a9bdd85SOleksandr "Alex" Zinenko       InFlightDiagnostic diag = emitError(currentLoc)
6205a9bdd85SOleksandr "Alex" Zinenko                                 << "op uses a handle associated with empty "
6215a9bdd85SOleksandr "Alex" Zinenko                                    "payload and invalidated by a "
6225a9bdd85SOleksandr "Alex" Zinenko                                    "previously executed transform op";
6235a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote(owner->getLoc())
6245a9bdd85SOleksandr "Alex" Zinenko           << "invalidated by this transform op that consumes its operand #"
6255a9bdd85SOleksandr "Alex" Zinenko           << operandNo;
6265a9bdd85SOleksandr "Alex" Zinenko     };
6275a9bdd85SOleksandr "Alex" Zinenko     return;
6285a9bdd85SOleksandr "Alex" Zinenko   }
6295a9bdd85SOleksandr "Alex" Zinenko 
6305a9bdd85SOleksandr "Alex" Zinenko   // Iterate over the mapping and invalidate aliasing handles. This is quite
6315a9bdd85SOleksandr "Alex" Zinenko   // expensive and only necessary for error reporting in case of transform
6325a9bdd85SOleksandr "Alex" Zinenko   // dialect misuse with dangling handles. Iteration over the handles is based
6335a9bdd85SOleksandr "Alex" Zinenko   // on the assumption that the number of handles is significantly less than the
6345a9bdd85SOleksandr "Alex" Zinenko   // number of IR objects (operations and values). Alternatively, we could walk
6355a9bdd85SOleksandr "Alex" Zinenko   // the IR nested in each payload op associated with the given handle and look
6365a9bdd85SOleksandr "Alex" Zinenko   // for handles associated with each operation and value.
6375a9bdd85SOleksandr "Alex" Zinenko   for (const auto &[region, mapping] : llvm::reverse(mappings)) {
6385a9bdd85SOleksandr "Alex" Zinenko     // Go over all op handle mappings and mark as invalidated any handle
6395a9bdd85SOleksandr "Alex" Zinenko     // pointing to any of the payload ops associated with the given handle or
6405a9bdd85SOleksandr "Alex" Zinenko     // any op nested in them.
6415a9bdd85SOleksandr "Alex" Zinenko     for (const auto &[payloadOp, otherHandles] : mapping->reverse) {
6425a9bdd85SOleksandr "Alex" Zinenko       for (Value otherHandle : otherHandles)
6435a9bdd85SOleksandr "Alex" Zinenko         recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp,
6445a9bdd85SOleksandr "Alex" Zinenko                                       otherHandle, throughValue,
6455a9bdd85SOleksandr "Alex" Zinenko                                       newlyInvalidated);
6465a9bdd85SOleksandr "Alex" Zinenko     }
6475a9bdd85SOleksandr "Alex" Zinenko     // Go over all value handle mappings and mark as invalidated any handle
6485a9bdd85SOleksandr "Alex" Zinenko     // pointing to any result of the payload op associated with the given handle
6495a9bdd85SOleksandr "Alex" Zinenko     // or any op nested in them. Similarly invalidate handles to argument of
6505a9bdd85SOleksandr "Alex" Zinenko     // blocks belonging to any region of any payload op associated with the
6515a9bdd85SOleksandr "Alex" Zinenko     // given handle or any op nested in them.
6525a9bdd85SOleksandr "Alex" Zinenko     for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) {
6535a9bdd85SOleksandr "Alex" Zinenko       for (Value valueHandle : valueHandles)
6545a9bdd85SOleksandr "Alex" Zinenko         recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors,
6555a9bdd85SOleksandr "Alex" Zinenko                                                    payloadValue, valueHandle,
6565a9bdd85SOleksandr "Alex" Zinenko                                                    newlyInvalidated);
6575a9bdd85SOleksandr "Alex" Zinenko     }
6585a9bdd85SOleksandr "Alex" Zinenko 
6595a9bdd85SOleksandr "Alex" Zinenko     // Stop lookup when reaching a region that is isolated from above.
6605a9bdd85SOleksandr "Alex" Zinenko     if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>())
6615a9bdd85SOleksandr "Alex" Zinenko       break;
6625a9bdd85SOleksandr "Alex" Zinenko   }
6635a9bdd85SOleksandr "Alex" Zinenko }
6645a9bdd85SOleksandr "Alex" Zinenko 
6655a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::recordValueHandleInvalidation(
6665a9bdd85SOleksandr "Alex" Zinenko     OpOperand &valueHandle,
6675a9bdd85SOleksandr "Alex" Zinenko     transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
6685a9bdd85SOleksandr "Alex" Zinenko   // Invalidate other handles to the same value.
6695a9bdd85SOleksandr "Alex" Zinenko   for (Value payloadValue : getPayloadValuesView(valueHandle.get())) {
6705a9bdd85SOleksandr "Alex" Zinenko     SmallVector<Value> otherValueHandles;
6715a9bdd85SOleksandr "Alex" Zinenko     (void)getHandlesForPayloadValue(payloadValue, otherValueHandles);
6725a9bdd85SOleksandr "Alex" Zinenko     for (Value otherHandle : otherValueHandles) {
6735a9bdd85SOleksandr "Alex" Zinenko       Operation *owner = valueHandle.getOwner();
6745a9bdd85SOleksandr "Alex" Zinenko       unsigned operandNo = valueHandle.getOperandNumber();
6755a9bdd85SOleksandr "Alex" Zinenko       Location valueLoc = payloadValue.getLoc();
6765a9bdd85SOleksandr "Alex" Zinenko       newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo,
6775a9bdd85SOleksandr "Alex" Zinenko                                        valueLoc](Location currentLoc) {
6785a9bdd85SOleksandr "Alex" Zinenko         InFlightDiagnostic diag = emitError(currentLoc)
6795a9bdd85SOleksandr "Alex" Zinenko                                   << "op uses a handle invalidated by a "
6805a9bdd85SOleksandr "Alex" Zinenko                                      "previously executed transform op";
6815a9bdd85SOleksandr "Alex" Zinenko         diag.attachNote(otherHandle.getLoc()) << "invalidated handle";
6825a9bdd85SOleksandr "Alex" Zinenko         diag.attachNote(owner->getLoc())
6835a9bdd85SOleksandr "Alex" Zinenko             << "invalidated by this transform op that consumes its operand #"
6845a9bdd85SOleksandr "Alex" Zinenko             << operandNo
6855a9bdd85SOleksandr "Alex" Zinenko             << " and invalidates handles to the same values as associated with "
6865a9bdd85SOleksandr "Alex" Zinenko                "it";
6875a9bdd85SOleksandr "Alex" Zinenko         diag.attachNote(valueLoc) << "payload value";
6885a9bdd85SOleksandr "Alex" Zinenko       };
6895a9bdd85SOleksandr "Alex" Zinenko     }
6905a9bdd85SOleksandr "Alex" Zinenko 
6915a9bdd85SOleksandr "Alex" Zinenko     if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) {
6925a9bdd85SOleksandr "Alex" Zinenko       Operation *payloadOp = opResult.getOwner();
6935a9bdd85SOleksandr "Alex" Zinenko       recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue,
6945a9bdd85SOleksandr "Alex" Zinenko                                  newlyInvalidated);
6955a9bdd85SOleksandr "Alex" Zinenko     } else {
6965a9bdd85SOleksandr "Alex" Zinenko       auto arg = llvm::dyn_cast<BlockArgument>(payloadValue);
6975a9bdd85SOleksandr "Alex" Zinenko       for (Operation &payloadOp : *arg.getOwner())
6985a9bdd85SOleksandr "Alex" Zinenko         recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue,
6995a9bdd85SOleksandr "Alex" Zinenko                                    newlyInvalidated);
7005a9bdd85SOleksandr "Alex" Zinenko     }
7015a9bdd85SOleksandr "Alex" Zinenko   }
7025a9bdd85SOleksandr "Alex" Zinenko }
7035a9bdd85SOleksandr "Alex" Zinenko 
7045a9bdd85SOleksandr "Alex" Zinenko /// Checks that the operation does not use invalidated handles as operands.
7055a9bdd85SOleksandr "Alex" Zinenko /// Reports errors and returns failure if it does. Otherwise, invalidates the
7065a9bdd85SOleksandr "Alex" Zinenko /// handles consumed by the operation as well as any handles pointing to payload
7075a9bdd85SOleksandr "Alex" Zinenko /// IR operations nested in the operations associated with the consumed handles.
7085a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl(
7095a9bdd85SOleksandr "Alex" Zinenko     transform::TransformOpInterface transform,
7105a9bdd85SOleksandr "Alex" Zinenko     transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const {
7115a9bdd85SOleksandr "Alex" Zinenko   FULL_LDBG("--Start checkAndRecordHandleInvalidation\n");
7125a9bdd85SOleksandr "Alex" Zinenko   auto memoryEffectsIface =
7135a9bdd85SOleksandr "Alex" Zinenko       cast<MemoryEffectOpInterface>(transform.getOperation());
7145a9bdd85SOleksandr "Alex" Zinenko   SmallVector<MemoryEffects::EffectInstance> effects;
7155a9bdd85SOleksandr "Alex" Zinenko   memoryEffectsIface.getEffectsOnResource(
7165a9bdd85SOleksandr "Alex" Zinenko       transform::TransformMappingResource::get(), effects);
7175a9bdd85SOleksandr "Alex" Zinenko 
7185a9bdd85SOleksandr "Alex" Zinenko   for (OpOperand &target : transform->getOpOperands()) {
7195a9bdd85SOleksandr "Alex" Zinenko     DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
7205a9bdd85SOleksandr "Alex" Zinenko       (DBGS() << "----iterate on handle: " << target.get() << "\n");
7215a9bdd85SOleksandr "Alex" Zinenko     });
7225a9bdd85SOleksandr "Alex" Zinenko     // If the operand uses an invalidated handle, report it. If the operation
7235a9bdd85SOleksandr "Alex" Zinenko     // allows handles to point to repeated payload operations, only report
7245a9bdd85SOleksandr "Alex" Zinenko     // pre-existing invalidation errors. Otherwise, also report invalidations
7255a9bdd85SOleksandr "Alex" Zinenko     // caused by the current transform operation affecting its other operands.
7265a9bdd85SOleksandr "Alex" Zinenko     auto it = invalidatedHandles.find(target.get());
7275a9bdd85SOleksandr "Alex" Zinenko     auto nit = newlyInvalidated.find(target.get());
7285a9bdd85SOleksandr "Alex" Zinenko     if (it != invalidatedHandles.end()) {
7295a9bdd85SOleksandr "Alex" Zinenko       FULL_LDBG("--End checkAndRecordHandleInvalidation, found already "
7305a9bdd85SOleksandr "Alex" Zinenko                 "invalidated -> FAILURE\n");
7315a9bdd85SOleksandr "Alex" Zinenko       return it->getSecond()(transform->getLoc()), failure();
7325a9bdd85SOleksandr "Alex" Zinenko     }
7335a9bdd85SOleksandr "Alex" Zinenko     if (!transform.allowsRepeatedHandleOperands() &&
7345a9bdd85SOleksandr "Alex" Zinenko         nit != newlyInvalidated.end()) {
7355a9bdd85SOleksandr "Alex" Zinenko       FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly "
7365a9bdd85SOleksandr "Alex" Zinenko                 "invalidated (by this op) -> FAILURE\n");
7375a9bdd85SOleksandr "Alex" Zinenko       return nit->getSecond()(transform->getLoc()), failure();
7385a9bdd85SOleksandr "Alex" Zinenko     }
7395a9bdd85SOleksandr "Alex" Zinenko 
7405a9bdd85SOleksandr "Alex" Zinenko     // Invalidate handles pointing to the operations nested in the operation
7415a9bdd85SOleksandr "Alex" Zinenko     // associated with the handle consumed by this operation.
7425a9bdd85SOleksandr "Alex" Zinenko     auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) {
7435a9bdd85SOleksandr "Alex" Zinenko       return isa<MemoryEffects::Free>(effect.getEffect()) &&
7445a9bdd85SOleksandr "Alex" Zinenko              effect.getValue() == target.get();
7455a9bdd85SOleksandr "Alex" Zinenko     };
7465a9bdd85SOleksandr "Alex" Zinenko     if (llvm::any_of(effects, consumesTarget)) {
7475a9bdd85SOleksandr "Alex" Zinenko       FULL_LDBG("----found consume effect\n");
7485a9bdd85SOleksandr "Alex" Zinenko       if (llvm::isa<transform::TransformHandleTypeInterface>(
7495a9bdd85SOleksandr "Alex" Zinenko               target.get().getType())) {
7505a9bdd85SOleksandr "Alex" Zinenko         FULL_LDBG("----recordOpHandleInvalidation\n");
7515a9bdd85SOleksandr "Alex" Zinenko         SmallVector<Operation *> payloadOps =
7525a9bdd85SOleksandr "Alex" Zinenko             llvm::to_vector(getPayloadOps(target.get()));
7535a9bdd85SOleksandr "Alex" Zinenko         recordOpHandleInvalidation(target, payloadOps, nullptr,
7545a9bdd85SOleksandr "Alex" Zinenko                                    newlyInvalidated);
7555a9bdd85SOleksandr "Alex" Zinenko       } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
7565a9bdd85SOleksandr "Alex" Zinenko                      target.get().getType())) {
7575a9bdd85SOleksandr "Alex" Zinenko         FULL_LDBG("----recordValueHandleInvalidation\n");
7585a9bdd85SOleksandr "Alex" Zinenko         recordValueHandleInvalidation(target, newlyInvalidated);
7595a9bdd85SOleksandr "Alex" Zinenko       } else {
7605a9bdd85SOleksandr "Alex" Zinenko         FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
7615a9bdd85SOleksandr "Alex" Zinenko       }
7625a9bdd85SOleksandr "Alex" Zinenko     } else {
7635a9bdd85SOleksandr "Alex" Zinenko       FULL_LDBG("----no consume effect -> SKIP\n");
7645a9bdd85SOleksandr "Alex" Zinenko     }
7655a9bdd85SOleksandr "Alex" Zinenko   }
7665a9bdd85SOleksandr "Alex" Zinenko 
7675a9bdd85SOleksandr "Alex" Zinenko   FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n");
7685a9bdd85SOleksandr "Alex" Zinenko   return success();
7695a9bdd85SOleksandr "Alex" Zinenko }
7705a9bdd85SOleksandr "Alex" Zinenko 
7715a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
7725a9bdd85SOleksandr "Alex" Zinenko     transform::TransformOpInterface transform) {
7735a9bdd85SOleksandr "Alex" Zinenko   InvalidatedHandleMap newlyInvalidated;
7745a9bdd85SOleksandr "Alex" Zinenko   LogicalResult checkResult =
7755a9bdd85SOleksandr "Alex" Zinenko       checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated);
7765a9bdd85SOleksandr "Alex" Zinenko   invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()),
7775a9bdd85SOleksandr "Alex" Zinenko                             std::make_move_iterator(newlyInvalidated.end()));
7785a9bdd85SOleksandr "Alex" Zinenko   return checkResult;
7795a9bdd85SOleksandr "Alex" Zinenko }
7805a9bdd85SOleksandr "Alex" Zinenko 
7815a9bdd85SOleksandr "Alex" Zinenko template <typename T>
7825a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure
7835a9bdd85SOleksandr "Alex" Zinenko checkRepeatedConsumptionInOperand(ArrayRef<T> payload,
7845a9bdd85SOleksandr "Alex" Zinenko                                   transform::TransformOpInterface transform,
7855a9bdd85SOleksandr "Alex" Zinenko                                   unsigned operandNumber) {
7865a9bdd85SOleksandr "Alex" Zinenko   DenseSet<T> seen;
7875a9bdd85SOleksandr "Alex" Zinenko   for (T p : payload) {
7885a9bdd85SOleksandr "Alex" Zinenko     if (!seen.insert(p).second) {
7895a9bdd85SOleksandr "Alex" Zinenko       DiagnosedSilenceableFailure diag =
7905a9bdd85SOleksandr "Alex" Zinenko           transform.emitSilenceableError()
7915a9bdd85SOleksandr "Alex" Zinenko           << "a handle passed as operand #" << operandNumber
7925a9bdd85SOleksandr "Alex" Zinenko           << " and consumed by this operation points to a payload "
7935a9bdd85SOleksandr "Alex" Zinenko              "entity more than once";
7945a9bdd85SOleksandr "Alex" Zinenko       if constexpr (std::is_pointer_v<T>)
7955a9bdd85SOleksandr "Alex" Zinenko         diag.attachNote(p->getLoc()) << "repeated target op";
7965a9bdd85SOleksandr "Alex" Zinenko       else
7975a9bdd85SOleksandr "Alex" Zinenko         diag.attachNote(p.getLoc()) << "repeated target value";
7985a9bdd85SOleksandr "Alex" Zinenko       return diag;
7995a9bdd85SOleksandr "Alex" Zinenko     }
8005a9bdd85SOleksandr "Alex" Zinenko   }
8015a9bdd85SOleksandr "Alex" Zinenko   return DiagnosedSilenceableFailure::success();
8025a9bdd85SOleksandr "Alex" Zinenko }
8035a9bdd85SOleksandr "Alex" Zinenko 
8045a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::compactOpHandles() {
8055a9bdd85SOleksandr "Alex" Zinenko   for (Value handle : opHandlesToCompact) {
8065a9bdd85SOleksandr "Alex" Zinenko     Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true);
8076c7a3f80SMehdi Amini #if LLVM_ENABLE_ABI_BREAKING_CHECKS
8085a9bdd85SOleksandr "Alex" Zinenko     if (llvm::find(mappings.direct[handle], nullptr) !=
8095a9bdd85SOleksandr "Alex" Zinenko         mappings.direct[handle].end())
8105a9bdd85SOleksandr "Alex" Zinenko       // Payload IR is removed from the mapping. This invalidates the respective
8115a9bdd85SOleksandr "Alex" Zinenko       // iterators.
8125a9bdd85SOleksandr "Alex" Zinenko       mappings.incrementTimestamp(handle);
8135a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
8145a9bdd85SOleksandr "Alex" Zinenko     llvm::erase(mappings.direct[handle], nullptr);
8155a9bdd85SOleksandr "Alex" Zinenko   }
8165a9bdd85SOleksandr "Alex" Zinenko   opHandlesToCompact.clear();
8175a9bdd85SOleksandr "Alex" Zinenko }
8185a9bdd85SOleksandr "Alex" Zinenko 
8195a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure
8205a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::applyTransform(TransformOpInterface transform) {
8215a9bdd85SOleksandr "Alex" Zinenko   LLVM_DEBUG({
8225a9bdd85SOleksandr "Alex" Zinenko     DBGS() << "applying: ";
8235a9bdd85SOleksandr "Alex" Zinenko     transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
8245a9bdd85SOleksandr "Alex" Zinenko     llvm::dbgs() << "\n";
8255a9bdd85SOleksandr "Alex" Zinenko   });
8265a9bdd85SOleksandr "Alex" Zinenko   DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
8275a9bdd85SOleksandr "Alex" Zinenko                   DBGS() << "Top-level payload before application:\n"
8285a9bdd85SOleksandr "Alex" Zinenko                          << *getTopLevel() << "\n");
8295a9bdd85SOleksandr "Alex" Zinenko   auto printOnFailureRAII = llvm::make_scope_exit([this] {
8305a9bdd85SOleksandr "Alex" Zinenko     (void)this;
8315a9bdd85SOleksandr "Alex" Zinenko     LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
8325a9bdd85SOleksandr "Alex" Zinenko         llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
8335a9bdd85SOleksandr "Alex" Zinenko   });
8345a9bdd85SOleksandr "Alex" Zinenko 
8355a9bdd85SOleksandr "Alex" Zinenko   // Set current transform op.
8365a9bdd85SOleksandr "Alex" Zinenko   regionStack.back()->currentTransform = transform;
8375a9bdd85SOleksandr "Alex" Zinenko 
8385a9bdd85SOleksandr "Alex" Zinenko   // Expensive checks to detect invalid transform IR.
8395a9bdd85SOleksandr "Alex" Zinenko   if (options.getExpensiveChecksEnabled()) {
8405a9bdd85SOleksandr "Alex" Zinenko     FULL_LDBG("ExpensiveChecksEnabled\n");
8415a9bdd85SOleksandr "Alex" Zinenko     if (failed(checkAndRecordHandleInvalidation(transform)))
8425a9bdd85SOleksandr "Alex" Zinenko       return DiagnosedSilenceableFailure::definiteFailure();
8435a9bdd85SOleksandr "Alex" Zinenko 
8445a9bdd85SOleksandr "Alex" Zinenko     for (OpOperand &operand : transform->getOpOperands()) {
8455a9bdd85SOleksandr "Alex" Zinenko       DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, {
8465a9bdd85SOleksandr "Alex" Zinenko         (DBGS() << "iterate on handle: " << operand.get() << "\n");
8475a9bdd85SOleksandr "Alex" Zinenko       });
8485a9bdd85SOleksandr "Alex" Zinenko       if (!isHandleConsumed(operand.get(), transform)) {
8495a9bdd85SOleksandr "Alex" Zinenko         FULL_LDBG("--handle not consumed -> SKIP\n");
8505a9bdd85SOleksandr "Alex" Zinenko         continue;
8515a9bdd85SOleksandr "Alex" Zinenko       }
8525a9bdd85SOleksandr "Alex" Zinenko       if (transform.allowsRepeatedHandleOperands()) {
8535a9bdd85SOleksandr "Alex" Zinenko         FULL_LDBG("--op allows repeated handles -> SKIP\n");
8545a9bdd85SOleksandr "Alex" Zinenko         continue;
8555a9bdd85SOleksandr "Alex" Zinenko       }
8565a9bdd85SOleksandr "Alex" Zinenko       FULL_LDBG("--handle is consumed\n");
8575a9bdd85SOleksandr "Alex" Zinenko 
8585a9bdd85SOleksandr "Alex" Zinenko       Type operandType = operand.get().getType();
8595a9bdd85SOleksandr "Alex" Zinenko       if (llvm::isa<TransformHandleTypeInterface>(operandType)) {
8605a9bdd85SOleksandr "Alex" Zinenko         FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n");
8615a9bdd85SOleksandr "Alex" Zinenko         DiagnosedSilenceableFailure check =
8625a9bdd85SOleksandr "Alex" Zinenko             checkRepeatedConsumptionInOperand<Operation *>(
8635a9bdd85SOleksandr "Alex" Zinenko                 getPayloadOpsView(operand.get()), transform,
8645a9bdd85SOleksandr "Alex" Zinenko                 operand.getOperandNumber());
8655a9bdd85SOleksandr "Alex" Zinenko         if (!check.succeeded()) {
8665a9bdd85SOleksandr "Alex" Zinenko           FULL_LDBG("----FAILED\n");
8675a9bdd85SOleksandr "Alex" Zinenko           return check;
8685a9bdd85SOleksandr "Alex" Zinenko         }
8695a9bdd85SOleksandr "Alex" Zinenko       } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) {
8705a9bdd85SOleksandr "Alex" Zinenko         FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n");
8715a9bdd85SOleksandr "Alex" Zinenko         DiagnosedSilenceableFailure check =
8725a9bdd85SOleksandr "Alex" Zinenko             checkRepeatedConsumptionInOperand<Value>(
8735a9bdd85SOleksandr "Alex" Zinenko                 getPayloadValuesView(operand.get()), transform,
8745a9bdd85SOleksandr "Alex" Zinenko                 operand.getOperandNumber());
8755a9bdd85SOleksandr "Alex" Zinenko         if (!check.succeeded()) {
8765a9bdd85SOleksandr "Alex" Zinenko           FULL_LDBG("----FAILED\n");
8775a9bdd85SOleksandr "Alex" Zinenko           return check;
8785a9bdd85SOleksandr "Alex" Zinenko         }
8795a9bdd85SOleksandr "Alex" Zinenko       } else {
8805a9bdd85SOleksandr "Alex" Zinenko         FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
8815a9bdd85SOleksandr "Alex" Zinenko       }
8825a9bdd85SOleksandr "Alex" Zinenko     }
8835a9bdd85SOleksandr "Alex" Zinenko   }
8845a9bdd85SOleksandr "Alex" Zinenko 
8855a9bdd85SOleksandr "Alex" Zinenko   // Find which operands are consumed.
8865a9bdd85SOleksandr "Alex" Zinenko   SmallVector<OpOperand *> consumedOperands =
8875a9bdd85SOleksandr "Alex" Zinenko       transform.getConsumedHandleOpOperands();
8885a9bdd85SOleksandr "Alex" Zinenko 
8895a9bdd85SOleksandr "Alex" Zinenko   // Remember the results of the payload ops associated with the consumed
8905a9bdd85SOleksandr "Alex" Zinenko   // op handles or the ops defining the value handles so we can drop the
8915a9bdd85SOleksandr "Alex" Zinenko   // association with them later. This must happen here because the
8925a9bdd85SOleksandr "Alex" Zinenko   // transformation may destroy or mutate them so we cannot traverse the payload
8935a9bdd85SOleksandr "Alex" Zinenko   // IR after that.
8945a9bdd85SOleksandr "Alex" Zinenko   SmallVector<Value> origOpFlatResults;
8955a9bdd85SOleksandr "Alex" Zinenko   SmallVector<Operation *> origAssociatedOps;
8965a9bdd85SOleksandr "Alex" Zinenko   for (OpOperand *opOperand : consumedOperands) {
8975a9bdd85SOleksandr "Alex" Zinenko     Value operand = opOperand->get();
8985a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
8995a9bdd85SOleksandr "Alex" Zinenko       for (Operation *payloadOp : getPayloadOps(operand)) {
9005a9bdd85SOleksandr "Alex" Zinenko         llvm::append_range(origOpFlatResults, payloadOp->getResults());
9015a9bdd85SOleksandr "Alex" Zinenko       }
9025a9bdd85SOleksandr "Alex" Zinenko       continue;
9035a9bdd85SOleksandr "Alex" Zinenko     }
9045a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) {
9055a9bdd85SOleksandr "Alex" Zinenko       for (Value payloadValue : getPayloadValuesView(operand)) {
9065a9bdd85SOleksandr "Alex" Zinenko         if (llvm::isa<OpResult>(payloadValue)) {
9075a9bdd85SOleksandr "Alex" Zinenko           origAssociatedOps.push_back(payloadValue.getDefiningOp());
9085a9bdd85SOleksandr "Alex" Zinenko           continue;
9095a9bdd85SOleksandr "Alex" Zinenko         }
9105a9bdd85SOleksandr "Alex" Zinenko         llvm::append_range(
9115a9bdd85SOleksandr "Alex" Zinenko             origAssociatedOps,
9125a9bdd85SOleksandr "Alex" Zinenko             llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(),
9135a9bdd85SOleksandr "Alex" Zinenko                             [](Operation &op) { return &op; }));
9145a9bdd85SOleksandr "Alex" Zinenko       }
9155a9bdd85SOleksandr "Alex" Zinenko       continue;
9165a9bdd85SOleksandr "Alex" Zinenko     }
9175a9bdd85SOleksandr "Alex" Zinenko     DiagnosedDefiniteFailure diag =
9185a9bdd85SOleksandr "Alex" Zinenko         emitDefiniteFailure(transform->getLoc())
9195a9bdd85SOleksandr "Alex" Zinenko         << "unexpectedly consumed a value that is not a handle as operand #"
9205a9bdd85SOleksandr "Alex" Zinenko         << opOperand->getOperandNumber();
9215a9bdd85SOleksandr "Alex" Zinenko     diag.attachNote(operand.getLoc())
9225a9bdd85SOleksandr "Alex" Zinenko         << "value defined here with type " << operand.getType();
9235a9bdd85SOleksandr "Alex" Zinenko     return diag;
9245a9bdd85SOleksandr "Alex" Zinenko   }
9255a9bdd85SOleksandr "Alex" Zinenko 
9265a9bdd85SOleksandr "Alex" Zinenko   // Prepare rewriter and listener.
9275a9bdd85SOleksandr "Alex" Zinenko   TrackingListenerConfig config;
9285a9bdd85SOleksandr "Alex" Zinenko   config.skipHandleFn = [&](Value handle) {
9295a9bdd85SOleksandr "Alex" Zinenko     // Skip handle if it is dead.
9305a9bdd85SOleksandr "Alex" Zinenko     auto scopeIt =
9315a9bdd85SOleksandr "Alex" Zinenko         llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) {
9325a9bdd85SOleksandr "Alex" Zinenko           return handle.getParentRegion() == scope->region;
9335a9bdd85SOleksandr "Alex" Zinenko         });
9345a9bdd85SOleksandr "Alex" Zinenko     assert(scopeIt != regionStack.rend() &&
9355a9bdd85SOleksandr "Alex" Zinenko            "could not find region scope for handle");
9365a9bdd85SOleksandr "Alex" Zinenko     RegionScope *scope = *scopeIt;
937e323b40bSJianjian Guan     return llvm::all_of(handle.getUsers(), [&](Operation *user) {
938e323b40bSJianjian Guan       return user == scope->currentTransform ||
939e323b40bSJianjian Guan              happensBefore(user, scope->currentTransform);
940e323b40bSJianjian Guan     });
9415a9bdd85SOleksandr "Alex" Zinenko   };
9425a9bdd85SOleksandr "Alex" Zinenko   transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
9435a9bdd85SOleksandr "Alex" Zinenko                                                             config);
9445a9bdd85SOleksandr "Alex" Zinenko   transform::TransformRewriter rewriter(transform->getContext(),
9455a9bdd85SOleksandr "Alex" Zinenko                                         &trackingListener);
9465a9bdd85SOleksandr "Alex" Zinenko 
9475a9bdd85SOleksandr "Alex" Zinenko   // Compute the result but do not short-circuit the silenceable failure case as
9485a9bdd85SOleksandr "Alex" Zinenko   // we still want the handles to propagate properly so the "suppress" mode can
9495a9bdd85SOleksandr "Alex" Zinenko   // proceed on a best effort basis.
9505a9bdd85SOleksandr "Alex" Zinenko   transform::TransformResults results(transform->getNumResults());
9515a9bdd85SOleksandr "Alex" Zinenko   DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this));
9525a9bdd85SOleksandr "Alex" Zinenko   compactOpHandles();
9535a9bdd85SOleksandr "Alex" Zinenko 
9545a9bdd85SOleksandr "Alex" Zinenko   // Error handling: fail if transform or listener failed.
9555a9bdd85SOleksandr "Alex" Zinenko   DiagnosedSilenceableFailure trackingFailure =
9565a9bdd85SOleksandr "Alex" Zinenko       trackingListener.checkAndResetError();
9575a9bdd85SOleksandr "Alex" Zinenko   if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() ||
9585a9bdd85SOleksandr "Alex" Zinenko       transform->hasAttr(FindPayloadReplacementOpInterface::
9595a9bdd85SOleksandr "Alex" Zinenko                              kSilenceTrackingFailuresAttrName)) {
9605a9bdd85SOleksandr "Alex" Zinenko     // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also
9615a9bdd85SOleksandr "Alex" Zinenko     // do not report failures if the above mentioned attribute is set.
9625a9bdd85SOleksandr "Alex" Zinenko     if (trackingFailure.isSilenceableFailure())
9635a9bdd85SOleksandr "Alex" Zinenko       (void)trackingFailure.silence();
9645a9bdd85SOleksandr "Alex" Zinenko     trackingFailure = DiagnosedSilenceableFailure::success();
9655a9bdd85SOleksandr "Alex" Zinenko   }
9665a9bdd85SOleksandr "Alex" Zinenko   if (!trackingFailure.succeeded()) {
9675a9bdd85SOleksandr "Alex" Zinenko     if (result.succeeded()) {
9685a9bdd85SOleksandr "Alex" Zinenko       result = std::move(trackingFailure);
9695a9bdd85SOleksandr "Alex" Zinenko     } else {
9705a9bdd85SOleksandr "Alex" Zinenko       // Transform op errors have precedence, report those first.
9715a9bdd85SOleksandr "Alex" Zinenko       if (result.isSilenceableFailure())
9725a9bdd85SOleksandr "Alex" Zinenko         result.attachNote() << "tracking listener also failed: "
9735a9bdd85SOleksandr "Alex" Zinenko                             << trackingFailure.getMessage();
9745a9bdd85SOleksandr "Alex" Zinenko       (void)trackingFailure.silence();
9755a9bdd85SOleksandr "Alex" Zinenko     }
9765a9bdd85SOleksandr "Alex" Zinenko   }
9775a9bdd85SOleksandr "Alex" Zinenko   if (result.isDefiniteFailure())
9785a9bdd85SOleksandr "Alex" Zinenko     return result;
9795a9bdd85SOleksandr "Alex" Zinenko 
9805a9bdd85SOleksandr "Alex" Zinenko   // If a silenceable failure was produced, some results may be unset, set them
9815a9bdd85SOleksandr "Alex" Zinenko   // to empty lists.
9825a9bdd85SOleksandr "Alex" Zinenko   if (result.isSilenceableFailure())
9835a9bdd85SOleksandr "Alex" Zinenko     results.setRemainingToEmpty(transform);
9845a9bdd85SOleksandr "Alex" Zinenko 
9855a9bdd85SOleksandr "Alex" Zinenko   // Remove the mapping for the operand if it is consumed by the operation. This
9865a9bdd85SOleksandr "Alex" Zinenko   // allows us to catch use-after-free with assertions later on.
9875a9bdd85SOleksandr "Alex" Zinenko   for (OpOperand *opOperand : consumedOperands) {
9885a9bdd85SOleksandr "Alex" Zinenko     Value operand = opOperand->get();
9895a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
9905a9bdd85SOleksandr "Alex" Zinenko       forgetMapping(operand, origOpFlatResults);
9915a9bdd85SOleksandr "Alex" Zinenko     } else if (llvm::isa<TransformValueHandleTypeInterface>(
9925a9bdd85SOleksandr "Alex" Zinenko                    operand.getType())) {
9935a9bdd85SOleksandr "Alex" Zinenko       forgetValueMapping(operand, origAssociatedOps);
9945a9bdd85SOleksandr "Alex" Zinenko     }
9955a9bdd85SOleksandr "Alex" Zinenko   }
9965a9bdd85SOleksandr "Alex" Zinenko 
9975a9bdd85SOleksandr "Alex" Zinenko   if (failed(updateStateFromResults(results, transform->getResults())))
9985a9bdd85SOleksandr "Alex" Zinenko     return DiagnosedSilenceableFailure::definiteFailure();
9995a9bdd85SOleksandr "Alex" Zinenko 
10005a9bdd85SOleksandr "Alex" Zinenko   printOnFailureRAII.release();
10015a9bdd85SOleksandr "Alex" Zinenko   DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
10025a9bdd85SOleksandr "Alex" Zinenko     DBGS() << "Top-level payload:\n";
10035a9bdd85SOleksandr "Alex" Zinenko     getTopLevel()->print(llvm::dbgs());
10045a9bdd85SOleksandr "Alex" Zinenko   });
10055a9bdd85SOleksandr "Alex" Zinenko   return result;
10065a9bdd85SOleksandr "Alex" Zinenko }
10075a9bdd85SOleksandr "Alex" Zinenko 
10085a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::updateStateFromResults(
10095a9bdd85SOleksandr "Alex" Zinenko     const TransformResults &results, ResultRange opResults) {
10105a9bdd85SOleksandr "Alex" Zinenko   for (OpResult result : opResults) {
10115a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<TransformParamTypeInterface>(result.getType())) {
10125a9bdd85SOleksandr "Alex" Zinenko       assert(results.isParam(result.getResultNumber()) &&
10135a9bdd85SOleksandr "Alex" Zinenko              "expected parameters for the parameter-typed result");
10145a9bdd85SOleksandr "Alex" Zinenko       if (failed(
10155a9bdd85SOleksandr "Alex" Zinenko               setParams(result, results.getParams(result.getResultNumber())))) {
10165a9bdd85SOleksandr "Alex" Zinenko         return failure();
10175a9bdd85SOleksandr "Alex" Zinenko       }
10185a9bdd85SOleksandr "Alex" Zinenko     } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) {
10195a9bdd85SOleksandr "Alex" Zinenko       assert(results.isValue(result.getResultNumber()) &&
10205a9bdd85SOleksandr "Alex" Zinenko              "expected values for value-type-result");
10215a9bdd85SOleksandr "Alex" Zinenko       if (failed(setPayloadValues(
10225a9bdd85SOleksandr "Alex" Zinenko               result, results.getValues(result.getResultNumber())))) {
10235a9bdd85SOleksandr "Alex" Zinenko         return failure();
10245a9bdd85SOleksandr "Alex" Zinenko       }
10255a9bdd85SOleksandr "Alex" Zinenko     } else {
10265a9bdd85SOleksandr "Alex" Zinenko       assert(!results.isParam(result.getResultNumber()) &&
10275a9bdd85SOleksandr "Alex" Zinenko              "expected payload ops for the non-parameter typed result");
10285a9bdd85SOleksandr "Alex" Zinenko       if (failed(
10295a9bdd85SOleksandr "Alex" Zinenko               setPayloadOps(result, results.get(result.getResultNumber())))) {
10305a9bdd85SOleksandr "Alex" Zinenko         return failure();
10315a9bdd85SOleksandr "Alex" Zinenko       }
10325a9bdd85SOleksandr "Alex" Zinenko     }
10335a9bdd85SOleksandr "Alex" Zinenko   }
10345a9bdd85SOleksandr "Alex" Zinenko   return success();
10355a9bdd85SOleksandr "Alex" Zinenko }
10365a9bdd85SOleksandr "Alex" Zinenko 
10375a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
10385a9bdd85SOleksandr "Alex" Zinenko // TransformState::Extension
10395a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
10405a9bdd85SOleksandr "Alex" Zinenko 
10415a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::Extension::~Extension() = default;
10425a9bdd85SOleksandr "Alex" Zinenko 
10435a9bdd85SOleksandr "Alex" Zinenko LogicalResult
10445a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::Extension::replacePayloadOp(Operation *op,
10455a9bdd85SOleksandr "Alex" Zinenko                                                        Operation *replacement) {
10465a9bdd85SOleksandr "Alex" Zinenko   // TODO: we may need to invalidate handles to operations and values nested in
10475a9bdd85SOleksandr "Alex" Zinenko   // the operation being replaced.
10485a9bdd85SOleksandr "Alex" Zinenko   return state.replacePayloadOp(op, replacement);
10495a9bdd85SOleksandr "Alex" Zinenko }
10505a9bdd85SOleksandr "Alex" Zinenko 
10515a9bdd85SOleksandr "Alex" Zinenko LogicalResult
10525a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::Extension::replacePayloadValue(Value value,
10535a9bdd85SOleksandr "Alex" Zinenko                                                           Value replacement) {
10545a9bdd85SOleksandr "Alex" Zinenko   return state.replacePayloadValue(value, replacement);
10555a9bdd85SOleksandr "Alex" Zinenko }
10565a9bdd85SOleksandr "Alex" Zinenko 
10575a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
10585a9bdd85SOleksandr "Alex" Zinenko // TransformState::RegionScope
10595a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
10605a9bdd85SOleksandr "Alex" Zinenko 
10615a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::RegionScope::~RegionScope() {
10625a9bdd85SOleksandr "Alex" Zinenko   // Remove handle invalidation notices as handles are going out of scope.
10635a9bdd85SOleksandr "Alex" Zinenko   // The same region may be re-entered leading to incorrect invalidation
10645a9bdd85SOleksandr "Alex" Zinenko   // errors.
10655a9bdd85SOleksandr "Alex" Zinenko   for (Block &block : *region) {
10665a9bdd85SOleksandr "Alex" Zinenko     for (Value handle : block.getArguments()) {
10675a9bdd85SOleksandr "Alex" Zinenko       state.invalidatedHandles.erase(handle);
10685a9bdd85SOleksandr "Alex" Zinenko     }
10695a9bdd85SOleksandr "Alex" Zinenko     for (Operation &op : block) {
10705a9bdd85SOleksandr "Alex" Zinenko       for (Value handle : op.getResults()) {
10715a9bdd85SOleksandr "Alex" Zinenko         state.invalidatedHandles.erase(handle);
10725a9bdd85SOleksandr "Alex" Zinenko       }
10735a9bdd85SOleksandr "Alex" Zinenko     }
10745a9bdd85SOleksandr "Alex" Zinenko   }
10755a9bdd85SOleksandr "Alex" Zinenko 
10765a9bdd85SOleksandr "Alex" Zinenko #if LLVM_ENABLE_ABI_BREAKING_CHECKS
10775a9bdd85SOleksandr "Alex" Zinenko   // Remember pointers to payload ops referenced by the handles going out of
10785a9bdd85SOleksandr "Alex" Zinenko   // scope.
10795a9bdd85SOleksandr "Alex" Zinenko   SmallVector<Operation *> referencedOps =
10805a9bdd85SOleksandr "Alex" Zinenko       llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse));
10815a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
10825a9bdd85SOleksandr "Alex" Zinenko 
10835a9bdd85SOleksandr "Alex" Zinenko   state.mappings.erase(region);
10845a9bdd85SOleksandr "Alex" Zinenko   state.regionStack.pop_back();
10855a9bdd85SOleksandr "Alex" Zinenko }
10865a9bdd85SOleksandr "Alex" Zinenko 
10875a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
10885a9bdd85SOleksandr "Alex" Zinenko // TransformResults
10895a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
10905a9bdd85SOleksandr "Alex" Zinenko 
10915a9bdd85SOleksandr "Alex" Zinenko transform::TransformResults::TransformResults(unsigned numSegments) {
10925a9bdd85SOleksandr "Alex" Zinenko   operations.appendEmptyRows(numSegments);
10935a9bdd85SOleksandr "Alex" Zinenko   params.appendEmptyRows(numSegments);
10945a9bdd85SOleksandr "Alex" Zinenko   values.appendEmptyRows(numSegments);
10955a9bdd85SOleksandr "Alex" Zinenko }
10965a9bdd85SOleksandr "Alex" Zinenko 
10975a9bdd85SOleksandr "Alex" Zinenko void transform::TransformResults::setParams(
10985a9bdd85SOleksandr "Alex" Zinenko     OpResult value, ArrayRef<transform::TransformState::Param> params) {
10995a9bdd85SOleksandr "Alex" Zinenko   int64_t position = value.getResultNumber();
11005a9bdd85SOleksandr "Alex" Zinenko   assert(position < static_cast<int64_t>(this->params.size()) &&
11015a9bdd85SOleksandr "Alex" Zinenko          "setting params for a non-existent handle");
11025a9bdd85SOleksandr "Alex" Zinenko   assert(this->params[position].data() == nullptr && "params already set");
11035a9bdd85SOleksandr "Alex" Zinenko   assert(operations[position].data() == nullptr &&
11045a9bdd85SOleksandr "Alex" Zinenko          "another kind of results already set");
11055a9bdd85SOleksandr "Alex" Zinenko   assert(values[position].data() == nullptr &&
11065a9bdd85SOleksandr "Alex" Zinenko          "another kind of results already set");
11075a9bdd85SOleksandr "Alex" Zinenko   this->params.replace(position, params);
11085a9bdd85SOleksandr "Alex" Zinenko }
11095a9bdd85SOleksandr "Alex" Zinenko 
11105a9bdd85SOleksandr "Alex" Zinenko void transform::TransformResults::setMappedValues(
11115a9bdd85SOleksandr "Alex" Zinenko     OpResult handle, ArrayRef<MappedValue> values) {
11125a9bdd85SOleksandr "Alex" Zinenko   DiagnosedSilenceableFailure diag = dispatchMappedValues(
11135a9bdd85SOleksandr "Alex" Zinenko       handle, values,
11145a9bdd85SOleksandr "Alex" Zinenko       [&](ArrayRef<Operation *> operations) {
11155a9bdd85SOleksandr "Alex" Zinenko         return set(handle, operations), success();
11165a9bdd85SOleksandr "Alex" Zinenko       },
11175a9bdd85SOleksandr "Alex" Zinenko       [&](ArrayRef<Param> params) {
11185a9bdd85SOleksandr "Alex" Zinenko         return setParams(handle, params), success();
11195a9bdd85SOleksandr "Alex" Zinenko       },
11205a9bdd85SOleksandr "Alex" Zinenko       [&](ValueRange payloadValues) {
11215a9bdd85SOleksandr "Alex" Zinenko         return setValues(handle, payloadValues), success();
11225a9bdd85SOleksandr "Alex" Zinenko       });
11235a9bdd85SOleksandr "Alex" Zinenko #ifndef NDEBUG
11245a9bdd85SOleksandr "Alex" Zinenko   if (!diag.succeeded())
11255a9bdd85SOleksandr "Alex" Zinenko     llvm::dbgs() << diag.getStatusString() << "\n";
11265a9bdd85SOleksandr "Alex" Zinenko   assert(diag.succeeded() && "incorrect mapping");
11275a9bdd85SOleksandr "Alex" Zinenko #endif // NDEBUG
11285a9bdd85SOleksandr "Alex" Zinenko   (void)diag.silence();
11295a9bdd85SOleksandr "Alex" Zinenko }
11305a9bdd85SOleksandr "Alex" Zinenko 
11315a9bdd85SOleksandr "Alex" Zinenko void transform::TransformResults::setRemainingToEmpty(
11325a9bdd85SOleksandr "Alex" Zinenko     transform::TransformOpInterface transform) {
11335a9bdd85SOleksandr "Alex" Zinenko   for (OpResult opResult : transform->getResults()) {
11345a9bdd85SOleksandr "Alex" Zinenko     if (!isSet(opResult.getResultNumber()))
11355a9bdd85SOleksandr "Alex" Zinenko       setMappedValues(opResult, {});
11365a9bdd85SOleksandr "Alex" Zinenko   }
11375a9bdd85SOleksandr "Alex" Zinenko }
11385a9bdd85SOleksandr "Alex" Zinenko 
11395a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Operation *>
11405a9bdd85SOleksandr "Alex" Zinenko transform::TransformResults::get(unsigned resultNumber) const {
11415a9bdd85SOleksandr "Alex" Zinenko   assert(resultNumber < operations.size() &&
11425a9bdd85SOleksandr "Alex" Zinenko          "querying results for a non-existent handle");
11435a9bdd85SOleksandr "Alex" Zinenko   assert(operations[resultNumber].data() != nullptr &&
11445a9bdd85SOleksandr "Alex" Zinenko          "querying unset results (values or params expected?)");
11455a9bdd85SOleksandr "Alex" Zinenko   return operations[resultNumber];
11465a9bdd85SOleksandr "Alex" Zinenko }
11475a9bdd85SOleksandr "Alex" Zinenko 
11485a9bdd85SOleksandr "Alex" Zinenko ArrayRef<transform::TransformState::Param>
11495a9bdd85SOleksandr "Alex" Zinenko transform::TransformResults::getParams(unsigned resultNumber) const {
11505a9bdd85SOleksandr "Alex" Zinenko   assert(resultNumber < params.size() &&
11515a9bdd85SOleksandr "Alex" Zinenko          "querying params for a non-existent handle");
11525a9bdd85SOleksandr "Alex" Zinenko   assert(params[resultNumber].data() != nullptr &&
11535a9bdd85SOleksandr "Alex" Zinenko          "querying unset params (ops or values expected?)");
11545a9bdd85SOleksandr "Alex" Zinenko   return params[resultNumber];
11555a9bdd85SOleksandr "Alex" Zinenko }
11565a9bdd85SOleksandr "Alex" Zinenko 
11575a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Value>
11585a9bdd85SOleksandr "Alex" Zinenko transform::TransformResults::getValues(unsigned resultNumber) const {
11595a9bdd85SOleksandr "Alex" Zinenko   assert(resultNumber < values.size() &&
11605a9bdd85SOleksandr "Alex" Zinenko          "querying values for a non-existent handle");
11615a9bdd85SOleksandr "Alex" Zinenko   assert(values[resultNumber].data() != nullptr &&
11625a9bdd85SOleksandr "Alex" Zinenko          "querying unset values (ops or params expected?)");
11635a9bdd85SOleksandr "Alex" Zinenko   return values[resultNumber];
11645a9bdd85SOleksandr "Alex" Zinenko }
11655a9bdd85SOleksandr "Alex" Zinenko 
11665a9bdd85SOleksandr "Alex" Zinenko bool transform::TransformResults::isParam(unsigned resultNumber) const {
11675a9bdd85SOleksandr "Alex" Zinenko   assert(resultNumber < params.size() &&
11685a9bdd85SOleksandr "Alex" Zinenko          "querying association for a non-existent handle");
11695a9bdd85SOleksandr "Alex" Zinenko   return params[resultNumber].data() != nullptr;
11705a9bdd85SOleksandr "Alex" Zinenko }
11715a9bdd85SOleksandr "Alex" Zinenko 
11725a9bdd85SOleksandr "Alex" Zinenko bool transform::TransformResults::isValue(unsigned resultNumber) const {
11735a9bdd85SOleksandr "Alex" Zinenko   assert(resultNumber < values.size() &&
11745a9bdd85SOleksandr "Alex" Zinenko          "querying association for a non-existent handle");
11755a9bdd85SOleksandr "Alex" Zinenko   return values[resultNumber].data() != nullptr;
11765a9bdd85SOleksandr "Alex" Zinenko }
11775a9bdd85SOleksandr "Alex" Zinenko 
11785a9bdd85SOleksandr "Alex" Zinenko bool transform::TransformResults::isSet(unsigned resultNumber) const {
11795a9bdd85SOleksandr "Alex" Zinenko   assert(resultNumber < params.size() &&
11805a9bdd85SOleksandr "Alex" Zinenko          "querying association for a non-existent handle");
11815a9bdd85SOleksandr "Alex" Zinenko   return params[resultNumber].data() != nullptr ||
11825a9bdd85SOleksandr "Alex" Zinenko          operations[resultNumber].data() != nullptr ||
11835a9bdd85SOleksandr "Alex" Zinenko          values[resultNumber].data() != nullptr;
11845a9bdd85SOleksandr "Alex" Zinenko }
11855a9bdd85SOleksandr "Alex" Zinenko 
11865a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
11875a9bdd85SOleksandr "Alex" Zinenko // TrackingListener
11885a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
11895a9bdd85SOleksandr "Alex" Zinenko 
11905a9bdd85SOleksandr "Alex" Zinenko transform::TrackingListener::TrackingListener(TransformState &state,
11915a9bdd85SOleksandr "Alex" Zinenko                                               TransformOpInterface op,
11925a9bdd85SOleksandr "Alex" Zinenko                                               TrackingListenerConfig config)
11935a9bdd85SOleksandr "Alex" Zinenko     : TransformState::Extension(state), transformOp(op), config(config) {
11945a9bdd85SOleksandr "Alex" Zinenko   if (op) {
11955a9bdd85SOleksandr "Alex" Zinenko     for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) {
11965a9bdd85SOleksandr "Alex" Zinenko       consumedHandles.insert(opOperand->get());
11975a9bdd85SOleksandr "Alex" Zinenko     }
11985a9bdd85SOleksandr "Alex" Zinenko   }
11995a9bdd85SOleksandr "Alex" Zinenko }
12005a9bdd85SOleksandr "Alex" Zinenko 
12015a9bdd85SOleksandr "Alex" Zinenko Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) {
12025a9bdd85SOleksandr "Alex" Zinenko   Operation *defOp = nullptr;
12035a9bdd85SOleksandr "Alex" Zinenko   for (Value v : values) {
12045a9bdd85SOleksandr "Alex" Zinenko     // Skip empty values.
12055a9bdd85SOleksandr "Alex" Zinenko     if (!v)
12065a9bdd85SOleksandr "Alex" Zinenko       continue;
12075a9bdd85SOleksandr "Alex" Zinenko     if (!defOp) {
12085a9bdd85SOleksandr "Alex" Zinenko       defOp = v.getDefiningOp();
12095a9bdd85SOleksandr "Alex" Zinenko       continue;
12105a9bdd85SOleksandr "Alex" Zinenko     }
12115a9bdd85SOleksandr "Alex" Zinenko     if (defOp != v.getDefiningOp())
12125a9bdd85SOleksandr "Alex" Zinenko       return nullptr;
12135a9bdd85SOleksandr "Alex" Zinenko   }
12145a9bdd85SOleksandr "Alex" Zinenko   return defOp;
12155a9bdd85SOleksandr "Alex" Zinenko }
12165a9bdd85SOleksandr "Alex" Zinenko 
12175a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
12185a9bdd85SOleksandr "Alex" Zinenko     Operation *&result, Operation *op, ValueRange newValues) const {
12195a9bdd85SOleksandr "Alex" Zinenko   assert(op->getNumResults() == newValues.size() &&
12205a9bdd85SOleksandr "Alex" Zinenko          "invalid number of replacement values");
12215a9bdd85SOleksandr "Alex" Zinenko   SmallVector<Value> values(newValues.begin(), newValues.end());
12225a9bdd85SOleksandr "Alex" Zinenko 
12235a9bdd85SOleksandr "Alex" Zinenko   DiagnosedSilenceableFailure diag = emitSilenceableFailure(
12245a9bdd85SOleksandr "Alex" Zinenko       getTransformOp(), "tracking listener failed to find replacement op "
12255a9bdd85SOleksandr "Alex" Zinenko                         "during application of this transform op");
12265a9bdd85SOleksandr "Alex" Zinenko 
12275a9bdd85SOleksandr "Alex" Zinenko   do {
12285a9bdd85SOleksandr "Alex" Zinenko     // If the replacement values belong to different ops, drop the mapping.
12295a9bdd85SOleksandr "Alex" Zinenko     Operation *defOp = getCommonDefiningOp(values);
12305a9bdd85SOleksandr "Alex" Zinenko     if (!defOp) {
12315a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote() << "replacement values belong to different ops";
12325a9bdd85SOleksandr "Alex" Zinenko       return diag;
12335a9bdd85SOleksandr "Alex" Zinenko     }
12345a9bdd85SOleksandr "Alex" Zinenko 
12355a9bdd85SOleksandr "Alex" Zinenko     // Skip through ops that implement CastOpInterface.
12365a9bdd85SOleksandr "Alex" Zinenko     if (config.skipCastOps && isa<CastOpInterface>(defOp)) {
12375a9bdd85SOleksandr "Alex" Zinenko       values.clear();
12385a9bdd85SOleksandr "Alex" Zinenko       values.assign(defOp->getOperands().begin(), defOp->getOperands().end());
12395a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote(defOp->getLoc())
12405a9bdd85SOleksandr "Alex" Zinenko           << "using output of 'CastOpInterface' op";
12415a9bdd85SOleksandr "Alex" Zinenko       continue;
12425a9bdd85SOleksandr "Alex" Zinenko     }
12435a9bdd85SOleksandr "Alex" Zinenko 
12445a9bdd85SOleksandr "Alex" Zinenko     // If the defining op has the same name or we do not care about the name of
12455a9bdd85SOleksandr "Alex" Zinenko     // op replacements at all, we take it as a replacement.
12465a9bdd85SOleksandr "Alex" Zinenko     if (!config.requireMatchingReplacementOpName ||
12475a9bdd85SOleksandr "Alex" Zinenko         op->getName() == defOp->getName()) {
12485a9bdd85SOleksandr "Alex" Zinenko       result = defOp;
12495a9bdd85SOleksandr "Alex" Zinenko       return DiagnosedSilenceableFailure::success();
12505a9bdd85SOleksandr "Alex" Zinenko     }
12515a9bdd85SOleksandr "Alex" Zinenko 
12525a9bdd85SOleksandr "Alex" Zinenko     // Replacing an op with a constant-like equivalent is a common
12535a9bdd85SOleksandr "Alex" Zinenko     // canonicalization.
12545a9bdd85SOleksandr "Alex" Zinenko     if (defOp->hasTrait<OpTrait::ConstantLike>()) {
12555a9bdd85SOleksandr "Alex" Zinenko       result = defOp;
12565a9bdd85SOleksandr "Alex" Zinenko       return DiagnosedSilenceableFailure::success();
12575a9bdd85SOleksandr "Alex" Zinenko     }
12585a9bdd85SOleksandr "Alex" Zinenko 
12595a9bdd85SOleksandr "Alex" Zinenko     values.clear();
12605a9bdd85SOleksandr "Alex" Zinenko 
12615a9bdd85SOleksandr "Alex" Zinenko     // Skip through ops that implement FindPayloadReplacementOpInterface.
12625a9bdd85SOleksandr "Alex" Zinenko     if (auto findReplacementOpInterface =
12635a9bdd85SOleksandr "Alex" Zinenko             dyn_cast<FindPayloadReplacementOpInterface>(defOp)) {
12645a9bdd85SOleksandr "Alex" Zinenko       values.assign(findReplacementOpInterface.getNextOperands());
12655a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote(defOp->getLoc()) << "using operands provided by "
12665a9bdd85SOleksandr "Alex" Zinenko                                           "'FindPayloadReplacementOpInterface'";
12675a9bdd85SOleksandr "Alex" Zinenko       continue;
12685a9bdd85SOleksandr "Alex" Zinenko     }
12695a9bdd85SOleksandr "Alex" Zinenko   } while (!values.empty());
12705a9bdd85SOleksandr "Alex" Zinenko 
12715a9bdd85SOleksandr "Alex" Zinenko   diag.attachNote() << "ran out of suitable replacement values";
12725a9bdd85SOleksandr "Alex" Zinenko   return diag;
12735a9bdd85SOleksandr "Alex" Zinenko }
12745a9bdd85SOleksandr "Alex" Zinenko 
12755a9bdd85SOleksandr "Alex" Zinenko void transform::TrackingListener::notifyMatchFailure(
12765a9bdd85SOleksandr "Alex" Zinenko     Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
12775a9bdd85SOleksandr "Alex" Zinenko   LLVM_DEBUG({
12785a9bdd85SOleksandr "Alex" Zinenko     Diagnostic diag(loc, DiagnosticSeverity::Remark);
12795a9bdd85SOleksandr "Alex" Zinenko     reasonCallback(diag);
12805a9bdd85SOleksandr "Alex" Zinenko     DBGS() << "Match Failure : " << diag.str() << "\n";
12815a9bdd85SOleksandr "Alex" Zinenko   });
12825a9bdd85SOleksandr "Alex" Zinenko }
12835a9bdd85SOleksandr "Alex" Zinenko 
12845a9bdd85SOleksandr "Alex" Zinenko void transform::TrackingListener::notifyOperationErased(Operation *op) {
12855a9bdd85SOleksandr "Alex" Zinenko   // Remove mappings for result values.
12865a9bdd85SOleksandr "Alex" Zinenko   for (OpResult value : op->getResults())
12875a9bdd85SOleksandr "Alex" Zinenko     (void)replacePayloadValue(value, nullptr);
12885a9bdd85SOleksandr "Alex" Zinenko   // Remove mapping for op.
12895a9bdd85SOleksandr "Alex" Zinenko   (void)replacePayloadOp(op, nullptr);
12905a9bdd85SOleksandr "Alex" Zinenko }
12915a9bdd85SOleksandr "Alex" Zinenko 
12925a9bdd85SOleksandr "Alex" Zinenko void transform::TrackingListener::notifyOperationReplaced(
12935a9bdd85SOleksandr "Alex" Zinenko     Operation *op, ValueRange newValues) {
12945a9bdd85SOleksandr "Alex" Zinenko   assert(op->getNumResults() == newValues.size() &&
12955a9bdd85SOleksandr "Alex" Zinenko          "invalid number of replacement values");
12965a9bdd85SOleksandr "Alex" Zinenko 
12975a9bdd85SOleksandr "Alex" Zinenko   // Replace value handles.
12985a9bdd85SOleksandr "Alex" Zinenko   for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues))
12995a9bdd85SOleksandr "Alex" Zinenko     (void)replacePayloadValue(oldValue, newValue);
13005a9bdd85SOleksandr "Alex" Zinenko 
13015a9bdd85SOleksandr "Alex" Zinenko   // Replace op handle.
13025a9bdd85SOleksandr "Alex" Zinenko   SmallVector<Value> opHandles;
13035a9bdd85SOleksandr "Alex" Zinenko   if (failed(getTransformState().getHandlesForPayloadOp(
13045a9bdd85SOleksandr "Alex" Zinenko           op, opHandles, /*includeOutOfScope=*/true))) {
13055a9bdd85SOleksandr "Alex" Zinenko     // Op is not tracked.
13065a9bdd85SOleksandr "Alex" Zinenko     return;
13075a9bdd85SOleksandr "Alex" Zinenko   }
13085a9bdd85SOleksandr "Alex" Zinenko 
13095a9bdd85SOleksandr "Alex" Zinenko   // Helper function to check if the current transform op consumes any handle
13105a9bdd85SOleksandr "Alex" Zinenko   // that is mapped to `op`.
13115a9bdd85SOleksandr "Alex" Zinenko   //
13125a9bdd85SOleksandr "Alex" Zinenko   // Note: If a handle was consumed, there shouldn't be any alive users, so it
13135a9bdd85SOleksandr "Alex" Zinenko   // is not really necessary to check for consumed handles. However, in case
13145a9bdd85SOleksandr "Alex" Zinenko   // there are indeed alive handles that were consumed (which is undefined
13155a9bdd85SOleksandr "Alex" Zinenko   // behavior) and a replacement op could not be found, we want to fail with a
13165a9bdd85SOleksandr "Alex" Zinenko   // nicer error message: "op uses a handle invalidated..." instead of "could
13175a9bdd85SOleksandr "Alex" Zinenko   // not find replacement op". This nicer error is produced later.
13185a9bdd85SOleksandr "Alex" Zinenko   auto handleWasConsumed = [&] {
13195a9bdd85SOleksandr "Alex" Zinenko     return llvm::any_of(opHandles,
13205a9bdd85SOleksandr "Alex" Zinenko                         [&](Value h) { return consumedHandles.contains(h); });
13215a9bdd85SOleksandr "Alex" Zinenko   };
13225a9bdd85SOleksandr "Alex" Zinenko 
13235a9bdd85SOleksandr "Alex" Zinenko   // Check if there are any handles that must be updated.
13245a9bdd85SOleksandr "Alex" Zinenko   Value aliveHandle;
13255a9bdd85SOleksandr "Alex" Zinenko   if (config.skipHandleFn) {
13265a9bdd85SOleksandr "Alex" Zinenko     auto it = llvm::find_if(opHandles,
13275a9bdd85SOleksandr "Alex" Zinenko                             [&](Value v) { return !config.skipHandleFn(v); });
13285a9bdd85SOleksandr "Alex" Zinenko     if (it != opHandles.end())
13295a9bdd85SOleksandr "Alex" Zinenko       aliveHandle = *it;
13305a9bdd85SOleksandr "Alex" Zinenko   } else if (!opHandles.empty()) {
13315a9bdd85SOleksandr "Alex" Zinenko     aliveHandle = opHandles.front();
13325a9bdd85SOleksandr "Alex" Zinenko   }
13335a9bdd85SOleksandr "Alex" Zinenko   if (!aliveHandle || handleWasConsumed()) {
13345a9bdd85SOleksandr "Alex" Zinenko     // The op is tracked but the corresponding handles are dead or were
13355a9bdd85SOleksandr "Alex" Zinenko     // consumed. Drop the op form the mapping.
13365a9bdd85SOleksandr "Alex" Zinenko     (void)replacePayloadOp(op, nullptr);
13375a9bdd85SOleksandr "Alex" Zinenko     return;
13385a9bdd85SOleksandr "Alex" Zinenko   }
13395a9bdd85SOleksandr "Alex" Zinenko 
13405a9bdd85SOleksandr "Alex" Zinenko   Operation *replacement;
13415a9bdd85SOleksandr "Alex" Zinenko   DiagnosedSilenceableFailure diag =
13425a9bdd85SOleksandr "Alex" Zinenko       findReplacementOp(replacement, op, newValues);
13435a9bdd85SOleksandr "Alex" Zinenko   // If the op is tracked but no replacement op was found, send a
13445a9bdd85SOleksandr "Alex" Zinenko   // notification.
13455a9bdd85SOleksandr "Alex" Zinenko   if (!diag.succeeded()) {
13465a9bdd85SOleksandr "Alex" Zinenko     diag.attachNote(aliveHandle.getLoc())
13475a9bdd85SOleksandr "Alex" Zinenko         << "replacement is required because this handle must be updated";
13485a9bdd85SOleksandr "Alex" Zinenko     notifyPayloadReplacementNotFound(op, newValues, std::move(diag));
13495a9bdd85SOleksandr "Alex" Zinenko     (void)replacePayloadOp(op, nullptr);
13505a9bdd85SOleksandr "Alex" Zinenko     return;
13515a9bdd85SOleksandr "Alex" Zinenko   }
13525a9bdd85SOleksandr "Alex" Zinenko 
13535a9bdd85SOleksandr "Alex" Zinenko   (void)replacePayloadOp(op, replacement);
13545a9bdd85SOleksandr "Alex" Zinenko }
13555a9bdd85SOleksandr "Alex" Zinenko 
13565a9bdd85SOleksandr "Alex" Zinenko transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
13575a9bdd85SOleksandr "Alex" Zinenko   // The state of the ErrorCheckingTrackingListener must be checked and reset
13585a9bdd85SOleksandr "Alex" Zinenko   // if there was an error. This is to prevent errors from accidentally being
13595a9bdd85SOleksandr "Alex" Zinenko   // missed.
13605a9bdd85SOleksandr "Alex" Zinenko   assert(status.succeeded() && "listener state was not checked");
13615a9bdd85SOleksandr "Alex" Zinenko }
13625a9bdd85SOleksandr "Alex" Zinenko 
13635a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure
13645a9bdd85SOleksandr "Alex" Zinenko transform::ErrorCheckingTrackingListener::checkAndResetError() {
13655a9bdd85SOleksandr "Alex" Zinenko   DiagnosedSilenceableFailure s = std::move(status);
13665a9bdd85SOleksandr "Alex" Zinenko   status = DiagnosedSilenceableFailure::success();
13675a9bdd85SOleksandr "Alex" Zinenko   errorCounter = 0;
13685a9bdd85SOleksandr "Alex" Zinenko   return s;
13695a9bdd85SOleksandr "Alex" Zinenko }
13705a9bdd85SOleksandr "Alex" Zinenko 
13715a9bdd85SOleksandr "Alex" Zinenko bool transform::ErrorCheckingTrackingListener::failed() const {
13725a9bdd85SOleksandr "Alex" Zinenko   return !status.succeeded();
13735a9bdd85SOleksandr "Alex" Zinenko }
13745a9bdd85SOleksandr "Alex" Zinenko 
13755a9bdd85SOleksandr "Alex" Zinenko void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
13765a9bdd85SOleksandr "Alex" Zinenko     Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
13775a9bdd85SOleksandr "Alex" Zinenko 
13785a9bdd85SOleksandr "Alex" Zinenko   // Merge potentially existing diags and store the result in the listener.
13795a9bdd85SOleksandr "Alex" Zinenko   SmallVector<Diagnostic> diags;
13805a9bdd85SOleksandr "Alex" Zinenko   diag.takeDiagnostics(diags);
13815a9bdd85SOleksandr "Alex" Zinenko   if (!status.succeeded())
13825a9bdd85SOleksandr "Alex" Zinenko     status.takeDiagnostics(diags);
13835a9bdd85SOleksandr "Alex" Zinenko   status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
13845a9bdd85SOleksandr "Alex" Zinenko 
13855a9bdd85SOleksandr "Alex" Zinenko   // Report more details.
13865a9bdd85SOleksandr "Alex" Zinenko   status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
13875a9bdd85SOleksandr "Alex" Zinenko   for (auto &&[index, value] : llvm::enumerate(values))
13885a9bdd85SOleksandr "Alex" Zinenko     status.attachNote(value.getLoc())
13895a9bdd85SOleksandr "Alex" Zinenko         << "[" << errorCounter << "] replacement value " << index;
13905a9bdd85SOleksandr "Alex" Zinenko   ++errorCounter;
13915a9bdd85SOleksandr "Alex" Zinenko }
13925a9bdd85SOleksandr "Alex" Zinenko 
13935a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
13945a9bdd85SOleksandr "Alex" Zinenko // TransformRewriter
13955a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
13965a9bdd85SOleksandr "Alex" Zinenko 
13975a9bdd85SOleksandr "Alex" Zinenko transform::TransformRewriter::TransformRewriter(
13985a9bdd85SOleksandr "Alex" Zinenko     MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
13995a9bdd85SOleksandr "Alex" Zinenko     : RewriterBase(ctx), listener(listener) {
14005a9bdd85SOleksandr "Alex" Zinenko   setListener(listener);
14015a9bdd85SOleksandr "Alex" Zinenko }
14025a9bdd85SOleksandr "Alex" Zinenko 
14035a9bdd85SOleksandr "Alex" Zinenko bool transform::TransformRewriter::hasTrackingFailures() const {
14045a9bdd85SOleksandr "Alex" Zinenko   return listener->failed();
14055a9bdd85SOleksandr "Alex" Zinenko }
14065a9bdd85SOleksandr "Alex" Zinenko 
14075a9bdd85SOleksandr "Alex" Zinenko /// Silence all tracking failures that have been encountered so far.
14085a9bdd85SOleksandr "Alex" Zinenko void transform::TransformRewriter::silenceTrackingFailure() {
14095a9bdd85SOleksandr "Alex" Zinenko   if (hasTrackingFailures()) {
14105a9bdd85SOleksandr "Alex" Zinenko     DiagnosedSilenceableFailure status = listener->checkAndResetError();
14115a9bdd85SOleksandr "Alex" Zinenko     (void)status.silence();
14125a9bdd85SOleksandr "Alex" Zinenko   }
14135a9bdd85SOleksandr "Alex" Zinenko }
14145a9bdd85SOleksandr "Alex" Zinenko 
14155a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformRewriter::notifyPayloadOperationReplaced(
14165a9bdd85SOleksandr "Alex" Zinenko     Operation *op, Operation *replacement) {
14175a9bdd85SOleksandr "Alex" Zinenko   return listener->replacePayloadOp(op, replacement);
14185a9bdd85SOleksandr "Alex" Zinenko }
14195a9bdd85SOleksandr "Alex" Zinenko 
14205a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
14215a9bdd85SOleksandr "Alex" Zinenko // Utilities for TransformEachOpTrait.
14225a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
14235a9bdd85SOleksandr "Alex" Zinenko 
14245a9bdd85SOleksandr "Alex" Zinenko LogicalResult
14255a9bdd85SOleksandr "Alex" Zinenko transform::detail::checkNestedConsumption(Location loc,
14265a9bdd85SOleksandr "Alex" Zinenko                                           ArrayRef<Operation *> targets) {
14275a9bdd85SOleksandr "Alex" Zinenko   for (auto &&[position, parent] : llvm::enumerate(targets)) {
14285a9bdd85SOleksandr "Alex" Zinenko     for (Operation *child : targets.drop_front(position + 1)) {
14295a9bdd85SOleksandr "Alex" Zinenko       if (parent->isAncestor(child)) {
14305a9bdd85SOleksandr "Alex" Zinenko         InFlightDiagnostic diag =
14315a9bdd85SOleksandr "Alex" Zinenko             emitError(loc)
14325a9bdd85SOleksandr "Alex" Zinenko             << "transform operation consumes a handle pointing to an ancestor "
14335a9bdd85SOleksandr "Alex" Zinenko                "payload operation before its descendant";
14345a9bdd85SOleksandr "Alex" Zinenko         diag.attachNote()
14355a9bdd85SOleksandr "Alex" Zinenko             << "the ancestor is likely erased or rewritten before the "
14365a9bdd85SOleksandr "Alex" Zinenko                "descendant is accessed, leading to undefined behavior";
14375a9bdd85SOleksandr "Alex" Zinenko         diag.attachNote(parent->getLoc()) << "ancestor payload op";
14385a9bdd85SOleksandr "Alex" Zinenko         diag.attachNote(child->getLoc()) << "descendant payload op";
14395a9bdd85SOleksandr "Alex" Zinenko         return diag;
14405a9bdd85SOleksandr "Alex" Zinenko       }
14415a9bdd85SOleksandr "Alex" Zinenko     }
14425a9bdd85SOleksandr "Alex" Zinenko   }
14435a9bdd85SOleksandr "Alex" Zinenko   return success();
14445a9bdd85SOleksandr "Alex" Zinenko }
14455a9bdd85SOleksandr "Alex" Zinenko 
14465a9bdd85SOleksandr "Alex" Zinenko LogicalResult
14475a9bdd85SOleksandr "Alex" Zinenko transform::detail::checkApplyToOne(Operation *transformOp,
14485a9bdd85SOleksandr "Alex" Zinenko                                    Location payloadOpLoc,
14495a9bdd85SOleksandr "Alex" Zinenko                                    const ApplyToEachResultList &partialResult) {
14505a9bdd85SOleksandr "Alex" Zinenko   Location transformOpLoc = transformOp->getLoc();
14515a9bdd85SOleksandr "Alex" Zinenko   StringRef transformOpName = transformOp->getName().getStringRef();
14525a9bdd85SOleksandr "Alex" Zinenko   unsigned expectedNumResults = transformOp->getNumResults();
14535a9bdd85SOleksandr "Alex" Zinenko 
14545a9bdd85SOleksandr "Alex" Zinenko   // Reuse the emission of the diagnostic note.
14555a9bdd85SOleksandr "Alex" Zinenko   auto emitDiag = [&]() {
14565a9bdd85SOleksandr "Alex" Zinenko     auto diag = mlir::emitError(transformOpLoc);
14575a9bdd85SOleksandr "Alex" Zinenko     diag.attachNote(payloadOpLoc) << "when applied to this op";
14585a9bdd85SOleksandr "Alex" Zinenko     return diag;
14595a9bdd85SOleksandr "Alex" Zinenko   };
14605a9bdd85SOleksandr "Alex" Zinenko 
14615a9bdd85SOleksandr "Alex" Zinenko   if (partialResult.size() != expectedNumResults) {
14625a9bdd85SOleksandr "Alex" Zinenko     auto diag = emitDiag() << "application of " << transformOpName
14635a9bdd85SOleksandr "Alex" Zinenko                            << " expected to produce " << expectedNumResults
14645a9bdd85SOleksandr "Alex" Zinenko                            << " results (actually produced "
14655a9bdd85SOleksandr "Alex" Zinenko                            << partialResult.size() << ").";
14665a9bdd85SOleksandr "Alex" Zinenko     diag.attachNote(transformOpLoc)
14675a9bdd85SOleksandr "Alex" Zinenko         << "if you need variadic results, consider a generic `apply` "
14685a9bdd85SOleksandr "Alex" Zinenko         << "instead of the specialized `applyToOne`.";
14695a9bdd85SOleksandr "Alex" Zinenko     return failure();
14705a9bdd85SOleksandr "Alex" Zinenko   }
14715a9bdd85SOleksandr "Alex" Zinenko 
14725a9bdd85SOleksandr "Alex" Zinenko   // Check that the right kind of value was produced.
14735a9bdd85SOleksandr "Alex" Zinenko   for (const auto &[ptr, res] :
14745a9bdd85SOleksandr "Alex" Zinenko        llvm::zip(partialResult, transformOp->getResults())) {
14755a9bdd85SOleksandr "Alex" Zinenko     if (ptr.isNull())
14765a9bdd85SOleksandr "Alex" Zinenko       continue;
14775a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<TransformHandleTypeInterface>(res.getType()) &&
1478*129f1001SKazu Hirata         !isa<Operation *>(ptr)) {
14795a9bdd85SOleksandr "Alex" Zinenko       return emitDiag() << "application of " << transformOpName
14805a9bdd85SOleksandr "Alex" Zinenko                         << " expected to produce an Operation * for result #"
14815a9bdd85SOleksandr "Alex" Zinenko                         << res.getResultNumber();
14825a9bdd85SOleksandr "Alex" Zinenko     }
14835a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<TransformParamTypeInterface>(res.getType()) &&
1484*129f1001SKazu Hirata         !isa<Attribute>(ptr)) {
14855a9bdd85SOleksandr "Alex" Zinenko       return emitDiag() << "application of " << transformOpName
14865a9bdd85SOleksandr "Alex" Zinenko                         << " expected to produce an Attribute for result #"
14875a9bdd85SOleksandr "Alex" Zinenko                         << res.getResultNumber();
14885a9bdd85SOleksandr "Alex" Zinenko     }
14895a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) &&
1490*129f1001SKazu Hirata         !isa<Value>(ptr)) {
14915a9bdd85SOleksandr "Alex" Zinenko       return emitDiag() << "application of " << transformOpName
14925a9bdd85SOleksandr "Alex" Zinenko                         << " expected to produce a Value for result #"
14935a9bdd85SOleksandr "Alex" Zinenko                         << res.getResultNumber();
14945a9bdd85SOleksandr "Alex" Zinenko     }
14955a9bdd85SOleksandr "Alex" Zinenko   }
14965a9bdd85SOleksandr "Alex" Zinenko   return success();
14975a9bdd85SOleksandr "Alex" Zinenko }
14985a9bdd85SOleksandr "Alex" Zinenko 
14995a9bdd85SOleksandr "Alex" Zinenko template <typename T>
15005a9bdd85SOleksandr "Alex" Zinenko static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) {
15015a9bdd85SOleksandr "Alex" Zinenko   return llvm::to_vector(llvm::map_range(
1502*129f1001SKazu Hirata       range, [](transform::MappedValue value) { return cast<T>(value); }));
15035a9bdd85SOleksandr "Alex" Zinenko }
15045a9bdd85SOleksandr "Alex" Zinenko 
15055a9bdd85SOleksandr "Alex" Zinenko void transform::detail::setApplyToOneResults(
15065a9bdd85SOleksandr "Alex" Zinenko     Operation *transformOp, TransformResults &transformResults,
15075a9bdd85SOleksandr "Alex" Zinenko     ArrayRef<ApplyToEachResultList> results) {
15085a9bdd85SOleksandr "Alex" Zinenko   SmallVector<SmallVector<MappedValue>> transposed;
15095a9bdd85SOleksandr "Alex" Zinenko   transposed.resize(transformOp->getNumResults());
15105a9bdd85SOleksandr "Alex" Zinenko   for (const ApplyToEachResultList &partialResults : results) {
15115a9bdd85SOleksandr "Alex" Zinenko     if (llvm::any_of(partialResults,
15125a9bdd85SOleksandr "Alex" Zinenko                      [](MappedValue value) { return value.isNull(); }))
15135a9bdd85SOleksandr "Alex" Zinenko       continue;
15145a9bdd85SOleksandr "Alex" Zinenko     assert(transformOp->getNumResults() == partialResults.size() &&
15155a9bdd85SOleksandr "Alex" Zinenko            "expected as many partial results as op as results");
15165a9bdd85SOleksandr "Alex" Zinenko     for (auto [i, value] : llvm::enumerate(partialResults))
15175a9bdd85SOleksandr "Alex" Zinenko       transposed[i].push_back(value);
15185a9bdd85SOleksandr "Alex" Zinenko   }
15195a9bdd85SOleksandr "Alex" Zinenko 
15205a9bdd85SOleksandr "Alex" Zinenko   for (OpResult r : transformOp->getResults()) {
15215a9bdd85SOleksandr "Alex" Zinenko     unsigned position = r.getResultNumber();
15225a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<TransformParamTypeInterface>(r.getType())) {
15235a9bdd85SOleksandr "Alex" Zinenko       transformResults.setParams(r,
15245a9bdd85SOleksandr "Alex" Zinenko                                  castVector<Attribute>(transposed[position]));
15255a9bdd85SOleksandr "Alex" Zinenko     } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) {
15265a9bdd85SOleksandr "Alex" Zinenko       transformResults.setValues(r, castVector<Value>(transposed[position]));
15275a9bdd85SOleksandr "Alex" Zinenko     } else {
15285a9bdd85SOleksandr "Alex" Zinenko       transformResults.set(r, castVector<Operation *>(transposed[position]));
15295a9bdd85SOleksandr "Alex" Zinenko     }
15305a9bdd85SOleksandr "Alex" Zinenko   }
15315a9bdd85SOleksandr "Alex" Zinenko }
15325a9bdd85SOleksandr "Alex" Zinenko 
15335a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
15345a9bdd85SOleksandr "Alex" Zinenko // Utilities for implementing transform ops with regions.
15355a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
15365a9bdd85SOleksandr "Alex" Zinenko 
1537e4b04b39SOleksandr "Alex" Zinenko LogicalResult transform::detail::appendValueMappings(
1538e4b04b39SOleksandr "Alex" Zinenko     MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
1539e4b04b39SOleksandr "Alex" Zinenko     ValueRange values, const transform::TransformState &state, bool flatten) {
1540e4b04b39SOleksandr "Alex" Zinenko   assert(mappings.size() == values.size() && "mismatching number of mappings");
1541e4b04b39SOleksandr "Alex" Zinenko   for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) {
1542e4b04b39SOleksandr "Alex" Zinenko     size_t mappedSize = mapped.size();
15435a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
15445a9bdd85SOleksandr "Alex" Zinenko       llvm::append_range(mapped, state.getPayloadOps(operand));
15455a9bdd85SOleksandr "Alex" Zinenko     } else if (llvm::isa<TransformValueHandleTypeInterface>(
15465a9bdd85SOleksandr "Alex" Zinenko                    operand.getType())) {
15475a9bdd85SOleksandr "Alex" Zinenko       llvm::append_range(mapped, state.getPayloadValues(operand));
15485a9bdd85SOleksandr "Alex" Zinenko     } else {
15495a9bdd85SOleksandr "Alex" Zinenko       assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) &&
15505a9bdd85SOleksandr "Alex" Zinenko              "unsupported kind of transform dialect value");
15515a9bdd85SOleksandr "Alex" Zinenko       llvm::append_range(mapped, state.getParams(operand));
15525a9bdd85SOleksandr "Alex" Zinenko     }
1553e4b04b39SOleksandr "Alex" Zinenko 
1554e4b04b39SOleksandr "Alex" Zinenko     if (mapped.size() - mappedSize != 1 && !flatten)
1555e4b04b39SOleksandr "Alex" Zinenko       return failure();
15565a9bdd85SOleksandr "Alex" Zinenko   }
1557e4b04b39SOleksandr "Alex" Zinenko   return success();
1558e4b04b39SOleksandr "Alex" Zinenko }
1559e4b04b39SOleksandr "Alex" Zinenko 
1560e4b04b39SOleksandr "Alex" Zinenko void transform::detail::prepareValueMappings(
1561e4b04b39SOleksandr "Alex" Zinenko     SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
1562e4b04b39SOleksandr "Alex" Zinenko     ValueRange values, const transform::TransformState &state) {
1563e4b04b39SOleksandr "Alex" Zinenko   mappings.resize(mappings.size() + values.size());
1564e4b04b39SOleksandr "Alex" Zinenko   (void)appendValueMappings(
1565e4b04b39SOleksandr "Alex" Zinenko       MutableArrayRef<SmallVector<transform::MappedValue>>(mappings).take_back(
1566e4b04b39SOleksandr "Alex" Zinenko           values.size()),
1567e4b04b39SOleksandr "Alex" Zinenko       values, state);
15685a9bdd85SOleksandr "Alex" Zinenko }
15695a9bdd85SOleksandr "Alex" Zinenko 
15705a9bdd85SOleksandr "Alex" Zinenko void transform::detail::forwardTerminatorOperands(
15715a9bdd85SOleksandr "Alex" Zinenko     Block *block, transform::TransformState &state,
15725a9bdd85SOleksandr "Alex" Zinenko     transform::TransformResults &results) {
15735a9bdd85SOleksandr "Alex" Zinenko   for (auto &&[terminatorOperand, result] :
15745a9bdd85SOleksandr "Alex" Zinenko        llvm::zip(block->getTerminator()->getOperands(),
15755a9bdd85SOleksandr "Alex" Zinenko                  block->getParentOp()->getOpResults())) {
15765a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) {
15775a9bdd85SOleksandr "Alex" Zinenko       results.set(result, state.getPayloadOps(terminatorOperand));
15785a9bdd85SOleksandr "Alex" Zinenko     } else if (llvm::isa<transform::TransformValueHandleTypeInterface>(
15795a9bdd85SOleksandr "Alex" Zinenko                    result.getType())) {
15805a9bdd85SOleksandr "Alex" Zinenko       results.setValues(result, state.getPayloadValues(terminatorOperand));
15815a9bdd85SOleksandr "Alex" Zinenko     } else {
15825a9bdd85SOleksandr "Alex" Zinenko       assert(
15835a9bdd85SOleksandr "Alex" Zinenko           llvm::isa<transform::TransformParamTypeInterface>(result.getType()) &&
15845a9bdd85SOleksandr "Alex" Zinenko           "unhandled transform type interface");
15855a9bdd85SOleksandr "Alex" Zinenko       results.setParams(result, state.getParams(terminatorOperand));
15865a9bdd85SOleksandr "Alex" Zinenko     }
15875a9bdd85SOleksandr "Alex" Zinenko   }
15885a9bdd85SOleksandr "Alex" Zinenko }
15895a9bdd85SOleksandr "Alex" Zinenko 
15905a9bdd85SOleksandr "Alex" Zinenko transform::TransformState
15915a9bdd85SOleksandr "Alex" Zinenko transform::detail::makeTransformStateForTesting(Region *region,
15925a9bdd85SOleksandr "Alex" Zinenko                                                 Operation *payloadRoot) {
15935a9bdd85SOleksandr "Alex" Zinenko   return TransformState(region, payloadRoot);
15945a9bdd85SOleksandr "Alex" Zinenko }
15955a9bdd85SOleksandr "Alex" Zinenko 
15965a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
15975a9bdd85SOleksandr "Alex" Zinenko // Utilities for PossibleTopLevelTransformOpTrait.
15985a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
15995a9bdd85SOleksandr "Alex" Zinenko 
16005a9bdd85SOleksandr "Alex" Zinenko /// Appends to `effects` the memory effect instances on `target` with the same
16015a9bdd85SOleksandr "Alex" Zinenko /// resource and effect as the ones the operation `iface` having on `source`.
16025a9bdd85SOleksandr "Alex" Zinenko static void
16032c1ae801Sdonald chen remapEffects(MemoryEffectOpInterface iface, BlockArgument source,
16042c1ae801Sdonald chen              OpOperand *target,
16055a9bdd85SOleksandr "Alex" Zinenko              SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
16065a9bdd85SOleksandr "Alex" Zinenko   SmallVector<MemoryEffects::EffectInstance> nestedEffects;
16075a9bdd85SOleksandr "Alex" Zinenko   iface.getEffectsOnValue(source, nestedEffects);
16085a9bdd85SOleksandr "Alex" Zinenko   for (const auto &effect : nestedEffects)
16095a9bdd85SOleksandr "Alex" Zinenko     effects.emplace_back(effect.getEffect(), target, effect.getResource());
16105a9bdd85SOleksandr "Alex" Zinenko }
16115a9bdd85SOleksandr "Alex" Zinenko 
16125a9bdd85SOleksandr "Alex" Zinenko /// Appends to `effects` the same effects as the operations of `block` have on
16135a9bdd85SOleksandr "Alex" Zinenko /// block arguments but associated with `operands.`
16145a9bdd85SOleksandr "Alex" Zinenko static void
16152c1ae801Sdonald chen remapArgumentEffects(Block &block, MutableArrayRef<OpOperand> operands,
16165a9bdd85SOleksandr "Alex" Zinenko                      SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
16175a9bdd85SOleksandr "Alex" Zinenko   for (Operation &op : block) {
16185a9bdd85SOleksandr "Alex" Zinenko     auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
16195a9bdd85SOleksandr "Alex" Zinenko     if (!iface)
16205a9bdd85SOleksandr "Alex" Zinenko       continue;
16215a9bdd85SOleksandr "Alex" Zinenko 
16225a9bdd85SOleksandr "Alex" Zinenko     for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) {
16232c1ae801Sdonald chen       remapEffects(iface, source, &target, effects);
16245a9bdd85SOleksandr "Alex" Zinenko     }
16255a9bdd85SOleksandr "Alex" Zinenko 
16265a9bdd85SOleksandr "Alex" Zinenko     SmallVector<MemoryEffects::EffectInstance> nestedEffects;
16275a9bdd85SOleksandr "Alex" Zinenko     iface.getEffectsOnResource(transform::PayloadIRResource::get(),
16285a9bdd85SOleksandr "Alex" Zinenko                                nestedEffects);
16295a9bdd85SOleksandr "Alex" Zinenko     llvm::append_range(effects, nestedEffects);
16305a9bdd85SOleksandr "Alex" Zinenko   }
16315a9bdd85SOleksandr "Alex" Zinenko }
16325a9bdd85SOleksandr "Alex" Zinenko 
16335a9bdd85SOleksandr "Alex" Zinenko void transform::detail::getPotentialTopLevelEffects(
16345a9bdd85SOleksandr "Alex" Zinenko     Operation *operation, Value root, Block &body,
16355a9bdd85SOleksandr "Alex" Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
16362c1ae801Sdonald chen   transform::onlyReadsHandle(operation->getOpOperands(), effects);
16372c1ae801Sdonald chen   transform::producesHandle(operation->getOpResults(), effects);
16385a9bdd85SOleksandr "Alex" Zinenko 
16395a9bdd85SOleksandr "Alex" Zinenko   if (!root) {
16405a9bdd85SOleksandr "Alex" Zinenko     for (Operation &op : body) {
16415a9bdd85SOleksandr "Alex" Zinenko       auto iface = dyn_cast<MemoryEffectOpInterface>(&op);
16425a9bdd85SOleksandr "Alex" Zinenko       if (!iface)
16435a9bdd85SOleksandr "Alex" Zinenko         continue;
16445a9bdd85SOleksandr "Alex" Zinenko 
16455a9bdd85SOleksandr "Alex" Zinenko       SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects;
16465a9bdd85SOleksandr "Alex" Zinenko       iface.getEffects(effects);
16475a9bdd85SOleksandr "Alex" Zinenko     }
16485a9bdd85SOleksandr "Alex" Zinenko     return;
16495a9bdd85SOleksandr "Alex" Zinenko   }
16505a9bdd85SOleksandr "Alex" Zinenko 
16515a9bdd85SOleksandr "Alex" Zinenko   // Carry over all effects on arguments of the entry block as those on the
16525a9bdd85SOleksandr "Alex" Zinenko   // operands, this is the same value just remapped.
16532c1ae801Sdonald chen   remapArgumentEffects(body, operation->getOpOperands(), effects);
16545a9bdd85SOleksandr "Alex" Zinenko }
16555a9bdd85SOleksandr "Alex" Zinenko 
16565a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
16575a9bdd85SOleksandr "Alex" Zinenko     TransformState &state, Operation *op, Region &region) {
16585a9bdd85SOleksandr "Alex" Zinenko   SmallVector<Operation *> targets;
16595a9bdd85SOleksandr "Alex" Zinenko   SmallVector<SmallVector<MappedValue>> extraMappings;
16605a9bdd85SOleksandr "Alex" Zinenko   if (op->getNumOperands() != 0) {
16615a9bdd85SOleksandr "Alex" Zinenko     llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
16625a9bdd85SOleksandr "Alex" Zinenko     prepareValueMappings(extraMappings, op->getOperands().drop_front(), state);
16635a9bdd85SOleksandr "Alex" Zinenko   } else {
16645a9bdd85SOleksandr "Alex" Zinenko     if (state.getNumTopLevelMappings() !=
16655a9bdd85SOleksandr "Alex" Zinenko         region.front().getNumArguments() - 1) {
16665a9bdd85SOleksandr "Alex" Zinenko       return emitError(op->getLoc())
16675a9bdd85SOleksandr "Alex" Zinenko              << "operation expects " << region.front().getNumArguments() - 1
16685a9bdd85SOleksandr "Alex" Zinenko              << " extra value bindings, but " << state.getNumTopLevelMappings()
16695a9bdd85SOleksandr "Alex" Zinenko              << " were provided to the interpreter";
16705a9bdd85SOleksandr "Alex" Zinenko     }
16715a9bdd85SOleksandr "Alex" Zinenko 
16725a9bdd85SOleksandr "Alex" Zinenko     targets.push_back(state.getTopLevel());
16735a9bdd85SOleksandr "Alex" Zinenko 
16745a9bdd85SOleksandr "Alex" Zinenko     for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i)
16755a9bdd85SOleksandr "Alex" Zinenko       extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i)));
16765a9bdd85SOleksandr "Alex" Zinenko   }
16775a9bdd85SOleksandr "Alex" Zinenko 
16785a9bdd85SOleksandr "Alex" Zinenko   if (failed(state.mapBlockArguments(region.front().getArgument(0), targets)))
16795a9bdd85SOleksandr "Alex" Zinenko     return failure();
16805a9bdd85SOleksandr "Alex" Zinenko 
16815a9bdd85SOleksandr "Alex" Zinenko   for (BlockArgument argument : region.front().getArguments().drop_front()) {
16825a9bdd85SOleksandr "Alex" Zinenko     if (failed(state.mapBlockArgument(
16835a9bdd85SOleksandr "Alex" Zinenko             argument, extraMappings[argument.getArgNumber() - 1])))
16845a9bdd85SOleksandr "Alex" Zinenko       return failure();
16855a9bdd85SOleksandr "Alex" Zinenko   }
16865a9bdd85SOleksandr "Alex" Zinenko 
16875a9bdd85SOleksandr "Alex" Zinenko   return success();
16885a9bdd85SOleksandr "Alex" Zinenko }
16895a9bdd85SOleksandr "Alex" Zinenko 
16905a9bdd85SOleksandr "Alex" Zinenko LogicalResult
16915a9bdd85SOleksandr "Alex" Zinenko transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
16925a9bdd85SOleksandr "Alex" Zinenko   // Attaching this trait without the interface is a misuse of the API, but it
16935a9bdd85SOleksandr "Alex" Zinenko   // cannot be caught via a static_assert because interface registration is
16945a9bdd85SOleksandr "Alex" Zinenko   // dynamic.
16955a9bdd85SOleksandr "Alex" Zinenko   assert(isa<TransformOpInterface>(op) &&
16965a9bdd85SOleksandr "Alex" Zinenko          "should implement TransformOpInterface to have "
16975a9bdd85SOleksandr "Alex" Zinenko          "PossibleTopLevelTransformOpTrait");
16985a9bdd85SOleksandr "Alex" Zinenko 
16995a9bdd85SOleksandr "Alex" Zinenko   if (op->getNumRegions() < 1)
17005a9bdd85SOleksandr "Alex" Zinenko     return op->emitOpError() << "expects at least one region";
17015a9bdd85SOleksandr "Alex" Zinenko 
17025a9bdd85SOleksandr "Alex" Zinenko   Region *bodyRegion = &op->getRegion(0);
17035a9bdd85SOleksandr "Alex" Zinenko   if (!llvm::hasNItems(*bodyRegion, 1))
17045a9bdd85SOleksandr "Alex" Zinenko     return op->emitOpError() << "expects a single-block region";
17055a9bdd85SOleksandr "Alex" Zinenko 
17065a9bdd85SOleksandr "Alex" Zinenko   Block *body = &bodyRegion->front();
17075a9bdd85SOleksandr "Alex" Zinenko   if (body->getNumArguments() == 0) {
17085a9bdd85SOleksandr "Alex" Zinenko     return op->emitOpError()
17095a9bdd85SOleksandr "Alex" Zinenko            << "expects the entry block to have at least one argument";
17105a9bdd85SOleksandr "Alex" Zinenko   }
17115a9bdd85SOleksandr "Alex" Zinenko   if (!llvm::isa<TransformHandleTypeInterface>(
17125a9bdd85SOleksandr "Alex" Zinenko           body->getArgument(0).getType())) {
17135a9bdd85SOleksandr "Alex" Zinenko     return op->emitOpError()
17145a9bdd85SOleksandr "Alex" Zinenko            << "expects the first entry block argument to be of type "
17155a9bdd85SOleksandr "Alex" Zinenko               "implementing TransformHandleTypeInterface";
17165a9bdd85SOleksandr "Alex" Zinenko   }
17175a9bdd85SOleksandr "Alex" Zinenko   BlockArgument arg = body->getArgument(0);
17185a9bdd85SOleksandr "Alex" Zinenko   if (op->getNumOperands() != 0) {
17195a9bdd85SOleksandr "Alex" Zinenko     if (arg.getType() != op->getOperand(0).getType()) {
17205a9bdd85SOleksandr "Alex" Zinenko       return op->emitOpError()
17215a9bdd85SOleksandr "Alex" Zinenko              << "expects the type of the block argument to match "
17225a9bdd85SOleksandr "Alex" Zinenko                 "the type of the operand";
17235a9bdd85SOleksandr "Alex" Zinenko     }
17245a9bdd85SOleksandr "Alex" Zinenko   }
17255a9bdd85SOleksandr "Alex" Zinenko   for (BlockArgument arg : body->getArguments().drop_front()) {
17265a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface,
17275a9bdd85SOleksandr "Alex" Zinenko                   TransformValueHandleTypeInterface>(arg.getType()))
17285a9bdd85SOleksandr "Alex" Zinenko       continue;
17295a9bdd85SOleksandr "Alex" Zinenko 
17305a9bdd85SOleksandr "Alex" Zinenko     InFlightDiagnostic diag =
17315a9bdd85SOleksandr "Alex" Zinenko         op->emitOpError()
17325a9bdd85SOleksandr "Alex" Zinenko         << "expects trailing entry block arguments to be of type implementing "
17335a9bdd85SOleksandr "Alex" Zinenko            "TransformHandleTypeInterface, TransformValueHandleTypeInterface or "
17345a9bdd85SOleksandr "Alex" Zinenko            "TransformParamTypeInterface";
17355a9bdd85SOleksandr "Alex" Zinenko     diag.attachNote() << "argument #" << arg.getArgNumber() << " does not";
17365a9bdd85SOleksandr "Alex" Zinenko     return diag;
17375a9bdd85SOleksandr "Alex" Zinenko   }
17385a9bdd85SOleksandr "Alex" Zinenko 
17395a9bdd85SOleksandr "Alex" Zinenko   if (auto *parent =
17405a9bdd85SOleksandr "Alex" Zinenko           op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
17415a9bdd85SOleksandr "Alex" Zinenko     if (op->getNumOperands() != body->getNumArguments()) {
17425a9bdd85SOleksandr "Alex" Zinenko       InFlightDiagnostic diag =
17435a9bdd85SOleksandr "Alex" Zinenko           op->emitOpError()
17445a9bdd85SOleksandr "Alex" Zinenko           << "expects operands to be provided for a nested op";
17455a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote(parent->getLoc())
17465a9bdd85SOleksandr "Alex" Zinenko           << "nested in another possible top-level op";
17475a9bdd85SOleksandr "Alex" Zinenko       return diag;
17485a9bdd85SOleksandr "Alex" Zinenko     }
17495a9bdd85SOleksandr "Alex" Zinenko   }
17505a9bdd85SOleksandr "Alex" Zinenko 
17515a9bdd85SOleksandr "Alex" Zinenko   return success();
17525a9bdd85SOleksandr "Alex" Zinenko }
17535a9bdd85SOleksandr "Alex" Zinenko 
17545a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
17555a9bdd85SOleksandr "Alex" Zinenko // Utilities for ParamProducedTransformOpTrait.
17565a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
17575a9bdd85SOleksandr "Alex" Zinenko 
17585a9bdd85SOleksandr "Alex" Zinenko void transform::detail::getParamProducerTransformOpTraitEffects(
17595a9bdd85SOleksandr "Alex" Zinenko     Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
17605a9bdd85SOleksandr "Alex" Zinenko   producesHandle(op->getResults(), effects);
17615a9bdd85SOleksandr "Alex" Zinenko   bool hasPayloadOperands = false;
17622c1ae801Sdonald chen   for (OpOperand &operand : op->getOpOperands()) {
17635a9bdd85SOleksandr "Alex" Zinenko     onlyReadsHandle(operand, effects);
17645a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<TransformHandleTypeInterface,
17652c1ae801Sdonald chen                   TransformValueHandleTypeInterface>(operand.get().getType()))
17665a9bdd85SOleksandr "Alex" Zinenko       hasPayloadOperands = true;
17675a9bdd85SOleksandr "Alex" Zinenko   }
17685a9bdd85SOleksandr "Alex" Zinenko   if (hasPayloadOperands)
17695a9bdd85SOleksandr "Alex" Zinenko     onlyReadsPayload(effects);
17705a9bdd85SOleksandr "Alex" Zinenko }
17715a9bdd85SOleksandr "Alex" Zinenko 
17725a9bdd85SOleksandr "Alex" Zinenko LogicalResult
17735a9bdd85SOleksandr "Alex" Zinenko transform::detail::verifyParamProducerTransformOpTrait(Operation *op) {
17745a9bdd85SOleksandr "Alex" Zinenko   // Interfaces can be attached dynamically, so this cannot be a static
17755a9bdd85SOleksandr "Alex" Zinenko   // assert.
17765a9bdd85SOleksandr "Alex" Zinenko   if (!op->getName().getInterface<MemoryEffectOpInterface>()) {
17775a9bdd85SOleksandr "Alex" Zinenko     llvm::report_fatal_error(
17785a9bdd85SOleksandr "Alex" Zinenko         Twine("ParamProducerTransformOpTrait must be attached to an op that "
17795a9bdd85SOleksandr "Alex" Zinenko               "implements MemoryEffectsOpInterface, found on ") +
17805a9bdd85SOleksandr "Alex" Zinenko         op->getName().getStringRef());
17815a9bdd85SOleksandr "Alex" Zinenko   }
17825a9bdd85SOleksandr "Alex" Zinenko   for (Value result : op->getResults()) {
17835a9bdd85SOleksandr "Alex" Zinenko     if (llvm::isa<TransformParamTypeInterface>(result.getType()))
17845a9bdd85SOleksandr "Alex" Zinenko       continue;
17855a9bdd85SOleksandr "Alex" Zinenko     return op->emitOpError()
17865a9bdd85SOleksandr "Alex" Zinenko            << "ParamProducerTransformOpTrait attached to this op expects "
17875a9bdd85SOleksandr "Alex" Zinenko               "result types to implement TransformParamTypeInterface";
17885a9bdd85SOleksandr "Alex" Zinenko   }
17895a9bdd85SOleksandr "Alex" Zinenko   return success();
17905a9bdd85SOleksandr "Alex" Zinenko }
17915a9bdd85SOleksandr "Alex" Zinenko 
17925a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
17935a9bdd85SOleksandr "Alex" Zinenko // Memory effects.
17945a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
17955a9bdd85SOleksandr "Alex" Zinenko 
17965a9bdd85SOleksandr "Alex" Zinenko void transform::consumesHandle(
17972c1ae801Sdonald chen     MutableArrayRef<OpOperand> handles,
17985a9bdd85SOleksandr "Alex" Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
17992c1ae801Sdonald chen   for (OpOperand &handle : handles) {
18002c1ae801Sdonald chen     effects.emplace_back(MemoryEffects::Read::get(), &handle,
18015a9bdd85SOleksandr "Alex" Zinenko                          TransformMappingResource::get());
18022c1ae801Sdonald chen     effects.emplace_back(MemoryEffects::Free::get(), &handle,
18035a9bdd85SOleksandr "Alex" Zinenko                          TransformMappingResource::get());
18045a9bdd85SOleksandr "Alex" Zinenko   }
18055a9bdd85SOleksandr "Alex" Zinenko }
18065a9bdd85SOleksandr "Alex" Zinenko 
18075a9bdd85SOleksandr "Alex" Zinenko /// Returns `true` if the given list of effects instances contains an instance
18085a9bdd85SOleksandr "Alex" Zinenko /// with the effect type specified as template parameter.
18095a9bdd85SOleksandr "Alex" Zinenko template <typename EffectTy, typename ResourceTy, typename Range>
18105a9bdd85SOleksandr "Alex" Zinenko static bool hasEffect(Range &&effects) {
18115a9bdd85SOleksandr "Alex" Zinenko   return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
18125a9bdd85SOleksandr "Alex" Zinenko     return isa<EffectTy>(effect.getEffect()) &&
18135a9bdd85SOleksandr "Alex" Zinenko            isa<ResourceTy>(effect.getResource());
18145a9bdd85SOleksandr "Alex" Zinenko   });
18155a9bdd85SOleksandr "Alex" Zinenko }
18165a9bdd85SOleksandr "Alex" Zinenko 
18175a9bdd85SOleksandr "Alex" Zinenko bool transform::isHandleConsumed(Value handle,
18185a9bdd85SOleksandr "Alex" Zinenko                                  transform::TransformOpInterface transform) {
18195a9bdd85SOleksandr "Alex" Zinenko   auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
18205a9bdd85SOleksandr "Alex" Zinenko   SmallVector<MemoryEffects::EffectInstance> effects;
18215a9bdd85SOleksandr "Alex" Zinenko   iface.getEffectsOnValue(handle, effects);
18225a9bdd85SOleksandr "Alex" Zinenko   return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) &&
18235a9bdd85SOleksandr "Alex" Zinenko          ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects);
18245a9bdd85SOleksandr "Alex" Zinenko }
18255a9bdd85SOleksandr "Alex" Zinenko 
18265a9bdd85SOleksandr "Alex" Zinenko void transform::producesHandle(
18272c1ae801Sdonald chen     ResultRange handles,
18285a9bdd85SOleksandr "Alex" Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
18292c1ae801Sdonald chen   for (OpResult handle : handles) {
18302c1ae801Sdonald chen     effects.emplace_back(MemoryEffects::Allocate::get(), handle,
18312c1ae801Sdonald chen                          TransformMappingResource::get());
18322c1ae801Sdonald chen     effects.emplace_back(MemoryEffects::Write::get(), handle,
18332c1ae801Sdonald chen                          TransformMappingResource::get());
18342c1ae801Sdonald chen   }
18352c1ae801Sdonald chen }
18362c1ae801Sdonald chen 
18372c1ae801Sdonald chen void transform::producesHandle(
18382c1ae801Sdonald chen     MutableArrayRef<BlockArgument> handles,
18392c1ae801Sdonald chen     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
18402c1ae801Sdonald chen   for (BlockArgument handle : handles) {
18415a9bdd85SOleksandr "Alex" Zinenko     effects.emplace_back(MemoryEffects::Allocate::get(), handle,
18425a9bdd85SOleksandr "Alex" Zinenko                          TransformMappingResource::get());
18435a9bdd85SOleksandr "Alex" Zinenko     effects.emplace_back(MemoryEffects::Write::get(), handle,
18445a9bdd85SOleksandr "Alex" Zinenko                          TransformMappingResource::get());
18455a9bdd85SOleksandr "Alex" Zinenko   }
18465a9bdd85SOleksandr "Alex" Zinenko }
18475a9bdd85SOleksandr "Alex" Zinenko 
18485a9bdd85SOleksandr "Alex" Zinenko void transform::onlyReadsHandle(
18492c1ae801Sdonald chen     MutableArrayRef<OpOperand> handles,
18505a9bdd85SOleksandr "Alex" Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
18512c1ae801Sdonald chen   for (OpOperand &handle : handles) {
18522c1ae801Sdonald chen     effects.emplace_back(MemoryEffects::Read::get(), &handle,
18535a9bdd85SOleksandr "Alex" Zinenko                          TransformMappingResource::get());
18545a9bdd85SOleksandr "Alex" Zinenko   }
18555a9bdd85SOleksandr "Alex" Zinenko }
18565a9bdd85SOleksandr "Alex" Zinenko 
18575a9bdd85SOleksandr "Alex" Zinenko void transform::modifiesPayload(
18585a9bdd85SOleksandr "Alex" Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
18595a9bdd85SOleksandr "Alex" Zinenko   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
18605a9bdd85SOleksandr "Alex" Zinenko   effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
18615a9bdd85SOleksandr "Alex" Zinenko }
18625a9bdd85SOleksandr "Alex" Zinenko 
18635a9bdd85SOleksandr "Alex" Zinenko void transform::onlyReadsPayload(
18645a9bdd85SOleksandr "Alex" Zinenko     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
18655a9bdd85SOleksandr "Alex" Zinenko   effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
18665a9bdd85SOleksandr "Alex" Zinenko }
18675a9bdd85SOleksandr "Alex" Zinenko 
18685a9bdd85SOleksandr "Alex" Zinenko bool transform::doesModifyPayload(transform::TransformOpInterface transform) {
18695a9bdd85SOleksandr "Alex" Zinenko   auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
18705a9bdd85SOleksandr "Alex" Zinenko   SmallVector<MemoryEffects::EffectInstance> effects;
18715a9bdd85SOleksandr "Alex" Zinenko   iface.getEffects(effects);
18725a9bdd85SOleksandr "Alex" Zinenko   return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects);
18735a9bdd85SOleksandr "Alex" Zinenko }
18745a9bdd85SOleksandr "Alex" Zinenko 
18755a9bdd85SOleksandr "Alex" Zinenko bool transform::doesReadPayload(transform::TransformOpInterface transform) {
18765a9bdd85SOleksandr "Alex" Zinenko   auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
18775a9bdd85SOleksandr "Alex" Zinenko   SmallVector<MemoryEffects::EffectInstance> effects;
18785a9bdd85SOleksandr "Alex" Zinenko   iface.getEffects(effects);
18795a9bdd85SOleksandr "Alex" Zinenko   return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects);
18805a9bdd85SOleksandr "Alex" Zinenko }
18815a9bdd85SOleksandr "Alex" Zinenko 
18825a9bdd85SOleksandr "Alex" Zinenko void transform::getConsumedBlockArguments(
18835a9bdd85SOleksandr "Alex" Zinenko     Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) {
18845a9bdd85SOleksandr "Alex" Zinenko   SmallVector<MemoryEffects::EffectInstance> effects;
18855a9bdd85SOleksandr "Alex" Zinenko   for (Operation &nested : block) {
18865a9bdd85SOleksandr "Alex" Zinenko     auto iface = dyn_cast<MemoryEffectOpInterface>(nested);
18875a9bdd85SOleksandr "Alex" Zinenko     if (!iface)
18885a9bdd85SOleksandr "Alex" Zinenko       continue;
18895a9bdd85SOleksandr "Alex" Zinenko 
18905a9bdd85SOleksandr "Alex" Zinenko     effects.clear();
18915a9bdd85SOleksandr "Alex" Zinenko     iface.getEffects(effects);
18925a9bdd85SOleksandr "Alex" Zinenko     for (const MemoryEffects::EffectInstance &effect : effects) {
18935a9bdd85SOleksandr "Alex" Zinenko       BlockArgument argument =
18945a9bdd85SOleksandr "Alex" Zinenko           dyn_cast_or_null<BlockArgument>(effect.getValue());
18955a9bdd85SOleksandr "Alex" Zinenko       if (!argument || argument.getOwner() != &block ||
18965a9bdd85SOleksandr "Alex" Zinenko           !isa<MemoryEffects::Free>(effect.getEffect()) ||
18975a9bdd85SOleksandr "Alex" Zinenko           effect.getResource() != transform::TransformMappingResource::get()) {
18985a9bdd85SOleksandr "Alex" Zinenko         continue;
18995a9bdd85SOleksandr "Alex" Zinenko       }
19005a9bdd85SOleksandr "Alex" Zinenko       consumedArguments.insert(argument.getArgNumber());
19015a9bdd85SOleksandr "Alex" Zinenko     }
19025a9bdd85SOleksandr "Alex" Zinenko   }
19035a9bdd85SOleksandr "Alex" Zinenko }
19045a9bdd85SOleksandr "Alex" Zinenko 
19055a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
19065a9bdd85SOleksandr "Alex" Zinenko // Utilities for TransformOpInterface.
19075a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
19085a9bdd85SOleksandr "Alex" Zinenko 
19095a9bdd85SOleksandr "Alex" Zinenko SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands(
19105a9bdd85SOleksandr "Alex" Zinenko     TransformOpInterface transformOp) {
19115a9bdd85SOleksandr "Alex" Zinenko   SmallVector<OpOperand *> consumedOperands;
19125a9bdd85SOleksandr "Alex" Zinenko   consumedOperands.reserve(transformOp->getNumOperands());
19135a9bdd85SOleksandr "Alex" Zinenko   auto memEffectInterface =
19145a9bdd85SOleksandr "Alex" Zinenko       cast<MemoryEffectOpInterface>(transformOp.getOperation());
19155a9bdd85SOleksandr "Alex" Zinenko   SmallVector<MemoryEffects::EffectInstance, 2> effects;
19165a9bdd85SOleksandr "Alex" Zinenko   for (OpOperand &target : transformOp->getOpOperands()) {
19175a9bdd85SOleksandr "Alex" Zinenko     effects.clear();
19185a9bdd85SOleksandr "Alex" Zinenko     memEffectInterface.getEffectsOnValue(target.get(), effects);
19195a9bdd85SOleksandr "Alex" Zinenko     if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
19205a9bdd85SOleksandr "Alex" Zinenko           return isa<transform::TransformMappingResource>(
19215a9bdd85SOleksandr "Alex" Zinenko                      effect.getResource()) &&
19225a9bdd85SOleksandr "Alex" Zinenko                  isa<MemoryEffects::Free>(effect.getEffect());
19235a9bdd85SOleksandr "Alex" Zinenko         })) {
19245a9bdd85SOleksandr "Alex" Zinenko       consumedOperands.push_back(&target);
19255a9bdd85SOleksandr "Alex" Zinenko     }
19265a9bdd85SOleksandr "Alex" Zinenko   }
19275a9bdd85SOleksandr "Alex" Zinenko   return consumedOperands;
19285a9bdd85SOleksandr "Alex" Zinenko }
19295a9bdd85SOleksandr "Alex" Zinenko 
19305a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
19315a9bdd85SOleksandr "Alex" Zinenko   auto iface = cast<MemoryEffectOpInterface>(op);
19325a9bdd85SOleksandr "Alex" Zinenko   SmallVector<MemoryEffects::EffectInstance> effects;
19335a9bdd85SOleksandr "Alex" Zinenko   iface.getEffects(effects);
19345a9bdd85SOleksandr "Alex" Zinenko 
19355a9bdd85SOleksandr "Alex" Zinenko   auto effectsOn = [&](Value value) {
19365a9bdd85SOleksandr "Alex" Zinenko     return llvm::make_filter_range(
19375a9bdd85SOleksandr "Alex" Zinenko         effects, [value](const MemoryEffects::EffectInstance &instance) {
19385a9bdd85SOleksandr "Alex" Zinenko           return instance.getValue() == value;
19395a9bdd85SOleksandr "Alex" Zinenko         });
19405a9bdd85SOleksandr "Alex" Zinenko   };
19415a9bdd85SOleksandr "Alex" Zinenko 
19425a9bdd85SOleksandr "Alex" Zinenko   std::optional<unsigned> firstConsumedOperand;
19435a9bdd85SOleksandr "Alex" Zinenko   for (OpOperand &operand : op->getOpOperands()) {
19445a9bdd85SOleksandr "Alex" Zinenko     auto range = effectsOn(operand.get());
19455a9bdd85SOleksandr "Alex" Zinenko     if (range.empty()) {
19465a9bdd85SOleksandr "Alex" Zinenko       InFlightDiagnostic diag =
19475a9bdd85SOleksandr "Alex" Zinenko           op->emitError() << "TransformOpInterface requires memory effects "
19485a9bdd85SOleksandr "Alex" Zinenko                              "on operands to be specified";
19495a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote() << "no effects specified for operand #"
19505a9bdd85SOleksandr "Alex" Zinenko                         << operand.getOperandNumber();
19515a9bdd85SOleksandr "Alex" Zinenko       return diag;
19525a9bdd85SOleksandr "Alex" Zinenko     }
19535a9bdd85SOleksandr "Alex" Zinenko     if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) {
19545a9bdd85SOleksandr "Alex" Zinenko       InFlightDiagnostic diag = op->emitError()
19555a9bdd85SOleksandr "Alex" Zinenko                                 << "TransformOpInterface did not expect "
19565a9bdd85SOleksandr "Alex" Zinenko                                    "'allocate' memory effect on an operand";
19575a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote() << "specified for operand #"
19585a9bdd85SOleksandr "Alex" Zinenko                         << operand.getOperandNumber();
19595a9bdd85SOleksandr "Alex" Zinenko       return diag;
19605a9bdd85SOleksandr "Alex" Zinenko     }
19615a9bdd85SOleksandr "Alex" Zinenko     if (!firstConsumedOperand &&
19625a9bdd85SOleksandr "Alex" Zinenko         ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) {
19635a9bdd85SOleksandr "Alex" Zinenko       firstConsumedOperand = operand.getOperandNumber();
19645a9bdd85SOleksandr "Alex" Zinenko     }
19655a9bdd85SOleksandr "Alex" Zinenko   }
19665a9bdd85SOleksandr "Alex" Zinenko 
19675a9bdd85SOleksandr "Alex" Zinenko   if (firstConsumedOperand &&
19685a9bdd85SOleksandr "Alex" Zinenko       !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) {
19695a9bdd85SOleksandr "Alex" Zinenko     InFlightDiagnostic diag =
19705a9bdd85SOleksandr "Alex" Zinenko         op->emitError()
19715a9bdd85SOleksandr "Alex" Zinenko         << "TransformOpInterface expects ops consuming operands to have a "
19725a9bdd85SOleksandr "Alex" Zinenko            "'write' effect on the payload resource";
19735a9bdd85SOleksandr "Alex" Zinenko     diag.attachNote() << "consumes operand #" << *firstConsumedOperand;
19745a9bdd85SOleksandr "Alex" Zinenko     return diag;
19755a9bdd85SOleksandr "Alex" Zinenko   }
19765a9bdd85SOleksandr "Alex" Zinenko 
19775a9bdd85SOleksandr "Alex" Zinenko   for (OpResult result : op->getResults()) {
19785a9bdd85SOleksandr "Alex" Zinenko     auto range = effectsOn(result);
19795a9bdd85SOleksandr "Alex" Zinenko     if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(
19805a9bdd85SOleksandr "Alex" Zinenko             range)) {
19815a9bdd85SOleksandr "Alex" Zinenko       InFlightDiagnostic diag =
19825a9bdd85SOleksandr "Alex" Zinenko           op->emitError() << "TransformOpInterface requires 'allocate' memory "
19835a9bdd85SOleksandr "Alex" Zinenko                              "effect to be specified for results";
19845a9bdd85SOleksandr "Alex" Zinenko       diag.attachNote() << "no 'allocate' effect specified for result #"
19855a9bdd85SOleksandr "Alex" Zinenko                         << result.getResultNumber();
19865a9bdd85SOleksandr "Alex" Zinenko       return diag;
19875a9bdd85SOleksandr "Alex" Zinenko     }
19885a9bdd85SOleksandr "Alex" Zinenko   }
19895a9bdd85SOleksandr "Alex" Zinenko 
19905a9bdd85SOleksandr "Alex" Zinenko   return success();
19915a9bdd85SOleksandr "Alex" Zinenko }
19925a9bdd85SOleksandr "Alex" Zinenko 
19935a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
19945a9bdd85SOleksandr "Alex" Zinenko // Entry point.
19955a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
19965a9bdd85SOleksandr "Alex" Zinenko 
19975a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::applyTransforms(
19985a9bdd85SOleksandr "Alex" Zinenko     Operation *payloadRoot, TransformOpInterface transform,
19995a9bdd85SOleksandr "Alex" Zinenko     const RaggedArray<MappedValue> &extraMapping,
20006634d44eSAmy Wang     const TransformOptions &options, bool enforceToplevelTransformOp,
20016634d44eSAmy Wang     function_ref<void(TransformState &)> stateInitializer,
20026634d44eSAmy Wang     function_ref<LogicalResult(TransformState &)> stateExporter) {
20035a9bdd85SOleksandr "Alex" Zinenko   if (enforceToplevelTransformOp) {
20045a9bdd85SOleksandr "Alex" Zinenko     if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
20055a9bdd85SOleksandr "Alex" Zinenko         transform->getNumOperands() != 0) {
20065a9bdd85SOleksandr "Alex" Zinenko       return transform->emitError()
20075a9bdd85SOleksandr "Alex" Zinenko              << "expected transform to start at the top-level transform op";
20085a9bdd85SOleksandr "Alex" Zinenko     }
20095a9bdd85SOleksandr "Alex" Zinenko   } else if (failed(
20105a9bdd85SOleksandr "Alex" Zinenko                  detail::verifyPossibleTopLevelTransformOpTrait(transform))) {
20115a9bdd85SOleksandr "Alex" Zinenko     return failure();
20125a9bdd85SOleksandr "Alex" Zinenko   }
20135a9bdd85SOleksandr "Alex" Zinenko 
20145a9bdd85SOleksandr "Alex" Zinenko   TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
20155a9bdd85SOleksandr "Alex" Zinenko                        options);
20166634d44eSAmy Wang   if (stateInitializer)
20176634d44eSAmy Wang     stateInitializer(state);
20186634d44eSAmy Wang   if (state.applyTransform(transform).checkAndReport().failed())
20196634d44eSAmy Wang     return failure();
20206634d44eSAmy Wang   if (stateExporter)
20216634d44eSAmy Wang     return stateExporter(state);
20226634d44eSAmy Wang   return success();
20235a9bdd85SOleksandr "Alex" Zinenko }
20245a9bdd85SOleksandr "Alex" Zinenko 
20255a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
20265a9bdd85SOleksandr "Alex" Zinenko // Generated interface implementation.
20275a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===//
20285a9bdd85SOleksandr "Alex" Zinenko 
20295a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc"
20305a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc"
2031