1 //===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Optimizer/Builder/FIRBuilder.h" 10 #include "flang/Optimizer/Builder/HLFIRTools.h" 11 #include "flang/Optimizer/Builder/IntrinsicCall.h" 12 #include "flang/Optimizer/Builder/Todo.h" 13 #include "flang/Optimizer/Dialect/FIRDialect.h" 14 #include "flang/Optimizer/Dialect/FIROps.h" 15 #include "flang/Optimizer/Dialect/FIRType.h" 16 #include "flang/Optimizer/Dialect/Support/FIRContext.h" 17 #include "flang/Optimizer/HLFIR/HLFIRDialect.h" 18 #include "flang/Optimizer/HLFIR/HLFIROps.h" 19 #include "flang/Optimizer/HLFIR/Passes.h" 20 #include "mlir/IR/BuiltinDialect.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/IR/PatternMatch.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 26 #include <optional> 27 28 namespace hlfir { 29 #define GEN_PASS_DEF_LOWERHLFIRINTRINSICS 30 #include "flang/Optimizer/HLFIR/Passes.h.inc" 31 } // namespace hlfir 32 33 namespace { 34 35 /// Base class for passes converting transformational intrinsic operations into 36 /// runtime calls 37 template <class OP> 38 class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> { 39 public: 40 explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx) 41 : mlir::OpRewritePattern<OP>{ctx} { 42 // required for cases where intrinsics are chained together e.g. 43 // matmul(matmul(a, b), c) 44 // because converting the inner operation then invalidates the 45 // outer operation: causing the pattern to apply recursively. 46 // 47 // This is safe because we always progress with each iteration. Circular 48 // applications of operations are not expressible in MLIR because we use 49 // an SSA form and one must become first. E.g. 50 // %a = hlfir.matmul %b %d 51 // %b = hlfir.matmul %a %d 52 // cannot be written. 53 // MSVC needs the this-> 54 this->setHasBoundedRewriteRecursion(true); 55 } 56 57 protected: 58 struct IntrinsicArgument { 59 mlir::Value val; // allowed to be null if the argument is absent 60 mlir::Type desiredType; 61 }; 62 63 /// Lower the arguments to the intrinsic: adding necessary boxing and 64 /// conversion to match the signature of the intrinsic in the runtime library. 65 llvm::SmallVector<fir::ExtendedValue, 3> 66 lowerArguments(mlir::Operation *op, 67 const llvm::ArrayRef<IntrinsicArgument> &args, 68 mlir::PatternRewriter &rewriter, 69 const fir::IntrinsicArgumentLoweringRules *argLowering) const { 70 mlir::Location loc = op->getLoc(); 71 fir::FirOpBuilder builder{rewriter, op}; 72 73 llvm::SmallVector<fir::ExtendedValue, 3> ret; 74 llvm::SmallVector<std::function<void()>, 2> cleanupFns; 75 76 for (size_t i = 0; i < args.size(); ++i) { 77 mlir::Value arg = args[i].val; 78 mlir::Type desiredType = args[i].desiredType; 79 if (!arg) { 80 ret.emplace_back(fir::getAbsentIntrinsicArgument()); 81 continue; 82 } 83 hlfir::Entity entity{arg}; 84 85 fir::ArgLoweringRule argRules = 86 fir::lowerIntrinsicArgumentAs(*argLowering, i); 87 switch (argRules.lowerAs) { 88 case fir::LowerIntrinsicArgAs::Value: { 89 if (args[i].desiredType != arg.getType()) { 90 arg = builder.createConvert(loc, desiredType, arg); 91 entity = hlfir::Entity{arg}; 92 } 93 auto [exv, cleanup] = hlfir::convertToValue(loc, builder, entity); 94 if (cleanup) 95 cleanupFns.push_back(*cleanup); 96 ret.emplace_back(exv); 97 } break; 98 case fir::LowerIntrinsicArgAs::Addr: { 99 auto [exv, cleanup] = 100 hlfir::convertToAddress(loc, builder, entity, desiredType); 101 if (cleanup) 102 cleanupFns.push_back(*cleanup); 103 ret.emplace_back(exv); 104 } break; 105 case fir::LowerIntrinsicArgAs::Box: { 106 auto [box, cleanup] = 107 hlfir::convertToBox(loc, builder, entity, desiredType); 108 if (cleanup) 109 cleanupFns.push_back(*cleanup); 110 ret.emplace_back(box); 111 } break; 112 case fir::LowerIntrinsicArgAs::Inquired: { 113 if (args[i].desiredType != arg.getType()) { 114 arg = builder.createConvert(loc, desiredType, arg); 115 entity = hlfir::Entity{arg}; 116 } 117 // Place hlfir.expr in memory, and unbox fir.boxchar. Other entities 118 // are translated to fir::ExtendedValue without transofrmation (notably, 119 // pointers/allocatable are not dereferenced). 120 // TODO: once lowering to FIR retires, UBOUND and LBOUND can be 121 // simplified since the fir.box lowered here are now guarenteed to 122 // contain the local lower bounds thanks to the hlfir.declare (the extra 123 // rebox can be removed). 124 auto [exv, cleanup] = 125 hlfir::translateToExtendedValue(loc, builder, entity); 126 if (cleanup) 127 cleanupFns.push_back(*cleanup); 128 ret.emplace_back(exv); 129 } break; 130 } 131 } 132 133 if (cleanupFns.size()) { 134 auto oldInsertionPoint = builder.saveInsertionPoint(); 135 builder.setInsertionPointAfter(op); 136 for (std::function<void()> cleanup : cleanupFns) 137 cleanup(); 138 builder.restoreInsertionPoint(oldInsertionPoint); 139 } 140 141 return ret; 142 } 143 144 void processReturnValue(mlir::Operation *op, 145 const fir::ExtendedValue &resultExv, bool mustBeFreed, 146 fir::FirOpBuilder &builder, 147 mlir::PatternRewriter &rewriter) const { 148 mlir::Location loc = op->getLoc(); 149 150 mlir::Value firBase = fir::getBase(resultExv); 151 mlir::Type firBaseTy = firBase.getType(); 152 153 std::optional<hlfir::EntityWithAttributes> resultEntity; 154 if (fir::isa_trivial(firBaseTy)) { 155 // Some intrinsics return i1 when the original operation 156 // produces fir.logical<>, so we may need to cast it. 157 firBase = builder.createConvert(loc, op->getResult(0).getType(), firBase); 158 resultEntity = hlfir::EntityWithAttributes{firBase}; 159 } else { 160 resultEntity = 161 hlfir::genDeclare(loc, builder, resultExv, ".tmp.intrinsic_result", 162 fir::FortranVariableFlagsAttr{}); 163 } 164 165 if (resultEntity->isVariable()) { 166 hlfir::AsExprOp asExpr = builder.create<hlfir::AsExprOp>( 167 loc, *resultEntity, builder.createBool(loc, mustBeFreed)); 168 resultEntity = hlfir::EntityWithAttributes{asExpr.getResult()}; 169 } 170 171 mlir::Value base = resultEntity->getBase(); 172 if (!mlir::isa<hlfir::ExprType>(base.getType())) { 173 for (mlir::Operation *use : op->getResult(0).getUsers()) { 174 if (mlir::isa<hlfir::DestroyOp>(use)) 175 rewriter.eraseOp(use); 176 } 177 } 178 179 rewriter.replaceOp(op, base); 180 } 181 }; 182 183 // Given an integer or array of integer type, calculate the Kind parameter from 184 // the width for use in runtime intrinsic calls. 185 static unsigned getKindForType(mlir::Type ty) { 186 mlir::Type eltty = hlfir::getFortranElementType(ty); 187 unsigned width = mlir::cast<mlir::IntegerType>(eltty).getWidth(); 188 return width / 8; 189 } 190 191 template <class OP> 192 class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> { 193 using HlfirIntrinsicConversion<OP>::HlfirIntrinsicConversion; 194 using IntrinsicArgument = 195 typename HlfirIntrinsicConversion<OP>::IntrinsicArgument; 196 using HlfirIntrinsicConversion<OP>::lowerArguments; 197 using HlfirIntrinsicConversion<OP>::processReturnValue; 198 199 protected: 200 auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType, 201 mlir::PatternRewriter &rewriter, 202 std::string opName) const { 203 llvm::SmallVector<IntrinsicArgument, 3> inArgs; 204 inArgs.push_back({operation.getArray(), operation.getArray().getType()}); 205 inArgs.push_back({operation.getDim(), i32}); 206 inArgs.push_back({operation.getMask(), logicalType}); 207 auto *argLowering = fir::getIntrinsicArgumentLowering(opName); 208 return lowerArguments(operation, inArgs, rewriter, argLowering); 209 }; 210 211 auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType, 212 mlir::PatternRewriter &rewriter, std::string opName, 213 fir::FirOpBuilder builder) const { 214 llvm::SmallVector<IntrinsicArgument, 3> inArgs; 215 inArgs.push_back({operation.getArray(), operation.getArray().getType()}); 216 inArgs.push_back({operation.getDim(), i32}); 217 inArgs.push_back({operation.getMask(), logicalType}); 218 mlir::Value kind = builder.createIntegerConstant( 219 operation->getLoc(), i32, getKindForType(operation.getType())); 220 inArgs.push_back({kind, i32}); 221 inArgs.push_back({operation.getBack(), i32}); 222 auto *argLowering = fir::getIntrinsicArgumentLowering(opName); 223 return lowerArguments(operation, inArgs, rewriter, argLowering); 224 }; 225 226 auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType, 227 mlir::PatternRewriter &rewriter, 228 std::string opName) const { 229 llvm::SmallVector<IntrinsicArgument, 2> inArgs; 230 inArgs.push_back({operation.getMask(), logicalType}); 231 inArgs.push_back({operation.getDim(), i32}); 232 auto *argLowering = fir::getIntrinsicArgumentLowering(opName); 233 return lowerArguments(operation, inArgs, rewriter, argLowering); 234 }; 235 236 public: 237 llvm::LogicalResult 238 matchAndRewrite(OP operation, 239 mlir::PatternRewriter &rewriter) const override { 240 std::string opName; 241 if constexpr (std::is_same_v<OP, hlfir::SumOp>) { 242 opName = "sum"; 243 } else if constexpr (std::is_same_v<OP, hlfir::ProductOp>) { 244 opName = "product"; 245 } else if constexpr (std::is_same_v<OP, hlfir::MaxvalOp>) { 246 opName = "maxval"; 247 } else if constexpr (std::is_same_v<OP, hlfir::MinvalOp>) { 248 opName = "minval"; 249 } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp>) { 250 opName = "minloc"; 251 } else if constexpr (std::is_same_v<OP, hlfir::MaxlocOp>) { 252 opName = "maxloc"; 253 } else if constexpr (std::is_same_v<OP, hlfir::AnyOp>) { 254 opName = "any"; 255 } else if constexpr (std::is_same_v<OP, hlfir::AllOp>) { 256 opName = "all"; 257 } else { 258 return mlir::failure(); 259 } 260 261 fir::FirOpBuilder builder{rewriter, operation.getOperation()}; 262 const mlir::Location &loc = operation->getLoc(); 263 264 mlir::Type i32 = builder.getI32Type(); 265 mlir::Type logicalType = fir::LogicalType::get( 266 builder.getContext(), builder.getKindMap().defaultLogicalKind()); 267 268 llvm::SmallVector<fir::ExtendedValue, 0> args; 269 270 if constexpr (std::is_same_v<OP, hlfir::SumOp> || 271 std::is_same_v<OP, hlfir::ProductOp> || 272 std::is_same_v<OP, hlfir::MaxvalOp> || 273 std::is_same_v<OP, hlfir::MinvalOp>) { 274 args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName); 275 } else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> || 276 std::is_same_v<OP, hlfir::MaxlocOp>) { 277 args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName, 278 builder); 279 } else { 280 args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName); 281 } 282 283 mlir::Type scalarResultType = 284 hlfir::getFortranElementType(operation.getType()); 285 286 auto [resultExv, mustBeFreed] = 287 fir::genIntrinsicCall(builder, loc, opName, scalarResultType, args); 288 289 processReturnValue(operation, resultExv, mustBeFreed, builder, rewriter); 290 return mlir::success(); 291 } 292 }; 293 294 using SumOpConversion = HlfirReductionIntrinsicConversion<hlfir::SumOp>; 295 296 using ProductOpConversion = HlfirReductionIntrinsicConversion<hlfir::ProductOp>; 297 298 using MaxvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxvalOp>; 299 300 using MinvalOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinvalOp>; 301 302 using MinlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MinlocOp>; 303 304 using MaxlocOpConversion = HlfirReductionIntrinsicConversion<hlfir::MaxlocOp>; 305 306 using AnyOpConversion = HlfirReductionIntrinsicConversion<hlfir::AnyOp>; 307 308 using AllOpConversion = HlfirReductionIntrinsicConversion<hlfir::AllOp>; 309 310 struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> { 311 using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion; 312 313 llvm::LogicalResult 314 matchAndRewrite(hlfir::CountOp count, 315 mlir::PatternRewriter &rewriter) const override { 316 fir::FirOpBuilder builder{rewriter, count.getOperation()}; 317 const mlir::Location &loc = count->getLoc(); 318 319 mlir::Type i32 = builder.getI32Type(); 320 mlir::Type logicalType = fir::LogicalType::get( 321 builder.getContext(), builder.getKindMap().defaultLogicalKind()); 322 323 llvm::SmallVector<IntrinsicArgument, 3> inArgs; 324 inArgs.push_back({count.getMask(), logicalType}); 325 inArgs.push_back({count.getDim(), i32}); 326 mlir::Value kind = builder.createIntegerConstant( 327 count->getLoc(), i32, getKindForType(count.getType())); 328 inArgs.push_back({kind, i32}); 329 330 auto *argLowering = fir::getIntrinsicArgumentLowering("count"); 331 llvm::SmallVector<fir::ExtendedValue, 3> args = 332 lowerArguments(count, inArgs, rewriter, argLowering); 333 334 mlir::Type scalarResultType = hlfir::getFortranElementType(count.getType()); 335 336 auto [resultExv, mustBeFreed] = 337 fir::genIntrinsicCall(builder, loc, "count", scalarResultType, args); 338 339 processReturnValue(count, resultExv, mustBeFreed, builder, rewriter); 340 return mlir::success(); 341 } 342 }; 343 344 struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> { 345 using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion; 346 347 llvm::LogicalResult 348 matchAndRewrite(hlfir::MatmulOp matmul, 349 mlir::PatternRewriter &rewriter) const override { 350 fir::FirOpBuilder builder{rewriter, matmul.getOperation()}; 351 const mlir::Location &loc = matmul->getLoc(); 352 353 mlir::Value lhs = matmul.getLhs(); 354 mlir::Value rhs = matmul.getRhs(); 355 llvm::SmallVector<IntrinsicArgument, 2> inArgs; 356 inArgs.push_back({lhs, lhs.getType()}); 357 inArgs.push_back({rhs, rhs.getType()}); 358 359 auto *argLowering = fir::getIntrinsicArgumentLowering("matmul"); 360 llvm::SmallVector<fir::ExtendedValue, 2> args = 361 lowerArguments(matmul, inArgs, rewriter, argLowering); 362 363 mlir::Type scalarResultType = 364 hlfir::getFortranElementType(matmul.getType()); 365 366 auto [resultExv, mustBeFreed] = 367 fir::genIntrinsicCall(builder, loc, "matmul", scalarResultType, args); 368 369 processReturnValue(matmul, resultExv, mustBeFreed, builder, rewriter); 370 return mlir::success(); 371 } 372 }; 373 374 struct DotProductOpConversion 375 : public HlfirIntrinsicConversion<hlfir::DotProductOp> { 376 using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion; 377 378 llvm::LogicalResult 379 matchAndRewrite(hlfir::DotProductOp dotProduct, 380 mlir::PatternRewriter &rewriter) const override { 381 fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()}; 382 const mlir::Location &loc = dotProduct->getLoc(); 383 384 mlir::Value lhs = dotProduct.getLhs(); 385 mlir::Value rhs = dotProduct.getRhs(); 386 llvm::SmallVector<IntrinsicArgument, 2> inArgs; 387 inArgs.push_back({lhs, lhs.getType()}); 388 inArgs.push_back({rhs, rhs.getType()}); 389 390 auto *argLowering = fir::getIntrinsicArgumentLowering("dot_product"); 391 llvm::SmallVector<fir::ExtendedValue, 2> args = 392 lowerArguments(dotProduct, inArgs, rewriter, argLowering); 393 394 mlir::Type scalarResultType = 395 hlfir::getFortranElementType(dotProduct.getType()); 396 397 auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( 398 builder, loc, "dot_product", scalarResultType, args); 399 400 processReturnValue(dotProduct, resultExv, mustBeFreed, builder, rewriter); 401 return mlir::success(); 402 } 403 }; 404 405 class TransposeOpConversion 406 : public HlfirIntrinsicConversion<hlfir::TransposeOp> { 407 using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion; 408 409 llvm::LogicalResult 410 matchAndRewrite(hlfir::TransposeOp transpose, 411 mlir::PatternRewriter &rewriter) const override { 412 fir::FirOpBuilder builder{rewriter, transpose.getOperation()}; 413 const mlir::Location &loc = transpose->getLoc(); 414 415 mlir::Value arg = transpose.getArray(); 416 llvm::SmallVector<IntrinsicArgument, 1> inArgs; 417 inArgs.push_back({arg, arg.getType()}); 418 419 auto *argLowering = fir::getIntrinsicArgumentLowering("transpose"); 420 llvm::SmallVector<fir::ExtendedValue, 1> args = 421 lowerArguments(transpose, inArgs, rewriter, argLowering); 422 423 mlir::Type scalarResultType = 424 hlfir::getFortranElementType(transpose.getType()); 425 426 auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( 427 builder, loc, "transpose", scalarResultType, args); 428 429 processReturnValue(transpose, resultExv, mustBeFreed, builder, rewriter); 430 return mlir::success(); 431 } 432 }; 433 434 struct MatmulTransposeOpConversion 435 : public HlfirIntrinsicConversion<hlfir::MatmulTransposeOp> { 436 using HlfirIntrinsicConversion< 437 hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion; 438 439 llvm::LogicalResult 440 matchAndRewrite(hlfir::MatmulTransposeOp multranspose, 441 mlir::PatternRewriter &rewriter) const override { 442 fir::FirOpBuilder builder{rewriter, multranspose.getOperation()}; 443 const mlir::Location &loc = multranspose->getLoc(); 444 445 mlir::Value lhs = multranspose.getLhs(); 446 mlir::Value rhs = multranspose.getRhs(); 447 llvm::SmallVector<IntrinsicArgument, 2> inArgs; 448 inArgs.push_back({lhs, lhs.getType()}); 449 inArgs.push_back({rhs, rhs.getType()}); 450 451 auto *argLowering = fir::getIntrinsicArgumentLowering("matmul"); 452 llvm::SmallVector<fir::ExtendedValue, 2> args = 453 lowerArguments(multranspose, inArgs, rewriter, argLowering); 454 455 mlir::Type scalarResultType = 456 hlfir::getFortranElementType(multranspose.getType()); 457 458 auto [resultExv, mustBeFreed] = fir::genIntrinsicCall( 459 builder, loc, "matmul_transpose", scalarResultType, args); 460 461 processReturnValue(multranspose, resultExv, mustBeFreed, builder, rewriter); 462 return mlir::success(); 463 } 464 }; 465 466 class CShiftOpConversion : public HlfirIntrinsicConversion<hlfir::CShiftOp> { 467 using HlfirIntrinsicConversion<hlfir::CShiftOp>::HlfirIntrinsicConversion; 468 469 llvm::LogicalResult 470 matchAndRewrite(hlfir::CShiftOp cshift, 471 mlir::PatternRewriter &rewriter) const override { 472 fir::FirOpBuilder builder{rewriter, cshift.getOperation()}; 473 const mlir::Location &loc = cshift->getLoc(); 474 475 llvm::SmallVector<IntrinsicArgument, 3> inArgs; 476 mlir::Value array = cshift.getArray(); 477 inArgs.push_back({array, array.getType()}); 478 mlir::Value shift = cshift.getShift(); 479 inArgs.push_back({shift, shift.getType()}); 480 inArgs.push_back({cshift.getDim(), builder.getI32Type()}); 481 482 auto *argLowering = fir::getIntrinsicArgumentLowering("cshift"); 483 llvm::SmallVector<fir::ExtendedValue, 3> args = 484 lowerArguments(cshift, inArgs, rewriter, argLowering); 485 486 mlir::Type scalarResultType = 487 hlfir::getFortranElementType(cshift.getType()); 488 489 auto [resultExv, mustBeFreed] = 490 fir::genIntrinsicCall(builder, loc, "cshift", scalarResultType, args); 491 492 processReturnValue(cshift, resultExv, mustBeFreed, builder, rewriter); 493 return mlir::success(); 494 } 495 }; 496 497 class ReshapeOpConversion : public HlfirIntrinsicConversion<hlfir::ReshapeOp> { 498 using HlfirIntrinsicConversion<hlfir::ReshapeOp>::HlfirIntrinsicConversion; 499 500 llvm::LogicalResult 501 matchAndRewrite(hlfir::ReshapeOp reshape, 502 mlir::PatternRewriter &rewriter) const override { 503 fir::FirOpBuilder builder{rewriter, reshape.getOperation()}; 504 const mlir::Location &loc = reshape->getLoc(); 505 506 llvm::SmallVector<IntrinsicArgument, 4> inArgs; 507 mlir::Value array = reshape.getArray(); 508 inArgs.push_back({array, array.getType()}); 509 mlir::Value shape = reshape.getShape(); 510 inArgs.push_back({shape, shape.getType()}); 511 mlir::Type noneType = builder.getNoneType(); 512 mlir::Value pad = reshape.getPad(); 513 inArgs.push_back({pad, pad ? pad.getType() : noneType}); 514 mlir::Value order = reshape.getOrder(); 515 inArgs.push_back({order, order ? order.getType() : noneType}); 516 517 auto *argLowering = fir::getIntrinsicArgumentLowering("reshape"); 518 llvm::SmallVector<fir::ExtendedValue, 4> args = 519 lowerArguments(reshape, inArgs, rewriter, argLowering); 520 521 mlir::Type scalarResultType = 522 hlfir::getFortranElementType(reshape.getType()); 523 524 auto [resultExv, mustBeFreed] = 525 fir::genIntrinsicCall(builder, loc, "reshape", scalarResultType, args); 526 527 processReturnValue(reshape, resultExv, mustBeFreed, builder, rewriter); 528 return mlir::success(); 529 } 530 }; 531 532 class LowerHLFIRIntrinsics 533 : public hlfir::impl::LowerHLFIRIntrinsicsBase<LowerHLFIRIntrinsics> { 534 public: 535 void runOnOperation() override { 536 mlir::ModuleOp module = this->getOperation(); 537 mlir::MLIRContext *context = &getContext(); 538 mlir::RewritePatternSet patterns(context); 539 patterns.insert< 540 MatmulOpConversion, MatmulTransposeOpConversion, AllOpConversion, 541 AnyOpConversion, SumOpConversion, ProductOpConversion, 542 TransposeOpConversion, CountOpConversion, DotProductOpConversion, 543 MaxvalOpConversion, MinvalOpConversion, MinlocOpConversion, 544 MaxlocOpConversion, CShiftOpConversion, ReshapeOpConversion>(context); 545 546 // While conceptually this pass is performing dialect conversion, we use 547 // pattern rewrites here instead of dialect conversion because this pass 548 // looses array bounds from some of the expressions e.g. 549 // !hlfir.expr<2xi32> -> !hlfir.expr<?xi32> 550 // MLIR thinks this is a different type so dialect conversion fails. 551 // Pattern rewriting only requires that the resulting IR is still valid 552 mlir::GreedyRewriteConfig config; 553 // Prevent the pattern driver from merging blocks 554 config.enableRegionSimplification = 555 mlir::GreedySimplifyRegionLevel::Disabled; 556 557 if (mlir::failed( 558 mlir::applyPatternsGreedily(module, std::move(patterns), config))) { 559 mlir::emitError(mlir::UnknownLoc::get(context), 560 "failure in HLFIR intrinsic lowering"); 561 signalPassFailure(); 562 } 563 } 564 }; 565 } // namespace 566