xref: /llvm-project/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (revision 382d3599c203573388b82717dc17e3db4039916a)
1 //===-- CUFDeviceGlobal.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Transforms/CUFOpConversion.h"
10 #include "flang/Common/Fortran.h"
11 #include "flang/Optimizer/Builder/CUFCommon.h"
12 #include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
13 #include "flang/Optimizer/CodeGen/TypeConverter.h"
14 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
15 #include "flang/Optimizer/Dialect/FIRDialect.h"
16 #include "flang/Optimizer/Dialect/FIROps.h"
17 #include "flang/Optimizer/HLFIR/HLFIROps.h"
18 #include "flang/Optimizer/Support/DataLayout.h"
19 #include "flang/Runtime/CUDA/allocatable.h"
20 #include "flang/Runtime/CUDA/common.h"
21 #include "flang/Runtime/CUDA/descriptor.h"
22 #include "flang/Runtime/CUDA/memory.h"
23 #include "flang/Runtime/CUDA/pointer.h"
24 #include "flang/Runtime/allocatable.h"
25 #include "mlir/Conversion/LLVMCommon/Pattern.h"
26 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 
32 namespace fir {
33 #define GEN_PASS_DEF_CUFOPCONVERSION
34 #include "flang/Optimizer/Transforms/Passes.h.inc"
35 } // namespace fir
36 
37 using namespace fir;
38 using namespace mlir;
39 using namespace Fortran::runtime;
40 using namespace Fortran::runtime::cuda;
41 
42 namespace {
43 
44 static inline unsigned getMemType(cuf::DataAttribute attr) {
45   if (attr == cuf::DataAttribute::Device)
46     return kMemTypeDevice;
47   if (attr == cuf::DataAttribute::Managed)
48     return kMemTypeManaged;
49   if (attr == cuf::DataAttribute::Unified)
50     return kMemTypeUnified;
51   if (attr == cuf::DataAttribute::Pinned)
52     return kMemTypePinned;
53   llvm::report_fatal_error("unsupported memory type");
54 }
55 
56 template <typename OpTy>
57 static bool isPinned(OpTy op) {
58   if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
59     return true;
60   return false;
61 }
62 
63 template <typename OpTy>
64 static bool hasDoubleDescriptors(OpTy op) {
65   if (auto declareOp =
66           mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
67     if (mlir::isa_and_nonnull<fir::AddrOfOp>(
68             declareOp.getMemref().getDefiningOp())) {
69       if (isPinned(declareOp))
70         return false;
71       return true;
72     }
73   } else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
74                  op.getBox().getDefiningOp())) {
75     if (mlir::isa_and_nonnull<fir::AddrOfOp>(
76             declareOp.getMemref().getDefiningOp())) {
77       if (isPinned(declareOp))
78         return false;
79       return true;
80     }
81   }
82   return false;
83 }
84 
85 static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
86                                    mlir::Location loc, mlir::Type toTy,
87                                    mlir::Value val) {
88   if (val.getType() != toTy)
89     return rewriter.create<fir::ConvertOp>(loc, toTy, val);
90   return val;
91 }
92 
93 template <typename OpTy>
94 static mlir::LogicalResult convertOpToCall(OpTy op,
95                                            mlir::PatternRewriter &rewriter,
96                                            mlir::func::FuncOp func) {
97   auto mod = op->template getParentOfType<mlir::ModuleOp>();
98   fir::FirOpBuilder builder(rewriter, mod);
99   mlir::Location loc = op.getLoc();
100   auto fTy = func.getFunctionType();
101 
102   mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
103   mlir::Value sourceLine;
104   if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>)
105     sourceLine = fir::factory::locationToLineNo(
106         builder, loc, op.getSource() ? fTy.getInput(6) : fTy.getInput(5));
107   else
108     sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
109 
110   mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
111                                         : builder.createBool(loc, false);
112 
113   mlir::Value errmsg;
114   if (op.getErrmsg()) {
115     errmsg = op.getErrmsg();
116   } else {
117     mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
118     errmsg = builder.create<fir::AbsentOp>(loc, boxNoneTy).getResult();
119   }
120   llvm::SmallVector<mlir::Value> args;
121   if constexpr (std::is_same_v<OpTy, cuf::AllocateOp>) {
122     if (op.getSource()) {
123       mlir::Value stream =
124           op.getStream()
125               ? op.getStream()
126               : builder.createIntegerConstant(loc, fTy.getInput(2), -1);
127       args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(),
128                                            op.getSource(), stream, hasStat,
129                                            errmsg, sourceFile, sourceLine);
130     } else {
131       mlir::Value stream =
132           op.getStream()
133               ? op.getStream()
134               : builder.createIntegerConstant(loc, fTy.getInput(1), -1);
135       args = fir::runtime::createArguments(builder, loc, fTy, op.getBox(),
136                                            stream, hasStat, errmsg, sourceFile,
137                                            sourceLine);
138     }
139   } else {
140     args =
141         fir::runtime::createArguments(builder, loc, fTy, op.getBox(), hasStat,
142                                       errmsg, sourceFile, sourceLine);
143   }
144   auto callOp = builder.create<fir::CallOp>(loc, func, args);
145   rewriter.replaceOp(op, callOp);
146   return mlir::success();
147 }
148 
149 struct CUFAllocateOpConversion
150     : public mlir::OpRewritePattern<cuf::AllocateOp> {
151   using OpRewritePattern::OpRewritePattern;
152 
153   mlir::LogicalResult
154   matchAndRewrite(cuf::AllocateOp op,
155                   mlir::PatternRewriter &rewriter) const override {
156     // TODO: Pinned is a reference to a logical value that can be set to true
157     // when pinned allocation succeed. This will require a new entry point.
158     if (op.getPinned())
159       return mlir::failure();
160 
161     auto mod = op->getParentOfType<mlir::ModuleOp>();
162     fir::FirOpBuilder builder(rewriter, mod);
163     mlir::Location loc = op.getLoc();
164 
165     bool isPointer = false;
166 
167     if (auto declareOp =
168             mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp()))
169       if (declareOp.getFortranAttrs() &&
170           bitEnumContainsAny(*declareOp.getFortranAttrs(),
171                              fir::FortranVariableFlagsEnum::pointer))
172         isPointer = true;
173 
174     if (hasDoubleDescriptors(op)) {
175       // Allocation for module variable are done with custom runtime entry point
176       // so the descriptors can be synchronized.
177       mlir::func::FuncOp func;
178       if (op.getSource()) {
179         func = isPointer ? fir::runtime::getRuntimeFunc<mkRTKey(
180                                CUFPointerAllocateSourceSync)>(loc, builder)
181                          : fir::runtime::getRuntimeFunc<mkRTKey(
182                                CUFAllocatableAllocateSourceSync)>(loc, builder);
183       } else {
184         func =
185             isPointer
186                 ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSync)>(
187                       loc, builder)
188                 : fir::runtime::getRuntimeFunc<mkRTKey(
189                       CUFAllocatableAllocateSync)>(loc, builder);
190       }
191       return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
192     }
193 
194     mlir::func::FuncOp func;
195     if (op.getSource()) {
196       func =
197           isPointer
198               ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocateSource)>(
199                     loc, builder)
200               : fir::runtime::getRuntimeFunc<mkRTKey(
201                     CUFAllocatableAllocateSource)>(loc, builder);
202     } else {
203       func =
204           isPointer
205               ? fir::runtime::getRuntimeFunc<mkRTKey(CUFPointerAllocate)>(
206                     loc, builder)
207               : fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
208                     loc, builder);
209     }
210 
211     return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
212   }
213 };
214 
215 struct CUFDeallocateOpConversion
216     : public mlir::OpRewritePattern<cuf::DeallocateOp> {
217   using OpRewritePattern::OpRewritePattern;
218 
219   mlir::LogicalResult
220   matchAndRewrite(cuf::DeallocateOp op,
221                   mlir::PatternRewriter &rewriter) const override {
222 
223     auto mod = op->getParentOfType<mlir::ModuleOp>();
224     fir::FirOpBuilder builder(rewriter, mod);
225     mlir::Location loc = op.getLoc();
226 
227     if (hasDoubleDescriptors(op)) {
228       // Deallocation for module variable are done with custom runtime entry
229       // point so the descriptors can be synchronized.
230       mlir::func::FuncOp func =
231           fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
232               loc, builder);
233       return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
234     }
235 
236     // Deallocation for local descriptor falls back on the standard runtime
237     // AllocatableDeallocate as the dedicated deallocator is set in the
238     // descriptor before the call.
239     mlir::func::FuncOp func =
240         fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
241                                                                      builder);
242     return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
243   }
244 };
245 
246 static bool inDeviceContext(mlir::Operation *op) {
247   if (op->getParentOfType<cuf::KernelOp>())
248     return true;
249   if (auto funcOp = op->getParentOfType<mlir::gpu::GPUFuncOp>())
250     return true;
251   if (auto funcOp = op->getParentOfType<mlir::gpu::LaunchOp>())
252     return true;
253   if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
254     if (auto cudaProcAttr =
255             funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
256                 cuf::getProcAttrName())) {
257       return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
258              cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
259     }
260   }
261   return false;
262 }
263 
264 static int computeWidth(mlir::Location loc, mlir::Type type,
265                         fir::KindMapping &kindMap) {
266   auto eleTy = fir::unwrapSequenceType(type);
267   if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
268     return t.getWidth() / 8;
269   if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
270     return t.getWidth() / 8;
271   if (eleTy.isInteger(1))
272     return 1;
273   if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
274     return kindMap.getLogicalBitsize(t.getFKind()) / 8;
275   if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
276     int elemSize =
277         mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
278     return 2 * elemSize;
279   }
280   if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
281     return kindMap.getCharacterBitsize(t.getFKind()) / 8;
282   mlir::emitError(loc, "unsupported type");
283   return 0;
284 }
285 
286 struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
287   using OpRewritePattern::OpRewritePattern;
288 
289   CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
290                        const fir::LLVMTypeConverter *typeConverter)
291       : OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
292 
293   mlir::LogicalResult
294   matchAndRewrite(cuf::AllocOp op,
295                   mlir::PatternRewriter &rewriter) const override {
296 
297     mlir::Location loc = op.getLoc();
298 
299     if (inDeviceContext(op.getOperation())) {
300       // In device context just replace the cuf.alloc operation with a fir.alloc
301       // the cuf.free will be removed.
302       auto allocaOp = rewriter.create<fir::AllocaOp>(
303           loc, op.getInType(), op.getUniqName() ? *op.getUniqName() : "",
304           op.getBindcName() ? *op.getBindcName() : "", op.getTypeparams(),
305           op.getShape());
306       allocaOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
307       rewriter.replaceOp(op, allocaOp);
308       return mlir::success();
309     }
310 
311     auto mod = op->getParentOfType<mlir::ModuleOp>();
312     fir::FirOpBuilder builder(rewriter, mod);
313     mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
314 
315     if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
316       // Convert scalar and known size array allocations.
317       mlir::Value bytes;
318       fir::KindMapping kindMap{fir::getKindMapping(mod)};
319       if (fir::isa_trivial(op.getInType())) {
320         int width = computeWidth(loc, op.getInType(), kindMap);
321         bytes =
322             builder.createIntegerConstant(loc, builder.getIndexType(), width);
323       } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
324                      op.getInType())) {
325         std::size_t size = 0;
326         if (fir::isa_derived(seqTy.getEleTy())) {
327           mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
328           size = dl->getTypeSizeInBits(structTy) / 8;
329         } else {
330           size = computeWidth(loc, seqTy.getEleTy(), kindMap);
331         }
332         mlir::Value width =
333             builder.createIntegerConstant(loc, builder.getIndexType(), size);
334         mlir::Value nbElem;
335         if (fir::sequenceWithNonConstantShape(seqTy)) {
336           assert(!op.getShape().empty() && "expect shape with dynamic arrays");
337           nbElem = builder.loadIfRef(loc, op.getShape()[0]);
338           for (unsigned i = 1; i < op.getShape().size(); ++i) {
339             nbElem = rewriter.create<mlir::arith::MulIOp>(
340                 loc, nbElem, builder.loadIfRef(loc, op.getShape()[i]));
341           }
342         } else {
343           nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
344                                                  seqTy.getConstantArraySize());
345         }
346         bytes = rewriter.create<mlir::arith::MulIOp>(loc, nbElem, width);
347       } else if (fir::isa_derived(op.getInType())) {
348         mlir::Type structTy = typeConverter->convertType(op.getInType());
349         std::size_t structSize = dl->getTypeSizeInBits(structTy) / 8;
350         bytes = builder.createIntegerConstant(loc, builder.getIndexType(),
351                                               structSize);
352       } else {
353         mlir::emitError(loc, "unsupported type in cuf.alloc\n");
354       }
355       mlir::func::FuncOp func =
356           fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
357       auto fTy = func.getFunctionType();
358       mlir::Value sourceLine =
359           fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
360       mlir::Value memTy = builder.createIntegerConstant(
361           loc, builder.getI32Type(), getMemType(op.getDataAttr()));
362       llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
363           builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
364       auto callOp = builder.create<fir::CallOp>(loc, func, args);
365       callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
366       auto convOp = builder.createConvert(loc, op.getResult().getType(),
367                                           callOp.getResult(0));
368       rewriter.replaceOp(op, convOp);
369       return mlir::success();
370     }
371 
372     // Convert descriptor allocations to function call.
373     auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
374     mlir::func::FuncOp func =
375         fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder);
376     auto fTy = func.getFunctionType();
377     mlir::Value sourceLine =
378         fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
379 
380     mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy);
381     std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
382     mlir::Value sizeInBytes =
383         builder.createIntegerConstant(loc, builder.getIndexType(), boxSize);
384 
385     llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
386         builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
387     auto callOp = builder.create<fir::CallOp>(loc, func, args);
388     callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
389     auto convOp = builder.createConvert(loc, op.getResult().getType(),
390                                         callOp.getResult(0));
391     rewriter.replaceOp(op, convOp);
392     return mlir::success();
393   }
394 
395 private:
396   mlir::DataLayout *dl;
397   const fir::LLVMTypeConverter *typeConverter;
398 };
399 
400 struct CUFDeviceAddressOpConversion
401     : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
402   using OpRewritePattern::OpRewritePattern;
403 
404   CUFDeviceAddressOpConversion(mlir::MLIRContext *context,
405                                const mlir::SymbolTable &symtab)
406       : OpRewritePattern(context), symTab{symtab} {}
407 
408   mlir::LogicalResult
409   matchAndRewrite(cuf::DeviceAddressOp op,
410                   mlir::PatternRewriter &rewriter) const override {
411     if (auto global = symTab.lookup<fir::GlobalOp>(
412             op.getHostSymbol().getRootReference().getValue())) {
413       auto mod = op->getParentOfType<mlir::ModuleOp>();
414       mlir::Location loc = op.getLoc();
415       auto hostAddr = rewriter.create<fir::AddrOfOp>(
416           loc, fir::ReferenceType::get(global.getType()), op.getHostSymbol());
417       fir::FirOpBuilder builder(rewriter, mod);
418       mlir::func::FuncOp callee =
419           fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc,
420                                                                      builder);
421       auto fTy = callee.getFunctionType();
422       mlir::Value conv =
423           createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr);
424       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
425       mlir::Value sourceLine =
426           fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
427       llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
428           builder, loc, fTy, conv, sourceFile, sourceLine)};
429       auto call = rewriter.create<fir::CallOp>(loc, callee, args);
430       mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(),
431                                          call->getResult(0));
432       rewriter.replaceOp(op, addr.getDefiningOp());
433       return success();
434     }
435     return failure();
436   }
437 
438 private:
439   const mlir::SymbolTable &symTab;
440 };
441 
442 struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
443   using OpRewritePattern::OpRewritePattern;
444 
445   DeclareOpConversion(mlir::MLIRContext *context,
446                       const mlir::SymbolTable &symtab)
447       : OpRewritePattern(context), symTab{symtab} {}
448 
449   mlir::LogicalResult
450   matchAndRewrite(fir::DeclareOp op,
451                   mlir::PatternRewriter &rewriter) const override {
452     if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
453       if (auto global = symTab.lookup<fir::GlobalOp>(
454               addrOfOp.getSymbol().getRootReference().getValue())) {
455         if (cuf::isRegisteredDeviceGlobal(global)) {
456           rewriter.setInsertionPointAfter(addrOfOp);
457           mlir::Value devAddr = rewriter.create<cuf::DeviceAddressOp>(
458               op.getLoc(), addrOfOp.getType(), addrOfOp.getSymbol());
459           rewriter.startOpModification(op);
460           op.getMemrefMutable().assign(devAddr);
461           rewriter.finalizeOpModification(op);
462           return success();
463         }
464       }
465     }
466     return failure();
467   }
468 
469 private:
470   const mlir::SymbolTable &symTab;
471 };
472 
473 struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
474   using OpRewritePattern::OpRewritePattern;
475 
476   mlir::LogicalResult
477   matchAndRewrite(cuf::FreeOp op,
478                   mlir::PatternRewriter &rewriter) const override {
479     if (inDeviceContext(op.getOperation())) {
480       rewriter.eraseOp(op);
481       return mlir::success();
482     }
483 
484     if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
485       return failure();
486 
487     auto mod = op->getParentOfType<mlir::ModuleOp>();
488     fir::FirOpBuilder builder(rewriter, mod);
489     mlir::Location loc = op.getLoc();
490     mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
491 
492     auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
493     if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
494       mlir::func::FuncOp func =
495           fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
496       auto fTy = func.getFunctionType();
497       mlir::Value sourceLine =
498           fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
499       mlir::Value memTy = builder.createIntegerConstant(
500           loc, builder.getI32Type(), getMemType(op.getDataAttr()));
501       llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
502           builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
503       builder.create<fir::CallOp>(loc, func, args);
504       rewriter.eraseOp(op);
505       return mlir::success();
506     }
507 
508     // Convert cuf.free on descriptors.
509     mlir::func::FuncOp func =
510         fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder);
511     auto fTy = func.getFunctionType();
512     mlir::Value sourceLine =
513         fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
514     llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
515         builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
516     auto callOp = builder.create<fir::CallOp>(loc, func, args);
517     callOp->setAttr(cuf::getDataAttrName(), op.getDataAttrAttr());
518     rewriter.eraseOp(op);
519     return mlir::success();
520   }
521 };
522 
523 static bool isDstGlobal(cuf::DataTransferOp op) {
524   if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>())
525     if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
526       return true;
527   if (auto declareOp = op.getDst().getDefiningOp<hlfir::DeclareOp>())
528     if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
529       return true;
530   return false;
531 }
532 
533 static mlir::Value getShapeFromDecl(mlir::Value src) {
534   if (auto declareOp = src.getDefiningOp<fir::DeclareOp>())
535     return declareOp.getShape();
536   if (auto declareOp = src.getDefiningOp<hlfir::DeclareOp>())
537     return declareOp.getShape();
538   return mlir::Value{};
539 }
540 
541 static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
542                             cuf::DataTransferOp op,
543                             const mlir::SymbolTable &symtab) {
544   auto mod = op->getParentOfType<mlir::ModuleOp>();
545   mlir::Location loc = op.getLoc();
546   fir::FirOpBuilder builder(rewriter, mod);
547   mlir::Value addr;
548   mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
549   if (fir::isa_trivial(srcTy) &&
550       mlir::matchPattern(op.getSrc().getDefiningOp(), mlir::m_Constant())) {
551     mlir::Value src = op.getSrc();
552     if (srcTy.isInteger(1)) {
553       // i1 is not a supported type in the descriptor and it is actually coming
554       // from a LOGICAL constant. Store it as a fir.logical.
555       srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
556       src = createConvertOp(rewriter, loc, srcTy, src);
557     }
558     // Put constant in memory if it is not.
559     mlir::Value alloc = builder.createTemporary(loc, srcTy);
560     builder.create<fir::StoreOp>(loc, src, alloc);
561     addr = alloc;
562   } else {
563     addr = op.getSrc();
564   }
565   llvm::SmallVector<mlir::Value> lenParams;
566   mlir::Type boxTy = fir::BoxType::get(srcTy);
567   mlir::Value box =
568       builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getSrc()),
569                         /*slice=*/nullptr, lenParams,
570                         /*tdesc=*/nullptr);
571   mlir::Value src = builder.createTemporary(loc, box.getType());
572   builder.create<fir::StoreOp>(loc, box, src);
573   return src;
574 }
575 
576 static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
577                             cuf::DataTransferOp op,
578                             const mlir::SymbolTable &symtab) {
579   auto mod = op->getParentOfType<mlir::ModuleOp>();
580   mlir::Location loc = op.getLoc();
581   fir::FirOpBuilder builder(rewriter, mod);
582   mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
583   mlir::Value dstAddr = op.getDst();
584   mlir::Type dstBoxTy = fir::BoxType::get(dstTy);
585   llvm::SmallVector<mlir::Value> lenParams;
586   mlir::Value dstBox =
587       builder.createBox(loc, dstBoxTy, dstAddr, getShapeFromDecl(op.getDst()),
588                         /*slice=*/nullptr, lenParams,
589                         /*tdesc=*/nullptr);
590   mlir::Value dst = builder.createTemporary(loc, dstBox.getType());
591   builder.create<fir::StoreOp>(loc, dstBox, dst);
592   return dst;
593 }
594 
595 struct CUFDataTransferOpConversion
596     : public mlir::OpRewritePattern<cuf::DataTransferOp> {
597   using OpRewritePattern::OpRewritePattern;
598 
599   CUFDataTransferOpConversion(mlir::MLIRContext *context,
600                               const mlir::SymbolTable &symtab,
601                               mlir::DataLayout *dl,
602                               const fir::LLVMTypeConverter *typeConverter)
603       : OpRewritePattern(context), symtab{symtab}, dl{dl},
604         typeConverter{typeConverter} {}
605 
606   mlir::LogicalResult
607   matchAndRewrite(cuf::DataTransferOp op,
608                   mlir::PatternRewriter &rewriter) const override {
609 
610     mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
611     mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
612 
613     mlir::Location loc = op.getLoc();
614     unsigned mode = 0;
615     if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) {
616       mode = kHostToDevice;
617     } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceHost) {
618       mode = kDeviceToHost;
619     } else if (op.getTransferKind() == cuf::DataTransferKind::DeviceDevice) {
620       mode = kDeviceToDevice;
621     } else {
622       mlir::emitError(loc, "unsupported transfer kind\n");
623     }
624 
625     auto mod = op->getParentOfType<mlir::ModuleOp>();
626     fir::FirOpBuilder builder(rewriter, mod);
627     fir::KindMapping kindMap{fir::getKindMapping(mod)};
628     mlir::Value modeValue =
629         builder.createIntegerConstant(loc, builder.getI32Type(), mode);
630 
631     // Convert data transfer without any descriptor.
632     if (!mlir::isa<fir::BaseBoxType>(srcTy) &&
633         !mlir::isa<fir::BaseBoxType>(dstTy)) {
634 
635       if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) {
636         // Initialization of an array from a scalar value should be implemented
637         // via a kernel launch. Use the flan runtime via the Assign function
638         // until we have more infrastructure.
639         mlir::Value src = emboxSrc(rewriter, op, symtab);
640         mlir::Value dst = emboxDst(rewriter, op, symtab);
641         mlir::func::FuncOp func =
642             fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
643                 loc, builder);
644         auto fTy = func.getFunctionType();
645         mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
646         mlir::Value sourceLine =
647             fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
648         llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
649             builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
650         builder.create<fir::CallOp>(loc, func, args);
651         rewriter.eraseOp(op);
652         return mlir::success();
653       }
654 
655       mlir::Type i64Ty = builder.getI64Type();
656       mlir::Value nbElement;
657       if (op.getShape()) {
658         llvm::SmallVector<mlir::Value> extents;
659         if (auto shapeOp =
660                 mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp())) {
661           extents = shapeOp.getExtents();
662         } else if (auto shapeShiftOp = mlir::dyn_cast<fir::ShapeShiftOp>(
663                        op.getShape().getDefiningOp())) {
664           for (auto i : llvm::enumerate(shapeShiftOp.getPairs()))
665             if (i.index() & 1)
666               extents.push_back(i.value());
667         }
668 
669         nbElement = rewriter.create<fir::ConvertOp>(loc, i64Ty, extents[0]);
670         for (unsigned i = 1; i < extents.size(); ++i) {
671           auto operand =
672               rewriter.create<fir::ConvertOp>(loc, i64Ty, extents[i]);
673           nbElement =
674               rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
675         }
676       } else {
677         if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
678           nbElement = builder.createIntegerConstant(
679               loc, i64Ty, seqTy.getConstantArraySize());
680       }
681       unsigned width = 0;
682       if (fir::isa_derived(fir::unwrapSequenceType(dstTy))) {
683         mlir::Type structTy =
684             typeConverter->convertType(fir::unwrapSequenceType(dstTy));
685         width = dl->getTypeSizeInBits(structTy) / 8;
686       } else {
687         width = computeWidth(loc, dstTy, kindMap);
688       }
689       mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
690           loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
691       mlir::Value bytes =
692           nbElement
693               ? rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue)
694               : widthValue;
695 
696       mlir::func::FuncOp func =
697           fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrPtr)>(loc,
698                                                                        builder);
699       auto fTy = func.getFunctionType();
700       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
701       mlir::Value sourceLine =
702           fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
703 
704       mlir::Value dst = op.getDst();
705       mlir::Value src = op.getSrc();
706       // Materialize the src if constant.
707       if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
708         mlir::Value temp = builder.createTemporary(loc, srcTy);
709         builder.create<fir::StoreOp>(loc, src, temp);
710         src = temp;
711       }
712       llvm::SmallVector<mlir::Value> args{
713           fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
714                                         modeValue, sourceFile, sourceLine)};
715       builder.create<fir::CallOp>(loc, func, args);
716       rewriter.eraseOp(op);
717       return mlir::success();
718     }
719 
720     auto materializeBoxIfNeeded = [&](mlir::Value val) -> mlir::Value {
721       if (mlir::isa<fir::EmboxOp, fir::ReboxOp>(val.getDefiningOp())) {
722         // Materialize the box to memory to be able to call the runtime.
723         mlir::Value box = builder.createTemporary(loc, val.getType());
724         builder.create<fir::StoreOp>(loc, val, box);
725         return box;
726       }
727       return val;
728     };
729 
730     // Conversion of data transfer involving at least one descriptor.
731     if (mlir::isa<fir::BaseBoxType>(dstTy)) {
732       // Transfer to a descriptor.
733       mlir::func::FuncOp func =
734           isDstGlobal(op)
735               ? fir::runtime::getRuntimeFunc<mkRTKey(
736                     CUFDataTransferGlobalDescDesc)>(loc, builder)
737               : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
738                     loc, builder);
739       mlir::Value dst = op.getDst();
740       mlir::Value src = op.getSrc();
741       if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
742         src = emboxSrc(rewriter, op, symtab);
743         if (fir::isa_trivial(srcTy))
744           func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
745               loc, builder);
746       }
747 
748       src = materializeBoxIfNeeded(src);
749       dst = materializeBoxIfNeeded(dst);
750 
751       auto fTy = func.getFunctionType();
752       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
753       mlir::Value sourceLine =
754           fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
755       llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
756           builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
757       builder.create<fir::CallOp>(loc, func, args);
758       rewriter.eraseOp(op);
759     } else {
760       // Transfer from a descriptor.
761       mlir::Value dst = emboxDst(rewriter, op, symtab);
762       mlir::Value src = materializeBoxIfNeeded(op.getSrc());
763 
764       mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
765           CUFDataTransferDescDescNoRealloc)>(loc, builder);
766 
767       auto fTy = func.getFunctionType();
768       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
769       mlir::Value sourceLine =
770           fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
771       llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
772           builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
773       builder.create<fir::CallOp>(loc, func, args);
774       rewriter.eraseOp(op);
775     }
776     return mlir::success();
777   }
778 
779 private:
780   const mlir::SymbolTable &symtab;
781   mlir::DataLayout *dl;
782   const fir::LLVMTypeConverter *typeConverter;
783 };
784 
785 struct CUFLaunchOpConversion
786     : public mlir::OpRewritePattern<cuf::KernelLaunchOp> {
787 public:
788   using OpRewritePattern::OpRewritePattern;
789 
790   CUFLaunchOpConversion(mlir::MLIRContext *context,
791                         const mlir::SymbolTable &symTab)
792       : OpRewritePattern(context), symTab{symTab} {}
793 
794   mlir::LogicalResult
795   matchAndRewrite(cuf::KernelLaunchOp op,
796                   mlir::PatternRewriter &rewriter) const override {
797     mlir::Location loc = op.getLoc();
798     auto idxTy = mlir::IndexType::get(op.getContext());
799     auto zero = rewriter.create<mlir::arith::ConstantOp>(
800         loc, rewriter.getIntegerType(32), rewriter.getI32IntegerAttr(0));
801     auto gridSizeX =
802         rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridX());
803     auto gridSizeY =
804         rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridY());
805     auto gridSizeZ =
806         rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridZ());
807     auto blockSizeX =
808         rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockX());
809     auto blockSizeY =
810         rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockY());
811     auto blockSizeZ =
812         rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockZ());
813     auto kernelName = mlir::SymbolRefAttr::get(
814         rewriter.getStringAttr(cudaDeviceModuleName),
815         {mlir::SymbolRefAttr::get(
816             rewriter.getContext(),
817             op.getCallee().getLeafReference().getValue())});
818     mlir::Value clusterDimX, clusterDimY, clusterDimZ;
819     cuf::ProcAttributeAttr procAttr;
820     if (auto funcOp = symTab.lookup<mlir::func::FuncOp>(
821             op.getCallee().getLeafReference())) {
822       if (auto clusterDimsAttr = funcOp->getAttrOfType<cuf::ClusterDimsAttr>(
823               cuf::getClusterDimsAttrName())) {
824         clusterDimX = rewriter.create<mlir::arith::ConstantIndexOp>(
825             loc, clusterDimsAttr.getX().getInt());
826         clusterDimY = rewriter.create<mlir::arith::ConstantIndexOp>(
827             loc, clusterDimsAttr.getY().getInt());
828         clusterDimZ = rewriter.create<mlir::arith::ConstantIndexOp>(
829             loc, clusterDimsAttr.getZ().getInt());
830       }
831       procAttr =
832           funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName());
833     }
834     llvm::SmallVector<mlir::Value> args;
835     for (mlir::Value arg : op.getArgs()) {
836       // If the argument is a global descriptor, make sure we pass the device
837       // copy of this descriptor and not the host one.
838       if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(arg.getType()))) {
839         if (auto declareOp =
840                 mlir::dyn_cast_or_null<fir::DeclareOp>(arg.getDefiningOp())) {
841           if (auto addrOfOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
842                   declareOp.getMemref().getDefiningOp())) {
843             if (auto global = symTab.lookup<fir::GlobalOp>(
844                     addrOfOp.getSymbol().getRootReference().getValue())) {
845               if (cuf::isRegisteredDeviceGlobal(global)) {
846                 arg = rewriter
847                           .create<cuf::DeviceAddressOp>(op.getLoc(),
848                                                         addrOfOp.getType(),
849                                                         addrOfOp.getSymbol())
850                           .getResult();
851               }
852             }
853           }
854         }
855       }
856       args.push_back(arg);
857     }
858 
859     auto gpuLaunchOp = rewriter.create<mlir::gpu::LaunchFuncOp>(
860         loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
861         mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero, args);
862     if (clusterDimX && clusterDimY && clusterDimZ) {
863       gpuLaunchOp.getClusterSizeXMutable().assign(clusterDimX);
864       gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY);
865       gpuLaunchOp.getClusterSizeZMutable().assign(clusterDimZ);
866     }
867     if (procAttr)
868       gpuLaunchOp->setAttr(cuf::getProcAttrName(), procAttr);
869     rewriter.replaceOp(op, gpuLaunchOp);
870     return mlir::success();
871   }
872 
873 private:
874   const mlir::SymbolTable &symTab;
875 };
876 
877 struct CUFSyncDescriptorOpConversion
878     : public mlir::OpRewritePattern<cuf::SyncDescriptorOp> {
879   using OpRewritePattern::OpRewritePattern;
880 
881   mlir::LogicalResult
882   matchAndRewrite(cuf::SyncDescriptorOp op,
883                   mlir::PatternRewriter &rewriter) const override {
884     auto mod = op->getParentOfType<mlir::ModuleOp>();
885     fir::FirOpBuilder builder(rewriter, mod);
886     mlir::Location loc = op.getLoc();
887 
888     auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName());
889     if (!globalOp)
890       return mlir::failure();
891 
892     auto hostAddr = builder.create<fir::AddrOfOp>(
893         loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName());
894     mlir::func::FuncOp callee =
895         fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc,
896                                                                        builder);
897     auto fTy = callee.getFunctionType();
898     mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
899     mlir::Value sourceLine =
900         fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
901     llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
902         builder, loc, fTy, hostAddr, sourceFile, sourceLine)};
903     builder.create<fir::CallOp>(loc, callee, args);
904     op.erase();
905     return mlir::success();
906   }
907 };
908 
909 class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
910 public:
911   void runOnOperation() override {
912     auto *ctx = &getContext();
913     mlir::RewritePatternSet patterns(ctx);
914     mlir::ConversionTarget target(*ctx);
915 
916     mlir::Operation *op = getOperation();
917     mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
918     if (!module)
919       return signalPassFailure();
920     mlir::SymbolTable symtab(module);
921 
922     std::optional<mlir::DataLayout> dl =
923         fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
924     fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
925                                          /*forceUnifiedTBAATree=*/false, *dl);
926     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
927                            mlir::gpu::GPUDialect>();
928     cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
929                                             patterns);
930     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
931                                                   std::move(patterns)))) {
932       mlir::emitError(mlir::UnknownLoc::get(ctx),
933                       "error in CUF op conversion\n");
934       signalPassFailure();
935     }
936 
937     target.addDynamicallyLegalOp<fir::DeclareOp>([&](fir::DeclareOp op) {
938       if (inDeviceContext(op))
939         return true;
940       if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
941         if (auto global = symtab.lookup<fir::GlobalOp>(
942                 addrOfOp.getSymbol().getRootReference().getValue())) {
943           if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(global.getType())))
944             return true;
945           if (cuf::isRegisteredDeviceGlobal(global))
946             return false;
947         }
948       }
949       return true;
950     });
951 
952     patterns.clear();
953     cuf::populateFIRCUFConversionPatterns(symtab, patterns);
954     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
955                                                   std::move(patterns)))) {
956       mlir::emitError(mlir::UnknownLoc::get(ctx),
957                       "error in CUF op conversion\n");
958       signalPassFailure();
959     }
960   }
961 };
962 } // namespace
963 
964 void cuf::populateCUFToFIRConversionPatterns(
965     const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
966     const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
967   patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
968   patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
969                   CUFFreeOpConversion, CUFSyncDescriptorOpConversion>(
970       patterns.getContext());
971   patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
972                                                &dl, &converter);
973   patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
974       patterns.getContext(), symtab);
975 }
976 
977 void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
978                                            mlir::RewritePatternSet &patterns) {
979   patterns.insert<DeclareOpConversion, CUFDeviceAddressOpConversion>(
980       patterns.getContext(), symtab);
981 }
982