1 //===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Defines the cast and conversion operations in the SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 14 15 #include "SPIRVOpUtils.h" 16 #include "SPIRVParsingUtils.h" 17 18 #include "llvm/ADT/TypeSwitch.h" 19 20 using namespace mlir::spirv::AttrNames; 21 22 namespace mlir::spirv { 23 24 static LogicalResult verifyCastOp(Operation *op, 25 bool requireSameBitWidth = true, 26 bool skipBitWidthCheck = false) { 27 // Some CastOps have no limit on bit widths for result and operand type. 28 if (skipBitWidthCheck) 29 return success(); 30 31 Type operandType = op->getOperand(0).getType(); 32 Type resultType = op->getResult(0).getType(); 33 34 // ODS checks that result type and operand type have the same shape. Check 35 // that composite types match and extract the element types, if any. 36 using TypePair = std::pair<Type, Type>; 37 auto [operandElemTy, resultElemTy] = 38 TypeSwitch<Type, TypePair>(operandType) 39 .Case<VectorType, spirv::CooperativeMatrixType>( 40 [resultType](auto concreteOperandTy) -> TypePair { 41 if (auto concreteResultTy = 42 dyn_cast<decltype(concreteOperandTy)>(resultType)) { 43 return {concreteOperandTy.getElementType(), 44 concreteResultTy.getElementType()}; 45 } 46 return {}; 47 }) 48 .Default([resultType](Type operandType) -> TypePair { 49 return {operandType, resultType}; 50 }); 51 52 if (!operandElemTy || !resultElemTy) 53 return op->emitOpError("incompatible operand and result types"); 54 55 unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth(); 56 unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth(); 57 bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; 58 59 if (requireSameBitWidth) { 60 if (!isSameBitWidth) { 61 return op->emitOpError( 62 "expected the same bit widths for operand type and result " 63 "type, but provided ") 64 << operandElemTy << " and " << resultElemTy; 65 } 66 return success(); 67 } 68 69 if (isSameBitWidth) { 70 return op->emitOpError( 71 "expected the different bit widths for operand type and result " 72 "type, but provided ") 73 << operandElemTy << " and " << resultElemTy; 74 } 75 return success(); 76 } 77 78 //===----------------------------------------------------------------------===// 79 // spirv.BitcastOp 80 //===----------------------------------------------------------------------===// 81 82 LogicalResult BitcastOp::verify() { 83 // TODO: The SPIR-V spec validation rules are different for different 84 // versions. 85 auto operandType = getOperand().getType(); 86 auto resultType = getResult().getType(); 87 if (operandType == resultType) { 88 return emitError("result type must be different from operand type"); 89 } 90 if (llvm::isa<spirv::PointerType>(operandType) && 91 !llvm::isa<spirv::PointerType>(resultType)) { 92 return emitError( 93 "unhandled bit cast conversion from pointer type to non-pointer type"); 94 } 95 if (!llvm::isa<spirv::PointerType>(operandType) && 96 llvm::isa<spirv::PointerType>(resultType)) { 97 return emitError( 98 "unhandled bit cast conversion from non-pointer type to pointer type"); 99 } 100 auto operandBitWidth = getBitWidth(operandType); 101 auto resultBitWidth = getBitWidth(resultType); 102 if (operandBitWidth != resultBitWidth) { 103 return emitOpError("mismatch in result type bitwidth ") 104 << resultBitWidth << " and operand type bitwidth " 105 << operandBitWidth; 106 } 107 return success(); 108 } 109 110 //===----------------------------------------------------------------------===// 111 // spirv.ConvertPtrToUOp 112 //===----------------------------------------------------------------------===// 113 114 LogicalResult ConvertPtrToUOp::verify() { 115 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); 116 auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType()); 117 if (!resultType || !resultType.isSignlessInteger()) 118 return emitError("result must be a scalar type of unsigned integer"); 119 auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>(); 120 if (!spirvModule) 121 return success(); 122 auto addressingModel = spirvModule.getAddressingModel(); 123 if ((addressingModel == spirv::AddressingModel::Logical) || 124 (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 && 125 operandType.getStorageClass() != 126 spirv::StorageClass::PhysicalStorageBuffer)) 127 return emitError("operand must be a physical pointer"); 128 return success(); 129 } 130 131 //===----------------------------------------------------------------------===// 132 // spirv.ConvertUToPtrOp 133 //===----------------------------------------------------------------------===// 134 135 LogicalResult ConvertUToPtrOp::verify() { 136 auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType()); 137 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); 138 if (!operandType || !operandType.isSignlessInteger()) 139 return emitError("result must be a scalar type of unsigned integer"); 140 auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>(); 141 if (!spirvModule) 142 return success(); 143 auto addressingModel = spirvModule.getAddressingModel(); 144 if ((addressingModel == spirv::AddressingModel::Logical) || 145 (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 && 146 resultType.getStorageClass() != 147 spirv::StorageClass::PhysicalStorageBuffer)) 148 return emitError("result must be a physical pointer"); 149 return success(); 150 } 151 152 //===----------------------------------------------------------------------===// 153 // spirv.PtrCastToGenericOp 154 //===----------------------------------------------------------------------===// 155 156 LogicalResult PtrCastToGenericOp::verify() { 157 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); 158 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); 159 160 spirv::StorageClass operandStorage = operandType.getStorageClass(); 161 if (operandStorage != spirv::StorageClass::Workgroup && 162 operandStorage != spirv::StorageClass::CrossWorkgroup && 163 operandStorage != spirv::StorageClass::Function) 164 return emitError("pointer must point to the Workgroup, CrossWorkgroup" 165 ", or Function Storage Class"); 166 167 spirv::StorageClass resultStorage = resultType.getStorageClass(); 168 if (resultStorage != spirv::StorageClass::Generic) 169 return emitError("result type must be of storage class Generic"); 170 171 Type operandPointeeType = operandType.getPointeeType(); 172 Type resultPointeeType = resultType.getPointeeType(); 173 if (operandPointeeType != resultPointeeType) 174 return emitOpError("pointer operand's pointee type must have the same " 175 "as the op result type, but found ") 176 << operandPointeeType << " vs " << resultPointeeType; 177 return success(); 178 } 179 180 //===----------------------------------------------------------------------===// 181 // spirv.GenericCastToPtrOp 182 //===----------------------------------------------------------------------===// 183 184 LogicalResult GenericCastToPtrOp::verify() { 185 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); 186 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); 187 188 spirv::StorageClass operandStorage = operandType.getStorageClass(); 189 if (operandStorage != spirv::StorageClass::Generic) 190 return emitError("pointer type must be of storage class Generic"); 191 192 spirv::StorageClass resultStorage = resultType.getStorageClass(); 193 if (resultStorage != spirv::StorageClass::Workgroup && 194 resultStorage != spirv::StorageClass::CrossWorkgroup && 195 resultStorage != spirv::StorageClass::Function) 196 return emitError("result must point to the Workgroup, CrossWorkgroup, " 197 "or Function Storage Class"); 198 199 Type operandPointeeType = operandType.getPointeeType(); 200 Type resultPointeeType = resultType.getPointeeType(); 201 if (operandPointeeType != resultPointeeType) 202 return emitOpError("pointer operand's pointee type must have the same " 203 "as the op result type, but found ") 204 << operandPointeeType << " vs " << resultPointeeType; 205 return success(); 206 } 207 208 //===----------------------------------------------------------------------===// 209 // spirv.GenericCastToPtrExplicitOp 210 //===----------------------------------------------------------------------===// 211 212 LogicalResult GenericCastToPtrExplicitOp::verify() { 213 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType()); 214 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType()); 215 216 spirv::StorageClass operandStorage = operandType.getStorageClass(); 217 if (operandStorage != spirv::StorageClass::Generic) 218 return emitError("pointer type must be of storage class Generic"); 219 220 spirv::StorageClass resultStorage = resultType.getStorageClass(); 221 if (resultStorage != spirv::StorageClass::Workgroup && 222 resultStorage != spirv::StorageClass::CrossWorkgroup && 223 resultStorage != spirv::StorageClass::Function) 224 return emitError("result must point to the Workgroup, CrossWorkgroup, " 225 "or Function Storage Class"); 226 227 Type operandPointeeType = operandType.getPointeeType(); 228 Type resultPointeeType = resultType.getPointeeType(); 229 if (operandPointeeType != resultPointeeType) 230 return emitOpError("pointer operand's pointee type must have the same " 231 "as the op result type, but found ") 232 << operandPointeeType << " vs " << resultPointeeType; 233 return success(); 234 } 235 236 //===----------------------------------------------------------------------===// 237 // spirv.ConvertFToSOp 238 //===----------------------------------------------------------------------===// 239 240 LogicalResult ConvertFToSOp::verify() { 241 return verifyCastOp(*this, /*requireSameBitWidth=*/false, 242 /*skipBitWidthCheck=*/true); 243 } 244 245 //===----------------------------------------------------------------------===// 246 // spirv.ConvertFToUOp 247 //===----------------------------------------------------------------------===// 248 249 LogicalResult ConvertFToUOp::verify() { 250 return verifyCastOp(*this, /*requireSameBitWidth=*/false, 251 /*skipBitWidthCheck=*/true); 252 } 253 254 //===----------------------------------------------------------------------===// 255 // spirv.ConvertSToFOp 256 //===----------------------------------------------------------------------===// 257 258 LogicalResult ConvertSToFOp::verify() { 259 return verifyCastOp(*this, /*requireSameBitWidth=*/false, 260 /*skipBitWidthCheck=*/true); 261 } 262 263 //===----------------------------------------------------------------------===// 264 // spirv.ConvertUToFOp 265 //===----------------------------------------------------------------------===// 266 267 LogicalResult ConvertUToFOp::verify() { 268 return verifyCastOp(*this, /*requireSameBitWidth=*/false, 269 /*skipBitWidthCheck=*/true); 270 } 271 272 //===----------------------------------------------------------------------===// 273 // spirv.INTELConvertBF16ToFOp 274 //===----------------------------------------------------------------------===// 275 276 LogicalResult INTELConvertBF16ToFOp::verify() { 277 auto operandType = getOperand().getType(); 278 auto resultType = getResult().getType(); 279 // ODS checks that vector result type and vector operand type have the same 280 // shape. 281 if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) { 282 unsigned operandNumElements = vectorType.getNumElements(); 283 unsigned resultNumElements = 284 llvm::cast<VectorType>(resultType).getNumElements(); 285 if (operandNumElements != resultNumElements) { 286 return emitOpError( 287 "operand and result must have same number of elements"); 288 } 289 } 290 return success(); 291 } 292 293 //===----------------------------------------------------------------------===// 294 // spirv.INTELConvertFToBF16Op 295 //===----------------------------------------------------------------------===// 296 297 LogicalResult INTELConvertFToBF16Op::verify() { 298 auto operandType = getOperand().getType(); 299 auto resultType = getResult().getType(); 300 // ODS checks that vector result type and vector operand type have the same 301 // shape. 302 if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) { 303 unsigned operandNumElements = vectorType.getNumElements(); 304 unsigned resultNumElements = 305 llvm::cast<VectorType>(resultType).getNumElements(); 306 if (operandNumElements != resultNumElements) { 307 return emitOpError( 308 "operand and result must have same number of elements"); 309 } 310 } 311 return success(); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // spirv.FConvertOp 316 //===----------------------------------------------------------------------===// 317 318 LogicalResult spirv::FConvertOp::verify() { 319 return verifyCastOp(*this, /*requireSameBitWidth=*/false); 320 } 321 322 //===----------------------------------------------------------------------===// 323 // spirv.SConvertOp 324 //===----------------------------------------------------------------------===// 325 326 LogicalResult spirv::SConvertOp::verify() { 327 return verifyCastOp(*this, /*requireSameBitWidth=*/false); 328 } 329 330 //===----------------------------------------------------------------------===// 331 // spirv.UConvertOp 332 //===----------------------------------------------------------------------===// 333 334 LogicalResult spirv::UConvertOp::verify() { 335 return verifyCastOp(*this, /*requireSameBitWidth=*/false); 336 } 337 338 } // namespace mlir::spirv 339