1 //===- GroupOps.cpp - MLIR SPIR-V Group Ops ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Defines the group operations in the SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 14 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 15 16 #include "SPIRVOpUtils.h" 17 #include "SPIRVParsingUtils.h" 18 19 using namespace mlir::spirv::AttrNames; 20 21 namespace mlir::spirv { 22 23 template <typename OpTy> 24 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) { 25 spirv::Scope scope = 26 groupOp 27 ->getAttrOfType<spirv::ScopeAttr>( 28 OpTy::getExecutionScopeAttrName(groupOp->getName())) 29 .getValue(); 30 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) 31 return groupOp->emitOpError( 32 "execution scope must be 'Workgroup' or 'Subgroup'"); 33 34 GroupOperation operation = 35 groupOp 36 ->getAttrOfType<GroupOperationAttr>( 37 OpTy::getGroupOperationAttrName(groupOp->getName())) 38 .getValue(); 39 if (operation == GroupOperation::ClusteredReduce && 40 groupOp->getNumOperands() == 1) 41 return groupOp->emitOpError("cluster size operand must be provided for " 42 "'ClusteredReduce' group operation"); 43 if (groupOp->getNumOperands() > 1) { 44 Operation *sizeOp = groupOp->getOperand(1).getDefiningOp(); 45 int32_t clusterSize = 0; 46 47 // TODO: support specialization constant here. 48 if (failed(extractValueFromConstOp(sizeOp, clusterSize))) 49 return groupOp->emitOpError( 50 "cluster size operand must come from a constant op"); 51 52 if (!llvm::isPowerOf2_32(clusterSize)) 53 return groupOp->emitOpError( 54 "cluster size operand must be a power of two"); 55 } 56 return success(); 57 } 58 59 //===----------------------------------------------------------------------===// 60 // spirv.GroupBroadcast 61 //===----------------------------------------------------------------------===// 62 63 LogicalResult GroupBroadcastOp::verify() { 64 spirv::Scope scope = getExecutionScope(); 65 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) 66 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); 67 68 if (auto localIdTy = llvm::dyn_cast<VectorType>(getLocalid().getType())) 69 if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3) 70 return emitOpError("localid is a vector and can be with only " 71 " 2 or 3 components, actual number is ") 72 << localIdTy.getNumElements(); 73 74 return success(); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // spirv.GroupNonUniformBallotOp 79 //===----------------------------------------------------------------------===// 80 81 LogicalResult GroupNonUniformBallotOp::verify() { 82 spirv::Scope scope = getExecutionScope(); 83 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) 84 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); 85 86 return success(); 87 } 88 89 //===----------------------------------------------------------------------===// 90 // spirv.GroupNonUniformBallotFindLSBOp 91 //===----------------------------------------------------------------------===// 92 93 LogicalResult GroupNonUniformBallotFindLSBOp::verify() { 94 spirv::Scope scope = getExecutionScope(); 95 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) 96 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); 97 98 return success(); 99 } 100 101 //===----------------------------------------------------------------------===// 102 // spirv.GroupNonUniformBallotFindLSBOp 103 //===----------------------------------------------------------------------===// 104 105 LogicalResult GroupNonUniformBallotFindMSBOp::verify() { 106 spirv::Scope scope = getExecutionScope(); 107 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) 108 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); 109 110 return success(); 111 } 112 113 //===----------------------------------------------------------------------===// 114 // spirv.GroupNonUniformBroadcast 115 //===----------------------------------------------------------------------===// 116 117 LogicalResult GroupNonUniformBroadcastOp::verify() { 118 spirv::Scope scope = getExecutionScope(); 119 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) 120 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); 121 122 // SPIR-V spec: "Before version 1.5, Id must come from a 123 // constant instruction. 124 auto targetEnv = spirv::getDefaultTargetEnv(getContext()); 125 if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>()) 126 targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule); 127 128 if (targetEnv.getVersion() < spirv::Version::V_1_5) { 129 auto *idOp = getId().getDefiningOp(); 130 if (!idOp || !isa<spirv::ConstantOp, // for normal constant 131 spirv::ReferenceOfOp>(idOp)) // for spec constant 132 return emitOpError("id must be the result of a constant op"); 133 } 134 135 return success(); 136 } 137 138 //===----------------------------------------------------------------------===// 139 // spirv.GroupNonUniformShuffle* 140 //===----------------------------------------------------------------------===// 141 142 template <typename OpTy> 143 static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) { 144 spirv::Scope scope = op.getExecutionScope(); 145 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) 146 return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); 147 148 if (op.getOperands().back().getType().isSignedInteger()) 149 return op.emitOpError("second operand must be a singless/unsigned integer"); 150 151 return success(); 152 } 153 154 LogicalResult GroupNonUniformShuffleOp::verify() { 155 return verifyGroupNonUniformShuffleOp(*this); 156 } 157 LogicalResult GroupNonUniformShuffleDownOp::verify() { 158 return verifyGroupNonUniformShuffleOp(*this); 159 } 160 LogicalResult GroupNonUniformShuffleUpOp::verify() { 161 return verifyGroupNonUniformShuffleOp(*this); 162 } 163 LogicalResult GroupNonUniformShuffleXorOp::verify() { 164 return verifyGroupNonUniformShuffleOp(*this); 165 } 166 167 //===----------------------------------------------------------------------===// 168 // spirv.GroupNonUniformElectOp 169 //===----------------------------------------------------------------------===// 170 171 LogicalResult GroupNonUniformElectOp::verify() { 172 spirv::Scope scope = getExecutionScope(); 173 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) 174 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); 175 176 return success(); 177 } 178 179 //===----------------------------------------------------------------------===// 180 // spirv.GroupNonUniformFAddOp 181 //===----------------------------------------------------------------------===// 182 183 LogicalResult GroupNonUniformFAddOp::verify() { 184 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFAddOp>(*this); 185 } 186 187 //===----------------------------------------------------------------------===// 188 // spirv.GroupNonUniformFMaxOp 189 //===----------------------------------------------------------------------===// 190 191 LogicalResult GroupNonUniformFMaxOp::verify() { 192 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMaxOp>(*this); 193 } 194 195 //===----------------------------------------------------------------------===// 196 // spirv.GroupNonUniformFMinOp 197 //===----------------------------------------------------------------------===// 198 199 LogicalResult GroupNonUniformFMinOp::verify() { 200 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMinOp>(*this); 201 } 202 203 //===----------------------------------------------------------------------===// 204 // spirv.GroupNonUniformFMulOp 205 //===----------------------------------------------------------------------===// 206 207 LogicalResult GroupNonUniformFMulOp::verify() { 208 return verifyGroupNonUniformArithmeticOp<GroupNonUniformFMulOp>(*this); 209 } 210 211 //===----------------------------------------------------------------------===// 212 // spirv.GroupNonUniformIAddOp 213 //===----------------------------------------------------------------------===// 214 215 LogicalResult GroupNonUniformIAddOp::verify() { 216 return verifyGroupNonUniformArithmeticOp<GroupNonUniformIAddOp>(*this); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // spirv.GroupNonUniformIMulOp 221 //===----------------------------------------------------------------------===// 222 223 LogicalResult GroupNonUniformIMulOp::verify() { 224 return verifyGroupNonUniformArithmeticOp<GroupNonUniformIMulOp>(*this); 225 } 226 227 //===----------------------------------------------------------------------===// 228 // spirv.GroupNonUniformSMaxOp 229 //===----------------------------------------------------------------------===// 230 231 LogicalResult GroupNonUniformSMaxOp::verify() { 232 return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMaxOp>(*this); 233 } 234 235 //===----------------------------------------------------------------------===// 236 // spirv.GroupNonUniformSMinOp 237 //===----------------------------------------------------------------------===// 238 239 LogicalResult GroupNonUniformSMinOp::verify() { 240 return verifyGroupNonUniformArithmeticOp<GroupNonUniformSMinOp>(*this); 241 } 242 243 //===----------------------------------------------------------------------===// 244 // spirv.GroupNonUniformUMaxOp 245 //===----------------------------------------------------------------------===// 246 247 LogicalResult GroupNonUniformUMaxOp::verify() { 248 return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMaxOp>(*this); 249 } 250 251 //===----------------------------------------------------------------------===// 252 // spirv.GroupNonUniformUMinOp 253 //===----------------------------------------------------------------------===// 254 255 LogicalResult GroupNonUniformUMinOp::verify() { 256 return verifyGroupNonUniformArithmeticOp<GroupNonUniformUMinOp>(*this); 257 } 258 259 //===----------------------------------------------------------------------===// 260 // spirv.GroupNonUniformBitwiseAnd 261 //===----------------------------------------------------------------------===// 262 263 LogicalResult GroupNonUniformBitwiseAndOp::verify() { 264 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseAndOp>(*this); 265 } 266 267 //===----------------------------------------------------------------------===// 268 // spirv.GroupNonUniformBitwiseOr 269 //===----------------------------------------------------------------------===// 270 271 LogicalResult GroupNonUniformBitwiseOrOp::verify() { 272 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseOrOp>(*this); 273 } 274 275 //===----------------------------------------------------------------------===// 276 // spirv.GroupNonUniformBitwiseXor 277 //===----------------------------------------------------------------------===// 278 279 LogicalResult GroupNonUniformBitwiseXorOp::verify() { 280 return verifyGroupNonUniformArithmeticOp<GroupNonUniformBitwiseXorOp>(*this); 281 } 282 283 //===----------------------------------------------------------------------===// 284 // spirv.GroupNonUniformLogicalAnd 285 //===----------------------------------------------------------------------===// 286 287 LogicalResult GroupNonUniformLogicalAndOp::verify() { 288 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalAndOp>(*this); 289 } 290 291 //===----------------------------------------------------------------------===// 292 // spirv.GroupNonUniformLogicalOr 293 //===----------------------------------------------------------------------===// 294 295 LogicalResult GroupNonUniformLogicalOrOp::verify() { 296 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalOrOp>(*this); 297 } 298 299 //===----------------------------------------------------------------------===// 300 // spirv.GroupNonUniformLogicalXor 301 //===----------------------------------------------------------------------===// 302 303 LogicalResult GroupNonUniformLogicalXorOp::verify() { 304 return verifyGroupNonUniformArithmeticOp<GroupNonUniformLogicalXorOp>(*this); 305 } 306 307 //===----------------------------------------------------------------------===// 308 // Group op verification 309 //===----------------------------------------------------------------------===// 310 311 template <typename Op> 312 static LogicalResult verifyGroupOp(Op op) { 313 spirv::Scope scope = op.getExecutionScope(); 314 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) 315 return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); 316 317 return success(); 318 } 319 320 LogicalResult GroupIAddOp::verify() { return verifyGroupOp(*this); } 321 322 LogicalResult GroupFAddOp::verify() { return verifyGroupOp(*this); } 323 324 LogicalResult GroupFMinOp::verify() { return verifyGroupOp(*this); } 325 326 LogicalResult GroupUMinOp::verify() { return verifyGroupOp(*this); } 327 328 LogicalResult GroupSMinOp::verify() { return verifyGroupOp(*this); } 329 330 LogicalResult GroupFMaxOp::verify() { return verifyGroupOp(*this); } 331 332 LogicalResult GroupUMaxOp::verify() { return verifyGroupOp(*this); } 333 334 LogicalResult GroupSMaxOp::verify() { return verifyGroupOp(*this); } 335 336 LogicalResult GroupIMulKHROp::verify() { return verifyGroupOp(*this); } 337 338 LogicalResult GroupFMulKHROp::verify() { return verifyGroupOp(*this); } 339 340 } // namespace mlir::spirv 341