//===- SCFTransformOps.cpp - Implementation of SCF transformation ops -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/OpDefinition.h" using namespace mlir; using namespace mlir::affine; //===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===// void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns( RewritePatternSet &patterns) { scf::populateSCFForLoopCanonicalizationPatterns(patterns); } void transform::ApplySCFStructuralConversionPatternsOp::populatePatterns( TypeConverter &typeConverter, RewritePatternSet &patterns) { scf::populateSCFStructuralTypeConversions(typeConverter, patterns); } void transform::ApplySCFStructuralConversionPatternsOp:: populateConversionTargetRules(const TypeConverter &typeConverter, ConversionTarget &conversionTarget) { scf::populateSCFStructuralTypeConversionTarget(typeConverter, conversionTarget); } void transform::ApplySCFToControlFlowPatternsOp::populatePatterns( TypeConverter &typeConverter, RewritePatternSet &patterns) { populateSCFToControlFlowConversionPatterns(patterns); } //===----------------------------------------------------------------------===// // ForallToForOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ForallToForOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto payload = state.getPayloadOps(getTarget()); if (!llvm::hasSingleElement(payload)) return emitSilenceableError() << "expected a single payload op"; auto target = dyn_cast(*payload.begin()); if (!target) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected the payload to be scf.forall"; diag.attachNote((*payload.begin())->getLoc()) << "payload op"; return diag; } if (!target.getOutputs().empty()) { return emitSilenceableError() << "unsupported shared outputs (didn't bufferize?)"; } SmallVector lbs = target.getMixedLowerBound(); if (getNumResults() != lbs.size()) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "op expects as many results (" << getNumResults() << ") as payload has induction variables (" << lbs.size() << ")"; diag.attachNote(target.getLoc()) << "payload op"; return diag; } SmallVector opResults; if (failed(scf::forallToForLoop(rewriter, target, &opResults))) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "failed to convert forall into for"; return diag; } for (auto &&[i, res] : llvm::enumerate(opResults)) { results.set(cast(getTransformed()[i]), {res}); } return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // ForallToForOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto payload = state.getPayloadOps(getTarget()); if (!llvm::hasSingleElement(payload)) return emitSilenceableError() << "expected a single payload op"; auto target = dyn_cast(*payload.begin()); if (!target) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "expected the payload to be scf.forall"; diag.attachNote((*payload.begin())->getLoc()) << "payload op"; return diag; } if (!target.getOutputs().empty()) { return emitSilenceableError() << "unsupported shared outputs (didn't bufferize?)"; } if (getNumResults() != 1) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "op expects one result, given " << getNumResults(); diag.attachNote(target.getLoc()) << "payload op"; return diag; } scf::ParallelOp opResult; if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "failed to convert forall into parallel"; return diag; } results.set(cast(getTransformed()[0]), {opResult}); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // LoopOutlineOp //===----------------------------------------------------------------------===// /// Wraps the given operation `op` into an `scf.execute_region` operation. Uses /// the provided rewriter for all operations to remain compatible with the /// rewriting infra, as opposed to just splicing the op in place. static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, Operation *op) { if (op->getNumRegions() != 1) return nullptr; OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); scf::ExecuteRegionOp executeRegionOp = b.create(op->getLoc(), op->getResultTypes()); { OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); Operation *clonedOp = b.cloneWithoutRegions(*op); Region &clonedRegion = clonedOp->getRegions().front(); assert(clonedRegion.empty() && "expected empty region"); b.inlineRegionBefore(op->getRegions().front(), clonedRegion, clonedRegion.end()); b.create(op->getLoc(), clonedOp->getResults()); } b.replaceOp(op, executeRegionOp.getResults()); return executeRegionOp; } DiagnosedSilenceableFailure transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector functions; SmallVector calls; DenseMap symbolTables; for (Operation *target : state.getPayloadOps(getTarget())) { Location location = target->getLoc(); Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); if (!exec) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "failed to outline"; diag.attachNote(target->getLoc()) << "target op"; return diag; } func::CallOp call; FailureOr outlined = outlineSingleBlockRegion( rewriter, location, exec.getRegion(), getFuncName(), &call); if (failed(outlined)) return emitDefaultDefiniteFailure(target); if (symbolTableOp) { SymbolTable &symbolTable = symbolTables.try_emplace(symbolTableOp, symbolTableOp) .first->getSecond(); symbolTable.insert(*outlined); call.setCalleeAttr(FlatSymbolRefAttr::get(*outlined)); } functions.push_back(*outlined); calls.push_back(call); } results.set(cast(getFunction()), functions); results.set(cast(getCall()), calls); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // LoopPeelOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter, scf::ForOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::ForOp result; if (getPeelFront()) { LogicalResult status = scf::peelForLoopFirstIteration(rewriter, target, result); if (failed(status)) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "failed to peel the first iteration"; return diag; } } else { LogicalResult status = scf::peelForLoopAndSimplifyBounds(rewriter, target, result); if (failed(status)) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "failed to peel the last iteration"; return diag; } } results.push_back(target); results.push_back(result); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // LoopPipelineOp //===----------------------------------------------------------------------===// /// Callback for PipeliningOption. Populates `schedule` with the mapping from an /// operation to its logical time position given the iteration interval and the /// read latency. The latter is only relevant for vector transfers. static void loopScheduling(scf::ForOp forOp, std::vector> &schedule, unsigned iterationInterval, unsigned readLatency) { auto getLatency = [&](Operation *op) -> unsigned { if (isa(op)) return readLatency; return 1; }; std::optional ubConstant = getConstantIntValue(forOp.getUpperBound()); std::optional lbConstant = getConstantIntValue(forOp.getLowerBound()); DenseMap opCycles; std::map> wrappedSchedule; for (Operation &op : forOp.getBody()->getOperations()) { if (isa(op)) continue; unsigned earlyCycle = 0; for (Value operand : op.getOperands()) { Operation *def = operand.getDefiningOp(); if (!def) continue; if (ubConstant && lbConstant) { unsigned ubInt = ubConstant.value(); unsigned lbInt = lbConstant.value(); auto minLatency = std::min(ubInt - lbInt - 1, getLatency(def)); earlyCycle = std::max(earlyCycle, opCycles[def] + minLatency); } else { earlyCycle = std::max(earlyCycle, opCycles[def] + getLatency(def)); } } opCycles[&op] = earlyCycle; wrappedSchedule[earlyCycle % iterationInterval].push_back(&op); } for (const auto &it : wrappedSchedule) { for (Operation *op : it.second) { unsigned cycle = opCycles[op]; schedule.emplace_back(op, cycle / iterationInterval); } } } DiagnosedSilenceableFailure transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter, scf::ForOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::PipeliningOption options; options.getScheduleFn = [this](scf::ForOp forOp, std::vector> &schedule) mutable { loopScheduling(forOp, schedule, getIterationInterval(), getReadLatency()); }; scf::ForLoopPipeliningPattern pattern(options, target->getContext()); rewriter.setInsertionPoint(target); FailureOr patternResult = scf::pipelineForLoop(rewriter, target, options); if (succeeded(patternResult)) { results.push_back(*patternResult); return DiagnosedSilenceableFailure::success(); } return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// // LoopPromoteIfOneIterationOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LoopPromoteIfOneIterationOp::applyToOne( transform::TransformRewriter &rewriter, LoopLikeOpInterface target, transform::ApplyToEachResultList &results, transform::TransformState &state) { (void)target.promoteIfSingleIteration(rewriter); return DiagnosedSilenceableFailure::success(); } void transform::LoopPromoteIfOneIterationOp::getEffects( SmallVectorImpl &effects) { consumesHandle(getTargetMutable(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // LoopUnrollOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, Operation *op, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); if (scf::ForOp scfFor = dyn_cast(op)) result = loopUnrollByFactor(scfFor, getFactor()); else if (AffineForOp affineFor = dyn_cast(op)) result = loopUnrollByFactor(affineFor, getFactor()); else return emitSilenceableError() << "failed to unroll, incorrect type of payload"; if (failed(result)) return emitSilenceableError() << "failed to unroll"; return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // LoopUnrollAndJamOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne( transform::TransformRewriter &rewriter, Operation *op, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); if (scf::ForOp scfFor = dyn_cast(op)) result = loopUnrollJamByFactor(scfFor, getFactor()); else if (AffineForOp affineFor = dyn_cast(op)) result = loopUnrollJamByFactor(affineFor, getFactor()); else return emitSilenceableError() << "failed to unroll and jam, incorrect type of payload"; if (failed(result)) return emitSilenceableError() << "failed to unroll and jam"; return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // LoopCoalesceOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter, Operation *op, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); if (scf::ForOp scfForOp = dyn_cast(op)) result = coalescePerfectlyNestedSCFForLoops(scfForOp); else if (AffineForOp affineForOp = dyn_cast(op)) result = coalescePerfectlyNestedAffineLoops(affineForOp); results.push_back(op); if (failed(result)) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "failed to coalesce"; return diag; } return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // TakeAssumedBranchOp //===----------------------------------------------------------------------===// /// Replaces the given op with the contents of the given single-block region, /// using the operands of the block terminator to replace operation results. static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, Region ®ion) { assert(llvm::hasSingleElement(region) && "expected single-region block"); Block *block = ®ion.front(); Operation *terminator = block->getTerminator(); ValueRange results = terminator->getOperands(); rewriter.inlineBlockBefore(block, op, /*blockArgs=*/{}); rewriter.replaceOp(op, results); rewriter.eraseOp(terminator); } DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne( transform::TransformRewriter &rewriter, scf::IfOp ifOp, transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(ifOp); Region ®ion = getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion(); if (!llvm::hasSingleElement(region)) { return emitDefiniteFailure() << "requires an scf.if op with a single-block " << ((getTakeElseBranch()) ? "`else`" : "`then`") << " region"; } replaceOpWithRegion(rewriter, ifOp, region); return DiagnosedSilenceableFailure::success(); } void transform::TakeAssumedBranchOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getTargetMutable(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // LoopFuseSiblingOp //===----------------------------------------------------------------------===// /// Check if `target` and `source` are siblings, in the context that `target` /// is being fused into `source`. /// /// This is a simple check that just checks if both operations are in the same /// block and some checks to ensure that the fused IR does not violate /// dominance. static DiagnosedSilenceableFailure isOpSibling(Operation *target, Operation *source) { // Check if both operations are same. if (target == source) return emitSilenceableFailure(source) << "target and source need to be different loops"; // Check if both operations are in the same block. if (target->getBlock() != source->getBlock()) return emitSilenceableFailure(source) << "target and source are not in the same block"; // Check if fusion will violate dominance. DominanceInfo domInfo(source); if (target->isBeforeInBlock(source)) { // Since `target` is before `source`, all users of results of `target` // need to be dominated by `source`. for (Operation *user : target->getUsers()) { if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) { return emitSilenceableFailure(target) << "user of results of target should be properly dominated by " "source"; } } } else { // Since `target` is after `source`, all values used by `target` need // to dominate `source`. // Check if operands of `target` are dominated by `source`. for (Value operand : target->getOperands()) { Operation *operandOp = operand.getDefiningOp(); // Operands without defining operations are block arguments. When `target` // and `source` occur in the same block, these operands dominate `source`. if (!operandOp) continue; // Operand's defining operation should properly dominate `source`. if (!domInfo.properlyDominates(operandOp, source, /*enclosingOpOk=*/false)) return emitSilenceableFailure(target) << "operands of target should be properly dominated by source"; } // Check if values used by `target` are dominated by `source`. bool failed = false; OpOperand *failedValue = nullptr; visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) { Operation *operandOp = operand->get().getDefiningOp(); if (operandOp && !domInfo.properlyDominates(operandOp, source, /*enclosingOpOk=*/false)) { // `operand` is not an argument of an enclosing block and the defining // op of `operand` is outside `target` but does not dominate `source`. failed = true; failedValue = operand; } }); if (failed) return emitSilenceableFailure(failedValue->getOwner()) << "values used inside regions of target should be properly " "dominated by source"; } return DiagnosedSilenceableFailure::success(); } /// Check if `target` scf.forall can be fused into `source` scf.forall. /// /// This simply checks if both loops have the same bounds, steps and mapping. /// No attempt is made at checking that the side effects of `target` and /// `source` are independent of each other. static bool isForallWithIdenticalConfiguration(Operation *target, Operation *source) { auto targetOp = dyn_cast(target); auto sourceOp = dyn_cast(source); if (!targetOp || !sourceOp) return false; return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() && targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() && targetOp.getMixedStep() == sourceOp.getMixedStep() && targetOp.getMapping() == sourceOp.getMapping(); } /// Check if `target` scf.for can be fused into `source` scf.for. /// /// This simply checks if both loops have the same bounds and steps. No attempt /// is made at checking that the side effects of `target` and `source` are /// independent of each other. static bool isForWithIdenticalConfiguration(Operation *target, Operation *source) { auto targetOp = dyn_cast(target); auto sourceOp = dyn_cast(source); if (!targetOp || !sourceOp) return false; return targetOp.getLowerBound() == sourceOp.getLowerBound() && targetOp.getUpperBound() == sourceOp.getUpperBound() && targetOp.getStep() == sourceOp.getStep(); } DiagnosedSilenceableFailure transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); auto sourceOps = state.getPayloadOps(getSource()); if (!llvm::hasSingleElement(targetOps) || !llvm::hasSingleElement(sourceOps)) { return emitDefiniteFailure() << "requires exactly one target handle (got " << llvm::range_size(targetOps) << ") and exactly one " << "source handle (got " << llvm::range_size(sourceOps) << ")"; } Operation *target = *targetOps.begin(); Operation *source = *sourceOps.begin(); // Check if the target and source are siblings. DiagnosedSilenceableFailure diag = isOpSibling(target, source); if (!diag.succeeded()) return diag; Operation *fusedLoop; /// TODO: Support fusion for loop-like ops besides scf.for and scf.forall. if (isForWithIdenticalConfiguration(target, source)) { fusedLoop = fuseIndependentSiblingForLoops( cast(target), cast(source), rewriter); } else if (isForallWithIdenticalConfiguration(target, source)) { fusedLoop = fuseIndependentSiblingForallLoops( cast(target), cast(source), rewriter); } else return emitSilenceableFailure(target->getLoc()) << "operations cannot be fused"; assert(fusedLoop && "failed to fuse operations"); results.set(cast(getFusedLoop()), {fusedLoop}); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// namespace { class SCFTransformDialectExtension : public transform::TransformDialectExtension< SCFTransformDialectExtension> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFTransformDialectExtension) using Base::Base; void init() { declareGeneratedDialect(); declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" >(); } }; } // namespace #define GET_OP_CLASSES #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" void mlir::scf::registerTransformDialectExtension(DialectRegistry ®istry) { registry.addExtensions(); }