1 //===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" 10 11 #include "mlir/Analysis/DataLayoutAnalysis.h" 12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 17 #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 18 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 19 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 20 #include "mlir/Dialect/SCF/IR/SCF.h" 21 #include "mlir/Dialect/Transform/IR/TransformDialect.h" 22 #include "mlir/Dialect/Transform/IR/TransformTypes.h" 23 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" 24 #include "mlir/Dialect/Vector/IR/VectorOps.h" 25 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 26 #include "mlir/Interfaces/LoopLikeInterface.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/Support/Debug.h" 29 30 using namespace mlir; 31 32 #define DEBUG_TYPE "memref-transforms" 33 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") 34 35 //===----------------------------------------------------------------------===// 36 // Apply...ConversionPatternsOp 37 //===----------------------------------------------------------------------===// 38 39 std::unique_ptr<TypeConverter> 40 transform::MemrefToLLVMTypeConverterOp::getTypeConverter() { 41 LowerToLLVMOptions options(getContext()); 42 options.allocLowering = 43 (getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc 44 : LowerToLLVMOptions::AllocLowering::Malloc); 45 options.useGenericFunctions = getUseGenericFunctions(); 46 47 if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout) 48 options.overrideIndexBitwidth(getIndexBitwidth()); 49 50 // TODO: the following two options don't really make sense for 51 // memref_to_llvm_type_converter specifically but we should have a single 52 // to_llvm_type_converter. 53 if (getDataLayout().has_value()) 54 options.dataLayout = llvm::DataLayout(getDataLayout().value()); 55 options.useBarePtrCallConv = getUseBarePtrCallConv(); 56 57 return std::make_unique<LLVMTypeConverter>(getContext(), options); 58 } 59 60 StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() { 61 return "LLVMTypeConverter"; 62 } 63 64 //===----------------------------------------------------------------------===// 65 // Apply...PatternsOp 66 //===----------------------------------------------------------------------===// 67 68 namespace { 69 class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> { 70 public: 71 explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0) 72 : OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()), 73 dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {} 74 75 LogicalResult matchAndRewrite(memref::AllocOp op, 76 PatternRewriter &rewriter) const override { 77 return success(memref::allocToAlloca( 78 rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) { 79 MemRefType type = alloc.getMemref().getType(); 80 if (!type.hasStaticShape()) 81 return false; 82 83 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc); 84 int64_t elementSize = dataLayout.getTypeSize(type.getElementType()); 85 return maxSize == 0 || type.getNumElements() * elementSize < maxSize; 86 })); 87 } 88 89 private: 90 DataLayoutAnalysis dataLayoutAnalysis; 91 int64_t maxSize; 92 }; 93 } // namespace 94 95 void transform::ApplyAllocToAllocaOp::populatePatterns( 96 RewritePatternSet &patterns) {} 97 98 void transform::ApplyAllocToAllocaOp::populatePatternsWithState( 99 RewritePatternSet &patterns, transform::TransformState &state) { 100 patterns.insert<AllocToAllocaPattern>( 101 state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0))); 102 } 103 104 void transform::ApplyExpandOpsPatternsOp::populatePatterns( 105 RewritePatternSet &patterns) { 106 memref::populateExpandOpsPatterns(patterns); 107 } 108 109 void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns( 110 RewritePatternSet &patterns) { 111 memref::populateExpandStridedMetadataPatterns(patterns); 112 } 113 114 void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns( 115 RewritePatternSet &patterns) { 116 memref::populateExtractAddressComputationsPatterns(patterns); 117 } 118 119 void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns( 120 RewritePatternSet &patterns) { 121 memref::populateFoldMemRefAliasOpPatterns(patterns); 122 } 123 124 void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp:: 125 populatePatterns(RewritePatternSet &patterns) { 126 memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); 127 } 128 129 //===----------------------------------------------------------------------===// 130 // AllocaToGlobalOp 131 //===----------------------------------------------------------------------===// 132 133 DiagnosedSilenceableFailure 134 transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, 135 transform::TransformResults &results, 136 transform::TransformState &state) { 137 auto allocaOps = state.getPayloadOps(getAlloca()); 138 139 SmallVector<memref::GlobalOp> globalOps; 140 SmallVector<memref::GetGlobalOp> getGlobalOps; 141 142 // Transform `memref.alloca`s. 143 for (auto *op : allocaOps) { 144 auto alloca = cast<memref::AllocaOp>(op); 145 MLIRContext *ctx = rewriter.getContext(); 146 Location loc = alloca->getLoc(); 147 148 memref::GlobalOp globalOp; 149 { 150 // Find nearest symbol table. 151 Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op); 152 assert(symbolTableOp && "expected alloca payload to be in symbol table"); 153 SymbolTable symbolTable(symbolTableOp); 154 155 // Insert a `memref.global` into the symbol table. 156 Type resultType = alloca.getResult().getType(); 157 OpBuilder builder(rewriter.getContext()); 158 // TODO: Add a better builder for this. 159 globalOp = builder.create<memref::GlobalOp>( 160 loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"), 161 TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{}); 162 symbolTable.insert(globalOp); 163 } 164 165 // Replace the `memref.alloca` with a `memref.get_global` accessing the 166 // global symbol inserted above. 167 rewriter.setInsertionPoint(alloca); 168 auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>( 169 alloca, globalOp.getType(), globalOp.getName()); 170 171 globalOps.push_back(globalOp); 172 getGlobalOps.push_back(getGlobalOp); 173 } 174 175 // Assemble results. 176 results.set(cast<OpResult>(getGlobal()), globalOps); 177 results.set(cast<OpResult>(getGetGlobal()), getGlobalOps); 178 179 return DiagnosedSilenceableFailure::success(); 180 } 181 182 void transform::MemRefAllocaToGlobalOp::getEffects( 183 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 184 producesHandle(getOperation()->getOpResults(), effects); 185 consumesHandle(getAllocaMutable(), effects); 186 modifiesPayload(effects); 187 } 188 189 //===----------------------------------------------------------------------===// 190 // MemRefMultiBufferOp 191 //===----------------------------------------------------------------------===// 192 193 DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( 194 transform::TransformRewriter &rewriter, 195 transform::TransformResults &transformResults, 196 transform::TransformState &state) { 197 SmallVector<Operation *> results; 198 for (Operation *op : state.getPayloadOps(getTarget())) { 199 bool canApplyMultiBuffer = true; 200 auto target = cast<memref::AllocOp>(op); 201 LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";); 202 // Skip allocations not used in a loop. 203 for (Operation *user : target->getUsers()) { 204 if (isa<memref::DeallocOp>(user)) 205 continue; 206 auto loop = user->getParentOfType<LoopLikeOpInterface>(); 207 if (!loop) { 208 LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n"; 209 DBGS() << "----due to user: " << *user;); 210 canApplyMultiBuffer = false; 211 break; 212 } 213 } 214 if (!canApplyMultiBuffer) { 215 LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";); 216 continue; 217 } 218 219 auto newBuffer = 220 memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis()); 221 222 if (failed(newBuffer)) { 223 LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";); 224 return emitSilenceableFailure(target->getLoc()) 225 << "op failed to multibuffer"; 226 } 227 228 results.push_back(*newBuffer); 229 } 230 transformResults.set(cast<OpResult>(getResult()), results); 231 return DiagnosedSilenceableFailure::success(); 232 } 233 234 //===----------------------------------------------------------------------===// 235 // MemRefEraseDeadAllocAndStoresOp 236 //===----------------------------------------------------------------------===// 237 238 DiagnosedSilenceableFailure 239 transform::MemRefEraseDeadAllocAndStoresOp::applyToOne( 240 transform::TransformRewriter &rewriter, Operation *target, 241 transform::ApplyToEachResultList &results, 242 transform::TransformState &state) { 243 // Apply store to load forwarding and dead store elimination. 244 vector::transferOpflowOpt(rewriter, target); 245 memref::eraseDeadAllocAndStores(rewriter, target); 246 return DiagnosedSilenceableFailure::success(); 247 } 248 249 void transform::MemRefEraseDeadAllocAndStoresOp::getEffects( 250 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { 251 transform::onlyReadsHandle(getTargetMutable(), effects); 252 transform::modifiesPayload(effects); 253 } 254 void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder, 255 OperationState &result, 256 Value target) { 257 result.addOperands(target); 258 } 259 260 //===----------------------------------------------------------------------===// 261 // MemRefMakeLoopIndependentOp 262 //===----------------------------------------------------------------------===// 263 264 DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne( 265 transform::TransformRewriter &rewriter, Operation *target, 266 transform::ApplyToEachResultList &results, 267 transform::TransformState &state) { 268 // Gather IVs. 269 SmallVector<Value> ivs; 270 Operation *nextOp = target; 271 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) { 272 nextOp = nextOp->getParentOfType<scf::ForOp>(); 273 if (!nextOp) { 274 DiagnosedSilenceableFailure diag = emitSilenceableError() 275 << "could not find " << i 276 << "-th enclosing loop"; 277 diag.attachNote(target->getLoc()) << "target op"; 278 return diag; 279 } 280 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar()); 281 } 282 283 // Rewrite IR. 284 FailureOr<Value> replacement = failure(); 285 if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) { 286 replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs); 287 } else { 288 DiagnosedSilenceableFailure diag = emitSilenceableError() 289 << "unsupported target op"; 290 diag.attachNote(target->getLoc()) << "target op"; 291 return diag; 292 } 293 if (failed(replacement)) { 294 DiagnosedSilenceableFailure diag = 295 emitSilenceableError() << "could not make target op loop-independent"; 296 diag.attachNote(target->getLoc()) << "target op"; 297 return diag; 298 } 299 results.push_back(replacement->getDefiningOp()); 300 return DiagnosedSilenceableFailure::success(); 301 } 302 303 //===----------------------------------------------------------------------===// 304 // Transform op registration 305 //===----------------------------------------------------------------------===// 306 307 namespace { 308 class MemRefTransformDialectExtension 309 : public transform::TransformDialectExtension< 310 MemRefTransformDialectExtension> { 311 public: 312 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemRefTransformDialectExtension) 313 314 using Base::Base; 315 316 void init() { 317 declareGeneratedDialect<affine::AffineDialect>(); 318 declareGeneratedDialect<arith::ArithDialect>(); 319 declareGeneratedDialect<memref::MemRefDialect>(); 320 declareGeneratedDialect<nvgpu::NVGPUDialect>(); 321 declareGeneratedDialect<vector::VectorDialect>(); 322 323 registerTransformOps< 324 #define GET_OP_LIST 325 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" 326 >(); 327 } 328 }; 329 } // namespace 330 331 #define GET_OP_CLASSES 332 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" 333 334 void mlir::memref::registerTransformDialectExtension( 335 DialectRegistry ®istry) { 336 registry.addExtensions<MemRefTransformDialectExtension>(); 337 } 338