xref: /llvm-project/flang/lib/Optimizer/Transforms/AffineDemotion.cpp (revision 6f8ef5ad2f35321257adbe353f86027bf5209023)
180c27abbSRajan Walia //===-- AffineDemotion.cpp -----------------------------------------------===//
280c27abbSRajan Walia //
380c27abbSRajan Walia // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
480c27abbSRajan Walia // See https://llvm.org/LICENSE.txt for license information.
580c27abbSRajan Walia // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
680c27abbSRajan Walia //
780c27abbSRajan Walia //===----------------------------------------------------------------------===//
8b2169992SValentin Clement //
9b2169992SValentin Clement // This transformation is a prototype that demote affine dialects operations
10b2169992SValentin Clement // after optimizations to FIR loops operations.
11b2169992SValentin Clement // It is used after the AffinePromotion pass.
12b2169992SValentin Clement // It is not part of the production pipeline and would need more work in order
13b2169992SValentin Clement // to be used in production.
14b2169992SValentin Clement // More information can be found in this presentation:
15b2169992SValentin Clement // https://slides.com/rajanwalia/deck
16b2169992SValentin Clement //
17b2169992SValentin Clement //===----------------------------------------------------------------------===//
1880c27abbSRajan Walia 
1980c27abbSRajan Walia #include "flang/Optimizer/Dialect/FIRDialect.h"
2080c27abbSRajan Walia #include "flang/Optimizer/Dialect/FIROps.h"
2180c27abbSRajan Walia #include "flang/Optimizer/Dialect/FIRType.h"
2280c27abbSRajan Walia #include "flang/Optimizer/Transforms/Passes.h"
2380c27abbSRajan Walia #include "mlir/Dialect/Affine/IR/AffineOps.h"
247d7ebf3cSUday Bondhugula #include "mlir/Dialect/Affine/Utils.h"
2523aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
2680c27abbSRajan Walia #include "mlir/Dialect/MemRef/IR/MemRef.h"
278b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
2880c27abbSRajan Walia #include "mlir/IR/BuiltinAttributes.h"
2980c27abbSRajan Walia #include "mlir/IR/IntegerSet.h"
3080c27abbSRajan Walia #include "mlir/IR/Visitors.h"
3180c27abbSRajan Walia #include "mlir/Pass/Pass.h"
3280c27abbSRajan Walia #include "mlir/Transforms/DialectConversion.h"
3380c27abbSRajan Walia #include "llvm/ADT/DenseMap.h"
3480c27abbSRajan Walia #include "llvm/Support/CommandLine.h"
3580c27abbSRajan Walia #include "llvm/Support/Debug.h"
3680c27abbSRajan Walia 
3767d0d7acSMichele Scuttari namespace fir {
3867d0d7acSMichele Scuttari #define GEN_PASS_DEF_AFFINEDIALECTDEMOTION
3967d0d7acSMichele Scuttari #include "flang/Optimizer/Transforms/Passes.h.inc"
4067d0d7acSMichele Scuttari } // namespace fir
4167d0d7acSMichele Scuttari 
4280c27abbSRajan Walia #define DEBUG_TYPE "flang-affine-demotion"
4380c27abbSRajan Walia 
4480c27abbSRajan Walia using namespace fir;
45092601d4SAndrzej Warzynski using namespace mlir;
4680c27abbSRajan Walia 
4780c27abbSRajan Walia namespace {
4880c27abbSRajan Walia 
494c48f016SMatthias Springer class AffineLoadConversion
504c48f016SMatthias Springer     : public OpConversionPattern<mlir::affine::AffineLoadOp> {
5180c27abbSRajan Walia public:
524c48f016SMatthias Springer   using OpConversionPattern<mlir::affine::AffineLoadOp>::OpConversionPattern;
5380c27abbSRajan Walia 
5442831686SRiver Riddle   LogicalResult
554c48f016SMatthias Springer   matchAndRewrite(mlir::affine::AffineLoadOp op, OpAdaptor adaptor,
5642831686SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
578df54a6aSJacques Pienaar     SmallVector<Value> indices(adaptor.getIndices());
584c48f016SMatthias Springer     auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(),
594c48f016SMatthias Springer                                                     op.getAffineMap(), indices);
6080c27abbSRajan Walia     if (!maybeExpandedMap)
6180c27abbSRajan Walia       return failure();
6280c27abbSRajan Walia 
6380c27abbSRajan Walia     auto coorOp = rewriter.create<fir::CoordinateOp>(
6480c27abbSRajan Walia         op.getLoc(), fir::ReferenceType::get(op.getResult().getType()),
658df54a6aSJacques Pienaar         adaptor.getMemref(), *maybeExpandedMap);
6680c27abbSRajan Walia 
6780c27abbSRajan Walia     rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult());
6880c27abbSRajan Walia     return success();
6980c27abbSRajan Walia   }
7080c27abbSRajan Walia };
7180c27abbSRajan Walia 
724c48f016SMatthias Springer class AffineStoreConversion
734c48f016SMatthias Springer     : public OpConversionPattern<mlir::affine::AffineStoreOp> {
7480c27abbSRajan Walia public:
754c48f016SMatthias Springer   using OpConversionPattern<mlir::affine::AffineStoreOp>::OpConversionPattern;
7680c27abbSRajan Walia 
7742831686SRiver Riddle   LogicalResult
784c48f016SMatthias Springer   matchAndRewrite(mlir::affine::AffineStoreOp op, OpAdaptor adaptor,
7942831686SRiver Riddle                   ConversionPatternRewriter &rewriter) const override {
808df54a6aSJacques Pienaar     SmallVector<Value> indices(op.getIndices());
814c48f016SMatthias Springer     auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(),
824c48f016SMatthias Springer                                                     op.getAffineMap(), indices);
8380c27abbSRajan Walia     if (!maybeExpandedMap)
8480c27abbSRajan Walia       return failure();
8580c27abbSRajan Walia 
8680c27abbSRajan Walia     auto coorOp = rewriter.create<fir::CoordinateOp>(
8780c27abbSRajan Walia         op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()),
888df54a6aSJacques Pienaar         adaptor.getMemref(), *maybeExpandedMap);
898df54a6aSJacques Pienaar     rewriter.replaceOpWithNewOp<fir::StoreOp>(op, adaptor.getValue(),
9080c27abbSRajan Walia                                               coorOp.getResult());
9180c27abbSRajan Walia     return success();
9280c27abbSRajan Walia   }
9380c27abbSRajan Walia };
9480c27abbSRajan Walia 
9580c27abbSRajan Walia class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> {
9680c27abbSRajan Walia public:
9780c27abbSRajan Walia   using OpRewritePattern::OpRewritePattern;
98db791b27SRamkumar Ramachandra   llvm::LogicalResult
9980c27abbSRajan Walia   matchAndRewrite(fir::ConvertOp op,
10080c27abbSRajan Walia                   mlir::PatternRewriter &rewriter) const override {
101fac349a1SChristian Sigg     if (mlir::isa<mlir::MemRefType>(op.getRes().getType())) {
10280c27abbSRajan Walia       // due to index calculation moving to affine maps we still need to
10380c27abbSRajan Walia       // add converts for sequence types this has a side effect of losing
10480c27abbSRajan Walia       // some information about arrays with known dimensions by creating:
10580c27abbSRajan Walia       // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) ->
10680c27abbSRajan Walia       // !fir.ref<!fir.array<?xi32>>
107fac349a1SChristian Sigg       if (auto refTy =
108fac349a1SChristian Sigg               mlir::dyn_cast<fir::ReferenceType>(op.getValue().getType()))
109fac349a1SChristian Sigg         if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(refTy.getEleTy())) {
11080c27abbSRajan Walia           fir::SequenceType::Shape flatShape = {
11180c27abbSRajan Walia               fir::SequenceType::getUnknownExtent()};
11280c27abbSRajan Walia           auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy());
11380c27abbSRajan Walia           auto flatTy = fir::ReferenceType::get(flatArrTy);
114149ad3d5SShraiysh Vaishay           rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy,
115149ad3d5SShraiysh Vaishay                                                       op.getValue());
11680c27abbSRajan Walia           return success();
11780c27abbSRajan Walia         }
1185fcf907bSMatthias Springer       rewriter.startOpModification(op->getParentOp());
119149ad3d5SShraiysh Vaishay       op.getResult().replaceAllUsesWith(op.getValue());
1205fcf907bSMatthias Springer       rewriter.finalizeOpModification(op->getParentOp());
12180c27abbSRajan Walia       rewriter.eraseOp(op);
12280c27abbSRajan Walia     }
12380c27abbSRajan Walia     return success();
12480c27abbSRajan Walia   }
12580c27abbSRajan Walia };
12680c27abbSRajan Walia 
12780c27abbSRajan Walia mlir::Type convertMemRef(mlir::MemRefType type) {
128*6f8ef5adSKazu Hirata   return fir::SequenceType::get(SmallVector<int64_t>(type.getShape()),
12980c27abbSRajan Walia                                 type.getElementType());
13080c27abbSRajan Walia }
13180c27abbSRajan Walia 
13280c27abbSRajan Walia class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> {
13380c27abbSRajan Walia public:
13480c27abbSRajan Walia   using OpRewritePattern::OpRewritePattern;
135db791b27SRamkumar Ramachandra   llvm::LogicalResult
13680c27abbSRajan Walia   matchAndRewrite(memref::AllocOp op,
13780c27abbSRajan Walia                   mlir::PatternRewriter &rewriter) const override {
13880c27abbSRajan Walia     rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()),
139c692a11eSRiver Riddle                                                op.getMemref());
14080c27abbSRajan Walia     return success();
14180c27abbSRajan Walia   }
14280c27abbSRajan Walia };
14380c27abbSRajan Walia 
14480c27abbSRajan Walia class AffineDialectDemotion
14567d0d7acSMichele Scuttari     : public fir::impl::AffineDialectDemotionBase<AffineDialectDemotion> {
14680c27abbSRajan Walia public:
147196c4279SRiver Riddle   void runOnOperation() override {
14880c27abbSRajan Walia     auto *context = &getContext();
149196c4279SRiver Riddle     auto function = getOperation();
15080c27abbSRajan Walia     LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n";
15180c27abbSRajan Walia                function.print(llvm::dbgs()););
15280c27abbSRajan Walia 
1539f85c198SRiver Riddle     mlir::RewritePatternSet patterns(context);
15480c27abbSRajan Walia     patterns.insert<ConvertConversion>(context);
15580c27abbSRajan Walia     patterns.insert<AffineLoadConversion>(context);
15680c27abbSRajan Walia     patterns.insert<AffineStoreConversion>(context);
15780c27abbSRajan Walia     patterns.insert<StdAllocConversion>(context);
15880c27abbSRajan Walia     mlir::ConversionTarget target(*context);
15980c27abbSRajan Walia     target.addIllegalOp<memref::AllocOp>();
16080c27abbSRajan Walia     target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) {
161fac349a1SChristian Sigg       if (mlir::isa<mlir::MemRefType>(op.getRes().getType()))
16280c27abbSRajan Walia         return false;
16380c27abbSRajan Walia       return true;
16480c27abbSRajan Walia     });
1656c8d8d10SJakub Kuderski     target
1666c8d8d10SJakub Kuderski         .addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
1676c8d8d10SJakub Kuderski                          mlir::arith::ArithDialect, mlir::func::FuncDialect>();
16880c27abbSRajan Walia 
16980c27abbSRajan Walia     if (mlir::failed(mlir::applyPartialConversion(function, target,
17080c27abbSRajan Walia                                                   std::move(patterns)))) {
17180c27abbSRajan Walia       mlir::emitError(mlir::UnknownLoc::get(context),
17280c27abbSRajan Walia                       "error in converting affine dialect\n");
17380c27abbSRajan Walia       signalPassFailure();
17480c27abbSRajan Walia     }
17580c27abbSRajan Walia   }
17680c27abbSRajan Walia };
17780c27abbSRajan Walia 
17880c27abbSRajan Walia } // namespace
17980c27abbSRajan Walia 
18080c27abbSRajan Walia std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() {
18180c27abbSRajan Walia   return std::make_unique<AffineDialectDemotion>();
18280c27abbSRajan Walia }
183