xref: /llvm-project/flang/lib/Optimizer/Transforms/AffineDemotion.cpp (revision 6f8ef5ad2f35321257adbe353f86027bf5209023)
1 //===-- AffineDemotion.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation is a prototype that demote affine dialects operations
10 // after optimizations to FIR loops operations.
11 // It is used after the AffinePromotion pass.
12 // It is not part of the production pipeline and would need more work in order
13 // to be used in production.
14 // More information can be found in this presentation:
15 // https://slides.com/rajanwalia/deck
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "flang/Optimizer/Dialect/FIRDialect.h"
20 #include "flang/Optimizer/Dialect/FIROps.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/Transforms/Passes.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/Affine/Utils.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/IR/SCF.h"
28 #include "mlir/IR/BuiltinAttributes.h"
29 #include "mlir/IR/IntegerSet.h"
30 #include "mlir/IR/Visitors.h"
31 #include "mlir/Pass/Pass.h"
32 #include "mlir/Transforms/DialectConversion.h"
33 #include "llvm/ADT/DenseMap.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/Debug.h"
36 
37 namespace fir {
38 #define GEN_PASS_DEF_AFFINEDIALECTDEMOTION
39 #include "flang/Optimizer/Transforms/Passes.h.inc"
40 } // namespace fir
41 
42 #define DEBUG_TYPE "flang-affine-demotion"
43 
44 using namespace fir;
45 using namespace mlir;
46 
47 namespace {
48 
49 class AffineLoadConversion
50     : public OpConversionPattern<mlir::affine::AffineLoadOp> {
51 public:
52   using OpConversionPattern<mlir::affine::AffineLoadOp>::OpConversionPattern;
53 
54   LogicalResult
55   matchAndRewrite(mlir::affine::AffineLoadOp op, OpAdaptor adaptor,
56                   ConversionPatternRewriter &rewriter) const override {
57     SmallVector<Value> indices(adaptor.getIndices());
58     auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(),
59                                                     op.getAffineMap(), indices);
60     if (!maybeExpandedMap)
61       return failure();
62 
63     auto coorOp = rewriter.create<fir::CoordinateOp>(
64         op.getLoc(), fir::ReferenceType::get(op.getResult().getType()),
65         adaptor.getMemref(), *maybeExpandedMap);
66 
67     rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult());
68     return success();
69   }
70 };
71 
72 class AffineStoreConversion
73     : public OpConversionPattern<mlir::affine::AffineStoreOp> {
74 public:
75   using OpConversionPattern<mlir::affine::AffineStoreOp>::OpConversionPattern;
76 
77   LogicalResult
78   matchAndRewrite(mlir::affine::AffineStoreOp op, OpAdaptor adaptor,
79                   ConversionPatternRewriter &rewriter) const override {
80     SmallVector<Value> indices(op.getIndices());
81     auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(),
82                                                     op.getAffineMap(), indices);
83     if (!maybeExpandedMap)
84       return failure();
85 
86     auto coorOp = rewriter.create<fir::CoordinateOp>(
87         op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()),
88         adaptor.getMemref(), *maybeExpandedMap);
89     rewriter.replaceOpWithNewOp<fir::StoreOp>(op, adaptor.getValue(),
90                                               coorOp.getResult());
91     return success();
92   }
93 };
94 
95 class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> {
96 public:
97   using OpRewritePattern::OpRewritePattern;
98   llvm::LogicalResult
99   matchAndRewrite(fir::ConvertOp op,
100                   mlir::PatternRewriter &rewriter) const override {
101     if (mlir::isa<mlir::MemRefType>(op.getRes().getType())) {
102       // due to index calculation moving to affine maps we still need to
103       // add converts for sequence types this has a side effect of losing
104       // some information about arrays with known dimensions by creating:
105       // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) ->
106       // !fir.ref<!fir.array<?xi32>>
107       if (auto refTy =
108               mlir::dyn_cast<fir::ReferenceType>(op.getValue().getType()))
109         if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(refTy.getEleTy())) {
110           fir::SequenceType::Shape flatShape = {
111               fir::SequenceType::getUnknownExtent()};
112           auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy());
113           auto flatTy = fir::ReferenceType::get(flatArrTy);
114           rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy,
115                                                       op.getValue());
116           return success();
117         }
118       rewriter.startOpModification(op->getParentOp());
119       op.getResult().replaceAllUsesWith(op.getValue());
120       rewriter.finalizeOpModification(op->getParentOp());
121       rewriter.eraseOp(op);
122     }
123     return success();
124   }
125 };
126 
127 mlir::Type convertMemRef(mlir::MemRefType type) {
128   return fir::SequenceType::get(SmallVector<int64_t>(type.getShape()),
129                                 type.getElementType());
130 }
131 
132 class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> {
133 public:
134   using OpRewritePattern::OpRewritePattern;
135   llvm::LogicalResult
136   matchAndRewrite(memref::AllocOp op,
137                   mlir::PatternRewriter &rewriter) const override {
138     rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()),
139                                                op.getMemref());
140     return success();
141   }
142 };
143 
144 class AffineDialectDemotion
145     : public fir::impl::AffineDialectDemotionBase<AffineDialectDemotion> {
146 public:
147   void runOnOperation() override {
148     auto *context = &getContext();
149     auto function = getOperation();
150     LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n";
151                function.print(llvm::dbgs()););
152 
153     mlir::RewritePatternSet patterns(context);
154     patterns.insert<ConvertConversion>(context);
155     patterns.insert<AffineLoadConversion>(context);
156     patterns.insert<AffineStoreConversion>(context);
157     patterns.insert<StdAllocConversion>(context);
158     mlir::ConversionTarget target(*context);
159     target.addIllegalOp<memref::AllocOp>();
160     target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) {
161       if (mlir::isa<mlir::MemRefType>(op.getRes().getType()))
162         return false;
163       return true;
164     });
165     target
166         .addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
167                          mlir::arith::ArithDialect, mlir::func::FuncDialect>();
168 
169     if (mlir::failed(mlir::applyPartialConversion(function, target,
170                                                   std::move(patterns)))) {
171       mlir::emitError(mlir::UnknownLoc::get(context),
172                       "error in converting affine dialect\n");
173       signalPassFailure();
174     }
175   }
176 };
177 
178 } // namespace
179 
180 std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() {
181   return std::make_unique<AffineDialectDemotion>();
182 }
183