1fee40fefSDeven Desai //===- ROCDLDialect.cpp - ROCDL IR Ops and Dialect registration -----------===//
2fee40fefSDeven Desai //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fee40fefSDeven Desai //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
8fee40fefSDeven Desai //
9fee40fefSDeven Desai // This file defines the types and operation details for the ROCDL IR dialect in
10fee40fefSDeven Desai // MLIR, and the LLVM IR dialect. It also registers the dialect.
11fee40fefSDeven Desai //
12fee40fefSDeven Desai // The ROCDL dialect only contains GPU specific additions on top of the general
13fee40fefSDeven Desai // LLVM dialect.
14fee40fefSDeven Desai //
15fee40fefSDeven Desai //===----------------------------------------------------------------------===//
16fee40fefSDeven Desai
17fee40fefSDeven Desai #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18fee40fefSDeven Desai
199779a731SMarkus Böck #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
20fee40fefSDeven Desai #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21fee40fefSDeven Desai #include "mlir/IR/Builders.h"
2209f7a55fSRiver Riddle #include "mlir/IR/BuiltinTypes.h"
2306821313SFabian Mora #include "mlir/IR/DialectImplementation.h"
24fee40fefSDeven Desai #include "mlir/IR/MLIRContext.h"
25fee40fefSDeven Desai #include "mlir/IR/Operation.h"
2606821313SFabian Mora #include "llvm/ADT/TypeSwitch.h"
27fee40fefSDeven Desai #include "llvm/AsmParser/Parser.h"
28fee40fefSDeven Desai #include "llvm/IR/Attributes.h"
29fee40fefSDeven Desai #include "llvm/IR/Function.h"
30fee40fefSDeven Desai #include "llvm/IR/Type.h"
31fee40fefSDeven Desai #include "llvm/Support/SourceMgr.h"
32fee40fefSDeven Desai
33e7aa47ffSRiver Riddle using namespace mlir;
34e7aa47ffSRiver Riddle using namespace ROCDL;
35fee40fefSDeven Desai
36485cc55eSStella Laurenzo #include "mlir/Dialect/LLVMIR/ROCDLOpsDialect.cpp.inc"
37485cc55eSStella Laurenzo
38fee40fefSDeven Desai //===----------------------------------------------------------------------===//
399c53ac08Sjerryyin // Parsing for ROCDL ops
409c53ac08Sjerryyin //===----------------------------------------------------------------------===//
419c53ac08Sjerryyin
429c53ac08Sjerryyin // <operation> ::=
43f1f05a91SKrzysztof Drewniak // `llvm.amdgcn.raw.buffer.load.* %rsrc, %offset, %soffset, %aux
44f1f05a91SKrzysztof Drewniak // : result_type`
parse(OpAsmParser & parser,OperationState & result)45f1f05a91SKrzysztof Drewniak ParseResult RawBufferLoadOp::parse(OpAsmParser &parser,
46f1f05a91SKrzysztof Drewniak OperationState &result) {
47f1f05a91SKrzysztof Drewniak SmallVector<OpAsmParser::UnresolvedOperand, 4> ops;
48f1f05a91SKrzysztof Drewniak Type type;
49f1f05a91SKrzysztof Drewniak if (parser.parseOperandList(ops, 4) || parser.parseColonType(type) ||
50f1f05a91SKrzysztof Drewniak parser.addTypeToList(type, result.types))
51f1f05a91SKrzysztof Drewniak return failure();
52f1f05a91SKrzysztof Drewniak
53f1f05a91SKrzysztof Drewniak auto bldr = parser.getBuilder();
54f1f05a91SKrzysztof Drewniak auto int32Ty = bldr.getI32Type();
55f1f05a91SKrzysztof Drewniak auto i32x4Ty = VectorType::get({4}, int32Ty);
56f1f05a91SKrzysztof Drewniak return parser.resolveOperands(ops, {i32x4Ty, int32Ty, int32Ty, int32Ty},
57f1f05a91SKrzysztof Drewniak parser.getNameLoc(), result.operands);
58f1f05a91SKrzysztof Drewniak }
59f1f05a91SKrzysztof Drewniak
print(OpAsmPrinter & p)60f1f05a91SKrzysztof Drewniak void RawBufferLoadOp::print(OpAsmPrinter &p) {
618df54a6aSJacques Pienaar p << " " << getOperands() << " : " << getRes().getType();
62f1f05a91SKrzysztof Drewniak }
63f1f05a91SKrzysztof Drewniak
64f1f05a91SKrzysztof Drewniak // <operation> ::=
65f1f05a91SKrzysztof Drewniak // `llvm.amdgcn.raw.buffer.store.* %vdata, %rsrc, %offset,
66f1f05a91SKrzysztof Drewniak // %soffset, %aux : result_type`
parse(OpAsmParser & parser,OperationState & result)67f1f05a91SKrzysztof Drewniak ParseResult RawBufferStoreOp::parse(OpAsmParser &parser,
68f1f05a91SKrzysztof Drewniak OperationState &result) {
69f1f05a91SKrzysztof Drewniak SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
70f1f05a91SKrzysztof Drewniak Type type;
71f1f05a91SKrzysztof Drewniak if (parser.parseOperandList(ops, 5) || parser.parseColonType(type))
72f1f05a91SKrzysztof Drewniak return failure();
73f1f05a91SKrzysztof Drewniak
74f1f05a91SKrzysztof Drewniak auto bldr = parser.getBuilder();
75f1f05a91SKrzysztof Drewniak auto int32Ty = bldr.getI32Type();
76f1f05a91SKrzysztof Drewniak auto i32x4Ty = VectorType::get({4}, int32Ty);
77f1f05a91SKrzysztof Drewniak
78f1f05a91SKrzysztof Drewniak if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty},
79f1f05a91SKrzysztof Drewniak parser.getNameLoc(), result.operands))
80f1f05a91SKrzysztof Drewniak return failure();
81f1f05a91SKrzysztof Drewniak return success();
82f1f05a91SKrzysztof Drewniak }
83f1f05a91SKrzysztof Drewniak
print(OpAsmPrinter & p)84f1f05a91SKrzysztof Drewniak void RawBufferStoreOp::print(OpAsmPrinter &p) {
858df54a6aSJacques Pienaar p << " " << getOperands() << " : " << getVdata().getType();
86f1f05a91SKrzysztof Drewniak }
87f1f05a91SKrzysztof Drewniak
88f1f05a91SKrzysztof Drewniak // <operation> ::=
89f1f05a91SKrzysztof Drewniak // `llvm.amdgcn.raw.buffer.atomic.fadd.* %vdata, %rsrc, %offset,
90f1f05a91SKrzysztof Drewniak // %soffset, %aux : result_type`
parse(OpAsmParser & parser,OperationState & result)91f1f05a91SKrzysztof Drewniak ParseResult RawBufferAtomicFAddOp::parse(OpAsmParser &parser,
92f1f05a91SKrzysztof Drewniak OperationState &result) {
93f1f05a91SKrzysztof Drewniak SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
94f1f05a91SKrzysztof Drewniak Type type;
95f1f05a91SKrzysztof Drewniak if (parser.parseOperandList(ops, 5) || parser.parseColonType(type))
96f1f05a91SKrzysztof Drewniak return failure();
97f1f05a91SKrzysztof Drewniak
98f1f05a91SKrzysztof Drewniak auto bldr = parser.getBuilder();
99f1f05a91SKrzysztof Drewniak auto int32Ty = bldr.getI32Type();
100f1f05a91SKrzysztof Drewniak auto i32x4Ty = VectorType::get({4}, int32Ty);
101f1f05a91SKrzysztof Drewniak
102f1f05a91SKrzysztof Drewniak if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty},
103f1f05a91SKrzysztof Drewniak parser.getNameLoc(), result.operands))
104f1f05a91SKrzysztof Drewniak return failure();
105f1f05a91SKrzysztof Drewniak return success();
106f1f05a91SKrzysztof Drewniak }
107f1f05a91SKrzysztof Drewniak
print(mlir::OpAsmPrinter & p)108f1f05a91SKrzysztof Drewniak void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
1098df54a6aSJacques Pienaar p << " " << getOperands() << " : " << getVdata().getType();
110f1f05a91SKrzysztof Drewniak }
111f1f05a91SKrzysztof Drewniak
112584f6436SManupa Karunaratne // <operation> ::=
113584f6436SManupa Karunaratne // `llvm.amdgcn.raw.buffer.atomic.fmax.* %vdata, %rsrc, %offset,
114584f6436SManupa Karunaratne // %soffset, %aux : result_type`
parse(OpAsmParser & parser,OperationState & result)115584f6436SManupa Karunaratne ParseResult RawBufferAtomicFMaxOp::parse(OpAsmParser &parser,
116584f6436SManupa Karunaratne OperationState &result) {
117584f6436SManupa Karunaratne SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
118584f6436SManupa Karunaratne Type type;
119584f6436SManupa Karunaratne if (parser.parseOperandList(ops, 5) || parser.parseColonType(type))
120584f6436SManupa Karunaratne return failure();
121584f6436SManupa Karunaratne
122584f6436SManupa Karunaratne auto bldr = parser.getBuilder();
123584f6436SManupa Karunaratne auto int32Ty = bldr.getI32Type();
124584f6436SManupa Karunaratne auto i32x4Ty = VectorType::get({4}, int32Ty);
125584f6436SManupa Karunaratne
126584f6436SManupa Karunaratne if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty},
127584f6436SManupa Karunaratne parser.getNameLoc(), result.operands))
128584f6436SManupa Karunaratne return failure();
129584f6436SManupa Karunaratne return success();
130584f6436SManupa Karunaratne }
131584f6436SManupa Karunaratne
print(mlir::OpAsmPrinter & p)132584f6436SManupa Karunaratne void RawBufferAtomicFMaxOp::print(mlir::OpAsmPrinter &p) {
133584f6436SManupa Karunaratne p << " " << getOperands() << " : " << getVdata().getType();
134584f6436SManupa Karunaratne }
135584f6436SManupa Karunaratne
136584f6436SManupa Karunaratne // <operation> ::=
137584f6436SManupa Karunaratne // `llvm.amdgcn.raw.buffer.atomic.smax.* %vdata, %rsrc, %offset,
138584f6436SManupa Karunaratne // %soffset, %aux : result_type`
parse(OpAsmParser & parser,OperationState & result)139584f6436SManupa Karunaratne ParseResult RawBufferAtomicSMaxOp::parse(OpAsmParser &parser,
140584f6436SManupa Karunaratne OperationState &result) {
141584f6436SManupa Karunaratne SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
142584f6436SManupa Karunaratne Type type;
143584f6436SManupa Karunaratne if (parser.parseOperandList(ops, 5) || parser.parseColonType(type))
144584f6436SManupa Karunaratne return failure();
145584f6436SManupa Karunaratne
146584f6436SManupa Karunaratne auto bldr = parser.getBuilder();
147584f6436SManupa Karunaratne auto int32Ty = bldr.getI32Type();
148584f6436SManupa Karunaratne auto i32x4Ty = VectorType::get({4}, int32Ty);
149584f6436SManupa Karunaratne
150584f6436SManupa Karunaratne if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty},
151584f6436SManupa Karunaratne parser.getNameLoc(), result.operands))
152584f6436SManupa Karunaratne return failure();
153584f6436SManupa Karunaratne return success();
154584f6436SManupa Karunaratne }
155584f6436SManupa Karunaratne
print(mlir::OpAsmPrinter & p)156584f6436SManupa Karunaratne void RawBufferAtomicSMaxOp::print(mlir::OpAsmPrinter &p) {
157584f6436SManupa Karunaratne p << " " << getOperands() << " : " << getVdata().getType();
158584f6436SManupa Karunaratne }
159584f6436SManupa Karunaratne
160584f6436SManupa Karunaratne // <operation> ::=
161584f6436SManupa Karunaratne // `llvm.amdgcn.raw.buffer.atomic.umin.* %vdata, %rsrc, %offset,
162584f6436SManupa Karunaratne // %soffset, %aux : result_type`
parse(OpAsmParser & parser,OperationState & result)163584f6436SManupa Karunaratne ParseResult RawBufferAtomicUMinOp::parse(OpAsmParser &parser,
164584f6436SManupa Karunaratne OperationState &result) {
165584f6436SManupa Karunaratne SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
166584f6436SManupa Karunaratne Type type;
167584f6436SManupa Karunaratne if (parser.parseOperandList(ops, 5) || parser.parseColonType(type))
168584f6436SManupa Karunaratne return failure();
169584f6436SManupa Karunaratne
170584f6436SManupa Karunaratne auto bldr = parser.getBuilder();
171584f6436SManupa Karunaratne auto int32Ty = bldr.getI32Type();
172584f6436SManupa Karunaratne auto i32x4Ty = VectorType::get({4}, int32Ty);
173584f6436SManupa Karunaratne
174584f6436SManupa Karunaratne if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty},
175584f6436SManupa Karunaratne parser.getNameLoc(), result.operands))
176584f6436SManupa Karunaratne return failure();
177584f6436SManupa Karunaratne return success();
178584f6436SManupa Karunaratne }
179584f6436SManupa Karunaratne
print(mlir::OpAsmPrinter & p)180584f6436SManupa Karunaratne void RawBufferAtomicUMinOp::print(mlir::OpAsmPrinter &p) {
181584f6436SManupa Karunaratne p << " " << getOperands() << " : " << getVdata().getType();
182584f6436SManupa Karunaratne }
183584f6436SManupa Karunaratne
1849c53ac08Sjerryyin //===----------------------------------------------------------------------===//
185fee40fefSDeven Desai // ROCDLDialect initialization, type parsing, and registration.
186fee40fefSDeven Desai //===----------------------------------------------------------------------===//
187fee40fefSDeven Desai
1889db53a18SRiver Riddle // TODO: This should be the llvm.rocdl dialect once this is supported.
initialize()189575b22b5SMehdi Amini void ROCDLDialect::initialize() {
190fee40fefSDeven Desai addOperations<
191fee40fefSDeven Desai #define GET_OP_LIST
192fee40fefSDeven Desai #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
193fee40fefSDeven Desai >();
194fee40fefSDeven Desai
19506821313SFabian Mora addAttributes<
19606821313SFabian Mora #define GET_ATTRDEF_LIST
19706821313SFabian Mora #include "mlir/Dialect/LLVMIR/ROCDLOpsAttributes.cpp.inc"
19806821313SFabian Mora >();
19906821313SFabian Mora
200fee40fefSDeven Desai // Support unknown operations because not all ROCDL operations are registered.
201fee40fefSDeven Desai allowUnknownOperations();
202*35d55f28SJustin Fargnoli declarePromisedInterface<gpu::TargetAttrInterface, ROCDLTargetAttr>();
203fee40fefSDeven Desai }
204fee40fefSDeven Desai
verifyOperationAttribute(Operation * op,NamedAttribute attr)2059cd47a26SAlex Zinenko LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
2069cd47a26SAlex Zinenko NamedAttribute attr) {
2079cd47a26SAlex Zinenko // Kernel function attribute should be attached to functions.
20845c226d4SMehdi Amini if (kernelAttrName.getName() == attr.getName()) {
2099cd47a26SAlex Zinenko if (!isa<LLVM::LLVMFuncOp>(op)) {
21045c226d4SMehdi Amini return op->emitError() << "'" << kernelAttrName.getName()
2119cd47a26SAlex Zinenko << "' attribute attached to unexpected op";
2129cd47a26SAlex Zinenko }
2139cd47a26SAlex Zinenko }
2149cd47a26SAlex Zinenko return success();
2159cd47a26SAlex Zinenko }
2169cd47a26SAlex Zinenko
21706821313SFabian Mora //===----------------------------------------------------------------------===//
21806821313SFabian Mora // ROCDL target attribute.
21906821313SFabian Mora //===----------------------------------------------------------------------===//
22006821313SFabian Mora LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,int optLevel,StringRef triple,StringRef chip,StringRef features,StringRef abiVersion,DictionaryAttr flags,ArrayAttr files)22106821313SFabian Mora ROCDLTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
22206821313SFabian Mora int optLevel, StringRef triple, StringRef chip,
22306821313SFabian Mora StringRef features, StringRef abiVersion,
22406821313SFabian Mora DictionaryAttr flags, ArrayAttr files) {
22506821313SFabian Mora if (optLevel < 0 || optLevel > 3) {
22606821313SFabian Mora emitError() << "The optimization level must be a number between 0 and 3.";
22706821313SFabian Mora return failure();
22806821313SFabian Mora }
22906821313SFabian Mora if (triple.empty()) {
23006821313SFabian Mora emitError() << "The target triple cannot be empty.";
23106821313SFabian Mora return failure();
23206821313SFabian Mora }
23306821313SFabian Mora if (chip.empty()) {
23406821313SFabian Mora emitError() << "The target chip cannot be empty.";
23506821313SFabian Mora return failure();
23606821313SFabian Mora }
23706821313SFabian Mora if (abiVersion != "400" && abiVersion != "500") {
23806821313SFabian Mora emitError() << "Invalid ABI version, it must be either `400` or `500`.";
23906821313SFabian Mora return failure();
24006821313SFabian Mora }
24106821313SFabian Mora if (files && !llvm::all_of(files, [](::mlir::Attribute attr) {
24206821313SFabian Mora return attr && mlir::isa<StringAttr>(attr);
24306821313SFabian Mora })) {
24406821313SFabian Mora emitError() << "All the elements in the `link` array must be strings.";
24506821313SFabian Mora return failure();
24606821313SFabian Mora }
24706821313SFabian Mora return success();
24806821313SFabian Mora }
24906821313SFabian Mora
250fee40fefSDeven Desai #define GET_OP_CLASSES
251fee40fefSDeven Desai #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
25206821313SFabian Mora
25306821313SFabian Mora #define GET_ATTRDEF_CLASSES
25406821313SFabian Mora #include "mlir/Dialect/LLVMIR/ROCDLOpsAttributes.cpp.inc"
255