xref: /llvm-project/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp (revision b719ab4eef634f24605ca7ccd4874338c34e05bd)
1 //===- CastOps.cpp - MLIR SPIR-V Cast Ops  --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the cast and conversion operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir::spirv::AttrNames;
21 
22 namespace mlir::spirv {
23 
24 static LogicalResult verifyCastOp(Operation *op,
25                                   bool requireSameBitWidth = true,
26                                   bool skipBitWidthCheck = false) {
27   // Some CastOps have no limit on bit widths for result and operand type.
28   if (skipBitWidthCheck)
29     return success();
30 
31   Type operandType = op->getOperand(0).getType();
32   Type resultType = op->getResult(0).getType();
33 
34   // ODS checks that result type and operand type have the same shape. Check
35   // that composite types match and extract the element types, if any.
36   using TypePair = std::pair<Type, Type>;
37   auto [operandElemTy, resultElemTy] =
38       TypeSwitch<Type, TypePair>(operandType)
39           .Case<VectorType, spirv::CooperativeMatrixType>(
40               [resultType](auto concreteOperandTy) -> TypePair {
41                 if (auto concreteResultTy =
42                         dyn_cast<decltype(concreteOperandTy)>(resultType)) {
43                   return {concreteOperandTy.getElementType(),
44                           concreteResultTy.getElementType()};
45                 }
46                 return {};
47               })
48           .Default([resultType](Type operandType) -> TypePair {
49             return {operandType, resultType};
50           });
51 
52   if (!operandElemTy || !resultElemTy)
53     return op->emitOpError("incompatible operand and result types");
54 
55   unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
56   unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
57   bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
58 
59   if (requireSameBitWidth) {
60     if (!isSameBitWidth) {
61       return op->emitOpError(
62                  "expected the same bit widths for operand type and result "
63                  "type, but provided ")
64              << operandElemTy << " and " << resultElemTy;
65     }
66     return success();
67   }
68 
69   if (isSameBitWidth) {
70     return op->emitOpError(
71                "expected the different bit widths for operand type and result "
72                "type, but provided ")
73            << operandElemTy << " and " << resultElemTy;
74   }
75   return success();
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // spirv.BitcastOp
80 //===----------------------------------------------------------------------===//
81 
82 LogicalResult BitcastOp::verify() {
83   // TODO: The SPIR-V spec validation rules are different for different
84   // versions.
85   auto operandType = getOperand().getType();
86   auto resultType = getResult().getType();
87   if (operandType == resultType) {
88     return emitError("result type must be different from operand type");
89   }
90   if (llvm::isa<spirv::PointerType>(operandType) &&
91       !llvm::isa<spirv::PointerType>(resultType)) {
92     return emitError(
93         "unhandled bit cast conversion from pointer type to non-pointer type");
94   }
95   if (!llvm::isa<spirv::PointerType>(operandType) &&
96       llvm::isa<spirv::PointerType>(resultType)) {
97     return emitError(
98         "unhandled bit cast conversion from non-pointer type to pointer type");
99   }
100   auto operandBitWidth = getBitWidth(operandType);
101   auto resultBitWidth = getBitWidth(resultType);
102   if (operandBitWidth != resultBitWidth) {
103     return emitOpError("mismatch in result type bitwidth ")
104            << resultBitWidth << " and operand type bitwidth "
105            << operandBitWidth;
106   }
107   return success();
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // spirv.ConvertPtrToUOp
112 //===----------------------------------------------------------------------===//
113 
114 LogicalResult ConvertPtrToUOp::verify() {
115   auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
116   auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
117   if (!resultType || !resultType.isSignlessInteger())
118     return emitError("result must be a scalar type of unsigned integer");
119   auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
120   if (!spirvModule)
121     return success();
122   auto addressingModel = spirvModule.getAddressingModel();
123   if ((addressingModel == spirv::AddressingModel::Logical) ||
124       (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
125        operandType.getStorageClass() !=
126            spirv::StorageClass::PhysicalStorageBuffer))
127     return emitError("operand must be a physical pointer");
128   return success();
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // spirv.ConvertUToPtrOp
133 //===----------------------------------------------------------------------===//
134 
135 LogicalResult ConvertUToPtrOp::verify() {
136   auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
137   auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
138   if (!operandType || !operandType.isSignlessInteger())
139     return emitError("result must be a scalar type of unsigned integer");
140   auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
141   if (!spirvModule)
142     return success();
143   auto addressingModel = spirvModule.getAddressingModel();
144   if ((addressingModel == spirv::AddressingModel::Logical) ||
145       (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
146        resultType.getStorageClass() !=
147            spirv::StorageClass::PhysicalStorageBuffer))
148     return emitError("result must be a physical pointer");
149   return success();
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // spirv.PtrCastToGenericOp
154 //===----------------------------------------------------------------------===//
155 
156 LogicalResult PtrCastToGenericOp::verify() {
157   auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
158   auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
159 
160   spirv::StorageClass operandStorage = operandType.getStorageClass();
161   if (operandStorage != spirv::StorageClass::Workgroup &&
162       operandStorage != spirv::StorageClass::CrossWorkgroup &&
163       operandStorage != spirv::StorageClass::Function)
164     return emitError("pointer must point to the Workgroup, CrossWorkgroup"
165                      ", or Function Storage Class");
166 
167   spirv::StorageClass resultStorage = resultType.getStorageClass();
168   if (resultStorage != spirv::StorageClass::Generic)
169     return emitError("result type must be of storage class Generic");
170 
171   Type operandPointeeType = operandType.getPointeeType();
172   Type resultPointeeType = resultType.getPointeeType();
173   if (operandPointeeType != resultPointeeType)
174     return emitOpError("pointer operand's pointee type must have the same "
175                        "as the op result type, but found ")
176            << operandPointeeType << " vs " << resultPointeeType;
177   return success();
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // spirv.GenericCastToPtrOp
182 //===----------------------------------------------------------------------===//
183 
184 LogicalResult GenericCastToPtrOp::verify() {
185   auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
186   auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
187 
188   spirv::StorageClass operandStorage = operandType.getStorageClass();
189   if (operandStorage != spirv::StorageClass::Generic)
190     return emitError("pointer type must be of storage class Generic");
191 
192   spirv::StorageClass resultStorage = resultType.getStorageClass();
193   if (resultStorage != spirv::StorageClass::Workgroup &&
194       resultStorage != spirv::StorageClass::CrossWorkgroup &&
195       resultStorage != spirv::StorageClass::Function)
196     return emitError("result must point to the Workgroup, CrossWorkgroup, "
197                      "or Function Storage Class");
198 
199   Type operandPointeeType = operandType.getPointeeType();
200   Type resultPointeeType = resultType.getPointeeType();
201   if (operandPointeeType != resultPointeeType)
202     return emitOpError("pointer operand's pointee type must have the same "
203                        "as the op result type, but found ")
204            << operandPointeeType << " vs " << resultPointeeType;
205   return success();
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // spirv.GenericCastToPtrExplicitOp
210 //===----------------------------------------------------------------------===//
211 
212 LogicalResult GenericCastToPtrExplicitOp::verify() {
213   auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
214   auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
215 
216   spirv::StorageClass operandStorage = operandType.getStorageClass();
217   if (operandStorage != spirv::StorageClass::Generic)
218     return emitError("pointer type must be of storage class Generic");
219 
220   spirv::StorageClass resultStorage = resultType.getStorageClass();
221   if (resultStorage != spirv::StorageClass::Workgroup &&
222       resultStorage != spirv::StorageClass::CrossWorkgroup &&
223       resultStorage != spirv::StorageClass::Function)
224     return emitError("result must point to the Workgroup, CrossWorkgroup, "
225                      "or Function Storage Class");
226 
227   Type operandPointeeType = operandType.getPointeeType();
228   Type resultPointeeType = resultType.getPointeeType();
229   if (operandPointeeType != resultPointeeType)
230     return emitOpError("pointer operand's pointee type must have the same "
231                        "as the op result type, but found ")
232            << operandPointeeType << " vs " << resultPointeeType;
233   return success();
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // spirv.ConvertFToSOp
238 //===----------------------------------------------------------------------===//
239 
240 LogicalResult ConvertFToSOp::verify() {
241   return verifyCastOp(*this, /*requireSameBitWidth=*/false,
242                       /*skipBitWidthCheck=*/true);
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // spirv.ConvertFToUOp
247 //===----------------------------------------------------------------------===//
248 
249 LogicalResult ConvertFToUOp::verify() {
250   return verifyCastOp(*this, /*requireSameBitWidth=*/false,
251                       /*skipBitWidthCheck=*/true);
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // spirv.ConvertSToFOp
256 //===----------------------------------------------------------------------===//
257 
258 LogicalResult ConvertSToFOp::verify() {
259   return verifyCastOp(*this, /*requireSameBitWidth=*/false,
260                       /*skipBitWidthCheck=*/true);
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // spirv.ConvertUToFOp
265 //===----------------------------------------------------------------------===//
266 
267 LogicalResult ConvertUToFOp::verify() {
268   return verifyCastOp(*this, /*requireSameBitWidth=*/false,
269                       /*skipBitWidthCheck=*/true);
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // spirv.INTELConvertBF16ToFOp
274 //===----------------------------------------------------------------------===//
275 
276 LogicalResult INTELConvertBF16ToFOp::verify() {
277   auto operandType = getOperand().getType();
278   auto resultType = getResult().getType();
279   // ODS checks that vector result type and vector operand type have the same
280   // shape.
281   if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
282     unsigned operandNumElements = vectorType.getNumElements();
283     unsigned resultNumElements =
284         llvm::cast<VectorType>(resultType).getNumElements();
285     if (operandNumElements != resultNumElements) {
286       return emitOpError(
287           "operand and result must have same number of elements");
288     }
289   }
290   return success();
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // spirv.INTELConvertFToBF16Op
295 //===----------------------------------------------------------------------===//
296 
297 LogicalResult INTELConvertFToBF16Op::verify() {
298   auto operandType = getOperand().getType();
299   auto resultType = getResult().getType();
300   // ODS checks that vector result type and vector operand type have the same
301   // shape.
302   if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
303     unsigned operandNumElements = vectorType.getNumElements();
304     unsigned resultNumElements =
305         llvm::cast<VectorType>(resultType).getNumElements();
306     if (operandNumElements != resultNumElements) {
307       return emitOpError(
308           "operand and result must have same number of elements");
309     }
310   }
311   return success();
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // spirv.FConvertOp
316 //===----------------------------------------------------------------------===//
317 
318 LogicalResult spirv::FConvertOp::verify() {
319   return verifyCastOp(*this, /*requireSameBitWidth=*/false);
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // spirv.SConvertOp
324 //===----------------------------------------------------------------------===//
325 
326 LogicalResult spirv::SConvertOp::verify() {
327   return verifyCastOp(*this, /*requireSameBitWidth=*/false);
328 }
329 
330 //===----------------------------------------------------------------------===//
331 // spirv.UConvertOp
332 //===----------------------------------------------------------------------===//
333 
334 LogicalResult spirv::UConvertOp::verify() {
335   return verifyCastOp(*this, /*requireSameBitWidth=*/false);
336 }
337 
338 } // namespace mlir::spirv
339