1 //===-- PolymorphicOpConversion.cpp ---------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Dialect/FIRDialect.h" 10 #include "flang/Optimizer/Dialect/FIROps.h" 11 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 12 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 13 #include "flang/Optimizer/Dialect/Support/KindMapping.h" 14 #include "flang/Optimizer/Support/InternalNames.h" 15 #include "flang/Optimizer/Support/TypeCode.h" 16 #include "flang/Optimizer/Transforms/Passes.h" 17 #include "flang/Runtime/derived-api.h" 18 #include "mlir/Dialect/Affine/IR/AffineOps.h" 19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 20 #include "mlir/Dialect/Func/IR/FuncOps.h" 21 #include "mlir/Pass/Pass.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/ADT/SmallSet.h" 24 #include "llvm/Support/CommandLine.h" 25 #include <mutex> 26 27 namespace fir { 28 #define GEN_PASS_DEF_POLYMORPHICOPCONVERSION 29 #include "flang/Optimizer/Transforms/Passes.h.inc" 30 } // namespace fir 31 32 using namespace fir; 33 using namespace mlir; 34 35 namespace { 36 37 /// SelectTypeOp converted to an if-then-else chain 38 /// 39 /// This lowers the test conditions to calls into the runtime. 40 class SelectTypeConv : public OpConversionPattern<fir::SelectTypeOp> { 41 public: 42 using OpConversionPattern<fir::SelectTypeOp>::OpConversionPattern; 43 44 SelectTypeConv(mlir::MLIRContext *ctx, std::mutex *moduleMutex) 45 : mlir::OpConversionPattern<fir::SelectTypeOp>(ctx), 46 moduleMutex(moduleMutex) {} 47 48 mlir::LogicalResult 49 matchAndRewrite(fir::SelectTypeOp selectType, OpAdaptor adaptor, 50 mlir::ConversionPatternRewriter &rewriter) const override; 51 52 private: 53 // Generate comparison of type descriptor addresses. 54 mlir::Value genTypeDescCompare(mlir::Location loc, mlir::Value selector, 55 mlir::Type ty, mlir::ModuleOp mod, 56 mlir::PatternRewriter &rewriter) const; 57 58 static int getTypeCode(mlir::Type ty, fir::KindMapping &kindMap); 59 60 mlir::LogicalResult genTypeLadderStep(mlir::Location loc, 61 mlir::Value selector, 62 mlir::Attribute attr, mlir::Block *dest, 63 std::optional<mlir::ValueRange> destOps, 64 mlir::ModuleOp mod, 65 mlir::PatternRewriter &rewriter, 66 fir::KindMapping &kindMap) const; 67 68 llvm::SmallSet<llvm::StringRef, 4> collectAncestors(fir::DispatchTableOp dt, 69 mlir::ModuleOp mod) const; 70 71 // Mutex used to guard insertion of mlir::func::FuncOp in the module. 72 std::mutex *moduleMutex; 73 }; 74 75 /// Convert FIR structured control flow ops to CFG ops. 76 class PolymorphicOpConversion 77 : public fir::impl::PolymorphicOpConversionBase<PolymorphicOpConversion> { 78 public: 79 mlir::LogicalResult initialize(mlir::MLIRContext *ctx) override { 80 moduleMutex = new std::mutex(); 81 return mlir::success(); 82 } 83 84 void runOnOperation() override { 85 auto *context = &getContext(); 86 mlir::RewritePatternSet patterns(context); 87 patterns.insert<SelectTypeConv>(context, moduleMutex); 88 mlir::ConversionTarget target(*context); 89 target.addLegalDialect<mlir::AffineDialect, mlir::cf::ControlFlowDialect, 90 FIROpsDialect, mlir::func::FuncDialect>(); 91 92 // apply the patterns 93 target.addIllegalOp<SelectTypeOp>(); 94 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); 95 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, 96 std::move(patterns)))) { 97 mlir::emitError(mlir::UnknownLoc::get(context), 98 "error in converting to CFG\n"); 99 signalPassFailure(); 100 } 101 } 102 103 private: 104 std::mutex *moduleMutex; 105 }; 106 } // namespace 107 108 mlir::LogicalResult SelectTypeConv::matchAndRewrite( 109 fir::SelectTypeOp selectType, OpAdaptor adaptor, 110 mlir::ConversionPatternRewriter &rewriter) const { 111 auto operands = adaptor.getOperands(); 112 auto typeGuards = selectType.getCases(); 113 unsigned typeGuardNum = typeGuards.size(); 114 auto selector = selectType.getSelector(); 115 auto loc = selectType.getLoc(); 116 auto mod = selectType.getOperation()->getParentOfType<mlir::ModuleOp>(); 117 fir::KindMapping kindMap = fir::getKindMapping(mod); 118 119 // Order type guards so the condition and branches are done to respect the 120 // Execution of SELECT TYPE construct as described in the Fortran 2018 121 // standard 11.1.11.2 point 4. 122 // 1. If a TYPE IS type guard statement matches the selector, the block 123 // following that statement is executed. 124 // 2. Otherwise, if exactly one CLASS IS type guard statement matches the 125 // selector, the block following that statement is executed. 126 // 3. Otherwise, if several CLASS IS type guard statements match the 127 // selector, one of these statements will inevitably specify a type that 128 // is an extension of all the types specified in the others; the block 129 // following that statement is executed. 130 // 4. Otherwise, if there is a CLASS DEFAULT type guard statement, the block 131 // following that statement is executed. 132 // 5. Otherwise, no block is executed. 133 134 llvm::SmallVector<unsigned> orderedTypeGuards; 135 llvm::SmallVector<unsigned> orderedClassIsGuards; 136 unsigned defaultGuard = typeGuardNum - 1; 137 138 // The following loop go through the type guards in the fir.select_type 139 // operation and sort them into two lists. 140 // - All the TYPE IS type guard are added in order to the orderedTypeGuards 141 // list. This list is used at the end to generate the if-then-else ladder. 142 // - CLASS IS type guard are added in a separate list. If a CLASS IS type 143 // guard type extends a type already present, the type guard is inserted 144 // before in the list to respect point 3. above. Otherwise it is just 145 // added in order at the end. 146 for (unsigned t = 0; t < typeGuardNum; ++t) { 147 if (auto a = typeGuards[t].dyn_cast<fir::ExactTypeAttr>()) { 148 orderedTypeGuards.push_back(t); 149 continue; 150 } 151 152 if (auto a = typeGuards[t].dyn_cast<fir::SubclassAttr>()) { 153 if (auto recTy = a.getType().dyn_cast<fir::RecordType>()) { 154 auto dt = mod.lookupSymbol<fir::DispatchTableOp>(recTy.getName()); 155 assert(dt && "dispatch table not found"); 156 llvm::SmallSet<llvm::StringRef, 4> ancestors = 157 collectAncestors(dt, mod); 158 if (!ancestors.empty()) { 159 auto it = orderedClassIsGuards.begin(); 160 while (it != orderedClassIsGuards.end()) { 161 fir::SubclassAttr sAttr = 162 typeGuards[*it].dyn_cast<fir::SubclassAttr>(); 163 if (auto ty = sAttr.getType().dyn_cast<fir::RecordType>()) { 164 if (ancestors.contains(ty.getName())) 165 break; 166 } 167 ++it; 168 } 169 if (it != orderedClassIsGuards.end()) { 170 // Parent type is present so place it before. 171 orderedClassIsGuards.insert(it, t); 172 continue; 173 } 174 } 175 } 176 orderedClassIsGuards.push_back(t); 177 } 178 } 179 orderedTypeGuards.append(orderedClassIsGuards); 180 orderedTypeGuards.push_back(defaultGuard); 181 assert(orderedTypeGuards.size() == typeGuardNum && 182 "ordered type guard size doesn't match number of type guards"); 183 184 for (unsigned idx : orderedTypeGuards) { 185 auto *dest = selectType.getSuccessor(idx); 186 std::optional<mlir::ValueRange> destOps = 187 selectType.getSuccessorOperands(operands, idx); 188 if (typeGuards[idx].dyn_cast<mlir::UnitAttr>()) 189 rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(selectType, dest); 190 else if (mlir::failed(genTypeLadderStep(loc, selector, typeGuards[idx], 191 dest, destOps, mod, rewriter, 192 kindMap))) 193 return mlir::failure(); 194 } 195 return mlir::success(); 196 } 197 198 mlir::LogicalResult SelectTypeConv::genTypeLadderStep( 199 mlir::Location loc, mlir::Value selector, mlir::Attribute attr, 200 mlir::Block *dest, std::optional<mlir::ValueRange> destOps, 201 mlir::ModuleOp mod, mlir::PatternRewriter &rewriter, 202 fir::KindMapping &kindMap) const { 203 mlir::Value cmp; 204 // TYPE IS type guard comparison are all done inlined. 205 if (auto a = attr.dyn_cast<fir::ExactTypeAttr>()) { 206 if (fir::isa_trivial(a.getType()) || 207 a.getType().isa<fir::CharacterType>()) { 208 // For type guard statement with Intrinsic type spec the type code of 209 // the descriptor is compared. 210 int code = getTypeCode(a.getType(), kindMap); 211 if (code == 0) 212 return mlir::emitError(loc) 213 << "type code unavailable for " << a.getType(); 214 mlir::Value typeCode = rewriter.create<mlir::arith::ConstantOp>( 215 loc, rewriter.getI8IntegerAttr(code)); 216 mlir::Value selectorTypeCode = rewriter.create<fir::BoxTypeCodeOp>( 217 loc, rewriter.getI8Type(), selector); 218 cmp = rewriter.create<mlir::arith::CmpIOp>( 219 loc, mlir::arith::CmpIPredicate::eq, selectorTypeCode, typeCode); 220 } else { 221 // Flang inline the kind parameter in the type descriptor so we can 222 // directly check if the type descriptor addresses are identical for 223 // the TYPE IS type guard statement. 224 mlir::Value res = 225 genTypeDescCompare(loc, selector, a.getType(), mod, rewriter); 226 if (!res) 227 return mlir::failure(); 228 cmp = res; 229 } 230 // CLASS IS type guard statement is done with a runtime call. 231 } else if (auto a = attr.dyn_cast<fir::SubclassAttr>()) { 232 // Retrieve the type descriptor from the type guard statement record type. 233 assert(a.getType().isa<fir::RecordType>() && "expect fir.record type"); 234 fir::RecordType recTy = a.getType().dyn_cast<fir::RecordType>(); 235 std::string typeDescName = 236 fir::NameUniquer::getTypeDescriptorName(recTy.getName()); 237 auto typeDescGlobal = mod.lookupSymbol<fir::GlobalOp>(typeDescName); 238 auto typeDescAddr = rewriter.create<fir::AddrOfOp>( 239 loc, fir::ReferenceType::get(typeDescGlobal.getType()), 240 typeDescGlobal.getSymbol()); 241 mlir::Type typeDescTy = ReferenceType::get(rewriter.getNoneType()); 242 mlir::Value typeDesc = 243 rewriter.create<ConvertOp>(loc, typeDescTy, typeDescAddr); 244 245 // Prepare the selector descriptor for the runtime call. 246 mlir::Type descNoneTy = fir::BoxType::get(rewriter.getNoneType()); 247 mlir::Value descSelector = 248 rewriter.create<ConvertOp>(loc, descNoneTy, selector); 249 250 // Generate runtime call. 251 llvm::StringRef fctName = RTNAME_STRING(ClassIs); 252 mlir::func::FuncOp callee; 253 { 254 // Since conversion is done in parallel for each fir.select_type 255 // operation, the runtime function insertion must be threadsafe. 256 std::lock_guard<std::mutex> lock(*moduleMutex); 257 callee = 258 fir::createFuncOp(rewriter.getUnknownLoc(), mod, fctName, 259 rewriter.getFunctionType({descNoneTy, typeDescTy}, 260 rewriter.getI1Type())); 261 } 262 cmp = rewriter 263 .create<fir::CallOp>(loc, callee, 264 mlir::ValueRange{descSelector, typeDesc}) 265 .getResult(0); 266 } 267 268 auto *thisBlock = rewriter.getInsertionBlock(); 269 auto *newBlock = 270 rewriter.createBlock(dest->getParent(), mlir::Region::iterator(dest)); 271 rewriter.setInsertionPointToEnd(thisBlock); 272 if (destOps.has_value()) 273 rewriter.create<mlir::cf::CondBranchOp>(loc, cmp, dest, destOps.value(), 274 newBlock, std::nullopt); 275 else 276 rewriter.create<mlir::cf::CondBranchOp>(loc, cmp, dest, newBlock); 277 rewriter.setInsertionPointToEnd(newBlock); 278 return mlir::success(); 279 } 280 281 // Generate comparison of type descriptor addresses. 282 mlir::Value 283 SelectTypeConv::genTypeDescCompare(mlir::Location loc, mlir::Value selector, 284 mlir::Type ty, mlir::ModuleOp mod, 285 mlir::PatternRewriter &rewriter) const { 286 assert(ty.isa<fir::RecordType>() && "expect fir.record type"); 287 fir::RecordType recTy = ty.dyn_cast<fir::RecordType>(); 288 std::string typeDescName = 289 fir::NameUniquer::getTypeDescriptorName(recTy.getName()); 290 auto typeDescGlobal = mod.lookupSymbol<fir::GlobalOp>(typeDescName); 291 if (!typeDescGlobal) 292 return {}; 293 auto typeDescAddr = rewriter.create<fir::AddrOfOp>( 294 loc, fir::ReferenceType::get(typeDescGlobal.getType()), 295 typeDescGlobal.getSymbol()); 296 auto intPtrTy = rewriter.getIndexType(); 297 mlir::Type tdescType = 298 fir::TypeDescType::get(mlir::NoneType::get(rewriter.getContext())); 299 mlir::Value selectorTdescAddr = 300 rewriter.create<fir::BoxTypeDescOp>(loc, tdescType, selector); 301 auto typeDescInt = 302 rewriter.create<fir::ConvertOp>(loc, intPtrTy, typeDescAddr); 303 auto selectorTdescInt = 304 rewriter.create<fir::ConvertOp>(loc, intPtrTy, selectorTdescAddr); 305 return rewriter.create<mlir::arith::CmpIOp>( 306 loc, mlir::arith::CmpIPredicate::eq, typeDescInt, selectorTdescInt); 307 } 308 309 int SelectTypeConv::getTypeCode(mlir::Type ty, fir::KindMapping &kindMap) { 310 if (auto intTy = ty.dyn_cast<mlir::IntegerType>()) 311 return fir::integerBitsToTypeCode(intTy.getWidth()); 312 if (auto floatTy = ty.dyn_cast<mlir::FloatType>()) 313 return fir::realBitsToTypeCode(floatTy.getWidth()); 314 if (auto logicalTy = ty.dyn_cast<fir::LogicalType>()) 315 return fir::logicalBitsToTypeCode( 316 kindMap.getLogicalBitsize(logicalTy.getFKind())); 317 if (fir::isa_complex(ty)) { 318 if (auto cmplxTy = ty.dyn_cast<mlir::ComplexType>()) 319 return fir::complexBitsToTypeCode( 320 cmplxTy.getElementType().cast<mlir::FloatType>().getWidth()); 321 auto cmplxTy = ty.cast<fir::ComplexType>(); 322 return fir::complexBitsToTypeCode( 323 kindMap.getRealBitsize(cmplxTy.getFKind())); 324 } 325 if (auto charTy = ty.dyn_cast<fir::CharacterType>()) 326 return fir::characterBitsToTypeCode( 327 kindMap.getCharacterBitsize(charTy.getFKind())); 328 return 0; 329 } 330 331 llvm::SmallSet<llvm::StringRef, 4> 332 SelectTypeConv::collectAncestors(fir::DispatchTableOp dt, 333 mlir::ModuleOp mod) const { 334 llvm::SmallSet<llvm::StringRef, 4> ancestors; 335 if (!dt.getParent().has_value()) 336 return ancestors; 337 while (dt.getParent().has_value()) { 338 ancestors.insert(*dt.getParent()); 339 dt = mod.lookupSymbol<fir::DispatchTableOp>(*dt.getParent()); 340 } 341 return ancestors; 342 } 343 344 std::unique_ptr<mlir::Pass> fir::createPolymorphicOpConversionPass() { 345 return std::make_unique<PolymorphicOpConversion>(); 346 } 347