//===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" using namespace mlir; #define DEBUG_TYPE "memref-transforms" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") //===----------------------------------------------------------------------===// // Apply...ConversionPatternsOp //===----------------------------------------------------------------------===// std::unique_ptr transform::MemrefToLLVMTypeConverterOp::getTypeConverter() { LowerToLLVMOptions options(getContext()); options.allocLowering = (getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc : LowerToLLVMOptions::AllocLowering::Malloc); options.useGenericFunctions = getUseGenericFunctions(); if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(getIndexBitwidth()); // TODO: the following two options don't really make sense for // memref_to_llvm_type_converter specifically but we should have a single // to_llvm_type_converter. if (getDataLayout().has_value()) options.dataLayout = llvm::DataLayout(getDataLayout().value()); options.useBarePtrCallConv = getUseBarePtrCallConv(); return std::make_unique(getContext(), options); } StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() { return "LLVMTypeConverter"; } //===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===// namespace { class AllocToAllocaPattern : public OpRewritePattern { public: explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0) : OpRewritePattern(analysisRoot->getContext()), dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {} LogicalResult matchAndRewrite(memref::AllocOp op, PatternRewriter &rewriter) const override { return success(memref::allocToAlloca( rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) { MemRefType type = alloc.getMemref().getType(); if (!type.hasStaticShape()) return false; const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc); int64_t elementSize = dataLayout.getTypeSize(type.getElementType()); return maxSize == 0 || type.getNumElements() * elementSize < maxSize; })); } private: DataLayoutAnalysis dataLayoutAnalysis; int64_t maxSize; }; } // namespace void transform::ApplyAllocToAllocaOp::populatePatterns( RewritePatternSet &patterns) {} void transform::ApplyAllocToAllocaOp::populatePatternsWithState( RewritePatternSet &patterns, transform::TransformState &state) { patterns.insert( state.getTopLevel(), static_cast(getSizeLimit().value_or(0))); } void transform::ApplyExpandOpsPatternsOp::populatePatterns( RewritePatternSet &patterns) { memref::populateExpandOpsPatterns(patterns); } void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns( RewritePatternSet &patterns) { memref::populateExpandStridedMetadataPatterns(patterns); } void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns( RewritePatternSet &patterns) { memref::populateExtractAddressComputationsPatterns(patterns); } void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns( RewritePatternSet &patterns) { memref::populateFoldMemRefAliasOpPatterns(patterns); } void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp:: populatePatterns(RewritePatternSet &patterns) { memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); } //===----------------------------------------------------------------------===// // AllocaToGlobalOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto allocaOps = state.getPayloadOps(getAlloca()); SmallVector globalOps; SmallVector getGlobalOps; // Transform `memref.alloca`s. for (auto *op : allocaOps) { auto alloca = cast(op); MLIRContext *ctx = rewriter.getContext(); Location loc = alloca->getLoc(); memref::GlobalOp globalOp; { // Find nearest symbol table. Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op); assert(symbolTableOp && "expected alloca payload to be in symbol table"); SymbolTable symbolTable(symbolTableOp); // Insert a `memref.global` into the symbol table. Type resultType = alloca.getResult().getType(); OpBuilder builder(rewriter.getContext()); // TODO: Add a better builder for this. globalOp = builder.create( loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"), TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{}); symbolTable.insert(globalOp); } // Replace the `memref.alloca` with a `memref.get_global` accessing the // global symbol inserted above. rewriter.setInsertionPoint(alloca); auto getGlobalOp = rewriter.replaceOpWithNewOp( alloca, globalOp.getType(), globalOp.getName()); globalOps.push_back(globalOp); getGlobalOps.push_back(getGlobalOp); } // Assemble results. results.set(cast(getGlobal()), globalOps); results.set(cast(getGetGlobal()), getGlobalOps); return DiagnosedSilenceableFailure::success(); } void transform::MemRefAllocaToGlobalOp::getEffects( SmallVectorImpl &effects) { producesHandle(getOperation()->getOpResults(), effects); consumesHandle(getAllocaMutable(), effects); modifiesPayload(effects); } //===----------------------------------------------------------------------===// // MemRefMultiBufferOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; for (Operation *op : state.getPayloadOps(getTarget())) { bool canApplyMultiBuffer = true; auto target = cast(op); LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";); // Skip allocations not used in a loop. for (Operation *user : target->getUsers()) { if (isa(user)) continue; auto loop = user->getParentOfType(); if (!loop) { LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n"; DBGS() << "----due to user: " << *user;); canApplyMultiBuffer = false; break; } } if (!canApplyMultiBuffer) { LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";); continue; } auto newBuffer = memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis()); if (failed(newBuffer)) { LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";); return emitSilenceableFailure(target->getLoc()) << "op failed to multibuffer"; } results.push_back(*newBuffer); } transformResults.set(cast(getResult()), results); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // MemRefEraseDeadAllocAndStoresOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefEraseDeadAllocAndStoresOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Apply store to load forwarding and dead store elimination. vector::transferOpflowOpt(rewriter, target); memref::eraseDeadAllocAndStores(rewriter, target); return DiagnosedSilenceableFailure::success(); } void transform::MemRefEraseDeadAllocAndStoresOp::getEffects( SmallVectorImpl &effects) { transform::onlyReadsHandle(getTargetMutable(), effects); transform::modifiesPayload(effects); } void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder, OperationState &result, Value target) { result.addOperands(target); } //===----------------------------------------------------------------------===// // MemRefMakeLoopIndependentOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne( transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Gather IVs. SmallVector ivs; Operation *nextOp = target; for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) { nextOp = nextOp->getParentOfType(); if (!nextOp) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "could not find " << i << "-th enclosing loop"; diag.attachNote(target->getLoc()) << "target op"; return diag; } ivs.push_back(cast(nextOp).getInductionVar()); } // Rewrite IR. FailureOr replacement = failure(); if (auto allocaOp = dyn_cast(target)) { replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs); } else { DiagnosedSilenceableFailure diag = emitSilenceableError() << "unsupported target op"; diag.attachNote(target->getLoc()) << "target op"; return diag; } if (failed(replacement)) { DiagnosedSilenceableFailure diag = emitSilenceableError() << "could not make target op loop-independent"; diag.attachNote(target->getLoc()) << "target op"; return diag; } results.push_back(replacement->getDefiningOp()); return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// namespace { class MemRefTransformDialectExtension : public transform::TransformDialectExtension< MemRefTransformDialectExtension> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemRefTransformDialectExtension) using Base::Base; void init() { declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" >(); } }; } // namespace #define GET_OP_CLASSES #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" void mlir::memref::registerTransformDialectExtension( DialectRegistry ®istry) { registry.addExtensions(); }