xref: /llvm-project/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp (revision 84cc1865ef9202af39404ff4524a9b13df80cfc1)
1 //===- MemRefTransformOps.cpp - Implementation of Memref transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
10 
11 #include "mlir/Analysis/DataLayoutAnalysis.h"
12 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Arith/IR/Arith.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
17 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
18 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
22 #include "mlir/Dialect/Transform/IR/TransformTypes.h"
23 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
24 #include "mlir/Dialect/Vector/IR/VectorOps.h"
25 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
26 #include "mlir/Interfaces/LoopLikeInterface.h"
27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28 #include "llvm/Support/Debug.h"
29 
30 using namespace mlir;
31 
32 #define DEBUG_TYPE "memref-transforms"
33 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
34 
35 //===----------------------------------------------------------------------===//
36 // Apply...ConversionPatternsOp
37 //===----------------------------------------------------------------------===//
38 
39 std::unique_ptr<TypeConverter>
40 transform::MemrefToLLVMTypeConverterOp::getTypeConverter() {
41   LowerToLLVMOptions options(getContext());
42   options.allocLowering =
43       (getUseAlignedAlloc() ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
44                             : LowerToLLVMOptions::AllocLowering::Malloc);
45   options.useGenericFunctions = getUseGenericFunctions();
46 
47   if (getIndexBitwidth() != kDeriveIndexBitwidthFromDataLayout)
48     options.overrideIndexBitwidth(getIndexBitwidth());
49 
50   // TODO: the following two options don't really make sense for
51   // memref_to_llvm_type_converter specifically but we should have a single
52   // to_llvm_type_converter.
53   if (getDataLayout().has_value())
54     options.dataLayout = llvm::DataLayout(getDataLayout().value());
55   options.useBarePtrCallConv = getUseBarePtrCallConv();
56 
57   return std::make_unique<LLVMTypeConverter>(getContext(), options);
58 }
59 
60 StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
61   return "LLVMTypeConverter";
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // Apply...PatternsOp
66 //===----------------------------------------------------------------------===//
67 
68 namespace {
69 class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
70 public:
71   explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
72       : OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
73         dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
74 
75   LogicalResult matchAndRewrite(memref::AllocOp op,
76                                 PatternRewriter &rewriter) const override {
77     return success(memref::allocToAlloca(
78         rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
79           MemRefType type = alloc.getMemref().getType();
80           if (!type.hasStaticShape())
81             return false;
82 
83           const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
84           int64_t elementSize = dataLayout.getTypeSize(type.getElementType());
85           return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
86         }));
87   }
88 
89 private:
90   DataLayoutAnalysis dataLayoutAnalysis;
91   int64_t maxSize;
92 };
93 } // namespace
94 
95 void transform::ApplyAllocToAllocaOp::populatePatterns(
96     RewritePatternSet &patterns) {}
97 
98 void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
99     RewritePatternSet &patterns, transform::TransformState &state) {
100   patterns.insert<AllocToAllocaPattern>(
101       state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
102 }
103 
104 void transform::ApplyExpandOpsPatternsOp::populatePatterns(
105     RewritePatternSet &patterns) {
106   memref::populateExpandOpsPatterns(patterns);
107 }
108 
109 void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns(
110     RewritePatternSet &patterns) {
111   memref::populateExpandStridedMetadataPatterns(patterns);
112 }
113 
114 void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns(
115     RewritePatternSet &patterns) {
116   memref::populateExtractAddressComputationsPatterns(patterns);
117 }
118 
119 void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns(
120     RewritePatternSet &patterns) {
121   memref::populateFoldMemRefAliasOpPatterns(patterns);
122 }
123 
124 void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp::
125     populatePatterns(RewritePatternSet &patterns) {
126   memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // AllocaToGlobalOp
131 //===----------------------------------------------------------------------===//
132 
133 DiagnosedSilenceableFailure
134 transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter,
135                                          transform::TransformResults &results,
136                                          transform::TransformState &state) {
137   auto allocaOps = state.getPayloadOps(getAlloca());
138 
139   SmallVector<memref::GlobalOp> globalOps;
140   SmallVector<memref::GetGlobalOp> getGlobalOps;
141 
142   // Transform `memref.alloca`s.
143   for (auto *op : allocaOps) {
144     auto alloca = cast<memref::AllocaOp>(op);
145     MLIRContext *ctx = rewriter.getContext();
146     Location loc = alloca->getLoc();
147 
148     memref::GlobalOp globalOp;
149     {
150       // Find nearest symbol table.
151       Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
152       assert(symbolTableOp && "expected alloca payload to be in symbol table");
153       SymbolTable symbolTable(symbolTableOp);
154 
155       // Insert a `memref.global` into the symbol table.
156       Type resultType = alloca.getResult().getType();
157       OpBuilder builder(rewriter.getContext());
158       // TODO: Add a better builder for this.
159       globalOp = builder.create<memref::GlobalOp>(
160           loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"),
161           TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{});
162       symbolTable.insert(globalOp);
163     }
164 
165     // Replace the `memref.alloca` with a `memref.get_global` accessing the
166     // global symbol inserted above.
167     rewriter.setInsertionPoint(alloca);
168     auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(
169         alloca, globalOp.getType(), globalOp.getName());
170 
171     globalOps.push_back(globalOp);
172     getGlobalOps.push_back(getGlobalOp);
173   }
174 
175   // Assemble results.
176   results.set(cast<OpResult>(getGlobal()), globalOps);
177   results.set(cast<OpResult>(getGetGlobal()), getGlobalOps);
178 
179   return DiagnosedSilenceableFailure::success();
180 }
181 
182 void transform::MemRefAllocaToGlobalOp::getEffects(
183     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
184   producesHandle(getOperation()->getOpResults(), effects);
185   consumesHandle(getAllocaMutable(), effects);
186   modifiesPayload(effects);
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // MemRefMultiBufferOp
191 //===----------------------------------------------------------------------===//
192 
193 DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
194     transform::TransformRewriter &rewriter,
195     transform::TransformResults &transformResults,
196     transform::TransformState &state) {
197   SmallVector<Operation *> results;
198   for (Operation *op : state.getPayloadOps(getTarget())) {
199     bool canApplyMultiBuffer = true;
200     auto target = cast<memref::AllocOp>(op);
201     LLVM_DEBUG(DBGS() << "Start multibuffer transform op: " << target << "\n";);
202     // Skip allocations not used in a loop.
203     for (Operation *user : target->getUsers()) {
204       if (isa<memref::DeallocOp>(user))
205         continue;
206       auto loop = user->getParentOfType<LoopLikeOpInterface>();
207       if (!loop) {
208         LLVM_DEBUG(DBGS() << "--allocation not used in a loop\n";
209                    DBGS() << "----due to user: " << *user;);
210         canApplyMultiBuffer = false;
211         break;
212       }
213     }
214     if (!canApplyMultiBuffer) {
215       LLVM_DEBUG(DBGS() << "--cannot apply multibuffering -> Skip\n";);
216       continue;
217     }
218 
219     auto newBuffer =
220         memref::multiBuffer(rewriter, target, getFactor(), getSkipAnalysis());
221 
222     if (failed(newBuffer)) {
223       LLVM_DEBUG(DBGS() << "--op failed to multibuffer\n";);
224       return emitSilenceableFailure(target->getLoc())
225              << "op failed to multibuffer";
226     }
227 
228     results.push_back(*newBuffer);
229   }
230   transformResults.set(cast<OpResult>(getResult()), results);
231   return DiagnosedSilenceableFailure::success();
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // MemRefEraseDeadAllocAndStoresOp
236 //===----------------------------------------------------------------------===//
237 
238 DiagnosedSilenceableFailure
239 transform::MemRefEraseDeadAllocAndStoresOp::applyToOne(
240     transform::TransformRewriter &rewriter, Operation *target,
241     transform::ApplyToEachResultList &results,
242     transform::TransformState &state) {
243   // Apply store to load forwarding and dead store elimination.
244   vector::transferOpflowOpt(rewriter, target);
245   memref::eraseDeadAllocAndStores(rewriter, target);
246   return DiagnosedSilenceableFailure::success();
247 }
248 
249 void transform::MemRefEraseDeadAllocAndStoresOp::getEffects(
250     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
251   transform::onlyReadsHandle(getTargetMutable(), effects);
252   transform::modifiesPayload(effects);
253 }
254 void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder,
255                                                        OperationState &result,
256                                                        Value target) {
257   result.addOperands(target);
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // MemRefMakeLoopIndependentOp
262 //===----------------------------------------------------------------------===//
263 
264 DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne(
265     transform::TransformRewriter &rewriter, Operation *target,
266     transform::ApplyToEachResultList &results,
267     transform::TransformState &state) {
268   // Gather IVs.
269   SmallVector<Value> ivs;
270   Operation *nextOp = target;
271   for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
272     nextOp = nextOp->getParentOfType<scf::ForOp>();
273     if (!nextOp) {
274       DiagnosedSilenceableFailure diag = emitSilenceableError()
275                                          << "could not find " << i
276                                          << "-th enclosing loop";
277       diag.attachNote(target->getLoc()) << "target op";
278       return diag;
279     }
280     ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
281   }
282 
283   // Rewrite IR.
284   FailureOr<Value> replacement = failure();
285   if (auto allocaOp = dyn_cast<memref::AllocaOp>(target)) {
286     replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs);
287   } else {
288     DiagnosedSilenceableFailure diag = emitSilenceableError()
289                                        << "unsupported target op";
290     diag.attachNote(target->getLoc()) << "target op";
291     return diag;
292   }
293   if (failed(replacement)) {
294     DiagnosedSilenceableFailure diag =
295         emitSilenceableError() << "could not make target op loop-independent";
296     diag.attachNote(target->getLoc()) << "target op";
297     return diag;
298   }
299   results.push_back(replacement->getDefiningOp());
300   return DiagnosedSilenceableFailure::success();
301 }
302 
303 //===----------------------------------------------------------------------===//
304 // Transform op registration
305 //===----------------------------------------------------------------------===//
306 
307 namespace {
308 class MemRefTransformDialectExtension
309     : public transform::TransformDialectExtension<
310           MemRefTransformDialectExtension> {
311 public:
312   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MemRefTransformDialectExtension)
313 
314   using Base::Base;
315 
316   void init() {
317     declareGeneratedDialect<affine::AffineDialect>();
318     declareGeneratedDialect<arith::ArithDialect>();
319     declareGeneratedDialect<memref::MemRefDialect>();
320     declareGeneratedDialect<nvgpu::NVGPUDialect>();
321     declareGeneratedDialect<vector::VectorDialect>();
322 
323     registerTransformOps<
324 #define GET_OP_LIST
325 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
326         >();
327   }
328 };
329 } // namespace
330 
331 #define GET_OP_CLASSES
332 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc"
333 
334 void mlir::memref::registerTransformDialectExtension(
335     DialectRegistry &registry) {
336   registry.addExtensions<MemRefTransformDialectExtension>();
337 }
338