1 //===- OptimizedBufferization.cpp - special cases for bufferization -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // In some special cases we can bufferize hlfir expressions in a more optimal 9 // way so as to avoid creating temporaries. This pass handles these. It should 10 // be run before the catch-all bufferization pass. 11 // 12 // This requires constant subexpression elimination to have already been run. 13 //===----------------------------------------------------------------------===// 14 15 #include "flang/Optimizer/Analysis/AliasAnalysis.h" 16 #include "flang/Optimizer/Builder/FIRBuilder.h" 17 #include "flang/Optimizer/Builder/HLFIRTools.h" 18 #include "flang/Optimizer/Dialect/FIROps.h" 19 #include "flang/Optimizer/Dialect/FIRType.h" 20 #include "flang/Optimizer/HLFIR/HLFIRDialect.h" 21 #include "flang/Optimizer/HLFIR/HLFIROps.h" 22 #include "flang/Optimizer/HLFIR/Passes.h" 23 #include "flang/Optimizer/OpenMP/Passes.h" 24 #include "flang/Optimizer/Transforms/Utils.h" 25 #include "mlir/Dialect/Func/IR/FuncOps.h" 26 #include "mlir/IR/Dominance.h" 27 #include "mlir/IR/PatternMatch.h" 28 #include "mlir/Interfaces/SideEffectInterfaces.h" 29 #include "mlir/Pass/Pass.h" 30 #include "mlir/Support/LLVM.h" 31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 32 #include "llvm/ADT/TypeSwitch.h" 33 #include <iterator> 34 #include <memory> 35 #include <mlir/Analysis/AliasAnalysis.h> 36 #include <optional> 37 38 namespace hlfir { 39 #define GEN_PASS_DEF_OPTIMIZEDBUFFERIZATION 40 #include "flang/Optimizer/HLFIR/Passes.h.inc" 41 } // namespace hlfir 42 43 #define DEBUG_TYPE "opt-bufferization" 44 45 namespace { 46 47 /// This transformation should match in place modification of arrays. 48 /// It should match code of the form 49 /// %array = some.operation // array has shape %shape 50 /// %expr = hlfir.elemental %shape : [...] { 51 /// bb0(%arg0: index) 52 /// %0 = hlfir.designate %array(%arg0) 53 /// [...] // no other reads or writes to %array 54 /// hlfir.yield_element %element 55 /// } 56 /// hlfir.assign %expr to %array 57 /// hlfir.destroy %expr 58 /// 59 /// Or 60 /// 61 /// %read_array = some.operation // shape %shape 62 /// %expr = hlfir.elemental %shape : [...] { 63 /// bb0(%arg0: index) 64 /// %0 = hlfir.designate %read_array(%arg0) 65 /// [...] 66 /// hlfir.yield_element %element 67 /// } 68 /// %write_array = some.operation // with shape %shape 69 /// [...] // operations which don't effect write_array 70 /// hlfir.assign %expr to %write_array 71 /// hlfir.destroy %expr 72 /// 73 /// In these cases, it is safe to turn the elemental into a do loop and modify 74 /// elements of %array in place without creating an extra temporary for the 75 /// elemental. We must check that there are no reads from the array at indexes 76 /// which might conflict with the assignment or any writes. For now we will keep 77 /// that strict and say that all reads must be at the elemental index (it is 78 /// probably safe to read from higher indices if lowering to an ordered loop). 79 class ElementalAssignBufferization 80 : public mlir::OpRewritePattern<hlfir::ElementalOp> { 81 private: 82 struct MatchInfo { 83 mlir::Value array; 84 hlfir::AssignOp assign; 85 hlfir::DestroyOp destroy; 86 }; 87 /// determines if the transformation can be applied to this elemental 88 static std::optional<MatchInfo> findMatch(hlfir::ElementalOp elemental); 89 90 /// Returns the array indices for the given hlfir.designate. 91 /// It recognizes the computations used to transform the one-based indices 92 /// into the array's lb-based indices, and returns the one-based indices 93 /// in these cases. 94 static llvm::SmallVector<mlir::Value> 95 getDesignatorIndices(hlfir::DesignateOp designate); 96 97 public: 98 using mlir::OpRewritePattern<hlfir::ElementalOp>::OpRewritePattern; 99 100 llvm::LogicalResult 101 matchAndRewrite(hlfir::ElementalOp elemental, 102 mlir::PatternRewriter &rewriter) const override; 103 }; 104 105 /// recursively collect all effects between start and end (including start, not 106 /// including end) start must properly dominate end, start and end must be in 107 /// the same block. If any operations with unknown effects are found, 108 /// std::nullopt is returned 109 static std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>> 110 getEffectsBetween(mlir::Operation *start, mlir::Operation *end) { 111 mlir::SmallVector<mlir::MemoryEffects::EffectInstance> ret; 112 if (start == end) 113 return ret; 114 assert(start->getBlock() && end->getBlock() && "TODO: block arguments"); 115 assert(start->getBlock() == end->getBlock()); 116 assert(mlir::DominanceInfo{}.properlyDominates(start, end)); 117 118 mlir::Operation *nextOp = start; 119 while (nextOp && nextOp != end) { 120 std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>> 121 effects = mlir::getEffectsRecursively(nextOp); 122 if (!effects) 123 return std::nullopt; 124 ret.append(*effects); 125 nextOp = nextOp->getNextNode(); 126 } 127 return ret; 128 } 129 130 /// If effect is a read or write on val, return whether it aliases. 131 /// Otherwise return mlir::AliasResult::NoAlias 132 static mlir::AliasResult 133 containsReadOrWriteEffectOn(const mlir::MemoryEffects::EffectInstance &effect, 134 mlir::Value val) { 135 fir::AliasAnalysis aliasAnalysis; 136 137 if (mlir::isa<mlir::MemoryEffects::Read, mlir::MemoryEffects::Write>( 138 effect.getEffect())) { 139 mlir::Value accessedVal = effect.getValue(); 140 if (mlir::isa<fir::DebuggingResource>(effect.getResource())) 141 return mlir::AliasResult::NoAlias; 142 if (!accessedVal) 143 return mlir::AliasResult::MayAlias; 144 if (accessedVal == val) 145 return mlir::AliasResult::MustAlias; 146 147 // if the accessed value might alias val 148 mlir::AliasResult res = aliasAnalysis.alias(val, accessedVal); 149 if (!res.isNo()) 150 return res; 151 152 // FIXME: alias analysis of fir.load 153 // follow this common pattern: 154 // %ref = hlfir.designate %array(%index) 155 // %val = fir.load $ref 156 if (auto designate = accessedVal.getDefiningOp<hlfir::DesignateOp>()) { 157 if (designate.getMemref() == val) 158 return mlir::AliasResult::MustAlias; 159 160 // if the designate is into an array that might alias val 161 res = aliasAnalysis.alias(val, designate.getMemref()); 162 if (!res.isNo()) 163 return res; 164 } 165 } 166 return mlir::AliasResult::NoAlias; 167 } 168 169 // Helper class for analyzing two array slices represented 170 // by two hlfir.designate operations. 171 class ArraySectionAnalyzer { 172 public: 173 // The result of the analyzis is one of the values below. 174 enum class SlicesOverlapKind { 175 // Slices overlap is unknown. 176 Unknown, 177 // Slices are definitely identical. 178 DefinitelyIdentical, 179 // Slices are definitely disjoint. 180 DefinitelyDisjoint, 181 // Slices may be either disjoint or identical, 182 // i.e. there is definitely no partial overlap. 183 EitherIdenticalOrDisjoint 184 }; 185 186 // Analyzes two hlfir.designate results and returns the overlap kind. 187 // The callers may use this method when the alias analysis reports 188 // an alias of some kind, so that we can run Fortran specific analysis 189 // on the array slices to see if they are identical or disjoint. 190 // Note that the alias analysis are not able to give such an answer 191 // about the references. 192 static SlicesOverlapKind analyze(mlir::Value ref1, mlir::Value ref2); 193 194 private: 195 struct SectionDesc { 196 // An array section is described by <lb, ub, stride> tuple. 197 // If the designator's subscript is not a triple, then 198 // the section descriptor is constructed as <lb, nullptr, nullptr>. 199 mlir::Value lb, ub, stride; 200 201 SectionDesc(mlir::Value lb, mlir::Value ub, mlir::Value stride) 202 : lb(lb), ub(ub), stride(stride) { 203 assert(lb && "lower bound or index must be specified"); 204 normalize(); 205 } 206 207 // Normalize the section descriptor: 208 // 1. If UB is nullptr, then it is set to LB. 209 // 2. If LB==UB, then stride does not matter, 210 // so it is reset to nullptr. 211 // 3. If STRIDE==1, then it is reset to nullptr. 212 void normalize() { 213 if (!ub) 214 ub = lb; 215 if (lb == ub) 216 stride = nullptr; 217 if (stride) 218 if (auto val = fir::getIntIfConstant(stride)) 219 if (*val == 1) 220 stride = nullptr; 221 } 222 223 bool operator==(const SectionDesc &other) const { 224 return lb == other.lb && ub == other.ub && stride == other.stride; 225 } 226 }; 227 228 // Given an operand_iterator over the indices operands, 229 // read the subscript values and return them as SectionDesc 230 // updating the iterator. If isTriplet is true, 231 // the subscript is a triplet, and the result is <lb, ub, stride>. 232 // Otherwise, the subscript is a scalar index, and the result 233 // is <index, nullptr, nullptr>. 234 static SectionDesc readSectionDesc(mlir::Operation::operand_iterator &it, 235 bool isTriplet) { 236 if (isTriplet) 237 return {*it++, *it++, *it++}; 238 return {*it++, nullptr, nullptr}; 239 } 240 241 // Return the ordered lower and upper bounds of the section. 242 // If stride is known to be non-negative, then the ordered 243 // bounds match the <lb, ub> of the descriptor. 244 // If stride is known to be negative, then the ordered 245 // bounds are <ub, lb> of the descriptor. 246 // If stride is unknown, we cannot deduce any order, 247 // so the result is <nullptr, nullptr> 248 static std::pair<mlir::Value, mlir::Value> 249 getOrderedBounds(const SectionDesc &desc) { 250 mlir::Value stride = desc.stride; 251 // Null stride means stride=1. 252 if (!stride) 253 return {desc.lb, desc.ub}; 254 // Reverse the bounds, if stride is negative. 255 if (auto val = fir::getIntIfConstant(stride)) { 256 if (*val >= 0) 257 return {desc.lb, desc.ub}; 258 else 259 return {desc.ub, desc.lb}; 260 } 261 262 return {nullptr, nullptr}; 263 } 264 265 // Given two array sections <lb1, ub1, stride1> and 266 // <lb2, ub2, stride2>, return true only if the sections 267 // are known to be disjoint. 268 // 269 // For example, for any positive constant C: 270 // X:Y does not overlap with (Y+C):Z 271 // X:Y does not overlap with Z:(X-C) 272 static bool areDisjointSections(const SectionDesc &desc1, 273 const SectionDesc &desc2) { 274 auto [lb1, ub1] = getOrderedBounds(desc1); 275 auto [lb2, ub2] = getOrderedBounds(desc2); 276 if (!lb1 || !lb2) 277 return false; 278 // Note that this comparison must be made on the ordered bounds, 279 // otherwise 'a(x:y:1) = a(z:x-1:-1) + 1' may be incorrectly treated 280 // as not overlapping (x=2, y=10, z=9). 281 if (isLess(ub1, lb2) || isLess(ub2, lb1)) 282 return true; 283 return false; 284 } 285 286 // Given two array sections <lb1, ub1, stride1> and 287 // <lb2, ub2, stride2>, return true only if the sections 288 // are known to be identical. 289 // 290 // For example: 291 // <x, x, stride> 292 // <x, nullptr, nullptr> 293 // 294 // These sections are identical, from the point of which array 295 // elements are being addresses, even though the shape 296 // of the array slices might be different. 297 static bool areIdenticalSections(const SectionDesc &desc1, 298 const SectionDesc &desc2) { 299 if (desc1 == desc2) 300 return true; 301 return false; 302 } 303 304 // Return true, if v1 is known to be less than v2. 305 static bool isLess(mlir::Value v1, mlir::Value v2); 306 }; 307 308 ArraySectionAnalyzer::SlicesOverlapKind 309 ArraySectionAnalyzer::analyze(mlir::Value ref1, mlir::Value ref2) { 310 if (ref1 == ref2) 311 return SlicesOverlapKind::DefinitelyIdentical; 312 313 auto des1 = ref1.getDefiningOp<hlfir::DesignateOp>(); 314 auto des2 = ref2.getDefiningOp<hlfir::DesignateOp>(); 315 // We only support a pair of designators right now. 316 if (!des1 || !des2) 317 return SlicesOverlapKind::Unknown; 318 319 if (des1.getMemref() != des2.getMemref()) { 320 // If the bases are different, then there is unknown overlap. 321 LLVM_DEBUG(llvm::dbgs() << "No identical base for:\n" 322 << des1 << "and:\n" 323 << des2 << "\n"); 324 return SlicesOverlapKind::Unknown; 325 } 326 327 // Require all components of the designators to be the same. 328 // It might be too strict, e.g. we may probably allow for 329 // different type parameters. 330 if (des1.getComponent() != des2.getComponent() || 331 des1.getComponentShape() != des2.getComponentShape() || 332 des1.getSubstring() != des2.getSubstring() || 333 des1.getComplexPart() != des2.getComplexPart() || 334 des1.getTypeparams() != des2.getTypeparams()) { 335 LLVM_DEBUG(llvm::dbgs() << "Different designator specs for:\n" 336 << des1 << "and:\n" 337 << des2 << "\n"); 338 return SlicesOverlapKind::Unknown; 339 } 340 341 // Analyze the subscripts. 342 auto des1It = des1.getIndices().begin(); 343 auto des2It = des2.getIndices().begin(); 344 bool identicalTriplets = true; 345 bool identicalIndices = true; 346 for (auto [isTriplet1, isTriplet2] : 347 llvm::zip(des1.getIsTriplet(), des2.getIsTriplet())) { 348 SectionDesc desc1 = readSectionDesc(des1It, isTriplet1); 349 SectionDesc desc2 = readSectionDesc(des2It, isTriplet2); 350 351 // See if we can prove that any of the sections do not overlap. 352 // This is mostly a Polyhedron/nf performance hack that looks for 353 // particular relations between the lower and upper bounds 354 // of the array sections, e.g. for any positive constant C: 355 // X:Y does not overlap with (Y+C):Z 356 // X:Y does not overlap with Z:(X-C) 357 if (areDisjointSections(desc1, desc2)) 358 return SlicesOverlapKind::DefinitelyDisjoint; 359 360 if (!areIdenticalSections(desc1, desc2)) { 361 if (isTriplet1 || isTriplet2) { 362 // For example: 363 // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %0) 364 // hlfir.designate %6#0 (%c2:%c7999:%c1, %c1:%c120:%c1, %1) 365 // 366 // If all the triplets (section speficiers) are the same, then 367 // we do not care if %0 is equal to %1 - the slices are either 368 // identical or completely disjoint. 369 // 370 // Also, treat these as identical sections: 371 // hlfir.designate %6#0 (%c2:%c2:%c1) 372 // hlfir.designate %6#0 (%c2) 373 identicalTriplets = false; 374 LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n" 375 << des1 << "and:\n" 376 << des2 << "\n"); 377 } else { 378 identicalIndices = false; 379 LLVM_DEBUG(llvm::dbgs() << "Indices mismatch for:\n" 380 << des1 << "and:\n" 381 << des2 << "\n"); 382 } 383 } 384 } 385 386 if (identicalTriplets) { 387 if (identicalIndices) 388 return SlicesOverlapKind::DefinitelyIdentical; 389 else 390 return SlicesOverlapKind::EitherIdenticalOrDisjoint; 391 } 392 393 LLVM_DEBUG(llvm::dbgs() << "Different sections for:\n" 394 << des1 << "and:\n" 395 << des2 << "\n"); 396 return SlicesOverlapKind::Unknown; 397 } 398 399 bool ArraySectionAnalyzer::isLess(mlir::Value v1, mlir::Value v2) { 400 auto removeConvert = [](mlir::Value v) -> mlir::Operation * { 401 auto *op = v.getDefiningOp(); 402 while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op)) 403 op = conv.getValue().getDefiningOp(); 404 return op; 405 }; 406 407 auto isPositiveConstant = [](mlir::Value v) -> bool { 408 if (auto val = fir::getIntIfConstant(v)) 409 return *val > 0; 410 return false; 411 }; 412 413 auto *op1 = removeConvert(v1); 414 auto *op2 = removeConvert(v2); 415 if (!op1 || !op2) 416 return false; 417 418 // Check if they are both constants. 419 if (auto val1 = fir::getIntIfConstant(op1->getResult(0))) 420 if (auto val2 = fir::getIntIfConstant(op2->getResult(0))) 421 return *val1 < *val2; 422 423 // Handle some variable cases (C > 0): 424 // v2 = v1 + C 425 // v2 = C + v1 426 // v1 = v2 - C 427 if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2)) 428 if ((addi.getLhs().getDefiningOp() == op1 && 429 isPositiveConstant(addi.getRhs())) || 430 (addi.getRhs().getDefiningOp() == op1 && 431 isPositiveConstant(addi.getLhs()))) 432 return true; 433 if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1)) 434 if (subi.getLhs().getDefiningOp() == op2 && 435 isPositiveConstant(subi.getRhs())) 436 return true; 437 return false; 438 } 439 440 llvm::SmallVector<mlir::Value> 441 ElementalAssignBufferization::getDesignatorIndices( 442 hlfir::DesignateOp designate) { 443 mlir::Value memref = designate.getMemref(); 444 445 // If the object is a box, then the indices may be adjusted 446 // according to the box's lower bound(s). Scan through 447 // the computations to try to find the one-based indices. 448 if (mlir::isa<fir::BaseBoxType>(memref.getType())) { 449 // Look for the following pattern: 450 // %13 = fir.load %12 : !fir.ref<!fir.box<...> 451 // %14:3 = fir.box_dims %13, %c0 : (!fir.box<...>, index) -> ... 452 // %17 = arith.subi %14#0, %c1 : index 453 // %18 = arith.addi %arg2, %17 : index 454 // %19 = hlfir.designate %13 (%18) : (!fir.box<...>, index) -> ... 455 // 456 // %arg2 is a one-based index. 457 458 auto isNormalizedLb = [memref](mlir::Value v, unsigned dim) { 459 // Return true, if v and dim are such that: 460 // %14:3 = fir.box_dims %13, %dim : (!fir.box<...>, index) -> ... 461 // %17 = arith.subi %14#0, %c1 : index 462 // %19 = hlfir.designate %13 (...) : (!fir.box<...>, index) -> ... 463 if (auto subOp = 464 mlir::dyn_cast_or_null<mlir::arith::SubIOp>(v.getDefiningOp())) { 465 auto cst = fir::getIntIfConstant(subOp.getRhs()); 466 if (!cst || *cst != 1) 467 return false; 468 if (auto dimsOp = mlir::dyn_cast_or_null<fir::BoxDimsOp>( 469 subOp.getLhs().getDefiningOp())) { 470 if (memref != dimsOp.getVal() || 471 dimsOp.getResult(0) != subOp.getLhs()) 472 return false; 473 auto dimsOpDim = fir::getIntIfConstant(dimsOp.getDim()); 474 return dimsOpDim && dimsOpDim == dim; 475 } 476 } 477 return false; 478 }; 479 480 llvm::SmallVector<mlir::Value> newIndices; 481 for (auto index : llvm::enumerate(designate.getIndices())) { 482 if (auto addOp = mlir::dyn_cast_or_null<mlir::arith::AddIOp>( 483 index.value().getDefiningOp())) { 484 for (unsigned opNum = 0; opNum < 2; ++opNum) 485 if (isNormalizedLb(addOp->getOperand(opNum), index.index())) { 486 newIndices.push_back(addOp->getOperand((opNum + 1) % 2)); 487 break; 488 } 489 490 // If new one-based index was not added, exit early. 491 if (newIndices.size() <= index.index()) 492 break; 493 } 494 } 495 496 // If any of the indices is not adjusted to the array's lb, 497 // then return the original designator indices. 498 if (newIndices.size() != designate.getIndices().size()) 499 return designate.getIndices(); 500 501 return newIndices; 502 } 503 504 return designate.getIndices(); 505 } 506 507 std::optional<ElementalAssignBufferization::MatchInfo> 508 ElementalAssignBufferization::findMatch(hlfir::ElementalOp elemental) { 509 mlir::Operation::user_range users = elemental->getUsers(); 510 // the only uses of the elemental should be the assignment and the destroy 511 if (std::distance(users.begin(), users.end()) != 2) { 512 LLVM_DEBUG(llvm::dbgs() << "Too many uses of the elemental\n"); 513 return std::nullopt; 514 } 515 516 // If the ElementalOp must produce a temporary (e.g. for 517 // finalization purposes), then we cannot inline it. 518 if (hlfir::elementalOpMustProduceTemp(elemental)) { 519 LLVM_DEBUG(llvm::dbgs() << "ElementalOp must produce a temp\n"); 520 return std::nullopt; 521 } 522 523 MatchInfo match; 524 for (mlir::Operation *user : users) 525 mlir::TypeSwitch<mlir::Operation *, void>(user) 526 .Case([&](hlfir::AssignOp op) { match.assign = op; }) 527 .Case([&](hlfir::DestroyOp op) { match.destroy = op; }); 528 529 if (!match.assign || !match.destroy) { 530 LLVM_DEBUG(llvm::dbgs() << "Couldn't find assign or destroy\n"); 531 return std::nullopt; 532 } 533 534 // the array is what the elemental is assigned into 535 // TODO: this could be extended to also allow hlfir.expr by first bufferizing 536 // the incoming expression 537 match.array = match.assign.getLhs(); 538 mlir::Type arrayType = mlir::dyn_cast<fir::SequenceType>( 539 fir::unwrapPassByRefType(match.array.getType())); 540 if (!arrayType) { 541 LLVM_DEBUG(llvm::dbgs() << "AssignOp's result is not an array\n"); 542 return std::nullopt; 543 } 544 545 // require that the array elements are trivial 546 // TODO: this is just to make the pass easier to think about. Not an inherent 547 // limitation 548 mlir::Type eleTy = hlfir::getFortranElementType(arrayType); 549 if (!fir::isa_trivial(eleTy)) { 550 LLVM_DEBUG(llvm::dbgs() << "AssignOp's data type is not trivial\n"); 551 return std::nullopt; 552 } 553 554 // The array must have the same shape as the elemental. 555 // 556 // f2018 10.2.1.2 (3) requires the lhs and rhs of an assignment to be 557 // conformable unless the lhs is an allocatable array. In HLFIR we can 558 // see this from the presence or absence of the realloc attribute on 559 // hlfir.assign. If it is not a realloc assignment, we can trust that 560 // the shapes do conform. 561 // 562 // TODO: the lhs's shape is dynamic, so it is hard to prove that 563 // there is no reallocation of the lhs due to the assignment. 564 // We can probably try generating multiple versions of the code 565 // with checking for the shape match, length parameters match, etc. 566 if (match.assign.isAllocatableAssignment()) { 567 LLVM_DEBUG(llvm::dbgs() << "AssignOp may involve (re)allocation of LHS\n"); 568 return std::nullopt; 569 } 570 571 // the transformation wants to apply the elemental in a do-loop at the 572 // hlfir.assign, check there are no effects which make this unsafe 573 574 // keep track of any values written to in the elemental, as these can't be 575 // read from between the elemental and the assignment 576 // likewise, values read in the elemental cannot be written to between the 577 // elemental and the assign 578 mlir::SmallVector<mlir::Value, 1> notToBeAccessedBeforeAssign; 579 // any accesses to the array between the array and the assignment means it 580 // would be unsafe to move the elemental to the assignment 581 notToBeAccessedBeforeAssign.push_back(match.array); 582 583 // 1) side effects in the elemental body - it isn't sufficient to just look 584 // for ordered elementals because we also cannot support out of order reads 585 std::optional<mlir::SmallVector<mlir::MemoryEffects::EffectInstance>> 586 effects = getEffectsBetween(&elemental.getBody()->front(), 587 elemental.getBody()->getTerminator()); 588 if (!effects) { 589 LLVM_DEBUG(llvm::dbgs() 590 << "operation with unknown effects inside elemental\n"); 591 return std::nullopt; 592 } 593 for (const mlir::MemoryEffects::EffectInstance &effect : *effects) { 594 mlir::AliasResult res = containsReadOrWriteEffectOn(effect, match.array); 595 if (res.isNo()) { 596 if (mlir::isa<mlir::MemoryEffects::Write, mlir::MemoryEffects::Read>( 597 effect.getEffect())) 598 if (effect.getValue()) 599 notToBeAccessedBeforeAssign.push_back(effect.getValue()); 600 601 // this is safe in the elemental 602 continue; 603 } 604 605 // don't allow any aliasing writes in the elemental 606 if (mlir::isa<mlir::MemoryEffects::Write>(effect.getEffect())) { 607 LLVM_DEBUG(llvm::dbgs() << "write inside the elemental body\n"); 608 return std::nullopt; 609 } 610 611 // allow if and only if the reads are from the elemental indices, in order 612 // => each iteration doesn't read values written by other iterations 613 // don't allow reads from a different value which may alias: fir alias 614 // analysis isn't precise enough to tell us if two aliasing arrays overlap 615 // exactly or only partially. If they overlap partially, a designate at the 616 // elemental indices could be accessing different elements: e.g. we could 617 // designate two slices of the same array at different start indexes. These 618 // two MustAlias but index 1 of one array isn't the same element as index 1 619 // of the other array. 620 if (!res.isPartial()) { 621 if (auto designate = 622 effect.getValue().getDefiningOp<hlfir::DesignateOp>()) { 623 ArraySectionAnalyzer::SlicesOverlapKind overlap = 624 ArraySectionAnalyzer::analyze(match.array, designate.getMemref()); 625 if (overlap == 626 ArraySectionAnalyzer::SlicesOverlapKind::DefinitelyDisjoint) 627 continue; 628 629 if (overlap == ArraySectionAnalyzer::SlicesOverlapKind::Unknown) { 630 LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate 631 << " at " << elemental.getLoc() << "\n"); 632 return std::nullopt; 633 } 634 auto indices = getDesignatorIndices(designate); 635 auto elementalIndices = elemental.getIndices(); 636 if (indices.size() == elementalIndices.size() && 637 std::equal(indices.begin(), indices.end(), elementalIndices.begin(), 638 elementalIndices.end())) 639 continue; 640 641 LLVM_DEBUG(llvm::dbgs() << "possible read conflict: " << designate 642 << " at " << elemental.getLoc() << "\n"); 643 return std::nullopt; 644 } 645 } 646 LLVM_DEBUG(llvm::dbgs() << "disallowed side-effect: " << effect.getValue() 647 << " for " << elemental.getLoc() << "\n"); 648 return std::nullopt; 649 } 650 651 // 2) look for conflicting effects between the elemental and the assignment 652 effects = getEffectsBetween(elemental->getNextNode(), match.assign); 653 if (!effects) { 654 LLVM_DEBUG( 655 llvm::dbgs() 656 << "operation with unknown effects between elemental and assign\n"); 657 return std::nullopt; 658 } 659 for (const mlir::MemoryEffects::EffectInstance &effect : *effects) { 660 // not safe to access anything written in the elemental as this write 661 // will be moved to the assignment 662 for (mlir::Value val : notToBeAccessedBeforeAssign) { 663 mlir::AliasResult res = containsReadOrWriteEffectOn(effect, val); 664 if (!res.isNo()) { 665 LLVM_DEBUG(llvm::dbgs() 666 << "diasllowed side-effect: " << effect.getValue() << " for " 667 << elemental.getLoc() << "\n"); 668 return std::nullopt; 669 } 670 } 671 } 672 673 return match; 674 } 675 676 llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite( 677 hlfir::ElementalOp elemental, mlir::PatternRewriter &rewriter) const { 678 std::optional<MatchInfo> match = findMatch(elemental); 679 if (!match) 680 return rewriter.notifyMatchFailure( 681 elemental, "cannot prove safety of ElementalAssignBufferization"); 682 683 mlir::Location loc = elemental->getLoc(); 684 fir::FirOpBuilder builder(rewriter, elemental.getOperation()); 685 auto extents = hlfir::getIndexExtents(loc, builder, elemental.getShape()); 686 687 // create the loop at the assignment 688 builder.setInsertionPoint(match->assign); 689 690 // Generate a loop nest looping around the hlfir.elemental shape and clone 691 // hlfir.elemental region inside the inner loop 692 hlfir::LoopNest loopNest = 693 hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered(), 694 flangomp::shouldUseWorkshareLowering(elemental)); 695 builder.setInsertionPointToStart(loopNest.body); 696 auto yield = hlfir::inlineElementalOp(loc, builder, elemental, 697 loopNest.oneBasedIndices); 698 hlfir::Entity elementValue{yield.getElementValue()}; 699 rewriter.eraseOp(yield); 700 701 // Assign the element value to the array element for this iteration. 702 auto arrayElement = hlfir::getElementAt( 703 loc, builder, hlfir::Entity{match->array}, loopNest.oneBasedIndices); 704 builder.create<hlfir::AssignOp>( 705 loc, elementValue, arrayElement, /*realloc=*/false, 706 /*keep_lhs_length_if_realloc=*/false, match->assign.getTemporaryLhs()); 707 708 rewriter.eraseOp(match->assign); 709 rewriter.eraseOp(match->destroy); 710 rewriter.eraseOp(elemental); 711 return mlir::success(); 712 } 713 714 /// Expand hlfir.assign of a scalar RHS to array LHS into a loop nest 715 /// of element-by-element assignments: 716 /// hlfir.assign %cst to %0 : f32, !fir.ref<!fir.array<6x6xf32>> 717 /// into: 718 /// fir.do_loop %arg0 = %c1 to %c6 step %c1 unordered { 719 /// fir.do_loop %arg1 = %c1 to %c6 step %c1 unordered { 720 /// %1 = hlfir.designate %0 (%arg1, %arg0) : 721 /// (!fir.ref<!fir.array<6x6xf32>>, index, index) -> !fir.ref<f32> 722 /// hlfir.assign %cst to %1 : f32, !fir.ref<f32> 723 /// } 724 /// } 725 class BroadcastAssignBufferization 726 : public mlir::OpRewritePattern<hlfir::AssignOp> { 727 private: 728 public: 729 using mlir::OpRewritePattern<hlfir::AssignOp>::OpRewritePattern; 730 731 llvm::LogicalResult 732 matchAndRewrite(hlfir::AssignOp assign, 733 mlir::PatternRewriter &rewriter) const override; 734 }; 735 736 llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite( 737 hlfir::AssignOp assign, mlir::PatternRewriter &rewriter) const { 738 // Since RHS is a scalar and LHS is an array, LHS must be allocated 739 // in a conforming Fortran program, and LHS cannot be reallocated 740 // as a result of the assignment. So we can ignore isAllocatableAssignment 741 // and do the transformation always. 742 mlir::Value rhs = assign.getRhs(); 743 if (!fir::isa_trivial(rhs.getType())) 744 return rewriter.notifyMatchFailure( 745 assign, "AssignOp's RHS is not a trivial scalar"); 746 747 hlfir::Entity lhs{assign.getLhs()}; 748 if (!lhs.isArray()) 749 return rewriter.notifyMatchFailure(assign, 750 "AssignOp's LHS is not an array"); 751 752 mlir::Type eleTy = lhs.getFortranElementType(); 753 if (!fir::isa_trivial(eleTy)) 754 return rewriter.notifyMatchFailure( 755 assign, "AssignOp's LHS data type is not trivial"); 756 757 mlir::Location loc = assign->getLoc(); 758 fir::FirOpBuilder builder(rewriter, assign.getOperation()); 759 builder.setInsertionPoint(assign); 760 lhs = hlfir::derefPointersAndAllocatables(loc, builder, lhs); 761 mlir::Value shape = hlfir::genShape(loc, builder, lhs); 762 llvm::SmallVector<mlir::Value> extents = 763 hlfir::getIndexExtents(loc, builder, shape); 764 hlfir::LoopNest loopNest = 765 hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true, 766 flangomp::shouldUseWorkshareLowering(assign)); 767 builder.setInsertionPointToStart(loopNest.body); 768 auto arrayElement = 769 hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices); 770 builder.create<hlfir::AssignOp>(loc, rhs, arrayElement); 771 rewriter.eraseOp(assign); 772 return mlir::success(); 773 } 774 775 using GenBodyFn = 776 std::function<mlir::Value(fir::FirOpBuilder &, mlir::Location, mlir::Value, 777 const llvm::SmallVectorImpl<mlir::Value> &)>; 778 static mlir::Value generateReductionLoop(fir::FirOpBuilder &builder, 779 mlir::Location loc, mlir::Value init, 780 mlir::Value shape, GenBodyFn genBody) { 781 auto extents = hlfir::getIndexExtents(loc, builder, shape); 782 mlir::Value reduction = init; 783 mlir::IndexType idxTy = builder.getIndexType(); 784 mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1); 785 786 // Create a reduction loop nest. We use one-based indices so that they can be 787 // passed to the elemental, and reverse the order so that they can be 788 // generated in column-major order for better performance. 789 llvm::SmallVector<mlir::Value> indices(extents.size(), mlir::Value{}); 790 for (unsigned i = 0; i < extents.size(); ++i) { 791 auto loop = builder.create<fir::DoLoopOp>( 792 loc, oneIdx, extents[extents.size() - i - 1], oneIdx, false, 793 /*finalCountValue=*/false, reduction); 794 reduction = loop.getRegionIterArgs()[0]; 795 indices[extents.size() - i - 1] = loop.getInductionVar(); 796 // Set insertion point to the loop body so that the next loop 797 // is inserted inside the current one. 798 builder.setInsertionPointToStart(loop.getBody()); 799 } 800 801 // Generate the body 802 reduction = genBody(builder, loc, reduction, indices); 803 804 // Unwind the loop nest. 805 for (unsigned i = 0; i < extents.size(); ++i) { 806 auto result = builder.create<fir::ResultOp>(loc, reduction); 807 auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp()); 808 reduction = loop.getResult(0); 809 // Set insertion point after the loop operation that we have 810 // just processed. 811 builder.setInsertionPointAfter(loop.getOperation()); 812 } 813 814 return reduction; 815 } 816 817 auto makeMinMaxInitValGenerator(bool isMax) { 818 return [isMax](fir::FirOpBuilder builder, mlir::Location loc, 819 mlir::Type elementType) -> mlir::Value { 820 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) { 821 const llvm::fltSemantics &sem = ty.getFloatSemantics(); 822 llvm::APFloat limit = llvm::APFloat::getInf(sem, /*Negative=*/isMax); 823 return builder.createRealConstant(loc, elementType, limit); 824 } 825 unsigned bits = elementType.getIntOrFloatBitWidth(); 826 int64_t limitInt = 827 isMax ? llvm::APInt::getSignedMinValue(bits).getSExtValue() 828 : llvm::APInt::getSignedMaxValue(bits).getSExtValue(); 829 return builder.createIntegerConstant(loc, elementType, limitInt); 830 }; 831 } 832 833 mlir::Value generateMinMaxComparison(fir::FirOpBuilder builder, 834 mlir::Location loc, mlir::Value elem, 835 mlir::Value reduction, bool isMax) { 836 if (mlir::isa<mlir::FloatType>(reduction.getType())) { 837 // For FP reductions we want the first smallest value to be used, that 838 // is not NaN. A OGL/OLT condition will usually work for this unless all 839 // the values are Nan or Inf. This follows the same logic as 840 // NumericCompare for Minloc/Maxlox in extrema.cpp. 841 mlir::Value cmp = builder.create<mlir::arith::CmpFOp>( 842 loc, 843 isMax ? mlir::arith::CmpFPredicate::OGT 844 : mlir::arith::CmpFPredicate::OLT, 845 elem, reduction); 846 mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>( 847 loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction); 848 mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>( 849 loc, mlir::arith::CmpFPredicate::OEQ, elem, elem); 850 cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2); 851 return builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan); 852 } else if (mlir::isa<mlir::IntegerType>(reduction.getType())) { 853 return builder.create<mlir::arith::CmpIOp>( 854 loc, 855 isMax ? mlir::arith::CmpIPredicate::sgt 856 : mlir::arith::CmpIPredicate::slt, 857 elem, reduction); 858 } 859 llvm_unreachable("unsupported type"); 860 } 861 862 /// Given a reduction operation with an elemental/designate source, attempt to 863 /// generate a do-loop to perform the operation inline. 864 /// %e = hlfir.elemental %shape unordered 865 /// %r = hlfir.count %e 866 /// => 867 /// %r = for.do_loop %arg = 1 to bound(%shape) step 1 iter_args(%arg2 = init) 868 /// %i = <inline elemental> 869 /// %c = <reduce count> %i 870 /// fir.result %c 871 template <typename Op> 872 class ReductionConversion : public mlir::OpRewritePattern<Op> { 873 public: 874 using mlir::OpRewritePattern<Op>::OpRewritePattern; 875 876 llvm::LogicalResult 877 matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override { 878 mlir::Location loc = op.getLoc(); 879 // Select source and validate its arguments. 880 mlir::Value source; 881 bool valid = false; 882 if constexpr (std::is_same_v<Op, hlfir::AnyOp> || 883 std::is_same_v<Op, hlfir::AllOp> || 884 std::is_same_v<Op, hlfir::CountOp>) { 885 source = op.getMask(); 886 valid = !op.getDim(); 887 } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> || 888 std::is_same_v<Op, hlfir::MinvalOp>) { 889 source = op.getArray(); 890 valid = !op.getDim() && !op.getMask(); 891 } else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> || 892 std::is_same_v<Op, hlfir::MinlocOp>) { 893 source = op.getArray(); 894 valid = !op.getDim() && !op.getMask() && !op.getBack(); 895 } 896 if (!valid) 897 return rewriter.notifyMatchFailure( 898 op, "Currently does not accept optional arguments"); 899 900 hlfir::ElementalOp elemental; 901 hlfir::DesignateOp designate; 902 mlir::Value shape; 903 if ((elemental = source.template getDefiningOp<hlfir::ElementalOp>())) { 904 shape = elemental.getOperand(0); 905 } else if ((designate = 906 source.template getDefiningOp<hlfir::DesignateOp>())) { 907 shape = designate.getShape(); 908 } else { 909 return rewriter.notifyMatchFailure(op, "Did not find valid argument"); 910 } 911 912 auto inlineSource = 913 [elemental, &designate]( 914 fir::FirOpBuilder builder, mlir::Location loc, 915 const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value { 916 if (elemental) { 917 // Inline the elemental and get the value from it. 918 auto yield = inlineElementalOp(loc, builder, elemental, indices); 919 auto tmp = yield.getElementValue(); 920 yield->erase(); 921 return tmp; 922 } 923 if (designate) { 924 // Create a designator over designator, then load the reference. 925 auto resEntity = hlfir::Entity{designate.getResult()}; 926 auto tmp = builder.create<hlfir::DesignateOp>( 927 loc, getVariableElementType(resEntity), designate, indices); 928 return builder.create<fir::LoadOp>(loc, tmp); 929 } 930 llvm_unreachable("unsupported type"); 931 }; 932 933 fir::KindMapping kindMap = 934 fir::getKindMapping(op->template getParentOfType<mlir::ModuleOp>()); 935 fir::FirOpBuilder builder{op, kindMap}; 936 937 mlir::Value init; 938 GenBodyFn genBodyFn; 939 if constexpr (std::is_same_v<Op, hlfir::AnyOp>) { 940 init = builder.createIntegerConstant(loc, builder.getI1Type(), 0); 941 genBodyFn = 942 [inlineSource](fir::FirOpBuilder builder, mlir::Location loc, 943 mlir::Value reduction, 944 const llvm::SmallVectorImpl<mlir::Value> &indices) 945 -> mlir::Value { 946 // Conditionally set the reduction variable. 947 mlir::Value cond = builder.create<fir::ConvertOp>( 948 loc, builder.getI1Type(), inlineSource(builder, loc, indices)); 949 return builder.create<mlir::arith::OrIOp>(loc, reduction, cond); 950 }; 951 } else if constexpr (std::is_same_v<Op, hlfir::AllOp>) { 952 init = builder.createIntegerConstant(loc, builder.getI1Type(), 1); 953 genBodyFn = 954 [inlineSource](fir::FirOpBuilder builder, mlir::Location loc, 955 mlir::Value reduction, 956 const llvm::SmallVectorImpl<mlir::Value> &indices) 957 -> mlir::Value { 958 // Conditionally set the reduction variable. 959 mlir::Value cond = builder.create<fir::ConvertOp>( 960 loc, builder.getI1Type(), inlineSource(builder, loc, indices)); 961 return builder.create<mlir::arith::AndIOp>(loc, reduction, cond); 962 }; 963 } else if constexpr (std::is_same_v<Op, hlfir::CountOp>) { 964 init = builder.createIntegerConstant(loc, op.getType(), 0); 965 genBodyFn = 966 [inlineSource](fir::FirOpBuilder builder, mlir::Location loc, 967 mlir::Value reduction, 968 const llvm::SmallVectorImpl<mlir::Value> &indices) 969 -> mlir::Value { 970 // Conditionally add one to the current value 971 mlir::Value cond = builder.create<fir::ConvertOp>( 972 loc, builder.getI1Type(), inlineSource(builder, loc, indices)); 973 mlir::Value one = 974 builder.createIntegerConstant(loc, reduction.getType(), 1); 975 mlir::Value add1 = 976 builder.create<mlir::arith::AddIOp>(loc, reduction, one); 977 return builder.create<mlir::arith::SelectOp>(loc, cond, add1, 978 reduction); 979 }; 980 } else if constexpr (std::is_same_v<Op, hlfir::MaxlocOp> || 981 std::is_same_v<Op, hlfir::MinlocOp>) { 982 // TODO: implement minloc/maxloc conversion. 983 return rewriter.notifyMatchFailure( 984 op, "Currently minloc/maxloc is not handled"); 985 } else if constexpr (std::is_same_v<Op, hlfir::MaxvalOp> || 986 std::is_same_v<Op, hlfir::MinvalOp>) { 987 bool isMax = std::is_same_v<Op, hlfir::MaxvalOp>; 988 init = makeMinMaxInitValGenerator(isMax)(builder, loc, op.getType()); 989 genBodyFn = [inlineSource, 990 isMax](fir::FirOpBuilder builder, mlir::Location loc, 991 mlir::Value reduction, 992 const llvm::SmallVectorImpl<mlir::Value> &indices) 993 -> mlir::Value { 994 mlir::Value val = inlineSource(builder, loc, indices); 995 mlir::Value cmp = 996 generateMinMaxComparison(builder, loc, val, reduction, isMax); 997 return builder.create<mlir::arith::SelectOp>(loc, cmp, val, reduction); 998 }; 999 } else { 1000 llvm_unreachable("unsupported type"); 1001 } 1002 1003 mlir::Value res = 1004 generateReductionLoop(builder, loc, init, shape, genBodyFn); 1005 if (res.getType() != op.getType()) 1006 res = builder.create<fir::ConvertOp>(loc, op.getType(), res); 1007 1008 // Check if the op was the only user of the source (apart from a destroy), 1009 // and remove it if so. 1010 mlir::Operation *sourceOp = source.getDefiningOp(); 1011 mlir::Operation::user_range srcUsers = sourceOp->getUsers(); 1012 hlfir::DestroyOp srcDestroy; 1013 if (std::distance(srcUsers.begin(), srcUsers.end()) == 2) { 1014 srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*srcUsers.begin()); 1015 if (!srcDestroy) 1016 srcDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++srcUsers.begin()); 1017 } 1018 1019 rewriter.replaceOp(op, res); 1020 if (srcDestroy) { 1021 rewriter.eraseOp(srcDestroy); 1022 rewriter.eraseOp(sourceOp); 1023 } 1024 return mlir::success(); 1025 } 1026 }; 1027 1028 // Look for minloc(mask=elemental) and generate the minloc loop with 1029 // inlined elemental. 1030 // %e = hlfir.elemental %shape ({ ... }) 1031 // %m = hlfir.minloc %array mask %e 1032 template <typename Op> 1033 class ReductionMaskConversion : public mlir::OpRewritePattern<Op> { 1034 public: 1035 using mlir::OpRewritePattern<Op>::OpRewritePattern; 1036 1037 llvm::LogicalResult 1038 matchAndRewrite(Op mloc, mlir::PatternRewriter &rewriter) const override { 1039 if (!mloc.getMask() || mloc.getDim() || mloc.getBack()) 1040 return rewriter.notifyMatchFailure(mloc, 1041 "Did not find valid minloc/maxloc"); 1042 1043 bool isMax = std::is_same_v<Op, hlfir::MaxlocOp>; 1044 1045 auto elemental = 1046 mloc.getMask().template getDefiningOp<hlfir::ElementalOp>(); 1047 if (!elemental || hlfir::elementalOpMustProduceTemp(elemental)) 1048 return rewriter.notifyMatchFailure(mloc, "Did not find elemental"); 1049 1050 mlir::Value array = mloc.getArray(); 1051 1052 unsigned rank = mlir::cast<hlfir::ExprType>(mloc.getType()).getShape()[0]; 1053 mlir::Type arrayType = array.getType(); 1054 if (!mlir::isa<fir::BoxType>(arrayType)) 1055 return rewriter.notifyMatchFailure( 1056 mloc, "Currently requires a boxed type input"); 1057 mlir::Type elementType = hlfir::getFortranElementType(arrayType); 1058 if (!fir::isa_trivial(elementType)) 1059 return rewriter.notifyMatchFailure( 1060 mloc, "Character arrays are currently not handled"); 1061 1062 mlir::Location loc = mloc.getLoc(); 1063 fir::FirOpBuilder builder{rewriter, mloc.getOperation()}; 1064 mlir::Value resultArr = builder.createTemporary( 1065 loc, fir::SequenceType::get( 1066 rank, hlfir::getFortranElementType(mloc.getType()))); 1067 1068 auto init = makeMinMaxInitValGenerator(isMax); 1069 1070 auto genBodyOp = 1071 [&rank, &resultArr, &elemental, isMax]( 1072 fir::FirOpBuilder builder, mlir::Location loc, 1073 mlir::Type elementType, mlir::Value array, mlir::Value flagRef, 1074 mlir::Value reduction, 1075 const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value { 1076 // We are in the innermost loop: generate the elemental inline 1077 mlir::Value oneIdx = 1078 builder.createIntegerConstant(loc, builder.getIndexType(), 1); 1079 llvm::SmallVector<mlir::Value> oneBasedIndices; 1080 llvm::transform( 1081 indices, std::back_inserter(oneBasedIndices), [&](mlir::Value V) { 1082 return builder.create<mlir::arith::AddIOp>(loc, V, oneIdx); 1083 }); 1084 hlfir::YieldElementOp yield = 1085 hlfir::inlineElementalOp(loc, builder, elemental, oneBasedIndices); 1086 mlir::Value maskElem = yield.getElementValue(); 1087 yield->erase(); 1088 1089 mlir::Type ifCompatType = builder.getI1Type(); 1090 mlir::Value ifCompatElem = 1091 builder.create<fir::ConvertOp>(loc, ifCompatType, maskElem); 1092 1093 llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType}; 1094 fir::IfOp maskIfOp = 1095 builder.create<fir::IfOp>(loc, elementType, ifCompatElem, 1096 /*withElseRegion=*/true); 1097 builder.setInsertionPointToStart(&maskIfOp.getThenRegion().front()); 1098 1099 // Set flag that mask was true at some point 1100 mlir::Value flagSet = builder.createIntegerConstant( 1101 loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1); 1102 mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef); 1103 mlir::Value addr = hlfir::getElementAt(loc, builder, hlfir::Entity{array}, 1104 oneBasedIndices); 1105 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr); 1106 1107 // Compare with the max reduction value 1108 mlir::Value cmp = 1109 generateMinMaxComparison(builder, loc, elem, reduction, isMax); 1110 1111 // The condition used for the loop is isFirst || <the condition above>. 1112 isFirst = builder.create<fir::ConvertOp>(loc, cmp.getType(), isFirst); 1113 isFirst = builder.create<mlir::arith::XOrIOp>( 1114 loc, isFirst, builder.createIntegerConstant(loc, cmp.getType(), 1)); 1115 cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst); 1116 1117 // Set the new coordinate to the result 1118 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, elementType, cmp, 1119 /*withElseRegion*/ true); 1120 1121 builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); 1122 builder.create<fir::StoreOp>(loc, flagSet, flagRef); 1123 mlir::Type resultElemTy = 1124 hlfir::getFortranElementType(resultArr.getType()); 1125 mlir::Type returnRefTy = builder.getRefType(resultElemTy); 1126 mlir::IndexType idxTy = builder.getIndexType(); 1127 1128 for (unsigned int i = 0; i < rank; ++i) { 1129 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i + 1); 1130 mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>( 1131 loc, returnRefTy, resultArr, index); 1132 mlir::Value fortranIndex = builder.create<fir::ConvertOp>( 1133 loc, resultElemTy, oneBasedIndices[i]); 1134 builder.create<fir::StoreOp>(loc, fortranIndex, resultElemAddr); 1135 } 1136 builder.create<fir::ResultOp>(loc, elem); 1137 builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); 1138 builder.create<fir::ResultOp>(loc, reduction); 1139 builder.setInsertionPointAfter(ifOp); 1140 1141 // Close the mask if 1142 builder.create<fir::ResultOp>(loc, ifOp.getResult(0)); 1143 builder.setInsertionPointToStart(&maskIfOp.getElseRegion().front()); 1144 builder.create<fir::ResultOp>(loc, reduction); 1145 builder.setInsertionPointAfter(maskIfOp); 1146 1147 return maskIfOp.getResult(0); 1148 }; 1149 auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc, 1150 const mlir::Type &resultElemType, mlir::Value resultArr, 1151 mlir::Value index) { 1152 mlir::Type resultRefTy = builder.getRefType(resultElemType); 1153 mlir::Value oneIdx = 1154 builder.createIntegerConstant(loc, builder.getIndexType(), 1); 1155 index = builder.create<mlir::arith::AddIOp>(loc, index, oneIdx); 1156 return builder.create<hlfir::DesignateOp>(loc, resultRefTy, resultArr, 1157 index); 1158 }; 1159 1160 // Initialize the result 1161 mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType()); 1162 mlir::Type resultRefTy = builder.getRefType(resultElemTy); 1163 mlir::Value returnValue = 1164 builder.createIntegerConstant(loc, resultElemTy, 0); 1165 for (unsigned int i = 0; i < rank; ++i) { 1166 mlir::Value index = 1167 builder.createIntegerConstant(loc, builder.getIndexType(), i + 1); 1168 mlir::Value resultElemAddr = builder.create<hlfir::DesignateOp>( 1169 loc, resultRefTy, resultArr, index); 1170 builder.create<fir::StoreOp>(loc, returnValue, resultElemAddr); 1171 } 1172 1173 fir::genMinMaxlocReductionLoop(builder, array, init, genBodyOp, getAddrFn, 1174 rank, elementType, loc, builder.getI1Type(), 1175 resultArr, false); 1176 1177 mlir::Value asExpr = builder.create<hlfir::AsExprOp>( 1178 loc, resultArr, builder.createBool(loc, false)); 1179 1180 // Check all the users - the destroy is no longer required, and any assign 1181 // can use resultArr directly so that InlineHLFIRAssign pass 1182 // can optimize the results. Other operations are replaced with an AsExpr 1183 // for the temporary resultArr. 1184 llvm::SmallVector<hlfir::DestroyOp> destroys; 1185 llvm::SmallVector<hlfir::AssignOp> assigns; 1186 for (auto user : mloc->getUsers()) { 1187 if (auto destroy = mlir::dyn_cast<hlfir::DestroyOp>(user)) 1188 destroys.push_back(destroy); 1189 else if (auto assign = mlir::dyn_cast<hlfir::AssignOp>(user)) 1190 assigns.push_back(assign); 1191 } 1192 1193 // Check if the minloc/maxloc was the only user of the elemental (apart from 1194 // a destroy), and remove it if so. 1195 mlir::Operation::user_range elemUsers = elemental->getUsers(); 1196 hlfir::DestroyOp elemDestroy; 1197 if (std::distance(elemUsers.begin(), elemUsers.end()) == 2) { 1198 elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*elemUsers.begin()); 1199 if (!elemDestroy) 1200 elemDestroy = mlir::dyn_cast<hlfir::DestroyOp>(*++elemUsers.begin()); 1201 } 1202 1203 for (auto d : destroys) 1204 rewriter.eraseOp(d); 1205 for (auto a : assigns) 1206 a.setOperand(0, resultArr); 1207 rewriter.replaceOp(mloc, asExpr); 1208 if (elemDestroy) { 1209 rewriter.eraseOp(elemDestroy); 1210 rewriter.eraseOp(elemental); 1211 } 1212 return mlir::success(); 1213 } 1214 }; 1215 1216 class EvaluateIntoMemoryAssignBufferization 1217 : public mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp> { 1218 1219 public: 1220 using mlir::OpRewritePattern<hlfir::EvaluateInMemoryOp>::OpRewritePattern; 1221 1222 llvm::LogicalResult 1223 matchAndRewrite(hlfir::EvaluateInMemoryOp, 1224 mlir::PatternRewriter &rewriter) const override; 1225 }; 1226 1227 static llvm::LogicalResult 1228 tryUsingAssignLhsDirectly(hlfir::EvaluateInMemoryOp evalInMem, 1229 mlir::PatternRewriter &rewriter) { 1230 mlir::Location loc = evalInMem.getLoc(); 1231 hlfir::DestroyOp destroy; 1232 hlfir::AssignOp assign; 1233 for (auto user : llvm::enumerate(evalInMem->getUsers())) { 1234 if (user.index() > 2) 1235 return mlir::failure(); 1236 mlir::TypeSwitch<mlir::Operation *, void>(user.value()) 1237 .Case([&](hlfir::AssignOp op) { assign = op; }) 1238 .Case([&](hlfir::DestroyOp op) { destroy = op; }); 1239 } 1240 if (!assign || !destroy || destroy.mustFinalizeExpr() || 1241 assign.isAllocatableAssignment()) 1242 return mlir::failure(); 1243 1244 hlfir::Entity lhs{assign.getLhs()}; 1245 // EvaluateInMemoryOp memory is contiguous, so in general, it can only be 1246 // replace by the LHS if the LHS is contiguous. 1247 if (!lhs.isSimplyContiguous()) 1248 return mlir::failure(); 1249 // Character assignment may involves truncation/padding, so the LHS 1250 // cannot be used to evaluate RHS in place without proving the LHS and 1251 // RHS lengths are the same. 1252 if (lhs.isCharacter()) 1253 return mlir::failure(); 1254 fir::AliasAnalysis aliasAnalysis; 1255 // The region must not read or write the LHS. 1256 // Note that getModRef is used instead of mlir::MemoryEffects because 1257 // EvaluateInMemoryOp is typically expected to hold fir.calls and that 1258 // Fortran calls cannot be modeled in a useful way with mlir::MemoryEffects: 1259 // it is hard/impossible to list all the read/written SSA values in a call, 1260 // but it is often possible to tell that an SSA value cannot be accessed, 1261 // hence getModRef is needed here and below. Also note that getModRef uses 1262 // mlir::MemoryEffects for operations that do not have special handling in 1263 // getModRef. 1264 if (aliasAnalysis.getModRef(evalInMem.getBody(), lhs).isModOrRef()) 1265 return mlir::failure(); 1266 // Any variables affected between the hlfir.evalInMem and assignment must not 1267 // be read or written inside the region since it will be moved at the 1268 // assignment insertion point. 1269 auto effects = getEffectsBetween(evalInMem->getNextNode(), assign); 1270 if (!effects) { 1271 LLVM_DEBUG( 1272 llvm::dbgs() 1273 << "operation with unknown effects between eval_in_mem and assign\n"); 1274 return mlir::failure(); 1275 } 1276 for (const mlir::MemoryEffects::EffectInstance &effect : *effects) { 1277 mlir::Value affected = effect.getValue(); 1278 if (!affected || 1279 aliasAnalysis.getModRef(evalInMem.getBody(), affected).isModOrRef()) 1280 return mlir::failure(); 1281 } 1282 1283 rewriter.setInsertionPoint(assign); 1284 fir::FirOpBuilder builder(rewriter, evalInMem.getOperation()); 1285 mlir::Value rawLhs = hlfir::genVariableRawAddress(loc, builder, lhs); 1286 hlfir::computeEvaluateOpIn(loc, builder, evalInMem, rawLhs); 1287 rewriter.eraseOp(assign); 1288 rewriter.eraseOp(destroy); 1289 rewriter.eraseOp(evalInMem); 1290 return mlir::success(); 1291 } 1292 1293 llvm::LogicalResult EvaluateIntoMemoryAssignBufferization::matchAndRewrite( 1294 hlfir::EvaluateInMemoryOp evalInMem, 1295 mlir::PatternRewriter &rewriter) const { 1296 if (mlir::succeeded(tryUsingAssignLhsDirectly(evalInMem, rewriter))) 1297 return mlir::success(); 1298 // Rewrite to temp + as_expr here so that the assign + as_expr pattern can 1299 // kick-in for simple types and at least implement the assignment inline 1300 // instead of call Assign runtime. 1301 fir::FirOpBuilder builder(rewriter, evalInMem.getOperation()); 1302 mlir::Location loc = evalInMem.getLoc(); 1303 auto [temp, isHeapAllocated] = hlfir::computeEvaluateOpInNewTemp( 1304 loc, builder, evalInMem, evalInMem.getShape(), evalInMem.getTypeparams()); 1305 rewriter.replaceOpWithNewOp<hlfir::AsExprOp>( 1306 evalInMem, temp, /*mustFree=*/builder.createBool(loc, isHeapAllocated)); 1307 return mlir::success(); 1308 } 1309 1310 class OptimizedBufferizationPass 1311 : public hlfir::impl::OptimizedBufferizationBase< 1312 OptimizedBufferizationPass> { 1313 public: 1314 void runOnOperation() override { 1315 mlir::MLIRContext *context = &getContext(); 1316 1317 mlir::GreedyRewriteConfig config; 1318 // Prevent the pattern driver from merging blocks 1319 config.enableRegionSimplification = 1320 mlir::GreedySimplifyRegionLevel::Disabled; 1321 1322 mlir::RewritePatternSet patterns(context); 1323 // TODO: right now the patterns are non-conflicting, 1324 // but it might be better to run this pass on hlfir.assign 1325 // operations and decide which transformation to apply 1326 // at one place (e.g. we may use some heuristics and 1327 // choose different optimization strategies). 1328 // This requires small code reordering in ElementalAssignBufferization. 1329 patterns.insert<ElementalAssignBufferization>(context); 1330 patterns.insert<BroadcastAssignBufferization>(context); 1331 patterns.insert<EvaluateIntoMemoryAssignBufferization>(context); 1332 patterns.insert<ReductionConversion<hlfir::CountOp>>(context); 1333 patterns.insert<ReductionConversion<hlfir::AnyOp>>(context); 1334 patterns.insert<ReductionConversion<hlfir::AllOp>>(context); 1335 // TODO: implement basic minloc/maxloc conversion. 1336 // patterns.insert<ReductionConversion<hlfir::MaxlocOp>>(context); 1337 // patterns.insert<ReductionConversion<hlfir::MinlocOp>>(context); 1338 patterns.insert<ReductionConversion<hlfir::MaxvalOp>>(context); 1339 patterns.insert<ReductionConversion<hlfir::MinvalOp>>(context); 1340 patterns.insert<ReductionMaskConversion<hlfir::MinlocOp>>(context); 1341 patterns.insert<ReductionMaskConversion<hlfir::MaxlocOp>>(context); 1342 // TODO: implement masked minval/maxval conversion. 1343 // patterns.insert<ReductionMaskConversion<hlfir::MaxvalOp>>(context); 1344 // patterns.insert<ReductionMaskConversion<hlfir::MinvalOp>>(context); 1345 1346 if (mlir::failed(mlir::applyPatternsGreedily( 1347 getOperation(), std::move(patterns), config))) { 1348 mlir::emitError(getOperation()->getLoc(), 1349 "failure in HLFIR optimized bufferization"); 1350 signalPassFailure(); 1351 } 1352 } 1353 }; 1354 } // namespace 1355