xref: /llvm-project/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp (revision f6c8e7dc3e1cbcecc2f01d898b895b96bb5723be)
1fee40fefSDeven Desai //===- ROCDLDialect.cpp - ROCDL IR Ops and Dialect registration -----------===//
2fee40fefSDeven Desai //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fee40fefSDeven Desai //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
8fee40fefSDeven Desai //
9fee40fefSDeven Desai // This file defines the types and operation details for the ROCDL IR dialect in
10fee40fefSDeven Desai // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11fee40fefSDeven Desai //
12fee40fefSDeven Desai // The ROCDL dialect only contains GPU specific additions on top of the general
13fee40fefSDeven Desai // LLVM dialect.
14fee40fefSDeven Desai //
15fee40fefSDeven Desai //===----------------------------------------------------------------------===//
16fee40fefSDeven Desai 
17fee40fefSDeven Desai #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18fee40fefSDeven Desai 
199779a731SMarkus Böck #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
20fee40fefSDeven Desai #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21fee40fefSDeven Desai #include "mlir/IR/Builders.h"
2209f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
2306821313SFabian Mora #include "mlir/IR/DialectImplementation.h"
24fee40fefSDeven Desai #include "mlir/IR/MLIRContext.h"
25fee40fefSDeven Desai #include "mlir/IR/Operation.h"
2606821313SFabian Mora #include "llvm/ADT/TypeSwitch.h"
27fee40fefSDeven Desai #include "llvm/AsmParser/Parser.h"
28fee40fefSDeven Desai #include "llvm/IR/Attributes.h"
29fee40fefSDeven Desai #include "llvm/IR/Function.h"
30fee40fefSDeven Desai #include "llvm/IR/Type.h"
31fee40fefSDeven Desai #include "llvm/Support/SourceMgr.h"
32fee40fefSDeven Desai 
33e7aa47ffSRiver Riddle using namespace mlir;
34e7aa47ffSRiver Riddle using namespace ROCDL;
35fee40fefSDeven Desai 
36485cc55eSStella Laurenzo #include "mlir/Dialect/LLVMIR/ROCDLOpsDialect.cpp.inc"
37485cc55eSStella Laurenzo 
38fee40fefSDeven Desai //===----------------------------------------------------------------------===//
399c53ac08Sjerryyin // Parsing for ROCDL ops
409c53ac08Sjerryyin //===----------------------------------------------------------------------===//
419c53ac08Sjerryyin 
429c53ac08Sjerryyin // <operation> ::=
43f1f05a91SKrzysztof Drewniak //     `llvm.amdgcn.raw.buffer.load.* %rsrc, %offset, %soffset, %aux
44f1f05a91SKrzysztof Drewniak //     : result_type`
parse(OpAsmParser & parser,OperationState & result)45f1f05a91SKrzysztof Drewniak ParseResult RawBufferLoadOp::parse(OpAsmParser &parser,
46f1f05a91SKrzysztof Drewniak                                    OperationState &result) {
47f1f05a91SKrzysztof Drewniak   SmallVector<OpAsmParser::UnresolvedOperand, 4> ops;
48f1f05a91SKrzysztof Drewniak   Type type;
49f1f05a91SKrzysztof Drewniak   if (parser.parseOperandList(ops, 4) || parser.parseColonType(type) ||
50f1f05a91SKrzysztof Drewniak       parser.addTypeToList(type, result.types))
51f1f05a91SKrzysztof Drewniak     return failure();
52f1f05a91SKrzysztof Drewniak 
53f1f05a91SKrzysztof Drewniak   auto bldr = parser.getBuilder();
54f1f05a91SKrzysztof Drewniak   auto int32Ty = bldr.getI32Type();
55f1f05a91SKrzysztof Drewniak   auto i32x4Ty = VectorType::get({4}, int32Ty);
56f1f05a91SKrzysztof Drewniak   return parser.resolveOperands(ops, {i32x4Ty, int32Ty, int32Ty, int32Ty},
57f1f05a91SKrzysztof Drewniak                                 parser.getNameLoc(), result.operands);
58f1f05a91SKrzysztof Drewniak }
59f1f05a91SKrzysztof Drewniak 
print(OpAsmPrinter & p)60f1f05a91SKrzysztof Drewniak void RawBufferLoadOp::print(OpAsmPrinter &p) {
618df54a6aSJacques Pienaar   p << " " << getOperands() << " : " << getRes().getType();
62f1f05a91SKrzysztof Drewniak }
63f1f05a91SKrzysztof Drewniak 
64f1f05a91SKrzysztof Drewniak // <operation> ::=
65f1f05a91SKrzysztof Drewniak //     `llvm.amdgcn.raw.buffer.store.* %vdata, %rsrc,  %offset,
66f1f05a91SKrzysztof Drewniak //     %soffset, %aux : result_type`
parse(OpAsmParser & parser,OperationState & result)67f1f05a91SKrzysztof Drewniak ParseResult RawBufferStoreOp::parse(OpAsmParser &parser,
68f1f05a91SKrzysztof Drewniak                                     OperationState &result) {
69f1f05a91SKrzysztof Drewniak   SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
70f1f05a91SKrzysztof Drewniak   Type type;
71f1f05a91SKrzysztof Drewniak   if (parser.parseOperandList(ops, 5) || parser.parseColonType(type))
72f1f05a91SKrzysztof Drewniak     return failure();
73f1f05a91SKrzysztof Drewniak 
74f1f05a91SKrzysztof Drewniak   auto bldr = parser.getBuilder();
75f1f05a91SKrzysztof Drewniak   auto int32Ty = bldr.getI32Type();
76f1f05a91SKrzysztof Drewniak   auto i32x4Ty = VectorType::get({4}, int32Ty);
77f1f05a91SKrzysztof Drewniak 
78f1f05a91SKrzysztof Drewniak   if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty},
79f1f05a91SKrzysztof Drewniak                              parser.getNameLoc(), result.operands))
80f1f05a91SKrzysztof Drewniak     return failure();
81f1f05a91SKrzysztof Drewniak   return success();
82f1f05a91SKrzysztof Drewniak }
83f1f05a91SKrzysztof Drewniak 
print(OpAsmPrinter & p)84f1f05a91SKrzysztof Drewniak void RawBufferStoreOp::print(OpAsmPrinter &p) {
858df54a6aSJacques Pienaar   p << " " << getOperands() << " : " << getVdata().getType();
86f1f05a91SKrzysztof Drewniak }
87f1f05a91SKrzysztof Drewniak 
88f1f05a91SKrzysztof Drewniak // <operation> ::=
89f1f05a91SKrzysztof Drewniak //     `llvm.amdgcn.raw.buffer.atomic.fadd.* %vdata, %rsrc,  %offset,
90f1f05a91SKrzysztof Drewniak //     %soffset, %aux : result_type`
parse(OpAsmParser & parser,OperationState & result)91f1f05a91SKrzysztof Drewniak ParseResult RawBufferAtomicFAddOp::parse(OpAsmParser &parser,
92f1f05a91SKrzysztof Drewniak                                          OperationState &result) {
93f1f05a91SKrzysztof Drewniak   SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
94f1f05a91SKrzysztof Drewniak   Type type;
95f1f05a91SKrzysztof Drewniak   if (parser.parseOperandList(ops, 5) || parser.parseColonType(type))
96f1f05a91SKrzysztof Drewniak     return failure();
97f1f05a91SKrzysztof Drewniak 
98f1f05a91SKrzysztof Drewniak   auto bldr = parser.getBuilder();
99f1f05a91SKrzysztof Drewniak   auto int32Ty = bldr.getI32Type();
100f1f05a91SKrzysztof Drewniak   auto i32x4Ty = VectorType::get({4}, int32Ty);
101f1f05a91SKrzysztof Drewniak 
102f1f05a91SKrzysztof Drewniak   if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty},
103f1f05a91SKrzysztof Drewniak                              parser.getNameLoc(), result.operands))
104f1f05a91SKrzysztof Drewniak     return failure();
105f1f05a91SKrzysztof Drewniak   return success();
106f1f05a91SKrzysztof Drewniak }
107f1f05a91SKrzysztof Drewniak 
print(mlir::OpAsmPrinter & p)108f1f05a91SKrzysztof Drewniak void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
1098df54a6aSJacques Pienaar   p << " " << getOperands() << " : " << getVdata().getType();
110f1f05a91SKrzysztof Drewniak }
111f1f05a91SKrzysztof Drewniak 
112584f6436SManupa Karunaratne // <operation> ::=
113584f6436SManupa Karunaratne //     `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc,  %offset,
114584f6436SManupa Karunaratne //     %soffset, %aux : result_type`
parse(OpAsmParser & parser,OperationState & result)115584f6436SManupa Karunaratne ParseResult RawBufferAtomicFMaxOp::parse(OpAsmParser &parser,
116584f6436SManupa Karunaratne                                          OperationState &result) {
117584f6436SManupa Karunaratne   SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
118584f6436SManupa Karunaratne   Type type;
119584f6436SManupa Karunaratne   if (parser.parseOperandList(ops, 5) || parser.parseColonType(type))
120584f6436SManupa Karunaratne     return failure();
121584f6436SManupa Karunaratne 
122584f6436SManupa Karunaratne   auto bldr = parser.getBuilder();
123584f6436SManupa Karunaratne   auto int32Ty = bldr.getI32Type();
124584f6436SManupa Karunaratne   auto i32x4Ty = VectorType::get({4}, int32Ty);
125584f6436SManupa Karunaratne 
126584f6436SManupa Karunaratne   if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty},
127584f6436SManupa Karunaratne                              parser.getNameLoc(), result.operands))
128584f6436SManupa Karunaratne     return failure();
129584f6436SManupa Karunaratne   return success();
130584f6436SManupa Karunaratne }
131584f6436SManupa Karunaratne 
print(mlir::OpAsmPrinter & p)132584f6436SManupa Karunaratne void RawBufferAtomicFMaxOp::print(mlir::OpAsmPrinter &p) {
133584f6436SManupa Karunaratne   p << " " << getOperands() << " : " << getVdata().getType();
134584f6436SManupa Karunaratne }
135584f6436SManupa Karunaratne 
136584f6436SManupa Karunaratne // <operation> ::=
137584f6436SManupa Karunaratne //     `llvm.amdgcn.raw.buffer.atomic.smax.* %vdata, %rsrc,  %offset,
138584f6436SManupa Karunaratne //     %soffset, %aux : result_type`
parse(OpAsmParser & parser,OperationState & result)139584f6436SManupa Karunaratne ParseResult RawBufferAtomicSMaxOp::parse(OpAsmParser &parser,
140584f6436SManupa Karunaratne                                          OperationState &result) {
141584f6436SManupa Karunaratne   SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
142584f6436SManupa Karunaratne   Type type;
143584f6436SManupa Karunaratne   if (parser.parseOperandList(ops, 5) || parser.parseColonType(type))
144584f6436SManupa Karunaratne     return failure();
145584f6436SManupa Karunaratne 
146584f6436SManupa Karunaratne   auto bldr = parser.getBuilder();
147584f6436SManupa Karunaratne   auto int32Ty = bldr.getI32Type();
148584f6436SManupa Karunaratne   auto i32x4Ty = VectorType::get({4}, int32Ty);
149584f6436SManupa Karunaratne 
150584f6436SManupa Karunaratne   if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty},
151584f6436SManupa Karunaratne                              parser.getNameLoc(), result.operands))
152584f6436SManupa Karunaratne     return failure();
153584f6436SManupa Karunaratne   return success();
154584f6436SManupa Karunaratne }
155584f6436SManupa Karunaratne 
print(mlir::OpAsmPrinter & p)156584f6436SManupa Karunaratne void RawBufferAtomicSMaxOp::print(mlir::OpAsmPrinter &p) {
157584f6436SManupa Karunaratne   p << " " << getOperands() << " : " << getVdata().getType();
158584f6436SManupa Karunaratne }
159584f6436SManupa Karunaratne 
160584f6436SManupa Karunaratne // <operation> ::=
161584f6436SManupa Karunaratne //     `llvm.amdgcn.raw.buffer.atomic.umin.* %vdata, %rsrc,  %offset,
162584f6436SManupa Karunaratne //     %soffset, %aux : result_type`
parse(OpAsmParser & parser,OperationState & result)163584f6436SManupa Karunaratne ParseResult RawBufferAtomicUMinOp::parse(OpAsmParser &parser,
164584f6436SManupa Karunaratne                                          OperationState &result) {
165584f6436SManupa Karunaratne   SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
166584f6436SManupa Karunaratne   Type type;
167584f6436SManupa Karunaratne   if (parser.parseOperandList(ops, 5) || parser.parseColonType(type))
168584f6436SManupa Karunaratne     return failure();
169584f6436SManupa Karunaratne 
170584f6436SManupa Karunaratne   auto bldr = parser.getBuilder();
171584f6436SManupa Karunaratne   auto int32Ty = bldr.getI32Type();
172584f6436SManupa Karunaratne   auto i32x4Ty = VectorType::get({4}, int32Ty);
173584f6436SManupa Karunaratne 
174584f6436SManupa Karunaratne   if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty},
175584f6436SManupa Karunaratne                              parser.getNameLoc(), result.operands))
176584f6436SManupa Karunaratne     return failure();
177584f6436SManupa Karunaratne   return success();
178584f6436SManupa Karunaratne }
179584f6436SManupa Karunaratne 
print(mlir::OpAsmPrinter & p)180584f6436SManupa Karunaratne void RawBufferAtomicUMinOp::print(mlir::OpAsmPrinter &p) {
181584f6436SManupa Karunaratne   p << " " << getOperands() << " : " << getVdata().getType();
182584f6436SManupa Karunaratne }
183584f6436SManupa Karunaratne 
1849c53ac08Sjerryyin //===----------------------------------------------------------------------===//
185fee40fefSDeven Desai // ROCDLDialect initialization, type parsing, and registration.
186fee40fefSDeven Desai //===----------------------------------------------------------------------===//
187fee40fefSDeven Desai 
1889db53a18SRiver Riddle // TODO: This should be the llvm.rocdl dialect once this is supported.
initialize()189575b22b5SMehdi Amini void ROCDLDialect::initialize() {
190fee40fefSDeven Desai   addOperations<
191fee40fefSDeven Desai #define GET_OP_LIST
192fee40fefSDeven Desai #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
193fee40fefSDeven Desai       >();
194fee40fefSDeven Desai 
19506821313SFabian Mora   addAttributes<
19606821313SFabian Mora #define GET_ATTRDEF_LIST
19706821313SFabian Mora #include "mlir/Dialect/LLVMIR/ROCDLOpsAttributes.cpp.inc"
19806821313SFabian Mora       >();
19906821313SFabian Mora 
200fee40fefSDeven Desai   // Support unknown operations because not all ROCDL operations are registered.
201fee40fefSDeven Desai   allowUnknownOperations();
202*35d55f28SJustin Fargnoli   declarePromisedInterface<gpu::TargetAttrInterface, ROCDLTargetAttr>();
203fee40fefSDeven Desai }
204fee40fefSDeven Desai 
verifyOperationAttribute(Operation * op,NamedAttribute attr)2059cd47a26SAlex Zinenko LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
2069cd47a26SAlex Zinenko                                                      NamedAttribute attr) {
2079cd47a26SAlex Zinenko   // Kernel function attribute should be attached to functions.
20845c226d4SMehdi Amini   if (kernelAttrName.getName() == attr.getName()) {
2099cd47a26SAlex Zinenko     if (!isa<LLVM::LLVMFuncOp>(op)) {
21045c226d4SMehdi Amini       return op->emitError() << "'" << kernelAttrName.getName()
2119cd47a26SAlex Zinenko                              << "' attribute attached to unexpected op";
2129cd47a26SAlex Zinenko     }
2139cd47a26SAlex Zinenko   }
2149cd47a26SAlex Zinenko   return success();
2159cd47a26SAlex Zinenko }
2169cd47a26SAlex Zinenko 
21706821313SFabian Mora //===----------------------------------------------------------------------===//
21806821313SFabian Mora // ROCDL target attribute.
21906821313SFabian Mora //===----------------------------------------------------------------------===//
22006821313SFabian Mora LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,int optLevel,StringRef triple,StringRef chip,StringRef features,StringRef abiVersion,DictionaryAttr flags,ArrayAttr files)22106821313SFabian Mora ROCDLTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
22206821313SFabian Mora                         int optLevel, StringRef triple, StringRef chip,
22306821313SFabian Mora                         StringRef features, StringRef abiVersion,
22406821313SFabian Mora                         DictionaryAttr flags, ArrayAttr files) {
22506821313SFabian Mora   if (optLevel < 0 || optLevel > 3) {
22606821313SFabian Mora     emitError() << "The optimization level must be a number between 0 and 3.";
22706821313SFabian Mora     return failure();
22806821313SFabian Mora   }
22906821313SFabian Mora   if (triple.empty()) {
23006821313SFabian Mora     emitError() << "The target triple cannot be empty.";
23106821313SFabian Mora     return failure();
23206821313SFabian Mora   }
23306821313SFabian Mora   if (chip.empty()) {
23406821313SFabian Mora     emitError() << "The target chip cannot be empty.";
23506821313SFabian Mora     return failure();
23606821313SFabian Mora   }
23706821313SFabian Mora   if (abiVersion != "400" && abiVersion != "500") {
23806821313SFabian Mora     emitError() << "Invalid ABI version, it must be either `400` or `500`.";
23906821313SFabian Mora     return failure();
24006821313SFabian Mora   }
24106821313SFabian Mora   if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
24206821313SFabian Mora         return attr && mlir::isa<StringAttr>(attr);
24306821313SFabian Mora       })) {
24406821313SFabian Mora     emitError() << "All the elements in the `link` array must be strings.";
24506821313SFabian Mora     return failure();
24606821313SFabian Mora   }
24706821313SFabian Mora   return success();
24806821313SFabian Mora }
24906821313SFabian Mora 
250fee40fefSDeven Desai #define GET_OP_CLASSES
251fee40fefSDeven Desai #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
25206821313SFabian Mora 
25306821313SFabian Mora #define GET_ATTRDEF_CLASSES
25406821313SFabian Mora #include "mlir/Dialect/LLVMIR/ROCDLOpsAttributes.cpp.inc"
255