1 //===- EmulateNarrowType.cpp - Narrow type emulation ----*- C++ 2 //-*-===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 10 #include "mlir/Dialect/Affine/IR/AffineOps.h" 11 #include "mlir/Dialect/Arith/IR/Arith.h" 12 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" 13 #include "mlir/Dialect/Arith/Transforms/Passes.h" 14 #include "mlir/Dialect/Arith/Utils/Utils.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/MemRef/Transforms/Transforms.h" 17 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/OpDefinition.h" 22 #include "mlir/Transforms/DialectConversion.h" 23 #include "llvm/Support/FormatVariadic.h" 24 #include "llvm/Support/MathExtras.h" 25 #include <cassert> 26 #include <type_traits> 27 28 using namespace mlir; 29 30 //===----------------------------------------------------------------------===// 31 // Utility functions 32 //===----------------------------------------------------------------------===// 33 34 /// Converts a memref::ReinterpretCastOp to the converted type. The result 35 /// MemRefType of the old op must have a rank and stride of 1, with static 36 /// offset and size. The number of bits in the offset must evenly divide the 37 /// bitwidth of the new converted type. 38 static LogicalResult 39 convertCastingOp(ConversionPatternRewriter &rewriter, 40 memref::ReinterpretCastOp::Adaptor adaptor, 41 memref::ReinterpretCastOp op, MemRefType newTy) { 42 auto convertedElementType = newTy.getElementType(); 43 auto oldElementType = op.getType().getElementType(); 44 int srcBits = oldElementType.getIntOrFloatBitWidth(); 45 int dstBits = convertedElementType.getIntOrFloatBitWidth(); 46 if (dstBits % srcBits != 0) { 47 return rewriter.notifyMatchFailure(op, 48 "only dstBits % srcBits == 0 supported"); 49 } 50 51 // Only support stride of 1. 52 if (llvm::any_of(op.getStaticStrides(), 53 [](int64_t stride) { return stride != 1; })) { 54 return rewriter.notifyMatchFailure(op->getLoc(), 55 "stride != 1 is not supported"); 56 } 57 58 auto sizes = op.getStaticSizes(); 59 int64_t offset = op.getStaticOffset(0); 60 // Only support static sizes and offsets. 61 if (llvm::is_contained(sizes, ShapedType::kDynamic) || 62 offset == ShapedType::kDynamic) { 63 return rewriter.notifyMatchFailure( 64 op, "dynamic size or offset is not supported"); 65 } 66 67 int elementsPerByte = dstBits / srcBits; 68 if (offset % elementsPerByte != 0) { 69 return rewriter.notifyMatchFailure( 70 op, "offset not multiple of elementsPerByte is not supported"); 71 } 72 73 SmallVector<int64_t> size; 74 if (sizes.size()) 75 size.push_back(llvm::divideCeilSigned(sizes[0], elementsPerByte)); 76 offset = offset / elementsPerByte; 77 78 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( 79 op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides()); 80 return success(); 81 } 82 83 /// When data is loaded/stored in `targetBits` granularity, but is used in 84 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is 85 /// treated as an array of elements of width `sourceBits`. 86 /// Return the bit offset of the value at position `srcIdx`. For example, if 87 /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is 88 /// located at (x % 2) * 4. Because there are two elements in one i8, and one 89 /// element has 4 bits. 90 static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, 91 int sourceBits, int targetBits, 92 OpBuilder &builder) { 93 assert(targetBits % sourceBits == 0); 94 AffineExpr s0; 95 bindSymbols(builder.getContext(), s0); 96 int scaleFactor = targetBits / sourceBits; 97 AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits; 98 OpFoldResult offsetVal = 99 affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx}); 100 Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal); 101 IntegerType dstType = builder.getIntegerType(targetBits); 102 return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset); 103 } 104 105 /// When writing a subbyte size, masked bitwise operations are used to only 106 /// modify the relevant bits. This function returns an and mask for clearing 107 /// the destination bits in a subbyte write. E.g., when writing to the second 108 /// i4 in an i32, 0xFFFFFF0F is created. 109 static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, 110 int64_t srcBits, int64_t dstBits, 111 Value bitwidthOffset, OpBuilder &builder) { 112 auto dstIntegerType = builder.getIntegerType(dstBits); 113 auto maskRightAlignedAttr = 114 builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1); 115 Value maskRightAligned = builder.create<arith::ConstantOp>( 116 loc, dstIntegerType, maskRightAlignedAttr); 117 Value writeMaskInverse = 118 builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset); 119 auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1); 120 Value flipVal = 121 builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr); 122 return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal); 123 } 124 125 /// Returns the scaled linearized index based on the `srcBits` and `dstBits` 126 /// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and 127 /// the returned index has the granularity of `dstBits` 128 static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, 129 OpFoldResult linearizedIndex, 130 int64_t srcBits, int64_t dstBits) { 131 AffineExpr s0; 132 bindSymbols(builder.getContext(), s0); 133 int64_t scaler = dstBits / srcBits; 134 OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply( 135 builder, loc, s0.floorDiv(scaler), {linearizedIndex}); 136 return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices); 137 } 138 139 static OpFoldResult 140 getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, 141 const SmallVector<OpFoldResult> &indices, 142 Value memref) { 143 auto stridedMetadata = 144 builder.create<memref::ExtractStridedMetadataOp>(loc, memref); 145 OpFoldResult linearizedIndices; 146 std::tie(std::ignore, linearizedIndices) = 147 memref::getLinearizedMemRefOffsetAndSize( 148 builder, loc, srcBits, srcBits, 149 stridedMetadata.getConstifiedMixedOffset(), 150 stridedMetadata.getConstifiedMixedSizes(), 151 stridedMetadata.getConstifiedMixedStrides(), indices); 152 return linearizedIndices; 153 } 154 155 namespace { 156 157 //===----------------------------------------------------------------------===// 158 // ConvertMemRefAllocation 159 //===----------------------------------------------------------------------===// 160 161 template <typename OpTy> 162 struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> { 163 using OpConversionPattern<OpTy>::OpConversionPattern; 164 165 LogicalResult 166 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, 167 ConversionPatternRewriter &rewriter) const override { 168 static_assert(std::is_same<OpTy, memref::AllocOp>() || 169 std::is_same<OpTy, memref::AllocaOp>(), 170 "expected only memref::AllocOp or memref::AllocaOp"); 171 auto currentType = cast<MemRefType>(op.getMemref().getType()); 172 auto newResultType = 173 this->getTypeConverter()->template convertType<MemRefType>( 174 op.getType()); 175 if (!newResultType) { 176 return rewriter.notifyMatchFailure( 177 op->getLoc(), 178 llvm::formatv("failed to convert memref type: {0}", op.getType())); 179 } 180 181 // Special case zero-rank memrefs. 182 if (currentType.getRank() == 0) { 183 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{}, 184 adaptor.getSymbolOperands(), 185 adaptor.getAlignmentAttr()); 186 return success(); 187 } 188 189 Location loc = op.getLoc(); 190 OpFoldResult zero = rewriter.getIndexAttr(0); 191 SmallVector<OpFoldResult> indices(currentType.getRank(), zero); 192 193 // Get linearized type. 194 int srcBits = currentType.getElementType().getIntOrFloatBitWidth(); 195 int dstBits = newResultType.getElementType().getIntOrFloatBitWidth(); 196 SmallVector<OpFoldResult> sizes = op.getMixedSizes(); 197 198 memref::LinearizedMemRefInfo linearizedMemRefInfo = 199 memref::getLinearizedMemRefOffsetAndSize( 200 rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes); 201 SmallVector<Value> dynamicLinearizedSize; 202 if (!newResultType.hasStaticShape()) { 203 dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp( 204 rewriter, loc, linearizedMemRefInfo.linearizedSize)); 205 } 206 207 rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize, 208 adaptor.getSymbolOperands(), 209 adaptor.getAlignmentAttr()); 210 return success(); 211 } 212 }; 213 214 //===----------------------------------------------------------------------===// 215 // ConvertMemRefAssumeAlignment 216 //===----------------------------------------------------------------------===// 217 218 struct ConvertMemRefAssumeAlignment final 219 : OpConversionPattern<memref::AssumeAlignmentOp> { 220 using OpConversionPattern::OpConversionPattern; 221 222 LogicalResult 223 matchAndRewrite(memref::AssumeAlignmentOp op, OpAdaptor adaptor, 224 ConversionPatternRewriter &rewriter) const override { 225 Type newTy = getTypeConverter()->convertType(op.getMemref().getType()); 226 if (!newTy) { 227 return rewriter.notifyMatchFailure( 228 op->getLoc(), llvm::formatv("failed to convert memref type: {0}", 229 op.getMemref().getType())); 230 } 231 232 rewriter.replaceOpWithNewOp<memref::AssumeAlignmentOp>( 233 op, adaptor.getMemref(), adaptor.getAlignmentAttr()); 234 return success(); 235 } 236 }; 237 238 //===----------------------------------------------------------------------===// 239 // ConvertMemRefCopy 240 //===----------------------------------------------------------------------===// 241 242 struct ConvertMemRefCopy final : OpConversionPattern<memref::CopyOp> { 243 using OpConversionPattern::OpConversionPattern; 244 245 LogicalResult 246 matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor, 247 ConversionPatternRewriter &rewriter) const override { 248 auto maybeRankedSource = dyn_cast<MemRefType>(op.getSource().getType()); 249 auto maybeRankedDest = dyn_cast<MemRefType>(op.getTarget().getType()); 250 if (maybeRankedSource && maybeRankedDest && 251 maybeRankedSource.getLayout() != maybeRankedDest.getLayout()) 252 return rewriter.notifyMatchFailure( 253 op, llvm::formatv("memref.copy emulation with distinct layouts ({0} " 254 "and {1}) is currently unimplemented", 255 maybeRankedSource.getLayout(), 256 maybeRankedDest.getLayout())); 257 rewriter.replaceOpWithNewOp<memref::CopyOp>(op, adaptor.getSource(), 258 adaptor.getTarget()); 259 return success(); 260 } 261 }; 262 263 //===----------------------------------------------------------------------===// 264 // ConvertMemRefDealloc 265 //===----------------------------------------------------------------------===// 266 267 struct ConvertMemRefDealloc final : OpConversionPattern<memref::DeallocOp> { 268 using OpConversionPattern::OpConversionPattern; 269 270 LogicalResult 271 matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, 272 ConversionPatternRewriter &rewriter) const override { 273 rewriter.replaceOpWithNewOp<memref::DeallocOp>(op, adaptor.getMemref()); 274 return success(); 275 } 276 }; 277 278 //===----------------------------------------------------------------------===// 279 // ConvertMemRefLoad 280 //===----------------------------------------------------------------------===// 281 282 struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> { 283 using OpConversionPattern::OpConversionPattern; 284 285 LogicalResult 286 matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, 287 ConversionPatternRewriter &rewriter) const override { 288 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType()); 289 auto convertedElementType = convertedType.getElementType(); 290 auto oldElementType = op.getMemRefType().getElementType(); 291 int srcBits = oldElementType.getIntOrFloatBitWidth(); 292 int dstBits = convertedElementType.getIntOrFloatBitWidth(); 293 if (dstBits % srcBits != 0) { 294 return rewriter.notifyMatchFailure( 295 op, "only dstBits % srcBits == 0 supported"); 296 } 297 298 Location loc = op.getLoc(); 299 // Special case 0-rank memref loads. 300 Value bitsLoad; 301 if (convertedType.getRank() == 0) { 302 bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(), 303 ValueRange{}); 304 } else { 305 // Linearize the indices of the original load instruction. Do not account 306 // for the scaling yet. This will be accounted for later. 307 OpFoldResult linearizedIndices = getLinearizedSrcIndices( 308 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); 309 310 Value newLoad = rewriter.create<memref::LoadOp>( 311 loc, adaptor.getMemref(), 312 getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits, 313 dstBits)); 314 315 // Get the offset and shift the bits to the rightmost. 316 // Note, currently only the big-endian is supported. 317 Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, 318 srcBits, dstBits, rewriter); 319 bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset); 320 } 321 322 // Get the corresponding bits. If the arith computation bitwidth equals 323 // to the emulated bitwidth, we apply a mask to extract the low bits. 324 // It is not clear if this case actually happens in practice, but we keep 325 // the operations just in case. Otherwise, if the arith computation bitwidth 326 // is different from the emulated bitwidth we truncate the result. 327 Operation *result; 328 auto resultTy = getTypeConverter()->convertType(oldElementType); 329 if (resultTy == convertedElementType) { 330 auto mask = rewriter.create<arith::ConstantOp>( 331 loc, convertedElementType, 332 rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1)); 333 334 result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask); 335 } else { 336 result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad); 337 } 338 339 rewriter.replaceOp(op, result->getResult(0)); 340 return success(); 341 } 342 }; 343 344 //===----------------------------------------------------------------------===// 345 // ConvertMemRefMemorySpaceCast 346 //===----------------------------------------------------------------------===// 347 348 struct ConvertMemRefMemorySpaceCast final 349 : OpConversionPattern<memref::MemorySpaceCastOp> { 350 using OpConversionPattern::OpConversionPattern; 351 352 LogicalResult 353 matchAndRewrite(memref::MemorySpaceCastOp op, OpAdaptor adaptor, 354 ConversionPatternRewriter &rewriter) const override { 355 Type newTy = getTypeConverter()->convertType(op.getDest().getType()); 356 if (!newTy) { 357 return rewriter.notifyMatchFailure( 358 op->getLoc(), llvm::formatv("failed to convert memref type: {0}", 359 op.getDest().getType())); 360 } 361 362 rewriter.replaceOpWithNewOp<memref::MemorySpaceCastOp>(op, newTy, 363 adaptor.getSource()); 364 return success(); 365 } 366 }; 367 368 //===----------------------------------------------------------------------===// 369 // ConvertMemRefReinterpretCast 370 //===----------------------------------------------------------------------===// 371 372 /// Output types should be at most one dimensional, so only the 0 or 1 373 /// dimensional cases are supported. 374 struct ConvertMemRefReinterpretCast final 375 : OpConversionPattern<memref::ReinterpretCastOp> { 376 using OpConversionPattern::OpConversionPattern; 377 378 LogicalResult 379 matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, 380 ConversionPatternRewriter &rewriter) const override { 381 MemRefType newTy = 382 getTypeConverter()->convertType<MemRefType>(op.getType()); 383 if (!newTy) { 384 return rewriter.notifyMatchFailure( 385 op->getLoc(), 386 llvm::formatv("failed to convert memref type: {0}", op.getType())); 387 } 388 389 // Only support for 0 or 1 dimensional cases. 390 if (op.getType().getRank() > 1) { 391 return rewriter.notifyMatchFailure( 392 op->getLoc(), "subview with rank > 1 is not supported"); 393 } 394 395 return convertCastingOp(rewriter, adaptor, op, newTy); 396 } 397 }; 398 399 //===----------------------------------------------------------------------===// 400 // ConvertMemrefStore 401 //===----------------------------------------------------------------------===// 402 403 struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> { 404 using OpConversionPattern::OpConversionPattern; 405 406 LogicalResult 407 matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, 408 ConversionPatternRewriter &rewriter) const override { 409 auto convertedType = cast<MemRefType>(adaptor.getMemref().getType()); 410 int srcBits = op.getMemRefType().getElementTypeBitWidth(); 411 int dstBits = convertedType.getElementTypeBitWidth(); 412 auto dstIntegerType = rewriter.getIntegerType(dstBits); 413 if (dstBits % srcBits != 0) { 414 return rewriter.notifyMatchFailure( 415 op, "only dstBits % srcBits == 0 supported"); 416 } 417 418 Location loc = op.getLoc(); 419 Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, 420 adaptor.getValue()); 421 422 // Special case 0-rank memref stores. No need for masking. 423 if (convertedType.getRank() == 0) { 424 rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign, 425 extendedInput, adaptor.getMemref(), 426 ValueRange{}); 427 rewriter.eraseOp(op); 428 return success(); 429 } 430 431 OpFoldResult linearizedIndices = getLinearizedSrcIndices( 432 rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); 433 Value storeIndices = getIndicesForLoadOrStore( 434 rewriter, loc, linearizedIndices, srcBits, dstBits); 435 Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits, 436 dstBits, rewriter); 437 Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits, 438 dstBits, bitwidthOffset, rewriter); 439 // Align the value to write with the destination bits 440 Value alignedVal = 441 rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset); 442 443 // Clear destination bits 444 rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, 445 writeMask, adaptor.getMemref(), 446 storeIndices); 447 // Write srcs bits to destination 448 rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, 449 alignedVal, adaptor.getMemref(), 450 storeIndices); 451 rewriter.eraseOp(op); 452 return success(); 453 } 454 }; 455 456 //===----------------------------------------------------------------------===// 457 // ConvertMemRefSubview 458 //===----------------------------------------------------------------------===// 459 460 /// Emulating narrow ints on subview have limited support, supporting only 461 /// static offset and size and stride of 1. Ideally, the subview should be 462 /// folded away before running narrow type emulation, and this pattern should 463 /// only run for cases that can't be folded. 464 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> { 465 using OpConversionPattern::OpConversionPattern; 466 467 LogicalResult 468 matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor, 469 ConversionPatternRewriter &rewriter) const override { 470 MemRefType newTy = 471 getTypeConverter()->convertType<MemRefType>(subViewOp.getType()); 472 if (!newTy) { 473 return rewriter.notifyMatchFailure( 474 subViewOp->getLoc(), 475 llvm::formatv("failed to convert memref type: {0}", 476 subViewOp.getType())); 477 } 478 479 Location loc = subViewOp.getLoc(); 480 Type convertedElementType = newTy.getElementType(); 481 Type oldElementType = subViewOp.getType().getElementType(); 482 int srcBits = oldElementType.getIntOrFloatBitWidth(); 483 int dstBits = convertedElementType.getIntOrFloatBitWidth(); 484 if (dstBits % srcBits != 0) 485 return rewriter.notifyMatchFailure( 486 subViewOp, "only dstBits % srcBits == 0 supported"); 487 488 // Only support stride of 1. 489 if (llvm::any_of(subViewOp.getStaticStrides(), 490 [](int64_t stride) { return stride != 1; })) { 491 return rewriter.notifyMatchFailure(subViewOp->getLoc(), 492 "stride != 1 is not supported"); 493 } 494 495 if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) { 496 return rewriter.notifyMatchFailure( 497 subViewOp, "the result memref type is not contiguous"); 498 } 499 500 auto sizes = subViewOp.getStaticSizes(); 501 int64_t lastOffset = subViewOp.getStaticOffsets().back(); 502 // Only support static sizes and offsets. 503 if (llvm::is_contained(sizes, ShapedType::kDynamic) || 504 lastOffset == ShapedType::kDynamic) { 505 return rewriter.notifyMatchFailure( 506 subViewOp->getLoc(), "dynamic size or offset is not supported"); 507 } 508 509 // Transform the offsets, sizes and strides according to the emulation. 510 auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>( 511 loc, subViewOp.getViewSource()); 512 513 OpFoldResult linearizedIndices; 514 auto strides = stridedMetadata.getConstifiedMixedStrides(); 515 memref::LinearizedMemRefInfo linearizedInfo; 516 std::tie(linearizedInfo, linearizedIndices) = 517 memref::getLinearizedMemRefOffsetAndSize( 518 rewriter, loc, srcBits, dstBits, 519 stridedMetadata.getConstifiedMixedOffset(), 520 subViewOp.getMixedSizes(), strides, 521 getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), 522 rewriter)); 523 524 rewriter.replaceOpWithNewOp<memref::SubViewOp>( 525 subViewOp, newTy, adaptor.getSource(), linearizedIndices, 526 linearizedInfo.linearizedSize, strides.back()); 527 return success(); 528 } 529 }; 530 531 //===----------------------------------------------------------------------===// 532 // ConvertMemRefCollapseShape 533 //===----------------------------------------------------------------------===// 534 535 /// Emulating a `memref.collapse_shape` becomes a no-op after emulation given 536 /// that we flatten memrefs to a single dimension as part of the emulation and 537 /// there is no dimension to collapse any further. 538 struct ConvertMemRefCollapseShape final 539 : OpConversionPattern<memref::CollapseShapeOp> { 540 using OpConversionPattern::OpConversionPattern; 541 542 LogicalResult 543 matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor, 544 ConversionPatternRewriter &rewriter) const override { 545 Value srcVal = adaptor.getSrc(); 546 auto newTy = dyn_cast<MemRefType>(srcVal.getType()); 547 if (!newTy) 548 return failure(); 549 550 if (newTy.getRank() != 1) 551 return failure(); 552 553 rewriter.replaceOp(collapseShapeOp, srcVal); 554 return success(); 555 } 556 }; 557 558 /// Emulating a `memref.expand_shape` becomes a no-op after emulation given 559 /// that we flatten memrefs to a single dimension as part of the emulation and 560 /// the expansion would just have been undone. 561 struct ConvertMemRefExpandShape final 562 : OpConversionPattern<memref::ExpandShapeOp> { 563 using OpConversionPattern::OpConversionPattern; 564 565 LogicalResult 566 matchAndRewrite(memref::ExpandShapeOp expandShapeOp, OpAdaptor adaptor, 567 ConversionPatternRewriter &rewriter) const override { 568 Value srcVal = adaptor.getSrc(); 569 auto newTy = dyn_cast<MemRefType>(srcVal.getType()); 570 if (!newTy) 571 return failure(); 572 573 if (newTy.getRank() != 1) 574 return failure(); 575 576 rewriter.replaceOp(expandShapeOp, srcVal); 577 return success(); 578 } 579 }; 580 } // end anonymous namespace 581 582 //===----------------------------------------------------------------------===// 583 // Public Interface Definition 584 //===----------------------------------------------------------------------===// 585 586 void memref::populateMemRefNarrowTypeEmulationPatterns( 587 const arith::NarrowTypeEmulationConverter &typeConverter, 588 RewritePatternSet &patterns) { 589 590 // Populate `memref.*` conversion patterns. 591 patterns.add<ConvertMemRefAllocation<memref::AllocOp>, 592 ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefCopy, 593 ConvertMemRefDealloc, ConvertMemRefCollapseShape, 594 ConvertMemRefExpandShape, ConvertMemRefLoad, ConvertMemrefStore, 595 ConvertMemRefAssumeAlignment, ConvertMemRefMemorySpaceCast, 596 ConvertMemRefSubview, ConvertMemRefReinterpretCast>( 597 typeConverter, patterns.getContext()); 598 memref::populateResolveExtractStridedMetadataPatterns(patterns); 599 } 600 601 static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits, 602 int dstBits) { 603 if (ty.getRank() == 0) 604 return {}; 605 606 int64_t linearizedShape = 1; 607 for (auto shape : ty.getShape()) { 608 if (shape == ShapedType::kDynamic) 609 return {ShapedType::kDynamic}; 610 linearizedShape *= shape; 611 } 612 int scale = dstBits / srcBits; 613 // Scale the size to the ceilDiv(linearizedShape, scale) 614 // to accomodate all the values. 615 linearizedShape = (linearizedShape + scale - 1) / scale; 616 return {linearizedShape}; 617 } 618 619 void memref::populateMemRefNarrowTypeEmulationConversions( 620 arith::NarrowTypeEmulationConverter &typeConverter) { 621 typeConverter.addConversion( 622 [&typeConverter](MemRefType ty) -> std::optional<Type> { 623 auto intTy = dyn_cast<IntegerType>(ty.getElementType()); 624 if (!intTy) 625 return ty; 626 627 unsigned width = intTy.getWidth(); 628 unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth(); 629 if (width >= loadStoreWidth) 630 return ty; 631 632 // Currently only handle innermost stride being 1, checking 633 SmallVector<int64_t> strides; 634 int64_t offset; 635 if (failed(ty.getStridesAndOffset(strides, offset))) 636 return nullptr; 637 if (!strides.empty() && strides.back() != 1) 638 return nullptr; 639 640 auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth, 641 intTy.getSignedness()); 642 if (!newElemTy) 643 return nullptr; 644 645 StridedLayoutAttr layoutAttr; 646 // If the offset is 0, we do not need a strided layout as the stride is 647 // 1, so we only use the strided layout if the offset is not 0. 648 if (offset != 0) { 649 if (offset == ShapedType::kDynamic) { 650 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, 651 ArrayRef<int64_t>{1}); 652 } else { 653 // Check if the number of bytes are a multiple of the loadStoreWidth 654 // and if so, divide it by the loadStoreWidth to get the offset. 655 if ((offset * width) % loadStoreWidth != 0) 656 return std::nullopt; 657 offset = (offset * width) / loadStoreWidth; 658 659 layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset, 660 ArrayRef<int64_t>{1}); 661 } 662 } 663 664 return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth), 665 newElemTy, layoutAttr, ty.getMemorySpace()); 666 }); 667 } 668