xref: /llvm-project/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp (revision b07ef9e7cd6f5348df0a4f63e70a60491427ff64)
1 //===-- PolymorphicOpConversion.cpp ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Dialect/FIRDialect.h"
10 #include "flang/Optimizer/Dialect/FIROps.h"
11 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
12 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
13 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
14 #include "flang/Optimizer/Support/InternalNames.h"
15 #include "flang/Optimizer/Support/TypeCode.h"
16 #include "flang/Optimizer/Transforms/Passes.h"
17 #include "flang/Runtime/derived-api.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/Support/CommandLine.h"
25 #include <mutex>
26 
27 namespace fir {
28 #define GEN_PASS_DEF_POLYMORPHICOPCONVERSION
29 #include "flang/Optimizer/Transforms/Passes.h.inc"
30 } // namespace fir
31 
32 using namespace fir;
33 using namespace mlir;
34 
35 namespace {
36 
37 /// SelectTypeOp converted to an if-then-else chain
38 ///
39 /// This lowers the test conditions to calls into the runtime.
40 class SelectTypeConv : public OpConversionPattern<fir::SelectTypeOp> {
41 public:
42   using OpConversionPattern<fir::SelectTypeOp>::OpConversionPattern;
43 
44   SelectTypeConv(mlir::MLIRContext *ctx, std::mutex *moduleMutex)
45       : mlir::OpConversionPattern<fir::SelectTypeOp>(ctx),
46         moduleMutex(moduleMutex) {}
47 
48   mlir::LogicalResult
49   matchAndRewrite(fir::SelectTypeOp selectType, OpAdaptor adaptor,
50                   mlir::ConversionPatternRewriter &rewriter) const override;
51 
52 private:
53   // Generate comparison of type descriptor addresses.
54   mlir::Value genTypeDescCompare(mlir::Location loc, mlir::Value selector,
55                                  mlir::Type ty, mlir::ModuleOp mod,
56                                  mlir::PatternRewriter &rewriter) const;
57 
58   static int getTypeCode(mlir::Type ty, fir::KindMapping &kindMap);
59 
60   mlir::LogicalResult genTypeLadderStep(mlir::Location loc,
61                                         mlir::Value selector,
62                                         mlir::Attribute attr, mlir::Block *dest,
63                                         std::optional<mlir::ValueRange> destOps,
64                                         mlir::ModuleOp mod,
65                                         mlir::PatternRewriter &rewriter,
66                                         fir::KindMapping &kindMap) const;
67 
68   llvm::SmallSet<llvm::StringRef, 4> collectAncestors(fir::DispatchTableOp dt,
69                                                       mlir::ModuleOp mod) const;
70 
71   // Mutex used to guard insertion of mlir::func::FuncOp in the module.
72   std::mutex *moduleMutex;
73 };
74 
75 /// Convert FIR structured control flow ops to CFG ops.
76 class PolymorphicOpConversion
77     : public fir::impl::PolymorphicOpConversionBase<PolymorphicOpConversion> {
78 public:
79   mlir::LogicalResult initialize(mlir::MLIRContext *ctx) override {
80     moduleMutex = new std::mutex();
81     return mlir::success();
82   }
83 
84   void runOnOperation() override {
85     auto *context = &getContext();
86     mlir::RewritePatternSet patterns(context);
87     patterns.insert<SelectTypeConv>(context, moduleMutex);
88     mlir::ConversionTarget target(*context);
89     target.addLegalDialect<mlir::AffineDialect, mlir::cf::ControlFlowDialect,
90                            FIROpsDialect, mlir::func::FuncDialect>();
91 
92     // apply the patterns
93     target.addIllegalOp<SelectTypeOp>();
94     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
95     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
96                                                   std::move(patterns)))) {
97       mlir::emitError(mlir::UnknownLoc::get(context),
98                       "error in converting to CFG\n");
99       signalPassFailure();
100     }
101   }
102 
103 private:
104   std::mutex *moduleMutex;
105 };
106 } // namespace
107 
108 mlir::LogicalResult SelectTypeConv::matchAndRewrite(
109     fir::SelectTypeOp selectType, OpAdaptor adaptor,
110     mlir::ConversionPatternRewriter &rewriter) const {
111   auto operands = adaptor.getOperands();
112   auto typeGuards = selectType.getCases();
113   unsigned typeGuardNum = typeGuards.size();
114   auto selector = selectType.getSelector();
115   auto loc = selectType.getLoc();
116   auto mod = selectType.getOperation()->getParentOfType<mlir::ModuleOp>();
117   fir::KindMapping kindMap = fir::getKindMapping(mod);
118 
119   // Order type guards so the condition and branches are done to respect the
120   // Execution of SELECT TYPE construct as described in the Fortran 2018
121   // standard 11.1.11.2 point 4.
122   // 1. If a TYPE IS type guard statement matches the selector, the block
123   //    following that statement is executed.
124   // 2. Otherwise, if exactly one CLASS IS type guard statement matches the
125   //    selector, the block following that statement is executed.
126   // 3. Otherwise, if several CLASS IS type guard statements match the
127   //    selector, one of these statements will inevitably specify a type that
128   //    is an extension of all the types specified in the others; the block
129   //    following that statement is executed.
130   // 4. Otherwise, if there is a CLASS DEFAULT type guard statement, the block
131   //    following that statement is executed.
132   // 5. Otherwise, no block is executed.
133 
134   llvm::SmallVector<unsigned> orderedTypeGuards;
135   llvm::SmallVector<unsigned> orderedClassIsGuards;
136   unsigned defaultGuard = typeGuardNum - 1;
137 
138   // The following loop go through the type guards in the fir.select_type
139   // operation and sort them into two lists.
140   // - All the TYPE IS type guard are added in order to the orderedTypeGuards
141   //   list. This list is used at the end to generate the if-then-else ladder.
142   // - CLASS IS type guard are added in a separate list. If a CLASS IS type
143   //   guard type extends a type already present, the type guard is inserted
144   //   before in the list to respect point 3. above. Otherwise it is just
145   //   added in order at the end.
146   for (unsigned t = 0; t < typeGuardNum; ++t) {
147     if (auto a = typeGuards[t].dyn_cast<fir::ExactTypeAttr>()) {
148       orderedTypeGuards.push_back(t);
149       continue;
150     }
151 
152     if (auto a = typeGuards[t].dyn_cast<fir::SubclassAttr>()) {
153       if (auto recTy = a.getType().dyn_cast<fir::RecordType>()) {
154         auto dt = mod.lookupSymbol<fir::DispatchTableOp>(recTy.getName());
155         assert(dt && "dispatch table not found");
156         llvm::SmallSet<llvm::StringRef, 4> ancestors =
157             collectAncestors(dt, mod);
158         if (!ancestors.empty()) {
159           auto it = orderedClassIsGuards.begin();
160           while (it != orderedClassIsGuards.end()) {
161             fir::SubclassAttr sAttr =
162                 typeGuards[*it].dyn_cast<fir::SubclassAttr>();
163             if (auto ty = sAttr.getType().dyn_cast<fir::RecordType>()) {
164               if (ancestors.contains(ty.getName()))
165                 break;
166             }
167             ++it;
168           }
169           if (it != orderedClassIsGuards.end()) {
170             // Parent type is present so place it before.
171             orderedClassIsGuards.insert(it, t);
172             continue;
173           }
174         }
175       }
176       orderedClassIsGuards.push_back(t);
177     }
178   }
179   orderedTypeGuards.append(orderedClassIsGuards);
180   orderedTypeGuards.push_back(defaultGuard);
181   assert(orderedTypeGuards.size() == typeGuardNum &&
182          "ordered type guard size doesn't match number of type guards");
183 
184   for (unsigned idx : orderedTypeGuards) {
185     auto *dest = selectType.getSuccessor(idx);
186     std::optional<mlir::ValueRange> destOps =
187         selectType.getSuccessorOperands(operands, idx);
188     if (typeGuards[idx].dyn_cast<mlir::UnitAttr>())
189       rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(selectType, dest);
190     else if (mlir::failed(genTypeLadderStep(loc, selector, typeGuards[idx],
191                                             dest, destOps, mod, rewriter,
192                                             kindMap)))
193       return mlir::failure();
194   }
195   return mlir::success();
196 }
197 
198 mlir::LogicalResult SelectTypeConv::genTypeLadderStep(
199     mlir::Location loc, mlir::Value selector, mlir::Attribute attr,
200     mlir::Block *dest, std::optional<mlir::ValueRange> destOps,
201     mlir::ModuleOp mod, mlir::PatternRewriter &rewriter,
202     fir::KindMapping &kindMap) const {
203   mlir::Value cmp;
204   // TYPE IS type guard comparison are all done inlined.
205   if (auto a = attr.dyn_cast<fir::ExactTypeAttr>()) {
206     if (fir::isa_trivial(a.getType()) ||
207         a.getType().isa<fir::CharacterType>()) {
208       // For type guard statement with Intrinsic type spec the type code of
209       // the descriptor is compared.
210       int code = getTypeCode(a.getType(), kindMap);
211       if (code == 0)
212         return mlir::emitError(loc)
213                << "type code unavailable for " << a.getType();
214       mlir::Value typeCode = rewriter.create<mlir::arith::ConstantOp>(
215           loc, rewriter.getI8IntegerAttr(code));
216       mlir::Value selectorTypeCode = rewriter.create<fir::BoxTypeCodeOp>(
217           loc, rewriter.getI8Type(), selector);
218       cmp = rewriter.create<mlir::arith::CmpIOp>(
219           loc, mlir::arith::CmpIPredicate::eq, selectorTypeCode, typeCode);
220     } else {
221       // Flang inline the kind parameter in the type descriptor so we can
222       // directly check if the type descriptor addresses are identical for
223       // the TYPE IS type guard statement.
224       mlir::Value res =
225           genTypeDescCompare(loc, selector, a.getType(), mod, rewriter);
226       if (!res)
227         return mlir::failure();
228       cmp = res;
229     }
230     // CLASS IS type guard statement is done with a runtime call.
231   } else if (auto a = attr.dyn_cast<fir::SubclassAttr>()) {
232     // Retrieve the type descriptor from the type guard statement record type.
233     assert(a.getType().isa<fir::RecordType>() && "expect fir.record type");
234     fir::RecordType recTy = a.getType().dyn_cast<fir::RecordType>();
235     std::string typeDescName =
236         fir::NameUniquer::getTypeDescriptorName(recTy.getName());
237     auto typeDescGlobal = mod.lookupSymbol<fir::GlobalOp>(typeDescName);
238     auto typeDescAddr = rewriter.create<fir::AddrOfOp>(
239         loc, fir::ReferenceType::get(typeDescGlobal.getType()),
240         typeDescGlobal.getSymbol());
241     mlir::Type typeDescTy = ReferenceType::get(rewriter.getNoneType());
242     mlir::Value typeDesc =
243         rewriter.create<ConvertOp>(loc, typeDescTy, typeDescAddr);
244 
245     // Prepare the selector descriptor for the runtime call.
246     mlir::Type descNoneTy = fir::BoxType::get(rewriter.getNoneType());
247     mlir::Value descSelector =
248         rewriter.create<ConvertOp>(loc, descNoneTy, selector);
249 
250     // Generate runtime call.
251     llvm::StringRef fctName = RTNAME_STRING(ClassIs);
252     mlir::func::FuncOp callee;
253     {
254       // Since conversion is done in parallel for each fir.select_type
255       // operation, the runtime function insertion must be threadsafe.
256       std::lock_guard<std::mutex> lock(*moduleMutex);
257       callee =
258           fir::createFuncOp(rewriter.getUnknownLoc(), mod, fctName,
259                             rewriter.getFunctionType({descNoneTy, typeDescTy},
260                                                      rewriter.getI1Type()));
261     }
262     cmp = rewriter
263               .create<fir::CallOp>(loc, callee,
264                                    mlir::ValueRange{descSelector, typeDesc})
265               .getResult(0);
266   }
267 
268   auto *thisBlock = rewriter.getInsertionBlock();
269   auto *newBlock =
270       rewriter.createBlock(dest->getParent(), mlir::Region::iterator(dest));
271   rewriter.setInsertionPointToEnd(thisBlock);
272   if (destOps.has_value())
273     rewriter.create<mlir::cf::CondBranchOp>(loc, cmp, dest, destOps.value(),
274                                             newBlock, std::nullopt);
275   else
276     rewriter.create<mlir::cf::CondBranchOp>(loc, cmp, dest, newBlock);
277   rewriter.setInsertionPointToEnd(newBlock);
278   return mlir::success();
279 }
280 
281 // Generate comparison of type descriptor addresses.
282 mlir::Value
283 SelectTypeConv::genTypeDescCompare(mlir::Location loc, mlir::Value selector,
284                                    mlir::Type ty, mlir::ModuleOp mod,
285                                    mlir::PatternRewriter &rewriter) const {
286   assert(ty.isa<fir::RecordType>() && "expect fir.record type");
287   fir::RecordType recTy = ty.dyn_cast<fir::RecordType>();
288   std::string typeDescName =
289       fir::NameUniquer::getTypeDescriptorName(recTy.getName());
290   auto typeDescGlobal = mod.lookupSymbol<fir::GlobalOp>(typeDescName);
291   if (!typeDescGlobal)
292     return {};
293   auto typeDescAddr = rewriter.create<fir::AddrOfOp>(
294       loc, fir::ReferenceType::get(typeDescGlobal.getType()),
295       typeDescGlobal.getSymbol());
296   auto intPtrTy = rewriter.getIndexType();
297   mlir::Type tdescType =
298       fir::TypeDescType::get(mlir::NoneType::get(rewriter.getContext()));
299   mlir::Value selectorTdescAddr =
300       rewriter.create<fir::BoxTypeDescOp>(loc, tdescType, selector);
301   auto typeDescInt =
302       rewriter.create<fir::ConvertOp>(loc, intPtrTy, typeDescAddr);
303   auto selectorTdescInt =
304       rewriter.create<fir::ConvertOp>(loc, intPtrTy, selectorTdescAddr);
305   return rewriter.create<mlir::arith::CmpIOp>(
306       loc, mlir::arith::CmpIPredicate::eq, typeDescInt, selectorTdescInt);
307 }
308 
309 int SelectTypeConv::getTypeCode(mlir::Type ty, fir::KindMapping &kindMap) {
310   if (auto intTy = ty.dyn_cast<mlir::IntegerType>())
311     return fir::integerBitsToTypeCode(intTy.getWidth());
312   if (auto floatTy = ty.dyn_cast<mlir::FloatType>())
313     return fir::realBitsToTypeCode(floatTy.getWidth());
314   if (auto logicalTy = ty.dyn_cast<fir::LogicalType>())
315     return fir::logicalBitsToTypeCode(
316         kindMap.getLogicalBitsize(logicalTy.getFKind()));
317   if (fir::isa_complex(ty)) {
318     if (auto cmplxTy = ty.dyn_cast<mlir::ComplexType>())
319       return fir::complexBitsToTypeCode(
320           cmplxTy.getElementType().cast<mlir::FloatType>().getWidth());
321     auto cmplxTy = ty.cast<fir::ComplexType>();
322     return fir::complexBitsToTypeCode(
323         kindMap.getRealBitsize(cmplxTy.getFKind()));
324   }
325   if (auto charTy = ty.dyn_cast<fir::CharacterType>())
326     return fir::characterBitsToTypeCode(
327         kindMap.getCharacterBitsize(charTy.getFKind()));
328   return 0;
329 }
330 
331 llvm::SmallSet<llvm::StringRef, 4>
332 SelectTypeConv::collectAncestors(fir::DispatchTableOp dt,
333                                  mlir::ModuleOp mod) const {
334   llvm::SmallSet<llvm::StringRef, 4> ancestors;
335   if (!dt.getParent().has_value())
336     return ancestors;
337   while (dt.getParent().has_value()) {
338     ancestors.insert(*dt.getParent());
339     dt = mod.lookupSymbol<fir::DispatchTableOp>(*dt.getParent());
340   }
341   return ancestors;
342 }
343 
344 std::unique_ptr<mlir::Pass> fir::createPolymorphicOpConversionPass() {
345   return std::make_unique<PolymorphicOpConversion>();
346 }
347