xref: /llvm-project/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++
2 //-*-===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "mlir/Dialect/Affine/IR/AffineOps.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
13 #include "mlir/Dialect/Arith/Transforms/Passes.h"
14 #include "mlir/Dialect/Arith/Utils/Utils.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
17 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/OpDefinition.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "llvm/Support/MathExtras.h"
25 #include <cassert>
26 #include <type_traits>
27 
28 using namespace mlir;
29 
30 //===----------------------------------------------------------------------===//
31 // Utility functions
32 //===----------------------------------------------------------------------===//
33 
34 /// Converts a memref::ReinterpretCastOp to the converted type. The result
35 /// MemRefType of the old op must have a rank and stride of 1, with static
36 /// offset and size. The number of bits in the offset must evenly divide the
37 /// bitwidth of the new converted type.
38 static LogicalResult
39 convertCastingOp(ConversionPatternRewriter &rewriter,
40                  memref::ReinterpretCastOp::Adaptor adaptor,
41                  memref::ReinterpretCastOp op, MemRefType newTy) {
42   auto convertedElementType = newTy.getElementType();
43   auto oldElementType = op.getType().getElementType();
44   int srcBits = oldElementType.getIntOrFloatBitWidth();
45   int dstBits = convertedElementType.getIntOrFloatBitWidth();
46   if (dstBits % srcBits != 0) {
47     return rewriter.notifyMatchFailure(op,
48                                        "only dstBits % srcBits == 0 supported");
49   }
50 
51   // Only support stride of 1.
52   if (llvm::any_of(op.getStaticStrides(),
53                    [](int64_t stride) { return stride != 1; })) {
54     return rewriter.notifyMatchFailure(op->getLoc(),
55                                        "stride != 1 is not supported");
56   }
57 
58   auto sizes = op.getStaticSizes();
59   int64_t offset = op.getStaticOffset(0);
60   // Only support static sizes and offsets.
61   if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
62       offset == ShapedType::kDynamic) {
63     return rewriter.notifyMatchFailure(
64         op, "dynamic size or offset is not supported");
65   }
66 
67   int elementsPerByte = dstBits / srcBits;
68   if (offset % elementsPerByte != 0) {
69     return rewriter.notifyMatchFailure(
70         op, "offset not multiple of elementsPerByte is not supported");
71   }
72 
73   SmallVector<int64_t> size;
74   if (sizes.size())
75     size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte));
76   offset = offset / elementsPerByte;
77 
78   rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
79       op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
80   return success();
81 }
82 
83 /// When data is loaded/stored in `targetBits` granularity, but is used in
84 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
85 /// treated as an array of elements of width `sourceBits`.
86 /// Return the bit offset of the value at position `srcIdx`. For example, if
87 /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
88 /// located at (x % 2) * 4. Because there are two elements in one i8, and one
89 /// element has 4 bits.
90 static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
91                                   int sourceBits, int targetBits,
92                                   OpBuilder &builder) {
93   assert(targetBits % sourceBits == 0);
94   AffineExpr s0;
95   bindSymbols(builder.getContext(), s0);
96   int scaleFactor = targetBits / sourceBits;
97   AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
98   OpFoldResult offsetVal =
99       affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
100   Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
101   IntegerType dstType = builder.getIntegerType(targetBits);
102   return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
103 }
104 
105 /// When writing a subbyte size, masked bitwise operations are used to only
106 /// modify the relevant bits. This function returns an and mask for clearing
107 /// the destination bits in a subbyte write. E.g., when writing to the second
108 /// i4 in an i32, 0xFFFFFF0F is created.
109 static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
110                                  int64_t srcBits, int64_t dstBits,
111                                  Value bitwidthOffset, OpBuilder &builder) {
112   auto dstIntegerType = builder.getIntegerType(dstBits);
113   auto maskRightAlignedAttr =
114       builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
115   Value maskRightAligned = builder.create<arith::ConstantOp>(
116       loc, dstIntegerType, maskRightAlignedAttr);
117   Value writeMaskInverse =
118       builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
119   auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
120   Value flipVal =
121       builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
122   return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
123 }
124 
125 /// Returns the scaled linearized index based on the `srcBits` and `dstBits`
126 /// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
127 /// the returned index has the granularity of `dstBits`
128 static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
129                                       OpFoldResult linearizedIndex,
130                                       int64_t srcBits, int64_t dstBits) {
131   AffineExpr s0;
132   bindSymbols(builder.getContext(), s0);
133   int64_t scaler = dstBits / srcBits;
134   OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
135       builder, loc, s0.floorDiv(scaler), {linearizedIndex});
136   return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
137 }
138 
139 static OpFoldResult
140 getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
141                         const SmallVector<OpFoldResult> &indices,
142                         Value memref) {
143   auto stridedMetadata =
144       builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
145   OpFoldResult linearizedIndices;
146   std::tie(std::ignore, linearizedIndices) =
147       memref::getLinearizedMemRefOffsetAndSize(
148           builder, loc, srcBits, srcBits,
149           stridedMetadata.getConstifiedMixedOffset(),
150           stridedMetadata.getConstifiedMixedSizes(),
151           stridedMetadata.getConstifiedMixedStrides(), indices);
152   return linearizedIndices;
153 }
154 
155 namespace {
156 
157 //===----------------------------------------------------------------------===//
158 // ConvertMemRefAllocation
159 //===----------------------------------------------------------------------===//
160 
161 template <typename OpTy>
162 struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
163   using OpConversionPattern<OpTy>::OpConversionPattern;
164 
165   LogicalResult
166   matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
167                   ConversionPatternRewriter &rewriter) const override {
168     static_assert(std::is_same<OpTy, memref::AllocOp>() ||
169                       std::is_same<OpTy, memref::AllocaOp>(),
170                   "expected only memref::AllocOp or memref::AllocaOp");
171     auto currentType = cast<MemRefType>(op.getMemref().getType());
172     auto newResultType =
173         this->getTypeConverter()->template convertType<MemRefType>(
174             op.getType());
175     if (!newResultType) {
176       return rewriter.notifyMatchFailure(
177           op->getLoc(),
178           llvm::formatv("failed to convert memref type: {0}", op.getType()));
179     }
180 
181     // Special case zero-rank memrefs.
182     if (currentType.getRank() == 0) {
183       rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
184                                         adaptor.getSymbolOperands(),
185                                         adaptor.getAlignmentAttr());
186       return success();
187     }
188 
189     Location loc = op.getLoc();
190     OpFoldResult zero = rewriter.getIndexAttr(0);
191     SmallVector<OpFoldResult> indices(currentType.getRank(), zero);
192 
193     // Get linearized type.
194     int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
195     int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
196     SmallVector<OpFoldResult> sizes = op.getMixedSizes();
197 
198     memref::LinearizedMemRefInfo linearizedMemRefInfo =
199         memref::getLinearizedMemRefOffsetAndSize(
200             rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
201     SmallVector<Value> dynamicLinearizedSize;
202     if (!newResultType.hasStaticShape()) {
203       dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
204           rewriter, loc, linearizedMemRefInfo.linearizedSize));
205     }
206 
207     rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
208                                       adaptor.getSymbolOperands(),
209                                       adaptor.getAlignmentAttr());
210     return success();
211   }
212 };
213 
214 //===----------------------------------------------------------------------===//
215 // ConvertMemRefAssumeAlignment
216 //===----------------------------------------------------------------------===//
217 
218 struct ConvertMemRefAssumeAlignment final
219     : OpConversionPattern<memref::AssumeAlignmentOp> {
220   using OpConversionPattern::OpConversionPattern;
221 
222   LogicalResult
223   matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor,
224                   ConversionPatternRewriter &rewriter) const override {
225     Type newTy = getTypeConverter()->convertType(op.getMemref().getType());
226     if (!newTy) {
227       return rewriter.notifyMatchFailure(
228           op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
229                                       op.getMemref().getType()));
230     }
231 
232     rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>(
233         op, adaptor.getMemref(), adaptor.getAlignmentAttr());
234     return success();
235   }
236 };
237 
238 //===----------------------------------------------------------------------===//
239 // ConvertMemRefCopy
240 //===----------------------------------------------------------------------===//
241 
242 struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> {
243   using OpConversionPattern::OpConversionPattern;
244 
245   LogicalResult
246   matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
247                   ConversionPatternRewriter &rewriter) const override {
248     auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType());
249     auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType());
250     if (maybeRankedSource && maybeRankedDest &&
251         maybeRankedSource.getLayout() != maybeRankedDest.getLayout())
252       return rewriter.notifyMatchFailure(
253           op, llvm::formatv("memref.copy emulation with distinct layouts ({0} "
254                             "and {1}) is currently unimplemented",
255                             maybeRankedSource.getLayout(),
256                             maybeRankedDest.getLayout()));
257     rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(),
258                                                 adaptor.getTarget());
259     return success();
260   }
261 };
262 
263 //===----------------------------------------------------------------------===//
264 // ConvertMemRefDealloc
265 //===----------------------------------------------------------------------===//
266 
267 struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> {
268   using OpConversionPattern::OpConversionPattern;
269 
270   LogicalResult
271   matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
272                   ConversionPatternRewriter &rewriter) const override {
273     rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref());
274     return success();
275   }
276 };
277 
278 //===----------------------------------------------------------------------===//
279 // ConvertMemRefLoad
280 //===----------------------------------------------------------------------===//
281 
282 struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
283   using OpConversionPattern::OpConversionPattern;
284 
285   LogicalResult
286   matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
287                   ConversionPatternRewriter &rewriter) const override {
288     auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
289     auto convertedElementType = convertedType.getElementType();
290     auto oldElementType = op.getMemRefType().getElementType();
291     int srcBits = oldElementType.getIntOrFloatBitWidth();
292     int dstBits = convertedElementType.getIntOrFloatBitWidth();
293     if (dstBits % srcBits != 0) {
294       return rewriter.notifyMatchFailure(
295           op, "only dstBits % srcBits == 0 supported");
296     }
297 
298     Location loc = op.getLoc();
299     // Special case 0-rank memref loads.
300     Value bitsLoad;
301     if (convertedType.getRank() == 0) {
302       bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
303                                                  ValueRange{});
304     } else {
305       // Linearize the indices of the original load instruction. Do not account
306       // for the scaling yet. This will be accounted for later.
307       OpFoldResult linearizedIndices = getLinearizedSrcIndices(
308           rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
309 
310       Value newLoad = rewriter.create<memref::LoadOp>(
311           loc, adaptor.getMemref(),
312           getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
313                                    dstBits));
314 
315       // Get the offset and shift the bits to the rightmost.
316       // Note, currently only the big-endian is supported.
317       Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
318                                                   srcBits, dstBits, rewriter);
319       bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
320     }
321 
322     // Get the corresponding bits. If the arith computation bitwidth equals
323     // to the emulated bitwidth, we apply a mask to extract the low bits.
324     // It is not clear if this case actually happens in practice, but we keep
325     // the operations just in case. Otherwise, if the arith computation bitwidth
326     // is different from the emulated bitwidth we truncate the result.
327     Operation *result;
328     auto resultTy = getTypeConverter()->convertType(oldElementType);
329     if (resultTy == convertedElementType) {
330       auto mask = rewriter.create<arith::ConstantOp>(
331           loc, convertedElementType,
332           rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
333 
334       result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
335     } else {
336       result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
337     }
338 
339     rewriter.replaceOp(op, result->getResult(0));
340     return success();
341   }
342 };
343 
344 //===----------------------------------------------------------------------===//
345 // ConvertMemRefMemorySpaceCast
346 //===----------------------------------------------------------------------===//
347 
348 struct ConvertMemRefMemorySpaceCast final
349     : OpConversionPattern<memref::MemorySpaceCastOp> {
350   using OpConversionPattern::OpConversionPattern;
351 
352   LogicalResult
353   matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor,
354                   ConversionPatternRewriter &rewriter) const override {
355     Type newTy = getTypeConverter()->convertType(op.getDest().getType());
356     if (!newTy) {
357       return rewriter.notifyMatchFailure(
358           op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
359                                       op.getDest().getType()));
360     }
361 
362     rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy,
363                                                            adaptor.getSource());
364     return success();
365   }
366 };
367 
368 //===----------------------------------------------------------------------===//
369 // ConvertMemRefReinterpretCast
370 //===----------------------------------------------------------------------===//
371 
372 /// Output types should be at most one dimensional, so only the 0 or 1
373 /// dimensional cases are supported.
374 struct ConvertMemRefReinterpretCast final
375     : OpConversionPattern<memref::ReinterpretCastOp> {
376   using OpConversionPattern::OpConversionPattern;
377 
378   LogicalResult
379   matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
380                   ConversionPatternRewriter &rewriter) const override {
381     MemRefType newTy =
382         getTypeConverter()->convertType<MemRefType>(op.getType());
383     if (!newTy) {
384       return rewriter.notifyMatchFailure(
385           op->getLoc(),
386           llvm::formatv("failed to convert memref type: {0}", op.getType()));
387     }
388 
389     // Only support for 0 or 1 dimensional cases.
390     if (op.getType().getRank() > 1) {
391       return rewriter.notifyMatchFailure(
392           op->getLoc(), "subview with rank > 1 is not supported");
393     }
394 
395     return convertCastingOp(rewriter, adaptor, op, newTy);
396   }
397 };
398 
399 //===----------------------------------------------------------------------===//
400 // ConvertMemrefStore
401 //===----------------------------------------------------------------------===//
402 
403 struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
404   using OpConversionPattern::OpConversionPattern;
405 
406   LogicalResult
407   matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
408                   ConversionPatternRewriter &rewriter) const override {
409     auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
410     int srcBits = op.getMemRefType().getElementTypeBitWidth();
411     int dstBits = convertedType.getElementTypeBitWidth();
412     auto dstIntegerType = rewriter.getIntegerType(dstBits);
413     if (dstBits % srcBits != 0) {
414       return rewriter.notifyMatchFailure(
415           op, "only dstBits % srcBits == 0 supported");
416     }
417 
418     Location loc = op.getLoc();
419     Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
420                                                           adaptor.getValue());
421 
422     // Special case 0-rank memref stores. No need for masking.
423     if (convertedType.getRank() == 0) {
424       rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
425                                            extendedInput, adaptor.getMemref(),
426                                            ValueRange{});
427       rewriter.eraseOp(op);
428       return success();
429     }
430 
431     OpFoldResult linearizedIndices = getLinearizedSrcIndices(
432         rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
433     Value storeIndices = getIndicesForLoadOrStore(
434         rewriter, loc, linearizedIndices, srcBits, dstBits);
435     Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
436                                                 dstBits, rewriter);
437     Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
438                                           dstBits, bitwidthOffset, rewriter);
439     // Align the value to write with the destination bits
440     Value alignedVal =
441         rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
442 
443     // Clear destination bits
444     rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
445                                          writeMask, adaptor.getMemref(),
446                                          storeIndices);
447     // Write srcs bits to destination
448     rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
449                                          alignedVal, adaptor.getMemref(),
450                                          storeIndices);
451     rewriter.eraseOp(op);
452     return success();
453   }
454 };
455 
456 //===----------------------------------------------------------------------===//
457 // ConvertMemRefSubview
458 //===----------------------------------------------------------------------===//
459 
460 /// Emulating narrow ints on subview have limited support, supporting only
461 /// static offset and size and stride of 1. Ideally, the subview should be
462 /// folded away before running narrow type emulation, and this pattern should
463 /// only run for cases that can't be folded.
464 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
465   using OpConversionPattern::OpConversionPattern;
466 
467   LogicalResult
468   matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
469                   ConversionPatternRewriter &rewriter) const override {
470     MemRefType newTy =
471         getTypeConverter()->convertType<MemRefType>(subViewOp.getType());
472     if (!newTy) {
473       return rewriter.notifyMatchFailure(
474           subViewOp->getLoc(),
475           llvm::formatv("failed to convert memref type: {0}",
476                         subViewOp.getType()));
477     }
478 
479     Location loc = subViewOp.getLoc();
480     Type convertedElementType = newTy.getElementType();
481     Type oldElementType = subViewOp.getType().getElementType();
482     int srcBits = oldElementType.getIntOrFloatBitWidth();
483     int dstBits = convertedElementType.getIntOrFloatBitWidth();
484     if (dstBits % srcBits != 0)
485       return rewriter.notifyMatchFailure(
486           subViewOp, "only dstBits % srcBits == 0 supported");
487 
488     // Only support stride of 1.
489     if (llvm::any_of(subViewOp.getStaticStrides(),
490                      [](int64_t stride) { return stride != 1; })) {
491       return rewriter.notifyMatchFailure(subViewOp->getLoc(),
492                                          "stride != 1 is not supported");
493     }
494 
495     if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) {
496       return rewriter.notifyMatchFailure(
497           subViewOp, "the result memref type is not contiguous");
498     }
499 
500     auto sizes = subViewOp.getStaticSizes();
501     int64_t lastOffset = subViewOp.getStaticOffsets().back();
502     // Only support static sizes and offsets.
503     if (llvm::is_contained(sizes, ShapedType::kDynamic) ||
504         lastOffset == ShapedType::kDynamic) {
505       return rewriter.notifyMatchFailure(
506           subViewOp->getLoc(), "dynamic size or offset is not supported");
507     }
508 
509     // Transform the offsets, sizes and strides according to the emulation.
510     auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
511         loc, subViewOp.getViewSource());
512 
513     OpFoldResult linearizedIndices;
514     auto strides = stridedMetadata.getConstifiedMixedStrides();
515     memref::LinearizedMemRefInfo linearizedInfo;
516     std::tie(linearizedInfo, linearizedIndices) =
517         memref::getLinearizedMemRefOffsetAndSize(
518             rewriter, loc, srcBits, dstBits,
519             stridedMetadata.getConstifiedMixedOffset(),
520             subViewOp.getMixedSizes(), strides,
521             getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
522                            rewriter));
523 
524     rewriter.replaceOpWithNewOp<memref::SubViewOp>(
525         subViewOp, newTy, adaptor.getSource(), linearizedIndices,
526         linearizedInfo.linearizedSize, strides.back());
527     return success();
528   }
529 };
530 
531 //===----------------------------------------------------------------------===//
532 // ConvertMemRefCollapseShape
533 //===----------------------------------------------------------------------===//
534 
535 /// Emulating a `memref.collapse_shape` becomes a no-op after emulation given
536 /// that we flatten memrefs to a single dimension as part of the emulation and
537 /// there is no dimension to collapse any further.
538 struct ConvertMemRefCollapseShape final
539     : OpConversionPattern<memref::CollapseShapeOp> {
540   using OpConversionPattern::OpConversionPattern;
541 
542   LogicalResult
543   matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor,
544                   ConversionPatternRewriter &rewriter) const override {
545     Value srcVal = adaptor.getSrc();
546     auto newTy = dyn_cast<MemRefType>(srcVal.getType());
547     if (!newTy)
548       return failure();
549 
550     if (newTy.getRank() != 1)
551       return failure();
552 
553     rewriter.replaceOp(collapseShapeOp, srcVal);
554     return success();
555   }
556 };
557 
558 /// Emulating a `memref.expand_shape` becomes a no-op after emulation given
559 /// that we flatten memrefs to a single dimension as part of the emulation and
560 /// the expansion would just have been undone.
561 struct ConvertMemRefExpandShape final
562     : OpConversionPattern<memref::ExpandShapeOp> {
563   using OpConversionPattern::OpConversionPattern;
564 
565   LogicalResult
566   matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor,
567                   ConversionPatternRewriter &rewriter) const override {
568     Value srcVal = adaptor.getSrc();
569     auto newTy = dyn_cast<MemRefType>(srcVal.getType());
570     if (!newTy)
571       return failure();
572 
573     if (newTy.getRank() != 1)
574       return failure();
575 
576     rewriter.replaceOp(expandShapeOp, srcVal);
577     return success();
578   }
579 };
580 } // end anonymous namespace
581 
582 //===----------------------------------------------------------------------===//
583 // Public Interface Definition
584 //===----------------------------------------------------------------------===//
585 
586 void memref::populateMemRefNarrowTypeEmulationPatterns(
587     const arith::NarrowTypeEmulationConverter &typeConverter,
588     RewritePatternSet &patterns) {
589 
590   // Populate `memref.*` conversion patterns.
591   patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
592                ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy,
593                ConvertMemRefDealloc, ConvertMemRefCollapseShape,
594                ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore,
595                ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast,
596                ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
597       typeConverter, patterns.getContext());
598   memref::populateResolveExtractStridedMetadataPatterns(patterns);
599 }
600 
601 static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
602                                                int dstBits) {
603   if (ty.getRank() == 0)
604     return {};
605 
606   int64_t linearizedShape = 1;
607   for (auto shape : ty.getShape()) {
608     if (shape == ShapedType::kDynamic)
609       return {ShapedType::kDynamic};
610     linearizedShape *= shape;
611   }
612   int scale = dstBits / srcBits;
613   // Scale the size to the ceilDiv(linearizedShape, scale)
614   // to accomodate all the values.
615   linearizedShape = (linearizedShape + scale - 1) / scale;
616   return {linearizedShape};
617 }
618 
619 void memref::populateMemRefNarrowTypeEmulationConversions(
620     arith::NarrowTypeEmulationConverter &typeConverter) {
621   typeConverter.addConversion(
622       [&typeConverter](MemRefType ty) -> std::optional<Type> {
623         auto intTy = dyn_cast<IntegerType>(ty.getElementType());
624         if (!intTy)
625           return ty;
626 
627         unsigned width = intTy.getWidth();
628         unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
629         if (width >= loadStoreWidth)
630           return ty;
631 
632         // Currently only handle innermost stride being 1, checking
633         SmallVector<int64_t> strides;
634         int64_t offset;
635         if (failed(ty.getStridesAndOffset(strides, offset)))
636           return nullptr;
637         if (!strides.empty() && strides.back() != 1)
638           return nullptr;
639 
640         auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
641                                           intTy.getSignedness());
642         if (!newElemTy)
643           return nullptr;
644 
645         StridedLayoutAttr layoutAttr;
646         // If the offset is 0, we do not need a strided layout as the stride is
647         // 1, so we only use the strided layout if the offset is not 0.
648         if (offset != 0) {
649           if (offset == ShapedType::kDynamic) {
650             layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
651                                                 ArrayRef<int64_t>{1});
652           } else {
653             // Check if the number of bytes are a multiple of the loadStoreWidth
654             // and if so, divide it by the loadStoreWidth to get the offset.
655             if ((offset * width) % loadStoreWidth != 0)
656               return std::nullopt;
657             offset = (offset * width) / loadStoreWidth;
658 
659             layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
660                                                 ArrayRef<int64_t>{1});
661           }
662         }
663 
664         return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
665                                newElemTy, layoutAttr, ty.getMemorySpace());
666       });
667 }
668