xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/AtomicOps.cpp (revision 1d5e3b2d6559a853c544099e4cf1d46f44f83368)
1 //===- AtomicOps.cpp - MLIR SPIR-V Atomic Ops  ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the atomic operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
18 using namespace mlir::spirv::AttrNames;
19 
20 namespace mlir::spirv {
21 
22 template <typename T>
23 static StringRef stringifyTypeName();
24 
25 template <>
stringifyTypeName()26 StringRef stringifyTypeName<IntegerType>() {
27   return "integer";
28 }
29 
30 template <>
stringifyTypeName()31 StringRef stringifyTypeName<FloatType>() {
32   return "float";
33 }
34 
35 // Verifies an atomic update op.
36 template <typename AtomicOpTy, typename ExpectedElementType>
verifyAtomicUpdateOp(Operation * op)37 static LogicalResult verifyAtomicUpdateOp(Operation *op) {
38   auto ptrType = llvm::cast<spirv::PointerType>(op->getOperand(0).getType());
39   auto elementType = ptrType.getPointeeType();
40   if (!llvm::isa<ExpectedElementType>(elementType))
41     return op->emitOpError() << "pointer operand must point to an "
42                              << stringifyTypeName<ExpectedElementType>()
43                              << " value, found " << elementType;
44 
45   StringAttr semanticsAttrName =
46       AtomicOpTy::getSemanticsAttrName(op->getName());
47   auto memorySemantics =
48       op->getAttrOfType<spirv::MemorySemanticsAttr>(semanticsAttrName)
49           .getValue();
50   if (failed(verifyMemorySemantics(op, memorySemantics))) {
51     return failure();
52   }
53   return success();
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // spirv.AtomicAndOp
58 //===----------------------------------------------------------------------===//
59 
verify()60 LogicalResult AtomicAndOp::verify() {
61   return verifyAtomicUpdateOp<AtomicAndOp, IntegerType>(getOperation());
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // spirv.AtomicIAddOp
66 //===----------------------------------------------------------------------===//
67 
verify()68 LogicalResult AtomicIAddOp::verify() {
69   return verifyAtomicUpdateOp<AtomicIAddOp, IntegerType>(getOperation());
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // spirv.EXT.AtomicFAddOp
74 //===----------------------------------------------------------------------===//
75 
verify()76 LogicalResult EXTAtomicFAddOp::verify() {
77   return verifyAtomicUpdateOp<EXTAtomicFAddOp, FloatType>(getOperation());
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // spirv.AtomicIDecrementOp
82 //===----------------------------------------------------------------------===//
83 
verify()84 LogicalResult AtomicIDecrementOp::verify() {
85   return verifyAtomicUpdateOp<AtomicIDecrementOp, IntegerType>(getOperation());
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // spirv.AtomicIIncrementOp
90 //===----------------------------------------------------------------------===//
91 
verify()92 LogicalResult AtomicIIncrementOp::verify() {
93   return verifyAtomicUpdateOp<AtomicIIncrementOp, IntegerType>(getOperation());
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // spirv.AtomicISubOp
98 //===----------------------------------------------------------------------===//
99 
verify()100 LogicalResult AtomicISubOp::verify() {
101   return verifyAtomicUpdateOp<AtomicISubOp, IntegerType>(getOperation());
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // spirv.AtomicOrOp
106 //===----------------------------------------------------------------------===//
107 
verify()108 LogicalResult AtomicOrOp::verify() {
109   return verifyAtomicUpdateOp<AtomicOrOp, IntegerType>(getOperation());
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // spirv.AtomicSMaxOp
114 //===----------------------------------------------------------------------===//
115 
verify()116 LogicalResult AtomicSMaxOp::verify() {
117   return verifyAtomicUpdateOp<AtomicSMaxOp, IntegerType>(getOperation());
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // spirv.AtomicSMinOp
122 //===----------------------------------------------------------------------===//
123 
verify()124 LogicalResult AtomicSMinOp::verify() {
125   return verifyAtomicUpdateOp<AtomicSMinOp, IntegerType>(getOperation());
126 }
127 
128 //===----------------------------------------------------------------------===//
129 // spirv.AtomicUMaxOp
130 //===----------------------------------------------------------------------===//
131 
verify()132 LogicalResult AtomicUMaxOp::verify() {
133   return verifyAtomicUpdateOp<AtomicUMaxOp, IntegerType>(getOperation());
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // spirv.AtomicUMinOp
138 //===----------------------------------------------------------------------===//
139 
verify()140 LogicalResult AtomicUMinOp::verify() {
141   return verifyAtomicUpdateOp<AtomicUMinOp, IntegerType>(getOperation());
142 }
143 
144 //===----------------------------------------------------------------------===//
145 // spirv.AtomicXorOp
146 //===----------------------------------------------------------------------===//
147 
verify()148 LogicalResult AtomicXorOp::verify() {
149   return verifyAtomicUpdateOp<AtomicXorOp, IntegerType>(getOperation());
150 }
151 
152 } // namespace mlir::spirv
153