1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 #include "mlir/IR/BuiltinAttributes.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/Visitors.h" 24 #include "llvm/Support/Debug.h" 25 #include <cassert> 26 #include <optional> 27 28 #define DEBUG_TYPE "memref-to-spirv-pattern" 29 30 using namespace mlir; 31 32 //===----------------------------------------------------------------------===// 33 // Utility functions 34 //===----------------------------------------------------------------------===// 35 36 /// Returns the offset of the value in `targetBits` representation. 37 /// 38 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. 39 /// It's assumed to be non-negative. 40 /// 41 /// When accessing an element in the array treating as having elements of 42 /// `targetBits`, multiple values are loaded in the same time. The method 43 /// returns the offset where the `srcIdx` locates in the value. For example, if 44 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is 45 /// located at (x % 4) * 8. Because there are four elements in one i32, and one 46 /// element has 8 bits. 47 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, 48 int targetBits, OpBuilder &builder) { 49 assert(targetBits % sourceBits == 0); 50 Type type = srcIdx.getType(); 51 IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits); 52 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr); 53 IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits); 54 auto srcBitsValue = 55 builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr); 56 auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx); 57 return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue); 58 } 59 60 /// Returns an adjusted spirv::AccessChainOp. Based on the 61 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be 62 /// supported. During conversion if a memref of an unsupported type is used, 63 /// load/stores to this memref need to be modified to use a supported higher 64 /// bitwidth `targetBits` and extracting the required bits. For an accessing a 65 /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load 66 /// the bits needed. The extraction of the actual bits needed are handled 67 /// separately. Note that this only works for a 1-D tensor. 68 static Value 69 adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, 70 spirv::AccessChainOp op, int sourceBits, 71 int targetBits, OpBuilder &builder) { 72 assert(targetBits % sourceBits == 0); 73 const auto loc = op.getLoc(); 74 Value lastDim = op->getOperand(op.getNumOperands() - 1); 75 Type type = lastDim.getType(); 76 IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits); 77 auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr); 78 auto indices = llvm::to_vector<4>(op.getIndices()); 79 // There are two elements if this is a 1-D tensor. 80 assert(indices.size() == 2); 81 indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx); 82 Type t = typeConverter.convertType(op.getComponentPtr().getType()); 83 return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices); 84 } 85 86 /// Casts the given `srcBool` into an integer of `dstType`. 87 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, 88 OpBuilder &builder) { 89 assert(srcBool.getType().isInteger(1)); 90 if (dstType.isInteger(1)) 91 return srcBool; 92 Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); 93 Value one = spirv::ConstantOp::getOne(dstType, loc, builder); 94 return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one, 95 zero); 96 } 97 98 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast 99 /// to the type destination type, and masked. 100 static Value shiftValue(Location loc, Value value, Value offset, Value mask, 101 OpBuilder &builder) { 102 IntegerType dstType = cast<IntegerType>(mask.getType()); 103 int targetBits = static_cast<int>(dstType.getWidth()); 104 int valueBits = value.getType().getIntOrFloatBitWidth(); 105 assert(valueBits <= targetBits); 106 107 if (valueBits == 1) { 108 value = castBoolToIntN(loc, value, dstType, builder); 109 } else { 110 if (valueBits < targetBits) { 111 value = builder.create<spirv::UConvertOp>( 112 loc, builder.getIntegerType(targetBits), value); 113 } 114 115 value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask); 116 } 117 return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(), 118 value, offset); 119 } 120 121 /// Returns true if the allocations of memref `type` generated from `allocOp` 122 /// can be lowered to SPIR-V. 123 static bool isAllocationSupported(Operation *allocOp, MemRefType type) { 124 if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) { 125 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); 126 if (!sc || sc.getValue() != spirv::StorageClass::Workgroup) 127 return false; 128 } else if (isa<memref::AllocaOp>(allocOp)) { 129 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); 130 if (!sc || sc.getValue() != spirv::StorageClass::Function) 131 return false; 132 } else { 133 return false; 134 } 135 136 // Currently only support static shape and int or float or vector of int or 137 // float element type. 138 if (!type.hasStaticShape()) 139 return false; 140 141 Type elementType = type.getElementType(); 142 if (auto vecType = dyn_cast<VectorType>(elementType)) 143 elementType = vecType.getElementType(); 144 return elementType.isIntOrFloat(); 145 } 146 147 /// Returns the scope to use for atomic operations use for emulating store 148 /// operations of unsupported integer bitwidths, based on the memref 149 /// type. Returns std::nullopt on failure. 150 static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) { 151 auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace()); 152 switch (sc.getValue()) { 153 case spirv::StorageClass::StorageBuffer: 154 return spirv::Scope::Device; 155 case spirv::StorageClass::Workgroup: 156 return spirv::Scope::Workgroup; 157 default: 158 break; 159 } 160 return {}; 161 } 162 163 /// Casts the given `srcInt` into a boolean value. 164 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { 165 if (srcInt.getType().isInteger(1)) 166 return srcInt; 167 168 auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder); 169 return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one); 170 } 171 172 //===----------------------------------------------------------------------===// 173 // Operation conversion 174 //===----------------------------------------------------------------------===// 175 176 // Note that DRR cannot be used for the patterns in this file: we may need to 177 // convert type along the way, which requires ConversionPattern. DRR generates 178 // normal RewritePattern. 179 180 namespace { 181 182 /// Converts memref.alloca to SPIR-V Function variables. 183 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> { 184 public: 185 using OpConversionPattern<memref::AllocaOp>::OpConversionPattern; 186 187 LogicalResult 188 matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, 189 ConversionPatternRewriter &rewriter) const override; 190 }; 191 192 /// Converts an allocation operation to SPIR-V. Currently only supports lowering 193 /// to Workgroup memory when the size is constant. Note that this pattern needs 194 /// to be applied in a pass that runs at least at spirv.module scope since it 195 /// wil ladd global variables into the spirv.module. 196 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> { 197 public: 198 using OpConversionPattern<memref::AllocOp>::OpConversionPattern; 199 200 LogicalResult 201 matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, 202 ConversionPatternRewriter &rewriter) const override; 203 }; 204 205 /// Converts memref.automic_rmw operations to SPIR-V atomic operations. 206 class AtomicRMWOpPattern final 207 : public OpConversionPattern<memref::AtomicRMWOp> { 208 public: 209 using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern; 210 211 LogicalResult 212 matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, 213 ConversionPatternRewriter &rewriter) const override; 214 }; 215 216 /// Removed a deallocation if it is a supported allocation. Currently only 217 /// removes deallocation if the memory space is workgroup memory. 218 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> { 219 public: 220 using OpConversionPattern<memref::DeallocOp>::OpConversionPattern; 221 222 LogicalResult 223 matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, 224 ConversionPatternRewriter &rewriter) const override; 225 }; 226 227 /// Converts memref.load to spirv.Load + spirv.AccessChain on integers. 228 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> { 229 public: 230 using OpConversionPattern<memref::LoadOp>::OpConversionPattern; 231 232 LogicalResult 233 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 234 ConversionPatternRewriter &rewriter) const override; 235 }; 236 237 /// Converts memref.load to spirv.Load + spirv.AccessChain. 238 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> { 239 public: 240 using OpConversionPattern<memref::LoadOp>::OpConversionPattern; 241 242 LogicalResult 243 matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 244 ConversionPatternRewriter &rewriter) const override; 245 }; 246 247 /// Converts memref.store to spirv.Store on integers. 248 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> { 249 public: 250 using OpConversionPattern<memref::StoreOp>::OpConversionPattern; 251 252 LogicalResult 253 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 254 ConversionPatternRewriter &rewriter) const override; 255 }; 256 257 /// Converts memref.memory_space_cast to the appropriate spirv cast operations. 258 class MemorySpaceCastOpPattern final 259 : public OpConversionPattern<memref::MemorySpaceCastOp> { 260 public: 261 using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern; 262 263 LogicalResult 264 matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, 265 ConversionPatternRewriter &rewriter) const override; 266 }; 267 268 /// Converts memref.store to spirv.Store. 269 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> { 270 public: 271 using OpConversionPattern<memref::StoreOp>::OpConversionPattern; 272 273 LogicalResult 274 matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 275 ConversionPatternRewriter &rewriter) const override; 276 }; 277 278 class ReinterpretCastPattern final 279 : public OpConversionPattern<memref::ReinterpretCastOp> { 280 public: 281 using OpConversionPattern::OpConversionPattern; 282 283 LogicalResult 284 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, 285 ConversionPatternRewriter &rewriter) const override; 286 }; 287 288 class CastPattern final : public OpConversionPattern<memref::CastOp> { 289 public: 290 using OpConversionPattern::OpConversionPattern; 291 292 LogicalResult 293 matchAndRewrite(memref::CastOp op, OpAdaptor adaptor, 294 ConversionPatternRewriter &rewriter) const override { 295 Value src = adaptor.getSource(); 296 Type srcType = src.getType(); 297 298 const TypeConverter *converter = getTypeConverter(); 299 Type dstType = converter->convertType(op.getType()); 300 if (srcType != dstType) 301 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 302 diag << "types doesn't match: " << srcType << " and " << dstType; 303 }); 304 305 rewriter.replaceOp(op, src); 306 return success(); 307 } 308 }; 309 310 } // namespace 311 312 //===----------------------------------------------------------------------===// 313 // AllocaOp 314 //===----------------------------------------------------------------------===// 315 316 LogicalResult 317 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, 318 ConversionPatternRewriter &rewriter) const { 319 MemRefType allocType = allocaOp.getType(); 320 if (!isAllocationSupported(allocaOp, allocType)) 321 return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type"); 322 323 // Get the SPIR-V type for the allocation. 324 Type spirvType = getTypeConverter()->convertType(allocType); 325 if (!spirvType) 326 return rewriter.notifyMatchFailure(allocaOp, "type conversion failed"); 327 328 rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType, 329 spirv::StorageClass::Function, 330 /*initializer=*/nullptr); 331 return success(); 332 } 333 334 //===----------------------------------------------------------------------===// 335 // AllocOp 336 //===----------------------------------------------------------------------===// 337 338 LogicalResult 339 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, 340 ConversionPatternRewriter &rewriter) const { 341 MemRefType allocType = operation.getType(); 342 if (!isAllocationSupported(operation, allocType)) 343 return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); 344 345 // Get the SPIR-V type for the allocation. 346 Type spirvType = getTypeConverter()->convertType(allocType); 347 if (!spirvType) 348 return rewriter.notifyMatchFailure(operation, "type conversion failed"); 349 350 // Insert spirv.GlobalVariable for this allocation. 351 Operation *parent = 352 SymbolTable::getNearestSymbolTable(operation->getParentOp()); 353 if (!parent) 354 return failure(); 355 Location loc = operation.getLoc(); 356 spirv::GlobalVariableOp varOp; 357 { 358 OpBuilder::InsertionGuard guard(rewriter); 359 Block &entryBlock = *parent->getRegion(0).begin(); 360 rewriter.setInsertionPointToStart(&entryBlock); 361 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>(); 362 std::string varName = 363 std::string("__workgroup_mem__") + 364 std::to_string(std::distance(varOps.begin(), varOps.end())); 365 varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName, 366 /*initializer=*/nullptr); 367 } 368 369 // Get pointer to global variable at the current scope. 370 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp); 371 return success(); 372 } 373 374 //===----------------------------------------------------------------------===// 375 // AllocOp 376 //===----------------------------------------------------------------------===// 377 378 LogicalResult 379 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, 380 OpAdaptor adaptor, 381 ConversionPatternRewriter &rewriter) const { 382 if (isa<FloatType>(atomicOp.getType())) 383 return rewriter.notifyMatchFailure(atomicOp, 384 "unimplemented floating-point case"); 385 386 auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType()); 387 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType); 388 if (!scope) 389 return rewriter.notifyMatchFailure(atomicOp, 390 "unsupported memref memory space"); 391 392 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 393 Type resultType = typeConverter.convertType(atomicOp.getType()); 394 if (!resultType) 395 return rewriter.notifyMatchFailure(atomicOp, 396 "failed to convert result type"); 397 398 auto loc = atomicOp.getLoc(); 399 Value ptr = 400 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), 401 adaptor.getIndices(), loc, rewriter); 402 403 if (!ptr) 404 return failure(); 405 406 #define ATOMIC_CASE(kind, spirvOp) \ 407 case arith::AtomicRMWKind::kind: \ 408 rewriter.replaceOpWithNewOp<spirv::spirvOp>( \ 409 atomicOp, resultType, ptr, *scope, \ 410 spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \ 411 break 412 413 switch (atomicOp.getKind()) { 414 ATOMIC_CASE(addi, AtomicIAddOp); 415 ATOMIC_CASE(maxs, AtomicSMaxOp); 416 ATOMIC_CASE(maxu, AtomicUMaxOp); 417 ATOMIC_CASE(mins, AtomicSMinOp); 418 ATOMIC_CASE(minu, AtomicUMinOp); 419 ATOMIC_CASE(ori, AtomicOrOp); 420 ATOMIC_CASE(andi, AtomicAndOp); 421 default: 422 return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind"); 423 } 424 425 #undef ATOMIC_CASE 426 427 return success(); 428 } 429 430 //===----------------------------------------------------------------------===// 431 // DeallocOp 432 //===----------------------------------------------------------------------===// 433 434 LogicalResult 435 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, 436 OpAdaptor adaptor, 437 ConversionPatternRewriter &rewriter) const { 438 MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType()); 439 if (!isAllocationSupported(operation, deallocType)) 440 return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); 441 rewriter.eraseOp(operation); 442 return success(); 443 } 444 445 //===----------------------------------------------------------------------===// 446 // LoadOp 447 //===----------------------------------------------------------------------===// 448 449 struct MemoryRequirements { 450 spirv::MemoryAccessAttr memoryAccess; 451 IntegerAttr alignment; 452 }; 453 454 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if 455 /// any. 456 static FailureOr<MemoryRequirements> 457 calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) { 458 MLIRContext *ctx = accessedPtr.getContext(); 459 460 auto memoryAccess = spirv::MemoryAccess::None; 461 if (isNontemporal) { 462 memoryAccess = spirv::MemoryAccess::Nontemporal; 463 } 464 465 auto ptrType = cast<spirv::PointerType>(accessedPtr.getType()); 466 if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) { 467 if (memoryAccess == spirv::MemoryAccess::None) { 468 return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}}; 469 } 470 return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess), 471 IntegerAttr{}}; 472 } 473 474 // PhysicalStorageBuffers require the `Aligned` attribute. 475 auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType()); 476 if (!pointeeType) 477 return failure(); 478 479 // For scalar types, the alignment is determined by their size. 480 std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes(); 481 if (!sizeInBytes.has_value()) 482 return failure(); 483 484 memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; 485 auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess); 486 auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes); 487 return MemoryRequirements{memAccessAttr, alignment}; 488 } 489 490 /// Given an accessed SPIR-V pointer and the original memref load/store 491 /// `memAccess` op, calculates the alignment requirements, if any. Takes into 492 /// account the alignment attributes applied to the load/store op. 493 template <class LoadOrStoreOp> 494 static FailureOr<MemoryRequirements> 495 calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) { 496 static_assert( 497 llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value, 498 "Must be called on either memref::LoadOp or memref::StoreOp"); 499 500 Operation *memrefAccessOp = loadOrStoreOp.getOperation(); 501 auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>( 502 spirv::attributeName<spirv::MemoryAccess>()); 503 auto memrefAlignment = 504 memrefAccessOp->getAttrOfType<IntegerAttr>("alignment"); 505 if (memrefMemAccess && memrefAlignment) 506 return MemoryRequirements{memrefMemAccess, memrefAlignment}; 507 508 return calculateMemoryRequirements(accessedPtr, 509 loadOrStoreOp.getNontemporal()); 510 } 511 512 LogicalResult 513 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 514 ConversionPatternRewriter &rewriter) const { 515 auto loc = loadOp.getLoc(); 516 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType()); 517 if (!memrefType.getElementType().isSignlessInteger()) 518 return failure(); 519 520 const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 521 Value accessChain = 522 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), 523 adaptor.getIndices(), loc, rewriter); 524 525 if (!accessChain) 526 return failure(); 527 528 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); 529 bool isBool = srcBits == 1; 530 if (isBool) 531 srcBits = typeConverter.getOptions().boolNumBits; 532 533 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType); 534 if (!pointerType) 535 return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type"); 536 537 Type pointeeType = pointerType.getPointeeType(); 538 Type dstType; 539 if (typeConverter.allows(spirv::Capability::Kernel)) { 540 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType)) 541 dstType = arrayType.getElementType(); 542 else 543 dstType = pointeeType; 544 } else { 545 // For Vulkan we need to extract element from wrapping struct and array. 546 Type structElemType = 547 cast<spirv::StructType>(pointeeType).getElementType(0); 548 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType)) 549 dstType = arrayType.getElementType(); 550 else 551 dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType(); 552 } 553 int dstBits = dstType.getIntOrFloatBitWidth(); 554 assert(dstBits % srcBits == 0); 555 556 // If the rewritten load op has the same bit width, use the loading value 557 // directly. 558 if (srcBits == dstBits) { 559 auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp); 560 if (failed(memoryRequirements)) 561 return rewriter.notifyMatchFailure( 562 loadOp, "failed to determine memory requirements"); 563 564 auto [memoryAccess, alignment] = *memoryRequirements; 565 Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain, 566 memoryAccess, alignment); 567 if (isBool) 568 loadVal = castIntNToBool(loc, loadVal, rewriter); 569 rewriter.replaceOp(loadOp, loadVal); 570 return success(); 571 } 572 573 // Bitcasting is currently unsupported for Kernel capability / 574 // spirv.PtrAccessChain. 575 if (typeConverter.allows(spirv::Capability::Kernel)) 576 return failure(); 577 578 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>(); 579 if (!accessChainOp) 580 return failure(); 581 582 // Assume that getElementPtr() works linearizely. If it's a scalar, the method 583 // still returns a linearized accessing. If the accessing is not linearized, 584 // there will be offset issues. 585 assert(accessChainOp.getIndices().size() == 2); 586 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, 587 srcBits, dstBits, rewriter); 588 auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp); 589 if (failed(memoryRequirements)) 590 return rewriter.notifyMatchFailure( 591 loadOp, "failed to determine memory requirements"); 592 593 auto [memoryAccess, alignment] = *memoryRequirements; 594 Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr, 595 memoryAccess, alignment); 596 597 // Shift the bits to the rightmost. 598 // ____XXXX________ -> ____________XXXX 599 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); 600 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); 601 Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( 602 loc, spvLoadOp.getType(), spvLoadOp, offset); 603 604 // Apply the mask to extract corresponding bits. 605 Value mask = rewriter.createOrFold<spirv::ConstantOp>( 606 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); 607 result = 608 rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask); 609 610 // Apply sign extension on the loading value unconditionally. The signedness 611 // semantic is carried in the operator itself, we relies other pattern to 612 // handle the casting. 613 IntegerAttr shiftValueAttr = 614 rewriter.getIntegerAttr(dstType, dstBits - srcBits); 615 Value shiftValue = 616 rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr); 617 result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType, 618 result, shiftValue); 619 result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>( 620 loc, dstType, result, shiftValue); 621 622 rewriter.replaceOp(loadOp, result); 623 624 assert(accessChainOp.use_empty()); 625 rewriter.eraseOp(accessChainOp); 626 627 return success(); 628 } 629 630 LogicalResult 631 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, 632 ConversionPatternRewriter &rewriter) const { 633 auto memrefType = cast<MemRefType>(loadOp.getMemref().getType()); 634 if (memrefType.getElementType().isSignlessInteger()) 635 return failure(); 636 Value loadPtr = spirv::getElementPtr( 637 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(), 638 adaptor.getIndices(), loadOp.getLoc(), rewriter); 639 640 if (!loadPtr) 641 return failure(); 642 643 auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp); 644 if (failed(memoryRequirements)) 645 return rewriter.notifyMatchFailure( 646 loadOp, "failed to determine memory requirements"); 647 648 auto [memoryAccess, alignment] = *memoryRequirements; 649 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess, 650 alignment); 651 return success(); 652 } 653 654 LogicalResult 655 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 656 ConversionPatternRewriter &rewriter) const { 657 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType()); 658 if (!memrefType.getElementType().isSignlessInteger()) 659 return rewriter.notifyMatchFailure(storeOp, 660 "element type is not a signless int"); 661 662 auto loc = storeOp.getLoc(); 663 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 664 Value accessChain = 665 spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), 666 adaptor.getIndices(), loc, rewriter); 667 668 if (!accessChain) 669 return rewriter.notifyMatchFailure( 670 storeOp, "failed to convert element pointer type"); 671 672 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); 673 674 bool isBool = srcBits == 1; 675 if (isBool) 676 srcBits = typeConverter.getOptions().boolNumBits; 677 678 auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType); 679 if (!pointerType) 680 return rewriter.notifyMatchFailure(storeOp, 681 "failed to convert memref type"); 682 683 Type pointeeType = pointerType.getPointeeType(); 684 IntegerType dstType; 685 if (typeConverter.allows(spirv::Capability::Kernel)) { 686 if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType)) 687 dstType = dyn_cast<IntegerType>(arrayType.getElementType()); 688 else 689 dstType = dyn_cast<IntegerType>(pointeeType); 690 } else { 691 // For Vulkan we need to extract element from wrapping struct and array. 692 Type structElemType = 693 cast<spirv::StructType>(pointeeType).getElementType(0); 694 if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType)) 695 dstType = dyn_cast<IntegerType>(arrayType.getElementType()); 696 else 697 dstType = dyn_cast<IntegerType>( 698 cast<spirv::RuntimeArrayType>(structElemType).getElementType()); 699 } 700 701 if (!dstType) 702 return rewriter.notifyMatchFailure( 703 storeOp, "failed to determine destination element type"); 704 705 int dstBits = static_cast<int>(dstType.getWidth()); 706 assert(dstBits % srcBits == 0); 707 708 if (srcBits == dstBits) { 709 auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp); 710 if (failed(memoryRequirements)) 711 return rewriter.notifyMatchFailure( 712 storeOp, "failed to determine memory requirements"); 713 714 auto [memoryAccess, alignment] = *memoryRequirements; 715 Value storeVal = adaptor.getValue(); 716 if (isBool) 717 storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); 718 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal, 719 memoryAccess, alignment); 720 return success(); 721 } 722 723 // Bitcasting is currently unsupported for Kernel capability / 724 // spirv.PtrAccessChain. 725 if (typeConverter.allows(spirv::Capability::Kernel)) 726 return failure(); 727 728 auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>(); 729 if (!accessChainOp) 730 return failure(); 731 732 // Since there are multiple threads in the processing, the emulation will be 733 // done with atomic operations. E.g., if the stored value is i8, rewrite the 734 // StoreOp to: 735 // 1) load a 32-bit integer 736 // 2) clear 8 bits in the loaded value 737 // 3) set 8 bits in the loaded value 738 // 4) store 32-bit value back 739 // 740 // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the 741 // loaded 32-bit value and the shifted 8-bit store value) as another atomic 742 // step. 743 assert(accessChainOp.getIndices().size() == 2); 744 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); 745 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); 746 747 // Create a mask to clear the destination. E.g., if it is the second i8 in 748 // i32, 0xFFFF00FF is created. 749 Value mask = rewriter.createOrFold<spirv::ConstantOp>( 750 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); 751 Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>( 752 loc, dstType, mask, offset); 753 clearBitsMask = 754 rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask); 755 756 Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter); 757 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, 758 srcBits, dstBits, rewriter); 759 std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType); 760 if (!scope) 761 return rewriter.notifyMatchFailure(storeOp, "atomic scope not available"); 762 763 Value result = rewriter.create<spirv::AtomicAndOp>( 764 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, 765 clearBitsMask); 766 result = rewriter.create<spirv::AtomicOrOp>( 767 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, 768 storeVal); 769 770 // The AtomicOrOp has no side effect. Since it is already inserted, we can 771 // just remove the original StoreOp. Note that rewriter.replaceOp() 772 // doesn't work because it only accepts that the numbers of result are the 773 // same. 774 rewriter.eraseOp(storeOp); 775 776 assert(accessChainOp.use_empty()); 777 rewriter.eraseOp(accessChainOp); 778 779 return success(); 780 } 781 782 //===----------------------------------------------------------------------===// 783 // MemorySpaceCastOp 784 //===----------------------------------------------------------------------===// 785 786 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( 787 memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, 788 ConversionPatternRewriter &rewriter) const { 789 Location loc = addrCastOp.getLoc(); 790 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>(); 791 if (!typeConverter.allows(spirv::Capability::Kernel)) 792 return rewriter.notifyMatchFailure( 793 loc, "address space casts require kernel capability"); 794 795 auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType()); 796 if (!sourceType) 797 return rewriter.notifyMatchFailure( 798 loc, "SPIR-V lowering requires ranked memref types"); 799 auto resultType = cast<MemRefType>(addrCastOp.getResult().getType()); 800 801 auto sourceStorageClassAttr = 802 dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace()); 803 if (!sourceStorageClassAttr) 804 return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) { 805 diag << "source address space " << sourceType.getMemorySpace() 806 << " must be a SPIR-V storage class"; 807 }); 808 auto resultStorageClassAttr = 809 dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace()); 810 if (!resultStorageClassAttr) 811 return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) { 812 diag << "result address space " << resultType.getMemorySpace() 813 << " must be a SPIR-V storage class"; 814 }); 815 816 spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue(); 817 spirv::StorageClass resultSc = resultStorageClassAttr.getValue(); 818 819 Value result = adaptor.getSource(); 820 Type resultPtrType = typeConverter.convertType(resultType); 821 if (!resultPtrType) 822 return rewriter.notifyMatchFailure(addrCastOp, 823 "failed to convert memref type"); 824 825 Type genericPtrType = resultPtrType; 826 // SPIR-V doesn't have a general address space cast operation. Instead, it has 827 // conversions to and from generic pointers. To implement the general case, 828 // we use specific-to-generic conversions when the source class is not 829 // generic. Then when the result storage class is not generic, we convert the 830 // generic pointer (either the input on ar intermediate result) to that 831 // class. This also means that we'll need the intermediate generic pointer 832 // type if neither the source or destination have it. 833 if (sourceSc != spirv::StorageClass::Generic && 834 resultSc != spirv::StorageClass::Generic) { 835 Type intermediateType = 836 MemRefType::get(sourceType.getShape(), sourceType.getElementType(), 837 sourceType.getLayout(), 838 rewriter.getAttr<spirv::StorageClassAttr>( 839 spirv::StorageClass::Generic)); 840 genericPtrType = typeConverter.convertType(intermediateType); 841 } 842 if (sourceSc != spirv::StorageClass::Generic) { 843 result = 844 rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result); 845 } 846 if (resultSc != spirv::StorageClass::Generic) { 847 result = 848 rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result); 849 } 850 rewriter.replaceOp(addrCastOp, result); 851 return success(); 852 } 853 854 LogicalResult 855 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, 856 ConversionPatternRewriter &rewriter) const { 857 auto memrefType = cast<MemRefType>(storeOp.getMemref().getType()); 858 if (memrefType.getElementType().isSignlessInteger()) 859 return rewriter.notifyMatchFailure(storeOp, "signless int"); 860 auto storePtr = spirv::getElementPtr( 861 *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(), 862 adaptor.getIndices(), storeOp.getLoc(), rewriter); 863 864 if (!storePtr) 865 return rewriter.notifyMatchFailure(storeOp, "type conversion failed"); 866 867 auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp); 868 if (failed(memoryRequirements)) 869 return rewriter.notifyMatchFailure( 870 storeOp, "failed to determine memory requirements"); 871 872 auto [memoryAccess, alignment] = *memoryRequirements; 873 rewriter.replaceOpWithNewOp<spirv::StoreOp>( 874 storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment); 875 return success(); 876 } 877 878 LogicalResult ReinterpretCastPattern::matchAndRewrite( 879 memref::ReinterpretCastOp op, OpAdaptor adaptor, 880 ConversionPatternRewriter &rewriter) const { 881 Value src = adaptor.getSource(); 882 auto srcType = dyn_cast<spirv::PointerType>(src.getType()); 883 884 if (!srcType) 885 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 886 diag << "invalid src type " << src.getType(); 887 }); 888 889 const TypeConverter *converter = getTypeConverter(); 890 891 auto dstType = converter->convertType<spirv::PointerType>(op.getType()); 892 if (dstType != srcType) 893 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { 894 diag << "invalid dst type " << op.getType(); 895 }); 896 897 OpFoldResult offset = 898 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter) 899 .front(); 900 if (isConstantIntValue(offset, 0)) { 901 rewriter.replaceOp(op, src); 902 return success(); 903 } 904 905 Type intType = converter->convertType(rewriter.getIndexType()); 906 if (!intType) 907 return rewriter.notifyMatchFailure(op, "failed to convert index type"); 908 909 Location loc = op.getLoc(); 910 auto offsetValue = [&]() -> Value { 911 if (auto val = dyn_cast<Value>(offset)) 912 return val; 913 914 int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt(); 915 Attribute attr = rewriter.getIntegerAttr(intType, attrVal); 916 return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr); 917 }(); 918 919 rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>( 920 op, src, offsetValue, std::nullopt); 921 return success(); 922 } 923 924 //===----------------------------------------------------------------------===// 925 // Pattern population 926 //===----------------------------------------------------------------------===// 927 928 namespace mlir { 929 void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, 930 RewritePatternSet &patterns) { 931 patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern, 932 DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, 933 LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern, 934 ReinterpretCastPattern, CastPattern>(typeConverter, 935 patterns.getContext()); 936 } 937 } // namespace mlir 938