xref: /llvm-project/flang/lib/Optimizer/CodeGen/PreCGRewrite.cpp (revision 0def9a923dadc2b2b3dd067eefcef541e475594c)
1 //===-- PreCGRewrite.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/CodeGen/CodeGen.h"
14 
15 #include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done
16 #include "flang/Optimizer/CodeGen/CGOps.h"
17 #include "flang/Optimizer/Dialect/FIRDialect.h"
18 #include "flang/Optimizer/Dialect/FIROps.h"
19 #include "flang/Optimizer/Dialect/FIRType.h"
20 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
21 #include "mlir/IR/Iterators.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/Support/Debug.h"
25 
26 namespace fir {
27 #define GEN_PASS_DEF_CODEGENREWRITE
28 #include "flang/Optimizer/CodeGen/CGPasses.h.inc"
29 } // namespace fir
30 
31 //===----------------------------------------------------------------------===//
32 // Codegen rewrite: rewriting of subgraphs of ops
33 //===----------------------------------------------------------------------===//
34 
35 #define DEBUG_TYPE "flang-codegen-rewrite"
36 
37 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec,
38                           fir::ShapeOp shape) {
39   vec.append(shape.getExtents().begin(), shape.getExtents().end());
40 }
41 
42 // Operands of fir.shape_shift split into two vectors.
43 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec,
44                                   llvm::SmallVectorImpl<mlir::Value> &shiftVec,
45                                   fir::ShapeShiftOp shift) {
46   for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end();
47        i != endIter;) {
48     shiftVec.push_back(*i++);
49     shapeVec.push_back(*i++);
50   }
51 }
52 
53 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec,
54                           fir::ShiftOp shift) {
55   vec.append(shift.getOrigins().begin(), shift.getOrigins().end());
56 }
57 
58 namespace {
59 
60 /// Convert fir.embox to the extended form where necessary.
61 ///
62 /// The embox operation can take arguments that specify multidimensional array
63 /// properties at runtime. These properties may be shared between distinct
64 /// objects that have the same properties. Before we lower these small DAGs to
65 /// LLVM-IR, we gather all the information into a single extended operation. For
66 /// example,
67 /// ```
68 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
69 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
70 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>,
71 /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
72 /// ```
73 /// can be rewritten as
74 /// ```
75 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] :
76 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) ->
77 /// !fir.box<!fir.array<?xi32>>
78 /// ```
79 class EmboxConversion : public mlir::OpRewritePattern<fir::EmboxOp> {
80 public:
81   using OpRewritePattern::OpRewritePattern;
82 
83   llvm::LogicalResult
84   matchAndRewrite(fir::EmboxOp embox,
85                   mlir::PatternRewriter &rewriter) const override {
86     // If the embox does not include a shape, then do not convert it
87     if (auto shapeVal = embox.getShape())
88       return rewriteDynamicShape(embox, rewriter, shapeVal);
89     if (mlir::isa<fir::ClassType>(embox.getType()))
90       TODO(embox.getLoc(), "embox conversion for fir.class type");
91     if (auto boxTy = mlir::dyn_cast<fir::BoxType>(embox.getType()))
92       if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy()))
93         if (!seqTy.hasDynamicExtents())
94           return rewriteStaticShape(embox, rewriter, seqTy);
95     return mlir::failure();
96   }
97 
98   llvm::LogicalResult rewriteStaticShape(fir::EmboxOp embox,
99                                          mlir::PatternRewriter &rewriter,
100                                          fir::SequenceType seqTy) const {
101     auto loc = embox.getLoc();
102     llvm::SmallVector<mlir::Value> shapeOpers;
103     auto idxTy = rewriter.getIndexType();
104     for (auto ext : seqTy.getShape()) {
105       auto iAttr = rewriter.getIndexAttr(ext);
106       auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr);
107       shapeOpers.push_back(extVal);
108     }
109     auto xbox = rewriter.create<fir::cg::XEmboxOp>(
110         loc, embox.getType(), embox.getMemref(), shapeOpers, std::nullopt,
111         std::nullopt, std::nullopt, std::nullopt, embox.getTypeparams(),
112         embox.getSourceBox(), embox.getAllocatorIdxAttr());
113     LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
114     rewriter.replaceOp(embox, xbox.getOperation()->getResults());
115     return mlir::success();
116   }
117 
118   llvm::LogicalResult rewriteDynamicShape(fir::EmboxOp embox,
119                                           mlir::PatternRewriter &rewriter,
120                                           mlir::Value shapeVal) const {
121     auto loc = embox.getLoc();
122     llvm::SmallVector<mlir::Value> shapeOpers;
123     llvm::SmallVector<mlir::Value> shiftOpers;
124     if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) {
125       populateShape(shapeOpers, shapeOp);
126     } else {
127       auto shiftOp =
128           mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp());
129       assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift");
130       populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
131     }
132     llvm::SmallVector<mlir::Value> sliceOpers;
133     llvm::SmallVector<mlir::Value> subcompOpers;
134     llvm::SmallVector<mlir::Value> substrOpers;
135     if (auto s = embox.getSlice())
136       if (auto sliceOp =
137               mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
138         sliceOpers.assign(sliceOp.getTriples().begin(),
139                           sliceOp.getTriples().end());
140         subcompOpers.assign(sliceOp.getFields().begin(),
141                             sliceOp.getFields().end());
142         substrOpers.assign(sliceOp.getSubstr().begin(),
143                            sliceOp.getSubstr().end());
144       }
145     auto xbox = rewriter.create<fir::cg::XEmboxOp>(
146         loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers,
147         sliceOpers, subcompOpers, substrOpers, embox.getTypeparams(),
148         embox.getSourceBox(), embox.getAllocatorIdxAttr());
149     LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
150     rewriter.replaceOp(embox, xbox.getOperation()->getResults());
151     return mlir::success();
152   }
153 };
154 
155 /// Convert fir.rebox to the extended form where necessary.
156 ///
157 /// For example,
158 /// ```
159 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) ->
160 /// !fir.box<!fir.array<?xi32>>
161 /// ```
162 /// converted to
163 /// ```
164 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>,
165 /// index, index) -> !fir.box<!fir.array<?xi32>>
166 /// ```
167 class ReboxConversion : public mlir::OpRewritePattern<fir::ReboxOp> {
168 public:
169   using OpRewritePattern::OpRewritePattern;
170 
171   llvm::LogicalResult
172   matchAndRewrite(fir::ReboxOp rebox,
173                   mlir::PatternRewriter &rewriter) const override {
174     auto loc = rebox.getLoc();
175     llvm::SmallVector<mlir::Value> shapeOpers;
176     llvm::SmallVector<mlir::Value> shiftOpers;
177     if (auto shapeVal = rebox.getShape()) {
178       if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
179         populateShape(shapeOpers, shapeOp);
180       else if (auto shiftOp =
181                    mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
182         populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
183       else if (auto shiftOp =
184                    mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
185         populateShift(shiftOpers, shiftOp);
186       else
187         return mlir::failure();
188     }
189     llvm::SmallVector<mlir::Value> sliceOpers;
190     llvm::SmallVector<mlir::Value> subcompOpers;
191     llvm::SmallVector<mlir::Value> substrOpers;
192     if (auto s = rebox.getSlice())
193       if (auto sliceOp =
194               mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
195         sliceOpers.append(sliceOp.getTriples().begin(),
196                           sliceOp.getTriples().end());
197         subcompOpers.append(sliceOp.getFields().begin(),
198                             sliceOp.getFields().end());
199         substrOpers.append(sliceOp.getSubstr().begin(),
200                            sliceOp.getSubstr().end());
201       }
202 
203     auto xRebox = rewriter.create<fir::cg::XReboxOp>(
204         loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers,
205         sliceOpers, subcompOpers, substrOpers);
206     LLVM_DEBUG(llvm::dbgs()
207                << "rewriting " << rebox << " to " << xRebox << '\n');
208     rewriter.replaceOp(rebox, xRebox.getOperation()->getResults());
209     return mlir::success();
210   }
211 };
212 
213 /// Convert all fir.array_coor to the extended form.
214 ///
215 /// For example,
216 /// ```
217 ///  %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>,
218 ///  !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
219 /// ```
220 /// converted to
221 /// ```
222 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> :
223 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) ->
224 /// !fir.ref<i32>
225 /// ```
226 class ArrayCoorConversion : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
227 public:
228   using OpRewritePattern::OpRewritePattern;
229 
230   llvm::LogicalResult
231   matchAndRewrite(fir::ArrayCoorOp arrCoor,
232                   mlir::PatternRewriter &rewriter) const override {
233     auto loc = arrCoor.getLoc();
234     llvm::SmallVector<mlir::Value> shapeOpers;
235     llvm::SmallVector<mlir::Value> shiftOpers;
236     if (auto shapeVal = arrCoor.getShape()) {
237       if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
238         populateShape(shapeOpers, shapeOp);
239       else if (auto shiftOp =
240                    mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
241         populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
242       else if (auto shiftOp =
243                    mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
244         populateShift(shiftOpers, shiftOp);
245       else
246         return mlir::failure();
247     }
248     llvm::SmallVector<mlir::Value> sliceOpers;
249     llvm::SmallVector<mlir::Value> subcompOpers;
250     if (auto s = arrCoor.getSlice())
251       if (auto sliceOp =
252               mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
253         sliceOpers.append(sliceOp.getTriples().begin(),
254                           sliceOp.getTriples().end());
255         subcompOpers.append(sliceOp.getFields().begin(),
256                             sliceOp.getFields().end());
257         assert(sliceOp.getSubstr().empty() &&
258                "Don't allow substring operations on array_coor. This "
259                "restriction may be lifted in the future.");
260       }
261     auto xArrCoor = rewriter.create<fir::cg::XArrayCoorOp>(
262         loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers,
263         sliceOpers, subcompOpers, arrCoor.getIndices(),
264         arrCoor.getTypeparams());
265     LLVM_DEBUG(llvm::dbgs()
266                << "rewriting " << arrCoor << " to " << xArrCoor << '\n');
267     rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults());
268     return mlir::success();
269   }
270 };
271 
272 class DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
273   bool preserveDeclare;
274 
275 public:
276   using OpRewritePattern::OpRewritePattern;
277   DeclareOpConversion(mlir::MLIRContext *ctx, bool preserveDecl)
278       : OpRewritePattern(ctx), preserveDeclare(preserveDecl) {}
279 
280   llvm::LogicalResult
281   matchAndRewrite(fir::DeclareOp declareOp,
282                   mlir::PatternRewriter &rewriter) const override {
283     if (!preserveDeclare) {
284       rewriter.replaceOp(declareOp, declareOp.getMemref());
285       return mlir::success();
286     }
287     auto loc = declareOp.getLoc();
288     llvm::SmallVector<mlir::Value> shapeOpers;
289     llvm::SmallVector<mlir::Value> shiftOpers;
290     if (auto shapeVal = declareOp.getShape()) {
291       if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
292         populateShape(shapeOpers, shapeOp);
293       else if (auto shiftOp =
294                    mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
295         populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
296       else if (auto shiftOp =
297                    mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
298         populateShift(shiftOpers, shiftOp);
299       else
300         return mlir::failure();
301     }
302     // FIXME: Add FortranAttrs and CudaAttrs
303     auto xDeclOp = rewriter.create<fir::cg::XDeclareOp>(
304         loc, declareOp.getType(), declareOp.getMemref(), shapeOpers, shiftOpers,
305         declareOp.getTypeparams(), declareOp.getDummyScope(),
306         declareOp.getUniqName());
307     LLVM_DEBUG(llvm::dbgs()
308                << "rewriting " << declareOp << " to " << xDeclOp << '\n');
309     rewriter.replaceOp(declareOp, xDeclOp.getOperation()->getResults());
310     return mlir::success();
311   }
312 };
313 
314 class DummyScopeOpConversion
315     : public mlir::OpRewritePattern<fir::DummyScopeOp> {
316 public:
317   using OpRewritePattern::OpRewritePattern;
318 
319   llvm::LogicalResult
320   matchAndRewrite(fir::DummyScopeOp dummyScopeOp,
321                   mlir::PatternRewriter &rewriter) const override {
322     rewriter.replaceOpWithNewOp<fir::UndefOp>(dummyScopeOp,
323                                               dummyScopeOp.getType());
324     return mlir::success();
325   }
326 };
327 
328 /// Simple DCE to erase fir.shape/shift/slice/unused shape operands after this
329 /// pass (fir.shape and like have no codegen).
330 /// mlir::RegionDCE is expensive and requires running
331 /// mlir::eraseUnreachableBlocks. It does things that are not needed here, like
332 /// removing unused block arguments. fir.shape/shift/slice cannot be block
333 /// arguments.
334 /// This helper does a naive backward walk of the IR. It is not even guaranteed
335 /// to walk blocks according to backward dominance, but that is good enough for
336 /// what is done here, fir.shape/shift/slice have no usages anymore. The
337 /// backward walk allows getting rid of most of the unused operands, it is not a
338 /// problem to leave some in the weird cases.
339 static void simpleDCE(mlir::RewriterBase &rewriter, mlir::Operation *op) {
340   op->walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
341       [&](mlir::Operation *subOp) {
342         if (mlir::isOpTriviallyDead(subOp))
343           rewriter.eraseOp(subOp);
344       });
345 }
346 
347 class CodeGenRewrite : public fir::impl::CodeGenRewriteBase<CodeGenRewrite> {
348 public:
349   using CodeGenRewriteBase<CodeGenRewrite>::CodeGenRewriteBase;
350 
351   void runOnOperation() override final {
352     mlir::ModuleOp mod = getOperation();
353 
354     auto &context = getContext();
355     mlir::ConversionTarget target(context);
356     target.addLegalDialect<mlir::arith::ArithDialect, fir::FIROpsDialect,
357                            fir::FIRCodeGenDialect, mlir::func::FuncDialect>();
358     target.addIllegalOp<fir::ArrayCoorOp>();
359     target.addIllegalOp<fir::ReboxOp>();
360     target.addIllegalOp<fir::DeclareOp>();
361     target.addIllegalOp<fir::DummyScopeOp>();
362     target.addDynamicallyLegalOp<fir::EmboxOp>([](fir::EmboxOp embox) {
363       return !(embox.getShape() ||
364                mlir::isa<fir::SequenceType>(
365                    mlir::cast<fir::BaseBoxType>(embox.getType()).getEleTy()));
366     });
367     mlir::RewritePatternSet patterns(&context);
368     fir::populatePreCGRewritePatterns(patterns, preserveDeclare);
369     if (mlir::failed(
370             mlir::applyPartialConversion(mod, target, std::move(patterns)))) {
371       mlir::emitError(mlir::UnknownLoc::get(&context),
372                       "error in running the pre-codegen conversions");
373       signalPassFailure();
374       return;
375     }
376     // Erase any residual (fir.shape, fir.slice...).
377     mlir::IRRewriter rewriter(&context);
378     simpleDCE(rewriter, mod.getOperation());
379   }
380 };
381 
382 } // namespace
383 
384 void fir::populatePreCGRewritePatterns(mlir::RewritePatternSet &patterns,
385                                        bool preserveDeclare) {
386   patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion,
387                   DummyScopeOpConversion>(patterns.getContext());
388   patterns.add<DeclareOpConversion>(patterns.getContext(), preserveDeclare);
389 }
390