180c27abbSRajan Walia //===-- AffineDemotion.cpp -----------------------------------------------===// 280c27abbSRajan Walia // 380c27abbSRajan Walia // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 480c27abbSRajan Walia // See https://llvm.org/LICENSE.txt for license information. 580c27abbSRajan Walia // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 680c27abbSRajan Walia // 780c27abbSRajan Walia //===----------------------------------------------------------------------===// 8b2169992SValentin Clement // 9b2169992SValentin Clement // This transformation is a prototype that demote affine dialects operations 10b2169992SValentin Clement // after optimizations to FIR loops operations. 11b2169992SValentin Clement // It is used after the AffinePromotion pass. 12b2169992SValentin Clement // It is not part of the production pipeline and would need more work in order 13b2169992SValentin Clement // to be used in production. 14b2169992SValentin Clement // More information can be found in this presentation: 15b2169992SValentin Clement // https://slides.com/rajanwalia/deck 16b2169992SValentin Clement // 17b2169992SValentin Clement //===----------------------------------------------------------------------===// 1880c27abbSRajan Walia 1980c27abbSRajan Walia #include "flang/Optimizer/Dialect/FIRDialect.h" 2080c27abbSRajan Walia #include "flang/Optimizer/Dialect/FIROps.h" 2180c27abbSRajan Walia #include "flang/Optimizer/Dialect/FIRType.h" 2280c27abbSRajan Walia #include "flang/Optimizer/Transforms/Passes.h" 2380c27abbSRajan Walia #include "mlir/Dialect/Affine/IR/AffineOps.h" 247d7ebf3cSUday Bondhugula #include "mlir/Dialect/Affine/Utils.h" 2523aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h" 2680c27abbSRajan Walia #include "mlir/Dialect/MemRef/IR/MemRef.h" 278b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 2880c27abbSRajan Walia #include "mlir/IR/BuiltinAttributes.h" 2980c27abbSRajan Walia #include "mlir/IR/IntegerSet.h" 3080c27abbSRajan Walia #include "mlir/IR/Visitors.h" 3180c27abbSRajan Walia #include "mlir/Pass/Pass.h" 3280c27abbSRajan Walia #include "mlir/Transforms/DialectConversion.h" 3380c27abbSRajan Walia #include "llvm/ADT/DenseMap.h" 3480c27abbSRajan Walia #include "llvm/Support/CommandLine.h" 3580c27abbSRajan Walia #include "llvm/Support/Debug.h" 3680c27abbSRajan Walia 3767d0d7acSMichele Scuttari namespace fir { 3867d0d7acSMichele Scuttari #define GEN_PASS_DEF_AFFINEDIALECTDEMOTION 3967d0d7acSMichele Scuttari #include "flang/Optimizer/Transforms/Passes.h.inc" 4067d0d7acSMichele Scuttari } // namespace fir 4167d0d7acSMichele Scuttari 4280c27abbSRajan Walia #define DEBUG_TYPE "flang-affine-demotion" 4380c27abbSRajan Walia 4480c27abbSRajan Walia using namespace fir; 45092601d4SAndrzej Warzynski using namespace mlir; 4680c27abbSRajan Walia 4780c27abbSRajan Walia namespace { 4880c27abbSRajan Walia 494c48f016SMatthias Springer class AffineLoadConversion 504c48f016SMatthias Springer : public OpConversionPattern<mlir::affine::AffineLoadOp> { 5180c27abbSRajan Walia public: 524c48f016SMatthias Springer using OpConversionPattern<mlir::affine::AffineLoadOp>::OpConversionPattern; 5380c27abbSRajan Walia 5442831686SRiver Riddle LogicalResult 554c48f016SMatthias Springer matchAndRewrite(mlir::affine::AffineLoadOp op, OpAdaptor adaptor, 5642831686SRiver Riddle ConversionPatternRewriter &rewriter) const override { 578df54a6aSJacques Pienaar SmallVector<Value> indices(adaptor.getIndices()); 584c48f016SMatthias Springer auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(), 594c48f016SMatthias Springer op.getAffineMap(), indices); 6080c27abbSRajan Walia if (!maybeExpandedMap) 6180c27abbSRajan Walia return failure(); 6280c27abbSRajan Walia 6380c27abbSRajan Walia auto coorOp = rewriter.create<fir::CoordinateOp>( 6480c27abbSRajan Walia op.getLoc(), fir::ReferenceType::get(op.getResult().getType()), 658df54a6aSJacques Pienaar adaptor.getMemref(), *maybeExpandedMap); 6680c27abbSRajan Walia 6780c27abbSRajan Walia rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult()); 6880c27abbSRajan Walia return success(); 6980c27abbSRajan Walia } 7080c27abbSRajan Walia }; 7180c27abbSRajan Walia 724c48f016SMatthias Springer class AffineStoreConversion 734c48f016SMatthias Springer : public OpConversionPattern<mlir::affine::AffineStoreOp> { 7480c27abbSRajan Walia public: 754c48f016SMatthias Springer using OpConversionPattern<mlir::affine::AffineStoreOp>::OpConversionPattern; 7680c27abbSRajan Walia 7742831686SRiver Riddle LogicalResult 784c48f016SMatthias Springer matchAndRewrite(mlir::affine::AffineStoreOp op, OpAdaptor adaptor, 7942831686SRiver Riddle ConversionPatternRewriter &rewriter) const override { 808df54a6aSJacques Pienaar SmallVector<Value> indices(op.getIndices()); 814c48f016SMatthias Springer auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(), 824c48f016SMatthias Springer op.getAffineMap(), indices); 8380c27abbSRajan Walia if (!maybeExpandedMap) 8480c27abbSRajan Walia return failure(); 8580c27abbSRajan Walia 8680c27abbSRajan Walia auto coorOp = rewriter.create<fir::CoordinateOp>( 8780c27abbSRajan Walia op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()), 888df54a6aSJacques Pienaar adaptor.getMemref(), *maybeExpandedMap); 898df54a6aSJacques Pienaar rewriter.replaceOpWithNewOp<fir::StoreOp>(op, adaptor.getValue(), 9080c27abbSRajan Walia coorOp.getResult()); 9180c27abbSRajan Walia return success(); 9280c27abbSRajan Walia } 9380c27abbSRajan Walia }; 9480c27abbSRajan Walia 9580c27abbSRajan Walia class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> { 9680c27abbSRajan Walia public: 9780c27abbSRajan Walia using OpRewritePattern::OpRewritePattern; 98db791b27SRamkumar Ramachandra llvm::LogicalResult 9980c27abbSRajan Walia matchAndRewrite(fir::ConvertOp op, 10080c27abbSRajan Walia mlir::PatternRewriter &rewriter) const override { 101fac349a1SChristian Sigg if (mlir::isa<mlir::MemRefType>(op.getRes().getType())) { 10280c27abbSRajan Walia // due to index calculation moving to affine maps we still need to 10380c27abbSRajan Walia // add converts for sequence types this has a side effect of losing 10480c27abbSRajan Walia // some information about arrays with known dimensions by creating: 10580c27abbSRajan Walia // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) -> 10680c27abbSRajan Walia // !fir.ref<!fir.array<?xi32>> 107fac349a1SChristian Sigg if (auto refTy = 108fac349a1SChristian Sigg mlir::dyn_cast<fir::ReferenceType>(op.getValue().getType())) 109fac349a1SChristian Sigg if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(refTy.getEleTy())) { 11080c27abbSRajan Walia fir::SequenceType::Shape flatShape = { 11180c27abbSRajan Walia fir::SequenceType::getUnknownExtent()}; 11280c27abbSRajan Walia auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy()); 11380c27abbSRajan Walia auto flatTy = fir::ReferenceType::get(flatArrTy); 114149ad3d5SShraiysh Vaishay rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy, 115149ad3d5SShraiysh Vaishay op.getValue()); 11680c27abbSRajan Walia return success(); 11780c27abbSRajan Walia } 1185fcf907bSMatthias Springer rewriter.startOpModification(op->getParentOp()); 119149ad3d5SShraiysh Vaishay op.getResult().replaceAllUsesWith(op.getValue()); 1205fcf907bSMatthias Springer rewriter.finalizeOpModification(op->getParentOp()); 12180c27abbSRajan Walia rewriter.eraseOp(op); 12280c27abbSRajan Walia } 12380c27abbSRajan Walia return success(); 12480c27abbSRajan Walia } 12580c27abbSRajan Walia }; 12680c27abbSRajan Walia 12780c27abbSRajan Walia mlir::Type convertMemRef(mlir::MemRefType type) { 128*6f8ef5adSKazu Hirata return fir::SequenceType::get(SmallVector<int64_t>(type.getShape()), 12980c27abbSRajan Walia type.getElementType()); 13080c27abbSRajan Walia } 13180c27abbSRajan Walia 13280c27abbSRajan Walia class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> { 13380c27abbSRajan Walia public: 13480c27abbSRajan Walia using OpRewritePattern::OpRewritePattern; 135db791b27SRamkumar Ramachandra llvm::LogicalResult 13680c27abbSRajan Walia matchAndRewrite(memref::AllocOp op, 13780c27abbSRajan Walia mlir::PatternRewriter &rewriter) const override { 13880c27abbSRajan Walia rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()), 139c692a11eSRiver Riddle op.getMemref()); 14080c27abbSRajan Walia return success(); 14180c27abbSRajan Walia } 14280c27abbSRajan Walia }; 14380c27abbSRajan Walia 14480c27abbSRajan Walia class AffineDialectDemotion 14567d0d7acSMichele Scuttari : public fir::impl::AffineDialectDemotionBase<AffineDialectDemotion> { 14680c27abbSRajan Walia public: 147196c4279SRiver Riddle void runOnOperation() override { 14880c27abbSRajan Walia auto *context = &getContext(); 149196c4279SRiver Riddle auto function = getOperation(); 15080c27abbSRajan Walia LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n"; 15180c27abbSRajan Walia function.print(llvm::dbgs());); 15280c27abbSRajan Walia 1539f85c198SRiver Riddle mlir::RewritePatternSet patterns(context); 15480c27abbSRajan Walia patterns.insert<ConvertConversion>(context); 15580c27abbSRajan Walia patterns.insert<AffineLoadConversion>(context); 15680c27abbSRajan Walia patterns.insert<AffineStoreConversion>(context); 15780c27abbSRajan Walia patterns.insert<StdAllocConversion>(context); 15880c27abbSRajan Walia mlir::ConversionTarget target(*context); 15980c27abbSRajan Walia target.addIllegalOp<memref::AllocOp>(); 16080c27abbSRajan Walia target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) { 161fac349a1SChristian Sigg if (mlir::isa<mlir::MemRefType>(op.getRes().getType())) 16280c27abbSRajan Walia return false; 16380c27abbSRajan Walia return true; 16480c27abbSRajan Walia }); 1656c8d8d10SJakub Kuderski target 1666c8d8d10SJakub Kuderski .addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, 1676c8d8d10SJakub Kuderski mlir::arith::ArithDialect, mlir::func::FuncDialect>(); 16880c27abbSRajan Walia 16980c27abbSRajan Walia if (mlir::failed(mlir::applyPartialConversion(function, target, 17080c27abbSRajan Walia std::move(patterns)))) { 17180c27abbSRajan Walia mlir::emitError(mlir::UnknownLoc::get(context), 17280c27abbSRajan Walia "error in converting affine dialect\n"); 17380c27abbSRajan Walia signalPassFailure(); 17480c27abbSRajan Walia } 17580c27abbSRajan Walia } 17680c27abbSRajan Walia }; 17780c27abbSRajan Walia 17880c27abbSRajan Walia } // namespace 17980c27abbSRajan Walia 18080c27abbSRajan Walia std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() { 18180c27abbSRajan Walia return std::make_unique<AffineDialectDemotion>(); 18280c27abbSRajan Walia } 183