xref: /llvm-project/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp (revision e4578616476426595737c73c9ac357467ee19123)
1 //===-- CUFOps.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "flang/Optimizer/Dialect/CUF/CUFOps.h"
14 #include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
15 #include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
16 #include "flang/Optimizer/Dialect/FIRAttr.h"
17 #include "flang/Optimizer/Dialect/FIRType.h"
18 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Attributes.h"
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Diagnostics.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/OpDefinition.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/SmallVector.h"
28 
29 //===----------------------------------------------------------------------===//
30 // AllocOp
31 //===----------------------------------------------------------------------===//
32 
33 static mlir::Type wrapAllocaResultType(mlir::Type intype) {
34   if (mlir::isa<fir::ReferenceType>(intype))
35     return {};
36   return fir::ReferenceType::get(intype);
37 }
38 
39 void cuf::AllocOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
40                          mlir::Type inType, llvm::StringRef uniqName,
41                          llvm::StringRef bindcName,
42                          cuf::DataAttributeAttr cudaAttr,
43                          mlir::ValueRange typeparams, mlir::ValueRange shape,
44                          llvm::ArrayRef<mlir::NamedAttribute> attributes) {
45   mlir::StringAttr nameAttr =
46       uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName);
47   mlir::StringAttr bindcAttr =
48       bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName);
49   build(builder, result, wrapAllocaResultType(inType),
50         mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape,
51         cudaAttr);
52   result.addAttributes(attributes);
53 }
54 
55 template <typename Op>
56 static llvm::LogicalResult checkCudaAttr(Op op) {
57   if (op.getDataAttr() == cuf::DataAttribute::Device ||
58       op.getDataAttr() == cuf::DataAttribute::Managed ||
59       op.getDataAttr() == cuf::DataAttribute::Unified ||
60       op.getDataAttr() == cuf::DataAttribute::Pinned)
61     return mlir::success();
62   return op.emitOpError()
63          << "expect device, managed, pinned or unified cuda attribute";
64 }
65 
66 llvm::LogicalResult cuf::AllocOp::verify() { return checkCudaAttr(*this); }
67 
68 //===----------------------------------------------------------------------===//
69 // FreeOp
70 //===----------------------------------------------------------------------===//
71 
72 llvm::LogicalResult cuf::FreeOp::verify() { return checkCudaAttr(*this); }
73 
74 //===----------------------------------------------------------------------===//
75 // AllocateOp
76 //===----------------------------------------------------------------------===//
77 
78 llvm::LogicalResult cuf::AllocateOp::verify() {
79   if (getPinned() && getStream())
80     return emitOpError("pinned and stream cannot appears at the same time");
81   if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType())))
82     return emitOpError(
83         "expect box to be a reference to a class or box type value");
84   if (getSource() &&
85       !mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getSource().getType())))
86     return emitOpError(
87         "expect source to be a reference to/or a class or box type value");
88   if (getErrmsg() &&
89       !mlir::isa<fir::BoxType>(fir::unwrapRefType(getErrmsg().getType())))
90     return emitOpError(
91         "expect errmsg to be a reference to/or a box type value");
92   if (getErrmsg() && !getHasStat())
93     return emitOpError("expect stat attribute when errmsg is provided");
94   return mlir::success();
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // DataTransferOp
99 //===----------------------------------------------------------------------===//
100 
101 llvm::LogicalResult cuf::DataTransferOp::verify() {
102   mlir::Type srcTy = getSrc().getType();
103   mlir::Type dstTy = getDst().getType();
104   if (getShape()) {
105     if (!fir::isa_ref_type(srcTy) && !fir::isa_ref_type(dstTy))
106       return emitOpError()
107              << "shape can only be specified on data transfer with references";
108   }
109   if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) ||
110       (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)) ||
111       (fir::isa_ref_type(srcTy) && fir::isa_box_type(dstTy)) ||
112       (fir::isa_box_type(srcTy) && fir::isa_ref_type(dstTy)))
113     return mlir::success();
114   if (fir::isa_trivial(srcTy) &&
115       matchPattern(getSrc().getDefiningOp(), mlir::m_Constant()))
116     return mlir::success();
117 
118   return emitOpError()
119          << "expect src and dst to be references or descriptors or src to "
120             "be a constant: "
121          << srcTy << " - " << dstTy;
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // DeallocateOp
126 //===----------------------------------------------------------------------===//
127 
128 llvm::LogicalResult cuf::DeallocateOp::verify() {
129   if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType())))
130     return emitOpError(
131         "expect box to be a reference to class or box type value");
132   if (getErrmsg() &&
133       !mlir::isa<fir::BoxType>(fir::unwrapRefType(getErrmsg().getType())))
134     return emitOpError(
135         "expect errmsg to be a reference to/or a box type value");
136   if (getErrmsg() && !getHasStat())
137     return emitOpError("expect stat attribute when errmsg is provided");
138   return mlir::success();
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // KernelOp
143 //===----------------------------------------------------------------------===//
144 
145 llvm::SmallVector<mlir::Region *> cuf::KernelOp::getLoopRegions() {
146   return {&getRegion()};
147 }
148 
149 mlir::ParseResult parseCUFKernelValues(
150     mlir::OpAsmParser &parser,
151     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values,
152     llvm::SmallVectorImpl<mlir::Type> &types) {
153   if (mlir::succeeded(parser.parseOptionalStar()))
154     return mlir::success();
155 
156   if (mlir::succeeded(parser.parseOptionalLParen())) {
157     if (mlir::failed(parser.parseCommaSeparatedList(
158             mlir::AsmParser::Delimiter::None, [&]() {
159               if (parser.parseOperand(values.emplace_back()))
160                 return mlir::failure();
161               return mlir::success();
162             })))
163       return mlir::failure();
164     auto builder = parser.getBuilder();
165     for (size_t i = 0; i < values.size(); i++) {
166       types.emplace_back(builder.getI32Type());
167     }
168     if (parser.parseRParen())
169       return mlir::failure();
170   } else {
171     if (parser.parseOperand(values.emplace_back()))
172       return mlir::failure();
173     auto builder = parser.getBuilder();
174     types.emplace_back(builder.getI32Type());
175     return mlir::success();
176   }
177   return mlir::success();
178 }
179 
180 void printCUFKernelValues(mlir::OpAsmPrinter &p, mlir::Operation *op,
181                           mlir::ValueRange values, mlir::TypeRange types) {
182   if (values.empty())
183     p << "*";
184 
185   if (values.size() > 1)
186     p << "(";
187   llvm::interleaveComma(values, p, [&p](mlir::Value v) { p << v; });
188   if (values.size() > 1)
189     p << ")";
190 }
191 
192 mlir::ParseResult parseCUFKernelLoopControl(
193     mlir::OpAsmParser &parser, mlir::Region &region,
194     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &lowerbound,
195     llvm::SmallVectorImpl<mlir::Type> &lowerboundType,
196     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &upperbound,
197     llvm::SmallVectorImpl<mlir::Type> &upperboundType,
198     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &step,
199     llvm::SmallVectorImpl<mlir::Type> &stepType) {
200 
201   llvm::SmallVector<mlir::OpAsmParser::Argument> inductionVars;
202   if (parser.parseLParen() ||
203       parser.parseArgumentList(inductionVars,
204                                mlir::OpAsmParser::Delimiter::None,
205                                /*allowType=*/true) ||
206       parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
207       parser.parseOperandList(lowerbound, inductionVars.size(),
208                               mlir::OpAsmParser::Delimiter::None) ||
209       parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
210       parser.parseKeyword("to") || parser.parseLParen() ||
211       parser.parseOperandList(upperbound, inductionVars.size(),
212                               mlir::OpAsmParser::Delimiter::None) ||
213       parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
214       parser.parseKeyword("step") || parser.parseLParen() ||
215       parser.parseOperandList(step, inductionVars.size(),
216                               mlir::OpAsmParser::Delimiter::None) ||
217       parser.parseColonTypeList(stepType) || parser.parseRParen())
218     return mlir::failure();
219   return parser.parseRegion(region, inductionVars);
220 }
221 
222 void printCUFKernelLoopControl(
223     mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Region &region,
224     mlir::ValueRange lowerbound, mlir::TypeRange lowerboundType,
225     mlir::ValueRange upperbound, mlir::TypeRange upperboundType,
226     mlir::ValueRange steps, mlir::TypeRange stepType) {
227   mlir::ValueRange regionArgs = region.front().getArguments();
228   if (!regionArgs.empty()) {
229     p << "(";
230     llvm::interleaveComma(
231         regionArgs, p, [&p](mlir::Value v) { p << v << " : " << v.getType(); });
232     p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
233       << upperbound << " : " << upperboundType << ") "
234       << " step (" << steps << " : " << stepType << ") ";
235   }
236   p.printRegion(region, /*printEntryBlockArgs=*/false);
237 }
238 
239 llvm::LogicalResult cuf::KernelOp::verify() {
240   if (getLowerbound().size() != getUpperbound().size() ||
241       getLowerbound().size() != getStep().size())
242     return emitOpError(
243         "expect same number of values in lowerbound, upperbound and step");
244   auto reduceAttrs = getReduceAttrs();
245   std::size_t reduceAttrsSize = reduceAttrs ? reduceAttrs->size() : 0;
246   if (getReduceOperands().size() != reduceAttrsSize)
247     return emitOpError("expect same number of values in reduce operands and "
248                        "reduce attributes");
249   if (reduceAttrs) {
250     for (const auto &attr : reduceAttrs.value()) {
251       if (!mlir::isa<fir::ReduceAttr>(attr))
252         return emitOpError("expect reduce attributes to be ReduceAttr");
253     }
254   }
255   return mlir::success();
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // RegisterKernelOp
260 //===----------------------------------------------------------------------===//
261 
262 mlir::StringAttr cuf::RegisterKernelOp::getKernelModuleName() {
263   return getName().getRootReference();
264 }
265 
266 mlir::StringAttr cuf::RegisterKernelOp::getKernelName() {
267   return getName().getLeafReference();
268 }
269 
270 mlir::LogicalResult cuf::RegisterKernelOp::verify() {
271   if (getKernelName() == getKernelModuleName())
272     return emitOpError("expect a module and a kernel name");
273 
274   auto mod = getOperation()->getParentOfType<mlir::ModuleOp>();
275   if (!mod)
276     return emitOpError("expect to be in a module");
277 
278   mlir::SymbolTable symTab(mod);
279   auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(getKernelModuleName());
280   if (!gpuMod) {
281     // If already a gpu.binary then stop the check here.
282     if (symTab.lookup<mlir::gpu::BinaryOp>(getKernelModuleName()))
283       return mlir::success();
284     return emitOpError("gpu module not found");
285   }
286 
287   mlir::SymbolTable gpuSymTab(gpuMod);
288   if (auto func = gpuSymTab.lookup<mlir::gpu::GPUFuncOp>(getKernelName())) {
289     if (!func.isKernel())
290       return emitOpError("only kernel gpu.func can be registered");
291     return mlir::success();
292   } else if (auto func =
293                  gpuSymTab.lookup<mlir::LLVM::LLVMFuncOp>(getKernelName())) {
294     if (!func->getAttrOfType<mlir::UnitAttr>(
295             mlir::gpu::GPUDialect::getKernelFuncAttrName()))
296       return emitOpError("only gpu.kernel llvm.func can be registered");
297     return mlir::success();
298   }
299   return emitOpError("device function not found");
300 }
301 
302 // Tablegen operators
303 
304 #define GET_OP_CLASSES
305 #include "flang/Optimizer/Dialect/CUF/CUFOps.cpp.inc"
306