xref: /llvm-project/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (revision 129ec845749fe117970f71c330945b5709e1d220)
1 //===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert MemRef dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Visitors.h"
24 #include "llvm/Support/Debug.h"
25 #include <cassert>
26 #include <optional>
27 
28 #define DEBUG_TYPE "memref-to-spirv-pattern"
29 
30 using namespace mlir;
31 
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35 
36 /// Returns the offset of the value in `targetBits` representation.
37 ///
38 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
39 /// It's assumed to be non-negative.
40 ///
41 /// When accessing an element in the array treating as having elements of
42 /// `targetBits`, multiple values are loaded in the same time. The method
43 /// returns the offset where the `srcIdx` locates in the value. For example, if
44 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
45 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
46 /// element has 8 bits.
47 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
48                                   int targetBits, OpBuilder &builder) {
49   assert(targetBits % sourceBits == 0);
50   Type type = srcIdx.getType();
51   IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
52   auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
53   IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
54   auto srcBitsValue =
55       builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
56   auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
57   return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
58 }
59 
60 /// Returns an adjusted spirv::AccessChainOp. Based on the
61 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
62 /// supported. During conversion if a memref of an unsupported type is used,
63 /// load/stores to this memref need to be modified to use a supported higher
64 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
65 /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load
66 /// the bits needed. The extraction of the actual bits needed are handled
67 /// separately. Note that this only works for a 1-D tensor.
68 static Value
69 adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
70                              spirv::AccessChainOp op, int sourceBits,
71                              int targetBits, OpBuilder &builder) {
72   assert(targetBits % sourceBits == 0);
73   const auto loc = op.getLoc();
74   Value lastDim = op->getOperand(op.getNumOperands() - 1);
75   Type type = lastDim.getType();
76   IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77   auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
78   auto indices = llvm::to_vector<4>(op.getIndices());
79   // There are two elements if this is a 1-D tensor.
80   assert(indices.size() == 2);
81   indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
82   Type t = typeConverter.convertType(op.getComponentPtr().getType());
83   return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
84 }
85 
86 /// Casts the given `srcBool` into an integer of `dstType`.
87 static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
88                             OpBuilder &builder) {
89   assert(srcBool.getType().isInteger(1));
90   if (dstType.isInteger(1))
91     return srcBool;
92   Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
93   Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
94   return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
95                                                zero);
96 }
97 
98 /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
99 /// to the type destination type, and masked.
100 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
101                         OpBuilder &builder) {
102   IntegerType dstType = cast<IntegerType>(mask.getType());
103   int targetBits = static_cast<int>(dstType.getWidth());
104   int valueBits = value.getType().getIntOrFloatBitWidth();
105   assert(valueBits <= targetBits);
106 
107   if (valueBits == 1) {
108     value = castBoolToIntN(loc, value, dstType, builder);
109   } else {
110     if (valueBits < targetBits) {
111       value = builder.create<spirv::UConvertOp>(
112           loc, builder.getIntegerType(targetBits), value);
113     }
114 
115     value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
116   }
117   return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
118                                                          value, offset);
119 }
120 
121 /// Returns true if the allocations of memref `type` generated from `allocOp`
122 /// can be lowered to SPIR-V.
123 static bool isAllocationSupported(Operation *allocOp, MemRefType type) {
124   if (isa<memref::AllocOp, memref::DeallocOp>(allocOp)) {
125     auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
126     if (!sc || sc.getValue() != spirv::StorageClass::Workgroup)
127       return false;
128   } else if (isa<memref::AllocaOp>(allocOp)) {
129     auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
130     if (!sc || sc.getValue() != spirv::StorageClass::Function)
131       return false;
132   } else {
133     return false;
134   }
135 
136   // Currently only support static shape and int or float or vector of int or
137   // float element type.
138   if (!type.hasStaticShape())
139     return false;
140 
141   Type elementType = type.getElementType();
142   if (auto vecType = dyn_cast<VectorType>(elementType))
143     elementType = vecType.getElementType();
144   return elementType.isIntOrFloat();
145 }
146 
147 /// Returns the scope to use for atomic operations use for emulating store
148 /// operations of unsupported integer bitwidths, based on the memref
149 /// type. Returns std::nullopt on failure.
150 static std::optional<spirv::Scope> getAtomicOpScope(MemRefType type) {
151   auto sc = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
152   switch (sc.getValue()) {
153   case spirv::StorageClass::StorageBuffer:
154     return spirv::Scope::Device;
155   case spirv::StorageClass::Workgroup:
156     return spirv::Scope::Workgroup;
157   default:
158     break;
159   }
160   return {};
161 }
162 
163 /// Casts the given `srcInt` into a boolean value.
164 static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
165   if (srcInt.getType().isInteger(1))
166     return srcInt;
167 
168   auto one = spirv::ConstantOp::getZero(srcInt.getType(), loc, builder);
169   return builder.createOrFold<spirv::INotEqualOp>(loc, srcInt, one);
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // Operation conversion
174 //===----------------------------------------------------------------------===//
175 
176 // Note that DRR cannot be used for the patterns in this file: we may need to
177 // convert type along the way, which requires ConversionPattern. DRR generates
178 // normal RewritePattern.
179 
180 namespace {
181 
182 /// Converts memref.alloca to SPIR-V Function variables.
183 class AllocaOpPattern final : public OpConversionPattern<memref::AllocaOp> {
184 public:
185   using OpConversionPattern<memref::AllocaOp>::OpConversionPattern;
186 
187   LogicalResult
188   matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
189                   ConversionPatternRewriter &rewriter) const override;
190 };
191 
192 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
193 /// to Workgroup memory when the size is constant.  Note that this pattern needs
194 /// to be applied in a pass that runs at least at spirv.module scope since it
195 /// wil ladd global variables into the spirv.module.
196 class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
197 public:
198   using OpConversionPattern<memref::AllocOp>::OpConversionPattern;
199 
200   LogicalResult
201   matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
202                   ConversionPatternRewriter &rewriter) const override;
203 };
204 
205 /// Converts memref.automic_rmw operations to SPIR-V atomic operations.
206 class AtomicRMWOpPattern final
207     : public OpConversionPattern<memref::AtomicRMWOp> {
208 public:
209   using OpConversionPattern<memref::AtomicRMWOp>::OpConversionPattern;
210 
211   LogicalResult
212   matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor,
213                   ConversionPatternRewriter &rewriter) const override;
214 };
215 
216 /// Removed a deallocation if it is a supported allocation. Currently only
217 /// removes deallocation if the memory space is workgroup memory.
218 class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
219 public:
220   using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;
221 
222   LogicalResult
223   matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
224                   ConversionPatternRewriter &rewriter) const override;
225 };
226 
227 /// Converts memref.load to spirv.Load + spirv.AccessChain on integers.
228 class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
229 public:
230   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
231 
232   LogicalResult
233   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
234                   ConversionPatternRewriter &rewriter) const override;
235 };
236 
237 /// Converts memref.load to spirv.Load + spirv.AccessChain.
238 class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
239 public:
240   using OpConversionPattern<memref::LoadOp>::OpConversionPattern;
241 
242   LogicalResult
243   matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
244                   ConversionPatternRewriter &rewriter) const override;
245 };
246 
247 /// Converts memref.store to spirv.Store on integers.
248 class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
249 public:
250   using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
251 
252   LogicalResult
253   matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
254                   ConversionPatternRewriter &rewriter) const override;
255 };
256 
257 /// Converts memref.memory_space_cast to the appropriate spirv cast operations.
258 class MemorySpaceCastOpPattern final
259     : public OpConversionPattern<memref::MemorySpaceCastOp> {
260 public:
261   using OpConversionPattern<memref::MemorySpaceCastOp>::OpConversionPattern;
262 
263   LogicalResult
264   matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
265                   ConversionPatternRewriter &rewriter) const override;
266 };
267 
268 /// Converts memref.store to spirv.Store.
269 class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
270 public:
271   using OpConversionPattern<memref::StoreOp>::OpConversionPattern;
272 
273   LogicalResult
274   matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
275                   ConversionPatternRewriter &rewriter) const override;
276 };
277 
278 class ReinterpretCastPattern final
279     : public OpConversionPattern<memref::ReinterpretCastOp> {
280 public:
281   using OpConversionPattern::OpConversionPattern;
282 
283   LogicalResult
284   matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
285                   ConversionPatternRewriter &rewriter) const override;
286 };
287 
288 class CastPattern final : public OpConversionPattern<memref::CastOp> {
289 public:
290   using OpConversionPattern::OpConversionPattern;
291 
292   LogicalResult
293   matchAndRewrite(memref::CastOp op, OpAdaptor adaptor,
294                   ConversionPatternRewriter &rewriter) const override {
295     Value src = adaptor.getSource();
296     Type srcType = src.getType();
297 
298     const TypeConverter *converter = getTypeConverter();
299     Type dstType = converter->convertType(op.getType());
300     if (srcType != dstType)
301       return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
302         diag << "types doesn't match: " << srcType << " and " << dstType;
303       });
304 
305     rewriter.replaceOp(op, src);
306     return success();
307   }
308 };
309 
310 } // namespace
311 
312 //===----------------------------------------------------------------------===//
313 // AllocaOp
314 //===----------------------------------------------------------------------===//
315 
316 LogicalResult
317 AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor,
318                                  ConversionPatternRewriter &rewriter) const {
319   MemRefType allocType = allocaOp.getType();
320   if (!isAllocationSupported(allocaOp, allocType))
321     return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type");
322 
323   // Get the SPIR-V type for the allocation.
324   Type spirvType = getTypeConverter()->convertType(allocType);
325   if (!spirvType)
326     return rewriter.notifyMatchFailure(allocaOp, "type conversion failed");
327 
328   rewriter.replaceOpWithNewOp<spirv::VariableOp>(allocaOp, spirvType,
329                                                  spirv::StorageClass::Function,
330                                                  /*initializer=*/nullptr);
331   return success();
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // AllocOp
336 //===----------------------------------------------------------------------===//
337 
338 LogicalResult
339 AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
340                                 ConversionPatternRewriter &rewriter) const {
341   MemRefType allocType = operation.getType();
342   if (!isAllocationSupported(operation, allocType))
343     return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
344 
345   // Get the SPIR-V type for the allocation.
346   Type spirvType = getTypeConverter()->convertType(allocType);
347   if (!spirvType)
348     return rewriter.notifyMatchFailure(operation, "type conversion failed");
349 
350   // Insert spirv.GlobalVariable for this allocation.
351   Operation *parent =
352       SymbolTable::getNearestSymbolTable(operation->getParentOp());
353   if (!parent)
354     return failure();
355   Location loc = operation.getLoc();
356   spirv::GlobalVariableOp varOp;
357   {
358     OpBuilder::InsertionGuard guard(rewriter);
359     Block &entryBlock = *parent->getRegion(0).begin();
360     rewriter.setInsertionPointToStart(&entryBlock);
361     auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
362     std::string varName =
363         std::string("__workgroup_mem__") +
364         std::to_string(std::distance(varOps.begin(), varOps.end()));
365     varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
366                                                      /*initializer=*/nullptr);
367   }
368 
369   // Get pointer to global variable at the current scope.
370   rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
371   return success();
372 }
373 
374 //===----------------------------------------------------------------------===//
375 // AllocOp
376 //===----------------------------------------------------------------------===//
377 
378 LogicalResult
379 AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp,
380                                     OpAdaptor adaptor,
381                                     ConversionPatternRewriter &rewriter) const {
382   if (isa<FloatType>(atomicOp.getType()))
383     return rewriter.notifyMatchFailure(atomicOp,
384                                        "unimplemented floating-point case");
385 
386   auto memrefType = cast<MemRefType>(atomicOp.getMemref().getType());
387   std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
388   if (!scope)
389     return rewriter.notifyMatchFailure(atomicOp,
390                                        "unsupported memref memory space");
391 
392   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
393   Type resultType = typeConverter.convertType(atomicOp.getType());
394   if (!resultType)
395     return rewriter.notifyMatchFailure(atomicOp,
396                                        "failed to convert result type");
397 
398   auto loc = atomicOp.getLoc();
399   Value ptr =
400       spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
401                            adaptor.getIndices(), loc, rewriter);
402 
403   if (!ptr)
404     return failure();
405 
406 #define ATOMIC_CASE(kind, spirvOp)                                             \
407   case arith::AtomicRMWKind::kind:                                             \
408     rewriter.replaceOpWithNewOp<spirv::spirvOp>(                               \
409         atomicOp, resultType, ptr, *scope,                                     \
410         spirv::MemorySemantics::AcquireRelease, adaptor.getValue());           \
411     break
412 
413   switch (atomicOp.getKind()) {
414     ATOMIC_CASE(addi, AtomicIAddOp);
415     ATOMIC_CASE(maxs, AtomicSMaxOp);
416     ATOMIC_CASE(maxu, AtomicUMaxOp);
417     ATOMIC_CASE(mins, AtomicSMinOp);
418     ATOMIC_CASE(minu, AtomicUMinOp);
419     ATOMIC_CASE(ori, AtomicOrOp);
420     ATOMIC_CASE(andi, AtomicAndOp);
421   default:
422     return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind");
423   }
424 
425 #undef ATOMIC_CASE
426 
427   return success();
428 }
429 
430 //===----------------------------------------------------------------------===//
431 // DeallocOp
432 //===----------------------------------------------------------------------===//
433 
434 LogicalResult
435 DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
436                                   OpAdaptor adaptor,
437                                   ConversionPatternRewriter &rewriter) const {
438   MemRefType deallocType = cast<MemRefType>(operation.getMemref().getType());
439   if (!isAllocationSupported(operation, deallocType))
440     return rewriter.notifyMatchFailure(operation, "unhandled allocation type");
441   rewriter.eraseOp(operation);
442   return success();
443 }
444 
445 //===----------------------------------------------------------------------===//
446 // LoadOp
447 //===----------------------------------------------------------------------===//
448 
449 struct MemoryRequirements {
450   spirv::MemoryAccessAttr memoryAccess;
451   IntegerAttr alignment;
452 };
453 
454 /// Given an accessed SPIR-V pointer, calculates its alignment requirements, if
455 /// any.
456 static FailureOr<MemoryRequirements>
457 calculateMemoryRequirements(Value accessedPtr, bool isNontemporal) {
458   MLIRContext *ctx = accessedPtr.getContext();
459 
460   auto memoryAccess = spirv::MemoryAccess::None;
461   if (isNontemporal) {
462     memoryAccess = spirv::MemoryAccess::Nontemporal;
463   }
464 
465   auto ptrType = cast<spirv::PointerType>(accessedPtr.getType());
466   if (ptrType.getStorageClass() != spirv::StorageClass::PhysicalStorageBuffer) {
467     if (memoryAccess == spirv::MemoryAccess::None) {
468       return MemoryRequirements{spirv::MemoryAccessAttr{}, IntegerAttr{}};
469     }
470     return MemoryRequirements{spirv::MemoryAccessAttr::get(ctx, memoryAccess),
471                               IntegerAttr{}};
472   }
473 
474   // PhysicalStorageBuffers require the `Aligned` attribute.
475   auto pointeeType = dyn_cast<spirv::ScalarType>(ptrType.getPointeeType());
476   if (!pointeeType)
477     return failure();
478 
479   // For scalar types, the alignment is determined by their size.
480   std::optional<int64_t> sizeInBytes = pointeeType.getSizeInBytes();
481   if (!sizeInBytes.has_value())
482     return failure();
483 
484   memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned;
485   auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess);
486   auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), *sizeInBytes);
487   return MemoryRequirements{memAccessAttr, alignment};
488 }
489 
490 /// Given an accessed SPIR-V pointer and the original memref load/store
491 /// `memAccess` op, calculates the alignment requirements, if any. Takes into
492 /// account the alignment attributes applied to the load/store op.
493 template <class LoadOrStoreOp>
494 static FailureOr<MemoryRequirements>
495 calculateMemoryRequirements(Value accessedPtr, LoadOrStoreOp loadOrStoreOp) {
496   static_assert(
497       llvm::is_one_of<LoadOrStoreOp, memref::LoadOp, memref::StoreOp>::value,
498       "Must be called on either memref::LoadOp or memref::StoreOp");
499 
500   Operation *memrefAccessOp = loadOrStoreOp.getOperation();
501   auto memrefMemAccess = memrefAccessOp->getAttrOfType<spirv::MemoryAccessAttr>(
502       spirv::attributeName<spirv::MemoryAccess>());
503   auto memrefAlignment =
504       memrefAccessOp->getAttrOfType<IntegerAttr>("alignment");
505   if (memrefMemAccess && memrefAlignment)
506     return MemoryRequirements{memrefMemAccess, memrefAlignment};
507 
508   return calculateMemoryRequirements(accessedPtr,
509                                      loadOrStoreOp.getNontemporal());
510 }
511 
512 LogicalResult
513 IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
514                                   ConversionPatternRewriter &rewriter) const {
515   auto loc = loadOp.getLoc();
516   auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
517   if (!memrefType.getElementType().isSignlessInteger())
518     return failure();
519 
520   const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
521   Value accessChain =
522       spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
523                            adaptor.getIndices(), loc, rewriter);
524 
525   if (!accessChain)
526     return failure();
527 
528   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
529   bool isBool = srcBits == 1;
530   if (isBool)
531     srcBits = typeConverter.getOptions().boolNumBits;
532 
533   auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
534   if (!pointerType)
535     return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type");
536 
537   Type pointeeType = pointerType.getPointeeType();
538   Type dstType;
539   if (typeConverter.allows(spirv::Capability::Kernel)) {
540     if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
541       dstType = arrayType.getElementType();
542     else
543       dstType = pointeeType;
544   } else {
545     // For Vulkan we need to extract element from wrapping struct and array.
546     Type structElemType =
547         cast<spirv::StructType>(pointeeType).getElementType(0);
548     if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
549       dstType = arrayType.getElementType();
550     else
551       dstType = cast<spirv::RuntimeArrayType>(structElemType).getElementType();
552   }
553   int dstBits = dstType.getIntOrFloatBitWidth();
554   assert(dstBits % srcBits == 0);
555 
556   // If the rewritten load op has the same bit width, use the loading value
557   // directly.
558   if (srcBits == dstBits) {
559     auto memoryRequirements = calculateMemoryRequirements(accessChain, loadOp);
560     if (failed(memoryRequirements))
561       return rewriter.notifyMatchFailure(
562           loadOp, "failed to determine memory requirements");
563 
564     auto [memoryAccess, alignment] = *memoryRequirements;
565     Value loadVal = rewriter.create<spirv::LoadOp>(loc, accessChain,
566                                                    memoryAccess, alignment);
567     if (isBool)
568       loadVal = castIntNToBool(loc, loadVal, rewriter);
569     rewriter.replaceOp(loadOp, loadVal);
570     return success();
571   }
572 
573   // Bitcasting is currently unsupported for Kernel capability /
574   // spirv.PtrAccessChain.
575   if (typeConverter.allows(spirv::Capability::Kernel))
576     return failure();
577 
578   auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
579   if (!accessChainOp)
580     return failure();
581 
582   // Assume that getElementPtr() works linearizely. If it's a scalar, the method
583   // still returns a linearized accessing. If the accessing is not linearized,
584   // there will be offset issues.
585   assert(accessChainOp.getIndices().size() == 2);
586   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
587                                                    srcBits, dstBits, rewriter);
588   auto memoryRequirements = calculateMemoryRequirements(adjustedPtr, loadOp);
589   if (failed(memoryRequirements))
590     return rewriter.notifyMatchFailure(
591         loadOp, "failed to determine memory requirements");
592 
593   auto [memoryAccess, alignment] = *memoryRequirements;
594   Value spvLoadOp = rewriter.create<spirv::LoadOp>(loc, dstType, adjustedPtr,
595                                                    memoryAccess, alignment);
596 
597   // Shift the bits to the rightmost.
598   // ____XXXX________ -> ____________XXXX
599   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
600   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
601   Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
602       loc, spvLoadOp.getType(), spvLoadOp, offset);
603 
604   // Apply the mask to extract corresponding bits.
605   Value mask = rewriter.createOrFold<spirv::ConstantOp>(
606       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
607   result =
608       rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
609 
610   // Apply sign extension on the loading value unconditionally. The signedness
611   // semantic is carried in the operator itself, we relies other pattern to
612   // handle the casting.
613   IntegerAttr shiftValueAttr =
614       rewriter.getIntegerAttr(dstType, dstBits - srcBits);
615   Value shiftValue =
616       rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
617   result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
618                                                             result, shiftValue);
619   result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
620       loc, dstType, result, shiftValue);
621 
622   rewriter.replaceOp(loadOp, result);
623 
624   assert(accessChainOp.use_empty());
625   rewriter.eraseOp(accessChainOp);
626 
627   return success();
628 }
629 
630 LogicalResult
631 LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
632                                ConversionPatternRewriter &rewriter) const {
633   auto memrefType = cast<MemRefType>(loadOp.getMemref().getType());
634   if (memrefType.getElementType().isSignlessInteger())
635     return failure();
636   Value loadPtr = spirv::getElementPtr(
637       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
638       adaptor.getIndices(), loadOp.getLoc(), rewriter);
639 
640   if (!loadPtr)
641     return failure();
642 
643   auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
644   if (failed(memoryRequirements))
645     return rewriter.notifyMatchFailure(
646         loadOp, "failed to determine memory requirements");
647 
648   auto [memoryAccess, alignment] = *memoryRequirements;
649   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr, memoryAccess,
650                                              alignment);
651   return success();
652 }
653 
654 LogicalResult
655 IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
656                                    ConversionPatternRewriter &rewriter) const {
657   auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
658   if (!memrefType.getElementType().isSignlessInteger())
659     return rewriter.notifyMatchFailure(storeOp,
660                                        "element type is not a signless int");
661 
662   auto loc = storeOp.getLoc();
663   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
664   Value accessChain =
665       spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
666                            adaptor.getIndices(), loc, rewriter);
667 
668   if (!accessChain)
669     return rewriter.notifyMatchFailure(
670         storeOp, "failed to convert element pointer type");
671 
672   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
673 
674   bool isBool = srcBits == 1;
675   if (isBool)
676     srcBits = typeConverter.getOptions().boolNumBits;
677 
678   auto pointerType = typeConverter.convertType<spirv::PointerType>(memrefType);
679   if (!pointerType)
680     return rewriter.notifyMatchFailure(storeOp,
681                                        "failed to convert memref type");
682 
683   Type pointeeType = pointerType.getPointeeType();
684   IntegerType dstType;
685   if (typeConverter.allows(spirv::Capability::Kernel)) {
686     if (auto arrayType = dyn_cast<spirv::ArrayType>(pointeeType))
687       dstType = dyn_cast<IntegerType>(arrayType.getElementType());
688     else
689       dstType = dyn_cast<IntegerType>(pointeeType);
690   } else {
691     // For Vulkan we need to extract element from wrapping struct and array.
692     Type structElemType =
693         cast<spirv::StructType>(pointeeType).getElementType(0);
694     if (auto arrayType = dyn_cast<spirv::ArrayType>(structElemType))
695       dstType = dyn_cast<IntegerType>(arrayType.getElementType());
696     else
697       dstType = dyn_cast<IntegerType>(
698           cast<spirv::RuntimeArrayType>(structElemType).getElementType());
699   }
700 
701   if (!dstType)
702     return rewriter.notifyMatchFailure(
703         storeOp, "failed to determine destination element type");
704 
705   int dstBits = static_cast<int>(dstType.getWidth());
706   assert(dstBits % srcBits == 0);
707 
708   if (srcBits == dstBits) {
709     auto memoryRequirements = calculateMemoryRequirements(accessChain, storeOp);
710     if (failed(memoryRequirements))
711       return rewriter.notifyMatchFailure(
712           storeOp, "failed to determine memory requirements");
713 
714     auto [memoryAccess, alignment] = *memoryRequirements;
715     Value storeVal = adaptor.getValue();
716     if (isBool)
717       storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
718     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
719                                                 memoryAccess, alignment);
720     return success();
721   }
722 
723   // Bitcasting is currently unsupported for Kernel capability /
724   // spirv.PtrAccessChain.
725   if (typeConverter.allows(spirv::Capability::Kernel))
726     return failure();
727 
728   auto accessChainOp = accessChain.getDefiningOp<spirv::AccessChainOp>();
729   if (!accessChainOp)
730     return failure();
731 
732   // Since there are multiple threads in the processing, the emulation will be
733   // done with atomic operations. E.g., if the stored value is i8, rewrite the
734   // StoreOp to:
735   // 1) load a 32-bit integer
736   // 2) clear 8 bits in the loaded value
737   // 3) set 8 bits in the loaded value
738   // 4) store 32-bit value back
739   //
740   // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the
741   // loaded 32-bit value and the shifted 8-bit store value) as another atomic
742   // step.
743   assert(accessChainOp.getIndices().size() == 2);
744   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
745   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
746 
747   // Create a mask to clear the destination. E.g., if it is the second i8 in
748   // i32, 0xFFFF00FF is created.
749   Value mask = rewriter.createOrFold<spirv::ConstantOp>(
750       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
751   Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
752       loc, dstType, mask, offset);
753   clearBitsMask =
754       rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
755 
756   Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
757   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
758                                                    srcBits, dstBits, rewriter);
759   std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
760   if (!scope)
761     return rewriter.notifyMatchFailure(storeOp, "atomic scope not available");
762 
763   Value result = rewriter.create<spirv::AtomicAndOp>(
764       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
765       clearBitsMask);
766   result = rewriter.create<spirv::AtomicOrOp>(
767       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
768       storeVal);
769 
770   // The AtomicOrOp has no side effect. Since it is already inserted, we can
771   // just remove the original StoreOp. Note that rewriter.replaceOp()
772   // doesn't work because it only accepts that the numbers of result are the
773   // same.
774   rewriter.eraseOp(storeOp);
775 
776   assert(accessChainOp.use_empty());
777   rewriter.eraseOp(accessChainOp);
778 
779   return success();
780 }
781 
782 //===----------------------------------------------------------------------===//
783 // MemorySpaceCastOp
784 //===----------------------------------------------------------------------===//
785 
786 LogicalResult MemorySpaceCastOpPattern::matchAndRewrite(
787     memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor,
788     ConversionPatternRewriter &rewriter) const {
789   Location loc = addrCastOp.getLoc();
790   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
791   if (!typeConverter.allows(spirv::Capability::Kernel))
792     return rewriter.notifyMatchFailure(
793         loc, "address space casts require kernel capability");
794 
795   auto sourceType = dyn_cast<MemRefType>(addrCastOp.getSource().getType());
796   if (!sourceType)
797     return rewriter.notifyMatchFailure(
798         loc, "SPIR-V lowering requires ranked memref types");
799   auto resultType = cast<MemRefType>(addrCastOp.getResult().getType());
800 
801   auto sourceStorageClassAttr =
802       dyn_cast_or_null<spirv::StorageClassAttr>(sourceType.getMemorySpace());
803   if (!sourceStorageClassAttr)
804     return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) {
805       diag << "source address space " << sourceType.getMemorySpace()
806            << " must be a SPIR-V storage class";
807     });
808   auto resultStorageClassAttr =
809       dyn_cast_or_null<spirv::StorageClassAttr>(resultType.getMemorySpace());
810   if (!resultStorageClassAttr)
811     return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) {
812       diag << "result address space " << resultType.getMemorySpace()
813            << " must be a SPIR-V storage class";
814     });
815 
816   spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue();
817   spirv::StorageClass resultSc = resultStorageClassAttr.getValue();
818 
819   Value result = adaptor.getSource();
820   Type resultPtrType = typeConverter.convertType(resultType);
821   if (!resultPtrType)
822     return rewriter.notifyMatchFailure(addrCastOp,
823                                        "failed to convert memref type");
824 
825   Type genericPtrType = resultPtrType;
826   // SPIR-V doesn't have a general address space cast operation. Instead, it has
827   // conversions to and from generic pointers. To implement the general case,
828   // we use specific-to-generic conversions when the source class is not
829   // generic. Then when the result storage class is not generic, we convert the
830   // generic pointer (either the input on ar intermediate result) to that
831   // class. This also means that we'll need the intermediate generic pointer
832   // type if neither the source or destination have it.
833   if (sourceSc != spirv::StorageClass::Generic &&
834       resultSc != spirv::StorageClass::Generic) {
835     Type intermediateType =
836         MemRefType::get(sourceType.getShape(), sourceType.getElementType(),
837                         sourceType.getLayout(),
838                         rewriter.getAttr<spirv::StorageClassAttr>(
839                             spirv::StorageClass::Generic));
840     genericPtrType = typeConverter.convertType(intermediateType);
841   }
842   if (sourceSc != spirv::StorageClass::Generic) {
843     result =
844         rewriter.create<spirv::PtrCastToGenericOp>(loc, genericPtrType, result);
845   }
846   if (resultSc != spirv::StorageClass::Generic) {
847     result =
848         rewriter.create<spirv::GenericCastToPtrOp>(loc, resultPtrType, result);
849   }
850   rewriter.replaceOp(addrCastOp, result);
851   return success();
852 }
853 
854 LogicalResult
855 StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
856                                 ConversionPatternRewriter &rewriter) const {
857   auto memrefType = cast<MemRefType>(storeOp.getMemref().getType());
858   if (memrefType.getElementType().isSignlessInteger())
859     return rewriter.notifyMatchFailure(storeOp, "signless int");
860   auto storePtr = spirv::getElementPtr(
861       *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
862       adaptor.getIndices(), storeOp.getLoc(), rewriter);
863 
864   if (!storePtr)
865     return rewriter.notifyMatchFailure(storeOp, "type conversion failed");
866 
867   auto memoryRequirements = calculateMemoryRequirements(storePtr, storeOp);
868   if (failed(memoryRequirements))
869     return rewriter.notifyMatchFailure(
870         storeOp, "failed to determine memory requirements");
871 
872   auto [memoryAccess, alignment] = *memoryRequirements;
873   rewriter.replaceOpWithNewOp<spirv::StoreOp>(
874       storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
875   return success();
876 }
877 
878 LogicalResult ReinterpretCastPattern::matchAndRewrite(
879     memref::ReinterpretCastOp op, OpAdaptor adaptor,
880     ConversionPatternRewriter &rewriter) const {
881   Value src = adaptor.getSource();
882   auto srcType = dyn_cast<spirv::PointerType>(src.getType());
883 
884   if (!srcType)
885     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
886       diag << "invalid src type " << src.getType();
887     });
888 
889   const TypeConverter *converter = getTypeConverter();
890 
891   auto dstType = converter->convertType<spirv::PointerType>(op.getType());
892   if (dstType != srcType)
893     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
894       diag << "invalid dst type " << op.getType();
895     });
896 
897   OpFoldResult offset =
898       getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
899           .front();
900   if (isConstantIntValue(offset, 0)) {
901     rewriter.replaceOp(op, src);
902     return success();
903   }
904 
905   Type intType = converter->convertType(rewriter.getIndexType());
906   if (!intType)
907     return rewriter.notifyMatchFailure(op, "failed to convert index type");
908 
909   Location loc = op.getLoc();
910   auto offsetValue = [&]() -> Value {
911     if (auto val = dyn_cast<Value>(offset))
912       return val;
913 
914     int64_t attrVal = cast<IntegerAttr>(cast<Attribute>(offset)).getInt();
915     Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
916     return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
917   }();
918 
919   rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
920       op, src, offsetValue, std::nullopt);
921   return success();
922 }
923 
924 //===----------------------------------------------------------------------===//
925 // Pattern population
926 //===----------------------------------------------------------------------===//
927 
928 namespace mlir {
929 void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
930                                    RewritePatternSet &patterns) {
931   patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
932                DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
933                LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
934                ReinterpretCastPattern, CastPattern>(typeConverter,
935                                                     patterns.getContext());
936 }
937 } // namespace mlir
938