1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" 14 #include "mlir/Conversion/LLVMCommon/Pattern.h" 15 #include "mlir/Conversion/LLVMCommon/TypeConverter.h" 16 #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 19 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 20 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 21 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/Support/Debug.h" 27 #include "llvm/Support/FormatVariadic.h" 28 29 #define DEBUG_TYPE "spirv-to-llvm-pattern" 30 31 using namespace mlir; 32 33 //===----------------------------------------------------------------------===// 34 // Utility functions 35 //===----------------------------------------------------------------------===// 36 37 /// Returns true if the given type is a signed integer or vector type. 38 static bool isSignedIntegerOrVector(Type type) { 39 if (type.isSignedInteger()) 40 return true; 41 if (auto vecType = dyn_cast<VectorType>(type)) 42 return vecType.getElementType().isSignedInteger(); 43 return false; 44 } 45 46 /// Returns true if the given type is an unsigned integer or vector type 47 static bool isUnsignedIntegerOrVector(Type type) { 48 if (type.isUnsignedInteger()) 49 return true; 50 if (auto vecType = dyn_cast<VectorType>(type)) 51 return vecType.getElementType().isUnsignedInteger(); 52 return false; 53 } 54 55 /// Returns the width of an integer or of the element type of an integer vector, 56 /// if applicable. 57 static std::optional<uint64_t> getIntegerOrVectorElementWidth(Type type) { 58 if (auto intType = dyn_cast<IntegerType>(type)) 59 return intType.getWidth(); 60 if (auto vecType = dyn_cast<VectorType>(type)) 61 if (auto intType = dyn_cast<IntegerType>(vecType.getElementType())) 62 return intType.getWidth(); 63 return std::nullopt; 64 } 65 66 /// Returns the bit width of integer, float or vector of float or integer values 67 static unsigned getBitWidth(Type type) { 68 assert((type.isIntOrFloat() || isa<VectorType>(type)) && 69 "bitwidth is not supported for this type"); 70 if (type.isIntOrFloat()) 71 return type.getIntOrFloatBitWidth(); 72 auto vecType = dyn_cast<VectorType>(type); 73 auto elementType = vecType.getElementType(); 74 assert(elementType.isIntOrFloat() && 75 "only integers and floats have a bitwidth"); 76 return elementType.getIntOrFloatBitWidth(); 77 } 78 79 /// Returns the bit width of LLVMType integer or vector. 80 static unsigned getLLVMTypeBitWidth(Type type) { 81 return cast<IntegerType>((LLVM::isCompatibleVectorType(type) 82 ? LLVM::getVectorElementType(type) 83 : type)) 84 .getWidth(); 85 } 86 87 /// Creates `IntegerAttribute` with all bits set for given type 88 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { 89 if (auto vecType = dyn_cast<VectorType>(type)) { 90 auto integerType = cast<IntegerType>(vecType.getElementType()); 91 return builder.getIntegerAttr(integerType, -1); 92 } 93 auto integerType = cast<IntegerType>(type); 94 return builder.getIntegerAttr(integerType, -1); 95 } 96 97 /// Creates `llvm.mlir.constant` with all bits set for the given type. 98 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, 99 PatternRewriter &rewriter) { 100 if (isa<VectorType>(srcType)) { 101 return rewriter.create<LLVM::ConstantOp>( 102 loc, dstType, 103 SplatElementsAttr::get(cast<ShapedType>(srcType), 104 minusOneIntegerAttribute(srcType, rewriter))); 105 } 106 return rewriter.create<LLVM::ConstantOp>( 107 loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); 108 } 109 110 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. 111 static Value createFPConstant(Location loc, Type srcType, Type dstType, 112 PatternRewriter &rewriter, double value) { 113 if (auto vecType = dyn_cast<VectorType>(srcType)) { 114 auto floatType = cast<FloatType>(vecType.getElementType()); 115 return rewriter.create<LLVM::ConstantOp>( 116 loc, dstType, 117 SplatElementsAttr::get(vecType, 118 rewriter.getFloatAttr(floatType, value))); 119 } 120 auto floatType = cast<FloatType>(srcType); 121 return rewriter.create<LLVM::ConstantOp>( 122 loc, dstType, rewriter.getFloatAttr(floatType, value)); 123 } 124 125 /// Utility function for bitfield ops: 126 /// - `BitFieldInsert` 127 /// - `BitFieldSExtract` 128 /// - `BitFieldUExtract` 129 /// Truncates or extends the value. If the bitwidth of the value is the same as 130 /// `llvmType` bitwidth, the value remains unchanged. 131 static Value optionallyTruncateOrExtend(Location loc, Value value, 132 Type llvmType, 133 PatternRewriter &rewriter) { 134 auto srcType = value.getType(); 135 unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType); 136 unsigned valueBitWidth = LLVM::isCompatibleType(srcType) 137 ? getLLVMTypeBitWidth(srcType) 138 : getBitWidth(srcType); 139 140 if (valueBitWidth < targetBitWidth) 141 return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value); 142 // If the bit widths of `Count` and `Offset` are greater than the bit width 143 // of the target type, they are truncated. Truncation is safe since `Count` 144 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, 145 // both values can be expressed in 8 bits. 146 if (valueBitWidth > targetBitWidth) 147 return rewriter.create<LLVM::TruncOp>(loc, llvmType, value); 148 return value; 149 } 150 151 /// Broadcasts the value to vector with `numElements` number of elements. 152 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, 153 const TypeConverter &typeConverter, 154 ConversionPatternRewriter &rewriter) { 155 auto vectorType = VectorType::get(numElements, toBroadcast.getType()); 156 auto llvmVectorType = typeConverter.convertType(vectorType); 157 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); 158 Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType); 159 for (unsigned i = 0; i < numElements; ++i) { 160 auto index = rewriter.create<LLVM::ConstantOp>( 161 loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); 162 broadcasted = rewriter.create<LLVM::InsertElementOp>( 163 loc, llvmVectorType, broadcasted, toBroadcast, index); 164 } 165 return broadcasted; 166 } 167 168 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged. 169 static Value optionallyBroadcast(Location loc, Value value, Type srcType, 170 const TypeConverter &typeConverter, 171 ConversionPatternRewriter &rewriter) { 172 if (auto vectorType = dyn_cast<VectorType>(srcType)) { 173 unsigned numElements = vectorType.getNumElements(); 174 return broadcast(loc, value, numElements, typeConverter, rewriter); 175 } 176 return value; 177 } 178 179 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and 180 /// `BitFieldUExtract`. 181 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of 182 /// a vector type, construct a vector that has: 183 /// - same number of elements as `Base` 184 /// - each element has the type that is the same as the type of `Offset` or 185 /// `Count` 186 /// - each element has the same value as `Offset` or `Count` 187 /// Then cast `Offset` and `Count` if their bit width is different 188 /// from `Base` bit width. 189 static Value processCountOrOffset(Location loc, Value value, Type srcType, 190 Type dstType, const TypeConverter &converter, 191 ConversionPatternRewriter &rewriter) { 192 Value broadcasted = 193 optionallyBroadcast(loc, value, srcType, converter, rewriter); 194 return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter); 195 } 196 197 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`) 198 /// offset to LLVM struct. Otherwise, the conversion is not supported. 199 static Type convertStructTypeWithOffset(spirv::StructType type, 200 const TypeConverter &converter) { 201 if (type != VulkanLayoutUtils::decorateType(type)) 202 return nullptr; 203 204 SmallVector<Type> elementsVector; 205 if (failed(converter.convertTypes(type.getElementTypes(), elementsVector))) 206 return nullptr; 207 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, 208 /*isPacked=*/false); 209 } 210 211 /// Converts SPIR-V struct with no offset to packed LLVM struct. 212 static Type convertStructTypePacked(spirv::StructType type, 213 const TypeConverter &converter) { 214 SmallVector<Type> elementsVector; 215 if (failed(converter.convertTypes(type.getElementTypes(), elementsVector))) 216 return nullptr; 217 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, 218 /*isPacked=*/true); 219 } 220 221 /// Creates LLVM dialect constant with the given value. 222 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, 223 unsigned value) { 224 return rewriter.create<LLVM::ConstantOp>( 225 loc, IntegerType::get(rewriter.getContext(), 32), 226 rewriter.getIntegerAttr(rewriter.getI32Type(), value)); 227 } 228 229 /// Utility for `spirv.Load` and `spirv.Store` conversion. 230 static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, 231 ConversionPatternRewriter &rewriter, 232 const TypeConverter &typeConverter, 233 unsigned alignment, bool isVolatile, 234 bool isNonTemporal) { 235 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) { 236 auto dstType = typeConverter.convertType(loadOp.getType()); 237 if (!dstType) 238 return rewriter.notifyMatchFailure(op, "type conversion failed"); 239 rewriter.replaceOpWithNewOp<LLVM::LoadOp>( 240 loadOp, dstType, spirv::LoadOpAdaptor(operands).getPtr(), alignment, 241 isVolatile, isNonTemporal); 242 return success(); 243 } 244 auto storeOp = cast<spirv::StoreOp>(op); 245 spirv::StoreOpAdaptor adaptor(operands); 246 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, adaptor.getValue(), 247 adaptor.getPtr(), alignment, 248 isVolatile, isNonTemporal); 249 return success(); 250 } 251 252 //===----------------------------------------------------------------------===// 253 // Type conversion 254 //===----------------------------------------------------------------------===// 255 256 /// Converts SPIR-V array type to LLVM array. Natural stride (according to 257 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected 258 /// when converting ops that manipulate array types. 259 static std::optional<Type> convertArrayType(spirv::ArrayType type, 260 TypeConverter &converter) { 261 unsigned stride = type.getArrayStride(); 262 Type elementType = type.getElementType(); 263 auto sizeInBytes = cast<spirv::SPIRVType>(elementType).getSizeInBytes(); 264 if (stride != 0 && (!sizeInBytes || *sizeInBytes != stride)) 265 return std::nullopt; 266 267 auto llvmElementType = converter.convertType(elementType); 268 unsigned numElements = type.getNumElements(); 269 return LLVM::LLVMArrayType::get(llvmElementType, numElements); 270 } 271 272 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not 273 /// modelled at the moment. 274 static Type convertPointerType(spirv::PointerType type, 275 const TypeConverter &converter, 276 spirv::ClientAPI clientAPI) { 277 unsigned addressSpace = 278 storageClassToAddressSpace(clientAPI, type.getStorageClass()); 279 return LLVM::LLVMPointerType::get(type.getContext(), addressSpace); 280 } 281 282 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over 283 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is 284 /// no modelling of array stride at the moment. 285 static std::optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type, 286 TypeConverter &converter) { 287 if (type.getArrayStride() != 0) 288 return std::nullopt; 289 auto elementType = converter.convertType(type.getElementType()); 290 return LLVM::LLVMArrayType::get(elementType, 0); 291 } 292 293 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with 294 /// member decorations. Also, only natural offset is supported. 295 static Type convertStructType(spirv::StructType type, 296 const TypeConverter &converter) { 297 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 298 type.getMemberDecorations(memberDecorations); 299 if (!memberDecorations.empty()) 300 return nullptr; 301 if (type.hasOffset()) 302 return convertStructTypeWithOffset(type, converter); 303 return convertStructTypePacked(type, converter); 304 } 305 306 //===----------------------------------------------------------------------===// 307 // Operation conversion 308 //===----------------------------------------------------------------------===// 309 310 namespace { 311 312 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> { 313 public: 314 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion; 315 316 LogicalResult 317 matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor, 318 ConversionPatternRewriter &rewriter) const override { 319 auto dstType = 320 getTypeConverter()->convertType(op.getComponentPtr().getType()); 321 if (!dstType) 322 return rewriter.notifyMatchFailure(op, "type conversion failed"); 323 // To use GEP we need to add a first 0 index to go through the pointer. 324 auto indices = llvm::to_vector<4>(adaptor.getIndices()); 325 Type indexType = op.getIndices().front().getType(); 326 auto llvmIndexType = getTypeConverter()->convertType(indexType); 327 if (!llvmIndexType) 328 return rewriter.notifyMatchFailure(op, "type conversion failed"); 329 Value zero = rewriter.create<LLVM::ConstantOp>( 330 op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); 331 indices.insert(indices.begin(), zero); 332 333 auto elementType = getTypeConverter()->convertType( 334 cast<spirv::PointerType>(op.getBasePtr().getType()).getPointeeType()); 335 if (!elementType) 336 return rewriter.notifyMatchFailure(op, "type conversion failed"); 337 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, elementType, 338 adaptor.getBasePtr(), indices); 339 return success(); 340 } 341 }; 342 343 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> { 344 public: 345 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion; 346 347 LogicalResult 348 matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor, 349 ConversionPatternRewriter &rewriter) const override { 350 auto dstType = getTypeConverter()->convertType(op.getPointer().getType()); 351 if (!dstType) 352 return rewriter.notifyMatchFailure(op, "type conversion failed"); 353 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, 354 op.getVariable()); 355 return success(); 356 } 357 }; 358 359 class BitFieldInsertPattern 360 : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> { 361 public: 362 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion; 363 364 LogicalResult 365 matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor, 366 ConversionPatternRewriter &rewriter) const override { 367 auto srcType = op.getType(); 368 auto dstType = getTypeConverter()->convertType(srcType); 369 if (!dstType) 370 return rewriter.notifyMatchFailure(op, "type conversion failed"); 371 Location loc = op.getLoc(); 372 373 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 374 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, 375 *getTypeConverter(), rewriter); 376 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, 377 *getTypeConverter(), rewriter); 378 379 // Create a mask with bits set outside [Offset, Offset + Count - 1]. 380 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); 381 Value maskShiftedByCount = 382 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); 383 Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType, 384 maskShiftedByCount, minusOne); 385 Value maskShiftedByCountAndOffset = 386 rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset); 387 Value mask = rewriter.create<LLVM::XOrOp>( 388 loc, dstType, maskShiftedByCountAndOffset, minusOne); 389 390 // Extract unchanged bits from the `Base` that are outside of 391 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. 392 Value baseAndMask = 393 rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask); 394 Value insertShiftedByOffset = 395 rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset); 396 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask, 397 insertShiftedByOffset); 398 return success(); 399 } 400 }; 401 402 /// Converts SPIR-V ConstantOp with scalar or vector type. 403 class ConstantScalarAndVectorPattern 404 : public SPIRVToLLVMConversion<spirv::ConstantOp> { 405 public: 406 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion; 407 408 LogicalResult 409 matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor, 410 ConversionPatternRewriter &rewriter) const override { 411 auto srcType = constOp.getType(); 412 if (!isa<VectorType>(srcType) && !srcType.isIntOrFloat()) 413 return failure(); 414 415 auto dstType = getTypeConverter()->convertType(srcType); 416 if (!dstType) 417 return rewriter.notifyMatchFailure(constOp, "type conversion failed"); 418 419 // SPIR-V constant can be a signed/unsigned integer, which has to be 420 // casted to signless integer when converting to LLVM dialect. Removing the 421 // sign bit may have unexpected behaviour. However, it is better to handle 422 // it case-by-case, given that the purpose of the conversion is not to 423 // cover all possible corner cases. 424 if (isSignedIntegerOrVector(srcType) || 425 isUnsignedIntegerOrVector(srcType)) { 426 auto signlessType = rewriter.getIntegerType(getBitWidth(srcType)); 427 428 if (isa<VectorType>(srcType)) { 429 auto dstElementsAttr = cast<DenseIntElementsAttr>(constOp.getValue()); 430 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( 431 constOp, dstType, 432 dstElementsAttr.mapValues( 433 signlessType, [&](const APInt &value) { return value; })); 434 return success(); 435 } 436 auto srcAttr = cast<IntegerAttr>(constOp.getValue()); 437 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue()); 438 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr); 439 return success(); 440 } 441 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>( 442 constOp, dstType, adaptor.getOperands(), constOp->getAttrs()); 443 return success(); 444 } 445 }; 446 447 class BitFieldSExtractPattern 448 : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> { 449 public: 450 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion; 451 452 LogicalResult 453 matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor, 454 ConversionPatternRewriter &rewriter) const override { 455 auto srcType = op.getType(); 456 auto dstType = getTypeConverter()->convertType(srcType); 457 if (!dstType) 458 return rewriter.notifyMatchFailure(op, "type conversion failed"); 459 Location loc = op.getLoc(); 460 461 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 462 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, 463 *getTypeConverter(), rewriter); 464 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, 465 *getTypeConverter(), rewriter); 466 467 // Create a constant that holds the size of the `Base`. 468 IntegerType integerType; 469 if (auto vecType = dyn_cast<VectorType>(srcType)) 470 integerType = cast<IntegerType>(vecType.getElementType()); 471 else 472 integerType = cast<IntegerType>(srcType); 473 474 auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); 475 Value size = 476 isa<VectorType>(srcType) 477 ? rewriter.create<LLVM::ConstantOp>( 478 loc, dstType, 479 SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize)) 480 : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize); 481 482 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit 483 // at Offset + Count - 1 is the most significant bit now. 484 Value countPlusOffset = 485 rewriter.create<LLVM::AddOp>(loc, dstType, count, offset); 486 Value amountToShiftLeft = 487 rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset); 488 Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>( 489 loc, dstType, op.getBase(), amountToShiftLeft); 490 491 // Shift the result right, filling the bits with the sign bit. 492 Value amountToShiftRight = 493 rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft); 494 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft, 495 amountToShiftRight); 496 return success(); 497 } 498 }; 499 500 class BitFieldUExtractPattern 501 : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> { 502 public: 503 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion; 504 505 LogicalResult 506 matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor, 507 ConversionPatternRewriter &rewriter) const override { 508 auto srcType = op.getType(); 509 auto dstType = getTypeConverter()->convertType(srcType); 510 if (!dstType) 511 return rewriter.notifyMatchFailure(op, "type conversion failed"); 512 Location loc = op.getLoc(); 513 514 // Process `Offset` and `Count`: broadcast and extend/truncate if needed. 515 Value offset = processCountOrOffset(loc, op.getOffset(), srcType, dstType, 516 *getTypeConverter(), rewriter); 517 Value count = processCountOrOffset(loc, op.getCount(), srcType, dstType, 518 *getTypeConverter(), rewriter); 519 520 // Create a mask with bits set at [0, Count - 1]. 521 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); 522 Value maskShiftedByCount = 523 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); 524 Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount, 525 minusOne); 526 527 // Shift `Base` by `Offset` and apply the mask on it. 528 Value shiftedBase = 529 rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset); 530 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask); 531 return success(); 532 } 533 }; 534 535 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> { 536 public: 537 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion; 538 539 LogicalResult 540 matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor, 541 ConversionPatternRewriter &rewriter) const override { 542 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, adaptor.getOperands(), 543 branchOp.getTarget()); 544 return success(); 545 } 546 }; 547 548 class BranchConditionalConversionPattern 549 : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> { 550 public: 551 using SPIRVToLLVMConversion< 552 spirv::BranchConditionalOp>::SPIRVToLLVMConversion; 553 554 LogicalResult 555 matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor, 556 ConversionPatternRewriter &rewriter) const override { 557 // If branch weights exist, map them to 32-bit integer vector. 558 DenseI32ArrayAttr branchWeights = nullptr; 559 if (auto weights = op.getBranchWeights()) { 560 SmallVector<int32_t> weightValues; 561 for (auto weight : weights->getAsRange<IntegerAttr>()) 562 weightValues.push_back(weight.getInt()); 563 branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues); 564 } 565 566 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( 567 op, op.getCondition(), op.getTrueBlockArguments(), 568 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(), 569 op.getFalseBlock()); 570 return success(); 571 } 572 }; 573 574 /// Converts `spirv.getCompositeExtract` to `llvm.extractvalue` if the container 575 /// type is an aggregate type (struct or array). Otherwise, converts to 576 /// `llvm.extractelement` that operates on vectors. 577 class CompositeExtractPattern 578 : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> { 579 public: 580 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion; 581 582 LogicalResult 583 matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor, 584 ConversionPatternRewriter &rewriter) const override { 585 auto dstType = this->getTypeConverter()->convertType(op.getType()); 586 if (!dstType) 587 return rewriter.notifyMatchFailure(op, "type conversion failed"); 588 589 Type containerType = op.getComposite().getType(); 590 if (isa<VectorType>(containerType)) { 591 Location loc = op.getLoc(); 592 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]); 593 Value index = createI32ConstantOf(loc, rewriter, value.getInt()); 594 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( 595 op, dstType, adaptor.getComposite(), index); 596 return success(); 597 } 598 599 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>( 600 op, adaptor.getComposite(), 601 LLVM::convertArrayToIndices(op.getIndices())); 602 return success(); 603 } 604 }; 605 606 /// Converts `spirv.getCompositeInsert` to `llvm.insertvalue` if the container 607 /// type is an aggregate type (struct or array). Otherwise, converts to 608 /// `llvm.insertelement` that operates on vectors. 609 class CompositeInsertPattern 610 : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> { 611 public: 612 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion; 613 614 LogicalResult 615 matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor, 616 ConversionPatternRewriter &rewriter) const override { 617 auto dstType = this->getTypeConverter()->convertType(op.getType()); 618 if (!dstType) 619 return rewriter.notifyMatchFailure(op, "type conversion failed"); 620 621 Type containerType = op.getComposite().getType(); 622 if (isa<VectorType>(containerType)) { 623 Location loc = op.getLoc(); 624 IntegerAttr value = cast<IntegerAttr>(op.getIndices()[0]); 625 Value index = createI32ConstantOf(loc, rewriter, value.getInt()); 626 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( 627 op, dstType, adaptor.getComposite(), adaptor.getObject(), index); 628 return success(); 629 } 630 631 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>( 632 op, adaptor.getComposite(), adaptor.getObject(), 633 LLVM::convertArrayToIndices(op.getIndices())); 634 return success(); 635 } 636 }; 637 638 /// Converts SPIR-V operations that have straightforward LLVM equivalent 639 /// into LLVM dialect operations. 640 template <typename SPIRVOp, typename LLVMOp> 641 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> { 642 public: 643 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 644 645 LogicalResult 646 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 647 ConversionPatternRewriter &rewriter) const override { 648 auto dstType = this->getTypeConverter()->convertType(op.getType()); 649 if (!dstType) 650 return rewriter.notifyMatchFailure(op, "type conversion failed"); 651 rewriter.template replaceOpWithNewOp<LLVMOp>( 652 op, dstType, adaptor.getOperands(), op->getAttrs()); 653 return success(); 654 } 655 }; 656 657 /// Converts `spirv.ExecutionMode` into a global struct constant that holds 658 /// execution mode information. 659 class ExecutionModePattern 660 : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> { 661 public: 662 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion; 663 664 LogicalResult 665 matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor, 666 ConversionPatternRewriter &rewriter) const override { 667 // First, create the global struct's name that would be associated with 668 // this entry point's execution mode. We set it to be: 669 // __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode} 670 ModuleOp module = op->getParentOfType<ModuleOp>(); 671 spirv::ExecutionModeAttr executionModeAttr = op.getExecutionModeAttr(); 672 std::string moduleName; 673 if (module.getName().has_value()) 674 moduleName = "_" + module.getName()->str(); 675 else 676 moduleName = ""; 677 std::string executionModeInfoName = llvm::formatv( 678 "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.getFn().str(), 679 static_cast<uint32_t>(executionModeAttr.getValue())); 680 681 MLIRContext *context = rewriter.getContext(); 682 OpBuilder::InsertionGuard guard(rewriter); 683 rewriter.setInsertionPointToStart(module.getBody()); 684 685 // Create a struct type, corresponding to the C struct below. 686 // struct { 687 // int32_t executionMode; 688 // int32_t values[]; // optional values 689 // }; 690 auto llvmI32Type = IntegerType::get(context, 32); 691 SmallVector<Type, 2> fields; 692 fields.push_back(llvmI32Type); 693 ArrayAttr values = op.getValues(); 694 if (!values.empty()) { 695 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size()); 696 fields.push_back(arrayType); 697 } 698 auto structType = LLVM::LLVMStructType::getLiteral(context, fields); 699 700 // Create `llvm.mlir.global` with initializer region containing one block. 701 auto global = rewriter.create<LLVM::GlobalOp>( 702 UnknownLoc::get(context), structType, /*isConstant=*/true, 703 LLVM::Linkage::External, executionModeInfoName, Attribute(), 704 /*alignment=*/0); 705 Location loc = global.getLoc(); 706 Region ®ion = global.getInitializerRegion(); 707 Block *block = rewriter.createBlock(®ion); 708 709 // Initialize the struct and set the execution mode value. 710 rewriter.setInsertionPointToStart(block); 711 Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType); 712 Value executionMode = rewriter.create<LLVM::ConstantOp>( 713 loc, llvmI32Type, 714 rewriter.getI32IntegerAttr( 715 static_cast<uint32_t>(executionModeAttr.getValue()))); 716 structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue, 717 executionMode, 0); 718 719 // Insert extra operands if they exist into execution mode info struct. 720 for (unsigned i = 0, e = values.size(); i < e; ++i) { 721 auto attr = values.getValue()[i]; 722 Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr); 723 structValue = rewriter.create<LLVM::InsertValueOp>( 724 loc, structValue, entry, ArrayRef<int64_t>({1, i})); 725 } 726 rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue})); 727 rewriter.eraseOp(op); 728 return success(); 729 } 730 }; 731 732 /// Converts `spirv.GlobalVariable` to `llvm.mlir.global`. Note that SPIR-V 733 /// global returns a pointer, whereas in LLVM dialect the global holds an actual 734 /// value. This difference is handled by `spirv.mlir.addressof` and 735 /// `llvm.mlir.addressof`ops that both return a pointer. 736 class GlobalVariablePattern 737 : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> { 738 public: 739 template <typename... Args> 740 GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args) 741 : SPIRVToLLVMConversion<spirv::GlobalVariableOp>( 742 std::forward<Args>(args)...), 743 clientAPI(clientAPI) {} 744 745 LogicalResult 746 matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor, 747 ConversionPatternRewriter &rewriter) const override { 748 // Currently, there is no support of initialization with a constant value in 749 // SPIR-V dialect. Specialization constants are not considered as well. 750 if (op.getInitializer()) 751 return failure(); 752 753 auto srcType = cast<spirv::PointerType>(op.getType()); 754 auto dstType = getTypeConverter()->convertType(srcType.getPointeeType()); 755 if (!dstType) 756 return rewriter.notifyMatchFailure(op, "type conversion failed"); 757 758 // Limit conversion to the current invocation only or `StorageBuffer` 759 // required by SPIR-V runner. 760 // This is okay because multiple invocations are not supported yet. 761 auto storageClass = srcType.getStorageClass(); 762 switch (storageClass) { 763 case spirv::StorageClass::Input: 764 case spirv::StorageClass::Private: 765 case spirv::StorageClass::Output: 766 case spirv::StorageClass::StorageBuffer: 767 case spirv::StorageClass::UniformConstant: 768 break; 769 default: 770 return failure(); 771 } 772 773 // LLVM dialect spec: "If the global value is a constant, storing into it is 774 // not allowed.". This corresponds to SPIR-V 'Input' and 'UniformConstant' 775 // storage class that is read-only. 776 bool isConstant = (storageClass == spirv::StorageClass::Input) || 777 (storageClass == spirv::StorageClass::UniformConstant); 778 // SPIR-V spec: "By default, functions and global variables are private to a 779 // module and cannot be accessed by other modules. However, a module may be 780 // written to export or import functions and global (module scope) 781 // variables.". Therefore, map 'Private' storage class to private linkage, 782 // 'Input' and 'Output' to external linkage. 783 auto linkage = storageClass == spirv::StorageClass::Private 784 ? LLVM::Linkage::Private 785 : LLVM::Linkage::External; 786 auto newGlobalOp = rewriter.replaceOpWithNewOp<LLVM::GlobalOp>( 787 op, dstType, isConstant, linkage, op.getSymName(), Attribute(), 788 /*alignment=*/0, storageClassToAddressSpace(clientAPI, storageClass)); 789 790 // Attach location attribute if applicable 791 if (op.getLocationAttr()) 792 newGlobalOp->setAttr(op.getLocationAttrName(), op.getLocationAttr()); 793 794 return success(); 795 } 796 797 private: 798 spirv::ClientAPI clientAPI; 799 }; 800 801 /// Converts SPIR-V cast ops that do not have straightforward LLVM 802 /// equivalent in LLVM dialect. 803 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp> 804 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> { 805 public: 806 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 807 808 LogicalResult 809 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 810 ConversionPatternRewriter &rewriter) const override { 811 812 Type fromType = op.getOperand().getType(); 813 Type toType = op.getType(); 814 815 auto dstType = this->getTypeConverter()->convertType(toType); 816 if (!dstType) 817 return rewriter.notifyMatchFailure(op, "type conversion failed"); 818 819 if (getBitWidth(fromType) < getBitWidth(toType)) { 820 rewriter.template replaceOpWithNewOp<LLVMExtOp>(op, dstType, 821 adaptor.getOperands()); 822 return success(); 823 } 824 if (getBitWidth(fromType) > getBitWidth(toType)) { 825 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(op, dstType, 826 adaptor.getOperands()); 827 return success(); 828 } 829 return failure(); 830 } 831 }; 832 833 class FunctionCallPattern 834 : public SPIRVToLLVMConversion<spirv::FunctionCallOp> { 835 public: 836 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion; 837 838 LogicalResult 839 matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor, 840 ConversionPatternRewriter &rewriter) const override { 841 if (callOp.getNumResults() == 0) { 842 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( 843 callOp, std::nullopt, adaptor.getOperands(), callOp->getAttrs()); 844 newOp.getProperties().operandSegmentSizes = { 845 static_cast<int32_t>(adaptor.getOperands().size()), 0}; 846 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); 847 return success(); 848 } 849 850 // Function returns a single result. 851 auto dstType = getTypeConverter()->convertType(callOp.getType(0)); 852 if (!dstType) 853 return rewriter.notifyMatchFailure(callOp, "type conversion failed"); 854 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( 855 callOp, dstType, adaptor.getOperands(), callOp->getAttrs()); 856 newOp.getProperties().operandSegmentSizes = { 857 static_cast<int32_t>(adaptor.getOperands().size()), 0}; 858 newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); 859 return success(); 860 } 861 }; 862 863 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" 864 template <typename SPIRVOp, LLVM::FCmpPredicate predicate> 865 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { 866 public: 867 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 868 869 LogicalResult 870 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 871 ConversionPatternRewriter &rewriter) const override { 872 873 auto dstType = this->getTypeConverter()->convertType(op.getType()); 874 if (!dstType) 875 return rewriter.notifyMatchFailure(op, "type conversion failed"); 876 877 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>( 878 op, dstType, predicate, op.getOperand1(), op.getOperand2()); 879 return success(); 880 } 881 }; 882 883 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate" 884 template <typename SPIRVOp, LLVM::ICmpPredicate predicate> 885 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> { 886 public: 887 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 888 889 LogicalResult 890 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 891 ConversionPatternRewriter &rewriter) const override { 892 893 auto dstType = this->getTypeConverter()->convertType(op.getType()); 894 if (!dstType) 895 return rewriter.notifyMatchFailure(op, "type conversion failed"); 896 897 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>( 898 op, dstType, predicate, op.getOperand1(), op.getOperand2()); 899 return success(); 900 } 901 }; 902 903 class InverseSqrtPattern 904 : public SPIRVToLLVMConversion<spirv::GLInverseSqrtOp> { 905 public: 906 using SPIRVToLLVMConversion<spirv::GLInverseSqrtOp>::SPIRVToLLVMConversion; 907 908 LogicalResult 909 matchAndRewrite(spirv::GLInverseSqrtOp op, OpAdaptor adaptor, 910 ConversionPatternRewriter &rewriter) const override { 911 auto srcType = op.getType(); 912 auto dstType = getTypeConverter()->convertType(srcType); 913 if (!dstType) 914 return rewriter.notifyMatchFailure(op, "type conversion failed"); 915 916 Location loc = op.getLoc(); 917 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); 918 Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand()); 919 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); 920 return success(); 921 } 922 }; 923 924 /// Converts `spirv.Load` and `spirv.Store` to LLVM dialect. 925 template <typename SPIRVOp> 926 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVOp> { 927 public: 928 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 929 930 LogicalResult 931 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 932 ConversionPatternRewriter &rewriter) const override { 933 if (!op.getMemoryAccess()) { 934 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, 935 *this->getTypeConverter(), /*alignment=*/0, 936 /*isVolatile=*/false, 937 /*isNonTemporal=*/false); 938 } 939 auto memoryAccess = *op.getMemoryAccess(); 940 switch (memoryAccess) { 941 case spirv::MemoryAccess::Aligned: 942 case spirv::MemoryAccess::None: 943 case spirv::MemoryAccess::Nontemporal: 944 case spirv::MemoryAccess::Volatile: { 945 unsigned alignment = 946 memoryAccess == spirv::MemoryAccess::Aligned ? *op.getAlignment() : 0; 947 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; 948 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; 949 return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, 950 *this->getTypeConverter(), alignment, 951 isVolatile, isNonTemporal); 952 } 953 default: 954 // There is no support of other memory access attributes. 955 return failure(); 956 } 957 } 958 }; 959 960 /// Converts `spirv.Not` and `spirv.LogicalNot` into LLVM dialect. 961 template <typename SPIRVOp> 962 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> { 963 public: 964 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 965 966 LogicalResult 967 matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor, 968 ConversionPatternRewriter &rewriter) const override { 969 auto srcType = notOp.getType(); 970 auto dstType = this->getTypeConverter()->convertType(srcType); 971 if (!dstType) 972 return rewriter.notifyMatchFailure(notOp, "type conversion failed"); 973 974 Location loc = notOp.getLoc(); 975 IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); 976 auto mask = 977 isa<VectorType>(srcType) 978 ? rewriter.create<LLVM::ConstantOp>( 979 loc, dstType, 980 SplatElementsAttr::get(cast<VectorType>(srcType), minusOne)) 981 : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne); 982 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType, 983 notOp.getOperand(), mask); 984 return success(); 985 } 986 }; 987 988 /// A template pattern that erases the given `SPIRVOp`. 989 template <typename SPIRVOp> 990 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> { 991 public: 992 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 993 994 LogicalResult 995 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 996 ConversionPatternRewriter &rewriter) const override { 997 rewriter.eraseOp(op); 998 return success(); 999 } 1000 }; 1001 1002 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> { 1003 public: 1004 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion; 1005 1006 LogicalResult 1007 matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor, 1008 ConversionPatternRewriter &rewriter) const override { 1009 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(), 1010 ArrayRef<Value>()); 1011 return success(); 1012 } 1013 }; 1014 1015 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> { 1016 public: 1017 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion; 1018 1019 LogicalResult 1020 matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor, 1021 ConversionPatternRewriter &rewriter) const override { 1022 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(), 1023 adaptor.getOperands()); 1024 return success(); 1025 } 1026 }; 1027 1028 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, 1029 StringRef name, 1030 ArrayRef<Type> paramTypes, 1031 Type resultType, 1032 bool convergent = true) { 1033 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>( 1034 SymbolTable::lookupSymbolIn(symbolTable, name)); 1035 if (func) 1036 return func; 1037 1038 OpBuilder b(symbolTable->getRegion(0)); 1039 func = b.create<LLVM::LLVMFuncOp>( 1040 symbolTable->getLoc(), name, 1041 LLVM::LLVMFunctionType::get(resultType, paramTypes)); 1042 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); 1043 func.setConvergent(convergent); 1044 func.setNoUnwind(true); 1045 func.setWillReturn(true); 1046 return func; 1047 } 1048 1049 static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder, 1050 LLVM::LLVMFuncOp func, 1051 ValueRange args) { 1052 auto call = builder.create<LLVM::CallOp>(loc, func, args); 1053 call.setCConv(func.getCConv()); 1054 call.setConvergentAttr(func.getConvergentAttr()); 1055 call.setNoUnwindAttr(func.getNoUnwindAttr()); 1056 call.setWillReturnAttr(func.getWillReturnAttr()); 1057 return call; 1058 } 1059 1060 template <typename BarrierOpTy> 1061 class ControlBarrierPattern : public SPIRVToLLVMConversion<BarrierOpTy> { 1062 public: 1063 using OpAdaptor = typename SPIRVToLLVMConversion<BarrierOpTy>::OpAdaptor; 1064 1065 using SPIRVToLLVMConversion<BarrierOpTy>::SPIRVToLLVMConversion; 1066 1067 static constexpr StringRef getFuncName(); 1068 1069 LogicalResult 1070 matchAndRewrite(BarrierOpTy controlBarrierOp, OpAdaptor adaptor, 1071 ConversionPatternRewriter &rewriter) const override { 1072 constexpr StringRef funcName = getFuncName(); 1073 Operation *symbolTable = 1074 controlBarrierOp->template getParentWithTrait<OpTrait::SymbolTable>(); 1075 1076 Type i32 = rewriter.getI32Type(); 1077 1078 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>(); 1079 LLVM::LLVMFuncOp func = 1080 lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy); 1081 1082 Location loc = controlBarrierOp->getLoc(); 1083 Value execution = rewriter.create<LLVM::ConstantOp>( 1084 loc, i32, static_cast<int32_t>(adaptor.getExecutionScope())); 1085 Value memory = rewriter.create<LLVM::ConstantOp>( 1086 loc, i32, static_cast<int32_t>(adaptor.getMemoryScope())); 1087 Value semantics = rewriter.create<LLVM::ConstantOp>( 1088 loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics())); 1089 1090 auto call = createSPIRVBuiltinCall(loc, rewriter, func, 1091 {execution, memory, semantics}); 1092 1093 rewriter.replaceOp(controlBarrierOp, call); 1094 return success(); 1095 } 1096 }; 1097 1098 namespace { 1099 1100 StringRef getTypeMangling(Type type, bool isSigned) { 1101 return llvm::TypeSwitch<Type, StringRef>(type) 1102 .Case<Float16Type>([](auto) { return "Dh"; }) 1103 .Case<Float32Type>([](auto) { return "f"; }) 1104 .Case<Float64Type>([](auto) { return "d"; }) 1105 .Case<IntegerType>([isSigned](IntegerType intTy) { 1106 switch (intTy.getWidth()) { 1107 case 1: 1108 return "b"; 1109 case 8: 1110 return (isSigned) ? "a" : "c"; 1111 case 16: 1112 return (isSigned) ? "s" : "t"; 1113 case 32: 1114 return (isSigned) ? "i" : "j"; 1115 case 64: 1116 return (isSigned) ? "l" : "m"; 1117 default: 1118 llvm_unreachable("Unsupported integer width"); 1119 } 1120 }) 1121 .Default([](auto) { 1122 llvm_unreachable("No mangling defined"); 1123 return ""; 1124 }); 1125 } 1126 1127 template <typename ReduceOp> 1128 constexpr StringLiteral getGroupFuncName(); 1129 1130 template <> 1131 constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() { 1132 return "_Z17__spirv_GroupIAddii"; 1133 } 1134 template <> 1135 constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() { 1136 return "_Z17__spirv_GroupFAddii"; 1137 } 1138 template <> 1139 constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() { 1140 return "_Z17__spirv_GroupSMinii"; 1141 } 1142 template <> 1143 constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() { 1144 return "_Z17__spirv_GroupUMinii"; 1145 } 1146 template <> 1147 constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() { 1148 return "_Z17__spirv_GroupFMinii"; 1149 } 1150 template <> 1151 constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() { 1152 return "_Z17__spirv_GroupSMaxii"; 1153 } 1154 template <> 1155 constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() { 1156 return "_Z17__spirv_GroupUMaxii"; 1157 } 1158 template <> 1159 constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() { 1160 return "_Z17__spirv_GroupFMaxii"; 1161 } 1162 template <> 1163 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() { 1164 return "_Z27__spirv_GroupNonUniformIAddii"; 1165 } 1166 template <> 1167 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() { 1168 return "_Z27__spirv_GroupNonUniformFAddii"; 1169 } 1170 template <> 1171 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() { 1172 return "_Z27__spirv_GroupNonUniformIMulii"; 1173 } 1174 template <> 1175 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() { 1176 return "_Z27__spirv_GroupNonUniformFMulii"; 1177 } 1178 template <> 1179 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() { 1180 return "_Z27__spirv_GroupNonUniformSMinii"; 1181 } 1182 template <> 1183 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() { 1184 return "_Z27__spirv_GroupNonUniformUMinii"; 1185 } 1186 template <> 1187 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() { 1188 return "_Z27__spirv_GroupNonUniformFMinii"; 1189 } 1190 template <> 1191 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() { 1192 return "_Z27__spirv_GroupNonUniformSMaxii"; 1193 } 1194 template <> 1195 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() { 1196 return "_Z27__spirv_GroupNonUniformUMaxii"; 1197 } 1198 template <> 1199 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() { 1200 return "_Z27__spirv_GroupNonUniformFMaxii"; 1201 } 1202 template <> 1203 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() { 1204 return "_Z33__spirv_GroupNonUniformBitwiseAndii"; 1205 } 1206 template <> 1207 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() { 1208 return "_Z32__spirv_GroupNonUniformBitwiseOrii"; 1209 } 1210 template <> 1211 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() { 1212 return "_Z33__spirv_GroupNonUniformBitwiseXorii"; 1213 } 1214 template <> 1215 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() { 1216 return "_Z33__spirv_GroupNonUniformLogicalAndii"; 1217 } 1218 template <> 1219 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() { 1220 return "_Z32__spirv_GroupNonUniformLogicalOrii"; 1221 } 1222 template <> 1223 constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() { 1224 return "_Z33__spirv_GroupNonUniformLogicalXorii"; 1225 } 1226 } // namespace 1227 1228 template <typename ReduceOp, bool Signed = false, bool NonUniform = false> 1229 class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> { 1230 public: 1231 using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion; 1232 1233 LogicalResult 1234 matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor, 1235 ConversionPatternRewriter &rewriter) const override { 1236 1237 Type retTy = op.getResult().getType(); 1238 if (!retTy.isIntOrFloat()) { 1239 return failure(); 1240 } 1241 SmallString<36> funcName = getGroupFuncName<ReduceOp>(); 1242 funcName += getTypeMangling(retTy, false); 1243 1244 Type i32Ty = rewriter.getI32Type(); 1245 SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy}; 1246 if constexpr (NonUniform) { 1247 if (adaptor.getClusterSize()) { 1248 funcName += "j"; 1249 paramTypes.push_back(i32Ty); 1250 } 1251 } 1252 1253 Operation *symbolTable = 1254 op->template getParentWithTrait<OpTrait::SymbolTable>(); 1255 1256 LLVM::LLVMFuncOp func = 1257 lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy); 1258 1259 Location loc = op.getLoc(); 1260 Value scope = rewriter.create<LLVM::ConstantOp>( 1261 loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope())); 1262 Value groupOp = rewriter.create<LLVM::ConstantOp>( 1263 loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation())); 1264 SmallVector<Value> operands{scope, groupOp}; 1265 operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); 1266 1267 auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands); 1268 rewriter.replaceOp(op, call); 1269 return success(); 1270 } 1271 }; 1272 1273 template <> 1274 constexpr StringRef 1275 ControlBarrierPattern<spirv::ControlBarrierOp>::getFuncName() { 1276 return "_Z22__spirv_ControlBarrieriii"; 1277 } 1278 1279 template <> 1280 constexpr StringRef 1281 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>::getFuncName() { 1282 return "_Z33__spirv_ControlBarrierArriveINTELiii"; 1283 } 1284 1285 template <> 1286 constexpr StringRef 1287 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>::getFuncName() { 1288 return "_Z31__spirv_ControlBarrierWaitINTELiii"; 1289 } 1290 1291 /// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection 1292 /// should be reachable for conversion to succeed. The structure of the loop in 1293 /// LLVM dialect will be the following: 1294 /// 1295 /// +------------------------------------+ 1296 /// | <code before spirv.mlir.loop> | 1297 /// | llvm.br ^header | 1298 /// +------------------------------------+ 1299 /// | 1300 /// +----------------+ | 1301 /// | | | 1302 /// | V V 1303 /// | +------------------------------------+ 1304 /// | | ^header: | 1305 /// | | <header code> | 1306 /// | | llvm.cond_br %cond, ^body, ^exit | 1307 /// | +------------------------------------+ 1308 /// | | 1309 /// | |----------------------+ 1310 /// | | | 1311 /// | V | 1312 /// | +------------------------------------+ | 1313 /// | | ^body: | | 1314 /// | | <body code> | | 1315 /// | | llvm.br ^continue | | 1316 /// | +------------------------------------+ | 1317 /// | | | 1318 /// | V | 1319 /// | +------------------------------------+ | 1320 /// | | ^continue: | | 1321 /// | | <continue code> | | 1322 /// | | llvm.br ^header | | 1323 /// | +------------------------------------+ | 1324 /// | | | 1325 /// +---------------+ +----------------------+ 1326 /// | 1327 /// V 1328 /// +------------------------------------+ 1329 /// | ^exit: | 1330 /// | llvm.br ^remaining | 1331 /// +------------------------------------+ 1332 /// | 1333 /// V 1334 /// +------------------------------------+ 1335 /// | ^remaining: | 1336 /// | <code after spirv.mlir.loop> | 1337 /// +------------------------------------+ 1338 /// 1339 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> { 1340 public: 1341 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion; 1342 1343 LogicalResult 1344 matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor, 1345 ConversionPatternRewriter &rewriter) const override { 1346 // There is no support of loop control at the moment. 1347 if (loopOp.getLoopControl() != spirv::LoopControl::None) 1348 return failure(); 1349 1350 // `spirv.mlir.loop` with empty region is redundant and should be erased. 1351 if (loopOp.getBody().empty()) { 1352 rewriter.eraseOp(loopOp); 1353 return success(); 1354 } 1355 1356 Location loc = loopOp.getLoc(); 1357 1358 // Split the current block after `spirv.mlir.loop`. The remaining ops will 1359 // be used in `endBlock`. 1360 Block *currentBlock = rewriter.getBlock(); 1361 auto position = Block::iterator(loopOp); 1362 Block *endBlock = rewriter.splitBlock(currentBlock, position); 1363 1364 // Remove entry block and create a branch in the current block going to the 1365 // header block. 1366 Block *entryBlock = loopOp.getEntryBlock(); 1367 assert(entryBlock->getOperations().size() == 1); 1368 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front()); 1369 if (!brOp) 1370 return failure(); 1371 Block *headerBlock = loopOp.getHeaderBlock(); 1372 rewriter.setInsertionPointToEnd(currentBlock); 1373 rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock); 1374 rewriter.eraseBlock(entryBlock); 1375 1376 // Branch from merge block to end block. 1377 Block *mergeBlock = loopOp.getMergeBlock(); 1378 Operation *terminator = mergeBlock->getTerminator(); 1379 ValueRange terminatorOperands = terminator->getOperands(); 1380 rewriter.setInsertionPointToEnd(mergeBlock); 1381 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock); 1382 1383 rewriter.inlineRegionBefore(loopOp.getBody(), endBlock); 1384 rewriter.replaceOp(loopOp, endBlock->getArguments()); 1385 return success(); 1386 } 1387 }; 1388 1389 /// Converts `spirv.mlir.selection` with `spirv.BranchConditional` in its header 1390 /// block. All blocks within selection should be reachable for conversion to 1391 /// succeed. 1392 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> { 1393 public: 1394 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion; 1395 1396 LogicalResult 1397 matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor, 1398 ConversionPatternRewriter &rewriter) const override { 1399 // There is no support for `Flatten` or `DontFlatten` selection control at 1400 // the moment. This are just compiler hints and can be performed during the 1401 // optimization passes. 1402 if (op.getSelectionControl() != spirv::SelectionControl::None) 1403 return failure(); 1404 1405 // `spirv.mlir.selection` should have at least two blocks: one selection 1406 // header block and one merge block. If no blocks are present, or control 1407 // flow branches straight to merge block (two blocks are present), the op is 1408 // redundant and it is erased. 1409 if (op.getBody().getBlocks().size() <= 2) { 1410 rewriter.eraseOp(op); 1411 return success(); 1412 } 1413 1414 Location loc = op.getLoc(); 1415 1416 // Split the current block after `spirv.mlir.selection`. The remaining ops 1417 // will be used in `continueBlock`. 1418 auto *currentBlock = rewriter.getInsertionBlock(); 1419 rewriter.setInsertionPointAfter(op); 1420 auto position = rewriter.getInsertionPoint(); 1421 auto *continueBlock = rewriter.splitBlock(currentBlock, position); 1422 1423 // Extract conditional branch information from the header block. By SPIR-V 1424 // dialect spec, it should contain `spirv.BranchConditional` or 1425 // `spirv.Switch` op. Note that `spirv.Switch op` is not supported at the 1426 // moment in the SPIR-V dialect. Remove this block when finished. 1427 auto *headerBlock = op.getHeaderBlock(); 1428 assert(headerBlock->getOperations().size() == 1); 1429 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>( 1430 headerBlock->getOperations().front()); 1431 if (!condBrOp) 1432 return failure(); 1433 rewriter.eraseBlock(headerBlock); 1434 1435 // Branch from merge block to continue block. 1436 auto *mergeBlock = op.getMergeBlock(); 1437 Operation *terminator = mergeBlock->getTerminator(); 1438 ValueRange terminatorOperands = terminator->getOperands(); 1439 rewriter.setInsertionPointToEnd(mergeBlock); 1440 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock); 1441 1442 // Link current block to `true` and `false` blocks within the selection. 1443 Block *trueBlock = condBrOp.getTrueBlock(); 1444 Block *falseBlock = condBrOp.getFalseBlock(); 1445 rewriter.setInsertionPointToEnd(currentBlock); 1446 rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock, 1447 condBrOp.getTrueTargetOperands(), 1448 falseBlock, 1449 condBrOp.getFalseTargetOperands()); 1450 1451 rewriter.inlineRegionBefore(op.getBody(), continueBlock); 1452 rewriter.replaceOp(op, continueBlock->getArguments()); 1453 return success(); 1454 } 1455 }; 1456 1457 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect 1458 /// puts a restriction on `Shift` and `Base` to have the same bit width, 1459 /// `Shift` is zero or sign extended to match this specification. Cases when 1460 /// `Shift` bit width > `Base` bit width are considered to be illegal. 1461 template <typename SPIRVOp, typename LLVMOp> 1462 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> { 1463 public: 1464 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion; 1465 1466 LogicalResult 1467 matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, 1468 ConversionPatternRewriter &rewriter) const override { 1469 1470 auto dstType = this->getTypeConverter()->convertType(op.getType()); 1471 if (!dstType) 1472 return rewriter.notifyMatchFailure(op, "type conversion failed"); 1473 1474 Type op1Type = op.getOperand1().getType(); 1475 Type op2Type = op.getOperand2().getType(); 1476 1477 if (op1Type == op2Type) { 1478 rewriter.template replaceOpWithNewOp<LLVMOp>(op, dstType, 1479 adaptor.getOperands()); 1480 return success(); 1481 } 1482 1483 std::optional<uint64_t> dstTypeWidth = 1484 getIntegerOrVectorElementWidth(dstType); 1485 std::optional<uint64_t> op2TypeWidth = 1486 getIntegerOrVectorElementWidth(op2Type); 1487 1488 if (!dstTypeWidth || !op2TypeWidth) 1489 return failure(); 1490 1491 Location loc = op.getLoc(); 1492 Value extended; 1493 if (op2TypeWidth < dstTypeWidth) { 1494 if (isUnsignedIntegerOrVector(op2Type)) { 1495 extended = rewriter.template create<LLVM::ZExtOp>( 1496 loc, dstType, adaptor.getOperand2()); 1497 } else { 1498 extended = rewriter.template create<LLVM::SExtOp>( 1499 loc, dstType, adaptor.getOperand2()); 1500 } 1501 } else if (op2TypeWidth == dstTypeWidth) { 1502 extended = adaptor.getOperand2(); 1503 } else { 1504 return failure(); 1505 } 1506 1507 Value result = rewriter.template create<LLVMOp>( 1508 loc, dstType, adaptor.getOperand1(), extended); 1509 rewriter.replaceOp(op, result); 1510 return success(); 1511 } 1512 }; 1513 1514 class TanPattern : public SPIRVToLLVMConversion<spirv::GLTanOp> { 1515 public: 1516 using SPIRVToLLVMConversion<spirv::GLTanOp>::SPIRVToLLVMConversion; 1517 1518 LogicalResult 1519 matchAndRewrite(spirv::GLTanOp tanOp, OpAdaptor adaptor, 1520 ConversionPatternRewriter &rewriter) const override { 1521 auto dstType = getTypeConverter()->convertType(tanOp.getType()); 1522 if (!dstType) 1523 return rewriter.notifyMatchFailure(tanOp, "type conversion failed"); 1524 1525 Location loc = tanOp.getLoc(); 1526 Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand()); 1527 Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand()); 1528 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); 1529 return success(); 1530 } 1531 }; 1532 1533 /// Convert `spirv.Tanh` to 1534 /// 1535 /// exp(2x) - 1 1536 /// ----------- 1537 /// exp(2x) + 1 1538 /// 1539 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLTanhOp> { 1540 public: 1541 using SPIRVToLLVMConversion<spirv::GLTanhOp>::SPIRVToLLVMConversion; 1542 1543 LogicalResult 1544 matchAndRewrite(spirv::GLTanhOp tanhOp, OpAdaptor adaptor, 1545 ConversionPatternRewriter &rewriter) const override { 1546 auto srcType = tanhOp.getType(); 1547 auto dstType = getTypeConverter()->convertType(srcType); 1548 if (!dstType) 1549 return rewriter.notifyMatchFailure(tanhOp, "type conversion failed"); 1550 1551 Location loc = tanhOp.getLoc(); 1552 Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); 1553 Value multiplied = 1554 rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand()); 1555 Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); 1556 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); 1557 Value numerator = 1558 rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one); 1559 Value denominator = 1560 rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one); 1561 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, 1562 denominator); 1563 return success(); 1564 } 1565 }; 1566 1567 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> { 1568 public: 1569 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion; 1570 1571 LogicalResult 1572 matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor, 1573 ConversionPatternRewriter &rewriter) const override { 1574 auto srcType = varOp.getType(); 1575 // Initialization is supported for scalars and vectors only. 1576 auto pointerTo = cast<spirv::PointerType>(srcType).getPointeeType(); 1577 auto init = varOp.getInitializer(); 1578 if (init && !pointerTo.isIntOrFloat() && !isa<VectorType>(pointerTo)) 1579 return failure(); 1580 1581 auto dstType = getTypeConverter()->convertType(srcType); 1582 if (!dstType) 1583 return rewriter.notifyMatchFailure(varOp, "type conversion failed"); 1584 1585 Location loc = varOp.getLoc(); 1586 Value size = createI32ConstantOf(loc, rewriter, 1); 1587 if (!init) { 1588 auto elementType = getTypeConverter()->convertType(pointerTo); 1589 if (!elementType) 1590 return rewriter.notifyMatchFailure(varOp, "type conversion failed"); 1591 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, elementType, 1592 size); 1593 return success(); 1594 } 1595 auto elementType = getTypeConverter()->convertType(pointerTo); 1596 if (!elementType) 1597 return rewriter.notifyMatchFailure(varOp, "type conversion failed"); 1598 Value allocated = 1599 rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size); 1600 rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated); 1601 rewriter.replaceOp(varOp, allocated); 1602 return success(); 1603 } 1604 }; 1605 1606 //===----------------------------------------------------------------------===// 1607 // BitcastOp conversion 1608 //===----------------------------------------------------------------------===// 1609 1610 class BitcastConversionPattern 1611 : public SPIRVToLLVMConversion<spirv::BitcastOp> { 1612 public: 1613 using SPIRVToLLVMConversion<spirv::BitcastOp>::SPIRVToLLVMConversion; 1614 1615 LogicalResult 1616 matchAndRewrite(spirv::BitcastOp bitcastOp, OpAdaptor adaptor, 1617 ConversionPatternRewriter &rewriter) const override { 1618 auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); 1619 if (!dstType) 1620 return rewriter.notifyMatchFailure(bitcastOp, "type conversion failed"); 1621 1622 // LLVM's opaque pointers do not require bitcasts. 1623 if (isa<LLVM::LLVMPointerType>(dstType)) { 1624 rewriter.replaceOp(bitcastOp, adaptor.getOperand()); 1625 return success(); 1626 } 1627 1628 rewriter.replaceOpWithNewOp<LLVM::BitcastOp>( 1629 bitcastOp, dstType, adaptor.getOperands(), bitcastOp->getAttrs()); 1630 return success(); 1631 } 1632 }; 1633 1634 //===----------------------------------------------------------------------===// 1635 // FuncOp conversion 1636 //===----------------------------------------------------------------------===// 1637 1638 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> { 1639 public: 1640 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion; 1641 1642 LogicalResult 1643 matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, 1644 ConversionPatternRewriter &rewriter) const override { 1645 1646 // Convert function signature. At the moment LLVMType converter is enough 1647 // for currently supported types. 1648 auto funcType = funcOp.getFunctionType(); 1649 TypeConverter::SignatureConversion signatureConverter( 1650 funcType.getNumInputs()); 1651 auto llvmType = static_cast<const LLVMTypeConverter *>(getTypeConverter()) 1652 ->convertFunctionSignature( 1653 funcType, /*isVariadic=*/false, 1654 /*useBarePtrCallConv=*/false, signatureConverter); 1655 if (!llvmType) 1656 return failure(); 1657 1658 // Create a new `LLVMFuncOp` 1659 Location loc = funcOp.getLoc(); 1660 StringRef name = funcOp.getName(); 1661 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType); 1662 1663 // Convert SPIR-V Function Control to equivalent LLVM function attribute 1664 MLIRContext *context = funcOp.getContext(); 1665 switch (funcOp.getFunctionControl()) { 1666 case spirv::FunctionControl::Inline: 1667 newFuncOp.setAlwaysInline(true); 1668 break; 1669 case spirv::FunctionControl::DontInline: 1670 newFuncOp.setNoInline(true); 1671 break; 1672 1673 #define DISPATCH(functionControl, llvmAttr) \ 1674 case functionControl: \ 1675 newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \ 1676 break; 1677 1678 DISPATCH(spirv::FunctionControl::Pure, 1679 StringAttr::get(context, "readonly")); 1680 DISPATCH(spirv::FunctionControl::Const, 1681 StringAttr::get(context, "readnone")); 1682 1683 #undef DISPATCH 1684 1685 // Default: if `spirv::FunctionControl::None`, then no attributes are 1686 // needed. 1687 default: 1688 break; 1689 } 1690 1691 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), 1692 newFuncOp.end()); 1693 if (failed(rewriter.convertRegionTypes( 1694 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) { 1695 return failure(); 1696 } 1697 rewriter.eraseOp(funcOp); 1698 return success(); 1699 } 1700 }; 1701 1702 //===----------------------------------------------------------------------===// 1703 // ModuleOp conversion 1704 //===----------------------------------------------------------------------===// 1705 1706 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> { 1707 public: 1708 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion; 1709 1710 LogicalResult 1711 matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor, 1712 ConversionPatternRewriter &rewriter) const override { 1713 1714 auto newModuleOp = 1715 rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName()); 1716 rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody()); 1717 1718 // Remove the terminator block that was automatically added by builder 1719 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back()); 1720 rewriter.eraseOp(spvModuleOp); 1721 return success(); 1722 } 1723 }; 1724 1725 //===----------------------------------------------------------------------===// 1726 // VectorShuffleOp conversion 1727 //===----------------------------------------------------------------------===// 1728 1729 class VectorShufflePattern 1730 : public SPIRVToLLVMConversion<spirv::VectorShuffleOp> { 1731 public: 1732 using SPIRVToLLVMConversion<spirv::VectorShuffleOp>::SPIRVToLLVMConversion; 1733 LogicalResult 1734 matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor, 1735 ConversionPatternRewriter &rewriter) const override { 1736 Location loc = op.getLoc(); 1737 auto components = adaptor.getComponents(); 1738 auto vector1 = adaptor.getVector1(); 1739 auto vector2 = adaptor.getVector2(); 1740 int vector1Size = cast<VectorType>(vector1.getType()).getNumElements(); 1741 int vector2Size = cast<VectorType>(vector2.getType()).getNumElements(); 1742 if (vector1Size == vector2Size) { 1743 rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>( 1744 op, vector1, vector2, 1745 LLVM::convertArrayToIndices<int32_t>(components)); 1746 return success(); 1747 } 1748 1749 auto dstType = getTypeConverter()->convertType(op.getType()); 1750 if (!dstType) 1751 return rewriter.notifyMatchFailure(op, "type conversion failed"); 1752 auto scalarType = cast<VectorType>(dstType).getElementType(); 1753 auto componentsArray = components.getValue(); 1754 auto *context = rewriter.getContext(); 1755 auto llvmI32Type = IntegerType::get(context, 32); 1756 Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType); 1757 for (unsigned i = 0; i < componentsArray.size(); i++) { 1758 if (!isa<IntegerAttr>(componentsArray[i])) 1759 return op.emitError("unable to support non-constant component"); 1760 1761 int indexVal = cast<IntegerAttr>(componentsArray[i]).getInt(); 1762 if (indexVal == -1) 1763 continue; 1764 1765 int offsetVal = 0; 1766 Value baseVector = vector1; 1767 if (indexVal >= vector1Size) { 1768 offsetVal = vector1Size; 1769 baseVector = vector2; 1770 } 1771 1772 Value dstIndex = rewriter.create<LLVM::ConstantOp>( 1773 loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i)); 1774 Value index = rewriter.create<LLVM::ConstantOp>( 1775 loc, llvmI32Type, 1776 rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal)); 1777 1778 auto extractOp = rewriter.create<LLVM::ExtractElementOp>( 1779 loc, scalarType, baseVector, index); 1780 targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp, 1781 extractOp, dstIndex); 1782 } 1783 rewriter.replaceOp(op, targetOp); 1784 return success(); 1785 } 1786 }; 1787 } // namespace 1788 1789 //===----------------------------------------------------------------------===// 1790 // Pattern population 1791 //===----------------------------------------------------------------------===// 1792 1793 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, 1794 spirv::ClientAPI clientAPI) { 1795 typeConverter.addConversion([&](spirv::ArrayType type) { 1796 return convertArrayType(type, typeConverter); 1797 }); 1798 typeConverter.addConversion([&, clientAPI](spirv::PointerType type) { 1799 return convertPointerType(type, typeConverter, clientAPI); 1800 }); 1801 typeConverter.addConversion([&](spirv::RuntimeArrayType type) { 1802 return convertRuntimeArrayType(type, typeConverter); 1803 }); 1804 typeConverter.addConversion([&](spirv::StructType type) { 1805 return convertStructType(type, typeConverter); 1806 }); 1807 } 1808 1809 void mlir::populateSPIRVToLLVMConversionPatterns( 1810 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, 1811 spirv::ClientAPI clientAPI) { 1812 patterns.add< 1813 // Arithmetic ops 1814 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>, 1815 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>, 1816 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>, 1817 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>, 1818 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>, 1819 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>, 1820 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>, 1821 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>, 1822 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>, 1823 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>, 1824 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>, 1825 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>, 1826 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>, 1827 1828 // Bitwise ops 1829 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern, 1830 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>, 1831 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>, 1832 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>, 1833 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>, 1834 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>, 1835 NotPattern<spirv::NotOp>, 1836 1837 // Cast ops 1838 BitcastConversionPattern, 1839 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>, 1840 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>, 1841 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>, 1842 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>, 1843 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>, 1844 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>, 1845 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>, 1846 1847 // Comparison ops 1848 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>, 1849 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>, 1850 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>, 1851 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>, 1852 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>, 1853 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>, 1854 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>, 1855 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>, 1856 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>, 1857 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>, 1858 FComparePattern<spirv::FUnordGreaterThanEqualOp, 1859 LLVM::FCmpPredicate::uge>, 1860 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>, 1861 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>, 1862 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>, 1863 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>, 1864 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>, 1865 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>, 1866 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>, 1867 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>, 1868 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>, 1869 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>, 1870 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>, 1871 1872 // Constant op 1873 ConstantScalarAndVectorPattern, 1874 1875 // Control Flow ops 1876 BranchConversionPattern, BranchConditionalConversionPattern, 1877 FunctionCallPattern, LoopPattern, SelectionPattern, 1878 ErasePattern<spirv::MergeOp>, 1879 1880 // Entry points and execution mode are handled separately. 1881 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern, 1882 1883 // GLSL extended instruction set ops 1884 DirectConversionPattern<spirv::GLCeilOp, LLVM::FCeilOp>, 1885 DirectConversionPattern<spirv::GLCosOp, LLVM::CosOp>, 1886 DirectConversionPattern<spirv::GLExpOp, LLVM::ExpOp>, 1887 DirectConversionPattern<spirv::GLFAbsOp, LLVM::FAbsOp>, 1888 DirectConversionPattern<spirv::GLFloorOp, LLVM::FFloorOp>, 1889 DirectConversionPattern<spirv::GLFMaxOp, LLVM::MaxNumOp>, 1890 DirectConversionPattern<spirv::GLFMinOp, LLVM::MinNumOp>, 1891 DirectConversionPattern<spirv::GLLogOp, LLVM::LogOp>, 1892 DirectConversionPattern<spirv::GLSinOp, LLVM::SinOp>, 1893 DirectConversionPattern<spirv::GLSMaxOp, LLVM::SMaxOp>, 1894 DirectConversionPattern<spirv::GLSMinOp, LLVM::SMinOp>, 1895 DirectConversionPattern<spirv::GLSqrtOp, LLVM::SqrtOp>, 1896 InverseSqrtPattern, TanPattern, TanhPattern, 1897 1898 // Logical ops 1899 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>, 1900 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>, 1901 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>, 1902 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>, 1903 NotPattern<spirv::LogicalNotOp>, 1904 1905 // Memory ops 1906 AccessChainPattern, AddressOfPattern, LoadStorePattern<spirv::LoadOp>, 1907 LoadStorePattern<spirv::StoreOp>, VariablePattern, 1908 1909 // Miscellaneous ops 1910 CompositeExtractPattern, CompositeInsertPattern, 1911 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>, 1912 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>, 1913 VectorShufflePattern, 1914 1915 // Shift ops 1916 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>, 1917 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>, 1918 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>, 1919 1920 // Return ops 1921 ReturnPattern, ReturnValuePattern, 1922 1923 // Barrier ops 1924 ControlBarrierPattern<spirv::ControlBarrierOp>, 1925 ControlBarrierPattern<spirv::INTELControlBarrierArriveOp>, 1926 ControlBarrierPattern<spirv::INTELControlBarrierWaitOp>, 1927 1928 // Group reduction operations 1929 GroupReducePattern<spirv::GroupIAddOp>, 1930 GroupReducePattern<spirv::GroupFAddOp>, 1931 GroupReducePattern<spirv::GroupFMinOp>, 1932 GroupReducePattern<spirv::GroupUMinOp>, 1933 GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>, 1934 GroupReducePattern<spirv::GroupFMaxOp>, 1935 GroupReducePattern<spirv::GroupUMaxOp>, 1936 GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>, 1937 GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false, 1938 /*NonUniform=*/true>, 1939 GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false, 1940 /*NonUniform=*/true>, 1941 GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false, 1942 /*NonUniform=*/true>, 1943 GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false, 1944 /*NonUniform=*/true>, 1945 GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true, 1946 /*NonUniform=*/true>, 1947 GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false, 1948 /*NonUniform=*/true>, 1949 GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false, 1950 /*NonUniform=*/true>, 1951 GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true, 1952 /*NonUniform=*/true>, 1953 GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false, 1954 /*NonUniform=*/true>, 1955 GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false, 1956 /*NonUniform=*/true>, 1957 GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false, 1958 /*NonUniform=*/true>, 1959 GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false, 1960 /*NonUniform=*/true>, 1961 GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false, 1962 /*NonUniform=*/true>, 1963 GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false, 1964 /*NonUniform=*/true>, 1965 GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false, 1966 /*NonUniform=*/true>, 1967 GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false, 1968 /*NonUniform=*/true>>(patterns.getContext(), 1969 typeConverter); 1970 1971 patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(), 1972 typeConverter); 1973 } 1974 1975 void mlir::populateSPIRVToLLVMFunctionConversionPatterns( 1976 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 1977 patterns.add<FuncConversionPattern>(patterns.getContext(), typeConverter); 1978 } 1979 1980 void mlir::populateSPIRVToLLVMModuleConversionPatterns( 1981 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { 1982 patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter); 1983 } 1984 1985 //===----------------------------------------------------------------------===// 1986 // Pre-conversion hooks 1987 //===----------------------------------------------------------------------===// 1988 1989 /// Hook for descriptor set and binding number encoding. 1990 static constexpr StringRef kBinding = "binding"; 1991 static constexpr StringRef kDescriptorSet = "descriptor_set"; 1992 void mlir::encodeBindAttribute(ModuleOp module) { 1993 auto spvModules = module.getOps<spirv::ModuleOp>(); 1994 for (auto spvModule : spvModules) { 1995 spvModule.walk([&](spirv::GlobalVariableOp op) { 1996 IntegerAttr descriptorSet = 1997 op->getAttrOfType<IntegerAttr>(kDescriptorSet); 1998 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding); 1999 // For every global variable in the module, get the ones with descriptor 2000 // set and binding numbers. 2001 if (descriptorSet && binding) { 2002 // Encode these numbers into the variable's symbolic name. If the 2003 // SPIR-V module has a name, add it at the beginning. 2004 auto moduleAndName = 2005 spvModule.getName().has_value() 2006 ? spvModule.getName()->str() + "_" + op.getSymName().str() 2007 : op.getSymName().str(); 2008 std::string name = 2009 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName, 2010 std::to_string(descriptorSet.getInt()), 2011 std::to_string(binding.getInt())); 2012 auto nameAttr = StringAttr::get(op->getContext(), name); 2013 2014 // Replace all symbol uses and set the new symbol name. Finally, remove 2015 // descriptor set and binding attributes. 2016 if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule))) 2017 op.emitError("unable to replace all symbol uses for ") << name; 2018 SymbolTable::setSymbolName(op, nameAttr); 2019 op->removeAttr(kDescriptorSet); 2020 op->removeAttr(kBinding); 2021 } 2022 }); 2023 } 2024 } 2025