15a9bdd85SOleksandr "Alex" Zinenko //===- TransformInterfaces.cpp - Transform Dialect Interfaces -------------===// 25a9bdd85SOleksandr "Alex" Zinenko // 35a9bdd85SOleksandr "Alex" Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45a9bdd85SOleksandr "Alex" Zinenko // See https://llvm.org/LICENSE.txt for license information. 55a9bdd85SOleksandr "Alex" Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65a9bdd85SOleksandr "Alex" Zinenko // 75a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 85a9bdd85SOleksandr "Alex" Zinenko 95a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 105a9bdd85SOleksandr "Alex" Zinenko 115a9bdd85SOleksandr "Alex" Zinenko #include "mlir/IR/Diagnostics.h" 125a9bdd85SOleksandr "Alex" Zinenko #include "mlir/IR/Operation.h" 135a9bdd85SOleksandr "Alex" Zinenko #include "mlir/IR/PatternMatch.h" 145a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Interfaces/CastInterfaces.h" 155a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 165a9bdd85SOleksandr "Alex" Zinenko #include "llvm/ADT/STLExtras.h" 175a9bdd85SOleksandr "Alex" Zinenko #include "llvm/ADT/ScopeExit.h" 185a9bdd85SOleksandr "Alex" Zinenko #include "llvm/Support/Debug.h" 195a9bdd85SOleksandr "Alex" Zinenko #include "llvm/Support/ErrorHandling.h" 205a9bdd85SOleksandr "Alex" Zinenko 215a9bdd85SOleksandr "Alex" Zinenko #define DEBUG_TYPE "transform-dialect" 225a9bdd85SOleksandr "Alex" Zinenko #define DEBUG_TYPE_FULL "transform-dialect-full" 235a9bdd85SOleksandr "Alex" Zinenko #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all" 245a9bdd85SOleksandr "Alex" Zinenko #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") 255a9bdd85SOleksandr "Alex" Zinenko #define LDBG(X) LLVM_DEBUG(DBGS() << (X)) 265a9bdd85SOleksandr "Alex" Zinenko #define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X))) 275a9bdd85SOleksandr "Alex" Zinenko 285a9bdd85SOleksandr "Alex" Zinenko using namespace mlir; 295a9bdd85SOleksandr "Alex" Zinenko 305a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 315a9bdd85SOleksandr "Alex" Zinenko // Helper functions 325a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 335a9bdd85SOleksandr "Alex" Zinenko 345a9bdd85SOleksandr "Alex" Zinenko /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors 355a9bdd85SOleksandr "Alex" Zinenko /// properly dominates `b` and `b` is not inside `a`. 365a9bdd85SOleksandr "Alex" Zinenko static bool happensBefore(Operation *a, Operation *b) { 375a9bdd85SOleksandr "Alex" Zinenko do { 385a9bdd85SOleksandr "Alex" Zinenko if (a->isProperAncestor(b)) 395a9bdd85SOleksandr "Alex" Zinenko return false; 405a9bdd85SOleksandr "Alex" Zinenko if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) { 415a9bdd85SOleksandr "Alex" Zinenko return a->isBeforeInBlock(bAncestor); 425a9bdd85SOleksandr "Alex" Zinenko } 435a9bdd85SOleksandr "Alex" Zinenko } while ((a = a->getParentOp())); 445a9bdd85SOleksandr "Alex" Zinenko return false; 455a9bdd85SOleksandr "Alex" Zinenko } 465a9bdd85SOleksandr "Alex" Zinenko 475a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 485a9bdd85SOleksandr "Alex" Zinenko // TransformState 495a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 505a9bdd85SOleksandr "Alex" Zinenko 515a9bdd85SOleksandr "Alex" Zinenko constexpr const Value transform::TransformState::kTopLevelValue; 525a9bdd85SOleksandr "Alex" Zinenko 535a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::TransformState( 545a9bdd85SOleksandr "Alex" Zinenko Region *region, Operation *payloadRoot, 555a9bdd85SOleksandr "Alex" Zinenko const RaggedArray<MappedValue> &extraMappings, 565a9bdd85SOleksandr "Alex" Zinenko const TransformOptions &options) 575a9bdd85SOleksandr "Alex" Zinenko : topLevel(payloadRoot), options(options) { 585a9bdd85SOleksandr "Alex" Zinenko topLevelMappedValues.reserve(extraMappings.size()); 595a9bdd85SOleksandr "Alex" Zinenko for (ArrayRef<MappedValue> mapping : extraMappings) 605a9bdd85SOleksandr "Alex" Zinenko topLevelMappedValues.push_back(mapping); 615a9bdd85SOleksandr "Alex" Zinenko if (region) { 625a9bdd85SOleksandr "Alex" Zinenko RegionScope *scope = new RegionScope(*this, *region); 635a9bdd85SOleksandr "Alex" Zinenko topLevelRegionScope.reset(scope); 645a9bdd85SOleksandr "Alex" Zinenko } 655a9bdd85SOleksandr "Alex" Zinenko } 665a9bdd85SOleksandr "Alex" Zinenko 675a9bdd85SOleksandr "Alex" Zinenko Operation *transform::TransformState::getTopLevel() const { return topLevel; } 685a9bdd85SOleksandr "Alex" Zinenko 695a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Operation *> 705a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::getPayloadOpsView(Value value) const { 715a9bdd85SOleksandr "Alex" Zinenko const TransformOpMapping &operationMapping = getMapping(value).direct; 725a9bdd85SOleksandr "Alex" Zinenko auto iter = operationMapping.find(value); 735a9bdd85SOleksandr "Alex" Zinenko assert(iter != operationMapping.end() && 745a9bdd85SOleksandr "Alex" Zinenko "cannot find mapping for payload handle (param/value handle " 755a9bdd85SOleksandr "Alex" Zinenko "provided?)"); 765a9bdd85SOleksandr "Alex" Zinenko return iter->getSecond(); 775a9bdd85SOleksandr "Alex" Zinenko } 785a9bdd85SOleksandr "Alex" Zinenko 795a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Attribute> transform::TransformState::getParams(Value value) const { 805a9bdd85SOleksandr "Alex" Zinenko const ParamMapping &mapping = getMapping(value).params; 815a9bdd85SOleksandr "Alex" Zinenko auto iter = mapping.find(value); 825a9bdd85SOleksandr "Alex" Zinenko assert(iter != mapping.end() && "cannot find mapping for param handle " 835a9bdd85SOleksandr "Alex" Zinenko "(operation/value handle provided?)"); 845a9bdd85SOleksandr "Alex" Zinenko return iter->getSecond(); 855a9bdd85SOleksandr "Alex" Zinenko } 865a9bdd85SOleksandr "Alex" Zinenko 875a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Value> 885a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::getPayloadValuesView(Value handleValue) const { 895a9bdd85SOleksandr "Alex" Zinenko const ValueMapping &mapping = getMapping(handleValue).values; 905a9bdd85SOleksandr "Alex" Zinenko auto iter = mapping.find(handleValue); 915a9bdd85SOleksandr "Alex" Zinenko assert(iter != mapping.end() && "cannot find mapping for value handle " 925a9bdd85SOleksandr "Alex" Zinenko "(param/operation handle provided?)"); 935a9bdd85SOleksandr "Alex" Zinenko return iter->getSecond(); 945a9bdd85SOleksandr "Alex" Zinenko } 955a9bdd85SOleksandr "Alex" Zinenko 965a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::getHandlesForPayloadOp( 975a9bdd85SOleksandr "Alex" Zinenko Operation *op, SmallVectorImpl<Value> &handles, 985a9bdd85SOleksandr "Alex" Zinenko bool includeOutOfScope) const { 995a9bdd85SOleksandr "Alex" Zinenko bool found = false; 1005a9bdd85SOleksandr "Alex" Zinenko for (const auto &[region, mapping] : llvm::reverse(mappings)) { 1015a9bdd85SOleksandr "Alex" Zinenko auto iterator = mapping->reverse.find(op); 1025a9bdd85SOleksandr "Alex" Zinenko if (iterator != mapping->reverse.end()) { 1035a9bdd85SOleksandr "Alex" Zinenko llvm::append_range(handles, iterator->getSecond()); 1045a9bdd85SOleksandr "Alex" Zinenko found = true; 1055a9bdd85SOleksandr "Alex" Zinenko } 1065a9bdd85SOleksandr "Alex" Zinenko // Stop looking when reaching a region that is isolated from above. 1075a9bdd85SOleksandr "Alex" Zinenko if (!includeOutOfScope && 1085a9bdd85SOleksandr "Alex" Zinenko region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 1095a9bdd85SOleksandr "Alex" Zinenko break; 1105a9bdd85SOleksandr "Alex" Zinenko } 1115a9bdd85SOleksandr "Alex" Zinenko 1125a9bdd85SOleksandr "Alex" Zinenko return success(found); 1135a9bdd85SOleksandr "Alex" Zinenko } 1145a9bdd85SOleksandr "Alex" Zinenko 1155a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::getHandlesForPayloadValue( 1165a9bdd85SOleksandr "Alex" Zinenko Value payloadValue, SmallVectorImpl<Value> &handles, 1175a9bdd85SOleksandr "Alex" Zinenko bool includeOutOfScope) const { 1185a9bdd85SOleksandr "Alex" Zinenko bool found = false; 1195a9bdd85SOleksandr "Alex" Zinenko for (const auto &[region, mapping] : llvm::reverse(mappings)) { 1205a9bdd85SOleksandr "Alex" Zinenko auto iterator = mapping->reverseValues.find(payloadValue); 1215a9bdd85SOleksandr "Alex" Zinenko if (iterator != mapping->reverseValues.end()) { 1225a9bdd85SOleksandr "Alex" Zinenko llvm::append_range(handles, iterator->getSecond()); 1235a9bdd85SOleksandr "Alex" Zinenko found = true; 1245a9bdd85SOleksandr "Alex" Zinenko } 1255a9bdd85SOleksandr "Alex" Zinenko // Stop looking when reaching a region that is isolated from above. 1265a9bdd85SOleksandr "Alex" Zinenko if (!includeOutOfScope && 1275a9bdd85SOleksandr "Alex" Zinenko region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 1285a9bdd85SOleksandr "Alex" Zinenko break; 1295a9bdd85SOleksandr "Alex" Zinenko } 1305a9bdd85SOleksandr "Alex" Zinenko 1315a9bdd85SOleksandr "Alex" Zinenko return success(found); 1325a9bdd85SOleksandr "Alex" Zinenko } 1335a9bdd85SOleksandr "Alex" Zinenko 1345a9bdd85SOleksandr "Alex" Zinenko /// Given a list of MappedValues, cast them to the value kind implied by the 1355a9bdd85SOleksandr "Alex" Zinenko /// interface of the handle type, and dispatch to one of the callbacks. 1365a9bdd85SOleksandr "Alex" Zinenko static DiagnosedSilenceableFailure dispatchMappedValues( 1375a9bdd85SOleksandr "Alex" Zinenko Value handle, ArrayRef<transform::MappedValue> values, 1385a9bdd85SOleksandr "Alex" Zinenko function_ref<LogicalResult(ArrayRef<Operation *>)> operationsFn, 1395a9bdd85SOleksandr "Alex" Zinenko function_ref<LogicalResult(ArrayRef<transform::Param>)> paramsFn, 1405a9bdd85SOleksandr "Alex" Zinenko function_ref<LogicalResult(ValueRange)> valuesFn) { 1415a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<transform::TransformHandleTypeInterface>(handle.getType())) { 1425a9bdd85SOleksandr "Alex" Zinenko SmallVector<Operation *> operations; 1435a9bdd85SOleksandr "Alex" Zinenko operations.reserve(values.size()); 1445a9bdd85SOleksandr "Alex" Zinenko for (transform::MappedValue value : values) { 1455a9bdd85SOleksandr "Alex" Zinenko if (auto *op = llvm::dyn_cast_if_present<Operation *>(value)) { 1465a9bdd85SOleksandr "Alex" Zinenko operations.push_back(op); 1475a9bdd85SOleksandr "Alex" Zinenko continue; 1485a9bdd85SOleksandr "Alex" Zinenko } 1495a9bdd85SOleksandr "Alex" Zinenko return emitSilenceableFailure(handle.getLoc()) 1505a9bdd85SOleksandr "Alex" Zinenko << "wrong kind of value provided for top-level operation handle"; 1515a9bdd85SOleksandr "Alex" Zinenko } 1525a9bdd85SOleksandr "Alex" Zinenko if (failed(operationsFn(operations))) 1535a9bdd85SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 1545a9bdd85SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::success(); 1555a9bdd85SOleksandr "Alex" Zinenko } 1565a9bdd85SOleksandr "Alex" Zinenko 1575a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<transform::TransformValueHandleTypeInterface>( 1585a9bdd85SOleksandr "Alex" Zinenko handle.getType())) { 1595a9bdd85SOleksandr "Alex" Zinenko SmallVector<Value> payloadValues; 1605a9bdd85SOleksandr "Alex" Zinenko payloadValues.reserve(values.size()); 1615a9bdd85SOleksandr "Alex" Zinenko for (transform::MappedValue value : values) { 1625a9bdd85SOleksandr "Alex" Zinenko if (auto v = llvm::dyn_cast_if_present<Value>(value)) { 1635a9bdd85SOleksandr "Alex" Zinenko payloadValues.push_back(v); 1645a9bdd85SOleksandr "Alex" Zinenko continue; 1655a9bdd85SOleksandr "Alex" Zinenko } 1665a9bdd85SOleksandr "Alex" Zinenko return emitSilenceableFailure(handle.getLoc()) 1675a9bdd85SOleksandr "Alex" Zinenko << "wrong kind of value provided for the top-level value handle"; 1685a9bdd85SOleksandr "Alex" Zinenko } 1695a9bdd85SOleksandr "Alex" Zinenko if (failed(valuesFn(payloadValues))) 1705a9bdd85SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 1715a9bdd85SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::success(); 1725a9bdd85SOleksandr "Alex" Zinenko } 1735a9bdd85SOleksandr "Alex" Zinenko 1745a9bdd85SOleksandr "Alex" Zinenko assert(llvm::isa<transform::TransformParamTypeInterface>(handle.getType()) && 1755a9bdd85SOleksandr "Alex" Zinenko "unsupported kind of block argument"); 1765a9bdd85SOleksandr "Alex" Zinenko SmallVector<transform::Param> parameters; 1775a9bdd85SOleksandr "Alex" Zinenko parameters.reserve(values.size()); 1785a9bdd85SOleksandr "Alex" Zinenko for (transform::MappedValue value : values) { 1795a9bdd85SOleksandr "Alex" Zinenko if (auto attr = llvm::dyn_cast_if_present<Attribute>(value)) { 1805a9bdd85SOleksandr "Alex" Zinenko parameters.push_back(attr); 1815a9bdd85SOleksandr "Alex" Zinenko continue; 1825a9bdd85SOleksandr "Alex" Zinenko } 1835a9bdd85SOleksandr "Alex" Zinenko return emitSilenceableFailure(handle.getLoc()) 1845a9bdd85SOleksandr "Alex" Zinenko << "wrong kind of value provided for top-level parameter"; 1855a9bdd85SOleksandr "Alex" Zinenko } 1865a9bdd85SOleksandr "Alex" Zinenko if (failed(paramsFn(parameters))) 1875a9bdd85SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 1885a9bdd85SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::success(); 1895a9bdd85SOleksandr "Alex" Zinenko } 1905a9bdd85SOleksandr "Alex" Zinenko 1915a9bdd85SOleksandr "Alex" Zinenko LogicalResult 1925a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::mapBlockArgument(BlockArgument argument, 1935a9bdd85SOleksandr "Alex" Zinenko ArrayRef<MappedValue> values) { 1945a9bdd85SOleksandr "Alex" Zinenko return dispatchMappedValues( 1955a9bdd85SOleksandr "Alex" Zinenko argument, values, 1965a9bdd85SOleksandr "Alex" Zinenko [&](ArrayRef<Operation *> operations) { 1975a9bdd85SOleksandr "Alex" Zinenko return setPayloadOps(argument, operations); 1985a9bdd85SOleksandr "Alex" Zinenko }, 1995a9bdd85SOleksandr "Alex" Zinenko [&](ArrayRef<Param> params) { 2005a9bdd85SOleksandr "Alex" Zinenko return setParams(argument, params); 2015a9bdd85SOleksandr "Alex" Zinenko }, 2025a9bdd85SOleksandr "Alex" Zinenko [&](ValueRange payloadValues) { 2035a9bdd85SOleksandr "Alex" Zinenko return setPayloadValues(argument, payloadValues); 2045a9bdd85SOleksandr "Alex" Zinenko }) 2055a9bdd85SOleksandr "Alex" Zinenko .checkAndReport(); 2065a9bdd85SOleksandr "Alex" Zinenko } 2075a9bdd85SOleksandr "Alex" Zinenko 208e4b04b39SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::mapBlockArguments( 209e4b04b39SOleksandr "Alex" Zinenko Block::BlockArgListType arguments, 210e4b04b39SOleksandr "Alex" Zinenko ArrayRef<SmallVector<MappedValue>> mapping) { 211e4b04b39SOleksandr "Alex" Zinenko for (auto &&[argument, values] : llvm::zip_equal(arguments, mapping)) 212e4b04b39SOleksandr "Alex" Zinenko if (failed(mapBlockArgument(argument, values))) 213e4b04b39SOleksandr "Alex" Zinenko return failure(); 214e4b04b39SOleksandr "Alex" Zinenko return success(); 215e4b04b39SOleksandr "Alex" Zinenko } 216e4b04b39SOleksandr "Alex" Zinenko 2175a9bdd85SOleksandr "Alex" Zinenko LogicalResult 2185a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::setPayloadOps(Value value, 2195a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Operation *> targets) { 2205a9bdd85SOleksandr "Alex" Zinenko assert(value != kTopLevelValue && 2215a9bdd85SOleksandr "Alex" Zinenko "attempting to reset the transformation root"); 2225a9bdd85SOleksandr "Alex" Zinenko assert(llvm::isa<TransformHandleTypeInterface>(value.getType()) && 2235a9bdd85SOleksandr "Alex" Zinenko "wrong handle type"); 2245a9bdd85SOleksandr "Alex" Zinenko 2255a9bdd85SOleksandr "Alex" Zinenko for (Operation *target : targets) { 2265a9bdd85SOleksandr "Alex" Zinenko if (target) 2275a9bdd85SOleksandr "Alex" Zinenko continue; 2285a9bdd85SOleksandr "Alex" Zinenko return emitError(value.getLoc()) 2295a9bdd85SOleksandr "Alex" Zinenko << "attempting to assign a null payload op to this transform value"; 2305a9bdd85SOleksandr "Alex" Zinenko } 2315a9bdd85SOleksandr "Alex" Zinenko 2325a9bdd85SOleksandr "Alex" Zinenko auto iface = llvm::cast<TransformHandleTypeInterface>(value.getType()); 2335a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure result = 2345a9bdd85SOleksandr "Alex" Zinenko iface.checkPayload(value.getLoc(), targets); 2355a9bdd85SOleksandr "Alex" Zinenko if (failed(result.checkAndReport())) 2365a9bdd85SOleksandr "Alex" Zinenko return failure(); 2375a9bdd85SOleksandr "Alex" Zinenko 2385a9bdd85SOleksandr "Alex" Zinenko // Setting new payload for the value without cleaning it first is a misuse of 2395a9bdd85SOleksandr "Alex" Zinenko // the API, assert here. 2405262865aSKazu Hirata SmallVector<Operation *> storedTargets(targets); 2415a9bdd85SOleksandr "Alex" Zinenko Mappings &mappings = getMapping(value); 2425a9bdd85SOleksandr "Alex" Zinenko bool inserted = 2435a9bdd85SOleksandr "Alex" Zinenko mappings.direct.insert({value, std::move(storedTargets)}).second; 2445a9bdd85SOleksandr "Alex" Zinenko assert(inserted && "value is already associated with another list"); 2455a9bdd85SOleksandr "Alex" Zinenko (void)inserted; 2465a9bdd85SOleksandr "Alex" Zinenko 2475a9bdd85SOleksandr "Alex" Zinenko for (Operation *op : targets) 2485a9bdd85SOleksandr "Alex" Zinenko mappings.reverse[op].push_back(value); 2495a9bdd85SOleksandr "Alex" Zinenko 2505a9bdd85SOleksandr "Alex" Zinenko return success(); 2515a9bdd85SOleksandr "Alex" Zinenko } 2525a9bdd85SOleksandr "Alex" Zinenko 2535a9bdd85SOleksandr "Alex" Zinenko LogicalResult 2545a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::setPayloadValues(Value handle, 2555a9bdd85SOleksandr "Alex" Zinenko ValueRange payloadValues) { 2565a9bdd85SOleksandr "Alex" Zinenko assert(handle != nullptr && "attempting to set params for a null value"); 2575a9bdd85SOleksandr "Alex" Zinenko assert(llvm::isa<TransformValueHandleTypeInterface>(handle.getType()) && 2585a9bdd85SOleksandr "Alex" Zinenko "wrong handle type"); 2595a9bdd85SOleksandr "Alex" Zinenko 2605a9bdd85SOleksandr "Alex" Zinenko for (Value payload : payloadValues) { 2615a9bdd85SOleksandr "Alex" Zinenko if (payload) 2625a9bdd85SOleksandr "Alex" Zinenko continue; 2635a9bdd85SOleksandr "Alex" Zinenko return emitError(handle.getLoc()) << "attempting to assign a null payload " 2645a9bdd85SOleksandr "Alex" Zinenko "value to this transform handle"; 2655a9bdd85SOleksandr "Alex" Zinenko } 2665a9bdd85SOleksandr "Alex" Zinenko 2675a9bdd85SOleksandr "Alex" Zinenko auto iface = llvm::cast<TransformValueHandleTypeInterface>(handle.getType()); 2685a9bdd85SOleksandr "Alex" Zinenko SmallVector<Value> payloadValueVector = llvm::to_vector(payloadValues); 2695a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure result = 2705a9bdd85SOleksandr "Alex" Zinenko iface.checkPayload(handle.getLoc(), payloadValueVector); 2715a9bdd85SOleksandr "Alex" Zinenko if (failed(result.checkAndReport())) 2725a9bdd85SOleksandr "Alex" Zinenko return failure(); 2735a9bdd85SOleksandr "Alex" Zinenko 2745a9bdd85SOleksandr "Alex" Zinenko Mappings &mappings = getMapping(handle); 2755a9bdd85SOleksandr "Alex" Zinenko bool inserted = 2765a9bdd85SOleksandr "Alex" Zinenko mappings.values.insert({handle, std::move(payloadValueVector)}).second; 2775a9bdd85SOleksandr "Alex" Zinenko assert( 2785a9bdd85SOleksandr "Alex" Zinenko inserted && 2795a9bdd85SOleksandr "Alex" Zinenko "value handle is already associated with another list of payload values"); 2805a9bdd85SOleksandr "Alex" Zinenko (void)inserted; 2815a9bdd85SOleksandr "Alex" Zinenko 2825a9bdd85SOleksandr "Alex" Zinenko for (Value payload : payloadValues) 2835a9bdd85SOleksandr "Alex" Zinenko mappings.reverseValues[payload].push_back(handle); 2845a9bdd85SOleksandr "Alex" Zinenko 2855a9bdd85SOleksandr "Alex" Zinenko return success(); 2865a9bdd85SOleksandr "Alex" Zinenko } 2875a9bdd85SOleksandr "Alex" Zinenko 2885a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::setParams(Value value, 2895a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Param> params) { 2905a9bdd85SOleksandr "Alex" Zinenko assert(value != nullptr && "attempting to set params for a null value"); 2915a9bdd85SOleksandr "Alex" Zinenko 2925a9bdd85SOleksandr "Alex" Zinenko for (Attribute attr : params) { 2935a9bdd85SOleksandr "Alex" Zinenko if (attr) 2945a9bdd85SOleksandr "Alex" Zinenko continue; 2955a9bdd85SOleksandr "Alex" Zinenko return emitError(value.getLoc()) 2965a9bdd85SOleksandr "Alex" Zinenko << "attempting to assign a null parameter to this transform value"; 2975a9bdd85SOleksandr "Alex" Zinenko } 2985a9bdd85SOleksandr "Alex" Zinenko 2995a9bdd85SOleksandr "Alex" Zinenko auto valueType = llvm::dyn_cast<TransformParamTypeInterface>(value.getType()); 3005a9bdd85SOleksandr "Alex" Zinenko assert(value && 3015a9bdd85SOleksandr "Alex" Zinenko "cannot associate parameter with a value of non-parameter type"); 3025a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure result = 3035a9bdd85SOleksandr "Alex" Zinenko valueType.checkPayload(value.getLoc(), params); 3045a9bdd85SOleksandr "Alex" Zinenko if (failed(result.checkAndReport())) 3055a9bdd85SOleksandr "Alex" Zinenko return failure(); 3065a9bdd85SOleksandr "Alex" Zinenko 3075a9bdd85SOleksandr "Alex" Zinenko Mappings &mappings = getMapping(value); 3085a9bdd85SOleksandr "Alex" Zinenko bool inserted = 3095a9bdd85SOleksandr "Alex" Zinenko mappings.params.insert({value, llvm::to_vector(params)}).second; 3105a9bdd85SOleksandr "Alex" Zinenko assert(inserted && "value is already associated with another list of params"); 3115a9bdd85SOleksandr "Alex" Zinenko (void)inserted; 3125a9bdd85SOleksandr "Alex" Zinenko return success(); 3135a9bdd85SOleksandr "Alex" Zinenko } 3145a9bdd85SOleksandr "Alex" Zinenko 3155a9bdd85SOleksandr "Alex" Zinenko template <typename Mapping, typename Key, typename Mapped> 3165a9bdd85SOleksandr "Alex" Zinenko void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) { 3175a9bdd85SOleksandr "Alex" Zinenko auto it = mapping.find(key); 3185a9bdd85SOleksandr "Alex" Zinenko if (it == mapping.end()) 3195a9bdd85SOleksandr "Alex" Zinenko return; 3205a9bdd85SOleksandr "Alex" Zinenko 3215a9bdd85SOleksandr "Alex" Zinenko llvm::erase(it->getSecond(), mapped); 3225a9bdd85SOleksandr "Alex" Zinenko if (it->getSecond().empty()) 3235a9bdd85SOleksandr "Alex" Zinenko mapping.erase(it); 3245a9bdd85SOleksandr "Alex" Zinenko } 3255a9bdd85SOleksandr "Alex" Zinenko 3265a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::forgetMapping(Value opHandle, 3275a9bdd85SOleksandr "Alex" Zinenko ValueRange origOpFlatResults, 3285a9bdd85SOleksandr "Alex" Zinenko bool allowOutOfScope) { 3295a9bdd85SOleksandr "Alex" Zinenko Mappings &mappings = getMapping(opHandle, allowOutOfScope); 3305a9bdd85SOleksandr "Alex" Zinenko for (Operation *op : mappings.direct[opHandle]) 3315a9bdd85SOleksandr "Alex" Zinenko dropMappingEntry(mappings.reverse, op, opHandle); 3325a9bdd85SOleksandr "Alex" Zinenko mappings.direct.erase(opHandle); 3336c7a3f80SMehdi Amini #if LLVM_ENABLE_ABI_BREAKING_CHECKS 3345a9bdd85SOleksandr "Alex" Zinenko // Payload IR is removed from the mapping. This invalidates the respective 3355a9bdd85SOleksandr "Alex" Zinenko // iterators. 3365a9bdd85SOleksandr "Alex" Zinenko mappings.incrementTimestamp(opHandle); 3375a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 3385a9bdd85SOleksandr "Alex" Zinenko 3395a9bdd85SOleksandr "Alex" Zinenko for (Value opResult : origOpFlatResults) { 3405a9bdd85SOleksandr "Alex" Zinenko SmallVector<Value> resultHandles; 3415a9bdd85SOleksandr "Alex" Zinenko (void)getHandlesForPayloadValue(opResult, resultHandles); 3425a9bdd85SOleksandr "Alex" Zinenko for (Value resultHandle : resultHandles) { 3435a9bdd85SOleksandr "Alex" Zinenko Mappings &localMappings = getMapping(resultHandle); 3445a9bdd85SOleksandr "Alex" Zinenko dropMappingEntry(localMappings.values, resultHandle, opResult); 3456c7a3f80SMehdi Amini #if LLVM_ENABLE_ABI_BREAKING_CHECKS 3465a9bdd85SOleksandr "Alex" Zinenko // Payload IR is removed from the mapping. This invalidates the respective 3475a9bdd85SOleksandr "Alex" Zinenko // iterators. 3485a9bdd85SOleksandr "Alex" Zinenko mappings.incrementTimestamp(resultHandle); 3495a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 3505a9bdd85SOleksandr "Alex" Zinenko dropMappingEntry(localMappings.reverseValues, opResult, resultHandle); 3515a9bdd85SOleksandr "Alex" Zinenko } 3525a9bdd85SOleksandr "Alex" Zinenko } 3535a9bdd85SOleksandr "Alex" Zinenko } 3545a9bdd85SOleksandr "Alex" Zinenko 3555a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::forgetValueMapping( 3565a9bdd85SOleksandr "Alex" Zinenko Value valueHandle, ArrayRef<Operation *> payloadOperations) { 3575a9bdd85SOleksandr "Alex" Zinenko Mappings &mappings = getMapping(valueHandle); 3585a9bdd85SOleksandr "Alex" Zinenko for (Value payloadValue : mappings.reverseValues[valueHandle]) 3595a9bdd85SOleksandr "Alex" Zinenko dropMappingEntry(mappings.reverseValues, payloadValue, valueHandle); 3605a9bdd85SOleksandr "Alex" Zinenko mappings.values.erase(valueHandle); 3616c7a3f80SMehdi Amini #if LLVM_ENABLE_ABI_BREAKING_CHECKS 3625a9bdd85SOleksandr "Alex" Zinenko // Payload IR is removed from the mapping. This invalidates the respective 3635a9bdd85SOleksandr "Alex" Zinenko // iterators. 3645a9bdd85SOleksandr "Alex" Zinenko mappings.incrementTimestamp(valueHandle); 3655a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 3665a9bdd85SOleksandr "Alex" Zinenko 3675a9bdd85SOleksandr "Alex" Zinenko for (Operation *payloadOp : payloadOperations) { 3685a9bdd85SOleksandr "Alex" Zinenko SmallVector<Value> opHandles; 3695a9bdd85SOleksandr "Alex" Zinenko (void)getHandlesForPayloadOp(payloadOp, opHandles); 3705a9bdd85SOleksandr "Alex" Zinenko for (Value opHandle : opHandles) { 3715a9bdd85SOleksandr "Alex" Zinenko Mappings &localMappings = getMapping(opHandle); 3725a9bdd85SOleksandr "Alex" Zinenko dropMappingEntry(localMappings.direct, opHandle, payloadOp); 3735a9bdd85SOleksandr "Alex" Zinenko dropMappingEntry(localMappings.reverse, payloadOp, opHandle); 3745a9bdd85SOleksandr "Alex" Zinenko 3756c7a3f80SMehdi Amini #if LLVM_ENABLE_ABI_BREAKING_CHECKS 3765a9bdd85SOleksandr "Alex" Zinenko // Payload IR is removed from the mapping. This invalidates the respective 3775a9bdd85SOleksandr "Alex" Zinenko // iterators. 3785a9bdd85SOleksandr "Alex" Zinenko localMappings.incrementTimestamp(opHandle); 3795a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 3805a9bdd85SOleksandr "Alex" Zinenko } 3815a9bdd85SOleksandr "Alex" Zinenko } 3825a9bdd85SOleksandr "Alex" Zinenko } 3835a9bdd85SOleksandr "Alex" Zinenko 3845a9bdd85SOleksandr "Alex" Zinenko LogicalResult 3855a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::replacePayloadOp(Operation *op, 3865a9bdd85SOleksandr "Alex" Zinenko Operation *replacement) { 3875a9bdd85SOleksandr "Alex" Zinenko // TODO: consider invalidating the handles to nested objects here. 3885a9bdd85SOleksandr "Alex" Zinenko 3895a9bdd85SOleksandr "Alex" Zinenko #ifndef NDEBUG 3905a9bdd85SOleksandr "Alex" Zinenko for (Value opResult : op->getResults()) { 3915a9bdd85SOleksandr "Alex" Zinenko SmallVector<Value> valueHandles; 3925a9bdd85SOleksandr "Alex" Zinenko (void)getHandlesForPayloadValue(opResult, valueHandles, 3935a9bdd85SOleksandr "Alex" Zinenko /*includeOutOfScope=*/true); 3945a9bdd85SOleksandr "Alex" Zinenko assert(valueHandles.empty() && "expected no mapping to old results"); 3955a9bdd85SOleksandr "Alex" Zinenko } 3965a9bdd85SOleksandr "Alex" Zinenko #endif // NDEBUG 3975a9bdd85SOleksandr "Alex" Zinenko 3985a9bdd85SOleksandr "Alex" Zinenko // Drop the mapping between the op and all handles that point to it. Fail if 3995a9bdd85SOleksandr "Alex" Zinenko // there are no handles. 4005a9bdd85SOleksandr "Alex" Zinenko SmallVector<Value> opHandles; 4015a9bdd85SOleksandr "Alex" Zinenko if (failed(getHandlesForPayloadOp(op, opHandles, /*includeOutOfScope=*/true))) 4025a9bdd85SOleksandr "Alex" Zinenko return failure(); 4035a9bdd85SOleksandr "Alex" Zinenko for (Value handle : opHandles) { 4045a9bdd85SOleksandr "Alex" Zinenko Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); 4055a9bdd85SOleksandr "Alex" Zinenko dropMappingEntry(mappings.reverse, op, handle); 4065a9bdd85SOleksandr "Alex" Zinenko } 4075a9bdd85SOleksandr "Alex" Zinenko 4085a9bdd85SOleksandr "Alex" Zinenko // Replace the pointed-to object of all handles with the replacement object. 4095a9bdd85SOleksandr "Alex" Zinenko // In case a payload op was erased (replacement object is nullptr), a nullptr 4105a9bdd85SOleksandr "Alex" Zinenko // is stored in the mapping. These nullptrs are removed after each transform. 4115a9bdd85SOleksandr "Alex" Zinenko // Furthermore, nullptrs are not enumerated by payload op iterators. The 4125a9bdd85SOleksandr "Alex" Zinenko // relative order of ops is preserved. 4135a9bdd85SOleksandr "Alex" Zinenko // 4145a9bdd85SOleksandr "Alex" Zinenko // Removing an op from the mapping would be problematic because removing an 4155a9bdd85SOleksandr "Alex" Zinenko // element from an array invalidates iterators; merely changing the value of 4165a9bdd85SOleksandr "Alex" Zinenko // elements does not. 4175a9bdd85SOleksandr "Alex" Zinenko for (Value handle : opHandles) { 4185a9bdd85SOleksandr "Alex" Zinenko Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); 4195a9bdd85SOleksandr "Alex" Zinenko auto it = mappings.direct.find(handle); 4205a9bdd85SOleksandr "Alex" Zinenko if (it == mappings.direct.end()) 4215a9bdd85SOleksandr "Alex" Zinenko continue; 4225a9bdd85SOleksandr "Alex" Zinenko 4235a9bdd85SOleksandr "Alex" Zinenko SmallVector<Operation *, 2> &association = it->getSecond(); 4245a9bdd85SOleksandr "Alex" Zinenko // Note that an operation may be associated with the handle more than once. 4255a9bdd85SOleksandr "Alex" Zinenko for (Operation *&mapped : association) { 4265a9bdd85SOleksandr "Alex" Zinenko if (mapped == op) 4275a9bdd85SOleksandr "Alex" Zinenko mapped = replacement; 4285a9bdd85SOleksandr "Alex" Zinenko } 4295a9bdd85SOleksandr "Alex" Zinenko 4305a9bdd85SOleksandr "Alex" Zinenko if (replacement) { 4315a9bdd85SOleksandr "Alex" Zinenko mappings.reverse[replacement].push_back(handle); 4325a9bdd85SOleksandr "Alex" Zinenko } else { 4335a9bdd85SOleksandr "Alex" Zinenko opHandlesToCompact.insert(handle); 4345a9bdd85SOleksandr "Alex" Zinenko } 4355a9bdd85SOleksandr "Alex" Zinenko } 4365a9bdd85SOleksandr "Alex" Zinenko 4375a9bdd85SOleksandr "Alex" Zinenko return success(); 4385a9bdd85SOleksandr "Alex" Zinenko } 4395a9bdd85SOleksandr "Alex" Zinenko 4405a9bdd85SOleksandr "Alex" Zinenko LogicalResult 4415a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::replacePayloadValue(Value value, Value replacement) { 4425a9bdd85SOleksandr "Alex" Zinenko SmallVector<Value> valueHandles; 4435a9bdd85SOleksandr "Alex" Zinenko if (failed(getHandlesForPayloadValue(value, valueHandles, 4445a9bdd85SOleksandr "Alex" Zinenko /*includeOutOfScope=*/true))) 4455a9bdd85SOleksandr "Alex" Zinenko return failure(); 4465a9bdd85SOleksandr "Alex" Zinenko 4475a9bdd85SOleksandr "Alex" Zinenko for (Value handle : valueHandles) { 4485a9bdd85SOleksandr "Alex" Zinenko Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); 4495a9bdd85SOleksandr "Alex" Zinenko dropMappingEntry(mappings.reverseValues, value, handle); 4505a9bdd85SOleksandr "Alex" Zinenko 4515a9bdd85SOleksandr "Alex" Zinenko // If replacing with null, that is erasing the mapping, drop the mapping 4525a9bdd85SOleksandr "Alex" Zinenko // between the handles and the IR objects 4535a9bdd85SOleksandr "Alex" Zinenko if (!replacement) { 4545a9bdd85SOleksandr "Alex" Zinenko dropMappingEntry(mappings.values, handle, value); 4556c7a3f80SMehdi Amini #if LLVM_ENABLE_ABI_BREAKING_CHECKS 4565a9bdd85SOleksandr "Alex" Zinenko // Payload IR is removed from the mapping. This invalidates the respective 4575a9bdd85SOleksandr "Alex" Zinenko // iterators. 4585a9bdd85SOleksandr "Alex" Zinenko mappings.incrementTimestamp(handle); 4595a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 4605a9bdd85SOleksandr "Alex" Zinenko } else { 4615a9bdd85SOleksandr "Alex" Zinenko auto it = mappings.values.find(handle); 4625a9bdd85SOleksandr "Alex" Zinenko if (it == mappings.values.end()) 4635a9bdd85SOleksandr "Alex" Zinenko continue; 4645a9bdd85SOleksandr "Alex" Zinenko 4655a9bdd85SOleksandr "Alex" Zinenko SmallVector<Value> &association = it->getSecond(); 4665a9bdd85SOleksandr "Alex" Zinenko for (Value &mapped : association) { 4675a9bdd85SOleksandr "Alex" Zinenko if (mapped == value) 4685a9bdd85SOleksandr "Alex" Zinenko mapped = replacement; 4695a9bdd85SOleksandr "Alex" Zinenko } 4705a9bdd85SOleksandr "Alex" Zinenko mappings.reverseValues[replacement].push_back(handle); 4715a9bdd85SOleksandr "Alex" Zinenko } 4725a9bdd85SOleksandr "Alex" Zinenko } 4735a9bdd85SOleksandr "Alex" Zinenko 4745a9bdd85SOleksandr "Alex" Zinenko return success(); 4755a9bdd85SOleksandr "Alex" Zinenko } 4765a9bdd85SOleksandr "Alex" Zinenko 4775a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::recordOpHandleInvalidationOne( 4785a9bdd85SOleksandr "Alex" Zinenko OpOperand &consumingHandle, ArrayRef<Operation *> potentialAncestors, 4795a9bdd85SOleksandr "Alex" Zinenko Operation *payloadOp, Value otherHandle, Value throughValue, 4805a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { 4815a9bdd85SOleksandr "Alex" Zinenko // If the op is associated with invalidated handle, skip the check as it 4825a9bdd85SOleksandr "Alex" Zinenko // may be reading invalid IR. This also ensures we report the first 4835a9bdd85SOleksandr "Alex" Zinenko // invalidation and not the last one. 4845a9bdd85SOleksandr "Alex" Zinenko if (invalidatedHandles.count(otherHandle) || 4855a9bdd85SOleksandr "Alex" Zinenko newlyInvalidated.count(otherHandle)) 4865a9bdd85SOleksandr "Alex" Zinenko return; 4875a9bdd85SOleksandr "Alex" Zinenko 4885a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("--recordOpHandleInvalidationOne\n"); 4895a9bdd85SOleksandr "Alex" Zinenko DEBUG_WITH_TYPE( 4905a9bdd85SOleksandr "Alex" Zinenko DEBUG_TYPE_FULL, 4915a9bdd85SOleksandr "Alex" Zinenko llvm::interleaveComma(potentialAncestors, DBGS() << "--ancestors: ", 4925a9bdd85SOleksandr "Alex" Zinenko [](Operation *op) { llvm::dbgs() << *op; }); 4935a9bdd85SOleksandr "Alex" Zinenko llvm::dbgs() << "\n"); 4945a9bdd85SOleksandr "Alex" Zinenko 4955a9bdd85SOleksandr "Alex" Zinenko Operation *owner = consumingHandle.getOwner(); 4965a9bdd85SOleksandr "Alex" Zinenko unsigned operandNo = consumingHandle.getOperandNumber(); 4975a9bdd85SOleksandr "Alex" Zinenko for (Operation *ancestor : potentialAncestors) { 4985a9bdd85SOleksandr "Alex" Zinenko // clang-format off 4995a9bdd85SOleksandr "Alex" Zinenko DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 5005a9bdd85SOleksandr "Alex" Zinenko { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); }); 5015a9bdd85SOleksandr "Alex" Zinenko DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 5025a9bdd85SOleksandr "Alex" Zinenko { (DBGS() << "----of payload with name: " 5035a9bdd85SOleksandr "Alex" Zinenko << payloadOp->getName().getIdentifier() << "\n"); }); 5045a9bdd85SOleksandr "Alex" Zinenko DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 5055a9bdd85SOleksandr "Alex" Zinenko { (DBGS() << "----of payload: " << *payloadOp << "\n"); }); 5065a9bdd85SOleksandr "Alex" Zinenko // clang-format on 5075a9bdd85SOleksandr "Alex" Zinenko if (!ancestor->isAncestor(payloadOp)) 5085a9bdd85SOleksandr "Alex" Zinenko continue; 5095a9bdd85SOleksandr "Alex" Zinenko 5105a9bdd85SOleksandr "Alex" Zinenko // Make sure the error-reporting lambda doesn't capture anything 5115a9bdd85SOleksandr "Alex" Zinenko // by-reference because it will go out of scope. Additionally, extract 5125a9bdd85SOleksandr "Alex" Zinenko // location from Payload IR ops because the ops themselves may be 5135a9bdd85SOleksandr "Alex" Zinenko // deleted before the lambda gets called. 5145a9bdd85SOleksandr "Alex" Zinenko Location ancestorLoc = ancestor->getLoc(); 5155a9bdd85SOleksandr "Alex" Zinenko Location opLoc = payloadOp->getLoc(); 5165a9bdd85SOleksandr "Alex" Zinenko std::optional<Location> throughValueLoc = 5175a9bdd85SOleksandr "Alex" Zinenko throughValue ? std::make_optional(throughValue.getLoc()) : std::nullopt; 5185a9bdd85SOleksandr "Alex" Zinenko newlyInvalidated[otherHandle] = [ancestorLoc, opLoc, owner, operandNo, 5195a9bdd85SOleksandr "Alex" Zinenko otherHandle, 5205a9bdd85SOleksandr "Alex" Zinenko throughValueLoc](Location currentLoc) { 5215a9bdd85SOleksandr "Alex" Zinenko InFlightDiagnostic diag = emitError(currentLoc) 5225a9bdd85SOleksandr "Alex" Zinenko << "op uses a handle invalidated by a " 5235a9bdd85SOleksandr "Alex" Zinenko "previously executed transform op"; 5245a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(otherHandle.getLoc()) << "handle to invalidated ops"; 5255a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(owner->getLoc()) 5265a9bdd85SOleksandr "Alex" Zinenko << "invalidated by this transform op that consumes its operand #" 5275a9bdd85SOleksandr "Alex" Zinenko << operandNo 5285a9bdd85SOleksandr "Alex" Zinenko << " and invalidates all handles to payload IR entities associated " 5295a9bdd85SOleksandr "Alex" Zinenko "with this operand and entities nested in them"; 5305a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(ancestorLoc) << "ancestor payload op"; 5315a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(opLoc) << "nested payload op"; 5325a9bdd85SOleksandr "Alex" Zinenko if (throughValueLoc) { 5335a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(*throughValueLoc) 5345a9bdd85SOleksandr "Alex" Zinenko << "consumed handle points to this payload value"; 5355a9bdd85SOleksandr "Alex" Zinenko } 5365a9bdd85SOleksandr "Alex" Zinenko }; 5375a9bdd85SOleksandr "Alex" Zinenko } 5385a9bdd85SOleksandr "Alex" Zinenko } 5395a9bdd85SOleksandr "Alex" Zinenko 5405a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::recordValueHandleInvalidationByOpHandleOne( 5415a9bdd85SOleksandr "Alex" Zinenko OpOperand &opHandle, ArrayRef<Operation *> potentialAncestors, 5425a9bdd85SOleksandr "Alex" Zinenko Value payloadValue, Value valueHandle, 5435a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { 5445a9bdd85SOleksandr "Alex" Zinenko // If the op is associated with invalidated handle, skip the check as it 5455a9bdd85SOleksandr "Alex" Zinenko // may be reading invalid IR. This also ensures we report the first 5465a9bdd85SOleksandr "Alex" Zinenko // invalidation and not the last one. 5475a9bdd85SOleksandr "Alex" Zinenko if (invalidatedHandles.count(valueHandle) || 5485a9bdd85SOleksandr "Alex" Zinenko newlyInvalidated.count(valueHandle)) 5495a9bdd85SOleksandr "Alex" Zinenko return; 5505a9bdd85SOleksandr "Alex" Zinenko 5515a9bdd85SOleksandr "Alex" Zinenko for (Operation *ancestor : potentialAncestors) { 5525a9bdd85SOleksandr "Alex" Zinenko Operation *definingOp; 5535a9bdd85SOleksandr "Alex" Zinenko std::optional<unsigned> resultNo; 5545a9bdd85SOleksandr "Alex" Zinenko unsigned argumentNo = std::numeric_limits<unsigned>::max(); 5555a9bdd85SOleksandr "Alex" Zinenko unsigned blockNo = std::numeric_limits<unsigned>::max(); 5565a9bdd85SOleksandr "Alex" Zinenko unsigned regionNo = std::numeric_limits<unsigned>::max(); 5575a9bdd85SOleksandr "Alex" Zinenko if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) { 5585a9bdd85SOleksandr "Alex" Zinenko definingOp = opResult.getOwner(); 5595a9bdd85SOleksandr "Alex" Zinenko resultNo = opResult.getResultNumber(); 5605a9bdd85SOleksandr "Alex" Zinenko } else { 5615a9bdd85SOleksandr "Alex" Zinenko auto arg = llvm::cast<BlockArgument>(payloadValue); 5625a9bdd85SOleksandr "Alex" Zinenko definingOp = arg.getParentBlock()->getParentOp(); 5635a9bdd85SOleksandr "Alex" Zinenko argumentNo = arg.getArgNumber(); 5645a9bdd85SOleksandr "Alex" Zinenko blockNo = std::distance(arg.getOwner()->getParent()->begin(), 5655a9bdd85SOleksandr "Alex" Zinenko arg.getOwner()->getIterator()); 5665a9bdd85SOleksandr "Alex" Zinenko regionNo = arg.getOwner()->getParent()->getRegionNumber(); 5675a9bdd85SOleksandr "Alex" Zinenko } 5685a9bdd85SOleksandr "Alex" Zinenko assert(definingOp && "expected the value to be defined by an op as result " 5695a9bdd85SOleksandr "Alex" Zinenko "or block argument"); 5705a9bdd85SOleksandr "Alex" Zinenko if (!ancestor->isAncestor(definingOp)) 5715a9bdd85SOleksandr "Alex" Zinenko continue; 5725a9bdd85SOleksandr "Alex" Zinenko 5735a9bdd85SOleksandr "Alex" Zinenko Operation *owner = opHandle.getOwner(); 5745a9bdd85SOleksandr "Alex" Zinenko unsigned operandNo = opHandle.getOperandNumber(); 5755a9bdd85SOleksandr "Alex" Zinenko Location ancestorLoc = ancestor->getLoc(); 5765a9bdd85SOleksandr "Alex" Zinenko Location opLoc = definingOp->getLoc(); 5775a9bdd85SOleksandr "Alex" Zinenko Location valueLoc = payloadValue.getLoc(); 5785a9bdd85SOleksandr "Alex" Zinenko newlyInvalidated[valueHandle] = [valueHandle, owner, operandNo, resultNo, 5795a9bdd85SOleksandr "Alex" Zinenko argumentNo, blockNo, regionNo, ancestorLoc, 5805a9bdd85SOleksandr "Alex" Zinenko opLoc, valueLoc](Location currentLoc) { 5815a9bdd85SOleksandr "Alex" Zinenko InFlightDiagnostic diag = emitError(currentLoc) 5825a9bdd85SOleksandr "Alex" Zinenko << "op uses a handle invalidated by a " 5835a9bdd85SOleksandr "Alex" Zinenko "previously executed transform op"; 5845a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(valueHandle.getLoc()) << "invalidated handle"; 5855a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(owner->getLoc()) 5865a9bdd85SOleksandr "Alex" Zinenko << "invalidated by this transform op that consumes its operand #" 5875a9bdd85SOleksandr "Alex" Zinenko << operandNo 5885a9bdd85SOleksandr "Alex" Zinenko << " and invalidates all handles to payload IR entities " 5895a9bdd85SOleksandr "Alex" Zinenko "associated with this operand and entities nested in them"; 5905a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(ancestorLoc) 5915a9bdd85SOleksandr "Alex" Zinenko << "ancestor op associated with the consumed handle"; 5925a9bdd85SOleksandr "Alex" Zinenko if (resultNo) { 5935a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(opLoc) 5945a9bdd85SOleksandr "Alex" Zinenko << "op defining the value as result #" << *resultNo; 5955a9bdd85SOleksandr "Alex" Zinenko } else { 5965a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(opLoc) 5975a9bdd85SOleksandr "Alex" Zinenko << "op defining the value as block argument #" << argumentNo 5985a9bdd85SOleksandr "Alex" Zinenko << " of block #" << blockNo << " in region #" << regionNo; 5995a9bdd85SOleksandr "Alex" Zinenko } 6005a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(valueLoc) << "payload value"; 6015a9bdd85SOleksandr "Alex" Zinenko }; 6025a9bdd85SOleksandr "Alex" Zinenko } 6035a9bdd85SOleksandr "Alex" Zinenko } 6045a9bdd85SOleksandr "Alex" Zinenko 6055a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::recordOpHandleInvalidation( 6065a9bdd85SOleksandr "Alex" Zinenko OpOperand &handle, ArrayRef<Operation *> potentialAncestors, 6075a9bdd85SOleksandr "Alex" Zinenko Value throughValue, 6085a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { 6095a9bdd85SOleksandr "Alex" Zinenko 6105a9bdd85SOleksandr "Alex" Zinenko if (potentialAncestors.empty()) { 6115a9bdd85SOleksandr "Alex" Zinenko DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { 6125a9bdd85SOleksandr "Alex" Zinenko (DBGS() << "----recording invalidation for empty handle: " << handle.get() 6135a9bdd85SOleksandr "Alex" Zinenko << "\n"); 6145a9bdd85SOleksandr "Alex" Zinenko }); 6155a9bdd85SOleksandr "Alex" Zinenko 6165a9bdd85SOleksandr "Alex" Zinenko Operation *owner = handle.getOwner(); 6175a9bdd85SOleksandr "Alex" Zinenko unsigned operandNo = handle.getOperandNumber(); 6185a9bdd85SOleksandr "Alex" Zinenko newlyInvalidated[handle.get()] = [owner, operandNo](Location currentLoc) { 6195a9bdd85SOleksandr "Alex" Zinenko InFlightDiagnostic diag = emitError(currentLoc) 6205a9bdd85SOleksandr "Alex" Zinenko << "op uses a handle associated with empty " 6215a9bdd85SOleksandr "Alex" Zinenko "payload and invalidated by a " 6225a9bdd85SOleksandr "Alex" Zinenko "previously executed transform op"; 6235a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(owner->getLoc()) 6245a9bdd85SOleksandr "Alex" Zinenko << "invalidated by this transform op that consumes its operand #" 6255a9bdd85SOleksandr "Alex" Zinenko << operandNo; 6265a9bdd85SOleksandr "Alex" Zinenko }; 6275a9bdd85SOleksandr "Alex" Zinenko return; 6285a9bdd85SOleksandr "Alex" Zinenko } 6295a9bdd85SOleksandr "Alex" Zinenko 6305a9bdd85SOleksandr "Alex" Zinenko // Iterate over the mapping and invalidate aliasing handles. This is quite 6315a9bdd85SOleksandr "Alex" Zinenko // expensive and only necessary for error reporting in case of transform 6325a9bdd85SOleksandr "Alex" Zinenko // dialect misuse with dangling handles. Iteration over the handles is based 6335a9bdd85SOleksandr "Alex" Zinenko // on the assumption that the number of handles is significantly less than the 6345a9bdd85SOleksandr "Alex" Zinenko // number of IR objects (operations and values). Alternatively, we could walk 6355a9bdd85SOleksandr "Alex" Zinenko // the IR nested in each payload op associated with the given handle and look 6365a9bdd85SOleksandr "Alex" Zinenko // for handles associated with each operation and value. 6375a9bdd85SOleksandr "Alex" Zinenko for (const auto &[region, mapping] : llvm::reverse(mappings)) { 6385a9bdd85SOleksandr "Alex" Zinenko // Go over all op handle mappings and mark as invalidated any handle 6395a9bdd85SOleksandr "Alex" Zinenko // pointing to any of the payload ops associated with the given handle or 6405a9bdd85SOleksandr "Alex" Zinenko // any op nested in them. 6415a9bdd85SOleksandr "Alex" Zinenko for (const auto &[payloadOp, otherHandles] : mapping->reverse) { 6425a9bdd85SOleksandr "Alex" Zinenko for (Value otherHandle : otherHandles) 6435a9bdd85SOleksandr "Alex" Zinenko recordOpHandleInvalidationOne(handle, potentialAncestors, payloadOp, 6445a9bdd85SOleksandr "Alex" Zinenko otherHandle, throughValue, 6455a9bdd85SOleksandr "Alex" Zinenko newlyInvalidated); 6465a9bdd85SOleksandr "Alex" Zinenko } 6475a9bdd85SOleksandr "Alex" Zinenko // Go over all value handle mappings and mark as invalidated any handle 6485a9bdd85SOleksandr "Alex" Zinenko // pointing to any result of the payload op associated with the given handle 6495a9bdd85SOleksandr "Alex" Zinenko // or any op nested in them. Similarly invalidate handles to argument of 6505a9bdd85SOleksandr "Alex" Zinenko // blocks belonging to any region of any payload op associated with the 6515a9bdd85SOleksandr "Alex" Zinenko // given handle or any op nested in them. 6525a9bdd85SOleksandr "Alex" Zinenko for (const auto &[payloadValue, valueHandles] : mapping->reverseValues) { 6535a9bdd85SOleksandr "Alex" Zinenko for (Value valueHandle : valueHandles) 6545a9bdd85SOleksandr "Alex" Zinenko recordValueHandleInvalidationByOpHandleOne(handle, potentialAncestors, 6555a9bdd85SOleksandr "Alex" Zinenko payloadValue, valueHandle, 6565a9bdd85SOleksandr "Alex" Zinenko newlyInvalidated); 6575a9bdd85SOleksandr "Alex" Zinenko } 6585a9bdd85SOleksandr "Alex" Zinenko 6595a9bdd85SOleksandr "Alex" Zinenko // Stop lookup when reaching a region that is isolated from above. 6605a9bdd85SOleksandr "Alex" Zinenko if (region->getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>()) 6615a9bdd85SOleksandr "Alex" Zinenko break; 6625a9bdd85SOleksandr "Alex" Zinenko } 6635a9bdd85SOleksandr "Alex" Zinenko } 6645a9bdd85SOleksandr "Alex" Zinenko 6655a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::recordValueHandleInvalidation( 6665a9bdd85SOleksandr "Alex" Zinenko OpOperand &valueHandle, 6675a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { 6685a9bdd85SOleksandr "Alex" Zinenko // Invalidate other handles to the same value. 6695a9bdd85SOleksandr "Alex" Zinenko for (Value payloadValue : getPayloadValuesView(valueHandle.get())) { 6705a9bdd85SOleksandr "Alex" Zinenko SmallVector<Value> otherValueHandles; 6715a9bdd85SOleksandr "Alex" Zinenko (void)getHandlesForPayloadValue(payloadValue, otherValueHandles); 6725a9bdd85SOleksandr "Alex" Zinenko for (Value otherHandle : otherValueHandles) { 6735a9bdd85SOleksandr "Alex" Zinenko Operation *owner = valueHandle.getOwner(); 6745a9bdd85SOleksandr "Alex" Zinenko unsigned operandNo = valueHandle.getOperandNumber(); 6755a9bdd85SOleksandr "Alex" Zinenko Location valueLoc = payloadValue.getLoc(); 6765a9bdd85SOleksandr "Alex" Zinenko newlyInvalidated[otherHandle] = [otherHandle, owner, operandNo, 6775a9bdd85SOleksandr "Alex" Zinenko valueLoc](Location currentLoc) { 6785a9bdd85SOleksandr "Alex" Zinenko InFlightDiagnostic diag = emitError(currentLoc) 6795a9bdd85SOleksandr "Alex" Zinenko << "op uses a handle invalidated by a " 6805a9bdd85SOleksandr "Alex" Zinenko "previously executed transform op"; 6815a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(otherHandle.getLoc()) << "invalidated handle"; 6825a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(owner->getLoc()) 6835a9bdd85SOleksandr "Alex" Zinenko << "invalidated by this transform op that consumes its operand #" 6845a9bdd85SOleksandr "Alex" Zinenko << operandNo 6855a9bdd85SOleksandr "Alex" Zinenko << " and invalidates handles to the same values as associated with " 6865a9bdd85SOleksandr "Alex" Zinenko "it"; 6875a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(valueLoc) << "payload value"; 6885a9bdd85SOleksandr "Alex" Zinenko }; 6895a9bdd85SOleksandr "Alex" Zinenko } 6905a9bdd85SOleksandr "Alex" Zinenko 6915a9bdd85SOleksandr "Alex" Zinenko if (auto opResult = llvm::dyn_cast<OpResult>(payloadValue)) { 6925a9bdd85SOleksandr "Alex" Zinenko Operation *payloadOp = opResult.getOwner(); 6935a9bdd85SOleksandr "Alex" Zinenko recordOpHandleInvalidation(valueHandle, payloadOp, payloadValue, 6945a9bdd85SOleksandr "Alex" Zinenko newlyInvalidated); 6955a9bdd85SOleksandr "Alex" Zinenko } else { 6965a9bdd85SOleksandr "Alex" Zinenko auto arg = llvm::dyn_cast<BlockArgument>(payloadValue); 6975a9bdd85SOleksandr "Alex" Zinenko for (Operation &payloadOp : *arg.getOwner()) 6985a9bdd85SOleksandr "Alex" Zinenko recordOpHandleInvalidation(valueHandle, &payloadOp, payloadValue, 6995a9bdd85SOleksandr "Alex" Zinenko newlyInvalidated); 7005a9bdd85SOleksandr "Alex" Zinenko } 7015a9bdd85SOleksandr "Alex" Zinenko } 7025a9bdd85SOleksandr "Alex" Zinenko } 7035a9bdd85SOleksandr "Alex" Zinenko 7045a9bdd85SOleksandr "Alex" Zinenko /// Checks that the operation does not use invalidated handles as operands. 7055a9bdd85SOleksandr "Alex" Zinenko /// Reports errors and returns failure if it does. Otherwise, invalidates the 7065a9bdd85SOleksandr "Alex" Zinenko /// handles consumed by the operation as well as any handles pointing to payload 7075a9bdd85SOleksandr "Alex" Zinenko /// IR operations nested in the operations associated with the consumed handles. 7085a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl( 7095a9bdd85SOleksandr "Alex" Zinenko transform::TransformOpInterface transform, 7105a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { 7115a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("--Start checkAndRecordHandleInvalidation\n"); 7125a9bdd85SOleksandr "Alex" Zinenko auto memoryEffectsIface = 7135a9bdd85SOleksandr "Alex" Zinenko cast<MemoryEffectOpInterface>(transform.getOperation()); 7145a9bdd85SOleksandr "Alex" Zinenko SmallVector<MemoryEffects::EffectInstance> effects; 7155a9bdd85SOleksandr "Alex" Zinenko memoryEffectsIface.getEffectsOnResource( 7165a9bdd85SOleksandr "Alex" Zinenko transform::TransformMappingResource::get(), effects); 7175a9bdd85SOleksandr "Alex" Zinenko 7185a9bdd85SOleksandr "Alex" Zinenko for (OpOperand &target : transform->getOpOperands()) { 7195a9bdd85SOleksandr "Alex" Zinenko DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { 7205a9bdd85SOleksandr "Alex" Zinenko (DBGS() << "----iterate on handle: " << target.get() << "\n"); 7215a9bdd85SOleksandr "Alex" Zinenko }); 7225a9bdd85SOleksandr "Alex" Zinenko // If the operand uses an invalidated handle, report it. If the operation 7235a9bdd85SOleksandr "Alex" Zinenko // allows handles to point to repeated payload operations, only report 7245a9bdd85SOleksandr "Alex" Zinenko // pre-existing invalidation errors. Otherwise, also report invalidations 7255a9bdd85SOleksandr "Alex" Zinenko // caused by the current transform operation affecting its other operands. 7265a9bdd85SOleksandr "Alex" Zinenko auto it = invalidatedHandles.find(target.get()); 7275a9bdd85SOleksandr "Alex" Zinenko auto nit = newlyInvalidated.find(target.get()); 7285a9bdd85SOleksandr "Alex" Zinenko if (it != invalidatedHandles.end()) { 7295a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("--End checkAndRecordHandleInvalidation, found already " 7305a9bdd85SOleksandr "Alex" Zinenko "invalidated -> FAILURE\n"); 7315a9bdd85SOleksandr "Alex" Zinenko return it->getSecond()(transform->getLoc()), failure(); 7325a9bdd85SOleksandr "Alex" Zinenko } 7335a9bdd85SOleksandr "Alex" Zinenko if (!transform.allowsRepeatedHandleOperands() && 7345a9bdd85SOleksandr "Alex" Zinenko nit != newlyInvalidated.end()) { 7355a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly " 7365a9bdd85SOleksandr "Alex" Zinenko "invalidated (by this op) -> FAILURE\n"); 7375a9bdd85SOleksandr "Alex" Zinenko return nit->getSecond()(transform->getLoc()), failure(); 7385a9bdd85SOleksandr "Alex" Zinenko } 7395a9bdd85SOleksandr "Alex" Zinenko 7405a9bdd85SOleksandr "Alex" Zinenko // Invalidate handles pointing to the operations nested in the operation 7415a9bdd85SOleksandr "Alex" Zinenko // associated with the handle consumed by this operation. 7425a9bdd85SOleksandr "Alex" Zinenko auto consumesTarget = [&](const MemoryEffects::EffectInstance &effect) { 7435a9bdd85SOleksandr "Alex" Zinenko return isa<MemoryEffects::Free>(effect.getEffect()) && 7445a9bdd85SOleksandr "Alex" Zinenko effect.getValue() == target.get(); 7455a9bdd85SOleksandr "Alex" Zinenko }; 7465a9bdd85SOleksandr "Alex" Zinenko if (llvm::any_of(effects, consumesTarget)) { 7475a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("----found consume effect\n"); 7485a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<transform::TransformHandleTypeInterface>( 7495a9bdd85SOleksandr "Alex" Zinenko target.get().getType())) { 7505a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("----recordOpHandleInvalidation\n"); 7515a9bdd85SOleksandr "Alex" Zinenko SmallVector<Operation *> payloadOps = 7525a9bdd85SOleksandr "Alex" Zinenko llvm::to_vector(getPayloadOps(target.get())); 7535a9bdd85SOleksandr "Alex" Zinenko recordOpHandleInvalidation(target, payloadOps, nullptr, 7545a9bdd85SOleksandr "Alex" Zinenko newlyInvalidated); 7555a9bdd85SOleksandr "Alex" Zinenko } else if (llvm::isa<transform::TransformValueHandleTypeInterface>( 7565a9bdd85SOleksandr "Alex" Zinenko target.get().getType())) { 7575a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("----recordValueHandleInvalidation\n"); 7585a9bdd85SOleksandr "Alex" Zinenko recordValueHandleInvalidation(target, newlyInvalidated); 7595a9bdd85SOleksandr "Alex" Zinenko } else { 7605a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); 7615a9bdd85SOleksandr "Alex" Zinenko } 7625a9bdd85SOleksandr "Alex" Zinenko } else { 7635a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("----no consume effect -> SKIP\n"); 7645a9bdd85SOleksandr "Alex" Zinenko } 7655a9bdd85SOleksandr "Alex" Zinenko } 7665a9bdd85SOleksandr "Alex" Zinenko 7675a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n"); 7685a9bdd85SOleksandr "Alex" Zinenko return success(); 7695a9bdd85SOleksandr "Alex" Zinenko } 7705a9bdd85SOleksandr "Alex" Zinenko 7715a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( 7725a9bdd85SOleksandr "Alex" Zinenko transform::TransformOpInterface transform) { 7735a9bdd85SOleksandr "Alex" Zinenko InvalidatedHandleMap newlyInvalidated; 7745a9bdd85SOleksandr "Alex" Zinenko LogicalResult checkResult = 7755a9bdd85SOleksandr "Alex" Zinenko checkAndRecordHandleInvalidationImpl(transform, newlyInvalidated); 7765a9bdd85SOleksandr "Alex" Zinenko invalidatedHandles.insert(std::make_move_iterator(newlyInvalidated.begin()), 7775a9bdd85SOleksandr "Alex" Zinenko std::make_move_iterator(newlyInvalidated.end())); 7785a9bdd85SOleksandr "Alex" Zinenko return checkResult; 7795a9bdd85SOleksandr "Alex" Zinenko } 7805a9bdd85SOleksandr "Alex" Zinenko 7815a9bdd85SOleksandr "Alex" Zinenko template <typename T> 7825a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure 7835a9bdd85SOleksandr "Alex" Zinenko checkRepeatedConsumptionInOperand(ArrayRef<T> payload, 7845a9bdd85SOleksandr "Alex" Zinenko transform::TransformOpInterface transform, 7855a9bdd85SOleksandr "Alex" Zinenko unsigned operandNumber) { 7865a9bdd85SOleksandr "Alex" Zinenko DenseSet<T> seen; 7875a9bdd85SOleksandr "Alex" Zinenko for (T p : payload) { 7885a9bdd85SOleksandr "Alex" Zinenko if (!seen.insert(p).second) { 7895a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure diag = 7905a9bdd85SOleksandr "Alex" Zinenko transform.emitSilenceableError() 7915a9bdd85SOleksandr "Alex" Zinenko << "a handle passed as operand #" << operandNumber 7925a9bdd85SOleksandr "Alex" Zinenko << " and consumed by this operation points to a payload " 7935a9bdd85SOleksandr "Alex" Zinenko "entity more than once"; 7945a9bdd85SOleksandr "Alex" Zinenko if constexpr (std::is_pointer_v<T>) 7955a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(p->getLoc()) << "repeated target op"; 7965a9bdd85SOleksandr "Alex" Zinenko else 7975a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(p.getLoc()) << "repeated target value"; 7985a9bdd85SOleksandr "Alex" Zinenko return diag; 7995a9bdd85SOleksandr "Alex" Zinenko } 8005a9bdd85SOleksandr "Alex" Zinenko } 8015a9bdd85SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::success(); 8025a9bdd85SOleksandr "Alex" Zinenko } 8035a9bdd85SOleksandr "Alex" Zinenko 8045a9bdd85SOleksandr "Alex" Zinenko void transform::TransformState::compactOpHandles() { 8055a9bdd85SOleksandr "Alex" Zinenko for (Value handle : opHandlesToCompact) { 8065a9bdd85SOleksandr "Alex" Zinenko Mappings &mappings = getMapping(handle, /*allowOutOfScope=*/true); 8076c7a3f80SMehdi Amini #if LLVM_ENABLE_ABI_BREAKING_CHECKS 8085a9bdd85SOleksandr "Alex" Zinenko if (llvm::find(mappings.direct[handle], nullptr) != 8095a9bdd85SOleksandr "Alex" Zinenko mappings.direct[handle].end()) 8105a9bdd85SOleksandr "Alex" Zinenko // Payload IR is removed from the mapping. This invalidates the respective 8115a9bdd85SOleksandr "Alex" Zinenko // iterators. 8125a9bdd85SOleksandr "Alex" Zinenko mappings.incrementTimestamp(handle); 8135a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 8145a9bdd85SOleksandr "Alex" Zinenko llvm::erase(mappings.direct[handle], nullptr); 8155a9bdd85SOleksandr "Alex" Zinenko } 8165a9bdd85SOleksandr "Alex" Zinenko opHandlesToCompact.clear(); 8175a9bdd85SOleksandr "Alex" Zinenko } 8185a9bdd85SOleksandr "Alex" Zinenko 8195a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure 8205a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::applyTransform(TransformOpInterface transform) { 8215a9bdd85SOleksandr "Alex" Zinenko LLVM_DEBUG({ 8225a9bdd85SOleksandr "Alex" Zinenko DBGS() << "applying: "; 8235a9bdd85SOleksandr "Alex" Zinenko transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions()); 8245a9bdd85SOleksandr "Alex" Zinenko llvm::dbgs() << "\n"; 8255a9bdd85SOleksandr "Alex" Zinenko }); 8265a9bdd85SOleksandr "Alex" Zinenko DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, 8275a9bdd85SOleksandr "Alex" Zinenko DBGS() << "Top-level payload before application:\n" 8285a9bdd85SOleksandr "Alex" Zinenko << *getTopLevel() << "\n"); 8295a9bdd85SOleksandr "Alex" Zinenko auto printOnFailureRAII = llvm::make_scope_exit([this] { 8305a9bdd85SOleksandr "Alex" Zinenko (void)this; 8315a9bdd85SOleksandr "Alex" Zinenko LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print( 8325a9bdd85SOleksandr "Alex" Zinenko llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm());); 8335a9bdd85SOleksandr "Alex" Zinenko }); 8345a9bdd85SOleksandr "Alex" Zinenko 8355a9bdd85SOleksandr "Alex" Zinenko // Set current transform op. 8365a9bdd85SOleksandr "Alex" Zinenko regionStack.back()->currentTransform = transform; 8375a9bdd85SOleksandr "Alex" Zinenko 8385a9bdd85SOleksandr "Alex" Zinenko // Expensive checks to detect invalid transform IR. 8395a9bdd85SOleksandr "Alex" Zinenko if (options.getExpensiveChecksEnabled()) { 8405a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("ExpensiveChecksEnabled\n"); 8415a9bdd85SOleksandr "Alex" Zinenko if (failed(checkAndRecordHandleInvalidation(transform))) 8425a9bdd85SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 8435a9bdd85SOleksandr "Alex" Zinenko 8445a9bdd85SOleksandr "Alex" Zinenko for (OpOperand &operand : transform->getOpOperands()) { 8455a9bdd85SOleksandr "Alex" Zinenko DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { 8465a9bdd85SOleksandr "Alex" Zinenko (DBGS() << "iterate on handle: " << operand.get() << "\n"); 8475a9bdd85SOleksandr "Alex" Zinenko }); 8485a9bdd85SOleksandr "Alex" Zinenko if (!isHandleConsumed(operand.get(), transform)) { 8495a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("--handle not consumed -> SKIP\n"); 8505a9bdd85SOleksandr "Alex" Zinenko continue; 8515a9bdd85SOleksandr "Alex" Zinenko } 8525a9bdd85SOleksandr "Alex" Zinenko if (transform.allowsRepeatedHandleOperands()) { 8535a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("--op allows repeated handles -> SKIP\n"); 8545a9bdd85SOleksandr "Alex" Zinenko continue; 8555a9bdd85SOleksandr "Alex" Zinenko } 8565a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("--handle is consumed\n"); 8575a9bdd85SOleksandr "Alex" Zinenko 8585a9bdd85SOleksandr "Alex" Zinenko Type operandType = operand.get().getType(); 8595a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformHandleTypeInterface>(operandType)) { 8605a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n"); 8615a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure check = 8625a9bdd85SOleksandr "Alex" Zinenko checkRepeatedConsumptionInOperand<Operation *>( 8635a9bdd85SOleksandr "Alex" Zinenko getPayloadOpsView(operand.get()), transform, 8645a9bdd85SOleksandr "Alex" Zinenko operand.getOperandNumber()); 8655a9bdd85SOleksandr "Alex" Zinenko if (!check.succeeded()) { 8665a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("----FAILED\n"); 8675a9bdd85SOleksandr "Alex" Zinenko return check; 8685a9bdd85SOleksandr "Alex" Zinenko } 8695a9bdd85SOleksandr "Alex" Zinenko } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) { 8705a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n"); 8715a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure check = 8725a9bdd85SOleksandr "Alex" Zinenko checkRepeatedConsumptionInOperand<Value>( 8735a9bdd85SOleksandr "Alex" Zinenko getPayloadValuesView(operand.get()), transform, 8745a9bdd85SOleksandr "Alex" Zinenko operand.getOperandNumber()); 8755a9bdd85SOleksandr "Alex" Zinenko if (!check.succeeded()) { 8765a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("----FAILED\n"); 8775a9bdd85SOleksandr "Alex" Zinenko return check; 8785a9bdd85SOleksandr "Alex" Zinenko } 8795a9bdd85SOleksandr "Alex" Zinenko } else { 8805a9bdd85SOleksandr "Alex" Zinenko FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); 8815a9bdd85SOleksandr "Alex" Zinenko } 8825a9bdd85SOleksandr "Alex" Zinenko } 8835a9bdd85SOleksandr "Alex" Zinenko } 8845a9bdd85SOleksandr "Alex" Zinenko 8855a9bdd85SOleksandr "Alex" Zinenko // Find which operands are consumed. 8865a9bdd85SOleksandr "Alex" Zinenko SmallVector<OpOperand *> consumedOperands = 8875a9bdd85SOleksandr "Alex" Zinenko transform.getConsumedHandleOpOperands(); 8885a9bdd85SOleksandr "Alex" Zinenko 8895a9bdd85SOleksandr "Alex" Zinenko // Remember the results of the payload ops associated with the consumed 8905a9bdd85SOleksandr "Alex" Zinenko // op handles or the ops defining the value handles so we can drop the 8915a9bdd85SOleksandr "Alex" Zinenko // association with them later. This must happen here because the 8925a9bdd85SOleksandr "Alex" Zinenko // transformation may destroy or mutate them so we cannot traverse the payload 8935a9bdd85SOleksandr "Alex" Zinenko // IR after that. 8945a9bdd85SOleksandr "Alex" Zinenko SmallVector<Value> origOpFlatResults; 8955a9bdd85SOleksandr "Alex" Zinenko SmallVector<Operation *> origAssociatedOps; 8965a9bdd85SOleksandr "Alex" Zinenko for (OpOperand *opOperand : consumedOperands) { 8975a9bdd85SOleksandr "Alex" Zinenko Value operand = opOperand->get(); 8985a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) { 8995a9bdd85SOleksandr "Alex" Zinenko for (Operation *payloadOp : getPayloadOps(operand)) { 9005a9bdd85SOleksandr "Alex" Zinenko llvm::append_range(origOpFlatResults, payloadOp->getResults()); 9015a9bdd85SOleksandr "Alex" Zinenko } 9025a9bdd85SOleksandr "Alex" Zinenko continue; 9035a9bdd85SOleksandr "Alex" Zinenko } 9045a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformValueHandleTypeInterface>(operand.getType())) { 9055a9bdd85SOleksandr "Alex" Zinenko for (Value payloadValue : getPayloadValuesView(operand)) { 9065a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<OpResult>(payloadValue)) { 9075a9bdd85SOleksandr "Alex" Zinenko origAssociatedOps.push_back(payloadValue.getDefiningOp()); 9085a9bdd85SOleksandr "Alex" Zinenko continue; 9095a9bdd85SOleksandr "Alex" Zinenko } 9105a9bdd85SOleksandr "Alex" Zinenko llvm::append_range( 9115a9bdd85SOleksandr "Alex" Zinenko origAssociatedOps, 9125a9bdd85SOleksandr "Alex" Zinenko llvm::map_range(*llvm::cast<BlockArgument>(payloadValue).getOwner(), 9135a9bdd85SOleksandr "Alex" Zinenko [](Operation &op) { return &op; })); 9145a9bdd85SOleksandr "Alex" Zinenko } 9155a9bdd85SOleksandr "Alex" Zinenko continue; 9165a9bdd85SOleksandr "Alex" Zinenko } 9175a9bdd85SOleksandr "Alex" Zinenko DiagnosedDefiniteFailure diag = 9185a9bdd85SOleksandr "Alex" Zinenko emitDefiniteFailure(transform->getLoc()) 9195a9bdd85SOleksandr "Alex" Zinenko << "unexpectedly consumed a value that is not a handle as operand #" 9205a9bdd85SOleksandr "Alex" Zinenko << opOperand->getOperandNumber(); 9215a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(operand.getLoc()) 9225a9bdd85SOleksandr "Alex" Zinenko << "value defined here with type " << operand.getType(); 9235a9bdd85SOleksandr "Alex" Zinenko return diag; 9245a9bdd85SOleksandr "Alex" Zinenko } 9255a9bdd85SOleksandr "Alex" Zinenko 9265a9bdd85SOleksandr "Alex" Zinenko // Prepare rewriter and listener. 9275a9bdd85SOleksandr "Alex" Zinenko TrackingListenerConfig config; 9285a9bdd85SOleksandr "Alex" Zinenko config.skipHandleFn = [&](Value handle) { 9295a9bdd85SOleksandr "Alex" Zinenko // Skip handle if it is dead. 9305a9bdd85SOleksandr "Alex" Zinenko auto scopeIt = 9315a9bdd85SOleksandr "Alex" Zinenko llvm::find_if(llvm::reverse(regionStack), [&](RegionScope *scope) { 9325a9bdd85SOleksandr "Alex" Zinenko return handle.getParentRegion() == scope->region; 9335a9bdd85SOleksandr "Alex" Zinenko }); 9345a9bdd85SOleksandr "Alex" Zinenko assert(scopeIt != regionStack.rend() && 9355a9bdd85SOleksandr "Alex" Zinenko "could not find region scope for handle"); 9365a9bdd85SOleksandr "Alex" Zinenko RegionScope *scope = *scopeIt; 937e323b40bSJianjian Guan return llvm::all_of(handle.getUsers(), [&](Operation *user) { 938e323b40bSJianjian Guan return user == scope->currentTransform || 939e323b40bSJianjian Guan happensBefore(user, scope->currentTransform); 940e323b40bSJianjian Guan }); 9415a9bdd85SOleksandr "Alex" Zinenko }; 9425a9bdd85SOleksandr "Alex" Zinenko transform::ErrorCheckingTrackingListener trackingListener(*this, transform, 9435a9bdd85SOleksandr "Alex" Zinenko config); 9445a9bdd85SOleksandr "Alex" Zinenko transform::TransformRewriter rewriter(transform->getContext(), 9455a9bdd85SOleksandr "Alex" Zinenko &trackingListener); 9465a9bdd85SOleksandr "Alex" Zinenko 9475a9bdd85SOleksandr "Alex" Zinenko // Compute the result but do not short-circuit the silenceable failure case as 9485a9bdd85SOleksandr "Alex" Zinenko // we still want the handles to propagate properly so the "suppress" mode can 9495a9bdd85SOleksandr "Alex" Zinenko // proceed on a best effort basis. 9505a9bdd85SOleksandr "Alex" Zinenko transform::TransformResults results(transform->getNumResults()); 9515a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this)); 9525a9bdd85SOleksandr "Alex" Zinenko compactOpHandles(); 9535a9bdd85SOleksandr "Alex" Zinenko 9545a9bdd85SOleksandr "Alex" Zinenko // Error handling: fail if transform or listener failed. 9555a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure trackingFailure = 9565a9bdd85SOleksandr "Alex" Zinenko trackingListener.checkAndResetError(); 9575a9bdd85SOleksandr "Alex" Zinenko if (!transform->hasTrait<ReportTrackingListenerFailuresOpTrait>() || 9585a9bdd85SOleksandr "Alex" Zinenko transform->hasAttr(FindPayloadReplacementOpInterface:: 9595a9bdd85SOleksandr "Alex" Zinenko kSilenceTrackingFailuresAttrName)) { 9605a9bdd85SOleksandr "Alex" Zinenko // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also 9615a9bdd85SOleksandr "Alex" Zinenko // do not report failures if the above mentioned attribute is set. 9625a9bdd85SOleksandr "Alex" Zinenko if (trackingFailure.isSilenceableFailure()) 9635a9bdd85SOleksandr "Alex" Zinenko (void)trackingFailure.silence(); 9645a9bdd85SOleksandr "Alex" Zinenko trackingFailure = DiagnosedSilenceableFailure::success(); 9655a9bdd85SOleksandr "Alex" Zinenko } 9665a9bdd85SOleksandr "Alex" Zinenko if (!trackingFailure.succeeded()) { 9675a9bdd85SOleksandr "Alex" Zinenko if (result.succeeded()) { 9685a9bdd85SOleksandr "Alex" Zinenko result = std::move(trackingFailure); 9695a9bdd85SOleksandr "Alex" Zinenko } else { 9705a9bdd85SOleksandr "Alex" Zinenko // Transform op errors have precedence, report those first. 9715a9bdd85SOleksandr "Alex" Zinenko if (result.isSilenceableFailure()) 9725a9bdd85SOleksandr "Alex" Zinenko result.attachNote() << "tracking listener also failed: " 9735a9bdd85SOleksandr "Alex" Zinenko << trackingFailure.getMessage(); 9745a9bdd85SOleksandr "Alex" Zinenko (void)trackingFailure.silence(); 9755a9bdd85SOleksandr "Alex" Zinenko } 9765a9bdd85SOleksandr "Alex" Zinenko } 9775a9bdd85SOleksandr "Alex" Zinenko if (result.isDefiniteFailure()) 9785a9bdd85SOleksandr "Alex" Zinenko return result; 9795a9bdd85SOleksandr "Alex" Zinenko 9805a9bdd85SOleksandr "Alex" Zinenko // If a silenceable failure was produced, some results may be unset, set them 9815a9bdd85SOleksandr "Alex" Zinenko // to empty lists. 9825a9bdd85SOleksandr "Alex" Zinenko if (result.isSilenceableFailure()) 9835a9bdd85SOleksandr "Alex" Zinenko results.setRemainingToEmpty(transform); 9845a9bdd85SOleksandr "Alex" Zinenko 9855a9bdd85SOleksandr "Alex" Zinenko // Remove the mapping for the operand if it is consumed by the operation. This 9865a9bdd85SOleksandr "Alex" Zinenko // allows us to catch use-after-free with assertions later on. 9875a9bdd85SOleksandr "Alex" Zinenko for (OpOperand *opOperand : consumedOperands) { 9885a9bdd85SOleksandr "Alex" Zinenko Value operand = opOperand->get(); 9895a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) { 9905a9bdd85SOleksandr "Alex" Zinenko forgetMapping(operand, origOpFlatResults); 9915a9bdd85SOleksandr "Alex" Zinenko } else if (llvm::isa<TransformValueHandleTypeInterface>( 9925a9bdd85SOleksandr "Alex" Zinenko operand.getType())) { 9935a9bdd85SOleksandr "Alex" Zinenko forgetValueMapping(operand, origAssociatedOps); 9945a9bdd85SOleksandr "Alex" Zinenko } 9955a9bdd85SOleksandr "Alex" Zinenko } 9965a9bdd85SOleksandr "Alex" Zinenko 9975a9bdd85SOleksandr "Alex" Zinenko if (failed(updateStateFromResults(results, transform->getResults()))) 9985a9bdd85SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::definiteFailure(); 9995a9bdd85SOleksandr "Alex" Zinenko 10005a9bdd85SOleksandr "Alex" Zinenko printOnFailureRAII.release(); 10015a9bdd85SOleksandr "Alex" Zinenko DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, { 10025a9bdd85SOleksandr "Alex" Zinenko DBGS() << "Top-level payload:\n"; 10035a9bdd85SOleksandr "Alex" Zinenko getTopLevel()->print(llvm::dbgs()); 10045a9bdd85SOleksandr "Alex" Zinenko }); 10055a9bdd85SOleksandr "Alex" Zinenko return result; 10065a9bdd85SOleksandr "Alex" Zinenko } 10075a9bdd85SOleksandr "Alex" Zinenko 10085a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformState::updateStateFromResults( 10095a9bdd85SOleksandr "Alex" Zinenko const TransformResults &results, ResultRange opResults) { 10105a9bdd85SOleksandr "Alex" Zinenko for (OpResult result : opResults) { 10115a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformParamTypeInterface>(result.getType())) { 10125a9bdd85SOleksandr "Alex" Zinenko assert(results.isParam(result.getResultNumber()) && 10135a9bdd85SOleksandr "Alex" Zinenko "expected parameters for the parameter-typed result"); 10145a9bdd85SOleksandr "Alex" Zinenko if (failed( 10155a9bdd85SOleksandr "Alex" Zinenko setParams(result, results.getParams(result.getResultNumber())))) { 10165a9bdd85SOleksandr "Alex" Zinenko return failure(); 10175a9bdd85SOleksandr "Alex" Zinenko } 10185a9bdd85SOleksandr "Alex" Zinenko } else if (llvm::isa<TransformValueHandleTypeInterface>(result.getType())) { 10195a9bdd85SOleksandr "Alex" Zinenko assert(results.isValue(result.getResultNumber()) && 10205a9bdd85SOleksandr "Alex" Zinenko "expected values for value-type-result"); 10215a9bdd85SOleksandr "Alex" Zinenko if (failed(setPayloadValues( 10225a9bdd85SOleksandr "Alex" Zinenko result, results.getValues(result.getResultNumber())))) { 10235a9bdd85SOleksandr "Alex" Zinenko return failure(); 10245a9bdd85SOleksandr "Alex" Zinenko } 10255a9bdd85SOleksandr "Alex" Zinenko } else { 10265a9bdd85SOleksandr "Alex" Zinenko assert(!results.isParam(result.getResultNumber()) && 10275a9bdd85SOleksandr "Alex" Zinenko "expected payload ops for the non-parameter typed result"); 10285a9bdd85SOleksandr "Alex" Zinenko if (failed( 10295a9bdd85SOleksandr "Alex" Zinenko setPayloadOps(result, results.get(result.getResultNumber())))) { 10305a9bdd85SOleksandr "Alex" Zinenko return failure(); 10315a9bdd85SOleksandr "Alex" Zinenko } 10325a9bdd85SOleksandr "Alex" Zinenko } 10335a9bdd85SOleksandr "Alex" Zinenko } 10345a9bdd85SOleksandr "Alex" Zinenko return success(); 10355a9bdd85SOleksandr "Alex" Zinenko } 10365a9bdd85SOleksandr "Alex" Zinenko 10375a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 10385a9bdd85SOleksandr "Alex" Zinenko // TransformState::Extension 10395a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 10405a9bdd85SOleksandr "Alex" Zinenko 10415a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::Extension::~Extension() = default; 10425a9bdd85SOleksandr "Alex" Zinenko 10435a9bdd85SOleksandr "Alex" Zinenko LogicalResult 10445a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::Extension::replacePayloadOp(Operation *op, 10455a9bdd85SOleksandr "Alex" Zinenko Operation *replacement) { 10465a9bdd85SOleksandr "Alex" Zinenko // TODO: we may need to invalidate handles to operations and values nested in 10475a9bdd85SOleksandr "Alex" Zinenko // the operation being replaced. 10485a9bdd85SOleksandr "Alex" Zinenko return state.replacePayloadOp(op, replacement); 10495a9bdd85SOleksandr "Alex" Zinenko } 10505a9bdd85SOleksandr "Alex" Zinenko 10515a9bdd85SOleksandr "Alex" Zinenko LogicalResult 10525a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::Extension::replacePayloadValue(Value value, 10535a9bdd85SOleksandr "Alex" Zinenko Value replacement) { 10545a9bdd85SOleksandr "Alex" Zinenko return state.replacePayloadValue(value, replacement); 10555a9bdd85SOleksandr "Alex" Zinenko } 10565a9bdd85SOleksandr "Alex" Zinenko 10575a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 10585a9bdd85SOleksandr "Alex" Zinenko // TransformState::RegionScope 10595a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 10605a9bdd85SOleksandr "Alex" Zinenko 10615a9bdd85SOleksandr "Alex" Zinenko transform::TransformState::RegionScope::~RegionScope() { 10625a9bdd85SOleksandr "Alex" Zinenko // Remove handle invalidation notices as handles are going out of scope. 10635a9bdd85SOleksandr "Alex" Zinenko // The same region may be re-entered leading to incorrect invalidation 10645a9bdd85SOleksandr "Alex" Zinenko // errors. 10655a9bdd85SOleksandr "Alex" Zinenko for (Block &block : *region) { 10665a9bdd85SOleksandr "Alex" Zinenko for (Value handle : block.getArguments()) { 10675a9bdd85SOleksandr "Alex" Zinenko state.invalidatedHandles.erase(handle); 10685a9bdd85SOleksandr "Alex" Zinenko } 10695a9bdd85SOleksandr "Alex" Zinenko for (Operation &op : block) { 10705a9bdd85SOleksandr "Alex" Zinenko for (Value handle : op.getResults()) { 10715a9bdd85SOleksandr "Alex" Zinenko state.invalidatedHandles.erase(handle); 10725a9bdd85SOleksandr "Alex" Zinenko } 10735a9bdd85SOleksandr "Alex" Zinenko } 10745a9bdd85SOleksandr "Alex" Zinenko } 10755a9bdd85SOleksandr "Alex" Zinenko 10765a9bdd85SOleksandr "Alex" Zinenko #if LLVM_ENABLE_ABI_BREAKING_CHECKS 10775a9bdd85SOleksandr "Alex" Zinenko // Remember pointers to payload ops referenced by the handles going out of 10785a9bdd85SOleksandr "Alex" Zinenko // scope. 10795a9bdd85SOleksandr "Alex" Zinenko SmallVector<Operation *> referencedOps = 10805a9bdd85SOleksandr "Alex" Zinenko llvm::to_vector(llvm::make_first_range(state.mappings[region]->reverse)); 10815a9bdd85SOleksandr "Alex" Zinenko #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS 10825a9bdd85SOleksandr "Alex" Zinenko 10835a9bdd85SOleksandr "Alex" Zinenko state.mappings.erase(region); 10845a9bdd85SOleksandr "Alex" Zinenko state.regionStack.pop_back(); 10855a9bdd85SOleksandr "Alex" Zinenko } 10865a9bdd85SOleksandr "Alex" Zinenko 10875a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 10885a9bdd85SOleksandr "Alex" Zinenko // TransformResults 10895a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 10905a9bdd85SOleksandr "Alex" Zinenko 10915a9bdd85SOleksandr "Alex" Zinenko transform::TransformResults::TransformResults(unsigned numSegments) { 10925a9bdd85SOleksandr "Alex" Zinenko operations.appendEmptyRows(numSegments); 10935a9bdd85SOleksandr "Alex" Zinenko params.appendEmptyRows(numSegments); 10945a9bdd85SOleksandr "Alex" Zinenko values.appendEmptyRows(numSegments); 10955a9bdd85SOleksandr "Alex" Zinenko } 10965a9bdd85SOleksandr "Alex" Zinenko 10975a9bdd85SOleksandr "Alex" Zinenko void transform::TransformResults::setParams( 10985a9bdd85SOleksandr "Alex" Zinenko OpResult value, ArrayRef<transform::TransformState::Param> params) { 10995a9bdd85SOleksandr "Alex" Zinenko int64_t position = value.getResultNumber(); 11005a9bdd85SOleksandr "Alex" Zinenko assert(position < static_cast<int64_t>(this->params.size()) && 11015a9bdd85SOleksandr "Alex" Zinenko "setting params for a non-existent handle"); 11025a9bdd85SOleksandr "Alex" Zinenko assert(this->params[position].data() == nullptr && "params already set"); 11035a9bdd85SOleksandr "Alex" Zinenko assert(operations[position].data() == nullptr && 11045a9bdd85SOleksandr "Alex" Zinenko "another kind of results already set"); 11055a9bdd85SOleksandr "Alex" Zinenko assert(values[position].data() == nullptr && 11065a9bdd85SOleksandr "Alex" Zinenko "another kind of results already set"); 11075a9bdd85SOleksandr "Alex" Zinenko this->params.replace(position, params); 11085a9bdd85SOleksandr "Alex" Zinenko } 11095a9bdd85SOleksandr "Alex" Zinenko 11105a9bdd85SOleksandr "Alex" Zinenko void transform::TransformResults::setMappedValues( 11115a9bdd85SOleksandr "Alex" Zinenko OpResult handle, ArrayRef<MappedValue> values) { 11125a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure diag = dispatchMappedValues( 11135a9bdd85SOleksandr "Alex" Zinenko handle, values, 11145a9bdd85SOleksandr "Alex" Zinenko [&](ArrayRef<Operation *> operations) { 11155a9bdd85SOleksandr "Alex" Zinenko return set(handle, operations), success(); 11165a9bdd85SOleksandr "Alex" Zinenko }, 11175a9bdd85SOleksandr "Alex" Zinenko [&](ArrayRef<Param> params) { 11185a9bdd85SOleksandr "Alex" Zinenko return setParams(handle, params), success(); 11195a9bdd85SOleksandr "Alex" Zinenko }, 11205a9bdd85SOleksandr "Alex" Zinenko [&](ValueRange payloadValues) { 11215a9bdd85SOleksandr "Alex" Zinenko return setValues(handle, payloadValues), success(); 11225a9bdd85SOleksandr "Alex" Zinenko }); 11235a9bdd85SOleksandr "Alex" Zinenko #ifndef NDEBUG 11245a9bdd85SOleksandr "Alex" Zinenko if (!diag.succeeded()) 11255a9bdd85SOleksandr "Alex" Zinenko llvm::dbgs() << diag.getStatusString() << "\n"; 11265a9bdd85SOleksandr "Alex" Zinenko assert(diag.succeeded() && "incorrect mapping"); 11275a9bdd85SOleksandr "Alex" Zinenko #endif // NDEBUG 11285a9bdd85SOleksandr "Alex" Zinenko (void)diag.silence(); 11295a9bdd85SOleksandr "Alex" Zinenko } 11305a9bdd85SOleksandr "Alex" Zinenko 11315a9bdd85SOleksandr "Alex" Zinenko void transform::TransformResults::setRemainingToEmpty( 11325a9bdd85SOleksandr "Alex" Zinenko transform::TransformOpInterface transform) { 11335a9bdd85SOleksandr "Alex" Zinenko for (OpResult opResult : transform->getResults()) { 11345a9bdd85SOleksandr "Alex" Zinenko if (!isSet(opResult.getResultNumber())) 11355a9bdd85SOleksandr "Alex" Zinenko setMappedValues(opResult, {}); 11365a9bdd85SOleksandr "Alex" Zinenko } 11375a9bdd85SOleksandr "Alex" Zinenko } 11385a9bdd85SOleksandr "Alex" Zinenko 11395a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Operation *> 11405a9bdd85SOleksandr "Alex" Zinenko transform::TransformResults::get(unsigned resultNumber) const { 11415a9bdd85SOleksandr "Alex" Zinenko assert(resultNumber < operations.size() && 11425a9bdd85SOleksandr "Alex" Zinenko "querying results for a non-existent handle"); 11435a9bdd85SOleksandr "Alex" Zinenko assert(operations[resultNumber].data() != nullptr && 11445a9bdd85SOleksandr "Alex" Zinenko "querying unset results (values or params expected?)"); 11455a9bdd85SOleksandr "Alex" Zinenko return operations[resultNumber]; 11465a9bdd85SOleksandr "Alex" Zinenko } 11475a9bdd85SOleksandr "Alex" Zinenko 11485a9bdd85SOleksandr "Alex" Zinenko ArrayRef<transform::TransformState::Param> 11495a9bdd85SOleksandr "Alex" Zinenko transform::TransformResults::getParams(unsigned resultNumber) const { 11505a9bdd85SOleksandr "Alex" Zinenko assert(resultNumber < params.size() && 11515a9bdd85SOleksandr "Alex" Zinenko "querying params for a non-existent handle"); 11525a9bdd85SOleksandr "Alex" Zinenko assert(params[resultNumber].data() != nullptr && 11535a9bdd85SOleksandr "Alex" Zinenko "querying unset params (ops or values expected?)"); 11545a9bdd85SOleksandr "Alex" Zinenko return params[resultNumber]; 11555a9bdd85SOleksandr "Alex" Zinenko } 11565a9bdd85SOleksandr "Alex" Zinenko 11575a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Value> 11585a9bdd85SOleksandr "Alex" Zinenko transform::TransformResults::getValues(unsigned resultNumber) const { 11595a9bdd85SOleksandr "Alex" Zinenko assert(resultNumber < values.size() && 11605a9bdd85SOleksandr "Alex" Zinenko "querying values for a non-existent handle"); 11615a9bdd85SOleksandr "Alex" Zinenko assert(values[resultNumber].data() != nullptr && 11625a9bdd85SOleksandr "Alex" Zinenko "querying unset values (ops or params expected?)"); 11635a9bdd85SOleksandr "Alex" Zinenko return values[resultNumber]; 11645a9bdd85SOleksandr "Alex" Zinenko } 11655a9bdd85SOleksandr "Alex" Zinenko 11665a9bdd85SOleksandr "Alex" Zinenko bool transform::TransformResults::isParam(unsigned resultNumber) const { 11675a9bdd85SOleksandr "Alex" Zinenko assert(resultNumber < params.size() && 11685a9bdd85SOleksandr "Alex" Zinenko "querying association for a non-existent handle"); 11695a9bdd85SOleksandr "Alex" Zinenko return params[resultNumber].data() != nullptr; 11705a9bdd85SOleksandr "Alex" Zinenko } 11715a9bdd85SOleksandr "Alex" Zinenko 11725a9bdd85SOleksandr "Alex" Zinenko bool transform::TransformResults::isValue(unsigned resultNumber) const { 11735a9bdd85SOleksandr "Alex" Zinenko assert(resultNumber < values.size() && 11745a9bdd85SOleksandr "Alex" Zinenko "querying association for a non-existent handle"); 11755a9bdd85SOleksandr "Alex" Zinenko return values[resultNumber].data() != nullptr; 11765a9bdd85SOleksandr "Alex" Zinenko } 11775a9bdd85SOleksandr "Alex" Zinenko 11785a9bdd85SOleksandr "Alex" Zinenko bool transform::TransformResults::isSet(unsigned resultNumber) const { 11795a9bdd85SOleksandr "Alex" Zinenko assert(resultNumber < params.size() && 11805a9bdd85SOleksandr "Alex" Zinenko "querying association for a non-existent handle"); 11815a9bdd85SOleksandr "Alex" Zinenko return params[resultNumber].data() != nullptr || 11825a9bdd85SOleksandr "Alex" Zinenko operations[resultNumber].data() != nullptr || 11835a9bdd85SOleksandr "Alex" Zinenko values[resultNumber].data() != nullptr; 11845a9bdd85SOleksandr "Alex" Zinenko } 11855a9bdd85SOleksandr "Alex" Zinenko 11865a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 11875a9bdd85SOleksandr "Alex" Zinenko // TrackingListener 11885a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 11895a9bdd85SOleksandr "Alex" Zinenko 11905a9bdd85SOleksandr "Alex" Zinenko transform::TrackingListener::TrackingListener(TransformState &state, 11915a9bdd85SOleksandr "Alex" Zinenko TransformOpInterface op, 11925a9bdd85SOleksandr "Alex" Zinenko TrackingListenerConfig config) 11935a9bdd85SOleksandr "Alex" Zinenko : TransformState::Extension(state), transformOp(op), config(config) { 11945a9bdd85SOleksandr "Alex" Zinenko if (op) { 11955a9bdd85SOleksandr "Alex" Zinenko for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) { 11965a9bdd85SOleksandr "Alex" Zinenko consumedHandles.insert(opOperand->get()); 11975a9bdd85SOleksandr "Alex" Zinenko } 11985a9bdd85SOleksandr "Alex" Zinenko } 11995a9bdd85SOleksandr "Alex" Zinenko } 12005a9bdd85SOleksandr "Alex" Zinenko 12015a9bdd85SOleksandr "Alex" Zinenko Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) { 12025a9bdd85SOleksandr "Alex" Zinenko Operation *defOp = nullptr; 12035a9bdd85SOleksandr "Alex" Zinenko for (Value v : values) { 12045a9bdd85SOleksandr "Alex" Zinenko // Skip empty values. 12055a9bdd85SOleksandr "Alex" Zinenko if (!v) 12065a9bdd85SOleksandr "Alex" Zinenko continue; 12075a9bdd85SOleksandr "Alex" Zinenko if (!defOp) { 12085a9bdd85SOleksandr "Alex" Zinenko defOp = v.getDefiningOp(); 12095a9bdd85SOleksandr "Alex" Zinenko continue; 12105a9bdd85SOleksandr "Alex" Zinenko } 12115a9bdd85SOleksandr "Alex" Zinenko if (defOp != v.getDefiningOp()) 12125a9bdd85SOleksandr "Alex" Zinenko return nullptr; 12135a9bdd85SOleksandr "Alex" Zinenko } 12145a9bdd85SOleksandr "Alex" Zinenko return defOp; 12155a9bdd85SOleksandr "Alex" Zinenko } 12165a9bdd85SOleksandr "Alex" Zinenko 12175a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp( 12185a9bdd85SOleksandr "Alex" Zinenko Operation *&result, Operation *op, ValueRange newValues) const { 12195a9bdd85SOleksandr "Alex" Zinenko assert(op->getNumResults() == newValues.size() && 12205a9bdd85SOleksandr "Alex" Zinenko "invalid number of replacement values"); 12215a9bdd85SOleksandr "Alex" Zinenko SmallVector<Value> values(newValues.begin(), newValues.end()); 12225a9bdd85SOleksandr "Alex" Zinenko 12235a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure diag = emitSilenceableFailure( 12245a9bdd85SOleksandr "Alex" Zinenko getTransformOp(), "tracking listener failed to find replacement op " 12255a9bdd85SOleksandr "Alex" Zinenko "during application of this transform op"); 12265a9bdd85SOleksandr "Alex" Zinenko 12275a9bdd85SOleksandr "Alex" Zinenko do { 12285a9bdd85SOleksandr "Alex" Zinenko // If the replacement values belong to different ops, drop the mapping. 12295a9bdd85SOleksandr "Alex" Zinenko Operation *defOp = getCommonDefiningOp(values); 12305a9bdd85SOleksandr "Alex" Zinenko if (!defOp) { 12315a9bdd85SOleksandr "Alex" Zinenko diag.attachNote() << "replacement values belong to different ops"; 12325a9bdd85SOleksandr "Alex" Zinenko return diag; 12335a9bdd85SOleksandr "Alex" Zinenko } 12345a9bdd85SOleksandr "Alex" Zinenko 12355a9bdd85SOleksandr "Alex" Zinenko // Skip through ops that implement CastOpInterface. 12365a9bdd85SOleksandr "Alex" Zinenko if (config.skipCastOps && isa<CastOpInterface>(defOp)) { 12375a9bdd85SOleksandr "Alex" Zinenko values.clear(); 12385a9bdd85SOleksandr "Alex" Zinenko values.assign(defOp->getOperands().begin(), defOp->getOperands().end()); 12395a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(defOp->getLoc()) 12405a9bdd85SOleksandr "Alex" Zinenko << "using output of 'CastOpInterface' op"; 12415a9bdd85SOleksandr "Alex" Zinenko continue; 12425a9bdd85SOleksandr "Alex" Zinenko } 12435a9bdd85SOleksandr "Alex" Zinenko 12445a9bdd85SOleksandr "Alex" Zinenko // If the defining op has the same name or we do not care about the name of 12455a9bdd85SOleksandr "Alex" Zinenko // op replacements at all, we take it as a replacement. 12465a9bdd85SOleksandr "Alex" Zinenko if (!config.requireMatchingReplacementOpName || 12475a9bdd85SOleksandr "Alex" Zinenko op->getName() == defOp->getName()) { 12485a9bdd85SOleksandr "Alex" Zinenko result = defOp; 12495a9bdd85SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::success(); 12505a9bdd85SOleksandr "Alex" Zinenko } 12515a9bdd85SOleksandr "Alex" Zinenko 12525a9bdd85SOleksandr "Alex" Zinenko // Replacing an op with a constant-like equivalent is a common 12535a9bdd85SOleksandr "Alex" Zinenko // canonicalization. 12545a9bdd85SOleksandr "Alex" Zinenko if (defOp->hasTrait<OpTrait::ConstantLike>()) { 12555a9bdd85SOleksandr "Alex" Zinenko result = defOp; 12565a9bdd85SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::success(); 12575a9bdd85SOleksandr "Alex" Zinenko } 12585a9bdd85SOleksandr "Alex" Zinenko 12595a9bdd85SOleksandr "Alex" Zinenko values.clear(); 12605a9bdd85SOleksandr "Alex" Zinenko 12615a9bdd85SOleksandr "Alex" Zinenko // Skip through ops that implement FindPayloadReplacementOpInterface. 12625a9bdd85SOleksandr "Alex" Zinenko if (auto findReplacementOpInterface = 12635a9bdd85SOleksandr "Alex" Zinenko dyn_cast<FindPayloadReplacementOpInterface>(defOp)) { 12645a9bdd85SOleksandr "Alex" Zinenko values.assign(findReplacementOpInterface.getNextOperands()); 12655a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(defOp->getLoc()) << "using operands provided by " 12665a9bdd85SOleksandr "Alex" Zinenko "'FindPayloadReplacementOpInterface'"; 12675a9bdd85SOleksandr "Alex" Zinenko continue; 12685a9bdd85SOleksandr "Alex" Zinenko } 12695a9bdd85SOleksandr "Alex" Zinenko } while (!values.empty()); 12705a9bdd85SOleksandr "Alex" Zinenko 12715a9bdd85SOleksandr "Alex" Zinenko diag.attachNote() << "ran out of suitable replacement values"; 12725a9bdd85SOleksandr "Alex" Zinenko return diag; 12735a9bdd85SOleksandr "Alex" Zinenko } 12745a9bdd85SOleksandr "Alex" Zinenko 12755a9bdd85SOleksandr "Alex" Zinenko void transform::TrackingListener::notifyMatchFailure( 12765a9bdd85SOleksandr "Alex" Zinenko Location loc, function_ref<void(Diagnostic &)> reasonCallback) { 12775a9bdd85SOleksandr "Alex" Zinenko LLVM_DEBUG({ 12785a9bdd85SOleksandr "Alex" Zinenko Diagnostic diag(loc, DiagnosticSeverity::Remark); 12795a9bdd85SOleksandr "Alex" Zinenko reasonCallback(diag); 12805a9bdd85SOleksandr "Alex" Zinenko DBGS() << "Match Failure : " << diag.str() << "\n"; 12815a9bdd85SOleksandr "Alex" Zinenko }); 12825a9bdd85SOleksandr "Alex" Zinenko } 12835a9bdd85SOleksandr "Alex" Zinenko 12845a9bdd85SOleksandr "Alex" Zinenko void transform::TrackingListener::notifyOperationErased(Operation *op) { 12855a9bdd85SOleksandr "Alex" Zinenko // Remove mappings for result values. 12865a9bdd85SOleksandr "Alex" Zinenko for (OpResult value : op->getResults()) 12875a9bdd85SOleksandr "Alex" Zinenko (void)replacePayloadValue(value, nullptr); 12885a9bdd85SOleksandr "Alex" Zinenko // Remove mapping for op. 12895a9bdd85SOleksandr "Alex" Zinenko (void)replacePayloadOp(op, nullptr); 12905a9bdd85SOleksandr "Alex" Zinenko } 12915a9bdd85SOleksandr "Alex" Zinenko 12925a9bdd85SOleksandr "Alex" Zinenko void transform::TrackingListener::notifyOperationReplaced( 12935a9bdd85SOleksandr "Alex" Zinenko Operation *op, ValueRange newValues) { 12945a9bdd85SOleksandr "Alex" Zinenko assert(op->getNumResults() == newValues.size() && 12955a9bdd85SOleksandr "Alex" Zinenko "invalid number of replacement values"); 12965a9bdd85SOleksandr "Alex" Zinenko 12975a9bdd85SOleksandr "Alex" Zinenko // Replace value handles. 12985a9bdd85SOleksandr "Alex" Zinenko for (auto [oldValue, newValue] : llvm::zip(op->getResults(), newValues)) 12995a9bdd85SOleksandr "Alex" Zinenko (void)replacePayloadValue(oldValue, newValue); 13005a9bdd85SOleksandr "Alex" Zinenko 13015a9bdd85SOleksandr "Alex" Zinenko // Replace op handle. 13025a9bdd85SOleksandr "Alex" Zinenko SmallVector<Value> opHandles; 13035a9bdd85SOleksandr "Alex" Zinenko if (failed(getTransformState().getHandlesForPayloadOp( 13045a9bdd85SOleksandr "Alex" Zinenko op, opHandles, /*includeOutOfScope=*/true))) { 13055a9bdd85SOleksandr "Alex" Zinenko // Op is not tracked. 13065a9bdd85SOleksandr "Alex" Zinenko return; 13075a9bdd85SOleksandr "Alex" Zinenko } 13085a9bdd85SOleksandr "Alex" Zinenko 13095a9bdd85SOleksandr "Alex" Zinenko // Helper function to check if the current transform op consumes any handle 13105a9bdd85SOleksandr "Alex" Zinenko // that is mapped to `op`. 13115a9bdd85SOleksandr "Alex" Zinenko // 13125a9bdd85SOleksandr "Alex" Zinenko // Note: If a handle was consumed, there shouldn't be any alive users, so it 13135a9bdd85SOleksandr "Alex" Zinenko // is not really necessary to check for consumed handles. However, in case 13145a9bdd85SOleksandr "Alex" Zinenko // there are indeed alive handles that were consumed (which is undefined 13155a9bdd85SOleksandr "Alex" Zinenko // behavior) and a replacement op could not be found, we want to fail with a 13165a9bdd85SOleksandr "Alex" Zinenko // nicer error message: "op uses a handle invalidated..." instead of "could 13175a9bdd85SOleksandr "Alex" Zinenko // not find replacement op". This nicer error is produced later. 13185a9bdd85SOleksandr "Alex" Zinenko auto handleWasConsumed = [&] { 13195a9bdd85SOleksandr "Alex" Zinenko return llvm::any_of(opHandles, 13205a9bdd85SOleksandr "Alex" Zinenko [&](Value h) { return consumedHandles.contains(h); }); 13215a9bdd85SOleksandr "Alex" Zinenko }; 13225a9bdd85SOleksandr "Alex" Zinenko 13235a9bdd85SOleksandr "Alex" Zinenko // Check if there are any handles that must be updated. 13245a9bdd85SOleksandr "Alex" Zinenko Value aliveHandle; 13255a9bdd85SOleksandr "Alex" Zinenko if (config.skipHandleFn) { 13265a9bdd85SOleksandr "Alex" Zinenko auto it = llvm::find_if(opHandles, 13275a9bdd85SOleksandr "Alex" Zinenko [&](Value v) { return !config.skipHandleFn(v); }); 13285a9bdd85SOleksandr "Alex" Zinenko if (it != opHandles.end()) 13295a9bdd85SOleksandr "Alex" Zinenko aliveHandle = *it; 13305a9bdd85SOleksandr "Alex" Zinenko } else if (!opHandles.empty()) { 13315a9bdd85SOleksandr "Alex" Zinenko aliveHandle = opHandles.front(); 13325a9bdd85SOleksandr "Alex" Zinenko } 13335a9bdd85SOleksandr "Alex" Zinenko if (!aliveHandle || handleWasConsumed()) { 13345a9bdd85SOleksandr "Alex" Zinenko // The op is tracked but the corresponding handles are dead or were 13355a9bdd85SOleksandr "Alex" Zinenko // consumed. Drop the op form the mapping. 13365a9bdd85SOleksandr "Alex" Zinenko (void)replacePayloadOp(op, nullptr); 13375a9bdd85SOleksandr "Alex" Zinenko return; 13385a9bdd85SOleksandr "Alex" Zinenko } 13395a9bdd85SOleksandr "Alex" Zinenko 13405a9bdd85SOleksandr "Alex" Zinenko Operation *replacement; 13415a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure diag = 13425a9bdd85SOleksandr "Alex" Zinenko findReplacementOp(replacement, op, newValues); 13435a9bdd85SOleksandr "Alex" Zinenko // If the op is tracked but no replacement op was found, send a 13445a9bdd85SOleksandr "Alex" Zinenko // notification. 13455a9bdd85SOleksandr "Alex" Zinenko if (!diag.succeeded()) { 13465a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(aliveHandle.getLoc()) 13475a9bdd85SOleksandr "Alex" Zinenko << "replacement is required because this handle must be updated"; 13485a9bdd85SOleksandr "Alex" Zinenko notifyPayloadReplacementNotFound(op, newValues, std::move(diag)); 13495a9bdd85SOleksandr "Alex" Zinenko (void)replacePayloadOp(op, nullptr); 13505a9bdd85SOleksandr "Alex" Zinenko return; 13515a9bdd85SOleksandr "Alex" Zinenko } 13525a9bdd85SOleksandr "Alex" Zinenko 13535a9bdd85SOleksandr "Alex" Zinenko (void)replacePayloadOp(op, replacement); 13545a9bdd85SOleksandr "Alex" Zinenko } 13555a9bdd85SOleksandr "Alex" Zinenko 13565a9bdd85SOleksandr "Alex" Zinenko transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() { 13575a9bdd85SOleksandr "Alex" Zinenko // The state of the ErrorCheckingTrackingListener must be checked and reset 13585a9bdd85SOleksandr "Alex" Zinenko // if there was an error. This is to prevent errors from accidentally being 13595a9bdd85SOleksandr "Alex" Zinenko // missed. 13605a9bdd85SOleksandr "Alex" Zinenko assert(status.succeeded() && "listener state was not checked"); 13615a9bdd85SOleksandr "Alex" Zinenko } 13625a9bdd85SOleksandr "Alex" Zinenko 13635a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure 13645a9bdd85SOleksandr "Alex" Zinenko transform::ErrorCheckingTrackingListener::checkAndResetError() { 13655a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure s = std::move(status); 13665a9bdd85SOleksandr "Alex" Zinenko status = DiagnosedSilenceableFailure::success(); 13675a9bdd85SOleksandr "Alex" Zinenko errorCounter = 0; 13685a9bdd85SOleksandr "Alex" Zinenko return s; 13695a9bdd85SOleksandr "Alex" Zinenko } 13705a9bdd85SOleksandr "Alex" Zinenko 13715a9bdd85SOleksandr "Alex" Zinenko bool transform::ErrorCheckingTrackingListener::failed() const { 13725a9bdd85SOleksandr "Alex" Zinenko return !status.succeeded(); 13735a9bdd85SOleksandr "Alex" Zinenko } 13745a9bdd85SOleksandr "Alex" Zinenko 13755a9bdd85SOleksandr "Alex" Zinenko void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( 13765a9bdd85SOleksandr "Alex" Zinenko Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) { 13775a9bdd85SOleksandr "Alex" Zinenko 13785a9bdd85SOleksandr "Alex" Zinenko // Merge potentially existing diags and store the result in the listener. 13795a9bdd85SOleksandr "Alex" Zinenko SmallVector<Diagnostic> diags; 13805a9bdd85SOleksandr "Alex" Zinenko diag.takeDiagnostics(diags); 13815a9bdd85SOleksandr "Alex" Zinenko if (!status.succeeded()) 13825a9bdd85SOleksandr "Alex" Zinenko status.takeDiagnostics(diags); 13835a9bdd85SOleksandr "Alex" Zinenko status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags)); 13845a9bdd85SOleksandr "Alex" Zinenko 13855a9bdd85SOleksandr "Alex" Zinenko // Report more details. 13865a9bdd85SOleksandr "Alex" Zinenko status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op"; 13875a9bdd85SOleksandr "Alex" Zinenko for (auto &&[index, value] : llvm::enumerate(values)) 13885a9bdd85SOleksandr "Alex" Zinenko status.attachNote(value.getLoc()) 13895a9bdd85SOleksandr "Alex" Zinenko << "[" << errorCounter << "] replacement value " << index; 13905a9bdd85SOleksandr "Alex" Zinenko ++errorCounter; 13915a9bdd85SOleksandr "Alex" Zinenko } 13925a9bdd85SOleksandr "Alex" Zinenko 13935a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 13945a9bdd85SOleksandr "Alex" Zinenko // TransformRewriter 13955a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 13965a9bdd85SOleksandr "Alex" Zinenko 13975a9bdd85SOleksandr "Alex" Zinenko transform::TransformRewriter::TransformRewriter( 13985a9bdd85SOleksandr "Alex" Zinenko MLIRContext *ctx, ErrorCheckingTrackingListener *listener) 13995a9bdd85SOleksandr "Alex" Zinenko : RewriterBase(ctx), listener(listener) { 14005a9bdd85SOleksandr "Alex" Zinenko setListener(listener); 14015a9bdd85SOleksandr "Alex" Zinenko } 14025a9bdd85SOleksandr "Alex" Zinenko 14035a9bdd85SOleksandr "Alex" Zinenko bool transform::TransformRewriter::hasTrackingFailures() const { 14045a9bdd85SOleksandr "Alex" Zinenko return listener->failed(); 14055a9bdd85SOleksandr "Alex" Zinenko } 14065a9bdd85SOleksandr "Alex" Zinenko 14075a9bdd85SOleksandr "Alex" Zinenko /// Silence all tracking failures that have been encountered so far. 14085a9bdd85SOleksandr "Alex" Zinenko void transform::TransformRewriter::silenceTrackingFailure() { 14095a9bdd85SOleksandr "Alex" Zinenko if (hasTrackingFailures()) { 14105a9bdd85SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure status = listener->checkAndResetError(); 14115a9bdd85SOleksandr "Alex" Zinenko (void)status.silence(); 14125a9bdd85SOleksandr "Alex" Zinenko } 14135a9bdd85SOleksandr "Alex" Zinenko } 14145a9bdd85SOleksandr "Alex" Zinenko 14155a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::TransformRewriter::notifyPayloadOperationReplaced( 14165a9bdd85SOleksandr "Alex" Zinenko Operation *op, Operation *replacement) { 14175a9bdd85SOleksandr "Alex" Zinenko return listener->replacePayloadOp(op, replacement); 14185a9bdd85SOleksandr "Alex" Zinenko } 14195a9bdd85SOleksandr "Alex" Zinenko 14205a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 14215a9bdd85SOleksandr "Alex" Zinenko // Utilities for TransformEachOpTrait. 14225a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 14235a9bdd85SOleksandr "Alex" Zinenko 14245a9bdd85SOleksandr "Alex" Zinenko LogicalResult 14255a9bdd85SOleksandr "Alex" Zinenko transform::detail::checkNestedConsumption(Location loc, 14265a9bdd85SOleksandr "Alex" Zinenko ArrayRef<Operation *> targets) { 14275a9bdd85SOleksandr "Alex" Zinenko for (auto &&[position, parent] : llvm::enumerate(targets)) { 14285a9bdd85SOleksandr "Alex" Zinenko for (Operation *child : targets.drop_front(position + 1)) { 14295a9bdd85SOleksandr "Alex" Zinenko if (parent->isAncestor(child)) { 14305a9bdd85SOleksandr "Alex" Zinenko InFlightDiagnostic diag = 14315a9bdd85SOleksandr "Alex" Zinenko emitError(loc) 14325a9bdd85SOleksandr "Alex" Zinenko << "transform operation consumes a handle pointing to an ancestor " 14335a9bdd85SOleksandr "Alex" Zinenko "payload operation before its descendant"; 14345a9bdd85SOleksandr "Alex" Zinenko diag.attachNote() 14355a9bdd85SOleksandr "Alex" Zinenko << "the ancestor is likely erased or rewritten before the " 14365a9bdd85SOleksandr "Alex" Zinenko "descendant is accessed, leading to undefined behavior"; 14375a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(parent->getLoc()) << "ancestor payload op"; 14385a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(child->getLoc()) << "descendant payload op"; 14395a9bdd85SOleksandr "Alex" Zinenko return diag; 14405a9bdd85SOleksandr "Alex" Zinenko } 14415a9bdd85SOleksandr "Alex" Zinenko } 14425a9bdd85SOleksandr "Alex" Zinenko } 14435a9bdd85SOleksandr "Alex" Zinenko return success(); 14445a9bdd85SOleksandr "Alex" Zinenko } 14455a9bdd85SOleksandr "Alex" Zinenko 14465a9bdd85SOleksandr "Alex" Zinenko LogicalResult 14475a9bdd85SOleksandr "Alex" Zinenko transform::detail::checkApplyToOne(Operation *transformOp, 14485a9bdd85SOleksandr "Alex" Zinenko Location payloadOpLoc, 14495a9bdd85SOleksandr "Alex" Zinenko const ApplyToEachResultList &partialResult) { 14505a9bdd85SOleksandr "Alex" Zinenko Location transformOpLoc = transformOp->getLoc(); 14515a9bdd85SOleksandr "Alex" Zinenko StringRef transformOpName = transformOp->getName().getStringRef(); 14525a9bdd85SOleksandr "Alex" Zinenko unsigned expectedNumResults = transformOp->getNumResults(); 14535a9bdd85SOleksandr "Alex" Zinenko 14545a9bdd85SOleksandr "Alex" Zinenko // Reuse the emission of the diagnostic note. 14555a9bdd85SOleksandr "Alex" Zinenko auto emitDiag = [&]() { 14565a9bdd85SOleksandr "Alex" Zinenko auto diag = mlir::emitError(transformOpLoc); 14575a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(payloadOpLoc) << "when applied to this op"; 14585a9bdd85SOleksandr "Alex" Zinenko return diag; 14595a9bdd85SOleksandr "Alex" Zinenko }; 14605a9bdd85SOleksandr "Alex" Zinenko 14615a9bdd85SOleksandr "Alex" Zinenko if (partialResult.size() != expectedNumResults) { 14625a9bdd85SOleksandr "Alex" Zinenko auto diag = emitDiag() << "application of " << transformOpName 14635a9bdd85SOleksandr "Alex" Zinenko << " expected to produce " << expectedNumResults 14645a9bdd85SOleksandr "Alex" Zinenko << " results (actually produced " 14655a9bdd85SOleksandr "Alex" Zinenko << partialResult.size() << ")."; 14665a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(transformOpLoc) 14675a9bdd85SOleksandr "Alex" Zinenko << "if you need variadic results, consider a generic `apply` " 14685a9bdd85SOleksandr "Alex" Zinenko << "instead of the specialized `applyToOne`."; 14695a9bdd85SOleksandr "Alex" Zinenko return failure(); 14705a9bdd85SOleksandr "Alex" Zinenko } 14715a9bdd85SOleksandr "Alex" Zinenko 14725a9bdd85SOleksandr "Alex" Zinenko // Check that the right kind of value was produced. 14735a9bdd85SOleksandr "Alex" Zinenko for (const auto &[ptr, res] : 14745a9bdd85SOleksandr "Alex" Zinenko llvm::zip(partialResult, transformOp->getResults())) { 14755a9bdd85SOleksandr "Alex" Zinenko if (ptr.isNull()) 14765a9bdd85SOleksandr "Alex" Zinenko continue; 14775a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformHandleTypeInterface>(res.getType()) && 1478*129f1001SKazu Hirata !isa<Operation *>(ptr)) { 14795a9bdd85SOleksandr "Alex" Zinenko return emitDiag() << "application of " << transformOpName 14805a9bdd85SOleksandr "Alex" Zinenko << " expected to produce an Operation * for result #" 14815a9bdd85SOleksandr "Alex" Zinenko << res.getResultNumber(); 14825a9bdd85SOleksandr "Alex" Zinenko } 14835a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformParamTypeInterface>(res.getType()) && 1484*129f1001SKazu Hirata !isa<Attribute>(ptr)) { 14855a9bdd85SOleksandr "Alex" Zinenko return emitDiag() << "application of " << transformOpName 14865a9bdd85SOleksandr "Alex" Zinenko << " expected to produce an Attribute for result #" 14875a9bdd85SOleksandr "Alex" Zinenko << res.getResultNumber(); 14885a9bdd85SOleksandr "Alex" Zinenko } 14895a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformValueHandleTypeInterface>(res.getType()) && 1490*129f1001SKazu Hirata !isa<Value>(ptr)) { 14915a9bdd85SOleksandr "Alex" Zinenko return emitDiag() << "application of " << transformOpName 14925a9bdd85SOleksandr "Alex" Zinenko << " expected to produce a Value for result #" 14935a9bdd85SOleksandr "Alex" Zinenko << res.getResultNumber(); 14945a9bdd85SOleksandr "Alex" Zinenko } 14955a9bdd85SOleksandr "Alex" Zinenko } 14965a9bdd85SOleksandr "Alex" Zinenko return success(); 14975a9bdd85SOleksandr "Alex" Zinenko } 14985a9bdd85SOleksandr "Alex" Zinenko 14995a9bdd85SOleksandr "Alex" Zinenko template <typename T> 15005a9bdd85SOleksandr "Alex" Zinenko static SmallVector<T> castVector(ArrayRef<transform::MappedValue> range) { 15015a9bdd85SOleksandr "Alex" Zinenko return llvm::to_vector(llvm::map_range( 1502*129f1001SKazu Hirata range, [](transform::MappedValue value) { return cast<T>(value); })); 15035a9bdd85SOleksandr "Alex" Zinenko } 15045a9bdd85SOleksandr "Alex" Zinenko 15055a9bdd85SOleksandr "Alex" Zinenko void transform::detail::setApplyToOneResults( 15065a9bdd85SOleksandr "Alex" Zinenko Operation *transformOp, TransformResults &transformResults, 15075a9bdd85SOleksandr "Alex" Zinenko ArrayRef<ApplyToEachResultList> results) { 15085a9bdd85SOleksandr "Alex" Zinenko SmallVector<SmallVector<MappedValue>> transposed; 15095a9bdd85SOleksandr "Alex" Zinenko transposed.resize(transformOp->getNumResults()); 15105a9bdd85SOleksandr "Alex" Zinenko for (const ApplyToEachResultList &partialResults : results) { 15115a9bdd85SOleksandr "Alex" Zinenko if (llvm::any_of(partialResults, 15125a9bdd85SOleksandr "Alex" Zinenko [](MappedValue value) { return value.isNull(); })) 15135a9bdd85SOleksandr "Alex" Zinenko continue; 15145a9bdd85SOleksandr "Alex" Zinenko assert(transformOp->getNumResults() == partialResults.size() && 15155a9bdd85SOleksandr "Alex" Zinenko "expected as many partial results as op as results"); 15165a9bdd85SOleksandr "Alex" Zinenko for (auto [i, value] : llvm::enumerate(partialResults)) 15175a9bdd85SOleksandr "Alex" Zinenko transposed[i].push_back(value); 15185a9bdd85SOleksandr "Alex" Zinenko } 15195a9bdd85SOleksandr "Alex" Zinenko 15205a9bdd85SOleksandr "Alex" Zinenko for (OpResult r : transformOp->getResults()) { 15215a9bdd85SOleksandr "Alex" Zinenko unsigned position = r.getResultNumber(); 15225a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformParamTypeInterface>(r.getType())) { 15235a9bdd85SOleksandr "Alex" Zinenko transformResults.setParams(r, 15245a9bdd85SOleksandr "Alex" Zinenko castVector<Attribute>(transposed[position])); 15255a9bdd85SOleksandr "Alex" Zinenko } else if (llvm::isa<TransformValueHandleTypeInterface>(r.getType())) { 15265a9bdd85SOleksandr "Alex" Zinenko transformResults.setValues(r, castVector<Value>(transposed[position])); 15275a9bdd85SOleksandr "Alex" Zinenko } else { 15285a9bdd85SOleksandr "Alex" Zinenko transformResults.set(r, castVector<Operation *>(transposed[position])); 15295a9bdd85SOleksandr "Alex" Zinenko } 15305a9bdd85SOleksandr "Alex" Zinenko } 15315a9bdd85SOleksandr "Alex" Zinenko } 15325a9bdd85SOleksandr "Alex" Zinenko 15335a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 15345a9bdd85SOleksandr "Alex" Zinenko // Utilities for implementing transform ops with regions. 15355a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 15365a9bdd85SOleksandr "Alex" Zinenko 1537e4b04b39SOleksandr "Alex" Zinenko LogicalResult transform::detail::appendValueMappings( 1538e4b04b39SOleksandr "Alex" Zinenko MutableArrayRef<SmallVector<transform::MappedValue>> mappings, 1539e4b04b39SOleksandr "Alex" Zinenko ValueRange values, const transform::TransformState &state, bool flatten) { 1540e4b04b39SOleksandr "Alex" Zinenko assert(mappings.size() == values.size() && "mismatching number of mappings"); 1541e4b04b39SOleksandr "Alex" Zinenko for (auto &&[operand, mapped] : llvm::zip_equal(values, mappings)) { 1542e4b04b39SOleksandr "Alex" Zinenko size_t mappedSize = mapped.size(); 15435a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) { 15445a9bdd85SOleksandr "Alex" Zinenko llvm::append_range(mapped, state.getPayloadOps(operand)); 15455a9bdd85SOleksandr "Alex" Zinenko } else if (llvm::isa<TransformValueHandleTypeInterface>( 15465a9bdd85SOleksandr "Alex" Zinenko operand.getType())) { 15475a9bdd85SOleksandr "Alex" Zinenko llvm::append_range(mapped, state.getPayloadValues(operand)); 15485a9bdd85SOleksandr "Alex" Zinenko } else { 15495a9bdd85SOleksandr "Alex" Zinenko assert(llvm::isa<TransformParamTypeInterface>(operand.getType()) && 15505a9bdd85SOleksandr "Alex" Zinenko "unsupported kind of transform dialect value"); 15515a9bdd85SOleksandr "Alex" Zinenko llvm::append_range(mapped, state.getParams(operand)); 15525a9bdd85SOleksandr "Alex" Zinenko } 1553e4b04b39SOleksandr "Alex" Zinenko 1554e4b04b39SOleksandr "Alex" Zinenko if (mapped.size() - mappedSize != 1 && !flatten) 1555e4b04b39SOleksandr "Alex" Zinenko return failure(); 15565a9bdd85SOleksandr "Alex" Zinenko } 1557e4b04b39SOleksandr "Alex" Zinenko return success(); 1558e4b04b39SOleksandr "Alex" Zinenko } 1559e4b04b39SOleksandr "Alex" Zinenko 1560e4b04b39SOleksandr "Alex" Zinenko void transform::detail::prepareValueMappings( 1561e4b04b39SOleksandr "Alex" Zinenko SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings, 1562e4b04b39SOleksandr "Alex" Zinenko ValueRange values, const transform::TransformState &state) { 1563e4b04b39SOleksandr "Alex" Zinenko mappings.resize(mappings.size() + values.size()); 1564e4b04b39SOleksandr "Alex" Zinenko (void)appendValueMappings( 1565e4b04b39SOleksandr "Alex" Zinenko MutableArrayRef<SmallVector<transform::MappedValue>>(mappings).take_back( 1566e4b04b39SOleksandr "Alex" Zinenko values.size()), 1567e4b04b39SOleksandr "Alex" Zinenko values, state); 15685a9bdd85SOleksandr "Alex" Zinenko } 15695a9bdd85SOleksandr "Alex" Zinenko 15705a9bdd85SOleksandr "Alex" Zinenko void transform::detail::forwardTerminatorOperands( 15715a9bdd85SOleksandr "Alex" Zinenko Block *block, transform::TransformState &state, 15725a9bdd85SOleksandr "Alex" Zinenko transform::TransformResults &results) { 15735a9bdd85SOleksandr "Alex" Zinenko for (auto &&[terminatorOperand, result] : 15745a9bdd85SOleksandr "Alex" Zinenko llvm::zip(block->getTerminator()->getOperands(), 15755a9bdd85SOleksandr "Alex" Zinenko block->getParentOp()->getOpResults())) { 15765a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<transform::TransformHandleTypeInterface>(result.getType())) { 15775a9bdd85SOleksandr "Alex" Zinenko results.set(result, state.getPayloadOps(terminatorOperand)); 15785a9bdd85SOleksandr "Alex" Zinenko } else if (llvm::isa<transform::TransformValueHandleTypeInterface>( 15795a9bdd85SOleksandr "Alex" Zinenko result.getType())) { 15805a9bdd85SOleksandr "Alex" Zinenko results.setValues(result, state.getPayloadValues(terminatorOperand)); 15815a9bdd85SOleksandr "Alex" Zinenko } else { 15825a9bdd85SOleksandr "Alex" Zinenko assert( 15835a9bdd85SOleksandr "Alex" Zinenko llvm::isa<transform::TransformParamTypeInterface>(result.getType()) && 15845a9bdd85SOleksandr "Alex" Zinenko "unhandled transform type interface"); 15855a9bdd85SOleksandr "Alex" Zinenko results.setParams(result, state.getParams(terminatorOperand)); 15865a9bdd85SOleksandr "Alex" Zinenko } 15875a9bdd85SOleksandr "Alex" Zinenko } 15885a9bdd85SOleksandr "Alex" Zinenko } 15895a9bdd85SOleksandr "Alex" Zinenko 15905a9bdd85SOleksandr "Alex" Zinenko transform::TransformState 15915a9bdd85SOleksandr "Alex" Zinenko transform::detail::makeTransformStateForTesting(Region *region, 15925a9bdd85SOleksandr "Alex" Zinenko Operation *payloadRoot) { 15935a9bdd85SOleksandr "Alex" Zinenko return TransformState(region, payloadRoot); 15945a9bdd85SOleksandr "Alex" Zinenko } 15955a9bdd85SOleksandr "Alex" Zinenko 15965a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 15975a9bdd85SOleksandr "Alex" Zinenko // Utilities for PossibleTopLevelTransformOpTrait. 15985a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 15995a9bdd85SOleksandr "Alex" Zinenko 16005a9bdd85SOleksandr "Alex" Zinenko /// Appends to `effects` the memory effect instances on `target` with the same 16015a9bdd85SOleksandr "Alex" Zinenko /// resource and effect as the ones the operation `iface` having on `source`. 16025a9bdd85SOleksandr "Alex" Zinenko static void 16032c1ae801Sdonald chen remapEffects(MemoryEffectOpInterface iface, BlockArgument source, 16042c1ae801Sdonald chen OpOperand *target, 16055a9bdd85SOleksandr "Alex" Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 16065a9bdd85SOleksandr "Alex" Zinenko SmallVector<MemoryEffects::EffectInstance> nestedEffects; 16075a9bdd85SOleksandr "Alex" Zinenko iface.getEffectsOnValue(source, nestedEffects); 16085a9bdd85SOleksandr "Alex" Zinenko for (const auto &effect : nestedEffects) 16095a9bdd85SOleksandr "Alex" Zinenko effects.emplace_back(effect.getEffect(), target, effect.getResource()); 16105a9bdd85SOleksandr "Alex" Zinenko } 16115a9bdd85SOleksandr "Alex" Zinenko 16125a9bdd85SOleksandr "Alex" Zinenko /// Appends to `effects` the same effects as the operations of `block` have on 16135a9bdd85SOleksandr "Alex" Zinenko /// block arguments but associated with `operands.` 16145a9bdd85SOleksandr "Alex" Zinenko static void 16152c1ae801Sdonald chen remapArgumentEffects(Block &block, MutableArrayRef<OpOperand> operands, 16165a9bdd85SOleksandr "Alex" Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 16175a9bdd85SOleksandr "Alex" Zinenko for (Operation &op : block) { 16185a9bdd85SOleksandr "Alex" Zinenko auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 16195a9bdd85SOleksandr "Alex" Zinenko if (!iface) 16205a9bdd85SOleksandr "Alex" Zinenko continue; 16215a9bdd85SOleksandr "Alex" Zinenko 16225a9bdd85SOleksandr "Alex" Zinenko for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) { 16232c1ae801Sdonald chen remapEffects(iface, source, &target, effects); 16245a9bdd85SOleksandr "Alex" Zinenko } 16255a9bdd85SOleksandr "Alex" Zinenko 16265a9bdd85SOleksandr "Alex" Zinenko SmallVector<MemoryEffects::EffectInstance> nestedEffects; 16275a9bdd85SOleksandr "Alex" Zinenko iface.getEffectsOnResource(transform::PayloadIRResource::get(), 16285a9bdd85SOleksandr "Alex" Zinenko nestedEffects); 16295a9bdd85SOleksandr "Alex" Zinenko llvm::append_range(effects, nestedEffects); 16305a9bdd85SOleksandr "Alex" Zinenko } 16315a9bdd85SOleksandr "Alex" Zinenko } 16325a9bdd85SOleksandr "Alex" Zinenko 16335a9bdd85SOleksandr "Alex" Zinenko void transform::detail::getPotentialTopLevelEffects( 16345a9bdd85SOleksandr "Alex" Zinenko Operation *operation, Value root, Block &body, 16355a9bdd85SOleksandr "Alex" Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 16362c1ae801Sdonald chen transform::onlyReadsHandle(operation->getOpOperands(), effects); 16372c1ae801Sdonald chen transform::producesHandle(operation->getOpResults(), effects); 16385a9bdd85SOleksandr "Alex" Zinenko 16395a9bdd85SOleksandr "Alex" Zinenko if (!root) { 16405a9bdd85SOleksandr "Alex" Zinenko for (Operation &op : body) { 16415a9bdd85SOleksandr "Alex" Zinenko auto iface = dyn_cast<MemoryEffectOpInterface>(&op); 16425a9bdd85SOleksandr "Alex" Zinenko if (!iface) 16435a9bdd85SOleksandr "Alex" Zinenko continue; 16445a9bdd85SOleksandr "Alex" Zinenko 16455a9bdd85SOleksandr "Alex" Zinenko SmallVector<MemoryEffects::EffectInstance, 2> nestedEffects; 16465a9bdd85SOleksandr "Alex" Zinenko iface.getEffects(effects); 16475a9bdd85SOleksandr "Alex" Zinenko } 16485a9bdd85SOleksandr "Alex" Zinenko return; 16495a9bdd85SOleksandr "Alex" Zinenko } 16505a9bdd85SOleksandr "Alex" Zinenko 16515a9bdd85SOleksandr "Alex" Zinenko // Carry over all effects on arguments of the entry block as those on the 16525a9bdd85SOleksandr "Alex" Zinenko // operands, this is the same value just remapped. 16532c1ae801Sdonald chen remapArgumentEffects(body, operation->getOpOperands(), effects); 16545a9bdd85SOleksandr "Alex" Zinenko } 16555a9bdd85SOleksandr "Alex" Zinenko 16565a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( 16575a9bdd85SOleksandr "Alex" Zinenko TransformState &state, Operation *op, Region ®ion) { 16585a9bdd85SOleksandr "Alex" Zinenko SmallVector<Operation *> targets; 16595a9bdd85SOleksandr "Alex" Zinenko SmallVector<SmallVector<MappedValue>> extraMappings; 16605a9bdd85SOleksandr "Alex" Zinenko if (op->getNumOperands() != 0) { 16615a9bdd85SOleksandr "Alex" Zinenko llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); 16625a9bdd85SOleksandr "Alex" Zinenko prepareValueMappings(extraMappings, op->getOperands().drop_front(), state); 16635a9bdd85SOleksandr "Alex" Zinenko } else { 16645a9bdd85SOleksandr "Alex" Zinenko if (state.getNumTopLevelMappings() != 16655a9bdd85SOleksandr "Alex" Zinenko region.front().getNumArguments() - 1) { 16665a9bdd85SOleksandr "Alex" Zinenko return emitError(op->getLoc()) 16675a9bdd85SOleksandr "Alex" Zinenko << "operation expects " << region.front().getNumArguments() - 1 16685a9bdd85SOleksandr "Alex" Zinenko << " extra value bindings, but " << state.getNumTopLevelMappings() 16695a9bdd85SOleksandr "Alex" Zinenko << " were provided to the interpreter"; 16705a9bdd85SOleksandr "Alex" Zinenko } 16715a9bdd85SOleksandr "Alex" Zinenko 16725a9bdd85SOleksandr "Alex" Zinenko targets.push_back(state.getTopLevel()); 16735a9bdd85SOleksandr "Alex" Zinenko 16745a9bdd85SOleksandr "Alex" Zinenko for (unsigned i = 0, e = state.getNumTopLevelMappings(); i < e; ++i) 16755a9bdd85SOleksandr "Alex" Zinenko extraMappings.push_back(llvm::to_vector(state.getTopLevelMapping(i))); 16765a9bdd85SOleksandr "Alex" Zinenko } 16775a9bdd85SOleksandr "Alex" Zinenko 16785a9bdd85SOleksandr "Alex" Zinenko if (failed(state.mapBlockArguments(region.front().getArgument(0), targets))) 16795a9bdd85SOleksandr "Alex" Zinenko return failure(); 16805a9bdd85SOleksandr "Alex" Zinenko 16815a9bdd85SOleksandr "Alex" Zinenko for (BlockArgument argument : region.front().getArguments().drop_front()) { 16825a9bdd85SOleksandr "Alex" Zinenko if (failed(state.mapBlockArgument( 16835a9bdd85SOleksandr "Alex" Zinenko argument, extraMappings[argument.getArgNumber() - 1]))) 16845a9bdd85SOleksandr "Alex" Zinenko return failure(); 16855a9bdd85SOleksandr "Alex" Zinenko } 16865a9bdd85SOleksandr "Alex" Zinenko 16875a9bdd85SOleksandr "Alex" Zinenko return success(); 16885a9bdd85SOleksandr "Alex" Zinenko } 16895a9bdd85SOleksandr "Alex" Zinenko 16905a9bdd85SOleksandr "Alex" Zinenko LogicalResult 16915a9bdd85SOleksandr "Alex" Zinenko transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { 16925a9bdd85SOleksandr "Alex" Zinenko // Attaching this trait without the interface is a misuse of the API, but it 16935a9bdd85SOleksandr "Alex" Zinenko // cannot be caught via a static_assert because interface registration is 16945a9bdd85SOleksandr "Alex" Zinenko // dynamic. 16955a9bdd85SOleksandr "Alex" Zinenko assert(isa<TransformOpInterface>(op) && 16965a9bdd85SOleksandr "Alex" Zinenko "should implement TransformOpInterface to have " 16975a9bdd85SOleksandr "Alex" Zinenko "PossibleTopLevelTransformOpTrait"); 16985a9bdd85SOleksandr "Alex" Zinenko 16995a9bdd85SOleksandr "Alex" Zinenko if (op->getNumRegions() < 1) 17005a9bdd85SOleksandr "Alex" Zinenko return op->emitOpError() << "expects at least one region"; 17015a9bdd85SOleksandr "Alex" Zinenko 17025a9bdd85SOleksandr "Alex" Zinenko Region *bodyRegion = &op->getRegion(0); 17035a9bdd85SOleksandr "Alex" Zinenko if (!llvm::hasNItems(*bodyRegion, 1)) 17045a9bdd85SOleksandr "Alex" Zinenko return op->emitOpError() << "expects a single-block region"; 17055a9bdd85SOleksandr "Alex" Zinenko 17065a9bdd85SOleksandr "Alex" Zinenko Block *body = &bodyRegion->front(); 17075a9bdd85SOleksandr "Alex" Zinenko if (body->getNumArguments() == 0) { 17085a9bdd85SOleksandr "Alex" Zinenko return op->emitOpError() 17095a9bdd85SOleksandr "Alex" Zinenko << "expects the entry block to have at least one argument"; 17105a9bdd85SOleksandr "Alex" Zinenko } 17115a9bdd85SOleksandr "Alex" Zinenko if (!llvm::isa<TransformHandleTypeInterface>( 17125a9bdd85SOleksandr "Alex" Zinenko body->getArgument(0).getType())) { 17135a9bdd85SOleksandr "Alex" Zinenko return op->emitOpError() 17145a9bdd85SOleksandr "Alex" Zinenko << "expects the first entry block argument to be of type " 17155a9bdd85SOleksandr "Alex" Zinenko "implementing TransformHandleTypeInterface"; 17165a9bdd85SOleksandr "Alex" Zinenko } 17175a9bdd85SOleksandr "Alex" Zinenko BlockArgument arg = body->getArgument(0); 17185a9bdd85SOleksandr "Alex" Zinenko if (op->getNumOperands() != 0) { 17195a9bdd85SOleksandr "Alex" Zinenko if (arg.getType() != op->getOperand(0).getType()) { 17205a9bdd85SOleksandr "Alex" Zinenko return op->emitOpError() 17215a9bdd85SOleksandr "Alex" Zinenko << "expects the type of the block argument to match " 17225a9bdd85SOleksandr "Alex" Zinenko "the type of the operand"; 17235a9bdd85SOleksandr "Alex" Zinenko } 17245a9bdd85SOleksandr "Alex" Zinenko } 17255a9bdd85SOleksandr "Alex" Zinenko for (BlockArgument arg : body->getArguments().drop_front()) { 17265a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformHandleTypeInterface, TransformParamTypeInterface, 17275a9bdd85SOleksandr "Alex" Zinenko TransformValueHandleTypeInterface>(arg.getType())) 17285a9bdd85SOleksandr "Alex" Zinenko continue; 17295a9bdd85SOleksandr "Alex" Zinenko 17305a9bdd85SOleksandr "Alex" Zinenko InFlightDiagnostic diag = 17315a9bdd85SOleksandr "Alex" Zinenko op->emitOpError() 17325a9bdd85SOleksandr "Alex" Zinenko << "expects trailing entry block arguments to be of type implementing " 17335a9bdd85SOleksandr "Alex" Zinenko "TransformHandleTypeInterface, TransformValueHandleTypeInterface or " 17345a9bdd85SOleksandr "Alex" Zinenko "TransformParamTypeInterface"; 17355a9bdd85SOleksandr "Alex" Zinenko diag.attachNote() << "argument #" << arg.getArgNumber() << " does not"; 17365a9bdd85SOleksandr "Alex" Zinenko return diag; 17375a9bdd85SOleksandr "Alex" Zinenko } 17385a9bdd85SOleksandr "Alex" Zinenko 17395a9bdd85SOleksandr "Alex" Zinenko if (auto *parent = 17405a9bdd85SOleksandr "Alex" Zinenko op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) { 17415a9bdd85SOleksandr "Alex" Zinenko if (op->getNumOperands() != body->getNumArguments()) { 17425a9bdd85SOleksandr "Alex" Zinenko InFlightDiagnostic diag = 17435a9bdd85SOleksandr "Alex" Zinenko op->emitOpError() 17445a9bdd85SOleksandr "Alex" Zinenko << "expects operands to be provided for a nested op"; 17455a9bdd85SOleksandr "Alex" Zinenko diag.attachNote(parent->getLoc()) 17465a9bdd85SOleksandr "Alex" Zinenko << "nested in another possible top-level op"; 17475a9bdd85SOleksandr "Alex" Zinenko return diag; 17485a9bdd85SOleksandr "Alex" Zinenko } 17495a9bdd85SOleksandr "Alex" Zinenko } 17505a9bdd85SOleksandr "Alex" Zinenko 17515a9bdd85SOleksandr "Alex" Zinenko return success(); 17525a9bdd85SOleksandr "Alex" Zinenko } 17535a9bdd85SOleksandr "Alex" Zinenko 17545a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 17555a9bdd85SOleksandr "Alex" Zinenko // Utilities for ParamProducedTransformOpTrait. 17565a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 17575a9bdd85SOleksandr "Alex" Zinenko 17585a9bdd85SOleksandr "Alex" Zinenko void transform::detail::getParamProducerTransformOpTraitEffects( 17595a9bdd85SOleksandr "Alex" Zinenko Operation *op, SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 17605a9bdd85SOleksandr "Alex" Zinenko producesHandle(op->getResults(), effects); 17615a9bdd85SOleksandr "Alex" Zinenko bool hasPayloadOperands = false; 17622c1ae801Sdonald chen for (OpOperand &operand : op->getOpOperands()) { 17635a9bdd85SOleksandr "Alex" Zinenko onlyReadsHandle(operand, effects); 17645a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformHandleTypeInterface, 17652c1ae801Sdonald chen TransformValueHandleTypeInterface>(operand.get().getType())) 17665a9bdd85SOleksandr "Alex" Zinenko hasPayloadOperands = true; 17675a9bdd85SOleksandr "Alex" Zinenko } 17685a9bdd85SOleksandr "Alex" Zinenko if (hasPayloadOperands) 17695a9bdd85SOleksandr "Alex" Zinenko onlyReadsPayload(effects); 17705a9bdd85SOleksandr "Alex" Zinenko } 17715a9bdd85SOleksandr "Alex" Zinenko 17725a9bdd85SOleksandr "Alex" Zinenko LogicalResult 17735a9bdd85SOleksandr "Alex" Zinenko transform::detail::verifyParamProducerTransformOpTrait(Operation *op) { 17745a9bdd85SOleksandr "Alex" Zinenko // Interfaces can be attached dynamically, so this cannot be a static 17755a9bdd85SOleksandr "Alex" Zinenko // assert. 17765a9bdd85SOleksandr "Alex" Zinenko if (!op->getName().getInterface<MemoryEffectOpInterface>()) { 17775a9bdd85SOleksandr "Alex" Zinenko llvm::report_fatal_error( 17785a9bdd85SOleksandr "Alex" Zinenko Twine("ParamProducerTransformOpTrait must be attached to an op that " 17795a9bdd85SOleksandr "Alex" Zinenko "implements MemoryEffectsOpInterface, found on ") + 17805a9bdd85SOleksandr "Alex" Zinenko op->getName().getStringRef()); 17815a9bdd85SOleksandr "Alex" Zinenko } 17825a9bdd85SOleksandr "Alex" Zinenko for (Value result : op->getResults()) { 17835a9bdd85SOleksandr "Alex" Zinenko if (llvm::isa<TransformParamTypeInterface>(result.getType())) 17845a9bdd85SOleksandr "Alex" Zinenko continue; 17855a9bdd85SOleksandr "Alex" Zinenko return op->emitOpError() 17865a9bdd85SOleksandr "Alex" Zinenko << "ParamProducerTransformOpTrait attached to this op expects " 17875a9bdd85SOleksandr "Alex" Zinenko "result types to implement TransformParamTypeInterface"; 17885a9bdd85SOleksandr "Alex" Zinenko } 17895a9bdd85SOleksandr "Alex" Zinenko return success(); 17905a9bdd85SOleksandr "Alex" Zinenko } 17915a9bdd85SOleksandr "Alex" Zinenko 17925a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 17935a9bdd85SOleksandr "Alex" Zinenko // Memory effects. 17945a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 17955a9bdd85SOleksandr "Alex" Zinenko 17965a9bdd85SOleksandr "Alex" Zinenko void transform::consumesHandle( 17972c1ae801Sdonald chen MutableArrayRef<OpOperand> handles, 17985a9bdd85SOleksandr "Alex" Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 17992c1ae801Sdonald chen for (OpOperand &handle : handles) { 18002c1ae801Sdonald chen effects.emplace_back(MemoryEffects::Read::get(), &handle, 18015a9bdd85SOleksandr "Alex" Zinenko TransformMappingResource::get()); 18022c1ae801Sdonald chen effects.emplace_back(MemoryEffects::Free::get(), &handle, 18035a9bdd85SOleksandr "Alex" Zinenko TransformMappingResource::get()); 18045a9bdd85SOleksandr "Alex" Zinenko } 18055a9bdd85SOleksandr "Alex" Zinenko } 18065a9bdd85SOleksandr "Alex" Zinenko 18075a9bdd85SOleksandr "Alex" Zinenko /// Returns `true` if the given list of effects instances contains an instance 18085a9bdd85SOleksandr "Alex" Zinenko /// with the effect type specified as template parameter. 18095a9bdd85SOleksandr "Alex" Zinenko template <typename EffectTy, typename ResourceTy, typename Range> 18105a9bdd85SOleksandr "Alex" Zinenko static bool hasEffect(Range &&effects) { 18115a9bdd85SOleksandr "Alex" Zinenko return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 18125a9bdd85SOleksandr "Alex" Zinenko return isa<EffectTy>(effect.getEffect()) && 18135a9bdd85SOleksandr "Alex" Zinenko isa<ResourceTy>(effect.getResource()); 18145a9bdd85SOleksandr "Alex" Zinenko }); 18155a9bdd85SOleksandr "Alex" Zinenko } 18165a9bdd85SOleksandr "Alex" Zinenko 18175a9bdd85SOleksandr "Alex" Zinenko bool transform::isHandleConsumed(Value handle, 18185a9bdd85SOleksandr "Alex" Zinenko transform::TransformOpInterface transform) { 18195a9bdd85SOleksandr "Alex" Zinenko auto iface = cast<MemoryEffectOpInterface>(transform.getOperation()); 18205a9bdd85SOleksandr "Alex" Zinenko SmallVector<MemoryEffects::EffectInstance> effects; 18215a9bdd85SOleksandr "Alex" Zinenko iface.getEffectsOnValue(handle, effects); 18225a9bdd85SOleksandr "Alex" Zinenko return ::hasEffect<MemoryEffects::Read, TransformMappingResource>(effects) && 18235a9bdd85SOleksandr "Alex" Zinenko ::hasEffect<MemoryEffects::Free, TransformMappingResource>(effects); 18245a9bdd85SOleksandr "Alex" Zinenko } 18255a9bdd85SOleksandr "Alex" Zinenko 18265a9bdd85SOleksandr "Alex" Zinenko void transform::producesHandle( 18272c1ae801Sdonald chen ResultRange handles, 18285a9bdd85SOleksandr "Alex" Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 18292c1ae801Sdonald chen for (OpResult handle : handles) { 18302c1ae801Sdonald chen effects.emplace_back(MemoryEffects::Allocate::get(), handle, 18312c1ae801Sdonald chen TransformMappingResource::get()); 18322c1ae801Sdonald chen effects.emplace_back(MemoryEffects::Write::get(), handle, 18332c1ae801Sdonald chen TransformMappingResource::get()); 18342c1ae801Sdonald chen } 18352c1ae801Sdonald chen } 18362c1ae801Sdonald chen 18372c1ae801Sdonald chen void transform::producesHandle( 18382c1ae801Sdonald chen MutableArrayRef<BlockArgument> handles, 18392c1ae801Sdonald chen SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 18402c1ae801Sdonald chen for (BlockArgument handle : handles) { 18415a9bdd85SOleksandr "Alex" Zinenko effects.emplace_back(MemoryEffects::Allocate::get(), handle, 18425a9bdd85SOleksandr "Alex" Zinenko TransformMappingResource::get()); 18435a9bdd85SOleksandr "Alex" Zinenko effects.emplace_back(MemoryEffects::Write::get(), handle, 18445a9bdd85SOleksandr "Alex" Zinenko TransformMappingResource::get()); 18455a9bdd85SOleksandr "Alex" Zinenko } 18465a9bdd85SOleksandr "Alex" Zinenko } 18475a9bdd85SOleksandr "Alex" Zinenko 18485a9bdd85SOleksandr "Alex" Zinenko void transform::onlyReadsHandle( 18492c1ae801Sdonald chen MutableArrayRef<OpOperand> handles, 18505a9bdd85SOleksandr "Alex" Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 18512c1ae801Sdonald chen for (OpOperand &handle : handles) { 18522c1ae801Sdonald chen effects.emplace_back(MemoryEffects::Read::get(), &handle, 18535a9bdd85SOleksandr "Alex" Zinenko TransformMappingResource::get()); 18545a9bdd85SOleksandr "Alex" Zinenko } 18555a9bdd85SOleksandr "Alex" Zinenko } 18565a9bdd85SOleksandr "Alex" Zinenko 18575a9bdd85SOleksandr "Alex" Zinenko void transform::modifiesPayload( 18585a9bdd85SOleksandr "Alex" Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 18595a9bdd85SOleksandr "Alex" Zinenko effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 18605a9bdd85SOleksandr "Alex" Zinenko effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get()); 18615a9bdd85SOleksandr "Alex" Zinenko } 18625a9bdd85SOleksandr "Alex" Zinenko 18635a9bdd85SOleksandr "Alex" Zinenko void transform::onlyReadsPayload( 18645a9bdd85SOleksandr "Alex" Zinenko SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 18655a9bdd85SOleksandr "Alex" Zinenko effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get()); 18665a9bdd85SOleksandr "Alex" Zinenko } 18675a9bdd85SOleksandr "Alex" Zinenko 18685a9bdd85SOleksandr "Alex" Zinenko bool transform::doesModifyPayload(transform::TransformOpInterface transform) { 18695a9bdd85SOleksandr "Alex" Zinenko auto iface = cast<MemoryEffectOpInterface>(transform.getOperation()); 18705a9bdd85SOleksandr "Alex" Zinenko SmallVector<MemoryEffects::EffectInstance> effects; 18715a9bdd85SOleksandr "Alex" Zinenko iface.getEffects(effects); 18725a9bdd85SOleksandr "Alex" Zinenko return ::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects); 18735a9bdd85SOleksandr "Alex" Zinenko } 18745a9bdd85SOleksandr "Alex" Zinenko 18755a9bdd85SOleksandr "Alex" Zinenko bool transform::doesReadPayload(transform::TransformOpInterface transform) { 18765a9bdd85SOleksandr "Alex" Zinenko auto iface = cast<MemoryEffectOpInterface>(transform.getOperation()); 18775a9bdd85SOleksandr "Alex" Zinenko SmallVector<MemoryEffects::EffectInstance> effects; 18785a9bdd85SOleksandr "Alex" Zinenko iface.getEffects(effects); 18795a9bdd85SOleksandr "Alex" Zinenko return ::hasEffect<MemoryEffects::Read, PayloadIRResource>(effects); 18805a9bdd85SOleksandr "Alex" Zinenko } 18815a9bdd85SOleksandr "Alex" Zinenko 18825a9bdd85SOleksandr "Alex" Zinenko void transform::getConsumedBlockArguments( 18835a9bdd85SOleksandr "Alex" Zinenko Block &block, llvm::SmallDenseSet<unsigned int> &consumedArguments) { 18845a9bdd85SOleksandr "Alex" Zinenko SmallVector<MemoryEffects::EffectInstance> effects; 18855a9bdd85SOleksandr "Alex" Zinenko for (Operation &nested : block) { 18865a9bdd85SOleksandr "Alex" Zinenko auto iface = dyn_cast<MemoryEffectOpInterface>(nested); 18875a9bdd85SOleksandr "Alex" Zinenko if (!iface) 18885a9bdd85SOleksandr "Alex" Zinenko continue; 18895a9bdd85SOleksandr "Alex" Zinenko 18905a9bdd85SOleksandr "Alex" Zinenko effects.clear(); 18915a9bdd85SOleksandr "Alex" Zinenko iface.getEffects(effects); 18925a9bdd85SOleksandr "Alex" Zinenko for (const MemoryEffects::EffectInstance &effect : effects) { 18935a9bdd85SOleksandr "Alex" Zinenko BlockArgument argument = 18945a9bdd85SOleksandr "Alex" Zinenko dyn_cast_or_null<BlockArgument>(effect.getValue()); 18955a9bdd85SOleksandr "Alex" Zinenko if (!argument || argument.getOwner() != &block || 18965a9bdd85SOleksandr "Alex" Zinenko !isa<MemoryEffects::Free>(effect.getEffect()) || 18975a9bdd85SOleksandr "Alex" Zinenko effect.getResource() != transform::TransformMappingResource::get()) { 18985a9bdd85SOleksandr "Alex" Zinenko continue; 18995a9bdd85SOleksandr "Alex" Zinenko } 19005a9bdd85SOleksandr "Alex" Zinenko consumedArguments.insert(argument.getArgNumber()); 19015a9bdd85SOleksandr "Alex" Zinenko } 19025a9bdd85SOleksandr "Alex" Zinenko } 19035a9bdd85SOleksandr "Alex" Zinenko } 19045a9bdd85SOleksandr "Alex" Zinenko 19055a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 19065a9bdd85SOleksandr "Alex" Zinenko // Utilities for TransformOpInterface. 19075a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 19085a9bdd85SOleksandr "Alex" Zinenko 19095a9bdd85SOleksandr "Alex" Zinenko SmallVector<OpOperand *> transform::detail::getConsumedHandleOpOperands( 19105a9bdd85SOleksandr "Alex" Zinenko TransformOpInterface transformOp) { 19115a9bdd85SOleksandr "Alex" Zinenko SmallVector<OpOperand *> consumedOperands; 19125a9bdd85SOleksandr "Alex" Zinenko consumedOperands.reserve(transformOp->getNumOperands()); 19135a9bdd85SOleksandr "Alex" Zinenko auto memEffectInterface = 19145a9bdd85SOleksandr "Alex" Zinenko cast<MemoryEffectOpInterface>(transformOp.getOperation()); 19155a9bdd85SOleksandr "Alex" Zinenko SmallVector<MemoryEffects::EffectInstance, 2> effects; 19165a9bdd85SOleksandr "Alex" Zinenko for (OpOperand &target : transformOp->getOpOperands()) { 19175a9bdd85SOleksandr "Alex" Zinenko effects.clear(); 19185a9bdd85SOleksandr "Alex" Zinenko memEffectInterface.getEffectsOnValue(target.get(), effects); 19195a9bdd85SOleksandr "Alex" Zinenko if (llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) { 19205a9bdd85SOleksandr "Alex" Zinenko return isa<transform::TransformMappingResource>( 19215a9bdd85SOleksandr "Alex" Zinenko effect.getResource()) && 19225a9bdd85SOleksandr "Alex" Zinenko isa<MemoryEffects::Free>(effect.getEffect()); 19235a9bdd85SOleksandr "Alex" Zinenko })) { 19245a9bdd85SOleksandr "Alex" Zinenko consumedOperands.push_back(&target); 19255a9bdd85SOleksandr "Alex" Zinenko } 19265a9bdd85SOleksandr "Alex" Zinenko } 19275a9bdd85SOleksandr "Alex" Zinenko return consumedOperands; 19285a9bdd85SOleksandr "Alex" Zinenko } 19295a9bdd85SOleksandr "Alex" Zinenko 19305a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) { 19315a9bdd85SOleksandr "Alex" Zinenko auto iface = cast<MemoryEffectOpInterface>(op); 19325a9bdd85SOleksandr "Alex" Zinenko SmallVector<MemoryEffects::EffectInstance> effects; 19335a9bdd85SOleksandr "Alex" Zinenko iface.getEffects(effects); 19345a9bdd85SOleksandr "Alex" Zinenko 19355a9bdd85SOleksandr "Alex" Zinenko auto effectsOn = [&](Value value) { 19365a9bdd85SOleksandr "Alex" Zinenko return llvm::make_filter_range( 19375a9bdd85SOleksandr "Alex" Zinenko effects, [value](const MemoryEffects::EffectInstance &instance) { 19385a9bdd85SOleksandr "Alex" Zinenko return instance.getValue() == value; 19395a9bdd85SOleksandr "Alex" Zinenko }); 19405a9bdd85SOleksandr "Alex" Zinenko }; 19415a9bdd85SOleksandr "Alex" Zinenko 19425a9bdd85SOleksandr "Alex" Zinenko std::optional<unsigned> firstConsumedOperand; 19435a9bdd85SOleksandr "Alex" Zinenko for (OpOperand &operand : op->getOpOperands()) { 19445a9bdd85SOleksandr "Alex" Zinenko auto range = effectsOn(operand.get()); 19455a9bdd85SOleksandr "Alex" Zinenko if (range.empty()) { 19465a9bdd85SOleksandr "Alex" Zinenko InFlightDiagnostic diag = 19475a9bdd85SOleksandr "Alex" Zinenko op->emitError() << "TransformOpInterface requires memory effects " 19485a9bdd85SOleksandr "Alex" Zinenko "on operands to be specified"; 19495a9bdd85SOleksandr "Alex" Zinenko diag.attachNote() << "no effects specified for operand #" 19505a9bdd85SOleksandr "Alex" Zinenko << operand.getOperandNumber(); 19515a9bdd85SOleksandr "Alex" Zinenko return diag; 19525a9bdd85SOleksandr "Alex" Zinenko } 19535a9bdd85SOleksandr "Alex" Zinenko if (::hasEffect<MemoryEffects::Allocate, TransformMappingResource>(range)) { 19545a9bdd85SOleksandr "Alex" Zinenko InFlightDiagnostic diag = op->emitError() 19555a9bdd85SOleksandr "Alex" Zinenko << "TransformOpInterface did not expect " 19565a9bdd85SOleksandr "Alex" Zinenko "'allocate' memory effect on an operand"; 19575a9bdd85SOleksandr "Alex" Zinenko diag.attachNote() << "specified for operand #" 19585a9bdd85SOleksandr "Alex" Zinenko << operand.getOperandNumber(); 19595a9bdd85SOleksandr "Alex" Zinenko return diag; 19605a9bdd85SOleksandr "Alex" Zinenko } 19615a9bdd85SOleksandr "Alex" Zinenko if (!firstConsumedOperand && 19625a9bdd85SOleksandr "Alex" Zinenko ::hasEffect<MemoryEffects::Free, TransformMappingResource>(range)) { 19635a9bdd85SOleksandr "Alex" Zinenko firstConsumedOperand = operand.getOperandNumber(); 19645a9bdd85SOleksandr "Alex" Zinenko } 19655a9bdd85SOleksandr "Alex" Zinenko } 19665a9bdd85SOleksandr "Alex" Zinenko 19675a9bdd85SOleksandr "Alex" Zinenko if (firstConsumedOperand && 19685a9bdd85SOleksandr "Alex" Zinenko !::hasEffect<MemoryEffects::Write, PayloadIRResource>(effects)) { 19695a9bdd85SOleksandr "Alex" Zinenko InFlightDiagnostic diag = 19705a9bdd85SOleksandr "Alex" Zinenko op->emitError() 19715a9bdd85SOleksandr "Alex" Zinenko << "TransformOpInterface expects ops consuming operands to have a " 19725a9bdd85SOleksandr "Alex" Zinenko "'write' effect on the payload resource"; 19735a9bdd85SOleksandr "Alex" Zinenko diag.attachNote() << "consumes operand #" << *firstConsumedOperand; 19745a9bdd85SOleksandr "Alex" Zinenko return diag; 19755a9bdd85SOleksandr "Alex" Zinenko } 19765a9bdd85SOleksandr "Alex" Zinenko 19775a9bdd85SOleksandr "Alex" Zinenko for (OpResult result : op->getResults()) { 19785a9bdd85SOleksandr "Alex" Zinenko auto range = effectsOn(result); 19795a9bdd85SOleksandr "Alex" Zinenko if (!::hasEffect<MemoryEffects::Allocate, TransformMappingResource>( 19805a9bdd85SOleksandr "Alex" Zinenko range)) { 19815a9bdd85SOleksandr "Alex" Zinenko InFlightDiagnostic diag = 19825a9bdd85SOleksandr "Alex" Zinenko op->emitError() << "TransformOpInterface requires 'allocate' memory " 19835a9bdd85SOleksandr "Alex" Zinenko "effect to be specified for results"; 19845a9bdd85SOleksandr "Alex" Zinenko diag.attachNote() << "no 'allocate' effect specified for result #" 19855a9bdd85SOleksandr "Alex" Zinenko << result.getResultNumber(); 19865a9bdd85SOleksandr "Alex" Zinenko return diag; 19875a9bdd85SOleksandr "Alex" Zinenko } 19885a9bdd85SOleksandr "Alex" Zinenko } 19895a9bdd85SOleksandr "Alex" Zinenko 19905a9bdd85SOleksandr "Alex" Zinenko return success(); 19915a9bdd85SOleksandr "Alex" Zinenko } 19925a9bdd85SOleksandr "Alex" Zinenko 19935a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 19945a9bdd85SOleksandr "Alex" Zinenko // Entry point. 19955a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 19965a9bdd85SOleksandr "Alex" Zinenko 19975a9bdd85SOleksandr "Alex" Zinenko LogicalResult transform::applyTransforms( 19985a9bdd85SOleksandr "Alex" Zinenko Operation *payloadRoot, TransformOpInterface transform, 19995a9bdd85SOleksandr "Alex" Zinenko const RaggedArray<MappedValue> &extraMapping, 20006634d44eSAmy Wang const TransformOptions &options, bool enforceToplevelTransformOp, 20016634d44eSAmy Wang function_ref<void(TransformState &)> stateInitializer, 20026634d44eSAmy Wang function_ref<LogicalResult(TransformState &)> stateExporter) { 20035a9bdd85SOleksandr "Alex" Zinenko if (enforceToplevelTransformOp) { 20045a9bdd85SOleksandr "Alex" Zinenko if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() || 20055a9bdd85SOleksandr "Alex" Zinenko transform->getNumOperands() != 0) { 20065a9bdd85SOleksandr "Alex" Zinenko return transform->emitError() 20075a9bdd85SOleksandr "Alex" Zinenko << "expected transform to start at the top-level transform op"; 20085a9bdd85SOleksandr "Alex" Zinenko } 20095a9bdd85SOleksandr "Alex" Zinenko } else if (failed( 20105a9bdd85SOleksandr "Alex" Zinenko detail::verifyPossibleTopLevelTransformOpTrait(transform))) { 20115a9bdd85SOleksandr "Alex" Zinenko return failure(); 20125a9bdd85SOleksandr "Alex" Zinenko } 20135a9bdd85SOleksandr "Alex" Zinenko 20145a9bdd85SOleksandr "Alex" Zinenko TransformState state(transform->getParentRegion(), payloadRoot, extraMapping, 20155a9bdd85SOleksandr "Alex" Zinenko options); 20166634d44eSAmy Wang if (stateInitializer) 20176634d44eSAmy Wang stateInitializer(state); 20186634d44eSAmy Wang if (state.applyTransform(transform).checkAndReport().failed()) 20196634d44eSAmy Wang return failure(); 20206634d44eSAmy Wang if (stateExporter) 20216634d44eSAmy Wang return stateExporter(state); 20226634d44eSAmy Wang return success(); 20235a9bdd85SOleksandr "Alex" Zinenko } 20245a9bdd85SOleksandr "Alex" Zinenko 20255a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 20265a9bdd85SOleksandr "Alex" Zinenko // Generated interface implementation. 20275a9bdd85SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 20285a9bdd85SOleksandr "Alex" Zinenko 20295a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc" 20305a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformTypeInterfaces.cpp.inc" 2031