1 //===-- FIROps.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/Dialect/FIROps.h" 14 #include "flang/Optimizer/Dialect/FIRAttr.h" 15 #include "flang/Optimizer/Dialect/FIRDialect.h" 16 #include "flang/Optimizer/Dialect/FIROpsSupport.h" 17 #include "flang/Optimizer/Dialect/FIRType.h" 18 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 19 #include "flang/Optimizer/Dialect/Support/KindMapping.h" 20 #include "flang/Optimizer/Support/Utils.h" 21 #include "mlir/Dialect/CommonFolders.h" 22 #include "mlir/Dialect/Func/IR/FuncOps.h" 23 #include "mlir/Dialect/OpenACC/OpenACC.h" 24 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 25 #include "mlir/IR/Attributes.h" 26 #include "mlir/IR/BuiltinAttributes.h" 27 #include "mlir/IR/BuiltinOps.h" 28 #include "mlir/IR/Diagnostics.h" 29 #include "mlir/IR/Matchers.h" 30 #include "mlir/IR/OpDefinition.h" 31 #include "mlir/IR/PatternMatch.h" 32 #include "llvm/ADT/STLExtras.h" 33 #include "llvm/ADT/SmallVector.h" 34 #include "llvm/ADT/TypeSwitch.h" 35 36 namespace { 37 #include "flang/Optimizer/Dialect/CanonicalizationPatterns.inc" 38 } // namespace 39 40 static void propagateAttributes(mlir::Operation *fromOp, 41 mlir::Operation *toOp) { 42 if (!fromOp || !toOp) 43 return; 44 45 for (mlir::NamedAttribute attr : fromOp->getAttrs()) { 46 if (attr.getName().getValue().starts_with( 47 mlir::acc::OpenACCDialect::getDialectNamespace())) 48 toOp->setAttr(attr.getName(), attr.getValue()); 49 } 50 } 51 52 /// Return true if a sequence type is of some incomplete size or a record type 53 /// is malformed or contains an incomplete sequence type. An incomplete sequence 54 /// type is one with more unknown extents in the type than have been provided 55 /// via `dynamicExtents`. Sequence types with an unknown rank are incomplete by 56 /// definition. 57 static bool verifyInType(mlir::Type inType, 58 llvm::SmallVectorImpl<llvm::StringRef> &visited, 59 unsigned dynamicExtents = 0) { 60 if (auto st = mlir::dyn_cast<fir::SequenceType>(inType)) { 61 auto shape = st.getShape(); 62 if (shape.size() == 0) 63 return true; 64 for (std::size_t i = 0, end = shape.size(); i < end; ++i) { 65 if (shape[i] != fir::SequenceType::getUnknownExtent()) 66 continue; 67 if (dynamicExtents-- == 0) 68 return true; 69 } 70 } else if (auto rt = mlir::dyn_cast<fir::RecordType>(inType)) { 71 // don't recurse if we're already visiting this one 72 if (llvm::is_contained(visited, rt.getName())) 73 return false; 74 // keep track of record types currently being visited 75 visited.push_back(rt.getName()); 76 for (auto &field : rt.getTypeList()) 77 if (verifyInType(field.second, visited)) 78 return true; 79 visited.pop_back(); 80 } 81 return false; 82 } 83 84 static bool verifyTypeParamCount(mlir::Type inType, unsigned numParams) { 85 auto ty = fir::unwrapSequenceType(inType); 86 if (numParams > 0) { 87 if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) 88 return numParams != recTy.getNumLenParams(); 89 if (auto chrTy = mlir::dyn_cast<fir::CharacterType>(ty)) 90 return !(numParams == 1 && chrTy.hasDynamicLen()); 91 return true; 92 } 93 if (auto chrTy = mlir::dyn_cast<fir::CharacterType>(ty)) 94 return !chrTy.hasConstantLen(); 95 return false; 96 } 97 98 /// Parser shared by Alloca and Allocmem 99 /// 100 /// operation ::= %res = (`fir.alloca` | `fir.allocmem`) $in_type 101 /// ( `(` $typeparams `)` )? ( `,` $shape )? 102 /// attr-dict-without-keyword 103 template <typename FN> 104 static mlir::ParseResult parseAllocatableOp(FN wrapResultType, 105 mlir::OpAsmParser &parser, 106 mlir::OperationState &result) { 107 mlir::Type intype; 108 if (parser.parseType(intype)) 109 return mlir::failure(); 110 auto &builder = parser.getBuilder(); 111 result.addAttribute("in_type", mlir::TypeAttr::get(intype)); 112 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; 113 llvm::SmallVector<mlir::Type> typeVec; 114 bool hasOperands = false; 115 std::int32_t typeparamsSize = 0; 116 if (!parser.parseOptionalLParen()) { 117 // parse the LEN params of the derived type. (<params> : <types>) 118 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || 119 parser.parseColonTypeList(typeVec) || parser.parseRParen()) 120 return mlir::failure(); 121 typeparamsSize = operands.size(); 122 hasOperands = true; 123 } 124 std::int32_t shapeSize = 0; 125 if (!parser.parseOptionalComma()) { 126 // parse size to scale by, vector of n dimensions of type index 127 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None)) 128 return mlir::failure(); 129 shapeSize = operands.size() - typeparamsSize; 130 auto idxTy = builder.getIndexType(); 131 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) 132 typeVec.push_back(idxTy); 133 hasOperands = true; 134 } 135 if (hasOperands && 136 parser.resolveOperands(operands, typeVec, parser.getNameLoc(), 137 result.operands)) 138 return mlir::failure(); 139 mlir::Type restype = wrapResultType(intype); 140 if (!restype) { 141 parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; 142 return mlir::failure(); 143 } 144 result.addAttribute("operandSegmentSizes", builder.getDenseI32ArrayAttr( 145 {typeparamsSize, shapeSize})); 146 if (parser.parseOptionalAttrDict(result.attributes) || 147 parser.addTypeToList(restype, result.types)) 148 return mlir::failure(); 149 return mlir::success(); 150 } 151 152 template <typename OP> 153 static void printAllocatableOp(mlir::OpAsmPrinter &p, OP &op) { 154 p << ' ' << op.getInType(); 155 if (!op.getTypeparams().empty()) { 156 p << '(' << op.getTypeparams() << " : " << op.getTypeparams().getTypes() 157 << ')'; 158 } 159 // print the shape of the allocation (if any); all must be index type 160 for (auto sh : op.getShape()) { 161 p << ", "; 162 p.printOperand(sh); 163 } 164 p.printOptionalAttrDict(op->getAttrs(), {"in_type", "operandSegmentSizes"}); 165 } 166 167 //===----------------------------------------------------------------------===// 168 // AllocaOp 169 //===----------------------------------------------------------------------===// 170 171 /// Create a legal memory reference as return type 172 static mlir::Type wrapAllocaResultType(mlir::Type intype) { 173 // FIR semantics: memory references to memory references are disallowed 174 if (mlir::isa<fir::ReferenceType>(intype)) 175 return {}; 176 return fir::ReferenceType::get(intype); 177 } 178 179 mlir::Type fir::AllocaOp::getAllocatedType() { 180 return mlir::cast<fir::ReferenceType>(getType()).getEleTy(); 181 } 182 183 mlir::Type fir::AllocaOp::getRefTy(mlir::Type ty) { 184 return fir::ReferenceType::get(ty); 185 } 186 187 void fir::AllocaOp::build(mlir::OpBuilder &builder, 188 mlir::OperationState &result, mlir::Type inType, 189 llvm::StringRef uniqName, mlir::ValueRange typeparams, 190 mlir::ValueRange shape, 191 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 192 auto nameAttr = builder.getStringAttr(uniqName); 193 build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, 194 /*pinned=*/false, typeparams, shape); 195 result.addAttributes(attributes); 196 } 197 198 void fir::AllocaOp::build(mlir::OpBuilder &builder, 199 mlir::OperationState &result, mlir::Type inType, 200 llvm::StringRef uniqName, bool pinned, 201 mlir::ValueRange typeparams, mlir::ValueRange shape, 202 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 203 auto nameAttr = builder.getStringAttr(uniqName); 204 build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, {}, 205 pinned, typeparams, shape); 206 result.addAttributes(attributes); 207 } 208 209 void fir::AllocaOp::build(mlir::OpBuilder &builder, 210 mlir::OperationState &result, mlir::Type inType, 211 llvm::StringRef uniqName, llvm::StringRef bindcName, 212 mlir::ValueRange typeparams, mlir::ValueRange shape, 213 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 214 auto nameAttr = 215 uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); 216 auto bindcAttr = 217 bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); 218 build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, 219 bindcAttr, /*pinned=*/false, typeparams, shape); 220 result.addAttributes(attributes); 221 } 222 223 void fir::AllocaOp::build(mlir::OpBuilder &builder, 224 mlir::OperationState &result, mlir::Type inType, 225 llvm::StringRef uniqName, llvm::StringRef bindcName, 226 bool pinned, mlir::ValueRange typeparams, 227 mlir::ValueRange shape, 228 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 229 auto nameAttr = 230 uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); 231 auto bindcAttr = 232 bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); 233 build(builder, result, wrapAllocaResultType(inType), inType, nameAttr, 234 bindcAttr, pinned, typeparams, shape); 235 result.addAttributes(attributes); 236 } 237 238 void fir::AllocaOp::build(mlir::OpBuilder &builder, 239 mlir::OperationState &result, mlir::Type inType, 240 mlir::ValueRange typeparams, mlir::ValueRange shape, 241 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 242 build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, 243 /*pinned=*/false, typeparams, shape); 244 result.addAttributes(attributes); 245 } 246 247 void fir::AllocaOp::build(mlir::OpBuilder &builder, 248 mlir::OperationState &result, mlir::Type inType, 249 bool pinned, mlir::ValueRange typeparams, 250 mlir::ValueRange shape, 251 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 252 build(builder, result, wrapAllocaResultType(inType), inType, {}, {}, pinned, 253 typeparams, shape); 254 result.addAttributes(attributes); 255 } 256 257 mlir::ParseResult fir::AllocaOp::parse(mlir::OpAsmParser &parser, 258 mlir::OperationState &result) { 259 return parseAllocatableOp(wrapAllocaResultType, parser, result); 260 } 261 262 void fir::AllocaOp::print(mlir::OpAsmPrinter &p) { 263 printAllocatableOp(p, *this); 264 } 265 266 llvm::LogicalResult fir::AllocaOp::verify() { 267 llvm::SmallVector<llvm::StringRef> visited; 268 if (verifyInType(getInType(), visited, numShapeOperands())) 269 return emitOpError("invalid type for allocation"); 270 if (verifyTypeParamCount(getInType(), numLenParams())) 271 return emitOpError("LEN params do not correspond to type"); 272 mlir::Type outType = getType(); 273 if (!mlir::isa<fir::ReferenceType>(outType)) 274 return emitOpError("must be a !fir.ref type"); 275 return mlir::success(); 276 } 277 278 bool fir::AllocaOp::ownsNestedAlloca(mlir::Operation *op) { 279 return op->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>() || 280 op->hasTrait<mlir::OpTrait::AutomaticAllocationScope>() || 281 mlir::isa<mlir::LoopLikeOpInterface>(*op); 282 } 283 284 mlir::Region *fir::AllocaOp::getOwnerRegion() { 285 mlir::Operation *currentOp = getOperation(); 286 while (mlir::Operation *parentOp = currentOp->getParentOp()) { 287 // If the operation was not registered, inquiries about its traits will be 288 // incorrect and it is not possible to reason about the operation. This 289 // should not happen in a normal Fortran compilation flow, but be foolproof. 290 if (!parentOp->isRegistered()) 291 return nullptr; 292 if (fir::AllocaOp::ownsNestedAlloca(parentOp)) 293 return currentOp->getParentRegion(); 294 currentOp = parentOp; 295 } 296 return nullptr; 297 } 298 299 //===----------------------------------------------------------------------===// 300 // AllocMemOp 301 //===----------------------------------------------------------------------===// 302 303 /// Create a legal heap reference as return type 304 static mlir::Type wrapAllocMemResultType(mlir::Type intype) { 305 // Fortran semantics: C852 an entity cannot be both ALLOCATABLE and POINTER 306 // 8.5.3 note 1 prohibits ALLOCATABLE procedures as well 307 // FIR semantics: one may not allocate a memory reference value 308 if (mlir::isa<fir::ReferenceType, fir::HeapType, fir::PointerType, 309 mlir::FunctionType>(intype)) 310 return {}; 311 return fir::HeapType::get(intype); 312 } 313 314 mlir::Type fir::AllocMemOp::getAllocatedType() { 315 return mlir::cast<fir::HeapType>(getType()).getEleTy(); 316 } 317 318 mlir::Type fir::AllocMemOp::getRefTy(mlir::Type ty) { 319 return fir::HeapType::get(ty); 320 } 321 322 void fir::AllocMemOp::build(mlir::OpBuilder &builder, 323 mlir::OperationState &result, mlir::Type inType, 324 llvm::StringRef uniqName, 325 mlir::ValueRange typeparams, mlir::ValueRange shape, 326 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 327 auto nameAttr = builder.getStringAttr(uniqName); 328 build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, {}, 329 typeparams, shape); 330 result.addAttributes(attributes); 331 } 332 333 void fir::AllocMemOp::build(mlir::OpBuilder &builder, 334 mlir::OperationState &result, mlir::Type inType, 335 llvm::StringRef uniqName, llvm::StringRef bindcName, 336 mlir::ValueRange typeparams, mlir::ValueRange shape, 337 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 338 auto nameAttr = builder.getStringAttr(uniqName); 339 auto bindcAttr = builder.getStringAttr(bindcName); 340 build(builder, result, wrapAllocMemResultType(inType), inType, nameAttr, 341 bindcAttr, typeparams, shape); 342 result.addAttributes(attributes); 343 } 344 345 void fir::AllocMemOp::build(mlir::OpBuilder &builder, 346 mlir::OperationState &result, mlir::Type inType, 347 mlir::ValueRange typeparams, mlir::ValueRange shape, 348 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 349 build(builder, result, wrapAllocMemResultType(inType), inType, {}, {}, 350 typeparams, shape); 351 result.addAttributes(attributes); 352 } 353 354 mlir::ParseResult fir::AllocMemOp::parse(mlir::OpAsmParser &parser, 355 mlir::OperationState &result) { 356 return parseAllocatableOp(wrapAllocMemResultType, parser, result); 357 } 358 359 void fir::AllocMemOp::print(mlir::OpAsmPrinter &p) { 360 printAllocatableOp(p, *this); 361 } 362 363 llvm::LogicalResult fir::AllocMemOp::verify() { 364 llvm::SmallVector<llvm::StringRef> visited; 365 if (verifyInType(getInType(), visited, numShapeOperands())) 366 return emitOpError("invalid type for allocation"); 367 if (verifyTypeParamCount(getInType(), numLenParams())) 368 return emitOpError("LEN params do not correspond to type"); 369 mlir::Type outType = getType(); 370 if (!mlir::dyn_cast<fir::HeapType>(outType)) 371 return emitOpError("must be a !fir.heap type"); 372 if (fir::isa_unknown_size_box(fir::dyn_cast_ptrEleTy(outType))) 373 return emitOpError("cannot allocate !fir.box of unknown rank or type"); 374 return mlir::success(); 375 } 376 377 //===----------------------------------------------------------------------===// 378 // ArrayCoorOp 379 //===----------------------------------------------------------------------===// 380 381 // CHARACTERs and derived types with LEN PARAMETERs are dependent types that 382 // require runtime values to fully define the type of an object. 383 static bool validTypeParams(mlir::Type dynTy, mlir::ValueRange typeParams) { 384 dynTy = fir::unwrapAllRefAndSeqType(dynTy); 385 // A box value will contain type parameter values itself. 386 if (mlir::isa<fir::BoxType>(dynTy)) 387 return typeParams.size() == 0; 388 // Derived type must have all type parameters satisfied. 389 if (auto recTy = mlir::dyn_cast<fir::RecordType>(dynTy)) 390 return typeParams.size() == recTy.getNumLenParams(); 391 // Characters with non-constant LEN must have a type parameter value. 392 if (auto charTy = mlir::dyn_cast<fir::CharacterType>(dynTy)) 393 if (charTy.hasDynamicLen()) 394 return typeParams.size() == 1; 395 // Otherwise, any type parameters are invalid. 396 return typeParams.size() == 0; 397 } 398 399 llvm::LogicalResult fir::ArrayCoorOp::verify() { 400 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); 401 auto arrTy = mlir::dyn_cast<fir::SequenceType>(eleTy); 402 if (!arrTy) 403 return emitOpError("must be a reference to an array"); 404 auto arrDim = arrTy.getDimension(); 405 406 if (auto shapeOp = getShape()) { 407 auto shapeTy = shapeOp.getType(); 408 unsigned shapeTyRank = 0; 409 if (auto s = mlir::dyn_cast<fir::ShapeType>(shapeTy)) { 410 shapeTyRank = s.getRank(); 411 } else if (auto ss = mlir::dyn_cast<fir::ShapeShiftType>(shapeTy)) { 412 shapeTyRank = ss.getRank(); 413 } else { 414 auto s = mlir::cast<fir::ShiftType>(shapeTy); 415 shapeTyRank = s.getRank(); 416 // TODO: it looks like PreCGRewrite and CodeGen can support 417 // fir.shift with plain array reference, so we may consider 418 // removing this check. 419 if (!mlir::isa<fir::BaseBoxType>(getMemref().getType())) 420 return emitOpError("shift can only be provided with fir.box memref"); 421 } 422 if (arrDim && arrDim != shapeTyRank) 423 return emitOpError("rank of dimension mismatched"); 424 // TODO: support slicing with changing the number of dimensions, 425 // e.g. when array_coor represents an element access to array(:,1,:) 426 // slice: the shape is 3D and the number of indices is 2 in this case. 427 if (shapeTyRank != getIndices().size()) 428 return emitOpError("number of indices do not match dim rank"); 429 } 430 431 if (auto sliceOp = getSlice()) { 432 if (auto sl = mlir::dyn_cast_or_null<fir::SliceOp>(sliceOp.getDefiningOp())) 433 if (!sl.getSubstr().empty()) 434 return emitOpError("array_coor cannot take a slice with substring"); 435 if (auto sliceTy = mlir::dyn_cast<fir::SliceType>(sliceOp.getType())) 436 if (sliceTy.getRank() != arrDim) 437 return emitOpError("rank of dimension in slice mismatched"); 438 } 439 if (!validTypeParams(getMemref().getType(), getTypeparams())) 440 return emitOpError("invalid type parameters"); 441 442 return mlir::success(); 443 } 444 445 // Pull in fir.embox and fir.rebox into fir.array_coor when possible. 446 struct SimplifyArrayCoorOp : public mlir::OpRewritePattern<fir::ArrayCoorOp> { 447 using mlir::OpRewritePattern<fir::ArrayCoorOp>::OpRewritePattern; 448 llvm::LogicalResult 449 matchAndRewrite(fir::ArrayCoorOp op, 450 mlir::PatternRewriter &rewriter) const override { 451 mlir::Value memref = op.getMemref(); 452 if (!mlir::isa<fir::BaseBoxType>(memref.getType())) 453 return mlir::failure(); 454 455 mlir::Value boxedMemref, boxedShape, boxedSlice; 456 if (auto emboxOp = 457 mlir::dyn_cast_or_null<fir::EmboxOp>(memref.getDefiningOp())) { 458 boxedMemref = emboxOp.getMemref(); 459 boxedShape = emboxOp.getShape(); 460 boxedSlice = emboxOp.getSlice(); 461 // If any of operands, that are not currently supported for migration 462 // to ArrayCoorOp, is present, don't rewrite. 463 if (!emboxOp.getTypeparams().empty() || emboxOp.getSourceBox() || 464 emboxOp.getAccessMap()) 465 return mlir::failure(); 466 } else if (auto reboxOp = mlir::dyn_cast_or_null<fir::ReboxOp>( 467 memref.getDefiningOp())) { 468 boxedMemref = reboxOp.getBox(); 469 boxedShape = reboxOp.getShape(); 470 // Avoid pulling in rebox that performs reshaping. 471 // There is no way to represent box reshaping with array_coor. 472 if (boxedShape && !mlir::isa<fir::ShiftType>(boxedShape.getType())) 473 return mlir::failure(); 474 boxedSlice = reboxOp.getSlice(); 475 } else { 476 return mlir::failure(); 477 } 478 479 bool boxedShapeIsShift = 480 boxedShape && mlir::isa<fir::ShiftType>(boxedShape.getType()); 481 bool boxedShapeIsShape = 482 boxedShape && mlir::isa<fir::ShapeType>(boxedShape.getType()); 483 bool boxedShapeIsShapeShift = 484 boxedShape && mlir::isa<fir::ShapeShiftType>(boxedShape.getType()); 485 486 // Slices changing the number of dimensions are not supported 487 // for array_coor yet. 488 unsigned origBoxRank; 489 if (mlir::isa<fir::BaseBoxType>(boxedMemref.getType())) 490 origBoxRank = fir::getBoxRank(boxedMemref.getType()); 491 else if (auto arrTy = mlir::dyn_cast<fir::SequenceType>( 492 fir::unwrapRefType(boxedMemref.getType()))) 493 origBoxRank = arrTy.getDimension(); 494 else 495 return mlir::failure(); 496 497 if (fir::getBoxRank(memref.getType()) != origBoxRank) 498 return mlir::failure(); 499 500 // Slices with substring are not supported by array_coor. 501 if (boxedSlice) 502 if (auto sliceOp = 503 mlir::dyn_cast_or_null<fir::SliceOp>(boxedSlice.getDefiningOp())) 504 if (!sliceOp.getSubstr().empty()) 505 return mlir::failure(); 506 507 // If embox/rebox and array_coor have conflicting shapes or slices, 508 // do nothing. 509 if (op.getShape() && boxedShape && boxedShape != op.getShape()) 510 return mlir::failure(); 511 if (op.getSlice() && boxedSlice && boxedSlice != op.getSlice()) 512 return mlir::failure(); 513 514 std::optional<IndicesVectorTy> shiftedIndices; 515 // The embox/rebox and array_coor either have compatible 516 // shape/slice at this point or shape/slice is null 517 // in one of them but not in the other. 518 // The compatibility means they are equal or both null. 519 if (!op.getShape()) { 520 if (boxedShape) { 521 if (op.getSlice()) { 522 if (!boxedSlice) { 523 if (boxedShapeIsShift) { 524 // %0 = fir.rebox %arg(%shift) 525 // %1 = fir.array_coor %0 [%slice] %idx 526 // Both the slice indices and %idx are 1-based, so the rebox 527 // may be pulled in as: 528 // %1 = fir.array_coor %arg [%slice] %idx 529 boxedShape = nullptr; 530 } else if (boxedShapeIsShape) { 531 // %0 = fir.embox %arg(%shape) 532 // %1 = fir.array_coor %0 [%slice] %idx 533 // Pull in as: 534 // %1 = fir.array_coor %arg(%shape) [%slice] %idx 535 } else if (boxedShapeIsShapeShift) { 536 // %0 = fir.embox %arg(%shapeshift) 537 // %1 = fir.array_coor %0 [%slice] %idx 538 // Pull in as: 539 // %shape = fir.shape <extents from the %shapeshift> 540 // %1 = fir.array_coor %arg(%shape) [%slice] %idx 541 boxedShape = getShapeFromShapeShift(boxedShape, rewriter); 542 if (!boxedShape) 543 return mlir::failure(); 544 } else { 545 return mlir::failure(); 546 } 547 } else { 548 if (boxedShapeIsShift) { 549 // %0 = fir.rebox %arg(%shift) [%slice] 550 // %1 = fir.array_coor %0 [%slice] %idx 551 // This FIR may only be valid if the shape specifies 552 // that all lower bounds are 1s and the slice's start indices 553 // and strides are all 1s. 554 // We could pull in the rebox as: 555 // %1 = fir.array_coor %arg [%slice] %idx 556 // Do not do anything for the time being. 557 return mlir::failure(); 558 } else if (boxedShapeIsShape) { 559 // %0 = fir.embox %arg(%shape) [%slice] 560 // %1 = fir.array_coor %0 [%slice] %idx 561 // This FIR may only be valid if the slice's start indices 562 // and strides are all 1s. 563 // We could pull in the embox as: 564 // %1 = fir.array_coor %arg(%shape) [%slice] %idx 565 return mlir::failure(); 566 } else if (boxedShapeIsShapeShift) { 567 // %0 = fir.embox %arg(%shapeshift) [%slice] 568 // %1 = fir.array_coor %0 [%slice] %idx 569 // This FIR may only be valid if the shape specifies 570 // that all lower bounds are 1s and the slice's start indices 571 // and strides are all 1s. 572 // We could pull in the embox as: 573 // %shape = fir.shape <extents from the %shapeshift> 574 // %1 = fir.array_coor %arg(%shape) [%slice] %idx 575 return mlir::failure(); 576 } else { 577 return mlir::failure(); 578 } 579 } 580 } else { // !op.getSlice() 581 if (!boxedSlice) { 582 if (boxedShapeIsShift) { 583 // %0 = fir.rebox %arg(%shift) 584 // %1 = fir.array_coor %0 %idx 585 // Pull in as: 586 // %1 = fir.array_coor %arg %idx 587 boxedShape = nullptr; 588 } else if (boxedShapeIsShape) { 589 // %0 = fir.embox %arg(%shape) 590 // %1 = fir.array_coor %0 %idx 591 // Pull in as: 592 // %1 = fir.array_coor %arg(%shape) %idx 593 } else if (boxedShapeIsShapeShift) { 594 // %0 = fir.embox %arg(%shapeshift) 595 // %1 = fir.array_coor %0 %idx 596 // Pull in as: 597 // %shape = fir.shape <extents from the %shapeshift> 598 // %1 = fir.array_coor %arg(%shape) %idx 599 boxedShape = getShapeFromShapeShift(boxedShape, rewriter); 600 if (!boxedShape) 601 return mlir::failure(); 602 } else { 603 return mlir::failure(); 604 } 605 } else { 606 if (boxedShapeIsShift) { 607 // %0 = fir.embox %arg(%shift) [%slice] 608 // %1 = fir.array_coor %0 %idx 609 // Pull in as: 610 // %tmp = arith.addi %idx, %shift.origin 611 // %idx_shifted = arith.subi %tmp, 1 612 // %1 = fir.array_coor %arg(%shift) %[slice] %idx_shifted 613 shiftedIndices = 614 getShiftedIndices(boxedShape, op.getIndices(), rewriter); 615 if (!shiftedIndices) 616 return mlir::failure(); 617 } else if (boxedShapeIsShape) { 618 // %0 = fir.embox %arg(%shape) [%slice] 619 // %1 = fir.array_coor %0 %idx 620 // Pull in as: 621 // %1 = fir.array_coor %arg(%shape) %[slice] %idx 622 } else if (boxedShapeIsShapeShift) { 623 // %0 = fir.embox %arg(%shapeshift) [%slice] 624 // %1 = fir.array_coor %0 %idx 625 // Pull in as: 626 // %tmp = arith.addi %idx, %shapeshift.lb 627 // %idx_shifted = arith.subi %tmp, 1 628 // %1 = fir.array_coor %arg(%shapeshift) %[slice] %idx_shifted 629 shiftedIndices = 630 getShiftedIndices(boxedShape, op.getIndices(), rewriter); 631 if (!shiftedIndices) 632 return mlir::failure(); 633 } else { 634 return mlir::failure(); 635 } 636 } 637 } 638 } else { // !boxedShape 639 if (op.getSlice()) { 640 if (!boxedSlice) { 641 // %0 = fir.rebox %arg 642 // %1 = fir.array_coor %0 [%slice] %idx 643 // Pull in as: 644 // %1 = fir.array_coor %arg [%slice] %idx 645 } else { 646 // %0 = fir.rebox %arg [%slice] 647 // %1 = fir.array_coor %0 [%slice] %idx 648 // This is a valid FIR iff the slice's lower bounds 649 // and strides are all 1s. 650 // Pull in as: 651 // %1 = fir.array_coor %arg [%slice] %idx 652 } 653 } else { // !op.getSlice() 654 if (!boxedSlice) { 655 // %0 = fir.rebox %arg 656 // %1 = fir.array_coor %0 %idx 657 // Pull in as: 658 // %1 = fir.array_coor %arg %idx 659 } else { 660 // %0 = fir.rebox %arg [%slice] 661 // %1 = fir.array_coor %0 %idx 662 // Pull in as: 663 // %1 = fir.array_coor %arg [%slice] %idx 664 } 665 } 666 } 667 } else { // op.getShape() 668 if (boxedShape) { 669 // Check if pulling in non-default shape is correct. 670 if (op.getSlice()) { 671 if (!boxedSlice) { 672 // %0 = fir.embox %arg(%shape) 673 // %1 = fir.array_coor %0(%shape) [%slice] %idx 674 // Pull in as: 675 // %1 = fir.array_coor %arg(%shape) [%slice] %idx 676 } else { 677 // %0 = fir.embox %arg(%shape) [%slice] 678 // %1 = fir.array_coor %0(%shape) [%slice] %idx 679 // Pull in as: 680 // %1 = fir.array_coor %arg(%shape) [%slice] %idx 681 } 682 } else { // !op.getSlice() 683 if (!boxedSlice) { 684 // %0 = fir.embox %arg(%shape) 685 // %1 = fir.array_coor %0(%shape) %idx 686 // Pull in as: 687 // %1 = fir.array_coor %arg(%shape) %idx 688 } else { 689 // %0 = fir.embox %arg(%shape) [%slice] 690 // %1 = fir.array_coor %0(%shape) %idx 691 // Pull in as: 692 // %1 = fir.array_coor %arg(%shape) [%slice] %idx 693 } 694 } 695 } else { // !boxedShape 696 if (op.getSlice()) { 697 if (!boxedSlice) { 698 // %0 = fir.rebox %arg 699 // %1 = fir.array_coor %0(%shape) [%slice] %idx 700 // Pull in as: 701 // %1 = fir.array_coor %arg(%shape) [%slice] %idx 702 } else { 703 // %0 = fir.rebox %arg [%slice] 704 // %1 = fir.array_coor %0(%shape) [%slice] %idx 705 return mlir::failure(); 706 } 707 } else { // !op.getSlice() 708 if (!boxedSlice) { 709 // %0 = fir.rebox %arg 710 // %1 = fir.array_coor %0(%shape) %idx 711 // Pull in as: 712 // %1 = fir.array_coor %arg(%shape) %idx 713 } else { 714 // %0 = fir.rebox %arg [%slice] 715 // %1 = fir.array_coor %0(%shape) %idx 716 // Cannot pull in without adjusting the slice indices. 717 return mlir::failure(); 718 } 719 } 720 } 721 } 722 723 // TODO: temporarily avoid producing array_coor with the shape shift 724 // and plain array reference (it seems to be a limitation of 725 // ArrayCoorOp verifier). 726 if (!mlir::isa<fir::BaseBoxType>(boxedMemref.getType())) { 727 if (boxedShape) { 728 if (mlir::isa<fir::ShiftType>(boxedShape.getType())) 729 return mlir::failure(); 730 } else if (op.getShape() && 731 mlir::isa<fir::ShiftType>(op.getShape().getType())) { 732 return mlir::failure(); 733 } 734 } 735 736 rewriter.modifyOpInPlace(op, [&]() { 737 op.getMemrefMutable().assign(boxedMemref); 738 if (boxedShape) 739 op.getShapeMutable().assign(boxedShape); 740 if (boxedSlice) 741 op.getSliceMutable().assign(boxedSlice); 742 if (shiftedIndices) 743 op.getIndicesMutable().assign(*shiftedIndices); 744 }); 745 return mlir::success(); 746 } 747 748 private: 749 using IndicesVectorTy = std::vector<mlir::Value>; 750 751 // If v is a shape_shift operation: 752 // fir.shape_shift %l1, %e1, %l2, %e2, ... 753 // create: 754 // fir.shape %e1, %e2, ... 755 static mlir::Value getShapeFromShapeShift(mlir::Value v, 756 mlir::PatternRewriter &rewriter) { 757 auto shapeShiftOp = 758 mlir::dyn_cast_or_null<fir::ShapeShiftOp>(v.getDefiningOp()); 759 if (!shapeShiftOp) 760 return nullptr; 761 mlir::OpBuilder::InsertionGuard guard(rewriter); 762 rewriter.setInsertionPoint(shapeShiftOp); 763 return rewriter.create<fir::ShapeOp>(shapeShiftOp.getLoc(), 764 shapeShiftOp.getExtents()); 765 } 766 767 static std::optional<IndicesVectorTy> 768 getShiftedIndices(mlir::Value v, mlir::ValueRange indices, 769 mlir::PatternRewriter &rewriter) { 770 auto insertAdjustments = [&](mlir::Operation *op, mlir::ValueRange lbs) { 771 // Compute the shifted indices using the extended type. 772 // Note that this can probably result in less efficient 773 // MLIR and further LLVM IR due to the extra conversions. 774 mlir::OpBuilder::InsertPoint savedIP = rewriter.saveInsertionPoint(); 775 rewriter.setInsertionPoint(op); 776 mlir::Location loc = op->getLoc(); 777 mlir::Type idxTy = rewriter.getIndexType(); 778 mlir::Value one = rewriter.create<mlir::arith::ConstantOp>( 779 loc, idxTy, rewriter.getIndexAttr(1)); 780 rewriter.restoreInsertionPoint(savedIP); 781 auto nsw = mlir::arith::IntegerOverflowFlags::nsw; 782 783 IndicesVectorTy shiftedIndices; 784 for (auto [lb, idx] : llvm::zip(lbs, indices)) { 785 mlir::Value extLb = rewriter.create<fir::ConvertOp>(loc, idxTy, lb); 786 mlir::Value extIdx = rewriter.create<fir::ConvertOp>(loc, idxTy, idx); 787 mlir::Value add = 788 rewriter.create<mlir::arith::AddIOp>(loc, extIdx, extLb, nsw); 789 mlir::Value sub = 790 rewriter.create<mlir::arith::SubIOp>(loc, add, one, nsw); 791 shiftedIndices.push_back(sub); 792 } 793 794 return shiftedIndices; 795 }; 796 797 if (auto shiftOp = 798 mlir::dyn_cast_or_null<fir::ShiftOp>(v.getDefiningOp())) { 799 return insertAdjustments(shiftOp.getOperation(), shiftOp.getOrigins()); 800 } else if (auto shapeShiftOp = mlir::dyn_cast_or_null<fir::ShapeShiftOp>( 801 v.getDefiningOp())) { 802 return insertAdjustments(shapeShiftOp.getOperation(), 803 shapeShiftOp.getOrigins()); 804 } 805 806 return std::nullopt; 807 } 808 }; 809 810 void fir::ArrayCoorOp::getCanonicalizationPatterns( 811 mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { 812 // TODO: !fir.shape<1> operand may be removed from array_coor always. 813 patterns.add<SimplifyArrayCoorOp>(context); 814 } 815 816 //===----------------------------------------------------------------------===// 817 // ArrayLoadOp 818 //===----------------------------------------------------------------------===// 819 820 static mlir::Type adjustedElementType(mlir::Type t) { 821 if (auto ty = mlir::dyn_cast<fir::ReferenceType>(t)) { 822 auto eleTy = ty.getEleTy(); 823 if (fir::isa_char(eleTy)) 824 return eleTy; 825 if (fir::isa_derived(eleTy)) 826 return eleTy; 827 if (mlir::isa<fir::SequenceType>(eleTy)) 828 return eleTy; 829 } 830 return t; 831 } 832 833 std::vector<mlir::Value> fir::ArrayLoadOp::getExtents() { 834 if (auto sh = getShape()) 835 if (auto *op = sh.getDefiningOp()) { 836 if (auto shOp = mlir::dyn_cast<fir::ShapeOp>(op)) { 837 auto extents = shOp.getExtents(); 838 return {extents.begin(), extents.end()}; 839 } 840 return mlir::cast<fir::ShapeShiftOp>(op).getExtents(); 841 } 842 return {}; 843 } 844 845 llvm::LogicalResult fir::ArrayLoadOp::verify() { 846 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); 847 auto arrTy = mlir::dyn_cast<fir::SequenceType>(eleTy); 848 if (!arrTy) 849 return emitOpError("must be a reference to an array"); 850 auto arrDim = arrTy.getDimension(); 851 852 if (auto shapeOp = getShape()) { 853 auto shapeTy = shapeOp.getType(); 854 unsigned shapeTyRank = 0u; 855 if (auto s = mlir::dyn_cast<fir::ShapeType>(shapeTy)) { 856 shapeTyRank = s.getRank(); 857 } else if (auto ss = mlir::dyn_cast<fir::ShapeShiftType>(shapeTy)) { 858 shapeTyRank = ss.getRank(); 859 } else { 860 auto s = mlir::cast<fir::ShiftType>(shapeTy); 861 shapeTyRank = s.getRank(); 862 if (!mlir::isa<fir::BaseBoxType>(getMemref().getType())) 863 return emitOpError("shift can only be provided with fir.box memref"); 864 } 865 if (arrDim && arrDim != shapeTyRank) 866 return emitOpError("rank of dimension mismatched"); 867 } 868 869 if (auto sliceOp = getSlice()) { 870 if (auto sl = mlir::dyn_cast_or_null<fir::SliceOp>(sliceOp.getDefiningOp())) 871 if (!sl.getSubstr().empty()) 872 return emitOpError("array_load cannot take a slice with substring"); 873 if (auto sliceTy = mlir::dyn_cast<fir::SliceType>(sliceOp.getType())) 874 if (sliceTy.getRank() != arrDim) 875 return emitOpError("rank of dimension in slice mismatched"); 876 } 877 878 if (!validTypeParams(getMemref().getType(), getTypeparams())) 879 return emitOpError("invalid type parameters"); 880 881 return mlir::success(); 882 } 883 884 //===----------------------------------------------------------------------===// 885 // ArrayMergeStoreOp 886 //===----------------------------------------------------------------------===// 887 888 llvm::LogicalResult fir::ArrayMergeStoreOp::verify() { 889 if (!mlir::isa<fir::ArrayLoadOp>(getOriginal().getDefiningOp())) 890 return emitOpError("operand #0 must be result of a fir.array_load op"); 891 if (auto sl = getSlice()) { 892 if (auto sliceOp = 893 mlir::dyn_cast_or_null<fir::SliceOp>(sl.getDefiningOp())) { 894 if (!sliceOp.getSubstr().empty()) 895 return emitOpError( 896 "array_merge_store cannot take a slice with substring"); 897 if (!sliceOp.getFields().empty()) { 898 // This is an intra-object merge, where the slice is projecting the 899 // subfields that are to be overwritten by the merge operation. 900 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); 901 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) { 902 auto projTy = 903 fir::applyPathToType(seqTy.getEleTy(), sliceOp.getFields()); 904 if (fir::unwrapSequenceType(getOriginal().getType()) != projTy) 905 return emitOpError( 906 "type of origin does not match sliced memref type"); 907 if (fir::unwrapSequenceType(getSequence().getType()) != projTy) 908 return emitOpError( 909 "type of sequence does not match sliced memref type"); 910 return mlir::success(); 911 } 912 return emitOpError("referenced type is not an array"); 913 } 914 } 915 return mlir::success(); 916 } 917 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(getMemref().getType()); 918 if (getOriginal().getType() != eleTy) 919 return emitOpError("type of origin does not match memref element type"); 920 if (getSequence().getType() != eleTy) 921 return emitOpError("type of sequence does not match memref element type"); 922 if (!validTypeParams(getMemref().getType(), getTypeparams())) 923 return emitOpError("invalid type parameters"); 924 return mlir::success(); 925 } 926 927 //===----------------------------------------------------------------------===// 928 // ArrayFetchOp 929 //===----------------------------------------------------------------------===// 930 931 // Template function used for both array_fetch and array_update verification. 932 template <typename A> 933 mlir::Type validArraySubobject(A op) { 934 auto ty = op.getSequence().getType(); 935 return fir::applyPathToType(ty, op.getIndices()); 936 } 937 938 llvm::LogicalResult fir::ArrayFetchOp::verify() { 939 auto arrTy = mlir::cast<fir::SequenceType>(getSequence().getType()); 940 auto indSize = getIndices().size(); 941 if (indSize < arrTy.getDimension()) 942 return emitOpError("number of indices != dimension of array"); 943 if (indSize == arrTy.getDimension() && 944 ::adjustedElementType(getElement().getType()) != arrTy.getEleTy()) 945 return emitOpError("return type does not match array"); 946 auto ty = validArraySubobject(*this); 947 if (!ty || ty != ::adjustedElementType(getType())) 948 return emitOpError("return type and/or indices do not type check"); 949 if (!mlir::isa<fir::ArrayLoadOp>(getSequence().getDefiningOp())) 950 return emitOpError("argument #0 must be result of fir.array_load"); 951 if (!validTypeParams(arrTy, getTypeparams())) 952 return emitOpError("invalid type parameters"); 953 return mlir::success(); 954 } 955 956 //===----------------------------------------------------------------------===// 957 // ArrayAccessOp 958 //===----------------------------------------------------------------------===// 959 960 llvm::LogicalResult fir::ArrayAccessOp::verify() { 961 auto arrTy = mlir::cast<fir::SequenceType>(getSequence().getType()); 962 std::size_t indSize = getIndices().size(); 963 if (indSize < arrTy.getDimension()) 964 return emitOpError("number of indices != dimension of array"); 965 if (indSize == arrTy.getDimension() && 966 getElement().getType() != fir::ReferenceType::get(arrTy.getEleTy())) 967 return emitOpError("return type does not match array"); 968 mlir::Type ty = validArraySubobject(*this); 969 if (!ty || fir::ReferenceType::get(ty) != getType()) 970 return emitOpError("return type and/or indices do not type check"); 971 if (!validTypeParams(arrTy, getTypeparams())) 972 return emitOpError("invalid type parameters"); 973 return mlir::success(); 974 } 975 976 //===----------------------------------------------------------------------===// 977 // ArrayUpdateOp 978 //===----------------------------------------------------------------------===// 979 980 llvm::LogicalResult fir::ArrayUpdateOp::verify() { 981 if (fir::isa_ref_type(getMerge().getType())) 982 return emitOpError("does not support reference type for merge"); 983 auto arrTy = mlir::cast<fir::SequenceType>(getSequence().getType()); 984 auto indSize = getIndices().size(); 985 if (indSize < arrTy.getDimension()) 986 return emitOpError("number of indices != dimension of array"); 987 if (indSize == arrTy.getDimension() && 988 ::adjustedElementType(getMerge().getType()) != arrTy.getEleTy()) 989 return emitOpError("merged value does not have element type"); 990 auto ty = validArraySubobject(*this); 991 if (!ty || ty != ::adjustedElementType(getMerge().getType())) 992 return emitOpError("merged value and/or indices do not type check"); 993 if (!validTypeParams(arrTy, getTypeparams())) 994 return emitOpError("invalid type parameters"); 995 return mlir::success(); 996 } 997 998 //===----------------------------------------------------------------------===// 999 // ArrayModifyOp 1000 //===----------------------------------------------------------------------===// 1001 1002 llvm::LogicalResult fir::ArrayModifyOp::verify() { 1003 auto arrTy = mlir::cast<fir::SequenceType>(getSequence().getType()); 1004 auto indSize = getIndices().size(); 1005 if (indSize < arrTy.getDimension()) 1006 return emitOpError("number of indices must match array dimension"); 1007 return mlir::success(); 1008 } 1009 1010 //===----------------------------------------------------------------------===// 1011 // BoxAddrOp 1012 //===----------------------------------------------------------------------===// 1013 1014 void fir::BoxAddrOp::build(mlir::OpBuilder &builder, 1015 mlir::OperationState &result, mlir::Value val) { 1016 mlir::Type type = 1017 llvm::TypeSwitch<mlir::Type, mlir::Type>(val.getType()) 1018 .Case<fir::BaseBoxType>([&](fir::BaseBoxType ty) -> mlir::Type { 1019 mlir::Type eleTy = ty.getEleTy(); 1020 if (fir::isa_ref_type(eleTy)) 1021 return eleTy; 1022 return fir::ReferenceType::get(eleTy); 1023 }) 1024 .Case<fir::BoxCharType>([&](fir::BoxCharType ty) -> mlir::Type { 1025 return fir::ReferenceType::get(ty.getEleTy()); 1026 }) 1027 .Case<fir::BoxProcType>( 1028 [&](fir::BoxProcType ty) { return ty.getEleTy(); }) 1029 .Default([&](const auto &) { return mlir::Type{}; }); 1030 assert(type && "bad val type"); 1031 build(builder, result, type, val); 1032 } 1033 1034 mlir::OpFoldResult fir::BoxAddrOp::fold(FoldAdaptor adaptor) { 1035 if (auto *v = getVal().getDefiningOp()) { 1036 if (auto box = mlir::dyn_cast<fir::EmboxOp>(v)) { 1037 // Fold only if not sliced 1038 if (!box.getSlice() && box.getMemref().getType() == getType()) { 1039 propagateAttributes(getOperation(), box.getMemref().getDefiningOp()); 1040 return box.getMemref(); 1041 } 1042 } 1043 if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v)) 1044 if (box.getMemref().getType() == getType()) 1045 return box.getMemref(); 1046 } 1047 return {}; 1048 } 1049 1050 //===----------------------------------------------------------------------===// 1051 // BoxCharLenOp 1052 //===----------------------------------------------------------------------===// 1053 1054 mlir::OpFoldResult fir::BoxCharLenOp::fold(FoldAdaptor adaptor) { 1055 if (auto v = getVal().getDefiningOp()) { 1056 if (auto box = mlir::dyn_cast<fir::EmboxCharOp>(v)) 1057 return box.getLen(); 1058 } 1059 return {}; 1060 } 1061 1062 //===----------------------------------------------------------------------===// 1063 // BoxDimsOp 1064 //===----------------------------------------------------------------------===// 1065 1066 /// Get the result types packed in a tuple tuple 1067 mlir::Type fir::BoxDimsOp::getTupleType() { 1068 // note: triple, but 4 is nearest power of 2 1069 llvm::SmallVector<mlir::Type> triple{ 1070 getResult(0).getType(), getResult(1).getType(), getResult(2).getType()}; 1071 return mlir::TupleType::get(getContext(), triple); 1072 } 1073 1074 //===----------------------------------------------------------------------===// 1075 // BoxRankOp 1076 //===----------------------------------------------------------------------===// 1077 1078 void fir::BoxRankOp::getEffects( 1079 llvm::SmallVectorImpl< 1080 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> 1081 &effects) { 1082 mlir::OpOperand &inputBox = getBoxMutable(); 1083 if (fir::isBoxAddress(inputBox.get().getType())) 1084 effects.emplace_back(mlir::MemoryEffects::Read::get(), &inputBox, 1085 mlir::SideEffects::DefaultResource::get()); 1086 } 1087 1088 //===----------------------------------------------------------------------===// 1089 // CallOp 1090 //===----------------------------------------------------------------------===// 1091 1092 mlir::FunctionType fir::CallOp::getFunctionType() { 1093 return mlir::FunctionType::get(getContext(), getOperandTypes(), 1094 getResultTypes()); 1095 } 1096 1097 void fir::CallOp::print(mlir::OpAsmPrinter &p) { 1098 bool isDirect = getCallee().has_value(); 1099 p << ' '; 1100 if (isDirect) 1101 p << *getCallee(); 1102 else 1103 p << getOperand(0); 1104 p << '(' << (*this)->getOperands().drop_front(isDirect ? 0 : 1) << ')'; 1105 1106 // Print `proc_attrs<...>`, if present. 1107 fir::FortranProcedureFlagsEnumAttr procAttrs = getProcedureAttrsAttr(); 1108 if (procAttrs && 1109 procAttrs.getValue() != fir::FortranProcedureFlagsEnum::none) { 1110 p << ' ' << fir::FortranProcedureFlagsEnumAttr::getMnemonic(); 1111 p.printStrippedAttrOrType(procAttrs); 1112 } 1113 1114 // Print 'fastmath<...>' (if it has non-default value) before 1115 // any other attributes. 1116 mlir::arith::FastMathFlagsAttr fmfAttr = getFastmathAttr(); 1117 if (fmfAttr.getValue() != mlir::arith::FastMathFlags::none) { 1118 p << ' ' << mlir::arith::FastMathFlagsAttr::getMnemonic(); 1119 p.printStrippedAttrOrType(fmfAttr); 1120 } 1121 1122 p.printOptionalAttrDict((*this)->getAttrs(), 1123 {fir::CallOp::getCalleeAttrNameStr(), 1124 getFastmathAttrName(), getProcedureAttrsAttrName()}); 1125 auto resultTypes{getResultTypes()}; 1126 llvm::SmallVector<mlir::Type> argTypes( 1127 llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1)); 1128 p << " : " << mlir::FunctionType::get(getContext(), argTypes, resultTypes); 1129 } 1130 1131 mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser, 1132 mlir::OperationState &result) { 1133 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; 1134 if (parser.parseOperandList(operands)) 1135 return mlir::failure(); 1136 1137 mlir::NamedAttrList attrs; 1138 mlir::SymbolRefAttr funcAttr; 1139 bool isDirect = operands.empty(); 1140 if (isDirect) 1141 if (parser.parseAttribute(funcAttr, fir::CallOp::getCalleeAttrNameStr(), 1142 attrs)) 1143 return mlir::failure(); 1144 1145 mlir::Type type; 1146 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren)) 1147 return mlir::failure(); 1148 1149 // Parse `proc_attrs<...>`, if present. 1150 fir::FortranProcedureFlagsEnumAttr procAttr; 1151 if (mlir::succeeded(parser.parseOptionalKeyword( 1152 fir::FortranProcedureFlagsEnumAttr::getMnemonic()))) 1153 if (parser.parseCustomAttributeWithFallback( 1154 procAttr, mlir::Type{}, getProcedureAttrsAttrName(result.name), 1155 attrs)) 1156 return mlir::failure(); 1157 1158 // Parse 'fastmath<...>', if present. 1159 mlir::arith::FastMathFlagsAttr fmfAttr; 1160 llvm::StringRef fmfAttrName = getFastmathAttrName(result.name); 1161 if (mlir::succeeded(parser.parseOptionalKeyword(fmfAttrName))) 1162 if (parser.parseCustomAttributeWithFallback(fmfAttr, mlir::Type{}, 1163 fmfAttrName, attrs)) 1164 return mlir::failure(); 1165 1166 if (parser.parseOptionalAttrDict(attrs) || parser.parseColon() || 1167 parser.parseType(type)) 1168 return mlir::failure(); 1169 1170 auto funcType = mlir::dyn_cast<mlir::FunctionType>(type); 1171 if (!funcType) 1172 return parser.emitError(parser.getNameLoc(), "expected function type"); 1173 if (isDirect) { 1174 if (parser.resolveOperands(operands, funcType.getInputs(), 1175 parser.getNameLoc(), result.operands)) 1176 return mlir::failure(); 1177 } else { 1178 auto funcArgs = 1179 llvm::ArrayRef<mlir::OpAsmParser::UnresolvedOperand>(operands) 1180 .drop_front(); 1181 if (parser.resolveOperand(operands[0], funcType, result.operands) || 1182 parser.resolveOperands(funcArgs, funcType.getInputs(), 1183 parser.getNameLoc(), result.operands)) 1184 return mlir::failure(); 1185 } 1186 result.addTypes(funcType.getResults()); 1187 result.attributes = attrs; 1188 return mlir::success(); 1189 } 1190 1191 void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 1192 mlir::func::FuncOp callee, mlir::ValueRange operands) { 1193 result.addOperands(operands); 1194 result.addAttribute(getCalleeAttrNameStr(), mlir::SymbolRefAttr::get(callee)); 1195 result.addTypes(callee.getFunctionType().getResults()); 1196 } 1197 1198 void fir::CallOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 1199 mlir::SymbolRefAttr callee, 1200 llvm::ArrayRef<mlir::Type> results, 1201 mlir::ValueRange operands) { 1202 result.addOperands(operands); 1203 if (callee) 1204 result.addAttribute(getCalleeAttrNameStr(), callee); 1205 result.addTypes(results); 1206 } 1207 1208 //===----------------------------------------------------------------------===// 1209 // CharConvertOp 1210 //===----------------------------------------------------------------------===// 1211 1212 llvm::LogicalResult fir::CharConvertOp::verify() { 1213 auto unwrap = [&](mlir::Type t) { 1214 t = fir::unwrapSequenceType(fir::dyn_cast_ptrEleTy(t)); 1215 return mlir::dyn_cast<fir::CharacterType>(t); 1216 }; 1217 auto inTy = unwrap(getFrom().getType()); 1218 auto outTy = unwrap(getTo().getType()); 1219 if (!(inTy && outTy)) 1220 return emitOpError("not a reference to a character"); 1221 if (inTy.getFKind() == outTy.getFKind()) 1222 return emitOpError("buffers must have different KIND values"); 1223 return mlir::success(); 1224 } 1225 1226 //===----------------------------------------------------------------------===// 1227 // CmpOp 1228 //===----------------------------------------------------------------------===// 1229 1230 template <typename OPTY> 1231 static void printCmpOp(mlir::OpAsmPrinter &p, OPTY op) { 1232 p << ' '; 1233 auto predSym = mlir::arith::symbolizeCmpFPredicate( 1234 op->template getAttrOfType<mlir::IntegerAttr>( 1235 OPTY::getPredicateAttrName()) 1236 .getInt()); 1237 assert(predSym.has_value() && "invalid symbol value for predicate"); 1238 p << '"' << mlir::arith::stringifyCmpFPredicate(predSym.value()) << '"' 1239 << ", "; 1240 p.printOperand(op.getLhs()); 1241 p << ", "; 1242 p.printOperand(op.getRhs()); 1243 p.printOptionalAttrDict(op->getAttrs(), 1244 /*elidedAttrs=*/{OPTY::getPredicateAttrName()}); 1245 p << " : " << op.getLhs().getType(); 1246 } 1247 1248 template <typename OPTY> 1249 static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, 1250 mlir::OperationState &result) { 1251 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> ops; 1252 mlir::NamedAttrList attrs; 1253 mlir::Attribute predicateNameAttr; 1254 mlir::Type type; 1255 if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), 1256 attrs) || 1257 parser.parseComma() || parser.parseOperandList(ops, 2) || 1258 parser.parseOptionalAttrDict(attrs) || parser.parseColonType(type) || 1259 parser.resolveOperands(ops, type, result.operands)) 1260 return mlir::failure(); 1261 1262 if (!mlir::isa<mlir::StringAttr>(predicateNameAttr)) 1263 return parser.emitError(parser.getNameLoc(), 1264 "expected string comparison predicate attribute"); 1265 1266 // Rewrite string attribute to an enum value. 1267 llvm::StringRef predicateName = 1268 mlir::cast<mlir::StringAttr>(predicateNameAttr).getValue(); 1269 auto predicate = fir::CmpcOp::getPredicateByName(predicateName); 1270 auto builder = parser.getBuilder(); 1271 mlir::Type i1Type = builder.getI1Type(); 1272 attrs.set(OPTY::getPredicateAttrName(), 1273 builder.getI64IntegerAttr(static_cast<std::int64_t>(predicate))); 1274 result.attributes = attrs; 1275 result.addTypes({i1Type}); 1276 return mlir::success(); 1277 } 1278 1279 //===----------------------------------------------------------------------===// 1280 // CmpcOp 1281 //===----------------------------------------------------------------------===// 1282 1283 void fir::buildCmpCOp(mlir::OpBuilder &builder, mlir::OperationState &result, 1284 mlir::arith::CmpFPredicate predicate, mlir::Value lhs, 1285 mlir::Value rhs) { 1286 result.addOperands({lhs, rhs}); 1287 result.types.push_back(builder.getI1Type()); 1288 result.addAttribute( 1289 fir::CmpcOp::getPredicateAttrName(), 1290 builder.getI64IntegerAttr(static_cast<std::int64_t>(predicate))); 1291 } 1292 1293 mlir::arith::CmpFPredicate 1294 fir::CmpcOp::getPredicateByName(llvm::StringRef name) { 1295 auto pred = mlir::arith::symbolizeCmpFPredicate(name); 1296 assert(pred.has_value() && "invalid predicate name"); 1297 return pred.value(); 1298 } 1299 1300 void fir::CmpcOp::print(mlir::OpAsmPrinter &p) { printCmpOp(p, *this); } 1301 1302 mlir::ParseResult fir::CmpcOp::parse(mlir::OpAsmParser &parser, 1303 mlir::OperationState &result) { 1304 return parseCmpOp<fir::CmpcOp>(parser, result); 1305 } 1306 1307 //===----------------------------------------------------------------------===// 1308 // ConvertOp 1309 //===----------------------------------------------------------------------===// 1310 1311 void fir::ConvertOp::getCanonicalizationPatterns( 1312 mlir::RewritePatternSet &results, mlir::MLIRContext *context) { 1313 results.insert<ConvertConvertOptPattern, ConvertAscendingIndexOptPattern, 1314 ConvertDescendingIndexOptPattern, RedundantConvertOptPattern, 1315 CombineConvertOptPattern, CombineConvertTruncOptPattern, 1316 ForwardConstantConvertPattern, ChainedPointerConvertsPattern>( 1317 context); 1318 } 1319 1320 mlir::OpFoldResult fir::ConvertOp::fold(FoldAdaptor adaptor) { 1321 if (getValue().getType() == getType()) 1322 return getValue(); 1323 if (matchPattern(getValue(), mlir::m_Op<fir::ConvertOp>())) { 1324 auto inner = mlir::cast<fir::ConvertOp>(getValue().getDefiningOp()); 1325 // (convert (convert 'a : logical -> i1) : i1 -> logical) ==> forward 'a 1326 if (auto toTy = mlir::dyn_cast<fir::LogicalType>(getType())) 1327 if (auto fromTy = 1328 mlir::dyn_cast<fir::LogicalType>(inner.getValue().getType())) 1329 if (mlir::isa<mlir::IntegerType>(inner.getType()) && (toTy == fromTy)) 1330 return inner.getValue(); 1331 // (convert (convert 'a : i1 -> logical) : logical -> i1) ==> forward 'a 1332 if (auto toTy = mlir::dyn_cast<mlir::IntegerType>(getType())) 1333 if (auto fromTy = 1334 mlir::dyn_cast<mlir::IntegerType>(inner.getValue().getType())) 1335 if (mlir::isa<fir::LogicalType>(inner.getType()) && (toTy == fromTy) && 1336 (fromTy.getWidth() == 1)) 1337 return inner.getValue(); 1338 } 1339 return {}; 1340 } 1341 1342 bool fir::ConvertOp::isInteger(mlir::Type ty) { 1343 return mlir::isa<mlir::IntegerType, mlir::IndexType, fir::IntegerType>(ty); 1344 } 1345 1346 bool fir::ConvertOp::isIntegerCompatible(mlir::Type ty) { 1347 return isInteger(ty) || mlir::isa<fir::LogicalType>(ty); 1348 } 1349 1350 bool fir::ConvertOp::isFloatCompatible(mlir::Type ty) { 1351 return mlir::isa<mlir::FloatType>(ty); 1352 } 1353 1354 bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) { 1355 return mlir::isa<fir::ReferenceType, fir::PointerType, fir::HeapType, 1356 fir::LLVMPointerType, mlir::MemRefType, mlir::FunctionType, 1357 fir::TypeDescType, mlir::LLVM::LLVMPointerType>(ty); 1358 } 1359 1360 static std::optional<mlir::Type> getVectorElementType(mlir::Type ty) { 1361 mlir::Type elemTy; 1362 if (mlir::isa<fir::VectorType>(ty)) 1363 elemTy = mlir::dyn_cast<fir::VectorType>(ty).getElementType(); 1364 else if (mlir::isa<mlir::VectorType>(ty)) 1365 elemTy = mlir::dyn_cast<mlir::VectorType>(ty).getElementType(); 1366 else 1367 return std::nullopt; 1368 1369 // e.g. fir.vector<4:ui32> => mlir.vector<4xi32> 1370 // e.g. mlir.vector<4xui32> => mlir.vector<4xi32> 1371 if (elemTy.isUnsignedInteger()) { 1372 elemTy = mlir::IntegerType::get( 1373 ty.getContext(), mlir::dyn_cast<mlir::IntegerType>(elemTy).getWidth()); 1374 } 1375 return elemTy; 1376 } 1377 1378 static std::optional<uint64_t> getVectorLen(mlir::Type ty) { 1379 if (mlir::isa<fir::VectorType>(ty)) 1380 return mlir::dyn_cast<fir::VectorType>(ty).getLen(); 1381 else if (mlir::isa<mlir::VectorType>(ty)) { 1382 // fir.vector only supports 1-D vector 1383 if (!(mlir::dyn_cast<mlir::VectorType>(ty).isScalable())) 1384 return mlir::dyn_cast<mlir::VectorType>(ty).getShape()[0]; 1385 } 1386 1387 return std::nullopt; 1388 } 1389 1390 bool fir::ConvertOp::areVectorsCompatible(mlir::Type inTy, mlir::Type outTy) { 1391 if (!(mlir::isa<fir::VectorType>(inTy) && 1392 mlir::isa<mlir::VectorType>(outTy)) && 1393 !(mlir::isa<mlir::VectorType>(inTy) && mlir::isa<fir::VectorType>(outTy))) 1394 return false; 1395 1396 // Only support integer, unsigned and real vector 1397 // Both vectors must have the same element type 1398 std::optional<mlir::Type> inElemTy = getVectorElementType(inTy); 1399 std::optional<mlir::Type> outElemTy = getVectorElementType(outTy); 1400 if (!inElemTy.has_value() || !outElemTy.has_value() || 1401 inElemTy.value() != outElemTy.value()) 1402 return false; 1403 1404 // Both vectors must have the same number of elements 1405 std::optional<uint64_t> inLen = getVectorLen(inTy); 1406 std::optional<uint64_t> outLen = getVectorLen(outTy); 1407 if (!inLen.has_value() || !outLen.has_value() || 1408 inLen.value() != outLen.value()) 1409 return false; 1410 1411 return true; 1412 } 1413 1414 static bool areRecordsCompatible(mlir::Type inTy, mlir::Type outTy) { 1415 // Both records must have the same field types. 1416 // Trust frontend semantics for in-depth checks, such as if both records 1417 // have the BIND(C) attribute. 1418 auto inRecTy = mlir::dyn_cast<fir::RecordType>(inTy); 1419 auto outRecTy = mlir::dyn_cast<fir::RecordType>(outTy); 1420 return inRecTy && outRecTy && inRecTy.getTypeList() == outRecTy.getTypeList(); 1421 } 1422 1423 bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) { 1424 if (inType == outType) 1425 return true; 1426 return (isPointerCompatible(inType) && isPointerCompatible(outType)) || 1427 (isIntegerCompatible(inType) && isIntegerCompatible(outType)) || 1428 (isInteger(inType) && isFloatCompatible(outType)) || 1429 (isFloatCompatible(inType) && isInteger(outType)) || 1430 (isFloatCompatible(inType) && isFloatCompatible(outType)) || 1431 (isIntegerCompatible(inType) && isPointerCompatible(outType)) || 1432 (isPointerCompatible(inType) && isIntegerCompatible(outType)) || 1433 (mlir::isa<fir::BoxType>(inType) && 1434 mlir::isa<fir::BoxType>(outType)) || 1435 (mlir::isa<fir::BoxProcType>(inType) && 1436 mlir::isa<fir::BoxProcType>(outType)) || 1437 (fir::isa_complex(inType) && fir::isa_complex(outType)) || 1438 (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) || 1439 (fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)) || 1440 (fir::isPolymorphicType(inType) && mlir::isa<BoxType>(outType)) || 1441 areVectorsCompatible(inType, outType) || 1442 areRecordsCompatible(inType, outType); 1443 } 1444 1445 llvm::LogicalResult fir::ConvertOp::verify() { 1446 if (canBeConverted(getValue().getType(), getType())) 1447 return mlir::success(); 1448 return emitOpError("invalid type conversion") 1449 << getValue().getType() << " / " << getType(); 1450 } 1451 1452 //===----------------------------------------------------------------------===// 1453 // CoordinateOp 1454 //===----------------------------------------------------------------------===// 1455 1456 void fir::CoordinateOp::print(mlir::OpAsmPrinter &p) { 1457 p << ' ' << getRef() << ", " << getCoor(); 1458 p.printOptionalAttrDict((*this)->getAttrs(), /*elideAttrs=*/{"baseType"}); 1459 p << " : "; 1460 p.printFunctionalType(getOperandTypes(), (*this)->getResultTypes()); 1461 } 1462 1463 mlir::ParseResult fir::CoordinateOp::parse(mlir::OpAsmParser &parser, 1464 mlir::OperationState &result) { 1465 mlir::OpAsmParser::UnresolvedOperand memref; 1466 if (parser.parseOperand(memref) || parser.parseComma()) 1467 return mlir::failure(); 1468 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> coorOperands; 1469 if (parser.parseOperandList(coorOperands)) 1470 return mlir::failure(); 1471 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> allOperands; 1472 allOperands.push_back(memref); 1473 allOperands.append(coorOperands.begin(), coorOperands.end()); 1474 mlir::FunctionType funcTy; 1475 auto loc = parser.getCurrentLocation(); 1476 if (parser.parseOptionalAttrDict(result.attributes) || 1477 parser.parseColonType(funcTy) || 1478 parser.resolveOperands(allOperands, funcTy.getInputs(), loc, 1479 result.operands) || 1480 parser.addTypesToList(funcTy.getResults(), result.types)) 1481 return mlir::failure(); 1482 result.addAttribute("baseType", mlir::TypeAttr::get(funcTy.getInput(0))); 1483 return mlir::success(); 1484 } 1485 1486 llvm::LogicalResult fir::CoordinateOp::verify() { 1487 const mlir::Type refTy = getRef().getType(); 1488 if (fir::isa_ref_type(refTy)) { 1489 auto eleTy = fir::dyn_cast_ptrEleTy(refTy); 1490 if (auto arrTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) { 1491 if (arrTy.hasUnknownShape()) 1492 return emitOpError("cannot find coordinate in unknown shape"); 1493 if (arrTy.getConstantRows() < arrTy.getDimension() - 1) 1494 return emitOpError("cannot find coordinate with unknown extents"); 1495 } 1496 if (!(fir::isa_aggregate(eleTy) || fir::isa_complex(eleTy) || 1497 fir::isa_char_string(eleTy))) 1498 return emitOpError("cannot apply to this element type"); 1499 } 1500 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(refTy); 1501 unsigned dimension = 0; 1502 const unsigned numCoors = getCoor().size(); 1503 for (auto coorOperand : llvm::enumerate(getCoor())) { 1504 auto co = coorOperand.value(); 1505 if (dimension == 0 && mlir::isa<fir::SequenceType>(eleTy)) { 1506 dimension = mlir::cast<fir::SequenceType>(eleTy).getDimension(); 1507 if (dimension == 0) 1508 return emitOpError("cannot apply to array of unknown rank"); 1509 } 1510 if (auto *defOp = co.getDefiningOp()) { 1511 if (auto index = mlir::dyn_cast<fir::LenParamIndexOp>(defOp)) { 1512 // Recovering a LEN type parameter only makes sense from a boxed 1513 // value. For a bare reference, the LEN type parameters must be 1514 // passed as additional arguments to `index`. 1515 if (mlir::isa<fir::BoxType>(refTy)) { 1516 if (coorOperand.index() != numCoors - 1) 1517 return emitOpError("len_param_index must be last argument"); 1518 if (getNumOperands() != 2) 1519 return emitOpError("too many operands for len_param_index case"); 1520 } 1521 if (eleTy != index.getOnType()) 1522 emitOpError( 1523 "len_param_index type not compatible with reference type"); 1524 return mlir::success(); 1525 } else if (auto index = mlir::dyn_cast<fir::FieldIndexOp>(defOp)) { 1526 if (eleTy != index.getOnType()) 1527 emitOpError("field_index type not compatible with reference type"); 1528 if (auto recTy = mlir::dyn_cast<fir::RecordType>(eleTy)) { 1529 eleTy = recTy.getType(index.getFieldName()); 1530 continue; 1531 } 1532 return emitOpError("field_index not applied to !fir.type"); 1533 } 1534 } 1535 if (dimension) { 1536 if (--dimension == 0) 1537 eleTy = mlir::cast<fir::SequenceType>(eleTy).getElementType(); 1538 } else { 1539 if (auto t = mlir::dyn_cast<mlir::TupleType>(eleTy)) { 1540 // FIXME: Generally, we don't know which field of the tuple is being 1541 // referred to unless the operand is a constant. Just assume everything 1542 // is good in the tuple case for now. 1543 return mlir::success(); 1544 } else if (auto t = mlir::dyn_cast<fir::RecordType>(eleTy)) { 1545 // FIXME: This is the same as the tuple case. 1546 return mlir::success(); 1547 } else if (auto t = mlir::dyn_cast<mlir::ComplexType>(eleTy)) { 1548 eleTy = t.getElementType(); 1549 } else if (auto t = mlir::dyn_cast<fir::CharacterType>(eleTy)) { 1550 if (t.getLen() == fir::CharacterType::singleton()) 1551 return emitOpError("cannot apply to character singleton"); 1552 eleTy = fir::CharacterType::getSingleton(t.getContext(), t.getFKind()); 1553 if (fir::unwrapRefType(getType()) != eleTy) 1554 return emitOpError("character type mismatch"); 1555 } else { 1556 return emitOpError("invalid parameters (too many)"); 1557 } 1558 } 1559 } 1560 return mlir::success(); 1561 } 1562 1563 //===----------------------------------------------------------------------===// 1564 // DispatchOp 1565 //===----------------------------------------------------------------------===// 1566 1567 llvm::LogicalResult fir::DispatchOp::verify() { 1568 // Check that pass_arg_pos is in range of actual operands. pass_arg_pos is 1569 // unsigned so check for less than zero is not needed. 1570 if (getPassArgPos() && *getPassArgPos() > (getArgOperands().size() - 1)) 1571 return emitOpError( 1572 "pass_arg_pos must be smaller than the number of operands"); 1573 1574 // Operand pointed by pass_arg_pos must have polymorphic type. 1575 if (getPassArgPos() && 1576 !fir::isPolymorphicType(getArgOperands()[*getPassArgPos()].getType())) 1577 return emitOpError("pass_arg_pos must be a polymorphic operand"); 1578 return mlir::success(); 1579 } 1580 1581 mlir::FunctionType fir::DispatchOp::getFunctionType() { 1582 return mlir::FunctionType::get(getContext(), getOperandTypes(), 1583 getResultTypes()); 1584 } 1585 1586 //===----------------------------------------------------------------------===// 1587 // TypeInfoOp 1588 //===----------------------------------------------------------------------===// 1589 1590 void fir::TypeInfoOp::build(mlir::OpBuilder &builder, 1591 mlir::OperationState &result, fir::RecordType type, 1592 fir::RecordType parentType, 1593 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1594 result.addRegion(); 1595 result.addRegion(); 1596 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 1597 builder.getStringAttr(type.getName())); 1598 result.addAttribute(getTypeAttrName(result.name), mlir::TypeAttr::get(type)); 1599 if (parentType) 1600 result.addAttribute(getParentTypeAttrName(result.name), 1601 mlir::TypeAttr::get(parentType)); 1602 result.addAttributes(attrs); 1603 } 1604 1605 llvm::LogicalResult fir::TypeInfoOp::verify() { 1606 if (!getDispatchTable().empty()) 1607 for (auto &op : getDispatchTable().front().without_terminator()) 1608 if (!mlir::isa<fir::DTEntryOp>(op)) 1609 return op.emitOpError("dispatch table must contain dt_entry"); 1610 1611 if (!mlir::isa<fir::RecordType>(getType())) 1612 return emitOpError("type must be a fir.type"); 1613 1614 if (getParentType() && !mlir::isa<fir::RecordType>(*getParentType())) 1615 return emitOpError("parent_type must be a fir.type"); 1616 return mlir::success(); 1617 } 1618 1619 //===----------------------------------------------------------------------===// 1620 // EmboxOp 1621 //===----------------------------------------------------------------------===// 1622 1623 llvm::LogicalResult fir::EmboxOp::verify() { 1624 auto eleTy = fir::dyn_cast_ptrEleTy(getMemref().getType()); 1625 bool isArray = false; 1626 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) { 1627 eleTy = seqTy.getEleTy(); 1628 isArray = true; 1629 } 1630 if (hasLenParams()) { 1631 auto lenPs = numLenParams(); 1632 if (auto rt = mlir::dyn_cast<fir::RecordType>(eleTy)) { 1633 if (lenPs != rt.getNumLenParams()) 1634 return emitOpError("number of LEN params does not correspond" 1635 " to the !fir.type type"); 1636 } else if (auto strTy = mlir::dyn_cast<fir::CharacterType>(eleTy)) { 1637 if (strTy.getLen() != fir::CharacterType::unknownLen()) 1638 return emitOpError("CHARACTER already has static LEN"); 1639 } else { 1640 return emitOpError("LEN parameters require CHARACTER or derived type"); 1641 } 1642 for (auto lp : getTypeparams()) 1643 if (!fir::isa_integer(lp.getType())) 1644 return emitOpError("LEN parameters must be integral type"); 1645 } 1646 if (getShape() && !isArray) 1647 return emitOpError("shape must not be provided for a scalar"); 1648 if (getSlice() && !isArray) 1649 return emitOpError("slice must not be provided for a scalar"); 1650 if (getSourceBox() && !mlir::isa<fir::ClassType>(getResult().getType())) 1651 return emitOpError("source_box must be used with fir.class result type"); 1652 return mlir::success(); 1653 } 1654 1655 //===----------------------------------------------------------------------===// 1656 // EmboxCharOp 1657 //===----------------------------------------------------------------------===// 1658 1659 llvm::LogicalResult fir::EmboxCharOp::verify() { 1660 auto eleTy = fir::dyn_cast_ptrEleTy(getMemref().getType()); 1661 if (!mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)) 1662 return mlir::failure(); 1663 return mlir::success(); 1664 } 1665 1666 //===----------------------------------------------------------------------===// 1667 // EmboxProcOp 1668 //===----------------------------------------------------------------------===// 1669 1670 llvm::LogicalResult fir::EmboxProcOp::verify() { 1671 // host bindings (optional) must be a reference to a tuple 1672 if (auto h = getHost()) { 1673 if (auto r = mlir::dyn_cast<fir::ReferenceType>(h.getType())) 1674 if (mlir::isa<mlir::TupleType>(r.getEleTy())) 1675 return mlir::success(); 1676 return mlir::failure(); 1677 } 1678 return mlir::success(); 1679 } 1680 1681 //===----------------------------------------------------------------------===// 1682 // TypeDescOp 1683 //===----------------------------------------------------------------------===// 1684 1685 void fir::TypeDescOp::build(mlir::OpBuilder &, mlir::OperationState &result, 1686 mlir::TypeAttr inty) { 1687 result.addAttribute("in_type", inty); 1688 result.addTypes(TypeDescType::get(inty.getValue())); 1689 } 1690 1691 mlir::ParseResult fir::TypeDescOp::parse(mlir::OpAsmParser &parser, 1692 mlir::OperationState &result) { 1693 mlir::Type intype; 1694 if (parser.parseType(intype)) 1695 return mlir::failure(); 1696 result.addAttribute("in_type", mlir::TypeAttr::get(intype)); 1697 mlir::Type restype = fir::TypeDescType::get(intype); 1698 if (parser.addTypeToList(restype, result.types)) 1699 return mlir::failure(); 1700 return mlir::success(); 1701 } 1702 1703 void fir::TypeDescOp::print(mlir::OpAsmPrinter &p) { 1704 p << ' ' << getOperation()->getAttr("in_type"); 1705 p.printOptionalAttrDict(getOperation()->getAttrs(), {"in_type"}); 1706 } 1707 1708 llvm::LogicalResult fir::TypeDescOp::verify() { 1709 mlir::Type resultTy = getType(); 1710 if (auto tdesc = mlir::dyn_cast<fir::TypeDescType>(resultTy)) { 1711 if (tdesc.getOfTy() != getInType()) 1712 return emitOpError("wrapped type mismatched"); 1713 return mlir::success(); 1714 } 1715 return emitOpError("must be !fir.tdesc type"); 1716 } 1717 1718 //===----------------------------------------------------------------------===// 1719 // GlobalOp 1720 //===----------------------------------------------------------------------===// 1721 1722 mlir::Type fir::GlobalOp::resultType() { 1723 return wrapAllocaResultType(getType()); 1724 } 1725 1726 mlir::ParseResult fir::GlobalOp::parse(mlir::OpAsmParser &parser, 1727 mlir::OperationState &result) { 1728 // Parse the optional linkage 1729 llvm::StringRef linkage; 1730 auto &builder = parser.getBuilder(); 1731 if (mlir::succeeded(parser.parseOptionalKeyword(&linkage))) { 1732 if (fir::GlobalOp::verifyValidLinkage(linkage)) 1733 return mlir::failure(); 1734 mlir::StringAttr linkAttr = builder.getStringAttr(linkage); 1735 result.addAttribute(fir::GlobalOp::getLinkNameAttrName(result.name), 1736 linkAttr); 1737 } 1738 1739 // Parse the name as a symbol reference attribute. 1740 mlir::SymbolRefAttr nameAttr; 1741 if (parser.parseAttribute(nameAttr, 1742 fir::GlobalOp::getSymrefAttrName(result.name), 1743 result.attributes)) 1744 return mlir::failure(); 1745 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 1746 nameAttr.getRootReference()); 1747 1748 bool simpleInitializer = false; 1749 if (mlir::succeeded(parser.parseOptionalLParen())) { 1750 mlir::Attribute attr; 1751 if (parser.parseAttribute(attr, getInitValAttrName(result.name), 1752 result.attributes) || 1753 parser.parseRParen()) 1754 return mlir::failure(); 1755 simpleInitializer = true; 1756 } 1757 1758 if (parser.parseOptionalAttrDict(result.attributes)) 1759 return mlir::failure(); 1760 1761 if (succeeded( 1762 parser.parseOptionalKeyword(getConstantAttrName(result.name)))) { 1763 // if "constant" keyword then mark this as a constant, not a variable 1764 result.addAttribute(getConstantAttrName(result.name), 1765 builder.getUnitAttr()); 1766 } 1767 1768 if (succeeded(parser.parseOptionalKeyword(getTargetAttrName(result.name)))) 1769 result.addAttribute(getTargetAttrName(result.name), builder.getUnitAttr()); 1770 1771 mlir::Type globalType; 1772 if (parser.parseColonType(globalType)) 1773 return mlir::failure(); 1774 1775 result.addAttribute(fir::GlobalOp::getTypeAttrName(result.name), 1776 mlir::TypeAttr::get(globalType)); 1777 1778 if (simpleInitializer) { 1779 result.addRegion(); 1780 } else { 1781 // Parse the optional initializer body. 1782 auto parseResult = 1783 parser.parseOptionalRegion(*result.addRegion(), /*arguments=*/{}); 1784 if (parseResult.has_value() && mlir::failed(*parseResult)) 1785 return mlir::failure(); 1786 } 1787 return mlir::success(); 1788 } 1789 1790 void fir::GlobalOp::print(mlir::OpAsmPrinter &p) { 1791 if (getLinkName()) 1792 p << ' ' << *getLinkName(); 1793 p << ' '; 1794 p.printAttributeWithoutType(getSymrefAttr()); 1795 if (auto val = getValueOrNull()) 1796 p << '(' << val << ')'; 1797 // Print all other attributes that are not pretty printed here. 1798 p.printOptionalAttrDict((*this)->getAttrs(), /*elideAttrs=*/{ 1799 getSymNameAttrName(), getSymrefAttrName(), 1800 getTypeAttrName(), getConstantAttrName(), 1801 getTargetAttrName(), getLinkNameAttrName(), 1802 getInitValAttrName()}); 1803 if (getOperation()->getAttr(getConstantAttrName())) 1804 p << " " << getConstantAttrName().strref(); 1805 if (getOperation()->getAttr(getTargetAttrName())) 1806 p << " " << getTargetAttrName().strref(); 1807 p << " : "; 1808 p.printType(getType()); 1809 if (hasInitializationBody()) { 1810 p << ' '; 1811 p.printRegion(getOperation()->getRegion(0), 1812 /*printEntryBlockArgs=*/false, 1813 /*printBlockTerminators=*/true); 1814 } 1815 } 1816 1817 void fir::GlobalOp::appendInitialValue(mlir::Operation *op) { 1818 getBlock().getOperations().push_back(op); 1819 } 1820 1821 void fir::GlobalOp::build(mlir::OpBuilder &builder, 1822 mlir::OperationState &result, llvm::StringRef name, 1823 bool isConstant, bool isTarget, mlir::Type type, 1824 mlir::Attribute initialVal, mlir::StringAttr linkage, 1825 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1826 result.addRegion(); 1827 result.addAttribute(getTypeAttrName(result.name), mlir::TypeAttr::get(type)); 1828 result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), 1829 builder.getStringAttr(name)); 1830 result.addAttribute(getSymrefAttrName(result.name), 1831 mlir::SymbolRefAttr::get(builder.getContext(), name)); 1832 if (isConstant) 1833 result.addAttribute(getConstantAttrName(result.name), 1834 builder.getUnitAttr()); 1835 if (isTarget) 1836 result.addAttribute(getTargetAttrName(result.name), builder.getUnitAttr()); 1837 if (initialVal) 1838 result.addAttribute(getInitValAttrName(result.name), initialVal); 1839 if (linkage) 1840 result.addAttribute(getLinkNameAttrName(result.name), linkage); 1841 result.attributes.append(attrs.begin(), attrs.end()); 1842 } 1843 1844 void fir::GlobalOp::build(mlir::OpBuilder &builder, 1845 mlir::OperationState &result, llvm::StringRef name, 1846 mlir::Type type, mlir::Attribute initialVal, 1847 mlir::StringAttr linkage, 1848 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1849 build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, 1850 {}, linkage, attrs); 1851 } 1852 1853 void fir::GlobalOp::build(mlir::OpBuilder &builder, 1854 mlir::OperationState &result, llvm::StringRef name, 1855 bool isConstant, bool isTarget, mlir::Type type, 1856 mlir::StringAttr linkage, 1857 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1858 build(builder, result, name, isConstant, isTarget, type, {}, linkage, attrs); 1859 } 1860 1861 void fir::GlobalOp::build(mlir::OpBuilder &builder, 1862 mlir::OperationState &result, llvm::StringRef name, 1863 mlir::Type type, mlir::StringAttr linkage, 1864 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1865 build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, 1866 {}, linkage, attrs); 1867 } 1868 1869 void fir::GlobalOp::build(mlir::OpBuilder &builder, 1870 mlir::OperationState &result, llvm::StringRef name, 1871 bool isConstant, bool isTarget, mlir::Type type, 1872 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1873 build(builder, result, name, isConstant, isTarget, type, mlir::StringAttr{}, 1874 attrs); 1875 } 1876 1877 void fir::GlobalOp::build(mlir::OpBuilder &builder, 1878 mlir::OperationState &result, llvm::StringRef name, 1879 mlir::Type type, 1880 llvm::ArrayRef<mlir::NamedAttribute> attrs) { 1881 build(builder, result, name, /*isConstant=*/false, /*isTarget=*/false, type, 1882 attrs); 1883 } 1884 1885 mlir::ParseResult fir::GlobalOp::verifyValidLinkage(llvm::StringRef linkage) { 1886 // Supporting only a subset of the LLVM linkage types for now 1887 static const char *validNames[] = {"common", "internal", "linkonce", 1888 "linkonce_odr", "weak"}; 1889 return mlir::success(llvm::is_contained(validNames, linkage)); 1890 } 1891 1892 //===----------------------------------------------------------------------===// 1893 // GlobalLenOp 1894 //===----------------------------------------------------------------------===// 1895 1896 mlir::ParseResult fir::GlobalLenOp::parse(mlir::OpAsmParser &parser, 1897 mlir::OperationState &result) { 1898 llvm::StringRef fieldName; 1899 if (failed(parser.parseOptionalKeyword(&fieldName))) { 1900 mlir::StringAttr fieldAttr; 1901 if (parser.parseAttribute(fieldAttr, 1902 fir::GlobalLenOp::getLenParamAttrName(), 1903 result.attributes)) 1904 return mlir::failure(); 1905 } else { 1906 result.addAttribute(fir::GlobalLenOp::getLenParamAttrName(), 1907 parser.getBuilder().getStringAttr(fieldName)); 1908 } 1909 mlir::IntegerAttr constant; 1910 if (parser.parseComma() || 1911 parser.parseAttribute(constant, fir::GlobalLenOp::getIntAttrName(), 1912 result.attributes)) 1913 return mlir::failure(); 1914 return mlir::success(); 1915 } 1916 1917 void fir::GlobalLenOp::print(mlir::OpAsmPrinter &p) { 1918 p << ' ' << getOperation()->getAttr(fir::GlobalLenOp::getLenParamAttrName()) 1919 << ", " << getOperation()->getAttr(fir::GlobalLenOp::getIntAttrName()); 1920 } 1921 1922 //===----------------------------------------------------------------------===// 1923 // FieldIndexOp 1924 //===----------------------------------------------------------------------===// 1925 1926 template <typename TY> 1927 mlir::ParseResult parseFieldLikeOp(mlir::OpAsmParser &parser, 1928 mlir::OperationState &result) { 1929 llvm::StringRef fieldName; 1930 auto &builder = parser.getBuilder(); 1931 mlir::Type recty; 1932 if (parser.parseOptionalKeyword(&fieldName) || parser.parseComma() || 1933 parser.parseType(recty)) 1934 return mlir::failure(); 1935 result.addAttribute(fir::FieldIndexOp::getFieldAttrName(), 1936 builder.getStringAttr(fieldName)); 1937 if (!mlir::dyn_cast<fir::RecordType>(recty)) 1938 return mlir::failure(); 1939 result.addAttribute(fir::FieldIndexOp::getTypeAttrName(), 1940 mlir::TypeAttr::get(recty)); 1941 if (!parser.parseOptionalLParen()) { 1942 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; 1943 llvm::SmallVector<mlir::Type> types; 1944 auto loc = parser.getNameLoc(); 1945 if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || 1946 parser.parseColonTypeList(types) || parser.parseRParen() || 1947 parser.resolveOperands(operands, types, loc, result.operands)) 1948 return mlir::failure(); 1949 } 1950 mlir::Type fieldType = TY::get(builder.getContext()); 1951 if (parser.addTypeToList(fieldType, result.types)) 1952 return mlir::failure(); 1953 return mlir::success(); 1954 } 1955 1956 mlir::ParseResult fir::FieldIndexOp::parse(mlir::OpAsmParser &parser, 1957 mlir::OperationState &result) { 1958 return parseFieldLikeOp<fir::FieldType>(parser, result); 1959 } 1960 1961 template <typename OP> 1962 void printFieldLikeOp(mlir::OpAsmPrinter &p, OP &op) { 1963 p << ' ' 1964 << op.getOperation() 1965 ->template getAttrOfType<mlir::StringAttr>( 1966 fir::FieldIndexOp::getFieldAttrName()) 1967 .getValue() 1968 << ", " << op.getOperation()->getAttr(fir::FieldIndexOp::getTypeAttrName()); 1969 if (op.getNumOperands()) { 1970 p << '('; 1971 p.printOperands(op.getTypeparams()); 1972 auto sep = ") : "; 1973 for (auto op : op.getTypeparams()) { 1974 p << sep; 1975 if (op) 1976 p.printType(op.getType()); 1977 else 1978 p << "()"; 1979 sep = ", "; 1980 } 1981 } 1982 } 1983 1984 void fir::FieldIndexOp::print(mlir::OpAsmPrinter &p) { 1985 printFieldLikeOp(p, *this); 1986 } 1987 1988 void fir::FieldIndexOp::build(mlir::OpBuilder &builder, 1989 mlir::OperationState &result, 1990 llvm::StringRef fieldName, mlir::Type recTy, 1991 mlir::ValueRange operands) { 1992 result.addAttribute(getFieldAttrName(), builder.getStringAttr(fieldName)); 1993 result.addAttribute(getTypeAttrName(), mlir::TypeAttr::get(recTy)); 1994 result.addOperands(operands); 1995 } 1996 1997 llvm::SmallVector<mlir::Attribute> fir::FieldIndexOp::getAttributes() { 1998 llvm::SmallVector<mlir::Attribute> attrs; 1999 attrs.push_back(getFieldIdAttr()); 2000 attrs.push_back(getOnTypeAttr()); 2001 return attrs; 2002 } 2003 2004 //===----------------------------------------------------------------------===// 2005 // InsertOnRangeOp 2006 //===----------------------------------------------------------------------===// 2007 2008 static mlir::ParseResult 2009 parseCustomRangeSubscript(mlir::OpAsmParser &parser, 2010 mlir::DenseIntElementsAttr &coord) { 2011 llvm::SmallVector<std::int64_t> lbounds; 2012 llvm::SmallVector<std::int64_t> ubounds; 2013 if (parser.parseKeyword("from") || 2014 parser.parseCommaSeparatedList( 2015 mlir::AsmParser::Delimiter::Paren, 2016 [&] { return parser.parseInteger(lbounds.emplace_back(0)); }) || 2017 parser.parseKeyword("to") || 2018 parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::Paren, [&] { 2019 return parser.parseInteger(ubounds.emplace_back(0)); 2020 })) 2021 return mlir::failure(); 2022 llvm::SmallVector<std::int64_t> zippedBounds; 2023 for (auto zip : llvm::zip(lbounds, ubounds)) { 2024 zippedBounds.push_back(std::get<0>(zip)); 2025 zippedBounds.push_back(std::get<1>(zip)); 2026 } 2027 coord = mlir::Builder(parser.getContext()).getIndexTensorAttr(zippedBounds); 2028 return mlir::success(); 2029 } 2030 2031 static void printCustomRangeSubscript(mlir::OpAsmPrinter &printer, 2032 fir::InsertOnRangeOp op, 2033 mlir::DenseIntElementsAttr coord) { 2034 printer << "from ("; 2035 auto enumerate = llvm::enumerate(coord.getValues<std::int64_t>()); 2036 // Even entries are the lower bounds. 2037 llvm::interleaveComma( 2038 make_filter_range( 2039 enumerate, 2040 [](auto indexed_value) { return indexed_value.index() % 2 == 0; }), 2041 printer, [&](auto indexed_value) { printer << indexed_value.value(); }); 2042 printer << ") to ("; 2043 // Odd entries are the upper bounds. 2044 llvm::interleaveComma( 2045 make_filter_range( 2046 enumerate, 2047 [](auto indexed_value) { return indexed_value.index() % 2 != 0; }), 2048 printer, [&](auto indexed_value) { printer << indexed_value.value(); }); 2049 printer << ")"; 2050 } 2051 2052 /// Range bounds must be nonnegative, and the range must not be empty. 2053 llvm::LogicalResult fir::InsertOnRangeOp::verify() { 2054 if (fir::hasDynamicSize(getSeq().getType())) 2055 return emitOpError("must have constant shape and size"); 2056 mlir::DenseIntElementsAttr coorAttr = getCoor(); 2057 if (coorAttr.size() < 2 || coorAttr.size() % 2 != 0) 2058 return emitOpError("has uneven number of values in ranges"); 2059 bool rangeIsKnownToBeNonempty = false; 2060 for (auto i = coorAttr.getValues<std::int64_t>().end(), 2061 b = coorAttr.getValues<std::int64_t>().begin(); 2062 i != b;) { 2063 int64_t ub = (*--i); 2064 int64_t lb = (*--i); 2065 if (lb < 0 || ub < 0) 2066 return emitOpError("negative range bound"); 2067 if (rangeIsKnownToBeNonempty) 2068 continue; 2069 if (lb > ub) 2070 return emitOpError("empty range"); 2071 rangeIsKnownToBeNonempty = lb < ub; 2072 } 2073 return mlir::success(); 2074 } 2075 2076 //===----------------------------------------------------------------------===// 2077 // InsertValueOp 2078 //===----------------------------------------------------------------------===// 2079 2080 static bool checkIsIntegerConstant(mlir::Attribute attr, std::int64_t conVal) { 2081 if (auto iattr = mlir::dyn_cast<mlir::IntegerAttr>(attr)) 2082 return iattr.getInt() == conVal; 2083 return false; 2084 } 2085 2086 static bool isZero(mlir::Attribute a) { return checkIsIntegerConstant(a, 0); } 2087 static bool isOne(mlir::Attribute a) { return checkIsIntegerConstant(a, 1); } 2088 2089 // Undo some complex patterns created in the front-end and turn them back into 2090 // complex ops. 2091 template <typename FltOp, typename CpxOp> 2092 struct UndoComplexPattern : public mlir::RewritePattern { 2093 UndoComplexPattern(mlir::MLIRContext *ctx) 2094 : mlir::RewritePattern("fir.insert_value", 2, ctx) {} 2095 2096 llvm::LogicalResult 2097 matchAndRewrite(mlir::Operation *op, 2098 mlir::PatternRewriter &rewriter) const override { 2099 auto insval = mlir::dyn_cast_or_null<fir::InsertValueOp>(op); 2100 if (!insval || !mlir::isa<mlir::ComplexType>(insval.getType())) 2101 return mlir::failure(); 2102 auto insval2 = mlir::dyn_cast_or_null<fir::InsertValueOp>( 2103 insval.getAdt().getDefiningOp()); 2104 if (!insval2) 2105 return mlir::failure(); 2106 auto binf = mlir::dyn_cast_or_null<FltOp>(insval.getVal().getDefiningOp()); 2107 auto binf2 = 2108 mlir::dyn_cast_or_null<FltOp>(insval2.getVal().getDefiningOp()); 2109 if (!binf || !binf2 || insval.getCoor().size() != 1 || 2110 !isOne(insval.getCoor()[0]) || insval2.getCoor().size() != 1 || 2111 !isZero(insval2.getCoor()[0])) 2112 return mlir::failure(); 2113 auto eai = mlir::dyn_cast_or_null<fir::ExtractValueOp>( 2114 binf.getLhs().getDefiningOp()); 2115 auto ebi = mlir::dyn_cast_or_null<fir::ExtractValueOp>( 2116 binf.getRhs().getDefiningOp()); 2117 auto ear = mlir::dyn_cast_or_null<fir::ExtractValueOp>( 2118 binf2.getLhs().getDefiningOp()); 2119 auto ebr = mlir::dyn_cast_or_null<fir::ExtractValueOp>( 2120 binf2.getRhs().getDefiningOp()); 2121 if (!eai || !ebi || !ear || !ebr || ear.getAdt() != eai.getAdt() || 2122 ebr.getAdt() != ebi.getAdt() || eai.getCoor().size() != 1 || 2123 !isOne(eai.getCoor()[0]) || ebi.getCoor().size() != 1 || 2124 !isOne(ebi.getCoor()[0]) || ear.getCoor().size() != 1 || 2125 !isZero(ear.getCoor()[0]) || ebr.getCoor().size() != 1 || 2126 !isZero(ebr.getCoor()[0])) 2127 return mlir::failure(); 2128 rewriter.replaceOpWithNewOp<CpxOp>(op, ear.getAdt(), ebr.getAdt()); 2129 return mlir::success(); 2130 } 2131 }; 2132 2133 void fir::InsertValueOp::getCanonicalizationPatterns( 2134 mlir::RewritePatternSet &results, mlir::MLIRContext *context) { 2135 results.insert<UndoComplexPattern<mlir::arith::AddFOp, fir::AddcOp>, 2136 UndoComplexPattern<mlir::arith::SubFOp, fir::SubcOp>>(context); 2137 } 2138 2139 //===----------------------------------------------------------------------===// 2140 // IterWhileOp 2141 //===----------------------------------------------------------------------===// 2142 2143 void fir::IterWhileOp::build(mlir::OpBuilder &builder, 2144 mlir::OperationState &result, mlir::Value lb, 2145 mlir::Value ub, mlir::Value step, 2146 mlir::Value iterate, bool finalCountValue, 2147 mlir::ValueRange iterArgs, 2148 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 2149 result.addOperands({lb, ub, step, iterate}); 2150 if (finalCountValue) { 2151 result.addTypes(builder.getIndexType()); 2152 result.addAttribute(getFinalValueAttrNameStr(), builder.getUnitAttr()); 2153 } 2154 result.addTypes(iterate.getType()); 2155 result.addOperands(iterArgs); 2156 for (auto v : iterArgs) 2157 result.addTypes(v.getType()); 2158 mlir::Region *bodyRegion = result.addRegion(); 2159 bodyRegion->push_back(new mlir::Block{}); 2160 bodyRegion->front().addArgument(builder.getIndexType(), result.location); 2161 bodyRegion->front().addArgument(iterate.getType(), result.location); 2162 bodyRegion->front().addArguments( 2163 iterArgs.getTypes(), 2164 llvm::SmallVector<mlir::Location>(iterArgs.size(), result.location)); 2165 result.addAttributes(attributes); 2166 } 2167 2168 mlir::ParseResult fir::IterWhileOp::parse(mlir::OpAsmParser &parser, 2169 mlir::OperationState &result) { 2170 auto &builder = parser.getBuilder(); 2171 mlir::OpAsmParser::Argument inductionVariable, iterateVar; 2172 mlir::OpAsmParser::UnresolvedOperand lb, ub, step, iterateInput; 2173 if (parser.parseLParen() || parser.parseArgument(inductionVariable) || 2174 parser.parseEqual()) 2175 return mlir::failure(); 2176 2177 // Parse loop bounds. 2178 auto indexType = builder.getIndexType(); 2179 auto i1Type = builder.getIntegerType(1); 2180 if (parser.parseOperand(lb) || 2181 parser.resolveOperand(lb, indexType, result.operands) || 2182 parser.parseKeyword("to") || parser.parseOperand(ub) || 2183 parser.resolveOperand(ub, indexType, result.operands) || 2184 parser.parseKeyword("step") || parser.parseOperand(step) || 2185 parser.parseRParen() || 2186 parser.resolveOperand(step, indexType, result.operands) || 2187 parser.parseKeyword("and") || parser.parseLParen() || 2188 parser.parseArgument(iterateVar) || parser.parseEqual() || 2189 parser.parseOperand(iterateInput) || parser.parseRParen() || 2190 parser.resolveOperand(iterateInput, i1Type, result.operands)) 2191 return mlir::failure(); 2192 2193 // Parse the initial iteration arguments. 2194 auto prependCount = false; 2195 2196 // Induction variable. 2197 llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs; 2198 regionArgs.push_back(inductionVariable); 2199 regionArgs.push_back(iterateVar); 2200 2201 if (succeeded(parser.parseOptionalKeyword("iter_args"))) { 2202 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; 2203 llvm::SmallVector<mlir::Type> regionTypes; 2204 // Parse assignment list and results type list. 2205 if (parser.parseAssignmentList(regionArgs, operands) || 2206 parser.parseArrowTypeList(regionTypes)) 2207 return mlir::failure(); 2208 if (regionTypes.size() == operands.size() + 2) 2209 prependCount = true; 2210 llvm::ArrayRef<mlir::Type> resTypes = regionTypes; 2211 resTypes = prependCount ? resTypes.drop_front(2) : resTypes; 2212 // Resolve input operands. 2213 for (auto operandType : llvm::zip(operands, resTypes)) 2214 if (parser.resolveOperand(std::get<0>(operandType), 2215 std::get<1>(operandType), result.operands)) 2216 return mlir::failure(); 2217 if (prependCount) { 2218 result.addTypes(regionTypes); 2219 } else { 2220 result.addTypes(i1Type); 2221 result.addTypes(resTypes); 2222 } 2223 } else if (succeeded(parser.parseOptionalArrow())) { 2224 llvm::SmallVector<mlir::Type> typeList; 2225 if (parser.parseLParen() || parser.parseTypeList(typeList) || 2226 parser.parseRParen()) 2227 return mlir::failure(); 2228 // Type list must be "(index, i1)". 2229 if (typeList.size() != 2 || !mlir::isa<mlir::IndexType>(typeList[0]) || 2230 !typeList[1].isSignlessInteger(1)) 2231 return mlir::failure(); 2232 result.addTypes(typeList); 2233 prependCount = true; 2234 } else { 2235 result.addTypes(i1Type); 2236 } 2237 2238 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 2239 return mlir::failure(); 2240 2241 llvm::SmallVector<mlir::Type> argTypes; 2242 // Induction variable (hidden) 2243 if (prependCount) 2244 result.addAttribute(IterWhileOp::getFinalValueAttrNameStr(), 2245 builder.getUnitAttr()); 2246 else 2247 argTypes.push_back(indexType); 2248 // Loop carried variables (including iterate) 2249 argTypes.append(result.types.begin(), result.types.end()); 2250 // Parse the body region. 2251 auto *body = result.addRegion(); 2252 if (regionArgs.size() != argTypes.size()) 2253 return parser.emitError( 2254 parser.getNameLoc(), 2255 "mismatch in number of loop-carried values and defined values"); 2256 2257 for (size_t i = 0, e = regionArgs.size(); i != e; ++i) 2258 regionArgs[i].type = argTypes[i]; 2259 2260 if (parser.parseRegion(*body, regionArgs)) 2261 return mlir::failure(); 2262 2263 fir::IterWhileOp::ensureTerminator(*body, builder, result.location); 2264 return mlir::success(); 2265 } 2266 2267 llvm::LogicalResult fir::IterWhileOp::verify() { 2268 // Check that the body defines as single block argument for the induction 2269 // variable. 2270 auto *body = getBody(); 2271 if (!body->getArgument(1).getType().isInteger(1)) 2272 return emitOpError( 2273 "expected body second argument to be an index argument for " 2274 "the induction variable"); 2275 if (!body->getArgument(0).getType().isIndex()) 2276 return emitOpError( 2277 "expected body first argument to be an index argument for " 2278 "the induction variable"); 2279 2280 auto opNumResults = getNumResults(); 2281 if (getFinalValue()) { 2282 // Result type must be "(index, i1, ...)". 2283 if (!mlir::isa<mlir::IndexType>(getResult(0).getType())) 2284 return emitOpError("result #0 expected to be index"); 2285 if (!getResult(1).getType().isSignlessInteger(1)) 2286 return emitOpError("result #1 expected to be i1"); 2287 opNumResults--; 2288 } else { 2289 // iterate_while always returns the early exit induction value. 2290 // Result type must be "(i1, ...)" 2291 if (!getResult(0).getType().isSignlessInteger(1)) 2292 return emitOpError("result #0 expected to be i1"); 2293 } 2294 if (opNumResults == 0) 2295 return mlir::failure(); 2296 if (getNumIterOperands() != opNumResults) 2297 return emitOpError( 2298 "mismatch in number of loop-carried values and defined values"); 2299 if (getNumRegionIterArgs() != opNumResults) 2300 return emitOpError( 2301 "mismatch in number of basic block args and defined values"); 2302 auto iterOperands = getIterOperands(); 2303 auto iterArgs = getRegionIterArgs(); 2304 auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); 2305 unsigned i = 0u; 2306 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 2307 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 2308 return emitOpError() << "types mismatch between " << i 2309 << "th iter operand and defined value"; 2310 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 2311 return emitOpError() << "types mismatch between " << i 2312 << "th iter region arg and defined value"; 2313 2314 i++; 2315 } 2316 return mlir::success(); 2317 } 2318 2319 void fir::IterWhileOp::print(mlir::OpAsmPrinter &p) { 2320 p << " (" << getInductionVar() << " = " << getLowerBound() << " to " 2321 << getUpperBound() << " step " << getStep() << ") and ("; 2322 assert(hasIterOperands()); 2323 auto regionArgs = getRegionIterArgs(); 2324 auto operands = getIterOperands(); 2325 p << regionArgs.front() << " = " << *operands.begin() << ")"; 2326 if (regionArgs.size() > 1) { 2327 p << " iter_args("; 2328 llvm::interleaveComma( 2329 llvm::zip(regionArgs.drop_front(), operands.drop_front()), p, 2330 [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); 2331 p << ") -> ("; 2332 llvm::interleaveComma( 2333 llvm::drop_begin(getResultTypes(), getFinalValue() ? 0 : 1), p); 2334 p << ")"; 2335 } else if (getFinalValue()) { 2336 p << " -> (" << getResultTypes() << ')'; 2337 } 2338 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), 2339 {getFinalValueAttrNameStr()}); 2340 p << ' '; 2341 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 2342 /*printBlockTerminators=*/true); 2343 } 2344 2345 llvm::SmallVector<mlir::Region *> fir::IterWhileOp::getLoopRegions() { 2346 return {&getRegion()}; 2347 } 2348 2349 mlir::BlockArgument fir::IterWhileOp::iterArgToBlockArg(mlir::Value iterArg) { 2350 for (auto i : llvm::enumerate(getInitArgs())) 2351 if (iterArg == i.value()) 2352 return getRegion().front().getArgument(i.index() + 1); 2353 return {}; 2354 } 2355 2356 void fir::IterWhileOp::resultToSourceOps( 2357 llvm::SmallVectorImpl<mlir::Value> &results, unsigned resultNum) { 2358 auto oper = getFinalValue() ? resultNum + 1 : resultNum; 2359 auto *term = getRegion().front().getTerminator(); 2360 if (oper < term->getNumOperands()) 2361 results.push_back(term->getOperand(oper)); 2362 } 2363 2364 mlir::Value fir::IterWhileOp::blockArgToSourceOp(unsigned blockArgNum) { 2365 if (blockArgNum > 0 && blockArgNum <= getInitArgs().size()) 2366 return getInitArgs()[blockArgNum - 1]; 2367 return {}; 2368 } 2369 2370 std::optional<llvm::MutableArrayRef<mlir::OpOperand>> 2371 fir::IterWhileOp::getYieldedValuesMutable() { 2372 auto *term = getRegion().front().getTerminator(); 2373 return getFinalValue() ? term->getOpOperands().drop_front() 2374 : term->getOpOperands(); 2375 } 2376 2377 //===----------------------------------------------------------------------===// 2378 // LenParamIndexOp 2379 //===----------------------------------------------------------------------===// 2380 2381 mlir::ParseResult fir::LenParamIndexOp::parse(mlir::OpAsmParser &parser, 2382 mlir::OperationState &result) { 2383 return parseFieldLikeOp<fir::LenType>(parser, result); 2384 } 2385 2386 void fir::LenParamIndexOp::print(mlir::OpAsmPrinter &p) { 2387 printFieldLikeOp(p, *this); 2388 } 2389 2390 void fir::LenParamIndexOp::build(mlir::OpBuilder &builder, 2391 mlir::OperationState &result, 2392 llvm::StringRef fieldName, mlir::Type recTy, 2393 mlir::ValueRange operands) { 2394 result.addAttribute(getFieldAttrName(), builder.getStringAttr(fieldName)); 2395 result.addAttribute(getTypeAttrName(), mlir::TypeAttr::get(recTy)); 2396 result.addOperands(operands); 2397 } 2398 2399 llvm::SmallVector<mlir::Attribute> fir::LenParamIndexOp::getAttributes() { 2400 llvm::SmallVector<mlir::Attribute> attrs; 2401 attrs.push_back(getFieldIdAttr()); 2402 attrs.push_back(getOnTypeAttr()); 2403 return attrs; 2404 } 2405 2406 //===----------------------------------------------------------------------===// 2407 // LoadOp 2408 //===----------------------------------------------------------------------===// 2409 2410 void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 2411 mlir::Value refVal) { 2412 if (!refVal) { 2413 mlir::emitError(result.location, "LoadOp has null argument"); 2414 return; 2415 } 2416 auto eleTy = fir::dyn_cast_ptrEleTy(refVal.getType()); 2417 if (!eleTy) { 2418 mlir::emitError(result.location, "not a memory reference type"); 2419 return; 2420 } 2421 build(builder, result, eleTy, refVal); 2422 } 2423 2424 void fir::LoadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 2425 mlir::Type resTy, mlir::Value refVal) { 2426 2427 if (!refVal) { 2428 mlir::emitError(result.location, "LoadOp has null argument"); 2429 return; 2430 } 2431 result.addOperands(refVal); 2432 result.addTypes(resTy); 2433 } 2434 2435 mlir::ParseResult fir::LoadOp::getElementOf(mlir::Type &ele, mlir::Type ref) { 2436 if ((ele = fir::dyn_cast_ptrEleTy(ref))) 2437 return mlir::success(); 2438 return mlir::failure(); 2439 } 2440 2441 mlir::ParseResult fir::LoadOp::parse(mlir::OpAsmParser &parser, 2442 mlir::OperationState &result) { 2443 mlir::Type type; 2444 mlir::OpAsmParser::UnresolvedOperand oper; 2445 if (parser.parseOperand(oper) || 2446 parser.parseOptionalAttrDict(result.attributes) || 2447 parser.parseColonType(type) || 2448 parser.resolveOperand(oper, type, result.operands)) 2449 return mlir::failure(); 2450 mlir::Type eleTy; 2451 if (fir::LoadOp::getElementOf(eleTy, type) || 2452 parser.addTypeToList(eleTy, result.types)) 2453 return mlir::failure(); 2454 return mlir::success(); 2455 } 2456 2457 void fir::LoadOp::print(mlir::OpAsmPrinter &p) { 2458 p << ' '; 2459 p.printOperand(getMemref()); 2460 p.printOptionalAttrDict(getOperation()->getAttrs(), {}); 2461 p << " : " << getMemref().getType(); 2462 } 2463 2464 //===----------------------------------------------------------------------===// 2465 // DoLoopOp 2466 //===----------------------------------------------------------------------===// 2467 2468 void fir::DoLoopOp::build(mlir::OpBuilder &builder, 2469 mlir::OperationState &result, mlir::Value lb, 2470 mlir::Value ub, mlir::Value step, bool unordered, 2471 bool finalCountValue, mlir::ValueRange iterArgs, 2472 mlir::ValueRange reduceOperands, 2473 llvm::ArrayRef<mlir::Attribute> reduceAttrs, 2474 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 2475 result.addOperands({lb, ub, step}); 2476 result.addOperands(reduceOperands); 2477 result.addOperands(iterArgs); 2478 result.addAttribute(getOperandSegmentSizeAttr(), 2479 builder.getDenseI32ArrayAttr( 2480 {1, 1, 1, static_cast<int32_t>(reduceOperands.size()), 2481 static_cast<int32_t>(iterArgs.size())})); 2482 if (finalCountValue) { 2483 result.addTypes(builder.getIndexType()); 2484 result.addAttribute(getFinalValueAttrName(result.name), 2485 builder.getUnitAttr()); 2486 } 2487 for (auto v : iterArgs) 2488 result.addTypes(v.getType()); 2489 mlir::Region *bodyRegion = result.addRegion(); 2490 bodyRegion->push_back(new mlir::Block{}); 2491 if (iterArgs.empty() && !finalCountValue) 2492 fir::DoLoopOp::ensureTerminator(*bodyRegion, builder, result.location); 2493 bodyRegion->front().addArgument(builder.getIndexType(), result.location); 2494 bodyRegion->front().addArguments( 2495 iterArgs.getTypes(), 2496 llvm::SmallVector<mlir::Location>(iterArgs.size(), result.location)); 2497 if (unordered) 2498 result.addAttribute(getUnorderedAttrName(result.name), 2499 builder.getUnitAttr()); 2500 if (!reduceAttrs.empty()) 2501 result.addAttribute(getReduceAttrsAttrName(result.name), 2502 builder.getArrayAttr(reduceAttrs)); 2503 result.addAttributes(attributes); 2504 } 2505 2506 mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser, 2507 mlir::OperationState &result) { 2508 auto &builder = parser.getBuilder(); 2509 mlir::OpAsmParser::Argument inductionVariable; 2510 mlir::OpAsmParser::UnresolvedOperand lb, ub, step; 2511 // Parse the induction variable followed by '='. 2512 if (parser.parseArgument(inductionVariable) || parser.parseEqual()) 2513 return mlir::failure(); 2514 2515 // Parse loop bounds. 2516 auto indexType = builder.getIndexType(); 2517 if (parser.parseOperand(lb) || 2518 parser.resolveOperand(lb, indexType, result.operands) || 2519 parser.parseKeyword("to") || parser.parseOperand(ub) || 2520 parser.resolveOperand(ub, indexType, result.operands) || 2521 parser.parseKeyword("step") || parser.parseOperand(step) || 2522 parser.resolveOperand(step, indexType, result.operands)) 2523 return mlir::failure(); 2524 2525 if (mlir::succeeded(parser.parseOptionalKeyword("unordered"))) 2526 result.addAttribute("unordered", builder.getUnitAttr()); 2527 2528 // Parse the reduction arguments. 2529 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands; 2530 llvm::SmallVector<mlir::Type> reduceArgTypes; 2531 if (succeeded(parser.parseOptionalKeyword("reduce"))) { 2532 // Parse reduction attributes and variables. 2533 llvm::SmallVector<ReduceAttr> attributes; 2534 if (failed(parser.parseCommaSeparatedList( 2535 mlir::AsmParser::Delimiter::Paren, [&]() { 2536 if (parser.parseAttribute(attributes.emplace_back()) || 2537 parser.parseArrow() || 2538 parser.parseOperand(reduceOperands.emplace_back()) || 2539 parser.parseColonType(reduceArgTypes.emplace_back())) 2540 return mlir::failure(); 2541 return mlir::success(); 2542 }))) 2543 return mlir::failure(); 2544 // Resolve input operands. 2545 for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes)) 2546 if (parser.resolveOperand(std::get<0>(operand_type), 2547 std::get<1>(operand_type), result.operands)) 2548 return mlir::failure(); 2549 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), 2550 attributes.end()); 2551 result.addAttribute(getReduceAttrsAttrName(result.name), 2552 builder.getArrayAttr(arrayAttr)); 2553 } 2554 2555 // Parse the optional initial iteration arguments. 2556 llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs; 2557 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands; 2558 llvm::SmallVector<mlir::Type> argTypes; 2559 bool prependCount = false; 2560 regionArgs.push_back(inductionVariable); 2561 2562 if (succeeded(parser.parseOptionalKeyword("iter_args"))) { 2563 // Parse assignment list and results type list. 2564 if (parser.parseAssignmentList(regionArgs, iterOperands) || 2565 parser.parseArrowTypeList(result.types)) 2566 return mlir::failure(); 2567 if (result.types.size() == iterOperands.size() + 1) 2568 prependCount = true; 2569 // Resolve input operands. 2570 llvm::ArrayRef<mlir::Type> resTypes = result.types; 2571 for (auto operand_type : llvm::zip( 2572 iterOperands, prependCount ? resTypes.drop_front() : resTypes)) 2573 if (parser.resolveOperand(std::get<0>(operand_type), 2574 std::get<1>(operand_type), result.operands)) 2575 return mlir::failure(); 2576 } else if (succeeded(parser.parseOptionalArrow())) { 2577 if (parser.parseKeyword("index")) 2578 return mlir::failure(); 2579 result.types.push_back(indexType); 2580 prependCount = true; 2581 } 2582 2583 // Set the operandSegmentSizes attribute 2584 result.addAttribute(getOperandSegmentSizeAttr(), 2585 builder.getDenseI32ArrayAttr( 2586 {1, 1, 1, static_cast<int32_t>(reduceOperands.size()), 2587 static_cast<int32_t>(iterOperands.size())})); 2588 2589 if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) 2590 return mlir::failure(); 2591 2592 // Induction variable. 2593 if (prependCount) 2594 result.addAttribute(DoLoopOp::getFinalValueAttrName(result.name), 2595 builder.getUnitAttr()); 2596 else 2597 argTypes.push_back(indexType); 2598 // Loop carried variables 2599 argTypes.append(result.types.begin(), result.types.end()); 2600 // Parse the body region. 2601 auto *body = result.addRegion(); 2602 if (regionArgs.size() != argTypes.size()) 2603 return parser.emitError( 2604 parser.getNameLoc(), 2605 "mismatch in number of loop-carried values and defined values"); 2606 for (size_t i = 0, e = regionArgs.size(); i != e; ++i) 2607 regionArgs[i].type = argTypes[i]; 2608 2609 if (parser.parseRegion(*body, regionArgs)) 2610 return mlir::failure(); 2611 2612 DoLoopOp::ensureTerminator(*body, builder, result.location); 2613 2614 return mlir::success(); 2615 } 2616 2617 fir::DoLoopOp fir::getForInductionVarOwner(mlir::Value val) { 2618 auto ivArg = mlir::dyn_cast<mlir::BlockArgument>(val); 2619 if (!ivArg) 2620 return {}; 2621 assert(ivArg.getOwner() && "unlinked block argument"); 2622 auto *containingInst = ivArg.getOwner()->getParentOp(); 2623 return mlir::dyn_cast_or_null<fir::DoLoopOp>(containingInst); 2624 } 2625 2626 // Lifted from loop.loop 2627 llvm::LogicalResult fir::DoLoopOp::verify() { 2628 // Check that the body defines as single block argument for the induction 2629 // variable. 2630 auto *body = getBody(); 2631 if (!body->getArgument(0).getType().isIndex()) 2632 return emitOpError( 2633 "expected body first argument to be an index argument for " 2634 "the induction variable"); 2635 2636 auto opNumResults = getNumResults(); 2637 if (opNumResults == 0) 2638 return mlir::success(); 2639 2640 if (getFinalValue()) { 2641 if (getUnordered()) 2642 return emitOpError("unordered loop has no final value"); 2643 opNumResults--; 2644 } 2645 if (getNumIterOperands() != opNumResults) 2646 return emitOpError( 2647 "mismatch in number of loop-carried values and defined values"); 2648 if (getNumRegionIterArgs() != opNumResults) 2649 return emitOpError( 2650 "mismatch in number of basic block args and defined values"); 2651 auto iterOperands = getIterOperands(); 2652 auto iterArgs = getRegionIterArgs(); 2653 auto opResults = getFinalValue() ? getResults().drop_front() : getResults(); 2654 unsigned i = 0u; 2655 for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) { 2656 if (std::get<0>(e).getType() != std::get<2>(e).getType()) 2657 return emitOpError() << "types mismatch between " << i 2658 << "th iter operand and defined value"; 2659 if (std::get<1>(e).getType() != std::get<2>(e).getType()) 2660 return emitOpError() << "types mismatch between " << i 2661 << "th iter region arg and defined value"; 2662 2663 i++; 2664 } 2665 auto reduceAttrs = getReduceAttrsAttr(); 2666 if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0)) 2667 return emitOpError( 2668 "mismatch in number of reduction variables and reduction attributes"); 2669 return mlir::success(); 2670 } 2671 2672 void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) { 2673 bool printBlockTerminators = false; 2674 p << ' ' << getInductionVar() << " = " << getLowerBound() << " to " 2675 << getUpperBound() << " step " << getStep(); 2676 if (getUnordered()) 2677 p << " unordered"; 2678 if (hasReduceOperands()) { 2679 p << " reduce("; 2680 auto attrs = getReduceAttrsAttr(); 2681 auto operands = getReduceOperands(); 2682 llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) { 2683 p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " 2684 << std::get<1>(it).getType(); 2685 }); 2686 p << ')'; 2687 printBlockTerminators = true; 2688 } 2689 if (hasIterOperands()) { 2690 p << " iter_args("; 2691 auto regionArgs = getRegionIterArgs(); 2692 auto operands = getIterOperands(); 2693 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { 2694 p << std::get<0>(it) << " = " << std::get<1>(it); 2695 }); 2696 p << ") -> (" << getResultTypes() << ')'; 2697 printBlockTerminators = true; 2698 } else if (getFinalValue()) { 2699 p << " -> " << getResultTypes(); 2700 printBlockTerminators = true; 2701 } 2702 p.printOptionalAttrDictWithKeyword( 2703 (*this)->getAttrs(), 2704 {"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"}); 2705 p << ' '; 2706 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, 2707 printBlockTerminators); 2708 } 2709 2710 llvm::SmallVector<mlir::Region *> fir::DoLoopOp::getLoopRegions() { 2711 return {&getRegion()}; 2712 } 2713 2714 /// Translate a value passed as an iter_arg to the corresponding block 2715 /// argument in the body of the loop. 2716 mlir::BlockArgument fir::DoLoopOp::iterArgToBlockArg(mlir::Value iterArg) { 2717 for (auto i : llvm::enumerate(getInitArgs())) 2718 if (iterArg == i.value()) 2719 return getRegion().front().getArgument(i.index() + 1); 2720 return {}; 2721 } 2722 2723 /// Translate the result vector (by index number) to the corresponding value 2724 /// to the `fir.result` Op. 2725 void fir::DoLoopOp::resultToSourceOps( 2726 llvm::SmallVectorImpl<mlir::Value> &results, unsigned resultNum) { 2727 auto oper = getFinalValue() ? resultNum + 1 : resultNum; 2728 auto *term = getRegion().front().getTerminator(); 2729 if (oper < term->getNumOperands()) 2730 results.push_back(term->getOperand(oper)); 2731 } 2732 2733 /// Translate the block argument (by index number) to the corresponding value 2734 /// passed as an iter_arg to the parent DoLoopOp. 2735 mlir::Value fir::DoLoopOp::blockArgToSourceOp(unsigned blockArgNum) { 2736 if (blockArgNum > 0 && blockArgNum <= getInitArgs().size()) 2737 return getInitArgs()[blockArgNum - 1]; 2738 return {}; 2739 } 2740 2741 std::optional<llvm::MutableArrayRef<mlir::OpOperand>> 2742 fir::DoLoopOp::getYieldedValuesMutable() { 2743 auto *term = getRegion().front().getTerminator(); 2744 return getFinalValue() ? term->getOpOperands().drop_front() 2745 : term->getOpOperands(); 2746 } 2747 2748 //===----------------------------------------------------------------------===// 2749 // DTEntryOp 2750 //===----------------------------------------------------------------------===// 2751 2752 mlir::ParseResult fir::DTEntryOp::parse(mlir::OpAsmParser &parser, 2753 mlir::OperationState &result) { 2754 llvm::StringRef methodName; 2755 // allow `methodName` or `"methodName"` 2756 if (failed(parser.parseOptionalKeyword(&methodName))) { 2757 mlir::StringAttr methodAttr; 2758 if (parser.parseAttribute(methodAttr, getMethodAttrName(result.name), 2759 result.attributes)) 2760 return mlir::failure(); 2761 } else { 2762 result.addAttribute(getMethodAttrName(result.name), 2763 parser.getBuilder().getStringAttr(methodName)); 2764 } 2765 mlir::SymbolRefAttr calleeAttr; 2766 if (parser.parseComma() || 2767 parser.parseAttribute(calleeAttr, fir::DTEntryOp::getProcAttrNameStr(), 2768 result.attributes)) 2769 return mlir::failure(); 2770 return mlir::success(); 2771 } 2772 2773 void fir::DTEntryOp::print(mlir::OpAsmPrinter &p) { 2774 p << ' ' << getMethodAttr() << ", " << getProcAttr(); 2775 } 2776 2777 //===----------------------------------------------------------------------===// 2778 // ReboxOp 2779 //===----------------------------------------------------------------------===// 2780 2781 /// Get the scalar type related to a fir.box type. 2782 /// Example: return f32 for !fir.box<!fir.heap<!fir.array<?x?xf32>>. 2783 static mlir::Type getBoxScalarEleTy(mlir::Type boxTy) { 2784 auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(boxTy); 2785 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) 2786 return seqTy.getEleTy(); 2787 return eleTy; 2788 } 2789 2790 /// Test if \p t1 and \p t2 are compatible character types (if they can 2791 /// represent the same type at runtime). 2792 static bool areCompatibleCharacterTypes(mlir::Type t1, mlir::Type t2) { 2793 auto c1 = mlir::dyn_cast<fir::CharacterType>(t1); 2794 auto c2 = mlir::dyn_cast<fir::CharacterType>(t2); 2795 if (!c1 || !c2) 2796 return false; 2797 if (c1.hasDynamicLen() || c2.hasDynamicLen()) 2798 return true; 2799 return c1.getLen() == c2.getLen(); 2800 } 2801 2802 llvm::LogicalResult fir::ReboxOp::verify() { 2803 auto inputBoxTy = getBox().getType(); 2804 if (fir::isa_unknown_size_box(inputBoxTy)) 2805 return emitOpError("box operand must not have unknown rank or type"); 2806 auto outBoxTy = getType(); 2807 if (fir::isa_unknown_size_box(outBoxTy)) 2808 return emitOpError("result type must not have unknown rank or type"); 2809 auto inputRank = fir::getBoxRank(inputBoxTy); 2810 auto inputEleTy = getBoxScalarEleTy(inputBoxTy); 2811 auto outRank = fir::getBoxRank(outBoxTy); 2812 auto outEleTy = getBoxScalarEleTy(outBoxTy); 2813 2814 if (auto sliceVal = getSlice()) { 2815 // Slicing case 2816 if (mlir::cast<fir::SliceType>(sliceVal.getType()).getRank() != inputRank) 2817 return emitOpError("slice operand rank must match box operand rank"); 2818 if (auto shapeVal = getShape()) { 2819 if (auto shiftTy = mlir::dyn_cast<fir::ShiftType>(shapeVal.getType())) { 2820 if (shiftTy.getRank() != inputRank) 2821 return emitOpError("shape operand and input box ranks must match " 2822 "when there is a slice"); 2823 } else { 2824 return emitOpError("shape operand must absent or be a fir.shift " 2825 "when there is a slice"); 2826 } 2827 } 2828 if (auto sliceOp = sliceVal.getDefiningOp()) { 2829 auto slicedRank = mlir::cast<fir::SliceOp>(sliceOp).getOutRank(); 2830 if (slicedRank != outRank) 2831 return emitOpError("result type rank and rank after applying slice " 2832 "operand must match"); 2833 } 2834 } else { 2835 // Reshaping case 2836 unsigned shapeRank = inputRank; 2837 if (auto shapeVal = getShape()) { 2838 auto ty = shapeVal.getType(); 2839 if (auto shapeTy = mlir::dyn_cast<fir::ShapeType>(ty)) { 2840 shapeRank = shapeTy.getRank(); 2841 } else if (auto shapeShiftTy = mlir::dyn_cast<fir::ShapeShiftType>(ty)) { 2842 shapeRank = shapeShiftTy.getRank(); 2843 } else { 2844 auto shiftTy = mlir::cast<fir::ShiftType>(ty); 2845 shapeRank = shiftTy.getRank(); 2846 if (shapeRank != inputRank) 2847 return emitOpError("shape operand and input box ranks must match " 2848 "when the shape is a fir.shift"); 2849 } 2850 } 2851 if (shapeRank != outRank) 2852 return emitOpError("result type and shape operand ranks must match"); 2853 } 2854 2855 if (inputEleTy != outEleTy) { 2856 // TODO: check that outBoxTy is a parent type of inputBoxTy for derived 2857 // types. 2858 // Character input and output types with constant length may be different if 2859 // there is a substring in the slice, otherwise, they must match. If any of 2860 // the types is a character with dynamic length, the other type can be any 2861 // character type. 2862 const bool typeCanMismatch = 2863 mlir::isa<fir::RecordType>(inputEleTy) || 2864 mlir::isa<mlir::NoneType>(outEleTy) || 2865 (mlir::isa<mlir::NoneType>(inputEleTy) && 2866 mlir::isa<fir::RecordType>(outEleTy)) || 2867 (getSlice() && mlir::isa<fir::CharacterType>(inputEleTy)) || 2868 (getSlice() && fir::isa_complex(inputEleTy) && 2869 mlir::isa<mlir::FloatType>(outEleTy)) || 2870 areCompatibleCharacterTypes(inputEleTy, outEleTy); 2871 if (!typeCanMismatch) 2872 return emitOpError( 2873 "op input and output element types must match for intrinsic types"); 2874 } 2875 return mlir::success(); 2876 } 2877 2878 //===----------------------------------------------------------------------===// 2879 // ReboxAssumedRankOp 2880 //===----------------------------------------------------------------------===// 2881 2882 static bool areCompatibleAssumedRankElementType(mlir::Type inputEleTy, 2883 mlir::Type outEleTy) { 2884 if (inputEleTy == outEleTy) 2885 return true; 2886 // Output is unlimited polymorphic -> output dynamic type is the same as input 2887 // type. 2888 if (mlir::isa<mlir::NoneType>(outEleTy)) 2889 return true; 2890 // Output/Input are derived types. Assuming input extends output type, output 2891 // dynamic type is the output static type, unless output is polymorphic. 2892 if (mlir::isa<fir::RecordType>(inputEleTy) && 2893 mlir::isa<fir::RecordType>(outEleTy)) 2894 return true; 2895 if (areCompatibleCharacterTypes(inputEleTy, outEleTy)) 2896 return true; 2897 return false; 2898 } 2899 2900 llvm::LogicalResult fir::ReboxAssumedRankOp::verify() { 2901 mlir::Type inputType = getBox().getType(); 2902 if (!mlir::isa<fir::BaseBoxType>(inputType) && !fir::isBoxAddress(inputType)) 2903 return emitOpError("input must be a box or box address"); 2904 mlir::Type inputEleTy = 2905 mlir::cast<fir::BaseBoxType>(fir::unwrapRefType(inputType)) 2906 .unwrapInnerType(); 2907 mlir::Type outEleTy = 2908 mlir::cast<fir::BaseBoxType>(getType()).unwrapInnerType(); 2909 if (!areCompatibleAssumedRankElementType(inputEleTy, outEleTy)) 2910 return emitOpError("input and output element types are incompatible"); 2911 return mlir::success(); 2912 } 2913 2914 void fir::ReboxAssumedRankOp::getEffects( 2915 llvm::SmallVectorImpl< 2916 mlir::SideEffects::EffectInstance<mlir::MemoryEffects::Effect>> 2917 &effects) { 2918 mlir::OpOperand &inputBox = getBoxMutable(); 2919 if (fir::isBoxAddress(inputBox.get().getType())) 2920 effects.emplace_back(mlir::MemoryEffects::Read::get(), &inputBox, 2921 mlir::SideEffects::DefaultResource::get()); 2922 } 2923 2924 //===----------------------------------------------------------------------===// 2925 // ResultOp 2926 //===----------------------------------------------------------------------===// 2927 2928 llvm::LogicalResult fir::ResultOp::verify() { 2929 auto *parentOp = (*this)->getParentOp(); 2930 auto results = parentOp->getResults(); 2931 auto operands = (*this)->getOperands(); 2932 2933 if (parentOp->getNumResults() != getNumOperands()) 2934 return emitOpError() << "parent of result must have same arity"; 2935 for (auto e : llvm::zip(results, operands)) 2936 if (std::get<0>(e).getType() != std::get<1>(e).getType()) 2937 return emitOpError() << "types mismatch between result op and its parent"; 2938 return mlir::success(); 2939 } 2940 2941 //===----------------------------------------------------------------------===// 2942 // SaveResultOp 2943 //===----------------------------------------------------------------------===// 2944 2945 llvm::LogicalResult fir::SaveResultOp::verify() { 2946 auto resultType = getValue().getType(); 2947 if (resultType != fir::dyn_cast_ptrEleTy(getMemref().getType())) 2948 return emitOpError("value type must match memory reference type"); 2949 if (fir::isa_unknown_size_box(resultType)) 2950 return emitOpError("cannot save !fir.box of unknown rank or type"); 2951 2952 if (mlir::isa<fir::BoxType>(resultType)) { 2953 if (getShape() || !getTypeparams().empty()) 2954 return emitOpError( 2955 "must not have shape or length operands if the value is a fir.box"); 2956 return mlir::success(); 2957 } 2958 2959 // fir.record or fir.array case. 2960 unsigned shapeTyRank = 0; 2961 if (auto shapeVal = getShape()) { 2962 auto shapeTy = shapeVal.getType(); 2963 if (auto s = mlir::dyn_cast<fir::ShapeType>(shapeTy)) 2964 shapeTyRank = s.getRank(); 2965 else 2966 shapeTyRank = mlir::cast<fir::ShapeShiftType>(shapeTy).getRank(); 2967 } 2968 2969 auto eleTy = resultType; 2970 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(resultType)) { 2971 if (seqTy.getDimension() != shapeTyRank) 2972 emitOpError("shape operand must be provided and have the value rank " 2973 "when the value is a fir.array"); 2974 eleTy = seqTy.getEleTy(); 2975 } else { 2976 if (shapeTyRank != 0) 2977 emitOpError( 2978 "shape operand should only be provided if the value is a fir.array"); 2979 } 2980 2981 if (auto recTy = mlir::dyn_cast<fir::RecordType>(eleTy)) { 2982 if (recTy.getNumLenParams() != getTypeparams().size()) 2983 emitOpError("length parameters number must match with the value type " 2984 "length parameters"); 2985 } else if (auto charTy = mlir::dyn_cast<fir::CharacterType>(eleTy)) { 2986 if (getTypeparams().size() > 1) 2987 emitOpError("no more than one length parameter must be provided for " 2988 "character value"); 2989 } else { 2990 if (!getTypeparams().empty()) 2991 emitOpError("length parameters must not be provided for this value type"); 2992 } 2993 2994 return mlir::success(); 2995 } 2996 2997 //===----------------------------------------------------------------------===// 2998 // IntegralSwitchTerminator 2999 //===----------------------------------------------------------------------===// 3000 static constexpr llvm::StringRef getCompareOffsetAttr() { 3001 return "compare_operand_offsets"; 3002 } 3003 3004 static constexpr llvm::StringRef getTargetOffsetAttr() { 3005 return "target_operand_offsets"; 3006 } 3007 3008 template <typename OpT> 3009 static llvm::LogicalResult verifyIntegralSwitchTerminator(OpT op) { 3010 if (!mlir::isa<mlir::IntegerType, mlir::IndexType, fir::IntegerType>( 3011 op.getSelector().getType())) 3012 return op.emitOpError("must be an integer"); 3013 auto cases = 3014 op->template getAttrOfType<mlir::ArrayAttr>(op.getCasesAttr()).getValue(); 3015 auto count = op.getNumDest(); 3016 if (count == 0) 3017 return op.emitOpError("must have at least one successor"); 3018 if (op.getNumConditions() != count) 3019 return op.emitOpError("number of cases and targets don't match"); 3020 if (op.targetOffsetSize() != count) 3021 return op.emitOpError("incorrect number of successor operand groups"); 3022 for (decltype(count) i = 0; i != count; ++i) { 3023 if (!mlir::isa<mlir::IntegerAttr, mlir::UnitAttr>(cases[i])) 3024 return op.emitOpError("invalid case alternative"); 3025 } 3026 return mlir::success(); 3027 } 3028 3029 static mlir::ParseResult parseIntegralSwitchTerminator( 3030 mlir::OpAsmParser &parser, mlir::OperationState &result, 3031 llvm::StringRef casesAttr, llvm::StringRef operandSegmentAttr) { 3032 mlir::OpAsmParser::UnresolvedOperand selector; 3033 mlir::Type type; 3034 if (fir::parseSelector(parser, result, selector, type)) 3035 return mlir::failure(); 3036 3037 llvm::SmallVector<mlir::Attribute> ivalues; 3038 llvm::SmallVector<mlir::Block *> dests; 3039 llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; 3040 while (true) { 3041 mlir::Attribute ivalue; // Integer or Unit 3042 mlir::Block *dest; 3043 llvm::SmallVector<mlir::Value> destArg; 3044 mlir::NamedAttrList temp; 3045 if (parser.parseAttribute(ivalue, "i", temp) || parser.parseComma() || 3046 parser.parseSuccessorAndUseList(dest, destArg)) 3047 return mlir::failure(); 3048 ivalues.push_back(ivalue); 3049 dests.push_back(dest); 3050 destArgs.push_back(destArg); 3051 if (!parser.parseOptionalRSquare()) 3052 break; 3053 if (parser.parseComma()) 3054 return mlir::failure(); 3055 } 3056 auto &bld = parser.getBuilder(); 3057 result.addAttribute(casesAttr, bld.getArrayAttr(ivalues)); 3058 llvm::SmallVector<int32_t> argOffs; 3059 int32_t sumArgs = 0; 3060 const auto count = dests.size(); 3061 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 3062 result.addSuccessors(dests[i]); 3063 result.addOperands(destArgs[i]); 3064 auto argSize = destArgs[i].size(); 3065 argOffs.push_back(argSize); 3066 sumArgs += argSize; 3067 } 3068 result.addAttribute(operandSegmentAttr, 3069 bld.getDenseI32ArrayAttr({1, 0, sumArgs})); 3070 result.addAttribute(getTargetOffsetAttr(), bld.getDenseI32ArrayAttr(argOffs)); 3071 return mlir::success(); 3072 } 3073 3074 template <typename OpT> 3075 static void printIntegralSwitchTerminator(OpT op, mlir::OpAsmPrinter &p) { 3076 p << ' '; 3077 p.printOperand(op.getSelector()); 3078 p << " : " << op.getSelector().getType() << " ["; 3079 auto cases = 3080 op->template getAttrOfType<mlir::ArrayAttr>(op.getCasesAttr()).getValue(); 3081 auto count = op.getNumConditions(); 3082 for (decltype(count) i = 0; i != count; ++i) { 3083 if (i) 3084 p << ", "; 3085 auto &attr = cases[i]; 3086 if (auto intAttr = mlir::dyn_cast_or_null<mlir::IntegerAttr>(attr)) 3087 p << intAttr.getValue(); 3088 else 3089 p.printAttribute(attr); 3090 p << ", "; 3091 op.printSuccessorAtIndex(p, i); 3092 } 3093 p << ']'; 3094 p.printOptionalAttrDict( 3095 op->getAttrs(), {op.getCasesAttr(), getCompareOffsetAttr(), 3096 getTargetOffsetAttr(), op.getOperandSegmentSizeAttr()}); 3097 } 3098 3099 //===----------------------------------------------------------------------===// 3100 // SelectOp 3101 //===----------------------------------------------------------------------===// 3102 3103 llvm::LogicalResult fir::SelectOp::verify() { 3104 return verifyIntegralSwitchTerminator(*this); 3105 } 3106 3107 mlir::ParseResult fir::SelectOp::parse(mlir::OpAsmParser &parser, 3108 mlir::OperationState &result) { 3109 return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), 3110 getOperandSegmentSizeAttr()); 3111 } 3112 3113 void fir::SelectOp::print(mlir::OpAsmPrinter &p) { 3114 printIntegralSwitchTerminator(*this, p); 3115 } 3116 3117 template <typename A, typename... AdditionalArgs> 3118 static A getSubOperands(unsigned pos, A allArgs, mlir::DenseI32ArrayAttr ranges, 3119 AdditionalArgs &&...additionalArgs) { 3120 unsigned start = 0; 3121 for (unsigned i = 0; i < pos; ++i) 3122 start += ranges[i]; 3123 return allArgs.slice(start, ranges[pos], 3124 std::forward<AdditionalArgs>(additionalArgs)...); 3125 } 3126 3127 static mlir::MutableOperandRange 3128 getMutableSuccessorOperands(unsigned pos, mlir::MutableOperandRange operands, 3129 llvm::StringRef offsetAttr) { 3130 mlir::Operation *owner = operands.getOwner(); 3131 mlir::NamedAttribute targetOffsetAttr = 3132 *owner->getAttrDictionary().getNamed(offsetAttr); 3133 return getSubOperands( 3134 pos, operands, 3135 mlir::cast<mlir::DenseI32ArrayAttr>(targetOffsetAttr.getValue()), 3136 mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); 3137 } 3138 3139 std::optional<mlir::OperandRange> fir::SelectOp::getCompareOperands(unsigned) { 3140 return {}; 3141 } 3142 3143 std::optional<llvm::ArrayRef<mlir::Value>> 3144 fir::SelectOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 3145 return {}; 3146 } 3147 3148 mlir::SuccessorOperands fir::SelectOp::getSuccessorOperands(unsigned oper) { 3149 return mlir::SuccessorOperands(::getMutableSuccessorOperands( 3150 oper, getTargetArgsMutable(), getTargetOffsetAttr())); 3151 } 3152 3153 std::optional<llvm::ArrayRef<mlir::Value>> 3154 fir::SelectOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 3155 unsigned oper) { 3156 auto a = 3157 (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); 3158 auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( 3159 getOperandSegmentSizeAttr()); 3160 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 3161 } 3162 3163 std::optional<mlir::ValueRange> 3164 fir::SelectOp::getSuccessorOperands(mlir::ValueRange operands, unsigned oper) { 3165 auto a = 3166 (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); 3167 auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( 3168 getOperandSegmentSizeAttr()); 3169 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 3170 } 3171 3172 unsigned fir::SelectOp::targetOffsetSize() { 3173 return (*this) 3174 ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) 3175 .size(); 3176 } 3177 3178 //===----------------------------------------------------------------------===// 3179 // SelectCaseOp 3180 //===----------------------------------------------------------------------===// 3181 3182 std::optional<mlir::OperandRange> 3183 fir::SelectCaseOp::getCompareOperands(unsigned cond) { 3184 auto a = 3185 (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()); 3186 return {getSubOperands(cond, getCompareArgs(), a)}; 3187 } 3188 3189 std::optional<llvm::ArrayRef<mlir::Value>> 3190 fir::SelectCaseOp::getCompareOperands(llvm::ArrayRef<mlir::Value> operands, 3191 unsigned cond) { 3192 auto a = 3193 (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()); 3194 auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( 3195 getOperandSegmentSizeAttr()); 3196 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; 3197 } 3198 3199 std::optional<mlir::ValueRange> 3200 fir::SelectCaseOp::getCompareOperands(mlir::ValueRange operands, 3201 unsigned cond) { 3202 auto a = 3203 (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()); 3204 auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( 3205 getOperandSegmentSizeAttr()); 3206 return {getSubOperands(cond, getSubOperands(1, operands, segments), a)}; 3207 } 3208 3209 mlir::SuccessorOperands fir::SelectCaseOp::getSuccessorOperands(unsigned oper) { 3210 return mlir::SuccessorOperands(::getMutableSuccessorOperands( 3211 oper, getTargetArgsMutable(), getTargetOffsetAttr())); 3212 } 3213 3214 std::optional<llvm::ArrayRef<mlir::Value>> 3215 fir::SelectCaseOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 3216 unsigned oper) { 3217 auto a = 3218 (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); 3219 auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( 3220 getOperandSegmentSizeAttr()); 3221 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 3222 } 3223 3224 std::optional<mlir::ValueRange> 3225 fir::SelectCaseOp::getSuccessorOperands(mlir::ValueRange operands, 3226 unsigned oper) { 3227 auto a = 3228 (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); 3229 auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( 3230 getOperandSegmentSizeAttr()); 3231 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 3232 } 3233 3234 // parser for fir.select_case Op 3235 mlir::ParseResult fir::SelectCaseOp::parse(mlir::OpAsmParser &parser, 3236 mlir::OperationState &result) { 3237 mlir::OpAsmParser::UnresolvedOperand selector; 3238 mlir::Type type; 3239 if (fir::parseSelector(parser, result, selector, type)) 3240 return mlir::failure(); 3241 3242 llvm::SmallVector<mlir::Attribute> attrs; 3243 llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> opers; 3244 llvm::SmallVector<mlir::Block *> dests; 3245 llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; 3246 llvm::SmallVector<std::int32_t> argOffs; 3247 std::int32_t offSize = 0; 3248 while (true) { 3249 mlir::Attribute attr; 3250 mlir::Block *dest; 3251 llvm::SmallVector<mlir::Value> destArg; 3252 mlir::NamedAttrList temp; 3253 if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || 3254 parser.parseComma()) 3255 return mlir::failure(); 3256 attrs.push_back(attr); 3257 if (mlir::dyn_cast_or_null<mlir::UnitAttr>(attr)) { 3258 argOffs.push_back(0); 3259 } else if (mlir::dyn_cast_or_null<fir::ClosedIntervalAttr>(attr)) { 3260 mlir::OpAsmParser::UnresolvedOperand oper1; 3261 mlir::OpAsmParser::UnresolvedOperand oper2; 3262 if (parser.parseOperand(oper1) || parser.parseComma() || 3263 parser.parseOperand(oper2) || parser.parseComma()) 3264 return mlir::failure(); 3265 opers.push_back(oper1); 3266 opers.push_back(oper2); 3267 argOffs.push_back(2); 3268 offSize += 2; 3269 } else { 3270 mlir::OpAsmParser::UnresolvedOperand oper; 3271 if (parser.parseOperand(oper) || parser.parseComma()) 3272 return mlir::failure(); 3273 opers.push_back(oper); 3274 argOffs.push_back(1); 3275 ++offSize; 3276 } 3277 if (parser.parseSuccessorAndUseList(dest, destArg)) 3278 return mlir::failure(); 3279 dests.push_back(dest); 3280 destArgs.push_back(destArg); 3281 if (mlir::succeeded(parser.parseOptionalRSquare())) 3282 break; 3283 if (parser.parseComma()) 3284 return mlir::failure(); 3285 } 3286 result.addAttribute(fir::SelectCaseOp::getCasesAttr(), 3287 parser.getBuilder().getArrayAttr(attrs)); 3288 if (parser.resolveOperands(opers, type, result.operands)) 3289 return mlir::failure(); 3290 llvm::SmallVector<int32_t> targOffs; 3291 int32_t toffSize = 0; 3292 const auto count = dests.size(); 3293 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 3294 result.addSuccessors(dests[i]); 3295 result.addOperands(destArgs[i]); 3296 auto argSize = destArgs[i].size(); 3297 targOffs.push_back(argSize); 3298 toffSize += argSize; 3299 } 3300 auto &bld = parser.getBuilder(); 3301 result.addAttribute(fir::SelectCaseOp::getOperandSegmentSizeAttr(), 3302 bld.getDenseI32ArrayAttr({1, offSize, toffSize})); 3303 result.addAttribute(getCompareOffsetAttr(), 3304 bld.getDenseI32ArrayAttr(argOffs)); 3305 result.addAttribute(getTargetOffsetAttr(), 3306 bld.getDenseI32ArrayAttr(targOffs)); 3307 return mlir::success(); 3308 } 3309 3310 void fir::SelectCaseOp::print(mlir::OpAsmPrinter &p) { 3311 p << ' '; 3312 p.printOperand(getSelector()); 3313 p << " : " << getSelector().getType() << " ["; 3314 auto cases = 3315 getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); 3316 auto count = getNumConditions(); 3317 for (decltype(count) i = 0; i != count; ++i) { 3318 if (i) 3319 p << ", "; 3320 p << cases[i] << ", "; 3321 if (!mlir::isa<mlir::UnitAttr>(cases[i])) { 3322 auto caseArgs = *getCompareOperands(i); 3323 p.printOperand(*caseArgs.begin()); 3324 p << ", "; 3325 if (mlir::isa<fir::ClosedIntervalAttr>(cases[i])) { 3326 p.printOperand(*(++caseArgs.begin())); 3327 p << ", "; 3328 } 3329 } 3330 printSuccessorAtIndex(p, i); 3331 } 3332 p << ']'; 3333 p.printOptionalAttrDict(getOperation()->getAttrs(), 3334 {getCasesAttr(), getCompareOffsetAttr(), 3335 getTargetOffsetAttr(), getOperandSegmentSizeAttr()}); 3336 } 3337 3338 unsigned fir::SelectCaseOp::compareOffsetSize() { 3339 return (*this) 3340 ->getAttrOfType<mlir::DenseI32ArrayAttr>(getCompareOffsetAttr()) 3341 .size(); 3342 } 3343 3344 unsigned fir::SelectCaseOp::targetOffsetSize() { 3345 return (*this) 3346 ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) 3347 .size(); 3348 } 3349 3350 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 3351 mlir::OperationState &result, 3352 mlir::Value selector, 3353 llvm::ArrayRef<mlir::Attribute> compareAttrs, 3354 llvm::ArrayRef<mlir::ValueRange> cmpOperands, 3355 llvm::ArrayRef<mlir::Block *> destinations, 3356 llvm::ArrayRef<mlir::ValueRange> destOperands, 3357 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 3358 result.addOperands(selector); 3359 result.addAttribute(getCasesAttr(), builder.getArrayAttr(compareAttrs)); 3360 llvm::SmallVector<int32_t> operOffs; 3361 int32_t operSize = 0; 3362 for (auto attr : compareAttrs) { 3363 if (mlir::isa<fir::ClosedIntervalAttr>(attr)) { 3364 operOffs.push_back(2); 3365 operSize += 2; 3366 } else if (mlir::isa<mlir::UnitAttr>(attr)) { 3367 operOffs.push_back(0); 3368 } else { 3369 operOffs.push_back(1); 3370 ++operSize; 3371 } 3372 } 3373 for (auto ops : cmpOperands) 3374 result.addOperands(ops); 3375 result.addAttribute(getCompareOffsetAttr(), 3376 builder.getDenseI32ArrayAttr(operOffs)); 3377 const auto count = destinations.size(); 3378 for (auto d : destinations) 3379 result.addSuccessors(d); 3380 const auto opCount = destOperands.size(); 3381 llvm::SmallVector<std::int32_t> argOffs; 3382 std::int32_t sumArgs = 0; 3383 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 3384 if (i < opCount) { 3385 result.addOperands(destOperands[i]); 3386 const auto argSz = destOperands[i].size(); 3387 argOffs.push_back(argSz); 3388 sumArgs += argSz; 3389 } else { 3390 argOffs.push_back(0); 3391 } 3392 } 3393 result.addAttribute(getOperandSegmentSizeAttr(), 3394 builder.getDenseI32ArrayAttr({1, operSize, sumArgs})); 3395 result.addAttribute(getTargetOffsetAttr(), 3396 builder.getDenseI32ArrayAttr(argOffs)); 3397 result.addAttributes(attributes); 3398 } 3399 3400 /// This builder has a slightly simplified interface in that the list of 3401 /// operands need not be partitioned by the builder. Instead the operands are 3402 /// partitioned here, before being passed to the default builder. This 3403 /// partitioning is unchecked, so can go awry on bad input. 3404 void fir::SelectCaseOp::build(mlir::OpBuilder &builder, 3405 mlir::OperationState &result, 3406 mlir::Value selector, 3407 llvm::ArrayRef<mlir::Attribute> compareAttrs, 3408 llvm::ArrayRef<mlir::Value> cmpOpList, 3409 llvm::ArrayRef<mlir::Block *> destinations, 3410 llvm::ArrayRef<mlir::ValueRange> destOperands, 3411 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 3412 llvm::SmallVector<mlir::ValueRange> cmpOpers; 3413 auto iter = cmpOpList.begin(); 3414 for (auto &attr : compareAttrs) { 3415 if (mlir::isa<fir::ClosedIntervalAttr>(attr)) { 3416 cmpOpers.push_back(mlir::ValueRange({iter, iter + 2})); 3417 iter += 2; 3418 } else if (mlir::isa<mlir::UnitAttr>(attr)) { 3419 cmpOpers.push_back(mlir::ValueRange{}); 3420 } else { 3421 cmpOpers.push_back(mlir::ValueRange({iter, iter + 1})); 3422 ++iter; 3423 } 3424 } 3425 build(builder, result, selector, compareAttrs, cmpOpers, destinations, 3426 destOperands, attributes); 3427 } 3428 3429 llvm::LogicalResult fir::SelectCaseOp::verify() { 3430 if (!mlir::isa<mlir::IntegerType, mlir::IndexType, fir::IntegerType, 3431 fir::LogicalType, fir::CharacterType>(getSelector().getType())) 3432 return emitOpError("must be an integer, character, or logical"); 3433 auto cases = 3434 getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); 3435 auto count = getNumDest(); 3436 if (count == 0) 3437 return emitOpError("must have at least one successor"); 3438 if (getNumConditions() != count) 3439 return emitOpError("number of conditions and successors don't match"); 3440 if (compareOffsetSize() != count) 3441 return emitOpError("incorrect number of compare operand groups"); 3442 if (targetOffsetSize() != count) 3443 return emitOpError("incorrect number of successor operand groups"); 3444 for (decltype(count) i = 0; i != count; ++i) { 3445 auto &attr = cases[i]; 3446 if (!(mlir::isa<fir::PointIntervalAttr>(attr) || 3447 mlir::isa<fir::LowerBoundAttr>(attr) || 3448 mlir::isa<fir::UpperBoundAttr>(attr) || 3449 mlir::isa<fir::ClosedIntervalAttr>(attr) || 3450 mlir::isa<mlir::UnitAttr>(attr))) 3451 return emitOpError("incorrect select case attribute type"); 3452 } 3453 return mlir::success(); 3454 } 3455 3456 //===----------------------------------------------------------------------===// 3457 // SelectRankOp 3458 //===----------------------------------------------------------------------===// 3459 3460 llvm::LogicalResult fir::SelectRankOp::verify() { 3461 return verifyIntegralSwitchTerminator(*this); 3462 } 3463 3464 mlir::ParseResult fir::SelectRankOp::parse(mlir::OpAsmParser &parser, 3465 mlir::OperationState &result) { 3466 return parseIntegralSwitchTerminator(parser, result, getCasesAttr(), 3467 getOperandSegmentSizeAttr()); 3468 } 3469 3470 void fir::SelectRankOp::print(mlir::OpAsmPrinter &p) { 3471 printIntegralSwitchTerminator(*this, p); 3472 } 3473 3474 std::optional<mlir::OperandRange> 3475 fir::SelectRankOp::getCompareOperands(unsigned) { 3476 return {}; 3477 } 3478 3479 std::optional<llvm::ArrayRef<mlir::Value>> 3480 fir::SelectRankOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 3481 return {}; 3482 } 3483 3484 mlir::SuccessorOperands fir::SelectRankOp::getSuccessorOperands(unsigned oper) { 3485 return mlir::SuccessorOperands(::getMutableSuccessorOperands( 3486 oper, getTargetArgsMutable(), getTargetOffsetAttr())); 3487 } 3488 3489 std::optional<llvm::ArrayRef<mlir::Value>> 3490 fir::SelectRankOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 3491 unsigned oper) { 3492 auto a = 3493 (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); 3494 auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( 3495 getOperandSegmentSizeAttr()); 3496 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 3497 } 3498 3499 std::optional<mlir::ValueRange> 3500 fir::SelectRankOp::getSuccessorOperands(mlir::ValueRange operands, 3501 unsigned oper) { 3502 auto a = 3503 (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); 3504 auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( 3505 getOperandSegmentSizeAttr()); 3506 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 3507 } 3508 3509 unsigned fir::SelectRankOp::targetOffsetSize() { 3510 return (*this) 3511 ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) 3512 .size(); 3513 } 3514 3515 //===----------------------------------------------------------------------===// 3516 // SelectTypeOp 3517 //===----------------------------------------------------------------------===// 3518 3519 std::optional<mlir::OperandRange> 3520 fir::SelectTypeOp::getCompareOperands(unsigned) { 3521 return {}; 3522 } 3523 3524 std::optional<llvm::ArrayRef<mlir::Value>> 3525 fir::SelectTypeOp::getCompareOperands(llvm::ArrayRef<mlir::Value>, unsigned) { 3526 return {}; 3527 } 3528 3529 mlir::SuccessorOperands fir::SelectTypeOp::getSuccessorOperands(unsigned oper) { 3530 return mlir::SuccessorOperands(::getMutableSuccessorOperands( 3531 oper, getTargetArgsMutable(), getTargetOffsetAttr())); 3532 } 3533 3534 std::optional<llvm::ArrayRef<mlir::Value>> 3535 fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands, 3536 unsigned oper) { 3537 auto a = 3538 (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); 3539 auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( 3540 getOperandSegmentSizeAttr()); 3541 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 3542 } 3543 3544 std::optional<mlir::ValueRange> 3545 fir::SelectTypeOp::getSuccessorOperands(mlir::ValueRange operands, 3546 unsigned oper) { 3547 auto a = 3548 (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()); 3549 auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>( 3550 getOperandSegmentSizeAttr()); 3551 return {getSubOperands(oper, getSubOperands(2, operands, segments), a)}; 3552 } 3553 3554 mlir::ParseResult fir::SelectTypeOp::parse(mlir::OpAsmParser &parser, 3555 mlir::OperationState &result) { 3556 mlir::OpAsmParser::UnresolvedOperand selector; 3557 mlir::Type type; 3558 if (fir::parseSelector(parser, result, selector, type)) 3559 return mlir::failure(); 3560 3561 llvm::SmallVector<mlir::Attribute> attrs; 3562 llvm::SmallVector<mlir::Block *> dests; 3563 llvm::SmallVector<llvm::SmallVector<mlir::Value>> destArgs; 3564 while (true) { 3565 mlir::Attribute attr; 3566 mlir::Block *dest; 3567 llvm::SmallVector<mlir::Value> destArg; 3568 mlir::NamedAttrList temp; 3569 if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || 3570 parser.parseSuccessorAndUseList(dest, destArg)) 3571 return mlir::failure(); 3572 attrs.push_back(attr); 3573 dests.push_back(dest); 3574 destArgs.push_back(destArg); 3575 if (mlir::succeeded(parser.parseOptionalRSquare())) 3576 break; 3577 if (parser.parseComma()) 3578 return mlir::failure(); 3579 } 3580 auto &bld = parser.getBuilder(); 3581 result.addAttribute(fir::SelectTypeOp::getCasesAttr(), 3582 bld.getArrayAttr(attrs)); 3583 llvm::SmallVector<int32_t> argOffs; 3584 int32_t offSize = 0; 3585 const auto count = dests.size(); 3586 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 3587 result.addSuccessors(dests[i]); 3588 result.addOperands(destArgs[i]); 3589 auto argSize = destArgs[i].size(); 3590 argOffs.push_back(argSize); 3591 offSize += argSize; 3592 } 3593 result.addAttribute(fir::SelectTypeOp::getOperandSegmentSizeAttr(), 3594 bld.getDenseI32ArrayAttr({1, 0, offSize})); 3595 result.addAttribute(getTargetOffsetAttr(), bld.getDenseI32ArrayAttr(argOffs)); 3596 return mlir::success(); 3597 } 3598 3599 unsigned fir::SelectTypeOp::targetOffsetSize() { 3600 return (*this) 3601 ->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr()) 3602 .size(); 3603 } 3604 3605 void fir::SelectTypeOp::print(mlir::OpAsmPrinter &p) { 3606 p << ' '; 3607 p.printOperand(getSelector()); 3608 p << " : " << getSelector().getType() << " ["; 3609 auto cases = 3610 getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue(); 3611 auto count = getNumConditions(); 3612 for (decltype(count) i = 0; i != count; ++i) { 3613 if (i) 3614 p << ", "; 3615 p << cases[i] << ", "; 3616 printSuccessorAtIndex(p, i); 3617 } 3618 p << ']'; 3619 p.printOptionalAttrDict(getOperation()->getAttrs(), 3620 {getCasesAttr(), getCompareOffsetAttr(), 3621 getTargetOffsetAttr(), 3622 fir::SelectTypeOp::getOperandSegmentSizeAttr()}); 3623 } 3624 3625 llvm::LogicalResult fir::SelectTypeOp::verify() { 3626 if (!mlir::isa<fir::BaseBoxType>(getSelector().getType())) 3627 return emitOpError("must be a fir.class or fir.box type"); 3628 if (auto boxType = mlir::dyn_cast<fir::BoxType>(getSelector().getType())) 3629 if (!mlir::isa<mlir::NoneType>(boxType.getEleTy())) 3630 return emitOpError("selector must be polymorphic"); 3631 auto typeGuardAttr = getCases(); 3632 for (unsigned idx = 0; idx < typeGuardAttr.size(); ++idx) 3633 if (mlir::isa<mlir::UnitAttr>(typeGuardAttr[idx]) && 3634 idx != typeGuardAttr.size() - 1) 3635 return emitOpError("default must be the last attribute"); 3636 auto count = getNumDest(); 3637 if (count == 0) 3638 return emitOpError("must have at least one successor"); 3639 if (getNumConditions() != count) 3640 return emitOpError("number of conditions and successors don't match"); 3641 if (targetOffsetSize() != count) 3642 return emitOpError("incorrect number of successor operand groups"); 3643 for (unsigned i = 0; i != count; ++i) { 3644 if (!mlir::isa<fir::ExactTypeAttr, fir::SubclassAttr, mlir::UnitAttr>( 3645 typeGuardAttr[i])) 3646 return emitOpError("invalid type-case alternative"); 3647 } 3648 return mlir::success(); 3649 } 3650 3651 void fir::SelectTypeOp::build(mlir::OpBuilder &builder, 3652 mlir::OperationState &result, 3653 mlir::Value selector, 3654 llvm::ArrayRef<mlir::Attribute> typeOperands, 3655 llvm::ArrayRef<mlir::Block *> destinations, 3656 llvm::ArrayRef<mlir::ValueRange> destOperands, 3657 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 3658 result.addOperands(selector); 3659 result.addAttribute(getCasesAttr(), builder.getArrayAttr(typeOperands)); 3660 const auto count = destinations.size(); 3661 for (mlir::Block *dest : destinations) 3662 result.addSuccessors(dest); 3663 const auto opCount = destOperands.size(); 3664 llvm::SmallVector<int32_t> argOffs; 3665 int32_t sumArgs = 0; 3666 for (std::remove_const_t<decltype(count)> i = 0; i != count; ++i) { 3667 if (i < opCount) { 3668 result.addOperands(destOperands[i]); 3669 const auto argSz = destOperands[i].size(); 3670 argOffs.push_back(argSz); 3671 sumArgs += argSz; 3672 } else { 3673 argOffs.push_back(0); 3674 } 3675 } 3676 result.addAttribute(getOperandSegmentSizeAttr(), 3677 builder.getDenseI32ArrayAttr({1, 0, sumArgs})); 3678 result.addAttribute(getTargetOffsetAttr(), 3679 builder.getDenseI32ArrayAttr(argOffs)); 3680 result.addAttributes(attributes); 3681 } 3682 3683 //===----------------------------------------------------------------------===// 3684 // ShapeOp 3685 //===----------------------------------------------------------------------===// 3686 3687 llvm::LogicalResult fir::ShapeOp::verify() { 3688 auto size = getExtents().size(); 3689 auto shapeTy = mlir::dyn_cast<fir::ShapeType>(getType()); 3690 assert(shapeTy && "must be a shape type"); 3691 if (shapeTy.getRank() != size) 3692 return emitOpError("shape type rank mismatch"); 3693 return mlir::success(); 3694 } 3695 3696 void fir::ShapeOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 3697 mlir::ValueRange extents) { 3698 auto type = fir::ShapeType::get(builder.getContext(), extents.size()); 3699 build(builder, result, type, extents); 3700 } 3701 3702 //===----------------------------------------------------------------------===// 3703 // ShapeShiftOp 3704 //===----------------------------------------------------------------------===// 3705 3706 llvm::LogicalResult fir::ShapeShiftOp::verify() { 3707 auto size = getPairs().size(); 3708 if (size < 2 || size > 16 * 2) 3709 return emitOpError("incorrect number of args"); 3710 if (size % 2 != 0) 3711 return emitOpError("requires a multiple of 2 args"); 3712 auto shapeTy = mlir::dyn_cast<fir::ShapeShiftType>(getType()); 3713 assert(shapeTy && "must be a shape shift type"); 3714 if (shapeTy.getRank() * 2 != size) 3715 return emitOpError("shape type rank mismatch"); 3716 return mlir::success(); 3717 } 3718 3719 //===----------------------------------------------------------------------===// 3720 // ShiftOp 3721 //===----------------------------------------------------------------------===// 3722 3723 llvm::LogicalResult fir::ShiftOp::verify() { 3724 auto size = getOrigins().size(); 3725 auto shiftTy = mlir::dyn_cast<fir::ShiftType>(getType()); 3726 assert(shiftTy && "must be a shift type"); 3727 if (shiftTy.getRank() != size) 3728 return emitOpError("shift type rank mismatch"); 3729 return mlir::success(); 3730 } 3731 3732 //===----------------------------------------------------------------------===// 3733 // SliceOp 3734 //===----------------------------------------------------------------------===// 3735 3736 void fir::SliceOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 3737 mlir::ValueRange trips, mlir::ValueRange path, 3738 mlir::ValueRange substr) { 3739 const auto rank = trips.size() / 3; 3740 auto sliceTy = fir::SliceType::get(builder.getContext(), rank); 3741 build(builder, result, sliceTy, trips, path, substr); 3742 } 3743 3744 /// Return the output rank of a slice op. The output rank must be between 1 and 3745 /// the rank of the array being sliced (inclusive). 3746 unsigned fir::SliceOp::getOutputRank(mlir::ValueRange triples) { 3747 unsigned rank = 0; 3748 if (!triples.empty()) { 3749 for (unsigned i = 1, end = triples.size(); i < end; i += 3) { 3750 auto *op = triples[i].getDefiningOp(); 3751 if (!mlir::isa_and_nonnull<fir::UndefOp>(op)) 3752 ++rank; 3753 } 3754 assert(rank > 0); 3755 } 3756 return rank; 3757 } 3758 3759 llvm::LogicalResult fir::SliceOp::verify() { 3760 auto size = getTriples().size(); 3761 if (size < 3 || size > 16 * 3) 3762 return emitOpError("incorrect number of args for triple"); 3763 if (size % 3 != 0) 3764 return emitOpError("requires a multiple of 3 args"); 3765 auto sliceTy = mlir::dyn_cast<fir::SliceType>(getType()); 3766 assert(sliceTy && "must be a slice type"); 3767 if (sliceTy.getRank() * 3 != size) 3768 return emitOpError("slice type rank mismatch"); 3769 return mlir::success(); 3770 } 3771 3772 //===----------------------------------------------------------------------===// 3773 // StoreOp 3774 //===----------------------------------------------------------------------===// 3775 3776 mlir::Type fir::StoreOp::elementType(mlir::Type refType) { 3777 return fir::dyn_cast_ptrEleTy(refType); 3778 } 3779 3780 mlir::ParseResult fir::StoreOp::parse(mlir::OpAsmParser &parser, 3781 mlir::OperationState &result) { 3782 mlir::Type type; 3783 mlir::OpAsmParser::UnresolvedOperand oper; 3784 mlir::OpAsmParser::UnresolvedOperand store; 3785 if (parser.parseOperand(oper) || parser.parseKeyword("to") || 3786 parser.parseOperand(store) || 3787 parser.parseOptionalAttrDict(result.attributes) || 3788 parser.parseColonType(type) || 3789 parser.resolveOperand(oper, fir::StoreOp::elementType(type), 3790 result.operands) || 3791 parser.resolveOperand(store, type, result.operands)) 3792 return mlir::failure(); 3793 return mlir::success(); 3794 } 3795 3796 void fir::StoreOp::print(mlir::OpAsmPrinter &p) { 3797 p << ' '; 3798 p.printOperand(getValue()); 3799 p << " to "; 3800 p.printOperand(getMemref()); 3801 p.printOptionalAttrDict(getOperation()->getAttrs(), {}); 3802 p << " : " << getMemref().getType(); 3803 } 3804 3805 llvm::LogicalResult fir::StoreOp::verify() { 3806 if (getValue().getType() != fir::dyn_cast_ptrEleTy(getMemref().getType())) 3807 return emitOpError("store value type must match memory reference type"); 3808 return mlir::success(); 3809 } 3810 3811 void fir::StoreOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 3812 mlir::Value value, mlir::Value memref) { 3813 build(builder, result, value, memref, {}); 3814 } 3815 3816 //===----------------------------------------------------------------------===// 3817 // StringLitOp 3818 //===----------------------------------------------------------------------===// 3819 3820 inline fir::CharacterType::KindTy stringLitOpGetKind(fir::StringLitOp op) { 3821 auto eleTy = mlir::cast<fir::SequenceType>(op.getType()).getElementType(); 3822 return mlir::cast<fir::CharacterType>(eleTy).getFKind(); 3823 } 3824 3825 bool fir::StringLitOp::isWideValue() { return stringLitOpGetKind(*this) != 1; } 3826 3827 static mlir::NamedAttribute 3828 mkNamedIntegerAttr(mlir::OpBuilder &builder, llvm::StringRef name, int64_t v) { 3829 assert(v > 0); 3830 return builder.getNamedAttr( 3831 name, builder.getIntegerAttr(builder.getIntegerType(64), v)); 3832 } 3833 3834 void fir::StringLitOp::build(mlir::OpBuilder &builder, 3835 mlir::OperationState &result, 3836 fir::CharacterType inType, llvm::StringRef val, 3837 std::optional<int64_t> len) { 3838 auto valAttr = builder.getNamedAttr(value(), builder.getStringAttr(val)); 3839 int64_t length = len ? *len : inType.getLen(); 3840 auto lenAttr = mkNamedIntegerAttr(builder, size(), length); 3841 result.addAttributes({valAttr, lenAttr}); 3842 result.addTypes(inType); 3843 } 3844 3845 template <typename C> 3846 static mlir::ArrayAttr convertToArrayAttr(mlir::OpBuilder &builder, 3847 llvm::ArrayRef<C> xlist) { 3848 llvm::SmallVector<mlir::Attribute> attrs; 3849 auto ty = builder.getIntegerType(8 * sizeof(C)); 3850 for (auto ch : xlist) 3851 attrs.push_back(builder.getIntegerAttr(ty, ch)); 3852 return builder.getArrayAttr(attrs); 3853 } 3854 3855 void fir::StringLitOp::build(mlir::OpBuilder &builder, 3856 mlir::OperationState &result, 3857 fir::CharacterType inType, 3858 llvm::ArrayRef<char> vlist, 3859 std::optional<std::int64_t> len) { 3860 auto valAttr = 3861 builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); 3862 std::int64_t length = len ? *len : inType.getLen(); 3863 auto lenAttr = mkNamedIntegerAttr(builder, size(), length); 3864 result.addAttributes({valAttr, lenAttr}); 3865 result.addTypes(inType); 3866 } 3867 3868 void fir::StringLitOp::build(mlir::OpBuilder &builder, 3869 mlir::OperationState &result, 3870 fir::CharacterType inType, 3871 llvm::ArrayRef<char16_t> vlist, 3872 std::optional<std::int64_t> len) { 3873 auto valAttr = 3874 builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); 3875 std::int64_t length = len ? *len : inType.getLen(); 3876 auto lenAttr = mkNamedIntegerAttr(builder, size(), length); 3877 result.addAttributes({valAttr, lenAttr}); 3878 result.addTypes(inType); 3879 } 3880 3881 void fir::StringLitOp::build(mlir::OpBuilder &builder, 3882 mlir::OperationState &result, 3883 fir::CharacterType inType, 3884 llvm::ArrayRef<char32_t> vlist, 3885 std::optional<std::int64_t> len) { 3886 auto valAttr = 3887 builder.getNamedAttr(xlist(), convertToArrayAttr(builder, vlist)); 3888 std::int64_t length = len ? *len : inType.getLen(); 3889 auto lenAttr = mkNamedIntegerAttr(builder, size(), length); 3890 result.addAttributes({valAttr, lenAttr}); 3891 result.addTypes(inType); 3892 } 3893 3894 mlir::ParseResult fir::StringLitOp::parse(mlir::OpAsmParser &parser, 3895 mlir::OperationState &result) { 3896 auto &builder = parser.getBuilder(); 3897 mlir::Attribute val; 3898 mlir::NamedAttrList attrs; 3899 llvm::SMLoc trailingTypeLoc; 3900 if (parser.parseAttribute(val, "fake", attrs)) 3901 return mlir::failure(); 3902 if (auto v = mlir::dyn_cast<mlir::StringAttr>(val)) 3903 result.attributes.push_back( 3904 builder.getNamedAttr(fir::StringLitOp::value(), v)); 3905 else if (auto v = mlir::dyn_cast<mlir::DenseElementsAttr>(val)) 3906 result.attributes.push_back( 3907 builder.getNamedAttr(fir::StringLitOp::xlist(), v)); 3908 else if (auto v = mlir::dyn_cast<mlir::ArrayAttr>(val)) 3909 result.attributes.push_back( 3910 builder.getNamedAttr(fir::StringLitOp::xlist(), v)); 3911 else 3912 return parser.emitError(parser.getCurrentLocation(), 3913 "found an invalid constant"); 3914 mlir::IntegerAttr sz; 3915 mlir::Type type; 3916 if (parser.parseLParen() || 3917 parser.parseAttribute(sz, fir::StringLitOp::size(), result.attributes) || 3918 parser.parseRParen() || parser.getCurrentLocation(&trailingTypeLoc) || 3919 parser.parseColonType(type)) 3920 return mlir::failure(); 3921 auto charTy = mlir::dyn_cast<fir::CharacterType>(type); 3922 if (!charTy) 3923 return parser.emitError(trailingTypeLoc, "must have character type"); 3924 type = fir::CharacterType::get(builder.getContext(), charTy.getFKind(), 3925 sz.getInt()); 3926 if (!type || parser.addTypesToList(type, result.types)) 3927 return mlir::failure(); 3928 return mlir::success(); 3929 } 3930 3931 void fir::StringLitOp::print(mlir::OpAsmPrinter &p) { 3932 p << ' ' << getValue() << '('; 3933 p << mlir::cast<mlir::IntegerAttr>(getSize()).getValue() << ") : "; 3934 p.printType(getType()); 3935 } 3936 3937 llvm::LogicalResult fir::StringLitOp::verify() { 3938 if (mlir::cast<mlir::IntegerAttr>(getSize()).getValue().isNegative()) 3939 return emitOpError("size must be non-negative"); 3940 if (auto xl = getOperation()->getAttr(fir::StringLitOp::xlist())) { 3941 if (auto xList = mlir::dyn_cast<mlir::ArrayAttr>(xl)) { 3942 for (auto a : xList) 3943 if (!mlir::isa<mlir::IntegerAttr>(a)) 3944 return emitOpError("values in initializer must be integers"); 3945 } else if (mlir::isa<mlir::DenseElementsAttr>(xl)) { 3946 // do nothing 3947 } else { 3948 return emitOpError("has unexpected attribute"); 3949 } 3950 } 3951 return mlir::success(); 3952 } 3953 3954 //===----------------------------------------------------------------------===// 3955 // UnboxProcOp 3956 //===----------------------------------------------------------------------===// 3957 3958 llvm::LogicalResult fir::UnboxProcOp::verify() { 3959 if (auto eleTy = fir::dyn_cast_ptrEleTy(getRefTuple().getType())) 3960 if (mlir::isa<mlir::TupleType>(eleTy)) 3961 return mlir::success(); 3962 return emitOpError("second output argument has bad type"); 3963 } 3964 3965 //===----------------------------------------------------------------------===// 3966 // IfOp 3967 //===----------------------------------------------------------------------===// 3968 3969 void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 3970 mlir::Value cond, bool withElseRegion) { 3971 build(builder, result, std::nullopt, cond, withElseRegion); 3972 } 3973 3974 void fir::IfOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 3975 mlir::TypeRange resultTypes, mlir::Value cond, 3976 bool withElseRegion) { 3977 result.addOperands(cond); 3978 result.addTypes(resultTypes); 3979 3980 mlir::Region *thenRegion = result.addRegion(); 3981 thenRegion->push_back(new mlir::Block()); 3982 if (resultTypes.empty()) 3983 IfOp::ensureTerminator(*thenRegion, builder, result.location); 3984 3985 mlir::Region *elseRegion = result.addRegion(); 3986 if (withElseRegion) { 3987 elseRegion->push_back(new mlir::Block()); 3988 if (resultTypes.empty()) 3989 IfOp::ensureTerminator(*elseRegion, builder, result.location); 3990 } 3991 } 3992 3993 // These 3 functions copied from scf.if implementation. 3994 3995 /// Given the region at `index`, or the parent operation if `index` is None, 3996 /// return the successor regions. These are the regions that may be selected 3997 /// during the flow of control. 3998 void fir::IfOp::getSuccessorRegions( 3999 mlir::RegionBranchPoint point, 4000 llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) { 4001 // The `then` and the `else` region branch back to the parent operation. 4002 if (!point.isParent()) { 4003 regions.push_back(mlir::RegionSuccessor(getResults())); 4004 return; 4005 } 4006 4007 // Don't consider the else region if it is empty. 4008 regions.push_back(mlir::RegionSuccessor(&getThenRegion())); 4009 4010 // Don't consider the else region if it is empty. 4011 mlir::Region *elseRegion = &this->getElseRegion(); 4012 if (elseRegion->empty()) 4013 regions.push_back(mlir::RegionSuccessor()); 4014 else 4015 regions.push_back(mlir::RegionSuccessor(elseRegion)); 4016 } 4017 4018 void fir::IfOp::getEntrySuccessorRegions( 4019 llvm::ArrayRef<mlir::Attribute> operands, 4020 llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) { 4021 FoldAdaptor adaptor(operands); 4022 auto boolAttr = 4023 mlir::dyn_cast_or_null<mlir::BoolAttr>(adaptor.getCondition()); 4024 if (!boolAttr || boolAttr.getValue()) 4025 regions.emplace_back(&getThenRegion()); 4026 4027 // If the else region is empty, execution continues after the parent op. 4028 if (!boolAttr || !boolAttr.getValue()) { 4029 if (!getElseRegion().empty()) 4030 regions.emplace_back(&getElseRegion()); 4031 else 4032 regions.emplace_back(getResults()); 4033 } 4034 } 4035 4036 void fir::IfOp::getRegionInvocationBounds( 4037 llvm::ArrayRef<mlir::Attribute> operands, 4038 llvm::SmallVectorImpl<mlir::InvocationBounds> &invocationBounds) { 4039 if (auto cond = mlir::dyn_cast_or_null<mlir::BoolAttr>(operands[0])) { 4040 // If the condition is known, then one region is known to be executed once 4041 // and the other zero times. 4042 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0); 4043 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1); 4044 } else { 4045 // Non-constant condition. Each region may be executed 0 or 1 times. 4046 invocationBounds.assign(2, {0, 1}); 4047 } 4048 } 4049 4050 mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser, 4051 mlir::OperationState &result) { 4052 result.regions.reserve(2); 4053 mlir::Region *thenRegion = result.addRegion(); 4054 mlir::Region *elseRegion = result.addRegion(); 4055 4056 auto &builder = parser.getBuilder(); 4057 mlir::OpAsmParser::UnresolvedOperand cond; 4058 mlir::Type i1Type = builder.getIntegerType(1); 4059 if (parser.parseOperand(cond) || 4060 parser.resolveOperand(cond, i1Type, result.operands)) 4061 return mlir::failure(); 4062 4063 if (parser.parseOptionalArrowTypeList(result.types)) 4064 return mlir::failure(); 4065 4066 if (parser.parseRegion(*thenRegion, {}, {})) 4067 return mlir::failure(); 4068 fir::IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), 4069 result.location); 4070 4071 if (mlir::succeeded(parser.parseOptionalKeyword("else"))) { 4072 if (parser.parseRegion(*elseRegion, {}, {})) 4073 return mlir::failure(); 4074 fir::IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), 4075 result.location); 4076 } 4077 4078 // Parse the optional attribute list. 4079 if (parser.parseOptionalAttrDict(result.attributes)) 4080 return mlir::failure(); 4081 return mlir::success(); 4082 } 4083 4084 llvm::LogicalResult fir::IfOp::verify() { 4085 if (getNumResults() != 0 && getElseRegion().empty()) 4086 return emitOpError("must have an else block if defining values"); 4087 4088 return mlir::success(); 4089 } 4090 4091 void fir::IfOp::print(mlir::OpAsmPrinter &p) { 4092 bool printBlockTerminators = false; 4093 p << ' ' << getCondition(); 4094 if (!getResults().empty()) { 4095 p << " -> (" << getResultTypes() << ')'; 4096 printBlockTerminators = true; 4097 } 4098 p << ' '; 4099 p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false, 4100 printBlockTerminators); 4101 4102 // Print the 'else' regions if it exists and has a block. 4103 auto &otherReg = getElseRegion(); 4104 if (!otherReg.empty()) { 4105 p << " else "; 4106 p.printRegion(otherReg, /*printEntryBlockArgs=*/false, 4107 printBlockTerminators); 4108 } 4109 p.printOptionalAttrDict((*this)->getAttrs()); 4110 } 4111 4112 void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results, 4113 unsigned resultNum) { 4114 auto *term = getThenRegion().front().getTerminator(); 4115 if (resultNum < term->getNumOperands()) 4116 results.push_back(term->getOperand(resultNum)); 4117 term = getElseRegion().front().getTerminator(); 4118 if (resultNum < term->getNumOperands()) 4119 results.push_back(term->getOperand(resultNum)); 4120 } 4121 4122 //===----------------------------------------------------------------------===// 4123 // BoxOffsetOp 4124 //===----------------------------------------------------------------------===// 4125 4126 llvm::LogicalResult fir::BoxOffsetOp::verify() { 4127 auto boxType = mlir::dyn_cast_or_null<fir::BaseBoxType>( 4128 fir::dyn_cast_ptrEleTy(getBoxRef().getType())); 4129 if (!boxType) 4130 return emitOpError("box_ref operand must have !fir.ref<!fir.box<T>> type"); 4131 if (getField() != fir::BoxFieldAttr::base_addr && 4132 getField() != fir::BoxFieldAttr::derived_type) 4133 return emitOpError("cannot address provided field"); 4134 if (getField() == fir::BoxFieldAttr::derived_type) 4135 if (!fir::boxHasAddendum(boxType)) 4136 return emitOpError("can only address derived_type field of derived type " 4137 "or unlimited polymorphic fir.box"); 4138 return mlir::success(); 4139 } 4140 4141 void fir::BoxOffsetOp::build(mlir::OpBuilder &builder, 4142 mlir::OperationState &result, mlir::Value boxRef, 4143 fir::BoxFieldAttr field) { 4144 mlir::Type valueType = 4145 fir::unwrapPassByRefType(fir::unwrapRefType(boxRef.getType())); 4146 mlir::Type resultType = valueType; 4147 if (field == fir::BoxFieldAttr::base_addr) 4148 resultType = fir::LLVMPointerType::get(fir::ReferenceType::get(valueType)); 4149 else if (field == fir::BoxFieldAttr::derived_type) 4150 resultType = fir::LLVMPointerType::get( 4151 fir::TypeDescType::get(fir::unwrapSequenceType(valueType))); 4152 build(builder, result, {resultType}, boxRef, field); 4153 } 4154 4155 //===----------------------------------------------------------------------===// 4156 4157 mlir::ParseResult fir::isValidCaseAttr(mlir::Attribute attr) { 4158 if (mlir::isa<mlir::UnitAttr, fir::ClosedIntervalAttr, fir::PointIntervalAttr, 4159 fir::LowerBoundAttr, fir::UpperBoundAttr>(attr)) 4160 return mlir::success(); 4161 return mlir::failure(); 4162 } 4163 4164 unsigned fir::getCaseArgumentOffset(llvm::ArrayRef<mlir::Attribute> cases, 4165 unsigned dest) { 4166 unsigned o = 0; 4167 for (unsigned i = 0; i < dest; ++i) { 4168 auto &attr = cases[i]; 4169 if (!mlir::dyn_cast_or_null<mlir::UnitAttr>(attr)) { 4170 ++o; 4171 if (mlir::dyn_cast_or_null<fir::ClosedIntervalAttr>(attr)) 4172 ++o; 4173 } 4174 } 4175 return o; 4176 } 4177 4178 mlir::ParseResult 4179 fir::parseSelector(mlir::OpAsmParser &parser, mlir::OperationState &result, 4180 mlir::OpAsmParser::UnresolvedOperand &selector, 4181 mlir::Type &type) { 4182 if (parser.parseOperand(selector) || parser.parseColonType(type) || 4183 parser.resolveOperand(selector, type, result.operands) || 4184 parser.parseLSquare()) 4185 return mlir::failure(); 4186 return mlir::success(); 4187 } 4188 4189 mlir::func::FuncOp fir::createFuncOp(mlir::Location loc, mlir::ModuleOp module, 4190 llvm::StringRef name, 4191 mlir::FunctionType type, 4192 llvm::ArrayRef<mlir::NamedAttribute> attrs, 4193 const mlir::SymbolTable *symbolTable) { 4194 if (symbolTable) 4195 if (auto f = symbolTable->lookup<mlir::func::FuncOp>(name)) { 4196 #ifdef EXPENSIVE_CHECKS 4197 assert(f == module.lookupSymbol<mlir::func::FuncOp>(name) && 4198 "symbolTable and module out of sync"); 4199 #endif 4200 return f; 4201 } 4202 if (auto f = module.lookupSymbol<mlir::func::FuncOp>(name)) 4203 return f; 4204 mlir::OpBuilder modBuilder(module.getBodyRegion()); 4205 modBuilder.setInsertionPointToEnd(module.getBody()); 4206 auto result = modBuilder.create<mlir::func::FuncOp>(loc, name, type, attrs); 4207 result.setVisibility(mlir::SymbolTable::Visibility::Private); 4208 return result; 4209 } 4210 4211 fir::GlobalOp fir::createGlobalOp(mlir::Location loc, mlir::ModuleOp module, 4212 llvm::StringRef name, mlir::Type type, 4213 llvm::ArrayRef<mlir::NamedAttribute> attrs, 4214 const mlir::SymbolTable *symbolTable) { 4215 if (symbolTable) 4216 if (auto g = symbolTable->lookup<fir::GlobalOp>(name)) { 4217 #ifdef EXPENSIVE_CHECKS 4218 assert(g == module.lookupSymbol<fir::GlobalOp>(name) && 4219 "symbolTable and module out of sync"); 4220 #endif 4221 return g; 4222 } 4223 if (auto g = module.lookupSymbol<fir::GlobalOp>(name)) 4224 return g; 4225 mlir::OpBuilder modBuilder(module.getBodyRegion()); 4226 auto result = modBuilder.create<fir::GlobalOp>(loc, name, type, attrs); 4227 result.setVisibility(mlir::SymbolTable::Visibility::Private); 4228 return result; 4229 } 4230 4231 bool fir::hasHostAssociationArgument(mlir::func::FuncOp func) { 4232 if (auto allArgAttrs = func.getAllArgAttrs()) 4233 for (auto attr : allArgAttrs) 4234 if (auto dict = mlir::dyn_cast_or_null<mlir::DictionaryAttr>(attr)) 4235 if (dict.get(fir::getHostAssocAttrName())) 4236 return true; 4237 return false; 4238 } 4239 4240 // Test if value's definition has the specified set of 4241 // attributeNames. The value's definition is one of the operations 4242 // that are able to carry the Fortran variable attributes, e.g. 4243 // fir.alloca or fir.allocmem. Function arguments may also represent 4244 // value definitions and carry relevant attributes. 4245 // 4246 // If it is not possible to reach the limited set of definition 4247 // entities from the given value, then the function will return 4248 // std::nullopt. Otherwise, the definition is known and the return 4249 // value is computed as: 4250 // * if checkAny is true, then the function will return true 4251 // iff any of the attributeNames attributes is set on the definition. 4252 // * if checkAny is false, then the function will return true 4253 // iff all of the attributeNames attributes are set on the definition. 4254 static std::optional<bool> 4255 valueCheckFirAttributes(mlir::Value value, 4256 llvm::ArrayRef<llvm::StringRef> attributeNames, 4257 bool checkAny) { 4258 auto testAttributeSets = [&](llvm::ArrayRef<mlir::NamedAttribute> setAttrs, 4259 llvm::ArrayRef<llvm::StringRef> checkAttrs) { 4260 if (checkAny) { 4261 // Return true iff any of checkAttrs attributes is present 4262 // in setAttrs set. 4263 for (llvm::StringRef checkAttrName : checkAttrs) 4264 if (llvm::any_of(setAttrs, [&](mlir::NamedAttribute setAttr) { 4265 return setAttr.getName() == checkAttrName; 4266 })) 4267 return true; 4268 4269 return false; 4270 } 4271 4272 // Return true iff all attributes from checkAttrs are present 4273 // in setAttrs set. 4274 for (mlir::StringRef checkAttrName : checkAttrs) 4275 if (llvm::none_of(setAttrs, [&](mlir::NamedAttribute setAttr) { 4276 return setAttr.getName() == checkAttrName; 4277 })) 4278 return false; 4279 4280 return true; 4281 }; 4282 // If this is a fir.box that was loaded, the fir attributes will be on the 4283 // related fir.ref<fir.box> creation. 4284 if (mlir::isa<fir::BoxType>(value.getType())) 4285 if (auto definingOp = value.getDefiningOp()) 4286 if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(definingOp)) 4287 value = loadOp.getMemref(); 4288 // If this is a function argument, look in the argument attributes. 4289 if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(value)) { 4290 if (blockArg.getOwner() && blockArg.getOwner()->isEntryBlock()) 4291 if (auto funcOp = mlir::dyn_cast<mlir::func::FuncOp>( 4292 blockArg.getOwner()->getParentOp())) 4293 return testAttributeSets( 4294 mlir::cast<mlir::FunctionOpInterface>(*funcOp).getArgAttrs( 4295 blockArg.getArgNumber()), 4296 attributeNames); 4297 4298 // If it is not a function argument, the attributes are unknown. 4299 return std::nullopt; 4300 } 4301 4302 if (auto definingOp = value.getDefiningOp()) { 4303 // If this is an allocated value, look at the allocation attributes. 4304 if (mlir::isa<fir::AllocMemOp>(definingOp) || 4305 mlir::isa<fir::AllocaOp>(definingOp)) 4306 return testAttributeSets(definingOp->getAttrs(), attributeNames); 4307 // If this is an imported global, look at AddrOfOp and GlobalOp attributes. 4308 // Both operations are looked at because use/host associated variable (the 4309 // AddrOfOp) can have ASYNCHRONOUS/VOLATILE attributes even if the ultimate 4310 // entity (the globalOp) does not have them. 4311 if (auto addressOfOp = mlir::dyn_cast<fir::AddrOfOp>(definingOp)) { 4312 if (testAttributeSets(addressOfOp->getAttrs(), attributeNames)) 4313 return true; 4314 if (auto module = definingOp->getParentOfType<mlir::ModuleOp>()) 4315 if (auto globalOp = 4316 module.lookupSymbol<fir::GlobalOp>(addressOfOp.getSymbol())) 4317 return testAttributeSets(globalOp->getAttrs(), attributeNames); 4318 } 4319 } 4320 // TODO: Construct associated entities attributes. Decide where the fir 4321 // attributes must be placed/looked for in this case. 4322 return std::nullopt; 4323 } 4324 4325 bool fir::valueMayHaveFirAttributes( 4326 mlir::Value value, llvm::ArrayRef<llvm::StringRef> attributeNames) { 4327 std::optional<bool> mayHaveAttr = 4328 valueCheckFirAttributes(value, attributeNames, /*checkAny=*/true); 4329 return mayHaveAttr.value_or(true); 4330 } 4331 4332 bool fir::valueHasFirAttribute(mlir::Value value, 4333 llvm::StringRef attributeName) { 4334 std::optional<bool> mayHaveAttr = 4335 valueCheckFirAttributes(value, {attributeName}, /*checkAny=*/false); 4336 return mayHaveAttr.value_or(false); 4337 } 4338 4339 bool fir::anyFuncArgsHaveAttr(mlir::func::FuncOp func, llvm::StringRef attr) { 4340 for (unsigned i = 0, end = func.getNumArguments(); i < end; ++i) 4341 if (func.getArgAttr(i, attr)) 4342 return true; 4343 return false; 4344 } 4345 4346 std::optional<std::int64_t> fir::getIntIfConstant(mlir::Value value) { 4347 if (auto *definingOp = value.getDefiningOp()) { 4348 if (auto cst = mlir::dyn_cast<mlir::arith::ConstantOp>(definingOp)) 4349 if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(cst.getValue())) 4350 return intAttr.getInt(); 4351 if (auto llConstOp = mlir::dyn_cast<mlir::LLVM::ConstantOp>(definingOp)) 4352 if (auto attr = mlir::dyn_cast<mlir::IntegerAttr>(llConstOp.getValue())) 4353 return attr.getValue().getSExtValue(); 4354 } 4355 return {}; 4356 } 4357 4358 bool fir::isDummyArgument(mlir::Value v) { 4359 auto blockArg{mlir::dyn_cast<mlir::BlockArgument>(v)}; 4360 if (!blockArg) { 4361 auto defOp = v.getDefiningOp(); 4362 if (defOp) { 4363 if (auto declareOp = mlir::dyn_cast<fir::DeclareOp>(defOp)) 4364 if (declareOp.getDummyScope()) 4365 return true; 4366 } 4367 return false; 4368 } 4369 4370 auto *owner{blockArg.getOwner()}; 4371 return owner->isEntryBlock() && 4372 mlir::isa<mlir::FunctionOpInterface>(owner->getParentOp()); 4373 } 4374 4375 mlir::Type fir::applyPathToType(mlir::Type eleTy, mlir::ValueRange path) { 4376 for (auto i = path.begin(), end = path.end(); eleTy && i < end;) { 4377 eleTy = llvm::TypeSwitch<mlir::Type, mlir::Type>(eleTy) 4378 .Case<fir::RecordType>([&](fir::RecordType ty) { 4379 if (auto *op = (*i++).getDefiningOp()) { 4380 if (auto off = mlir::dyn_cast<fir::FieldIndexOp>(op)) 4381 return ty.getType(off.getFieldName()); 4382 if (auto off = mlir::dyn_cast<mlir::arith::ConstantOp>(op)) 4383 return ty.getType(fir::toInt(off)); 4384 } 4385 return mlir::Type{}; 4386 }) 4387 .Case<fir::SequenceType>([&](fir::SequenceType ty) { 4388 bool valid = true; 4389 const auto rank = ty.getDimension(); 4390 for (std::remove_const_t<decltype(rank)> ii = 0; 4391 valid && ii < rank; ++ii) 4392 valid = i < end && fir::isa_integer((*i++).getType()); 4393 return valid ? ty.getEleTy() : mlir::Type{}; 4394 }) 4395 .Case<mlir::TupleType>([&](mlir::TupleType ty) { 4396 if (auto *op = (*i++).getDefiningOp()) 4397 if (auto off = mlir::dyn_cast<mlir::arith::ConstantOp>(op)) 4398 return ty.getType(fir::toInt(off)); 4399 return mlir::Type{}; 4400 }) 4401 .Case<mlir::ComplexType>([&](mlir::ComplexType ty) { 4402 if (fir::isa_integer((*i++).getType())) 4403 return ty.getElementType(); 4404 return mlir::Type{}; 4405 }) 4406 .Default([&](const auto &) { return mlir::Type{}; }); 4407 } 4408 return eleTy; 4409 } 4410 4411 llvm::LogicalResult fir::DeclareOp::verify() { 4412 auto fortranVar = 4413 mlir::cast<fir::FortranVariableOpInterface>(this->getOperation()); 4414 return fortranVar.verifyDeclareLikeOpImpl(getMemref()); 4415 } 4416 4417 //===----------------------------------------------------------------------===// 4418 // FIROpsDialect 4419 //===----------------------------------------------------------------------===// 4420 4421 void fir::FIROpsDialect::registerOpExternalInterfaces() { 4422 // Attach default declare target interfaces to operations which can be marked 4423 // as declare target. 4424 fir::GlobalOp::attachInterface< 4425 mlir::omp::DeclareTargetDefaultModel<fir::GlobalOp>>(*getContext()); 4426 } 4427 4428 // Tablegen operators 4429 4430 #define GET_OP_CLASSES 4431 #include "flang/Optimizer/Dialect/FIROps.cpp.inc" 4432