xref: /llvm-project/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp (revision 039b969b32b64b64123dce30dd28ec4e343d893f)
1 //===-- PreCGRewrite.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CGOps.h"
14 #include "PassDetail.h"
15 #include "flang/Optimizer/CodeGen/CodeGen.h"
16 #include "flang/Optimizer/Dialect/FIRDialect.h"
17 #include "flang/Optimizer/Dialect/FIROps.h"
18 #include "flang/Optimizer/Dialect/FIRType.h"
19 #include "flang/Optimizer/Support/FIRContext.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23 
24 //===----------------------------------------------------------------------===//
25 // Codegen rewrite: rewriting of subgraphs of ops
26 //===----------------------------------------------------------------------===//
27 
28 #define DEBUG_TYPE "flang-codegen-rewrite"
29 
30 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec,
31                           fir::ShapeOp shape) {
32   vec.append(shape.getExtents().begin(), shape.getExtents().end());
33 }
34 
35 // Operands of fir.shape_shift split into two vectors.
36 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec,
37                                   llvm::SmallVectorImpl<mlir::Value> &shiftVec,
38                                   fir::ShapeShiftOp shift) {
39   for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end();
40        i != endIter;) {
41     shiftVec.push_back(*i++);
42     shapeVec.push_back(*i++);
43   }
44 }
45 
46 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec,
47                           fir::ShiftOp shift) {
48   vec.append(shift.getOrigins().begin(), shift.getOrigins().end());
49 }
50 
51 namespace {
52 
53 /// Convert fir.embox to the extended form where necessary.
54 ///
55 /// The embox operation can take arguments that specify multidimensional array
56 /// properties at runtime. These properties may be shared between distinct
57 /// objects that have the same properties. Before we lower these small DAGs to
58 /// LLVM-IR, we gather all the information into a single extended operation. For
59 /// example,
60 /// ```
61 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
62 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
63 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>,
64 /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
65 /// ```
66 /// can be rewritten as
67 /// ```
68 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] :
69 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) ->
70 /// !fir.box<!fir.array<?xi32>>
71 /// ```
72 class EmboxConversion : public mlir::OpRewritePattern<fir::EmboxOp> {
73 public:
74   using OpRewritePattern::OpRewritePattern;
75 
76   mlir::LogicalResult
77   matchAndRewrite(fir::EmboxOp embox,
78                   mlir::PatternRewriter &rewriter) const override {
79     // If the embox does not include a shape, then do not convert it
80     if (auto shapeVal = embox.getShape())
81       return rewriteDynamicShape(embox, rewriter, shapeVal);
82     if (auto boxTy = embox.getType().dyn_cast<fir::BoxType>())
83       if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
84         if (!seqTy.hasDynamicExtents())
85           return rewriteStaticShape(embox, rewriter, seqTy);
86     return mlir::failure();
87   }
88 
89   mlir::LogicalResult rewriteStaticShape(fir::EmboxOp embox,
90                                          mlir::PatternRewriter &rewriter,
91                                          fir::SequenceType seqTy) const {
92     auto loc = embox.getLoc();
93     llvm::SmallVector<mlir::Value> shapeOpers;
94     auto idxTy = rewriter.getIndexType();
95     for (auto ext : seqTy.getShape()) {
96       auto iAttr = rewriter.getIndexAttr(ext);
97       auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr);
98       shapeOpers.push_back(extVal);
99     }
100     auto xbox = rewriter.create<fir::cg::XEmboxOp>(
101         loc, embox.getType(), embox.getMemref(), shapeOpers, llvm::None,
102         llvm::None, llvm::None, llvm::None, embox.getTypeparams());
103     LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
104     rewriter.replaceOp(embox, xbox.getOperation()->getResults());
105     return mlir::success();
106   }
107 
108   mlir::LogicalResult rewriteDynamicShape(fir::EmboxOp embox,
109                                           mlir::PatternRewriter &rewriter,
110                                           mlir::Value shapeVal) const {
111     auto loc = embox.getLoc();
112     llvm::SmallVector<mlir::Value> shapeOpers;
113     llvm::SmallVector<mlir::Value> shiftOpers;
114     if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) {
115       populateShape(shapeOpers, shapeOp);
116     } else {
117       auto shiftOp =
118           mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp());
119       assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift");
120       populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
121     }
122     llvm::SmallVector<mlir::Value> sliceOpers;
123     llvm::SmallVector<mlir::Value> subcompOpers;
124     llvm::SmallVector<mlir::Value> substrOpers;
125     if (auto s = embox.getSlice())
126       if (auto sliceOp =
127               mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
128         sliceOpers.assign(sliceOp.getTriples().begin(),
129                           sliceOp.getTriples().end());
130         subcompOpers.assign(sliceOp.getFields().begin(),
131                             sliceOp.getFields().end());
132         substrOpers.assign(sliceOp.getSubstr().begin(),
133                            sliceOp.getSubstr().end());
134       }
135     auto xbox = rewriter.create<fir::cg::XEmboxOp>(
136         loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers,
137         sliceOpers, subcompOpers, substrOpers, embox.getTypeparams());
138     LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
139     rewriter.replaceOp(embox, xbox.getOperation()->getResults());
140     return mlir::success();
141   }
142 };
143 
144 /// Convert fir.rebox to the extended form where necessary.
145 ///
146 /// For example,
147 /// ```
148 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) ->
149 /// !fir.box<!fir.array<?xi32>>
150 /// ```
151 /// converted to
152 /// ```
153 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>,
154 /// index, index) -> !fir.box<!fir.array<?xi32>>
155 /// ```
156 class ReboxConversion : public mlir::OpRewritePattern<fir::ReboxOp> {
157 public:
158   using OpRewritePattern::OpRewritePattern;
159 
160   mlir::LogicalResult
161   matchAndRewrite(fir::ReboxOp rebox,
162                   mlir::PatternRewriter &rewriter) const override {
163     auto loc = rebox.getLoc();
164     llvm::SmallVector<mlir::Value> shapeOpers;
165     llvm::SmallVector<mlir::Value> shiftOpers;
166     if (auto shapeVal = rebox.getShape()) {
167       if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
168         populateShape(shapeOpers, shapeOp);
169       else if (auto shiftOp =
170                    mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
171         populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
172       else if (auto shiftOp =
173                    mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
174         populateShift(shiftOpers, shiftOp);
175       else
176         return mlir::failure();
177     }
178     llvm::SmallVector<mlir::Value> sliceOpers;
179     llvm::SmallVector<mlir::Value> subcompOpers;
180     llvm::SmallVector<mlir::Value> substrOpers;
181     if (auto s = rebox.getSlice())
182       if (auto sliceOp =
183               mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
184         sliceOpers.append(sliceOp.getTriples().begin(),
185                           sliceOp.getTriples().end());
186         subcompOpers.append(sliceOp.getFields().begin(),
187                             sliceOp.getFields().end());
188         substrOpers.append(sliceOp.getSubstr().begin(),
189                            sliceOp.getSubstr().end());
190       }
191 
192     auto xRebox = rewriter.create<fir::cg::XReboxOp>(
193         loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers,
194         sliceOpers, subcompOpers, substrOpers);
195     LLVM_DEBUG(llvm::dbgs()
196                << "rewriting " << rebox << " to " << xRebox << '\n');
197     rewriter.replaceOp(rebox, xRebox.getOperation()->getResults());
198     return mlir::success();
199   }
200 };
201 
202 /// Convert all fir.array_coor to the extended form.
203 ///
204 /// For example,
205 /// ```
206 ///  %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>,
207 ///  !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
208 /// ```
209 /// converted to
210 /// ```
211 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> :
212 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) ->
213 /// !fir.ref<i32>
214 /// ```
215 class ArrayCoorConversion : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
216 public:
217   using OpRewritePattern::OpRewritePattern;
218 
219   mlir::LogicalResult
220   matchAndRewrite(fir::ArrayCoorOp arrCoor,
221                   mlir::PatternRewriter &rewriter) const override {
222     auto loc = arrCoor.getLoc();
223     llvm::SmallVector<mlir::Value> shapeOpers;
224     llvm::SmallVector<mlir::Value> shiftOpers;
225     if (auto shapeVal = arrCoor.getShape()) {
226       if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
227         populateShape(shapeOpers, shapeOp);
228       else if (auto shiftOp =
229                    mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
230         populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
231       else if (auto shiftOp =
232                    mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
233         populateShift(shiftOpers, shiftOp);
234       else
235         return mlir::failure();
236     }
237     llvm::SmallVector<mlir::Value> sliceOpers;
238     llvm::SmallVector<mlir::Value> subcompOpers;
239     if (auto s = arrCoor.getSlice())
240       if (auto sliceOp =
241               mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
242         sliceOpers.append(sliceOp.getTriples().begin(),
243                           sliceOp.getTriples().end());
244         subcompOpers.append(sliceOp.getFields().begin(),
245                             sliceOp.getFields().end());
246         assert(sliceOp.getSubstr().empty() &&
247                "Don't allow substring operations on array_coor. This "
248                "restriction may be lifted in the future.");
249       }
250     auto xArrCoor = rewriter.create<fir::cg::XArrayCoorOp>(
251         loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers,
252         sliceOpers, subcompOpers, arrCoor.getIndices(),
253         arrCoor.getTypeparams());
254     LLVM_DEBUG(llvm::dbgs()
255                << "rewriting " << arrCoor << " to " << xArrCoor << '\n');
256     rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults());
257     return mlir::success();
258   }
259 };
260 
261 class CodeGenRewrite : public fir::CodeGenRewriteBase<CodeGenRewrite> {
262 public:
263   void runOn(mlir::Operation *op, mlir::Region &region) {
264     auto &context = getContext();
265     mlir::OpBuilder rewriter(&context);
266     mlir::ConversionTarget target(context);
267     target.addLegalDialect<mlir::arith::ArithmeticDialect, fir::FIROpsDialect,
268                            fir::FIRCodeGenDialect, mlir::func::FuncDialect>();
269     target.addIllegalOp<fir::ArrayCoorOp>();
270     target.addIllegalOp<fir::ReboxOp>();
271     target.addDynamicallyLegalOp<fir::EmboxOp>([](fir::EmboxOp embox) {
272       return !(embox.getShape() || embox.getType()
273                                        .cast<fir::BoxType>()
274                                        .getEleTy()
275                                        .isa<fir::SequenceType>());
276     });
277     mlir::RewritePatternSet patterns(&context);
278     patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>(
279         &context);
280     if (mlir::failed(
281             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
282       mlir::emitError(mlir::UnknownLoc::get(&context),
283                       "error in running the pre-codegen conversions");
284       signalPassFailure();
285     }
286     // Erase any residual.
287     simplifyRegion(region);
288   }
289 
290   void runOnOperation() override final {
291     // Call runOn on all top level regions that may contain emboxOp/arrayCoorOp.
292     auto mod = getOperation();
293     for (auto func : mod.getOps<mlir::func::FuncOp>())
294       runOn(func, func.getBody());
295     for (auto global : mod.getOps<fir::GlobalOp>())
296       runOn(global, global.getRegion());
297   }
298 
299   // Clean up the region.
300   void simplifyRegion(mlir::Region &region) {
301     for (auto &block : region.getBlocks())
302       for (auto &op : block.getOperations()) {
303         for (auto &reg : op.getRegions())
304           simplifyRegion(reg);
305         maybeEraseOp(&op);
306       }
307     doDCE();
308   }
309 
310   /// Run a simple DCE cleanup to remove any dead code after the rewrites.
311   void doDCE() {
312     std::vector<mlir::Operation *> workList;
313     workList.swap(opsToErase);
314     while (!workList.empty()) {
315       for (auto *op : workList) {
316         std::vector<mlir::Value> opOperands(op->operand_begin(),
317                                             op->operand_end());
318         LLVM_DEBUG(llvm::dbgs() << "DCE on " << *op << '\n');
319         ++numDCE;
320         op->erase();
321         for (auto opnd : opOperands)
322           maybeEraseOp(opnd.getDefiningOp());
323       }
324       workList.clear();
325       workList.swap(opsToErase);
326     }
327   }
328 
329   void maybeEraseOp(mlir::Operation *op) {
330     if (!op)
331       return;
332     if (op->hasTrait<mlir::OpTrait::IsTerminator>())
333       return;
334     if (mlir::isOpTriviallyDead(op))
335       opsToErase.push_back(op);
336   }
337 
338 private:
339   std::vector<mlir::Operation *> opsToErase;
340 };
341 
342 } // namespace
343 
344 std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() {
345   return std::make_unique<CodeGenRewrite>();
346 }
347