1 //===-- AffineDemotion.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation is a prototype that demote affine dialects operations 10 // after optimizations to FIR loops operations. 11 // It is used after the AffinePromotion pass. 12 // It is not part of the production pipeline and would need more work in order 13 // to be used in production. 14 // More information can be found in this presentation: 15 // https://slides.com/rajanwalia/deck 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "flang/Optimizer/Dialect/FIRDialect.h" 20 #include "flang/Optimizer/Dialect/FIROps.h" 21 #include "flang/Optimizer/Dialect/FIRType.h" 22 #include "flang/Optimizer/Transforms/Passes.h" 23 #include "mlir/Dialect/Affine/IR/AffineOps.h" 24 #include "mlir/Dialect/Affine/Utils.h" 25 #include "mlir/Dialect/Func/IR/FuncOps.h" 26 #include "mlir/Dialect/MemRef/IR/MemRef.h" 27 #include "mlir/Dialect/SCF/IR/SCF.h" 28 #include "mlir/IR/BuiltinAttributes.h" 29 #include "mlir/IR/IntegerSet.h" 30 #include "mlir/IR/Visitors.h" 31 #include "mlir/Pass/Pass.h" 32 #include "mlir/Transforms/DialectConversion.h" 33 #include "llvm/ADT/DenseMap.h" 34 #include "llvm/Support/CommandLine.h" 35 #include "llvm/Support/Debug.h" 36 37 namespace fir { 38 #define GEN_PASS_DEF_AFFINEDIALECTDEMOTION 39 #include "flang/Optimizer/Transforms/Passes.h.inc" 40 } // namespace fir 41 42 #define DEBUG_TYPE "flang-affine-demotion" 43 44 using namespace fir; 45 using namespace mlir; 46 47 namespace { 48 49 class AffineLoadConversion 50 : public OpConversionPattern<mlir::affine::AffineLoadOp> { 51 public: 52 using OpConversionPattern<mlir::affine::AffineLoadOp>::OpConversionPattern; 53 54 LogicalResult 55 matchAndRewrite(mlir::affine::AffineLoadOp op, OpAdaptor adaptor, 56 ConversionPatternRewriter &rewriter) const override { 57 SmallVector<Value> indices(adaptor.getIndices()); 58 auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(), 59 op.getAffineMap(), indices); 60 if (!maybeExpandedMap) 61 return failure(); 62 63 auto coorOp = rewriter.create<fir::CoordinateOp>( 64 op.getLoc(), fir::ReferenceType::get(op.getResult().getType()), 65 adaptor.getMemref(), *maybeExpandedMap); 66 67 rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult()); 68 return success(); 69 } 70 }; 71 72 class AffineStoreConversion 73 : public OpConversionPattern<mlir::affine::AffineStoreOp> { 74 public: 75 using OpConversionPattern<mlir::affine::AffineStoreOp>::OpConversionPattern; 76 77 LogicalResult 78 matchAndRewrite(mlir::affine::AffineStoreOp op, OpAdaptor adaptor, 79 ConversionPatternRewriter &rewriter) const override { 80 SmallVector<Value> indices(op.getIndices()); 81 auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(), 82 op.getAffineMap(), indices); 83 if (!maybeExpandedMap) 84 return failure(); 85 86 auto coorOp = rewriter.create<fir::CoordinateOp>( 87 op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()), 88 adaptor.getMemref(), *maybeExpandedMap); 89 rewriter.replaceOpWithNewOp<fir::StoreOp>(op, adaptor.getValue(), 90 coorOp.getResult()); 91 return success(); 92 } 93 }; 94 95 class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> { 96 public: 97 using OpRewritePattern::OpRewritePattern; 98 llvm::LogicalResult 99 matchAndRewrite(fir::ConvertOp op, 100 mlir::PatternRewriter &rewriter) const override { 101 if (mlir::isa<mlir::MemRefType>(op.getRes().getType())) { 102 // due to index calculation moving to affine maps we still need to 103 // add converts for sequence types this has a side effect of losing 104 // some information about arrays with known dimensions by creating: 105 // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) -> 106 // !fir.ref<!fir.array<?xi32>> 107 if (auto refTy = 108 mlir::dyn_cast<fir::ReferenceType>(op.getValue().getType())) 109 if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(refTy.getEleTy())) { 110 fir::SequenceType::Shape flatShape = { 111 fir::SequenceType::getUnknownExtent()}; 112 auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy()); 113 auto flatTy = fir::ReferenceType::get(flatArrTy); 114 rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy, 115 op.getValue()); 116 return success(); 117 } 118 rewriter.startOpModification(op->getParentOp()); 119 op.getResult().replaceAllUsesWith(op.getValue()); 120 rewriter.finalizeOpModification(op->getParentOp()); 121 rewriter.eraseOp(op); 122 } 123 return success(); 124 } 125 }; 126 127 mlir::Type convertMemRef(mlir::MemRefType type) { 128 return fir::SequenceType::get(SmallVector<int64_t>(type.getShape()), 129 type.getElementType()); 130 } 131 132 class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> { 133 public: 134 using OpRewritePattern::OpRewritePattern; 135 llvm::LogicalResult 136 matchAndRewrite(memref::AllocOp op, 137 mlir::PatternRewriter &rewriter) const override { 138 rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()), 139 op.getMemref()); 140 return success(); 141 } 142 }; 143 144 class AffineDialectDemotion 145 : public fir::impl::AffineDialectDemotionBase<AffineDialectDemotion> { 146 public: 147 void runOnOperation() override { 148 auto *context = &getContext(); 149 auto function = getOperation(); 150 LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n"; 151 function.print(llvm::dbgs());); 152 153 mlir::RewritePatternSet patterns(context); 154 patterns.insert<ConvertConversion>(context); 155 patterns.insert<AffineLoadConversion>(context); 156 patterns.insert<AffineStoreConversion>(context); 157 patterns.insert<StdAllocConversion>(context); 158 mlir::ConversionTarget target(*context); 159 target.addIllegalOp<memref::AllocOp>(); 160 target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) { 161 if (mlir::isa<mlir::MemRefType>(op.getRes().getType())) 162 return false; 163 return true; 164 }); 165 target 166 .addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, 167 mlir::arith::ArithDialect, mlir::func::FuncDialect>(); 168 169 if (mlir::failed(mlir::applyPartialConversion(function, target, 170 std::move(patterns)))) { 171 mlir::emitError(mlir::UnknownLoc::get(context), 172 "error in converting affine dialect\n"); 173 signalPassFailure(); 174 } 175 } 176 }; 177 178 } // namespace 179 180 std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() { 181 return std::make_unique<AffineDialectDemotion>(); 182 } 183