1 //===-- CUFOps.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "flang/Optimizer/Dialect/CUF/CUFOps.h" 14 #include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h" 15 #include "flang/Optimizer/Dialect/CUF/CUFDialect.h" 16 #include "flang/Optimizer/Dialect/FIRAttr.h" 17 #include "flang/Optimizer/Dialect/FIRType.h" 18 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/IR/Attributes.h" 21 #include "mlir/IR/BuiltinAttributes.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/Diagnostics.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/OpDefinition.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "llvm/ADT/SmallVector.h" 28 29 //===----------------------------------------------------------------------===// 30 // AllocOp 31 //===----------------------------------------------------------------------===// 32 33 static mlir::Type wrapAllocaResultType(mlir::Type intype) { 34 if (mlir::isa<fir::ReferenceType>(intype)) 35 return {}; 36 return fir::ReferenceType::get(intype); 37 } 38 39 void cuf::AllocOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, 40 mlir::Type inType, llvm::StringRef uniqName, 41 llvm::StringRef bindcName, 42 cuf::DataAttributeAttr cudaAttr, 43 mlir::ValueRange typeparams, mlir::ValueRange shape, 44 llvm::ArrayRef<mlir::NamedAttribute> attributes) { 45 mlir::StringAttr nameAttr = 46 uniqName.empty() ? mlir::StringAttr{} : builder.getStringAttr(uniqName); 47 mlir::StringAttr bindcAttr = 48 bindcName.empty() ? mlir::StringAttr{} : builder.getStringAttr(bindcName); 49 build(builder, result, wrapAllocaResultType(inType), 50 mlir::TypeAttr::get(inType), nameAttr, bindcAttr, typeparams, shape, 51 cudaAttr); 52 result.addAttributes(attributes); 53 } 54 55 template <typename Op> 56 static llvm::LogicalResult checkCudaAttr(Op op) { 57 if (op.getDataAttr() == cuf::DataAttribute::Device || 58 op.getDataAttr() == cuf::DataAttribute::Managed || 59 op.getDataAttr() == cuf::DataAttribute::Unified || 60 op.getDataAttr() == cuf::DataAttribute::Pinned) 61 return mlir::success(); 62 return op.emitOpError() 63 << "expect device, managed, pinned or unified cuda attribute"; 64 } 65 66 llvm::LogicalResult cuf::AllocOp::verify() { return checkCudaAttr(*this); } 67 68 //===----------------------------------------------------------------------===// 69 // FreeOp 70 //===----------------------------------------------------------------------===// 71 72 llvm::LogicalResult cuf::FreeOp::verify() { return checkCudaAttr(*this); } 73 74 //===----------------------------------------------------------------------===// 75 // AllocateOp 76 //===----------------------------------------------------------------------===// 77 78 llvm::LogicalResult cuf::AllocateOp::verify() { 79 if (getPinned() && getStream()) 80 return emitOpError("pinned and stream cannot appears at the same time"); 81 if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType()))) 82 return emitOpError( 83 "expect box to be a reference to a class or box type value"); 84 if (getSource() && 85 !mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getSource().getType()))) 86 return emitOpError( 87 "expect source to be a reference to/or a class or box type value"); 88 if (getErrmsg() && 89 !mlir::isa<fir::BoxType>(fir::unwrapRefType(getErrmsg().getType()))) 90 return emitOpError( 91 "expect errmsg to be a reference to/or a box type value"); 92 if (getErrmsg() && !getHasStat()) 93 return emitOpError("expect stat attribute when errmsg is provided"); 94 return mlir::success(); 95 } 96 97 //===----------------------------------------------------------------------===// 98 // DataTransferOp 99 //===----------------------------------------------------------------------===// 100 101 llvm::LogicalResult cuf::DataTransferOp::verify() { 102 mlir::Type srcTy = getSrc().getType(); 103 mlir::Type dstTy = getDst().getType(); 104 if (getShape()) { 105 if (!fir::isa_ref_type(srcTy) && !fir::isa_ref_type(dstTy)) 106 return emitOpError() 107 << "shape can only be specified on data transfer with references"; 108 } 109 if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) || 110 (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)) || 111 (fir::isa_ref_type(srcTy) && fir::isa_box_type(dstTy)) || 112 (fir::isa_box_type(srcTy) && fir::isa_ref_type(dstTy))) 113 return mlir::success(); 114 if (fir::isa_trivial(srcTy) && 115 matchPattern(getSrc().getDefiningOp(), mlir::m_Constant())) 116 return mlir::success(); 117 118 return emitOpError() 119 << "expect src and dst to be references or descriptors or src to " 120 "be a constant: " 121 << srcTy << " - " << dstTy; 122 } 123 124 //===----------------------------------------------------------------------===// 125 // DeallocateOp 126 //===----------------------------------------------------------------------===// 127 128 llvm::LogicalResult cuf::DeallocateOp::verify() { 129 if (!mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(getBox().getType()))) 130 return emitOpError( 131 "expect box to be a reference to class or box type value"); 132 if (getErrmsg() && 133 !mlir::isa<fir::BoxType>(fir::unwrapRefType(getErrmsg().getType()))) 134 return emitOpError( 135 "expect errmsg to be a reference to/or a box type value"); 136 if (getErrmsg() && !getHasStat()) 137 return emitOpError("expect stat attribute when errmsg is provided"); 138 return mlir::success(); 139 } 140 141 //===----------------------------------------------------------------------===// 142 // KernelOp 143 //===----------------------------------------------------------------------===// 144 145 llvm::SmallVector<mlir::Region *> cuf::KernelOp::getLoopRegions() { 146 return {&getRegion()}; 147 } 148 149 mlir::ParseResult parseCUFKernelValues( 150 mlir::OpAsmParser &parser, 151 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values, 152 llvm::SmallVectorImpl<mlir::Type> &types) { 153 if (mlir::succeeded(parser.parseOptionalStar())) 154 return mlir::success(); 155 156 if (mlir::succeeded(parser.parseOptionalLParen())) { 157 if (mlir::failed(parser.parseCommaSeparatedList( 158 mlir::AsmParser::Delimiter::None, [&]() { 159 if (parser.parseOperand(values.emplace_back())) 160 return mlir::failure(); 161 return mlir::success(); 162 }))) 163 return mlir::failure(); 164 auto builder = parser.getBuilder(); 165 for (size_t i = 0; i < values.size(); i++) { 166 types.emplace_back(builder.getI32Type()); 167 } 168 if (parser.parseRParen()) 169 return mlir::failure(); 170 } else { 171 if (parser.parseOperand(values.emplace_back())) 172 return mlir::failure(); 173 auto builder = parser.getBuilder(); 174 types.emplace_back(builder.getI32Type()); 175 return mlir::success(); 176 } 177 return mlir::success(); 178 } 179 180 void printCUFKernelValues(mlir::OpAsmPrinter &p, mlir::Operation *op, 181 mlir::ValueRange values, mlir::TypeRange types) { 182 if (values.empty()) 183 p << "*"; 184 185 if (values.size() > 1) 186 p << "("; 187 llvm::interleaveComma(values, p, [&p](mlir::Value v) { p << v; }); 188 if (values.size() > 1) 189 p << ")"; 190 } 191 192 mlir::ParseResult parseCUFKernelLoopControl( 193 mlir::OpAsmParser &parser, mlir::Region ®ion, 194 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &lowerbound, 195 llvm::SmallVectorImpl<mlir::Type> &lowerboundType, 196 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &upperbound, 197 llvm::SmallVectorImpl<mlir::Type> &upperboundType, 198 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &step, 199 llvm::SmallVectorImpl<mlir::Type> &stepType) { 200 201 llvm::SmallVector<mlir::OpAsmParser::Argument> inductionVars; 202 if (parser.parseLParen() || 203 parser.parseArgumentList(inductionVars, 204 mlir::OpAsmParser::Delimiter::None, 205 /*allowType=*/true) || 206 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() || 207 parser.parseOperandList(lowerbound, inductionVars.size(), 208 mlir::OpAsmParser::Delimiter::None) || 209 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() || 210 parser.parseKeyword("to") || parser.parseLParen() || 211 parser.parseOperandList(upperbound, inductionVars.size(), 212 mlir::OpAsmParser::Delimiter::None) || 213 parser.parseColonTypeList(upperboundType) || parser.parseRParen() || 214 parser.parseKeyword("step") || parser.parseLParen() || 215 parser.parseOperandList(step, inductionVars.size(), 216 mlir::OpAsmParser::Delimiter::None) || 217 parser.parseColonTypeList(stepType) || parser.parseRParen()) 218 return mlir::failure(); 219 return parser.parseRegion(region, inductionVars); 220 } 221 222 void printCUFKernelLoopControl( 223 mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Region ®ion, 224 mlir::ValueRange lowerbound, mlir::TypeRange lowerboundType, 225 mlir::ValueRange upperbound, mlir::TypeRange upperboundType, 226 mlir::ValueRange steps, mlir::TypeRange stepType) { 227 mlir::ValueRange regionArgs = region.front().getArguments(); 228 if (!regionArgs.empty()) { 229 p << "("; 230 llvm::interleaveComma( 231 regionArgs, p, [&p](mlir::Value v) { p << v << " : " << v.getType(); }); 232 p << ") = (" << lowerbound << " : " << lowerboundType << ") to (" 233 << upperbound << " : " << upperboundType << ") " 234 << " step (" << steps << " : " << stepType << ") "; 235 } 236 p.printRegion(region, /*printEntryBlockArgs=*/false); 237 } 238 239 llvm::LogicalResult cuf::KernelOp::verify() { 240 if (getLowerbound().size() != getUpperbound().size() || 241 getLowerbound().size() != getStep().size()) 242 return emitOpError( 243 "expect same number of values in lowerbound, upperbound and step"); 244 auto reduceAttrs = getReduceAttrs(); 245 std::size_t reduceAttrsSize = reduceAttrs ? reduceAttrs->size() : 0; 246 if (getReduceOperands().size() != reduceAttrsSize) 247 return emitOpError("expect same number of values in reduce operands and " 248 "reduce attributes"); 249 if (reduceAttrs) { 250 for (const auto &attr : reduceAttrs.value()) { 251 if (!mlir::isa<fir::ReduceAttr>(attr)) 252 return emitOpError("expect reduce attributes to be ReduceAttr"); 253 } 254 } 255 return mlir::success(); 256 } 257 258 //===----------------------------------------------------------------------===// 259 // RegisterKernelOp 260 //===----------------------------------------------------------------------===// 261 262 mlir::StringAttr cuf::RegisterKernelOp::getKernelModuleName() { 263 return getName().getRootReference(); 264 } 265 266 mlir::StringAttr cuf::RegisterKernelOp::getKernelName() { 267 return getName().getLeafReference(); 268 } 269 270 mlir::LogicalResult cuf::RegisterKernelOp::verify() { 271 if (getKernelName() == getKernelModuleName()) 272 return emitOpError("expect a module and a kernel name"); 273 274 auto mod = getOperation()->getParentOfType<mlir::ModuleOp>(); 275 if (!mod) 276 return emitOpError("expect to be in a module"); 277 278 mlir::SymbolTable symTab(mod); 279 auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(getKernelModuleName()); 280 if (!gpuMod) { 281 // If already a gpu.binary then stop the check here. 282 if (symTab.lookup<mlir::gpu::BinaryOp>(getKernelModuleName())) 283 return mlir::success(); 284 return emitOpError("gpu module not found"); 285 } 286 287 mlir::SymbolTable gpuSymTab(gpuMod); 288 if (auto func = gpuSymTab.lookup<mlir::gpu::GPUFuncOp>(getKernelName())) { 289 if (!func.isKernel()) 290 return emitOpError("only kernel gpu.func can be registered"); 291 return mlir::success(); 292 } else if (auto func = 293 gpuSymTab.lookup<mlir::LLVM::LLVMFuncOp>(getKernelName())) { 294 if (!func->getAttrOfType<mlir::UnitAttr>( 295 mlir::gpu::GPUDialect::getKernelFuncAttrName())) 296 return emitOpError("only gpu.kernel llvm.func can be registered"); 297 return mlir::success(); 298 } 299 return emitOpError("device function not found"); 300 } 301 302 // Tablegen operators 303 304 #define GET_OP_CLASSES 305 #include "flang/Optimizer/Dialect/CUF/CUFOps.cpp.inc" 306