xref: /llvm-project/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp (revision 915fce040271c77df1ff9b2c8797c441cec0d18d)
1 //===- AffineToStandard.cpp - Lower affine constructs to primitives -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file lowers affine constructs (If and For statements, AffineApply
10 // operations) within a function into their standard If and For equivalent ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
15 
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
18 #include "mlir/Dialect/Affine/Utils.h"
19 #include "mlir/Dialect/MemRef/IR/MemRef.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/Dialect/Vector/IR/VectorOps.h"
22 #include "mlir/IR/IRMapping.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/Passes.h"
27 
28 namespace mlir {
29 #define GEN_PASS_DEF_CONVERTAFFINETOSTANDARD
30 #include "mlir/Conversion/Passes.h.inc"
31 } // namespace mlir
32 
33 using namespace mlir;
34 using namespace mlir::affine;
35 using namespace mlir::vector;
36 
37 /// Given a range of values, emit the code that reduces them with "min" or "max"
38 /// depending on the provided comparison predicate, sgt for max and slt for min.
39 ///
40 /// Multiple values are scanned in a linear sequence.  This creates a data
41 /// dependences that wouldn't exist in a tree reduction, but is easier to
42 /// recognize as a reduction by the subsequent passes.
buildMinMaxReductionSeq(Location loc,arith::CmpIPredicate predicate,ValueRange values,OpBuilder & builder)43 static Value buildMinMaxReductionSeq(Location loc,
44                                      arith::CmpIPredicate predicate,
45                                      ValueRange values, OpBuilder &builder) {
46   assert(!values.empty() && "empty min/max chain");
47   assert(predicate == arith::CmpIPredicate::sgt ||
48          predicate == arith::CmpIPredicate::slt);
49 
50   auto valueIt = values.begin();
51   Value value = *valueIt++;
52   for (; valueIt != values.end(); ++valueIt) {
53     if (predicate == arith::CmpIPredicate::sgt)
54       value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
55     else
56       value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
57   }
58 
59   return value;
60 }
61 
62 /// Emit instructions that correspond to computing the maximum value among the
63 /// values of a (potentially) multi-output affine map applied to `operands`.
lowerAffineMapMax(OpBuilder & builder,Location loc,AffineMap map,ValueRange operands)64 static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map,
65                                ValueRange operands) {
66   if (auto values = expandAffineMap(builder, loc, map, operands))
67     return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::sgt, *values,
68                                    builder);
69   return nullptr;
70 }
71 
72 /// Emit instructions that correspond to computing the minimum value among the
73 /// values of a (potentially) multi-output affine map applied to `operands`.
lowerAffineMapMin(OpBuilder & builder,Location loc,AffineMap map,ValueRange operands)74 static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map,
75                                ValueRange operands) {
76   if (auto values = expandAffineMap(builder, loc, map, operands))
77     return buildMinMaxReductionSeq(loc, arith::CmpIPredicate::slt, *values,
78                                    builder);
79   return nullptr;
80 }
81 
82 /// Emit instructions that correspond to the affine map in the upper bound
83 /// applied to the respective operands, and compute the minimum value across
84 /// the results.
lowerAffineUpperBound(AffineForOp op,OpBuilder & builder)85 Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
86   return lowerAffineMapMin(builder, op.getLoc(), op.getUpperBoundMap(),
87                            op.getUpperBoundOperands());
88 }
89 
90 /// Emit instructions that correspond to the affine map in the lower bound
91 /// applied to the respective operands, and compute the maximum value across
92 /// the results.
lowerAffineLowerBound(AffineForOp op,OpBuilder & builder)93 Value mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
94   return lowerAffineMapMax(builder, op.getLoc(), op.getLowerBoundMap(),
95                            op.getLowerBoundOperands());
96 }
97 
98 namespace {
99 class AffineMinLowering : public OpRewritePattern<AffineMinOp> {
100 public:
101   using OpRewritePattern<AffineMinOp>::OpRewritePattern;
102 
matchAndRewrite(AffineMinOp op,PatternRewriter & rewriter) const103   LogicalResult matchAndRewrite(AffineMinOp op,
104                                 PatternRewriter &rewriter) const override {
105     Value reduced =
106         lowerAffineMapMin(rewriter, op.getLoc(), op.getMap(), op.getOperands());
107     if (!reduced)
108       return failure();
109 
110     rewriter.replaceOp(op, reduced);
111     return success();
112   }
113 };
114 
115 class AffineMaxLowering : public OpRewritePattern<AffineMaxOp> {
116 public:
117   using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
118 
matchAndRewrite(AffineMaxOp op,PatternRewriter & rewriter) const119   LogicalResult matchAndRewrite(AffineMaxOp op,
120                                 PatternRewriter &rewriter) const override {
121     Value reduced =
122         lowerAffineMapMax(rewriter, op.getLoc(), op.getMap(), op.getOperands());
123     if (!reduced)
124       return failure();
125 
126     rewriter.replaceOp(op, reduced);
127     return success();
128   }
129 };
130 
131 /// Affine yields ops are removed.
132 class AffineYieldOpLowering : public OpRewritePattern<AffineYieldOp> {
133 public:
134   using OpRewritePattern<AffineYieldOp>::OpRewritePattern;
135 
matchAndRewrite(AffineYieldOp op,PatternRewriter & rewriter) const136   LogicalResult matchAndRewrite(AffineYieldOp op,
137                                 PatternRewriter &rewriter) const override {
138     if (isa<scf::ParallelOp>(op->getParentOp())) {
139       // Terminator is rewritten as part of the "affine.parallel" lowering
140       // pattern.
141       return failure();
142     }
143     rewriter.replaceOpWithNewOp<scf::YieldOp>(op, op.getOperands());
144     return success();
145   }
146 };
147 
148 class AffineForLowering : public OpRewritePattern<AffineForOp> {
149 public:
150   using OpRewritePattern<AffineForOp>::OpRewritePattern;
151 
matchAndRewrite(AffineForOp op,PatternRewriter & rewriter) const152   LogicalResult matchAndRewrite(AffineForOp op,
153                                 PatternRewriter &rewriter) const override {
154     Location loc = op.getLoc();
155     Value lowerBound = lowerAffineLowerBound(op, rewriter);
156     Value upperBound = lowerAffineUpperBound(op, rewriter);
157     Value step =
158         rewriter.create<arith::ConstantIndexOp>(loc, op.getStepAsInt());
159     auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
160                                                 step, op.getInits());
161     rewriter.eraseBlock(scfForOp.getBody());
162     rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(),
163                                 scfForOp.getRegion().end());
164     rewriter.replaceOp(op, scfForOp.getResults());
165     return success();
166   }
167 };
168 
169 /// Convert an `affine.parallel` (loop nest) operation into a `scf.parallel`
170 /// operation.
171 class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
172 public:
173   using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
174 
matchAndRewrite(AffineParallelOp op,PatternRewriter & rewriter) const175   LogicalResult matchAndRewrite(AffineParallelOp op,
176                                 PatternRewriter &rewriter) const override {
177     Location loc = op.getLoc();
178     SmallVector<Value, 8> steps;
179     SmallVector<Value, 8> upperBoundTuple;
180     SmallVector<Value, 8> lowerBoundTuple;
181     SmallVector<Value, 8> identityVals;
182     // Emit IR computing the lower and upper bound by expanding the map
183     // expression.
184     lowerBoundTuple.reserve(op.getNumDims());
185     upperBoundTuple.reserve(op.getNumDims());
186     for (unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
187       Value lower = lowerAffineMapMax(rewriter, loc, op.getLowerBoundMap(i),
188                                       op.getLowerBoundsOperands());
189       if (!lower)
190         return rewriter.notifyMatchFailure(op, "couldn't convert lower bounds");
191       lowerBoundTuple.push_back(lower);
192 
193       Value upper = lowerAffineMapMin(rewriter, loc, op.getUpperBoundMap(i),
194                                       op.getUpperBoundsOperands());
195       if (!upper)
196         return rewriter.notifyMatchFailure(op, "couldn't convert upper bounds");
197       upperBoundTuple.push_back(upper);
198     }
199     steps.reserve(op.getSteps().size());
200     for (int64_t step : op.getSteps())
201       steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
202 
203     // Get the terminator op.
204     auto affineParOpTerminator =
205         cast<AffineYieldOp>(op.getBody()->getTerminator());
206     scf::ParallelOp parOp;
207     if (op.getResults().empty()) {
208       // Case with no reduction operations/return values.
209       parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
210                                                upperBoundTuple, steps,
211                                                /*bodyBuilderFn=*/nullptr);
212       rewriter.eraseBlock(parOp.getBody());
213       rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
214                                   parOp.getRegion().end());
215       rewriter.replaceOp(op, parOp.getResults());
216       rewriter.setInsertionPoint(affineParOpTerminator);
217       rewriter.replaceOpWithNewOp<scf::ReduceOp>(affineParOpTerminator);
218       return success();
219     }
220     // Case with affine.parallel with reduction operations/return values.
221     // scf.parallel handles the reduction operation differently unlike
222     // affine.parallel.
223     ArrayRef<Attribute> reductions = op.getReductions().getValue();
224     for (auto pair : llvm::zip(reductions, op.getResultTypes())) {
225       // For each of the reduction operations get the identity values for
226       // initialization of the result values.
227       Attribute reduction = std::get<0>(pair);
228       Type resultType = std::get<1>(pair);
229       std::optional<arith::AtomicRMWKind> reductionOp =
230           arith::symbolizeAtomicRMWKind(
231               static_cast<uint64_t>(cast<IntegerAttr>(reduction).getInt()));
232       assert(reductionOp && "Reduction operation cannot be of None Type");
233       arith::AtomicRMWKind reductionOpValue = *reductionOp;
234       identityVals.push_back(
235           arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
236     }
237     parOp = rewriter.create<scf::ParallelOp>(
238         loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
239         /*bodyBuilderFn=*/nullptr);
240 
241     //  Copy the body of the affine.parallel op.
242     rewriter.eraseBlock(parOp.getBody());
243     rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(),
244                                 parOp.getRegion().end());
245     assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
246            "Unequal number of reductions and operands.");
247 
248     // Emit new "scf.reduce" terminator.
249     rewriter.setInsertionPoint(affineParOpTerminator);
250     auto reduceOp = rewriter.replaceOpWithNewOp<scf::ReduceOp>(
251         affineParOpTerminator, affineParOpTerminator->getOperands());
252     for (unsigned i = 0, end = reductions.size(); i < end; i++) {
253       // For each of the reduction operations get the respective mlir::Value.
254       std::optional<arith::AtomicRMWKind> reductionOp =
255           arith::symbolizeAtomicRMWKind(
256               cast<IntegerAttr>(reductions[i]).getInt());
257       assert(reductionOp && "Reduction Operation cannot be of None Type");
258       arith::AtomicRMWKind reductionOpValue = *reductionOp;
259       rewriter.setInsertionPoint(&parOp.getBody()->back());
260       Block &reductionBody = reduceOp.getReductions()[i].front();
261       rewriter.setInsertionPointToEnd(&reductionBody);
262       Value reductionResult = arith::getReductionOp(
263           reductionOpValue, rewriter, loc, reductionBody.getArgument(0),
264           reductionBody.getArgument(1));
265       rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
266     }
267     rewriter.replaceOp(op, parOp.getResults());
268     return success();
269   }
270 };
271 
272 class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
273 public:
274   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
275 
matchAndRewrite(AffineIfOp op,PatternRewriter & rewriter) const276   LogicalResult matchAndRewrite(AffineIfOp op,
277                                 PatternRewriter &rewriter) const override {
278     auto loc = op.getLoc();
279 
280     // Now we just have to handle the condition logic.
281     auto integerSet = op.getIntegerSet();
282     Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0);
283     SmallVector<Value, 8> operands(op.getOperands());
284     auto operandsRef = llvm::ArrayRef(operands);
285 
286     // Calculate cond as a conjunction without short-circuiting.
287     Value cond = nullptr;
288     for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
289       AffineExpr constraintExpr = integerSet.getConstraint(i);
290       bool isEquality = integerSet.isEq(i);
291 
292       // Build and apply an affine expression
293       auto numDims = integerSet.getNumDims();
294       Value affResult = expandAffineExpr(rewriter, loc, constraintExpr,
295                                          operandsRef.take_front(numDims),
296                                          operandsRef.drop_front(numDims));
297       if (!affResult)
298         return failure();
299       auto pred =
300           isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
301       Value cmpVal =
302           rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
303       cond = cond
304                  ? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult()
305                  : cmpVal;
306     }
307     cond = cond ? cond
308                 : rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
309                                                         /*width=*/1);
310 
311     bool hasElseRegion = !op.getElseRegion().empty();
312     auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
313                                            hasElseRegion);
314     rewriter.inlineRegionBefore(op.getThenRegion(),
315                                 &ifOp.getThenRegion().back());
316     rewriter.eraseBlock(&ifOp.getThenRegion().back());
317     if (hasElseRegion) {
318       rewriter.inlineRegionBefore(op.getElseRegion(),
319                                   &ifOp.getElseRegion().back());
320       rewriter.eraseBlock(&ifOp.getElseRegion().back());
321     }
322 
323     // Replace the Affine IfOp finally.
324     rewriter.replaceOp(op, ifOp.getResults());
325     return success();
326   }
327 };
328 
329 /// Convert an "affine.apply" operation into a sequence of arithmetic
330 /// operations using the StandardOps dialect.
331 class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
332 public:
333   using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
334 
matchAndRewrite(AffineApplyOp op,PatternRewriter & rewriter) const335   LogicalResult matchAndRewrite(AffineApplyOp op,
336                                 PatternRewriter &rewriter) const override {
337     auto maybeExpandedMap =
338         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
339                         llvm::to_vector<8>(op.getOperands()));
340     if (!maybeExpandedMap)
341       return failure();
342     rewriter.replaceOp(op, *maybeExpandedMap);
343     return success();
344   }
345 };
346 
347 /// Apply the affine map from an 'affine.load' operation to its operands, and
348 /// feed the results to a newly created 'memref.load' operation (which replaces
349 /// the original 'affine.load').
350 class AffineLoadLowering : public OpRewritePattern<AffineLoadOp> {
351 public:
352   using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
353 
matchAndRewrite(AffineLoadOp op,PatternRewriter & rewriter) const354   LogicalResult matchAndRewrite(AffineLoadOp op,
355                                 PatternRewriter &rewriter) const override {
356     // Expand affine map from 'affineLoadOp'.
357     SmallVector<Value, 8> indices(op.getMapOperands());
358     auto resultOperands =
359         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
360     if (!resultOperands)
361       return failure();
362 
363     // Build vector.load memref[expandedMap.results].
364     rewriter.replaceOpWithNewOp<memref::LoadOp>(op, op.getMemRef(),
365                                                 *resultOperands);
366     return success();
367   }
368 };
369 
370 /// Apply the affine map from an 'affine.prefetch' operation to its operands,
371 /// and feed the results to a newly created 'memref.prefetch' operation (which
372 /// replaces the original 'affine.prefetch').
373 class AffinePrefetchLowering : public OpRewritePattern<AffinePrefetchOp> {
374 public:
375   using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
376 
matchAndRewrite(AffinePrefetchOp op,PatternRewriter & rewriter) const377   LogicalResult matchAndRewrite(AffinePrefetchOp op,
378                                 PatternRewriter &rewriter) const override {
379     // Expand affine map from 'affinePrefetchOp'.
380     SmallVector<Value, 8> indices(op.getMapOperands());
381     auto resultOperands =
382         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
383     if (!resultOperands)
384       return failure();
385 
386     // Build memref.prefetch memref[expandedMap.results].
387     rewriter.replaceOpWithNewOp<memref::PrefetchOp>(
388         op, op.getMemref(), *resultOperands, op.getIsWrite(),
389         op.getLocalityHint(), op.getIsDataCache());
390     return success();
391   }
392 };
393 
394 /// Apply the affine map from an 'affine.store' operation to its operands, and
395 /// feed the results to a newly created 'memref.store' operation (which replaces
396 /// the original 'affine.store').
397 class AffineStoreLowering : public OpRewritePattern<AffineStoreOp> {
398 public:
399   using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
400 
matchAndRewrite(AffineStoreOp op,PatternRewriter & rewriter) const401   LogicalResult matchAndRewrite(AffineStoreOp op,
402                                 PatternRewriter &rewriter) const override {
403     // Expand affine map from 'affineStoreOp'.
404     SmallVector<Value, 8> indices(op.getMapOperands());
405     auto maybeExpandedMap =
406         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
407     if (!maybeExpandedMap)
408       return failure();
409 
410     // Build memref.store valueToStore, memref[expandedMap.results].
411     rewriter.replaceOpWithNewOp<memref::StoreOp>(
412         op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
413     return success();
414   }
415 };
416 
417 /// Apply the affine maps from an 'affine.dma_start' operation to each of their
418 /// respective map operands, and feed the results to a newly created
419 /// 'memref.dma_start' operation (which replaces the original
420 /// 'affine.dma_start').
421 class AffineDmaStartLowering : public OpRewritePattern<AffineDmaStartOp> {
422 public:
423   using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
424 
matchAndRewrite(AffineDmaStartOp op,PatternRewriter & rewriter) const425   LogicalResult matchAndRewrite(AffineDmaStartOp op,
426                                 PatternRewriter &rewriter) const override {
427     SmallVector<Value, 8> operands(op.getOperands());
428     auto operandsRef = llvm::ArrayRef(operands);
429 
430     // Expand affine map for DMA source memref.
431     auto maybeExpandedSrcMap = expandAffineMap(
432         rewriter, op.getLoc(), op.getSrcMap(),
433         operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
434     if (!maybeExpandedSrcMap)
435       return failure();
436     // Expand affine map for DMA destination memref.
437     auto maybeExpandedDstMap = expandAffineMap(
438         rewriter, op.getLoc(), op.getDstMap(),
439         operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
440     if (!maybeExpandedDstMap)
441       return failure();
442     // Expand affine map for DMA tag memref.
443     auto maybeExpandedTagMap = expandAffineMap(
444         rewriter, op.getLoc(), op.getTagMap(),
445         operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
446     if (!maybeExpandedTagMap)
447       return failure();
448 
449     // Build memref.dma_start operation with affine map results.
450     rewriter.replaceOpWithNewOp<memref::DmaStartOp>(
451         op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
452         *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
453         *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
454     return success();
455   }
456 };
457 
458 /// Apply the affine map from an 'affine.dma_wait' operation tag memref,
459 /// and feed the results to a newly created 'memref.dma_wait' operation (which
460 /// replaces the original 'affine.dma_wait').
461 class AffineDmaWaitLowering : public OpRewritePattern<AffineDmaWaitOp> {
462 public:
463   using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
464 
matchAndRewrite(AffineDmaWaitOp op,PatternRewriter & rewriter) const465   LogicalResult matchAndRewrite(AffineDmaWaitOp op,
466                                 PatternRewriter &rewriter) const override {
467     // Expand affine map for DMA tag memref.
468     SmallVector<Value, 8> indices(op.getTagIndices());
469     auto maybeExpandedTagMap =
470         expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices);
471     if (!maybeExpandedTagMap)
472       return failure();
473 
474     // Build memref.dma_wait operation with affine map results.
475     rewriter.replaceOpWithNewOp<memref::DmaWaitOp>(
476         op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
477     return success();
478   }
479 };
480 
481 /// Apply the affine map from an 'affine.vector_load' operation to its operands,
482 /// and feed the results to a newly created 'vector.load' operation (which
483 /// replaces the original 'affine.vector_load').
484 class AffineVectorLoadLowering : public OpRewritePattern<AffineVectorLoadOp> {
485 public:
486   using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
487 
matchAndRewrite(AffineVectorLoadOp op,PatternRewriter & rewriter) const488   LogicalResult matchAndRewrite(AffineVectorLoadOp op,
489                                 PatternRewriter &rewriter) const override {
490     // Expand affine map from 'affineVectorLoadOp'.
491     SmallVector<Value, 8> indices(op.getMapOperands());
492     auto resultOperands =
493         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
494     if (!resultOperands)
495       return failure();
496 
497     // Build vector.load memref[expandedMap.results].
498     rewriter.replaceOpWithNewOp<vector::LoadOp>(
499         op, op.getVectorType(), op.getMemRef(), *resultOperands);
500     return success();
501   }
502 };
503 
504 /// Apply the affine map from an 'affine.vector_store' operation to its
505 /// operands, and feed the results to a newly created 'vector.store' operation
506 /// (which replaces the original 'affine.vector_store').
507 class AffineVectorStoreLowering : public OpRewritePattern<AffineVectorStoreOp> {
508 public:
509   using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
510 
matchAndRewrite(AffineVectorStoreOp op,PatternRewriter & rewriter) const511   LogicalResult matchAndRewrite(AffineVectorStoreOp op,
512                                 PatternRewriter &rewriter) const override {
513     // Expand affine map from 'affineVectorStoreOp'.
514     SmallVector<Value, 8> indices(op.getMapOperands());
515     auto maybeExpandedMap =
516         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
517     if (!maybeExpandedMap)
518       return failure();
519 
520     rewriter.replaceOpWithNewOp<vector::StoreOp>(
521         op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
522     return success();
523   }
524 };
525 
526 } // namespace
527 
populateAffineToStdConversionPatterns(RewritePatternSet & patterns)528 void mlir::populateAffineToStdConversionPatterns(RewritePatternSet &patterns) {
529   // clang-format off
530   patterns.add<
531       AffineApplyLowering,
532       AffineDmaStartLowering,
533       AffineDmaWaitLowering,
534       AffineLoadLowering,
535       AffineMinLowering,
536       AffineMaxLowering,
537       AffineParallelLowering,
538       AffinePrefetchLowering,
539       AffineStoreLowering,
540       AffineForLowering,
541       AffineIfLowering,
542       AffineYieldOpLowering>(patterns.getContext());
543   // clang-format on
544 }
545 
populateAffineToVectorConversionPatterns(RewritePatternSet & patterns)546 void mlir::populateAffineToVectorConversionPatterns(
547     RewritePatternSet &patterns) {
548   // clang-format off
549   patterns.add<
550       AffineVectorLoadLowering,
551       AffineVectorStoreLowering>(patterns.getContext());
552   // clang-format on
553 }
554 
555 namespace {
556 class LowerAffinePass
557     : public impl::ConvertAffineToStandardBase<LowerAffinePass> {
runOnOperation()558   void runOnOperation() override {
559     RewritePatternSet patterns(&getContext());
560     populateAffineToStdConversionPatterns(patterns);
561     populateAffineToVectorConversionPatterns(patterns);
562     populateAffineExpandIndexOpsPatterns(patterns);
563     ConversionTarget target(getContext());
564     target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
565                            scf::SCFDialect, VectorDialect>();
566     if (failed(applyPartialConversion(getOperation(), target,
567                                       std::move(patterns))))
568       signalPassFailure();
569   }
570 };
571 } // namespace
572 
573 /// Lowers If and For operations within a function into their lower level CFG
574 /// equivalent blocks.
createLowerAffinePass()575 std::unique_ptr<Pass> mlir::createLowerAffinePass() {
576   return std::make_unique<LowerAffinePass>();
577 }
578