15f0d4f20SAlex Zinenko //===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===// 25f0d4f20SAlex Zinenko // 35f0d4f20SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45f0d4f20SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 55f0d4f20SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65f0d4f20SAlex Zinenko // 75f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===// 85f0d4f20SAlex Zinenko 95f0d4f20SAlex Zinenko #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" 10acc159aeSMatthias Springer 11acc159aeSMatthias Springer #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" 125f0d4f20SAlex Zinenko #include "mlir/Dialect/Affine/IR/AffineOps.h" 13e4e64eaaSAmy Wang #include "mlir/Dialect/Affine/LoopUtils.h" 144fbb5f93SOleksandr "Alex" Zinenko #include "mlir/Dialect/Arith/IR/Arith.h" 154fbb5f93SOleksandr "Alex" Zinenko #include "mlir/Dialect/Arith/Utils/Utils.h" 165f0d4f20SAlex Zinenko #include "mlir/Dialect/Func/IR/FuncOps.h" 178b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 188b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Patterns.h" 198b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Transforms.h" 205f0d4f20SAlex Zinenko #include "mlir/Dialect/SCF/Utils/Utils.h" 215f0d4f20SAlex Zinenko #include "mlir/Dialect/Transform/IR/TransformDialect.h" 2288b7e8e0SNicolas Vasilache #include "mlir/Dialect/Transform/IR/TransformOps.h" 235a9bdd85SOleksandr "Alex" Zinenko #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 244fbb5f93SOleksandr "Alex" Zinenko #include "mlir/Dialect/Utils/StaticValueUtils.h" 255f0d4f20SAlex Zinenko #include "mlir/Dialect/Vector/IR/VectorOps.h" 264fbb5f93SOleksandr "Alex" Zinenko #include "mlir/IR/BuiltinAttributes.h" 279aaf007aSGroverkss #include "mlir/IR/Dominance.h" 284fbb5f93SOleksandr "Alex" Zinenko #include "mlir/IR/OpDefinition.h" 295f0d4f20SAlex Zinenko 305f0d4f20SAlex Zinenko using namespace mlir; 314c48f016SMatthias Springer using namespace mlir::affine; 325f0d4f20SAlex Zinenko 335f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===// 34cc7f5243SMatthias Springer // Apply...PatternsOp 35cc7f5243SMatthias Springer //===----------------------------------------------------------------------===// 36cc7f5243SMatthias Springer 37cc7f5243SMatthias Springer void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns( 38cc7f5243SMatthias Springer RewritePatternSet &patterns) { 39cc7f5243SMatthias Springer scf::populateSCFForLoopCanonicalizationPatterns(patterns); 40cc7f5243SMatthias Springer } 41cc7f5243SMatthias Springer 42e2d39f79SChristopher Bate void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns( 43e2d39f79SChristopher Bate TypeConverter &typeConverter, RewritePatternSet &patterns) { 44e2d39f79SChristopher Bate scf::populateSCFStructuralTypeConversions(typeConverter, patterns); 45e2d39f79SChristopher Bate } 46e2d39f79SChristopher Bate 47e2d39f79SChristopher Bate void transform::ApplySCFStructuralConversionPatternsOp:: 48e2d39f79SChristopher Bate populateConversionTargetRules(const TypeConverter &typeConverter, 49e2d39f79SChristopher Bate ConversionTarget &conversionTarget) { 50e2d39f79SChristopher Bate scf::populateSCFStructuralTypeConversionTarget(typeConverter, 51e2d39f79SChristopher Bate conversionTarget); 52e2d39f79SChristopher Bate } 53e2d39f79SChristopher Bate 54acc159aeSMatthias Springer void transform::ApplySCFToControlFlowPatternsOp::populatePatterns( 55acc159aeSMatthias Springer TypeConverter &typeConverter, RewritePatternSet &patterns) { 56acc159aeSMatthias Springer populateSCFToControlFlowConversionPatterns(patterns); 57acc159aeSMatthias Springer } 58acc159aeSMatthias Springer 59cc7f5243SMatthias Springer //===----------------------------------------------------------------------===// 604fbb5f93SOleksandr "Alex" Zinenko // ForallToForOp 614fbb5f93SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 624fbb5f93SOleksandr "Alex" Zinenko 634fbb5f93SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure 644fbb5f93SOleksandr "Alex" Zinenko transform::ForallToForOp::apply(transform::TransformRewriter &rewriter, 654fbb5f93SOleksandr "Alex" Zinenko transform::TransformResults &results, 664fbb5f93SOleksandr "Alex" Zinenko transform::TransformState &state) { 674fbb5f93SOleksandr "Alex" Zinenko auto payload = state.getPayloadOps(getTarget()); 684fbb5f93SOleksandr "Alex" Zinenko if (!llvm::hasSingleElement(payload)) 694fbb5f93SOleksandr "Alex" Zinenko return emitSilenceableError() << "expected a single payload op"; 704fbb5f93SOleksandr "Alex" Zinenko 714fbb5f93SOleksandr "Alex" Zinenko auto target = dyn_cast<scf::ForallOp>(*payload.begin()); 724fbb5f93SOleksandr "Alex" Zinenko if (!target) { 734fbb5f93SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure diag = 744fbb5f93SOleksandr "Alex" Zinenko emitSilenceableError() << "expected the payload to be scf.forall"; 754fbb5f93SOleksandr "Alex" Zinenko diag.attachNote((*payload.begin())->getLoc()) << "payload op"; 764fbb5f93SOleksandr "Alex" Zinenko return diag; 774fbb5f93SOleksandr "Alex" Zinenko } 784fbb5f93SOleksandr "Alex" Zinenko 794fbb5f93SOleksandr "Alex" Zinenko if (!target.getOutputs().empty()) { 804fbb5f93SOleksandr "Alex" Zinenko return emitSilenceableError() 814fbb5f93SOleksandr "Alex" Zinenko << "unsupported shared outputs (didn't bufferize?)"; 824fbb5f93SOleksandr "Alex" Zinenko } 834fbb5f93SOleksandr "Alex" Zinenko 844fbb5f93SOleksandr "Alex" Zinenko SmallVector<OpFoldResult> lbs = target.getMixedLowerBound(); 854fbb5f93SOleksandr "Alex" Zinenko 864fbb5f93SOleksandr "Alex" Zinenko if (getNumResults() != lbs.size()) { 874fbb5f93SOleksandr "Alex" Zinenko DiagnosedSilenceableFailure diag = 884fbb5f93SOleksandr "Alex" Zinenko emitSilenceableError() 894fbb5f93SOleksandr "Alex" Zinenko << "op expects as many results (" << getNumResults() 904fbb5f93SOleksandr "Alex" Zinenko << ") as payload has induction variables (" << lbs.size() << ")"; 914fbb5f93SOleksandr "Alex" Zinenko diag.attachNote(target.getLoc()) << "payload op"; 924fbb5f93SOleksandr "Alex" Zinenko return diag; 934fbb5f93SOleksandr "Alex" Zinenko } 944fbb5f93SOleksandr "Alex" Zinenko 95286bd42aSJorn Tuyls SmallVector<Operation *> opResults; 96286bd42aSJorn Tuyls if (failed(scf::forallToForLoop(rewriter, target, &opResults))) { 97286bd42aSJorn Tuyls DiagnosedSilenceableFailure diag = emitSilenceableError() 98286bd42aSJorn Tuyls << "failed to convert forall into for"; 99286bd42aSJorn Tuyls return diag; 1004fbb5f93SOleksandr "Alex" Zinenko } 1014fbb5f93SOleksandr "Alex" Zinenko 102286bd42aSJorn Tuyls for (auto &&[i, res] : llvm::enumerate(opResults)) { 103286bd42aSJorn Tuyls results.set(cast<OpResult>(getTransformed()[i]), {res}); 1044fbb5f93SOleksandr "Alex" Zinenko } 1054fbb5f93SOleksandr "Alex" Zinenko return DiagnosedSilenceableFailure::success(); 1064fbb5f93SOleksandr "Alex" Zinenko } 1074fbb5f93SOleksandr "Alex" Zinenko 1084fbb5f93SOleksandr "Alex" Zinenko //===----------------------------------------------------------------------===// 1090b665c3dSSpenser Bauman // ForallToForOp 1100b665c3dSSpenser Bauman //===----------------------------------------------------------------------===// 1110b665c3dSSpenser Bauman 1120b665c3dSSpenser Bauman DiagnosedSilenceableFailure 1130b665c3dSSpenser Bauman transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter, 1140b665c3dSSpenser Bauman transform::TransformResults &results, 1150b665c3dSSpenser Bauman transform::TransformState &state) { 1160b665c3dSSpenser Bauman auto payload = state.getPayloadOps(getTarget()); 1170b665c3dSSpenser Bauman if (!llvm::hasSingleElement(payload)) 1180b665c3dSSpenser Bauman return emitSilenceableError() << "expected a single payload op"; 1190b665c3dSSpenser Bauman 1200b665c3dSSpenser Bauman auto target = dyn_cast<scf::ForallOp>(*payload.begin()); 1210b665c3dSSpenser Bauman if (!target) { 1220b665c3dSSpenser Bauman DiagnosedSilenceableFailure diag = 1230b665c3dSSpenser Bauman emitSilenceableError() << "expected the payload to be scf.forall"; 1240b665c3dSSpenser Bauman diag.attachNote((*payload.begin())->getLoc()) << "payload op"; 1250b665c3dSSpenser Bauman return diag; 1260b665c3dSSpenser Bauman } 1270b665c3dSSpenser Bauman 1280b665c3dSSpenser Bauman if (!target.getOutputs().empty()) { 1290b665c3dSSpenser Bauman return emitSilenceableError() 1300b665c3dSSpenser Bauman << "unsupported shared outputs (didn't bufferize?)"; 1310b665c3dSSpenser Bauman } 1320b665c3dSSpenser Bauman 1330b665c3dSSpenser Bauman if (getNumResults() != 1) { 1340b665c3dSSpenser Bauman DiagnosedSilenceableFailure diag = emitSilenceableError() 1350b665c3dSSpenser Bauman << "op expects one result, given " 1360b665c3dSSpenser Bauman << getNumResults(); 1370b665c3dSSpenser Bauman diag.attachNote(target.getLoc()) << "payload op"; 1380b665c3dSSpenser Bauman return diag; 1390b665c3dSSpenser Bauman } 1400b665c3dSSpenser Bauman 1410b665c3dSSpenser Bauman scf::ParallelOp opResult; 1420b665c3dSSpenser Bauman if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) { 1430b665c3dSSpenser Bauman DiagnosedSilenceableFailure diag = 1440b665c3dSSpenser Bauman emitSilenceableError() << "failed to convert forall into parallel"; 1450b665c3dSSpenser Bauman return diag; 1460b665c3dSSpenser Bauman } 1470b665c3dSSpenser Bauman 1480b665c3dSSpenser Bauman results.set(cast<OpResult>(getTransformed()[0]), {opResult}); 1490b665c3dSSpenser Bauman return DiagnosedSilenceableFailure::success(); 1500b665c3dSSpenser Bauman } 1510b665c3dSSpenser Bauman 1520b665c3dSSpenser Bauman //===----------------------------------------------------------------------===// 1535f0d4f20SAlex Zinenko // LoopOutlineOp 1545f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===// 1555f0d4f20SAlex Zinenko 1565f0d4f20SAlex Zinenko /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses 1575f0d4f20SAlex Zinenko /// the provided rewriter for all operations to remain compatible with the 1585f0d4f20SAlex Zinenko /// rewriting infra, as opposed to just splicing the op in place. 1595f0d4f20SAlex Zinenko static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, 1605f0d4f20SAlex Zinenko Operation *op) { 1615f0d4f20SAlex Zinenko if (op->getNumRegions() != 1) 1625f0d4f20SAlex Zinenko return nullptr; 1635f0d4f20SAlex Zinenko OpBuilder::InsertionGuard g(b); 1645f0d4f20SAlex Zinenko b.setInsertionPoint(op); 1655f0d4f20SAlex Zinenko scf::ExecuteRegionOp executeRegionOp = 1665f0d4f20SAlex Zinenko b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes()); 1675f0d4f20SAlex Zinenko { 1685f0d4f20SAlex Zinenko OpBuilder::InsertionGuard g(b); 1695f0d4f20SAlex Zinenko b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); 1705f0d4f20SAlex Zinenko Operation *clonedOp = b.cloneWithoutRegions(*op); 1715f0d4f20SAlex Zinenko Region &clonedRegion = clonedOp->getRegions().front(); 1725f0d4f20SAlex Zinenko assert(clonedRegion.empty() && "expected empty region"); 1735f0d4f20SAlex Zinenko b.inlineRegionBefore(op->getRegions().front(), clonedRegion, 1745f0d4f20SAlex Zinenko clonedRegion.end()); 1755f0d4f20SAlex Zinenko b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults()); 1765f0d4f20SAlex Zinenko } 1775f0d4f20SAlex Zinenko b.replaceOp(op, executeRegionOp.getResults()); 1785f0d4f20SAlex Zinenko return executeRegionOp; 1795f0d4f20SAlex Zinenko } 1805f0d4f20SAlex Zinenko 1811d45282aSAlex Zinenko DiagnosedSilenceableFailure 182c63d2b2cSMatthias Springer transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter, 183c63d2b2cSMatthias Springer transform::TransformResults &results, 1845f0d4f20SAlex Zinenko transform::TransformState &state) { 185d9064269SAlex Zinenko SmallVector<Operation *> functions; 186d9064269SAlex Zinenko SmallVector<Operation *> calls; 1875f0d4f20SAlex Zinenko DenseMap<Operation *, SymbolTable> symbolTables; 1885f0d4f20SAlex Zinenko for (Operation *target : state.getPayloadOps(getTarget())) { 1895f0d4f20SAlex Zinenko Location location = target->getLoc(); 1905f0d4f20SAlex Zinenko Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); 1915f0d4f20SAlex Zinenko scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); 1925f0d4f20SAlex Zinenko if (!exec) { 1931d45282aSAlex Zinenko DiagnosedSilenceableFailure diag = emitSilenceableError() 194e3890b7fSAlex Zinenko << "failed to outline"; 1955f0d4f20SAlex Zinenko diag.attachNote(target->getLoc()) << "target op"; 1965f0d4f20SAlex Zinenko return diag; 1975f0d4f20SAlex Zinenko } 1985f0d4f20SAlex Zinenko func::CallOp call; 1995f0d4f20SAlex Zinenko FailureOr<func::FuncOp> outlined = outlineSingleBlockRegion( 2005f0d4f20SAlex Zinenko rewriter, location, exec.getRegion(), getFuncName(), &call); 2015f0d4f20SAlex Zinenko 2027d5bef77SAlex Zinenko if (failed(outlined)) 2037d5bef77SAlex Zinenko return emitDefaultDefiniteFailure(target); 2045f0d4f20SAlex Zinenko 2055f0d4f20SAlex Zinenko if (symbolTableOp) { 2065f0d4f20SAlex Zinenko SymbolTable &symbolTable = 2075f0d4f20SAlex Zinenko symbolTables.try_emplace(symbolTableOp, symbolTableOp) 2085f0d4f20SAlex Zinenko .first->getSecond(); 2095f0d4f20SAlex Zinenko symbolTable.insert(*outlined); 2105f0d4f20SAlex Zinenko call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined)); 2115f0d4f20SAlex Zinenko } 212d9064269SAlex Zinenko functions.push_back(*outlined); 213d9064269SAlex Zinenko calls.push_back(call); 2145f0d4f20SAlex Zinenko } 2155550c821STres Popp results.set(cast<OpResult>(getFunction()), functions); 2165550c821STres Popp results.set(cast<OpResult>(getCall()), calls); 2171d45282aSAlex Zinenko return DiagnosedSilenceableFailure::success(); 2185f0d4f20SAlex Zinenko } 2195f0d4f20SAlex Zinenko 2205f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===// 2215f0d4f20SAlex Zinenko // LoopPeelOp 2225f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===// 2235f0d4f20SAlex Zinenko 22452307109SNicolas Vasilache DiagnosedSilenceableFailure 225c63d2b2cSMatthias Springer transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter, 226c63d2b2cSMatthias Springer scf::ForOp target, 2274b455a71SAlex Zinenko transform::ApplyToEachResultList &results, 22852307109SNicolas Vasilache transform::TransformState &state) { 2295f0d4f20SAlex Zinenko scf::ForOp result; 230bd6a2452SVivian if (getPeelFront()) { 231bd6a2452SVivian LogicalResult status = 232bd6a2452SVivian scf::peelForLoopFirstIteration(rewriter, target, result); 233bd6a2452SVivian if (failed(status)) { 234bd6a2452SVivian DiagnosedSilenceableFailure diag = 235bd6a2452SVivian emitSilenceableError() << "failed to peel the first iteration"; 236bd6a2452SVivian return diag; 237bd6a2452SVivian } 238bd6a2452SVivian } else { 2395f0d4f20SAlex Zinenko LogicalResult status = 240bb2ae985SNicolas Vasilache scf::peelForLoopAndSimplifyBounds(rewriter, target, result); 2411e70ab5fSAndrzej Warzynski if (failed(status)) { 2421e70ab5fSAndrzej Warzynski DiagnosedSilenceableFailure diag = emitSilenceableError() 243bd6a2452SVivian << "failed to peel the last iteration"; 2441e70ab5fSAndrzej Warzynski return diag; 2451e70ab5fSAndrzej Warzynski } 246bd6a2452SVivian } 247bd6a2452SVivian 2481e70ab5fSAndrzej Warzynski results.push_back(target); 2491e70ab5fSAndrzej Warzynski results.push_back(result); 2501e70ab5fSAndrzej Warzynski 2517d5bef77SAlex Zinenko return DiagnosedSilenceableFailure::success(); 2525f0d4f20SAlex Zinenko } 2535f0d4f20SAlex Zinenko 2545f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===// 2555f0d4f20SAlex Zinenko // LoopPipelineOp 2565f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===// 2575f0d4f20SAlex Zinenko 2585f0d4f20SAlex Zinenko /// Callback for PipeliningOption. Populates `schedule` with the mapping from an 2595f0d4f20SAlex Zinenko /// operation to its logical time position given the iteration interval and the 2605f0d4f20SAlex Zinenko /// read latency. The latter is only relevant for vector transfers. 2615f0d4f20SAlex Zinenko static void 2625f0d4f20SAlex Zinenko loopScheduling(scf::ForOp forOp, 2635f0d4f20SAlex Zinenko std::vector<std::pair<Operation *, unsigned>> &schedule, 2645f0d4f20SAlex Zinenko unsigned iterationInterval, unsigned readLatency) { 2655f0d4f20SAlex Zinenko auto getLatency = [&](Operation *op) -> unsigned { 2665f0d4f20SAlex Zinenko if (isa<vector::TransferReadOp>(op)) 2675f0d4f20SAlex Zinenko return readLatency; 2685f0d4f20SAlex Zinenko return 1; 2695f0d4f20SAlex Zinenko }; 2705f0d4f20SAlex Zinenko 271acc159aeSMatthias Springer std::optional<int64_t> ubConstant = 272acc159aeSMatthias Springer getConstantIntValue(forOp.getUpperBound()); 273acc159aeSMatthias Springer std::optional<int64_t> lbConstant = 274acc159aeSMatthias Springer getConstantIntValue(forOp.getLowerBound()); 2755f0d4f20SAlex Zinenko DenseMap<Operation *, unsigned> opCycles; 2765f0d4f20SAlex Zinenko std::map<unsigned, std::vector<Operation *>> wrappedSchedule; 2775f0d4f20SAlex Zinenko for (Operation &op : forOp.getBody()->getOperations()) { 2785f0d4f20SAlex Zinenko if (isa<scf::YieldOp>(op)) 2795f0d4f20SAlex Zinenko continue; 2805f0d4f20SAlex Zinenko unsigned earlyCycle = 0; 2815f0d4f20SAlex Zinenko for (Value operand : op.getOperands()) { 2825f0d4f20SAlex Zinenko Operation *def = operand.getDefiningOp(); 2835f0d4f20SAlex Zinenko if (!def) 2845f0d4f20SAlex Zinenko continue; 285192cd685SFotis Kounelis if (ubConstant && lbConstant) { 286192cd685SFotis Kounelis unsigned ubInt = ubConstant.value(); 287192cd685SFotis Kounelis unsigned lbInt = lbConstant.value(); 288192cd685SFotis Kounelis auto minLatency = std::min(ubInt - lbInt - 1, getLatency(def)); 289192cd685SFotis Kounelis earlyCycle = std::max(earlyCycle, opCycles[def] + minLatency); 290192cd685SFotis Kounelis } else { 2915f0d4f20SAlex Zinenko earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def)); 2925f0d4f20SAlex Zinenko } 293192cd685SFotis Kounelis } 2945f0d4f20SAlex Zinenko opCycles[&op] = earlyCycle; 2955f0d4f20SAlex Zinenko wrappedSchedule[earlyCycle % iterationInterval].push_back(&op); 2965f0d4f20SAlex Zinenko } 2978ab925a2SAdrian Kuegel for (const auto &it : wrappedSchedule) { 2985f0d4f20SAlex Zinenko for (Operation *op : it.second) { 2995f0d4f20SAlex Zinenko unsigned cycle = opCycles[op]; 300cd417c6aSMehdi Amini schedule.emplace_back(op, cycle / iterationInterval); 3015f0d4f20SAlex Zinenko } 3025f0d4f20SAlex Zinenko } 3035f0d4f20SAlex Zinenko } 3045f0d4f20SAlex Zinenko 30552307109SNicolas Vasilache DiagnosedSilenceableFailure 306c63d2b2cSMatthias Springer transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter, 307c63d2b2cSMatthias Springer scf::ForOp target, 3084b455a71SAlex Zinenko transform::ApplyToEachResultList &results, 30952307109SNicolas Vasilache transform::TransformState &state) { 3105f0d4f20SAlex Zinenko scf::PipeliningOption options; 3115f0d4f20SAlex Zinenko options.getScheduleFn = 3125f0d4f20SAlex Zinenko [this](scf::ForOp forOp, 3135f0d4f20SAlex Zinenko std::vector<std::pair<Operation *, unsigned>> &schedule) mutable { 3145f0d4f20SAlex Zinenko loopScheduling(forOp, schedule, getIterationInterval(), 3155f0d4f20SAlex Zinenko getReadLatency()); 3165f0d4f20SAlex Zinenko }; 31752307109SNicolas Vasilache scf::ForLoopPipeliningPattern pattern(options, target->getContext()); 31852307109SNicolas Vasilache rewriter.setInsertionPoint(target); 3195f0d4f20SAlex Zinenko FailureOr<scf::ForOp> patternResult = 3201cff4cbdSNicolas Vasilache scf::pipelineForLoop(rewriter, target, options); 32152307109SNicolas Vasilache if (succeeded(patternResult)) { 32252307109SNicolas Vasilache results.push_back(*patternResult); 3237d5bef77SAlex Zinenko return DiagnosedSilenceableFailure::success(); 32452307109SNicolas Vasilache } 32552307109SNicolas Vasilache return emitDefaultSilenceableFailure(target); 3265f0d4f20SAlex Zinenko } 3275f0d4f20SAlex Zinenko 3285f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===// 3298b8e62d3SMatthias Springer // LoopPromoteIfOneIterationOp 3308b8e62d3SMatthias Springer //===----------------------------------------------------------------------===// 3318b8e62d3SMatthias Springer 3328b8e62d3SMatthias Springer DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne( 3338b8e62d3SMatthias Springer transform::TransformRewriter &rewriter, LoopLikeOpInterface target, 3348b8e62d3SMatthias Springer transform::ApplyToEachResultList &results, 3358b8e62d3SMatthias Springer transform::TransformState &state) { 3368b8e62d3SMatthias Springer (void)target.promoteIfSingleIteration(rewriter); 3378b8e62d3SMatthias Springer return DiagnosedSilenceableFailure::success(); 3388b8e62d3SMatthias Springer } 3398b8e62d3SMatthias Springer 3408b8e62d3SMatthias Springer void transform::LoopPromoteIfOneIterationOp::getEffects( 3418b8e62d3SMatthias Springer SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 3422c1ae801Sdonald chen consumesHandle(getTargetMutable(), effects); 3438b8e62d3SMatthias Springer modifiesPayload(effects); 3448b8e62d3SMatthias Springer } 3458b8e62d3SMatthias Springer 3468b8e62d3SMatthias Springer //===----------------------------------------------------------------------===// 3475f0d4f20SAlex Zinenko // LoopUnrollOp 3485f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===// 3495f0d4f20SAlex Zinenko 35052307109SNicolas Vasilache DiagnosedSilenceableFailure 351c63d2b2cSMatthias Springer transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, 352c63d2b2cSMatthias Springer Operation *op, 3534b455a71SAlex Zinenko transform::ApplyToEachResultList &results, 35452307109SNicolas Vasilache transform::TransformState &state) { 355e4e64eaaSAmy Wang LogicalResult result(failure()); 356e4e64eaaSAmy Wang if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) 357e4e64eaaSAmy Wang result = loopUnrollByFactor(scfFor, getFactor()); 358e4e64eaaSAmy Wang else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op)) 359e4e64eaaSAmy Wang result = loopUnrollByFactor(affineFor, getFactor()); 360cb8bd6f7SAviad Cohen else 361cb8bd6f7SAviad Cohen return emitSilenceableError() 362cb8bd6f7SAviad Cohen << "failed to unroll, incorrect type of payload"; 363e4e64eaaSAmy Wang 364cb8bd6f7SAviad Cohen if (failed(result)) 365cb8bd6f7SAviad Cohen return emitSilenceableError() << "failed to unroll"; 366cb8bd6f7SAviad Cohen 367cb8bd6f7SAviad Cohen return DiagnosedSilenceableFailure::success(); 368efc0ba02SAmy Wang } 369cb8bd6f7SAviad Cohen 370cb8bd6f7SAviad Cohen //===----------------------------------------------------------------------===// 371cb8bd6f7SAviad Cohen // LoopUnrollAndJamOp 372cb8bd6f7SAviad Cohen //===----------------------------------------------------------------------===// 373cb8bd6f7SAviad Cohen 374cb8bd6f7SAviad Cohen DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne( 375cb8bd6f7SAviad Cohen transform::TransformRewriter &rewriter, Operation *op, 376cb8bd6f7SAviad Cohen transform::ApplyToEachResultList &results, 377cb8bd6f7SAviad Cohen transform::TransformState &state) { 378cb8bd6f7SAviad Cohen LogicalResult result(failure()); 379cb8bd6f7SAviad Cohen if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) 380cb8bd6f7SAviad Cohen result = loopUnrollJamByFactor(scfFor, getFactor()); 381cb8bd6f7SAviad Cohen else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op)) 382cb8bd6f7SAviad Cohen result = loopUnrollJamByFactor(affineFor, getFactor()); 383cb8bd6f7SAviad Cohen else 384cb8bd6f7SAviad Cohen return emitSilenceableError() 385cb8bd6f7SAviad Cohen << "failed to unroll and jam, incorrect type of payload"; 386cb8bd6f7SAviad Cohen 387cb8bd6f7SAviad Cohen if (failed(result)) 388cb8bd6f7SAviad Cohen return emitSilenceableError() << "failed to unroll and jam"; 389cb8bd6f7SAviad Cohen 390efc0ba02SAmy Wang return DiagnosedSilenceableFailure::success(); 391efc0ba02SAmy Wang } 392efc0ba02SAmy Wang 393efc0ba02SAmy Wang //===----------------------------------------------------------------------===// 394efc0ba02SAmy Wang // LoopCoalesceOp 395efc0ba02SAmy Wang //===----------------------------------------------------------------------===// 396efc0ba02SAmy Wang 397efc0ba02SAmy Wang DiagnosedSilenceableFailure 398c63d2b2cSMatthias Springer transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter, 399c63d2b2cSMatthias Springer Operation *op, 400efc0ba02SAmy Wang transform::ApplyToEachResultList &results, 401efc0ba02SAmy Wang transform::TransformState &state) { 402efc0ba02SAmy Wang LogicalResult result(failure()); 403efc0ba02SAmy Wang if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op)) 4045aeb604cSMaheshRavishankar result = coalescePerfectlyNestedSCFForLoops(scfForOp); 405efc0ba02SAmy Wang else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op)) 4065aeb604cSMaheshRavishankar result = coalescePerfectlyNestedAffineLoops(affineForOp); 407efc0ba02SAmy Wang 408efc0ba02SAmy Wang results.push_back(op); 409efc0ba02SAmy Wang if (failed(result)) { 410efc0ba02SAmy Wang DiagnosedSilenceableFailure diag = emitSilenceableError() 411efc0ba02SAmy Wang << "failed to coalesce"; 412efc0ba02SAmy Wang return diag; 41352307109SNicolas Vasilache } 4147d5bef77SAlex Zinenko return DiagnosedSilenceableFailure::success(); 4155f0d4f20SAlex Zinenko } 4165f0d4f20SAlex Zinenko 4175f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===// 41888b7e8e0SNicolas Vasilache // TakeAssumedBranchOp 41988b7e8e0SNicolas Vasilache //===----------------------------------------------------------------------===// 42088b7e8e0SNicolas Vasilache /// Replaces the given op with the contents of the given single-block region, 42188b7e8e0SNicolas Vasilache /// using the operands of the block terminator to replace operation results. 42288b7e8e0SNicolas Vasilache static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, 42388b7e8e0SNicolas Vasilache Region ®ion) { 42488b7e8e0SNicolas Vasilache assert(llvm::hasSingleElement(region) && "expected single-region block"); 42588b7e8e0SNicolas Vasilache Block *block = ®ion.front(); 42688b7e8e0SNicolas Vasilache Operation *terminator = block->getTerminator(); 42788b7e8e0SNicolas Vasilache ValueRange results = terminator->getOperands(); 42888b7e8e0SNicolas Vasilache rewriter.inlineBlockBefore(block, op, /*blockArgs=*/{}); 42988b7e8e0SNicolas Vasilache rewriter.replaceOp(op, results); 43088b7e8e0SNicolas Vasilache rewriter.eraseOp(terminator); 43188b7e8e0SNicolas Vasilache } 43288b7e8e0SNicolas Vasilache 43388b7e8e0SNicolas Vasilache DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne( 434c63d2b2cSMatthias Springer transform::TransformRewriter &rewriter, scf::IfOp ifOp, 435c63d2b2cSMatthias Springer transform::ApplyToEachResultList &results, 43688b7e8e0SNicolas Vasilache transform::TransformState &state) { 43788b7e8e0SNicolas Vasilache rewriter.setInsertionPoint(ifOp); 43888b7e8e0SNicolas Vasilache Region ®ion = 43988b7e8e0SNicolas Vasilache getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion(); 44088b7e8e0SNicolas Vasilache if (!llvm::hasSingleElement(region)) { 44188b7e8e0SNicolas Vasilache return emitDefiniteFailure() 44288b7e8e0SNicolas Vasilache << "requires an scf.if op with a single-block " 44388b7e8e0SNicolas Vasilache << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region"; 44488b7e8e0SNicolas Vasilache } 44588b7e8e0SNicolas Vasilache replaceOpWithRegion(rewriter, ifOp, region); 44688b7e8e0SNicolas Vasilache return DiagnosedSilenceableFailure::success(); 44788b7e8e0SNicolas Vasilache } 44888b7e8e0SNicolas Vasilache 44988b7e8e0SNicolas Vasilache void transform::TakeAssumedBranchOp::getEffects( 45088b7e8e0SNicolas Vasilache SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 4512c1ae801Sdonald chen onlyReadsHandle(getTargetMutable(), effects); 45288b7e8e0SNicolas Vasilache modifiesPayload(effects); 45388b7e8e0SNicolas Vasilache } 45488b7e8e0SNicolas Vasilache 45588b7e8e0SNicolas Vasilache //===----------------------------------------------------------------------===// 456eacda36cSRolf Morel // LoopFuseSiblingOp 4579aaf007aSGroverkss //===----------------------------------------------------------------------===// 4589aaf007aSGroverkss 45997a2bd84SAlexander Belyaev /// Check if `target` and `source` are siblings, in the context that `target` 46097a2bd84SAlexander Belyaev /// is being fused into `source`. 46197a2bd84SAlexander Belyaev /// 46297a2bd84SAlexander Belyaev /// This is a simple check that just checks if both operations are in the same 46397a2bd84SAlexander Belyaev /// block and some checks to ensure that the fused IR does not violate 46497a2bd84SAlexander Belyaev /// dominance. 46597a2bd84SAlexander Belyaev static DiagnosedSilenceableFailure isOpSibling(Operation *target, 46697a2bd84SAlexander Belyaev Operation *source) { 46797a2bd84SAlexander Belyaev // Check if both operations are same. 46897a2bd84SAlexander Belyaev if (target == source) 46997a2bd84SAlexander Belyaev return emitSilenceableFailure(source) 47097a2bd84SAlexander Belyaev << "target and source need to be different loops"; 47197a2bd84SAlexander Belyaev 47297a2bd84SAlexander Belyaev // Check if both operations are in the same block. 47397a2bd84SAlexander Belyaev if (target->getBlock() != source->getBlock()) 47497a2bd84SAlexander Belyaev return emitSilenceableFailure(source) 47597a2bd84SAlexander Belyaev << "target and source are not in the same block"; 47697a2bd84SAlexander Belyaev 47797a2bd84SAlexander Belyaev // Check if fusion will violate dominance. 47897a2bd84SAlexander Belyaev DominanceInfo domInfo(source); 47997a2bd84SAlexander Belyaev if (target->isBeforeInBlock(source)) { 48097a2bd84SAlexander Belyaev // Since `target` is before `source`, all users of results of `target` 48197a2bd84SAlexander Belyaev // need to be dominated by `source`. 48297a2bd84SAlexander Belyaev for (Operation *user : target->getUsers()) { 48397a2bd84SAlexander Belyaev if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { 48497a2bd84SAlexander Belyaev return emitSilenceableFailure(target) 48597a2bd84SAlexander Belyaev << "user of results of target should be properly dominated by " 48697a2bd84SAlexander Belyaev "source"; 48797a2bd84SAlexander Belyaev } 48897a2bd84SAlexander Belyaev } 48997a2bd84SAlexander Belyaev } else { 49097a2bd84SAlexander Belyaev // Since `target` is after `source`, all values used by `target` need 49197a2bd84SAlexander Belyaev // to dominate `source`. 49297a2bd84SAlexander Belyaev 49397a2bd84SAlexander Belyaev // Check if operands of `target` are dominated by `source`. 49497a2bd84SAlexander Belyaev for (Value operand : target->getOperands()) { 49597a2bd84SAlexander Belyaev Operation *operandOp = operand.getDefiningOp(); 49697a2bd84SAlexander Belyaev // Operands without defining operations are block arguments. When `target` 49797a2bd84SAlexander Belyaev // and `source` occur in the same block, these operands dominate `source`. 49897a2bd84SAlexander Belyaev if (!operandOp) 49997a2bd84SAlexander Belyaev continue; 50097a2bd84SAlexander Belyaev 50197a2bd84SAlexander Belyaev // Operand's defining operation should properly dominate `source`. 50297a2bd84SAlexander Belyaev if (!domInfo.properlyDominates(operandOp, source, 50397a2bd84SAlexander Belyaev /*enclosingOpOk=*/false)) 50497a2bd84SAlexander Belyaev return emitSilenceableFailure(target) 50597a2bd84SAlexander Belyaev << "operands of target should be properly dominated by source"; 50697a2bd84SAlexander Belyaev } 50797a2bd84SAlexander Belyaev 50897a2bd84SAlexander Belyaev // Check if values used by `target` are dominated by `source`. 50997a2bd84SAlexander Belyaev bool failed = false; 51097a2bd84SAlexander Belyaev OpOperand *failedValue = nullptr; 51197a2bd84SAlexander Belyaev visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { 51297a2bd84SAlexander Belyaev Operation *operandOp = operand->get().getDefiningOp(); 51397a2bd84SAlexander Belyaev if (operandOp && !domInfo.properlyDominates(operandOp, source, 51497a2bd84SAlexander Belyaev /*enclosingOpOk=*/false)) { 51597a2bd84SAlexander Belyaev // `operand` is not an argument of an enclosing block and the defining 51697a2bd84SAlexander Belyaev // op of `operand` is outside `target` but does not dominate `source`. 51797a2bd84SAlexander Belyaev failed = true; 51897a2bd84SAlexander Belyaev failedValue = operand; 51997a2bd84SAlexander Belyaev } 52097a2bd84SAlexander Belyaev }); 52197a2bd84SAlexander Belyaev 52297a2bd84SAlexander Belyaev if (failed) 52397a2bd84SAlexander Belyaev return emitSilenceableFailure(failedValue->getOwner()) 52497a2bd84SAlexander Belyaev << "values used inside regions of target should be properly " 52597a2bd84SAlexander Belyaev "dominated by source"; 52697a2bd84SAlexander Belyaev } 52797a2bd84SAlexander Belyaev 52897a2bd84SAlexander Belyaev return DiagnosedSilenceableFailure::success(); 52997a2bd84SAlexander Belyaev } 53097a2bd84SAlexander Belyaev 53197a2bd84SAlexander Belyaev /// Check if `target` scf.forall can be fused into `source` scf.forall. 53297a2bd84SAlexander Belyaev /// 53397a2bd84SAlexander Belyaev /// This simply checks if both loops have the same bounds, steps and mapping. 53497a2bd84SAlexander Belyaev /// No attempt is made at checking that the side effects of `target` and 53597a2bd84SAlexander Belyaev /// `source` are independent of each other. 53697a2bd84SAlexander Belyaev static bool isForallWithIdenticalConfiguration(Operation *target, 53797a2bd84SAlexander Belyaev Operation *source) { 53897a2bd84SAlexander Belyaev auto targetOp = dyn_cast<scf::ForallOp>(target); 53997a2bd84SAlexander Belyaev auto sourceOp = dyn_cast<scf::ForallOp>(source); 54097a2bd84SAlexander Belyaev if (!targetOp || !sourceOp) 54197a2bd84SAlexander Belyaev return false; 54297a2bd84SAlexander Belyaev 54397a2bd84SAlexander Belyaev return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && 54497a2bd84SAlexander Belyaev targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && 54597a2bd84SAlexander Belyaev targetOp.getMixedStep() == sourceOp.getMixedStep() && 54697a2bd84SAlexander Belyaev targetOp.getMapping() == sourceOp.getMapping(); 54797a2bd84SAlexander Belyaev } 54897a2bd84SAlexander Belyaev 54997a2bd84SAlexander Belyaev /// Check if `target` scf.for can be fused into `source` scf.for. 55097a2bd84SAlexander Belyaev /// 55197a2bd84SAlexander Belyaev /// This simply checks if both loops have the same bounds and steps. No attempt 55297a2bd84SAlexander Belyaev /// is made at checking that the side effects of `target` and `source` are 55397a2bd84SAlexander Belyaev /// independent of each other. 55497a2bd84SAlexander Belyaev static bool isForWithIdenticalConfiguration(Operation *target, 55597a2bd84SAlexander Belyaev Operation *source) { 55697a2bd84SAlexander Belyaev auto targetOp = dyn_cast<scf::ForOp>(target); 55797a2bd84SAlexander Belyaev auto sourceOp = dyn_cast<scf::ForOp>(source); 55897a2bd84SAlexander Belyaev if (!targetOp || !sourceOp) 55997a2bd84SAlexander Belyaev return false; 56097a2bd84SAlexander Belyaev 56197a2bd84SAlexander Belyaev return targetOp.getLowerBound() == sourceOp.getLowerBound() && 56297a2bd84SAlexander Belyaev targetOp.getUpperBound() == sourceOp.getUpperBound() && 56397a2bd84SAlexander Belyaev targetOp.getStep() == sourceOp.getStep(); 56497a2bd84SAlexander Belyaev } 56597a2bd84SAlexander Belyaev 5669aaf007aSGroverkss DiagnosedSilenceableFailure 567eacda36cSRolf Morel transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, 5689aaf007aSGroverkss transform::TransformResults &results, 5699aaf007aSGroverkss transform::TransformState &state) { 5709aaf007aSGroverkss auto targetOps = state.getPayloadOps(getTarget()); 5719aaf007aSGroverkss auto sourceOps = state.getPayloadOps(getSource()); 5729aaf007aSGroverkss 5739aaf007aSGroverkss if (!llvm::hasSingleElement(targetOps) || 5749aaf007aSGroverkss !llvm::hasSingleElement(sourceOps)) { 5759aaf007aSGroverkss return emitDefiniteFailure() 5769aaf007aSGroverkss << "requires exactly one target handle (got " 5779aaf007aSGroverkss << llvm::range_size(targetOps) << ") and exactly one " 5789aaf007aSGroverkss << "source handle (got " << llvm::range_size(sourceOps) << ")"; 5799aaf007aSGroverkss } 5809aaf007aSGroverkss 58197a2bd84SAlexander Belyaev Operation *target = *targetOps.begin(); 58297a2bd84SAlexander Belyaev Operation *source = *sourceOps.begin(); 5839aaf007aSGroverkss 58497a2bd84SAlexander Belyaev // Check if the target and source are siblings. 58597a2bd84SAlexander Belyaev DiagnosedSilenceableFailure diag = isOpSibling(target, source); 58697a2bd84SAlexander Belyaev if (!diag.succeeded()) 58797a2bd84SAlexander Belyaev return diag; 5889aaf007aSGroverkss 589eacda36cSRolf Morel Operation *fusedLoop; 59097a2bd84SAlexander Belyaev /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. 59197a2bd84SAlexander Belyaev if (isForWithIdenticalConfiguration(target, source)) { 592eacda36cSRolf Morel fusedLoop = fuseIndependentSiblingForLoops( 593eacda36cSRolf Morel cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter); 59497a2bd84SAlexander Belyaev } else if (isForallWithIdenticalConfiguration(target, source)) { 595eacda36cSRolf Morel fusedLoop = fuseIndependentSiblingForallLoops( 596eacda36cSRolf Morel cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter); 597eacda36cSRolf Morel } else 5989aaf007aSGroverkss return emitSilenceableFailure(target->getLoc()) 59997a2bd84SAlexander Belyaev << "operations cannot be fused"; 6009aaf007aSGroverkss 6019aaf007aSGroverkss assert(fusedLoop && "failed to fuse operations"); 6029aaf007aSGroverkss 6039aaf007aSGroverkss results.set(cast<OpResult>(getFusedLoop()), {fusedLoop}); 6049aaf007aSGroverkss return DiagnosedSilenceableFailure::success(); 6059aaf007aSGroverkss } 6069aaf007aSGroverkss 6079aaf007aSGroverkss //===----------------------------------------------------------------------===// 6085f0d4f20SAlex Zinenko // Transform op registration 6095f0d4f20SAlex Zinenko //===----------------------------------------------------------------------===// 6105f0d4f20SAlex Zinenko 6115f0d4f20SAlex Zinenko namespace { 6125f0d4f20SAlex Zinenko class SCFTransformDialectExtension 6135f0d4f20SAlex Zinenko : public transform::TransformDialectExtension< 6145f0d4f20SAlex Zinenko SCFTransformDialectExtension> { 6155f0d4f20SAlex Zinenko public: 616*84cc1865SNikhil Kalra MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension) 617*84cc1865SNikhil Kalra 618333ee218SAlex Zinenko using Base::Base; 619333ee218SAlex Zinenko 620333ee218SAlex Zinenko void init() { 6214c48f016SMatthias Springer declareGeneratedDialect<affine::AffineDialect>(); 622333ee218SAlex Zinenko declareGeneratedDialect<func::FuncDialect>(); 623333ee218SAlex Zinenko 6245f0d4f20SAlex Zinenko registerTransformOps< 6255f0d4f20SAlex Zinenko #define GET_OP_LIST 6265f0d4f20SAlex Zinenko #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" 6275f0d4f20SAlex Zinenko >(); 6285f0d4f20SAlex Zinenko } 6295f0d4f20SAlex Zinenko }; 6305f0d4f20SAlex Zinenko } // namespace 6315f0d4f20SAlex Zinenko 6325f0d4f20SAlex Zinenko #define GET_OP_CLASSES 6335f0d4f20SAlex Zinenko #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" 6345f0d4f20SAlex Zinenko 6355f0d4f20SAlex Zinenko void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) { 6365f0d4f20SAlex Zinenko registry.addExtensions<SCFTransformDialectExtension>(); 6375f0d4f20SAlex Zinenko } 638