Lines Matching full:mlir
38 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
39 #include "mlir/IR/Matchers.h"
40 #include "mlir/IR/Operation.h"
41 #include "mlir/Pass/Pass.h"
42 #include "mlir/Transforms/DialectConversion.h"
43 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
44 #include "mlir/Transforms/RegionUtils.h"
48 #include <mlir/Dialect/Arith/IR/Arith.h>
49 #include <mlir/IR/BuiltinTypes.h>
50 #include <mlir/IR/Location.h>
51 #include <mlir/IR/MLIRContext.h>
52 #include <mlir/IR/Value.h>
53 #include <mlir/Support/LLVM.h>
68 llvm::function_ref<mlir::FunctionType(fir::FirOpBuilder &)>;
70 llvm::function_ref<void(fir::FirOpBuilder &, mlir::func::FuncOp &)>;
72 fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp, unsigned rank,
73 mlir::Type elementType)>;
84 mlir::func::FuncOp getOrCreateFunction(fir::FirOpBuilder &builder,
85 const mlir::StringRef &basename,
89 void getDependentDialects(mlir::DialectRegistry ®istry) const override;
112 const mlir::StringRef &basename,
113 mlir::Type elementType);
121 getSimplificationBuilder(mlir::Operation *op, const fir::KindMapping &kindMap) {
123 auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
135 static mlir::FunctionType genNoneBoxType(fir::FirOpBuilder &builder,
136 const mlir::Type &elementType) {
137 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
138 return mlir::FunctionType::get(builder.getContext(), {boxType},
143 Op expectOp(mlir::Value val) {
144 if (Op op = mlir::dyn_cast_or_null<Op>(val.getDefiningOp()))
152 static mlir::Value findDefSingle(fir::ConvertOp op) {
160 static mlir::Value findDef(fir::ConvertOp op) {
161 mlir::Value defOp;
169 static bool isOperandAbsent(mlir::Value val) {
172 return mlir::isa_and_nonnull<fir::AbsentOp>(
178 static bool isTrueOrNotConstant(mlir::Value val) {
179 if (auto op = expectOp<mlir::arith::ConstantOp>(val)) {
180 return !mlir::matchPattern(val, mlir::m_Zero());
185 static bool isZero(mlir::Value val) {
188 if (mlir::Operation *defOp = op->getOperand(0).getDefiningOp())
189 return mlir::matchPattern(defOp, mlir::m_Zero());
194 static mlir::Value findBoxDef(mlir::Value val) {
202 static mlir::Value findMaskDef(mlir::Value val) {
210 static unsigned getDimCount(mlir::Value val) {
218 if (mlir::Value emboxVal = findBoxDef(val))
219 if (auto boxTy = mlir::dyn_cast<fir::BoxType>(emboxVal.getType()))
220 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(boxTy.getEleTy()))
231 static std::optional<mlir::Type> getArgElementType(mlir::Value val) {
232 mlir::Operation *defOp;
236 if (!mlir::isa<fir::ConvertOp>(defOp))
241 auto boxType = mlir::cast<fir::BoxType>(val.getType());
243 if (!mlir::isa<mlir::NoneType>(elementType))
248 using BodyOpGeneratorTy = llvm::function_ref<mlir::Value(
249 fir::FirOpBuilder &, mlir::Location, const mlir::Type &, mlir::Value,
250 mlir::Value)>;
251 using ContinueLoopGenTy = llvm::function_ref<llvm::SmallVector<mlir::Value>(
252 fir::FirOpBuilder &, mlir::Location, mlir::Value)>;
266 /// mlir constant, and controls the inital value for while loops
271 genReductionLoop(fir::FirOpBuilder &builder, mlir::func::FuncOp &funcOp,
274 unsigned rank, mlir::Type elementType, mlir::Location loc) {
276 mlir::IndexType idxTy = builder.getIndexType();
278 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
279 mlir::Value arg = args[0];
281 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
285 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
286 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
287 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, arg);
288 mlir::Type resultType = funcOp.getResultTypes()[0];
289 mlir::Value init = initVal(builder, loc, resultType);
291 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
294 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
301 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
304 mlir::Value len = dims.getResult(1);
306 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
315 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
317 mlir::Value step = one;
318 mlir::Value loopCount = bounds[i - 1];
333 mlir::Type eleRefTy = builder.getRefType(elementType);
334 mlir::Value addr =
336 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
337 mlir::Value reductionVal = genBody(builder, loc, elementType, elem, init);
341 llvm::SmallVector<mlir::Value> results = loopCond(builder, loc, reductionVal);
349 auto loop = mlir::cast<OP>(result->getParentOp());
357 builder.create<mlir::func::ReturnOp>(loc, results[resultIndex]);
360 static llvm::SmallVector<mlir::Value> nopLoopCond(fir::FirOpBuilder &builder,
361 mlir::Location loc,
362 mlir::Value reductionVal) {
372 mlir::func::FuncOp &funcOp, unsigned rank,
373 mlir::Type elementType) {
383 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
384 mlir::Type elementType) {
385 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
393 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
394 mlir::Type elementType, mlir::Value elem1,
395 mlir::Value elem2) -> mlir::Value {
396 if (mlir::isa<mlir::FloatType>(elementType))
397 return builder.create<mlir::arith::AddFOp>(loc, elem1, elem2);
398 if (mlir::isa<mlir::IntegerType>(elementType))
399 return builder.create<mlir::arith::AddIOp>(loc, elem1, elem2);
405 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
414 mlir::func::FuncOp &funcOp, unsigned rank,
415 mlir::Type elementType) {
416 auto init = [](fir::FirOpBuilder builder, mlir::Location loc,
417 mlir::Type elementType) {
418 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
428 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
429 mlir::Type elementType, mlir::Value elem1,
430 mlir::Value elem2) -> mlir::Value {
431 if (mlir::isa<mlir::FloatType>(elementType)) {
439 auto compare = builder.create<mlir::arith::CmpFOp>(
440 loc, mlir::arith::CmpFPredicate::OGT, elem1, elem2);
441 return builder.create<mlir::arith::SelectOp>(loc, compare, elem1, elem2);
443 if (mlir::isa<mlir::IntegerType>(elementType))
444 return builder.create<mlir::arith::MaxSIOp>(loc, elem1, elem2);
450 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
459 mlir::func::FuncOp &funcOp, unsigned rank,
460 mlir::Type elementType) {
461 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
462 mlir::Type elementType) {
468 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
469 mlir::Type elementType, mlir::Value elem1,
470 mlir::Value elem2) -> mlir::Value {
475 auto compare = builder.create<mlir::arith::CmpIOp>(
476 loc, mlir::arith::CmpIPredicate::eq, elem1, zero32);
478 builder.create<mlir::arith::SelectOp>(loc, compare, zero64, one64);
479 return builder.create<mlir::arith::AddIOp>(loc, select, elem2);
484 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
493 mlir::func::FuncOp &funcOp, unsigned rank,
494 mlir::Type elementType) {
495 auto zero = [](fir::FirOpBuilder builder, mlir::Location loc,
496 mlir::Type elementType) {
500 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
501 mlir::Type elementType, mlir::Value elem1,
502 mlir::Value elem2) -> mlir::Value {
504 return builder.create<mlir::arith::CmpIOp>(
505 loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
508 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
509 mlir::Value reductionVal) {
511 auto eor = builder.create<mlir::arith::XOrIOp>(loc, reductionVal, one1);
512 llvm::SmallVector<mlir::Value> results = {eor, reductionVal};
516 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
518 mlir::Value ok = builder.createBool(loc, true);
520 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
526 mlir::func::FuncOp &funcOp, unsigned rank,
527 mlir::Type elementType) {
528 auto one = [](fir::FirOpBuilder builder, mlir::Location loc,
529 mlir::Type elementType) {
533 auto genBodyOp = [](fir::FirOpBuilder builder, mlir::Location loc,
534 mlir::Type elementType, mlir::Value elem1,
535 mlir::Value elem2) -> mlir::Value {
537 return builder.create<mlir::arith::CmpIOp>(
538 loc, mlir::arith::CmpIPredicate::ne, elem1, zero);
541 auto continueCond = [](fir::FirOpBuilder builder, mlir::Location loc,
542 mlir::Value reductionVal) {
543 llvm::SmallVector<mlir::Value> results = {reductionVal, reductionVal};
547 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
549 mlir::Value ok = builder.createBool(loc, true);
551 genReductionLoop<fir::IterWhileOp, mlir::Value, 1>(
556 static mlir::FunctionType genRuntimeMinlocType(fir::FirOpBuilder &builder,
558 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
559 mlir::Type boxRefType = builder.getRefType(boxType);
561 return mlir::FunctionType::get(builder.getContext(),
567 fir::FirOpBuilder &builder, mlir::Value array,
569 fir::AddrGeneratorTy getAddrFn, unsigned rank, mlir::Type elementType,
570 mlir::Location loc, mlir::Type maskElemType, mlir::Value resultArr,
572 mlir::IndexType idxTy = builder.getIndexType();
574 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
578 mlir::Type arrTy = fir::SequenceType::get(flatShape, elementType);
579 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
582 mlir::Type resultElemType = hlfir::getFortranElementType(resultArr.getType());
583 mlir::Value flagSet = builder.createIntegerConstant(loc, resultElemType, 1);
584 mlir::Value zero = builder.createIntegerConstant(loc, resultElemType, 0);
585 mlir::Value flagRef = builder.createTemporary(loc, resultElemType);
588 mlir::Value init = initVal(builder, loc, elementType);
589 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> bounds;
592 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
599 mlir::Value dimIdx = builder.createIntegerConstant(loc, idxTy, i);
602 mlir::Value len = dims.getResult(1);
604 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
613 llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
615 mlir::Value step = one;
616 mlir::Value loopCount = bounds[i - 1];
630 mlir::Value reductionVal =
639 auto loop = mlir::cast<fir::DoLoopOp>(result->getParentOp());
648 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp())) {
659 mlir::func::FuncOp &funcOp, bool isMax,
661 mlir::Type elementType,
662 mlir::Type maskElemType,
663 mlir::Type resultElemTy, bool isDim) {
664 auto init = [isMax](fir::FirOpBuilder builder, mlir::Location loc,
665 mlir::Type elementType) {
666 if (auto ty = mlir::dyn_cast<mlir::FloatType>(elementType)) {
678 mlir::Location loc = mlir::UnknownLoc::get(builder.getContext());
681 mlir::Value mask = funcOp.front().getArgument(2);
684 mlir::IndexType idxTy = builder.getIndexType();
685 mlir::Type resultTy = fir::SequenceType::get(rank, resultElemTy);
686 mlir::Type resultHeapTy = fir::HeapType::get(resultTy);
687 mlir::Type resultBoxTy = fir::BoxType::get(resultHeapTy);
689 mlir::Value returnValue = builder.createIntegerConstant(loc, resultElemTy, 0);
690 mlir::Value resultArrSize = builder.createIntegerConstant(loc, idxTy, rank);
692 mlir::Value resultArrInit = builder.create<fir::AllocMemOp>(loc, resultTy);
693 mlir::Value resultArrShape = builder.create<fir::ShapeOp>(loc, resultArrSize);
694 mlir::Value resultArr = builder.create<fir::EmboxOp>(
697 mlir::Type resultRefTy = builder.getRefType(resultElemTy);
702 mlir::Type maskTy = fir::SequenceType::get(flatShape, maskElemType);
703 mlir::Type boxMaskTy = fir::BoxType::get(maskTy);
708 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
709 mlir::Value resultElemAddr =
716 fir::FirOpBuilder builder, mlir::Location loc, mlir::Type elementType,
717 mlir::Value array, mlir::Value flagRef, mlir::Value reduction,
718 const llvm::SmallVectorImpl<mlir::Value> &indices) -> mlir::Value {
721 mlir::Type logicalRef = builder.getRefType(maskElemType);
722 mlir::Value maskAddr =
724 mlir::Value maskElem = builder.create<fir::LoadOp>(loc, maskAddr);
728 mlir::Type ifCompatType = builder.getI1Type();
729 mlir::Value ifCompatElem =
732 llvm::SmallVector<mlir::Type> resultsTy = {elementType, elementType};
739 mlir::Value flagSet = builder.createIntegerConstant(
740 loc, mlir::cast<fir::ReferenceType>(flagRef.getType()).getEleTy(), 1);
741 mlir::Value isFirst = builder.create<fir::LoadOp>(loc, flagRef);
742 mlir::Type eleRefTy = builder.getRefType(elementType);
743 mlir::Value addr =
745 mlir::Value elem = builder.create<fir::LoadOp>(loc, addr);
747 mlir::Value cmp;
748 if (mlir::isa<mlir::FloatType>(elementType)) {
753 cmp = builder.create<mlir::arith::CmpFOp>(
755 isMax ? mlir::arith::CmpFPredicate::OGT
756 : mlir::arith::CmpFPredicate::OLT,
759 mlir::Value cmpNan = builder.create<mlir::arith::CmpFOp>(
760 loc, mlir::arith::CmpFPredicate::UNE, reduction, reduction);
761 mlir::Value cmpNan2 = builder.create<mlir::arith::CmpFOp>(
762 loc, mlir::arith::CmpFPredicate::OEQ, elem, elem);
763 cmpNan = builder.create<mlir::arith::AndIOp>(loc, cmpNan, cmpNan2);
764 cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, cmpNan);
765 } else if (mlir::isa<mlir::IntegerType>(elementType)) {
766 cmp = builder.create<mlir::arith::CmpIOp>(
768 isMax ? mlir::arith::CmpIPredicate::sgt
769 : mlir::arith::CmpIPredicate::slt,
777 isFirst = builder.create<mlir::arith::XOrIOp>(
779 cmp = builder.create<mlir::arith::OrIOp>(loc, cmp, isFirst);
785 mlir::Type resultElemTy = hlfir::getFortranElementType(resultArr.getType());
786 mlir::Type returnRefTy = builder.getRefType(resultElemTy);
787 mlir::IndexType idxTy = builder.getIndexType();
789 mlir::Value one = builder.createIntegerConstant(loc, resultElemTy, 1);
792 mlir::Value index = builder.createIntegerConstant(loc, idxTy, i);
793 mlir::Value resultElemAddr =
795 mlir::Value convert =
797 mlir::Value fortranIndex =
798 builder.create<mlir::arith::AddIOp>(loc, convert, one);
805 mlir::Value reductionVal = ifOp.getResult(0);
810 mlir::dyn_cast<fir::IfOp>(builder.getBlock()->getParentOp());
824 mlir::Type logical = builder.getI1Type();
825 mlir::IndexType idxTy = builder.getIndexType();
828 mlir::Type arrTy = fir::SequenceType::get(singleElement, logical);
829 mlir::Type boxArrTy = fir::BoxType::get(arrTy);
830 mlir::Value array = builder.create<fir::ConvertOp>(loc, boxArrTy, mask);
832 mlir::Value indx = builder.createIntegerConstant(loc, idxTy, 0);
833 mlir::Type logicalRefTy = builder.getRefType(logical);
834 mlir::Value condAddr =
836 mlir::Value cond = builder.create<fir::LoadOp>(loc, condAddr);
842 mlir::Value basicValue;
843 if (mlir::isa<mlir::IntegerType>(elementType)) {
852 auto getAddrFn = [](fir::FirOpBuilder builder, mlir::Location loc,
853 const mlir::Type &resultElemType, mlir::Value resultArr,
854 mlir::Value index) {
855 mlir::Type resultRefTy = builder.getRefType(resultElemType);
866 mlir::Type resultBoxTy =
868 mlir::Value outputArr = builder.create<fir::ConvertOp>(
870 mlir::Value resultArrScalar = builder.create<fir::ConvertOp>(
872 mlir::Value resultBox =
877 mlir::Type outputArrTy = fir::SequenceType::get(resultShape, resultElemTy);
878 mlir::Type outputHeapTy = fir::HeapType::get(outputArrTy);
879 mlir::Type outputBoxTy = fir::BoxType::get(outputHeapTy);
880 mlir::Type outputRefTy = builder.getRefType(outputBoxTy);
881 mlir::Value outputArr = builder.create<fir::ConvertOp>(
886 builder.create<mlir::func::ReturnOp>(loc);
891 static mlir::FunctionType genRuntimeDotType(fir::FirOpBuilder &builder,
892 const mlir::Type &elementType) {
893 mlir::Type boxType = fir::BoxType::get(builder.getNoneType());
894 return mlir::FunctionType::get(builder.getContext(), {boxType, boxType},
906 mlir::func::FuncOp &funcOp,
907 mlir::Type arg1ElementTy,
908 mlir::Type arg2ElementTy) {
918 auto loc = mlir::UnknownLoc::get(builder.getContext());
919 mlir::Type resultElementType = funcOp.getResultTypes()[0];
922 mlir::IndexType idxTy = builder.getIndexType();
924 mlir::Value zero =
925 mlir::isa<mlir::FloatType>(resultElementType)
929 mlir::Block::BlockArgListType args = funcOp.front().getArguments();
930 mlir::Value arg1 = args[0];
931 mlir::Value arg2 = args[1];
933 mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
936 mlir::Type arrTy1 = fir::SequenceType::get(flatShape, arg1ElementTy);
937 mlir::Type boxArrTy1 = fir::BoxType::get(arrTy1);
938 mlir::Value array1 = builder.create<fir::ConvertOp>(loc, boxArrTy1, arg1);
939 mlir::Type arrTy2 = fir::SequenceType::get(flatShape, arg2ElementTy);
940 mlir::Type boxArrTy2 = fir::BoxType::get(arrTy2);
941 mlir::Value array2 = builder.create<fir::ConvertOp>(loc, boxArrTy2, arg2);
952 mlir::Value len = dims.getResult(1);
953 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
954 mlir::Value step = one;
957 mlir::Value loopCount = builder.create<mlir::arith::SubIOp>(loc, len, one);
961 mlir::Value sumVal = loop.getRegionIterArgs()[0];
964 mlir::OpBuilder::InsertPoint loopEndPt = builder.saveInsertionPoint();
967 mlir::Type eleRef1Ty = builder.getRefType(arg1ElementTy);
968 mlir::Value index = loop.getInductionVar();
969 mlir::Value addr1 =
971 mlir::Value elem1 = builder.create<fir::LoadOp>(loc, addr1);
975 mlir::Type eleRef2Ty = builder.getRefType(arg2ElementTy);
976 mlir::Value addr2 =
978 mlir::Value elem2 = builder.create<fir::LoadOp>(loc, addr2);
982 if (mlir::isa<mlir::FloatType>(resultElementType))
983 sumVal = builder.create<mlir::arith::AddFOp>(
984 loc, builder.create<mlir::arith::MulFOp>(loc, elem1, elem2), sumVal);
985 else if (mlir::isa<mlir::IntegerType>(resultElementType))
986 sumVal = builder.create<mlir::arith::AddIOp>(
987 loc, builder.create<mlir::arith::MulIOp>(loc, elem1, elem2), sumVal);
995 mlir::Value resultVal = loop.getResult(0);
996 builder.create<mlir::func::ReturnOp>(loc, resultVal);
999 mlir::func::FuncOp SimplifyIntrinsicsPass::getOrCreateFunction(
1000 fir::FirOpBuilder &builder, const mlir::StringRef &baseName,
1010 std::string replacementName = mlir::Twine{baseName, "_simplified"}.str();
1012 mlir::func::FuncOp newFunc = builder.getNamedFunction(replacementName);
1013 mlir::FunctionType fType = typeGenerator(builder);
1021 auto loc = mlir::UnknownLoc::get(builder.getContext());
1023 auto inlineLinkage = mlir::LLVM::linkage::Linkage::LinkonceODR;
1025 mlir::LLVM::LinkageAttr::get(builder.getContext(), inlineLinkage);
1029 mlir::OpBuilder::InsertPoint insertPt = builder.saveInsertionPoint();
1043 mlir::Operation::operand_range args = call.getArgs();
1045 const mlir::Value &dim = args[3];
1046 const mlir::Value &mask = args[4];
1058 mlir::Type resultType = call.getResult(0).getType();
1060 if (!mlir::isa<mlir::FloatType>(resultType) &&
1061 !mlir::isa<mlir::IntegerType>(resultType))
1070 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1075 (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1076 mlir::Twine{rank} +
1079 (fmfString.empty() ? mlir::Twine{} : mlir::Twine{"_", fmfString}))
1090 mlir::Operation::operand_range args = call.getArgs();
1091 const mlir::Value &dim = args[3];
1099 mlir::Value inputBox = findBoxDef(args[0]);
1101 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1102 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1108 mlir::dyn_cast<fir::LogicalType>(elementType)};
1110 mlir::Type intElementType = builder.getIntegerType(kind * 8);
1114 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1115 mlir::Twine{kind} + "x" + mlir::Twine{rank})
1126 mlir::Operation::operand_range args = call.getArgs();
1127 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1128 mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1137 mlir::Value inputBox = findBoxDef(args[0]);
1138 mlir::Type elementType = hlfir::getFortranElementType(inputBox.getType());
1144 mlir::dyn_cast<fir::LogicalType>(elementType)};
1146 mlir::Type intElementType = builder.getIntegerType(kind * 8);
1150 (mlir::Twine{callee.getLeafReference().getValue(), "Logical"} +
1151 mlir::Twine{kind} + "x" + mlir::Twine{rank})
1161 mlir::Operation::operand_range args = call.getArgs();
1163 mlir::SymbolRefAttr callee = call.getCalleeAttr();
1164 mlir::StringRef funcNameBase = callee.getLeafReference().getValue();
1166 mlir::Value back = args[isDim ? 7 : 6];
1170 mlir::Value mask = args[isDim ? 6 : 5];
1171 mlir::Value maskDef = findMaskDef(mask);
1184 mlir::Location loc = call.getLoc();
1186 mlir::Type inputType = hlfir::getFortranElementType(inputBox.getType());
1188 if (mlir::isa<fir::CharacterType>(inputType))
1193 mlir::Type logicalElemType = builder.getI1Type();
1198 mlir::Type maskElemTy = hlfir::getFortranElementType(maskDef.getType());
1200 mlir::dyn_cast<fir::LogicalType>(maskElemTy)};
1202 // Convert fir::LogicalType to mlir::Type
1206 mlir::Operation *outputDef = args[0].getDefiningOp();
1207 mlir::Value outputAlloc = outputDef->getOperand(0);
1208 mlir::Type outType = hlfir::getFortranElementType(outputAlloc.getType());
1212 (mlir::Twine{callee.getLeafReference().getValue(), "x"} +
1213 mlir::Twine{rank} +
1215 ? "_Logical" + mlir::Twine{kind} + "x" + mlir::Twine{maskRank}
1231 mlir::func::FuncOp &funcOp) {
1236 mlir::func::FuncOp newFunc =
1239 mlir::ValueRange{args[0], args[1], mask});
1247 const mlir::StringRef &funcName, mlir::Type elementType) {
1249 mlir::Operation::operand_range args = call.getArgs();
1251 mlir::Type resultType = call.getResult(0).getType();
1254 mlir::Location loc = call.getLoc();
1261 mlir::func::FuncOp &funcOp) {
1265 mlir::func::FuncOp newFunc =
1268 builder.create<fir::CallOp>(loc, newFunc, mlir::ValueRange{args[0]});
1276 mlir::ModuleOp module = getOperation();
1278 module.walk([&](mlir::Operation *op) {
1279 if (auto call = mlir::dyn_cast<fir::CallOp>(op)) {
1282 if (mlir::SymbolRefAttr callee = call.getCalleeAttr()) {
1283 mlir::StringRef funcName = callee.getLeafReference().getValue();
1301 mlir::Operation::operand_range args = call.getArgs();
1302 const mlir::Value &v1 = args[0];
1303 const mlir::Value &v2 = args[1];
1304 mlir::Location loc = call.getLoc();
1310 mlir::Type type = call.getResult(0).getType();
1311 if (!mlir::isa<mlir::FloatType>(type) &&
1312 !mlir::isa<mlir::IntegerType>(type))
1324 if (!mlir::isa<mlir::FloatType, mlir::IntegerType>(*arg1Type))
1326 if (!mlir::isa<mlir::FloatType, mlir::IntegerType>(*arg2Type))
1334 mlir::func::FuncOp &funcOp) {
1351 mlir::func::FuncOp newFunc = getOrCreateFunction(
1354 mlir::ValueRange{v1, v2});
1394 mlir::DialectRegistry ®istry) const {
1396 registry.insert<mlir::LLVM::LLVMDialect>();