1 //===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent rewrites and utilities to emulate 10 // narrow types that are not supported by the target hardware, e.g. i4, using 11 // wider types, e.g. i8. 12 // 13 /// Currently, only power-of-two integer types are supported. These are 14 /// converted to wider integers that are either 8 bits wide or wider. 15 /// 16 /// TODO: Support for non-powers-of-two. 17 //===----------------------------------------------------------------------===// 18 19 #include "mlir/Dialect/Affine/IR/AffineOps.h" 20 #include "mlir/Dialect/Arith/IR/Arith.h" 21 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h" 22 #include "mlir/Dialect/Arith/Utils/Utils.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" 25 #include "mlir/Dialect/Utils/StaticValueUtils.h" 26 #include "mlir/Dialect/Vector/IR/VectorOps.h" 27 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" 28 #include "mlir/IR/BuiltinAttributes.h" 29 #include "mlir/IR/BuiltinTypes.h" 30 #include "mlir/IR/OpDefinition.h" 31 #include "mlir/IR/TypeUtilities.h" 32 #include "mlir/IR/Value.h" 33 #include "mlir/Transforms/DialectConversion.h" 34 #include "llvm/ADT/SmallVector.h" 35 #include "llvm/Support/Debug.h" 36 #include "llvm/Support/MathExtras.h" 37 #include "llvm/Support/raw_ostream.h" 38 #include <cstdint> 39 #include <optional> 40 41 using namespace mlir; 42 43 #define DEBUG_TYPE "vector-narrow-type-emulation" 44 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 45 #define DBGSNL() (llvm::dbgs() << "\n") 46 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") 47 48 using VectorValue = TypedValue<VectorType>; 49 using MemRefValue = TypedValue<MemRefType>; 50 51 /// Returns a compressed mask for the emulated vector. For example, when 52 /// emulating an eight-element `i8` vector with `i32` (i.e. when the source 53 /// elements span two dest elements), this method compresses `vector<8xi1>` 54 /// into `vector<2xi1>`. 55 /// 56 /// The compressed/output mask value is set iff any mask in the corresponding 57 /// `numSrcElemsPerDest` range of uncompressed/input masks is set. E.g., if 58 /// `numSrcElemsPerDest` equals to 2, and `numFrontPadElems` equals to 1, the 59 /// following mask: 60 /// 61 /// %mask = [1, 1, 0, 0, 0, 0] 62 /// 63 /// will first be padded in the front with `numFrontPadElems` zeros, and zeros 64 /// will be added in the back to make the number of elements a multiple of 65 /// `numSrcElemsPerDest` (for easier computation). The resulting mask will be: 66 /// 67 /// %mask = [0, 1, 1, 0, 0, 0, 0, 0] 68 /// 69 /// then it will return the following new compressed mask: 70 /// 71 /// %mask = [1, 1, 0, 0] 72 /// 73 /// NOTE: `numFrontPadElems` is assumed to be strictly smaller than 74 /// `numSrcElemsPerDest`. 75 static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter, 76 Location loc, Value mask, 77 int numSrcElems, 78 int numSrcElemsPerDest, 79 int numFrontPadElems = 0) { 80 81 assert(numFrontPadElems < numSrcElemsPerDest && 82 "numFrontPadElems must be less than numSrcElemsPerDest"); 83 84 auto numDestElems = 85 (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1) / 86 numSrcElemsPerDest; 87 88 Operation *maskOp = mask.getDefiningOp(); 89 SmallVector<vector::ExtractOp, 2> extractOps; 90 // TODO: add support to `vector.splat`. 91 // Finding the mask creation operation. 92 while (maskOp && 93 !isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>( 94 maskOp)) { 95 if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) { 96 maskOp = extractOp.getVector().getDefiningOp(); 97 extractOps.push_back(extractOp); 98 } 99 } 100 101 if (!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>( 102 maskOp)) 103 return failure(); 104 105 // Computing the "compressed" mask. All the emulation logic (i.e. computing 106 // new mask index) only happens on the last dimension of the vectors. 107 SmallVector<int64_t> maskShape( 108 cast<VectorType>(maskOp->getResultTypes()[0]).getShape()); 109 maskShape.back() = numDestElems; 110 auto newMaskType = VectorType::get(maskShape, rewriter.getI1Type()); 111 std::optional<Operation *> newMask = 112 TypeSwitch<Operation *, std::optional<Operation *>>(maskOp) 113 .Case<vector::CreateMaskOp>( 114 [&](auto createMaskOp) -> std::optional<Operation *> { 115 OperandRange maskOperands = createMaskOp.getOperands(); 116 // The `vector.create_mask` op creates a mask arrangement 117 // without any zeros at the front. Also, because 118 // `numFrontPadElems` is strictly smaller than 119 // `numSrcElemsPerDest`, the compressed mask generated by 120 // padding the original mask by `numFrontPadElems` will not 121 // have any zeros at the front as well. 122 AffineExpr s0; 123 bindSymbols(rewriter.getContext(), s0); 124 s0 = (s0 + numFrontPadElems).ceilDiv(numSrcElemsPerDest); 125 OpFoldResult origIndex = getAsOpFoldResult(maskOperands.back()); 126 OpFoldResult maskIndex = affine::makeComposedFoldedAffineApply( 127 rewriter, loc, s0, origIndex); 128 SmallVector<Value> newMaskOperands(maskOperands.drop_back()); 129 newMaskOperands.push_back( 130 getValueOrCreateConstantIndexOp(rewriter, loc, maskIndex)); 131 return rewriter.create<vector::CreateMaskOp>(loc, newMaskType, 132 newMaskOperands); 133 }) 134 .Case<vector::ConstantMaskOp>( 135 [&](auto constantMaskOp) -> std::optional<Operation *> { 136 // Take the shape of mask, compress its trailing dimension: 137 SmallVector<int64_t> maskDimSizes( 138 constantMaskOp.getMaskDimSizes()); 139 int64_t &maskIndex = maskDimSizes.back(); 140 maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex, 141 numSrcElemsPerDest); 142 return rewriter.create<vector::ConstantMaskOp>(loc, newMaskType, 143 maskDimSizes); 144 }) 145 .Case<arith::ConstantOp>([&](auto constantOp) 146 -> std::optional<Operation *> { 147 // TODO: Support multiple dimensions. 148 if (maskShape.size() != 1) 149 return std::nullopt; 150 // Rearrange the original mask values to cover the whole potential 151 // loading region. For example, in the case of using byte-size for 152 // emulation, given the following mask: 153 // 154 // %mask = [0, 1, 0, 1, 0, 0] 155 // 156 // With front offset of 1, the mask will be padded 0s in the front 157 // and back so that: 158 // 1. It is aligned with the effective loading bits 159 // 2. Its length is multiple of `numSrcElemPerDest` (and the total 160 // coverage size is mulitiple of bytes). The new mask will be like 161 // this before compressing: 162 // 163 // %new_mask = [0, 0, 1, 0, 1, 0, 0, 0] 164 auto originalMask = 165 cast<DenseIntElementsAttr>(constantOp.getValue()); 166 SmallVector<bool> paddedMaskValues(numFrontPadElems, false); 167 paddedMaskValues.append(originalMask.template value_begin<bool>(), 168 originalMask.template value_end<bool>()); 169 paddedMaskValues.resize(numDestElems * numSrcElemsPerDest, false); 170 171 // Compressing by combining every `numSrcElemsPerDest` elements: 172 SmallVector<bool> compressedMaskValues; 173 for (size_t i = 0; i < paddedMaskValues.size(); 174 i += numSrcElemsPerDest) { 175 bool combinedValue = false; 176 for (int j = 0; j < numSrcElemsPerDest; ++j) { 177 combinedValue |= paddedMaskValues[i + j]; 178 } 179 compressedMaskValues.push_back(combinedValue); 180 } 181 return rewriter.create<arith::ConstantOp>( 182 loc, DenseElementsAttr::get(newMaskType, compressedMaskValues)); 183 }); 184 185 if (!newMask) 186 return failure(); 187 188 while (!extractOps.empty()) { 189 newMask = rewriter.create<vector::ExtractOp>( 190 loc, (*newMask)->getResults()[0], extractOps.back().getMixedPosition()); 191 extractOps.pop_back(); 192 } 193 194 return *newMask; 195 } 196 197 /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for 198 /// emitting `vector.extract_strided_slice`. 199 static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, 200 Value source, int64_t frontOffset, 201 int64_t subvecSize) { 202 auto vectorType = cast<VectorType>(source.getType()); 203 assert(vectorType.getRank() == 1 && "expected 1-D source types"); 204 assert(frontOffset + subvecSize <= vectorType.getNumElements() && 205 "subvector out of bounds"); 206 207 // do not need extraction if the subvector size is the same as the source 208 if (vectorType.getNumElements() == subvecSize) 209 return source; 210 211 auto offsets = rewriter.getI64ArrayAttr({frontOffset}); 212 auto sizes = rewriter.getI64ArrayAttr({subvecSize}); 213 auto strides = rewriter.getI64ArrayAttr({1}); 214 215 auto resultVectorType = 216 VectorType::get({subvecSize}, vectorType.getElementType()); 217 return rewriter 218 .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source, 219 offsets, sizes, strides) 220 ->getResult(0); 221 } 222 223 /// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting 224 /// at `offset`. it is a wrapper function for emitting 225 /// `vector.insert_strided_slice`. 226 static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc, 227 Value src, Value dest, int64_t offset) { 228 [[maybe_unused]] auto srcType = cast<VectorType>(src.getType()); 229 [[maybe_unused]] auto destType = cast<VectorType>(dest.getType()); 230 assert(srcType.getRank() == 1 && destType.getRank() == 1 && 231 "expected source and dest to be vector type"); 232 auto offsets = rewriter.getI64ArrayAttr({offset}); 233 auto strides = rewriter.getI64ArrayAttr({1}); 234 return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src, 235 dest, offsets, strides); 236 } 237 238 /// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset` 239 /// and size `numElementsToExtract`, and inserts into the `dest` vector. This 240 /// function emits multiple `vector.extract` and `vector.insert` ops, so only 241 /// use it when `offset` cannot be folded into a constant value. 242 static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc, 243 Value source, Value dest, 244 OpFoldResult offset, 245 int64_t numElementsToExtract) { 246 assert(isa<VectorValue>(source) && "expected `source` to be a vector type"); 247 for (int i = 0; i < numElementsToExtract; ++i) { 248 Value extractLoc = 249 (i == 0) ? offset.dyn_cast<Value>() 250 : rewriter.create<arith::AddIOp>( 251 loc, rewriter.getIndexType(), offset.dyn_cast<Value>(), 252 rewriter.create<arith::ConstantIndexOp>(loc, i)); 253 auto extractOp = 254 rewriter.create<vector::ExtractOp>(loc, source, extractLoc); 255 dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i); 256 } 257 return dest; 258 } 259 260 /// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`. 261 static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc, 262 Value source, Value dest, 263 OpFoldResult destOffsetVar, 264 size_t length) { 265 assert(isa<VectorValue>(source) && "expected `source` to be a vector type"); 266 assert(length > 0 && "length must be greater than 0"); 267 Value destOffsetVal = 268 getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar); 269 for (size_t i = 0; i < length; ++i) { 270 auto insertLoc = i == 0 271 ? destOffsetVal 272 : rewriter.create<arith::AddIOp>( 273 loc, rewriter.getIndexType(), destOffsetVal, 274 rewriter.create<arith::ConstantIndexOp>(loc, i)); 275 auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i); 276 dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc); 277 } 278 return dest; 279 } 280 281 /// Returns the op sequence for an emulated sub-byte data type vector load. 282 /// specifically, use `emulatedElemType` for loading a vector of `origElemType`. 283 /// The load location is given by `base` and `linearizedIndices`, and the 284 /// load size is given by `numEmulatedElementsToLoad`. 285 static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, 286 Value base, 287 OpFoldResult linearizedIndices, 288 int64_t numEmultedElementsToLoad, 289 Type origElemType, 290 Type emulatedElemType) { 291 auto scale = emulatedElemType.getIntOrFloatBitWidth() / 292 origElemType.getIntOrFloatBitWidth(); 293 auto newLoad = rewriter.create<vector::LoadOp>( 294 loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base, 295 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); 296 return rewriter.create<vector::BitCastOp>( 297 loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType), 298 newLoad); 299 } 300 301 /// Downcast two values to `downcastType`, then select values 302 /// based on `mask`, and casts the result to `upcastType`. 303 static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc, 304 VectorType downcastType, 305 VectorType upcastType, Value mask, 306 Value trueValue, Value falseValue) { 307 assert( 308 downcastType.getNumElements() * downcastType.getElementTypeBitWidth() == 309 upcastType.getNumElements() * upcastType.getElementTypeBitWidth() && 310 "expected input and output number of bits to match"); 311 if (trueValue.getType() != downcastType) { 312 trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue); 313 } 314 if (falseValue.getType() != downcastType) { 315 falseValue = 316 builder.create<vector::BitCastOp>(loc, downcastType, falseValue); 317 } 318 Value selectedType = 319 builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue); 320 // Upcast the selected value to the new type. 321 return builder.create<vector::BitCastOp>(loc, upcastType, selectedType); 322 } 323 324 /// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a 325 /// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of 326 /// subbyte-sized elements, with size of 8 bits, and the mask is used to select 327 /// which elements to store. 328 /// 329 /// Inputs: 330 /// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>) 331 /// storeIdx = 2 332 /// valueToStore = |3|3|3|3| : vector<4xi2> 333 /// mask = |0|0|1|1| : vector<4xi1> 334 /// 335 /// Result: 336 /// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>) 337 static void atomicStore(OpBuilder &builder, Location loc, 338 MemRefValue linearizedMemref, Value storeIdx, 339 VectorValue valueToStore, Value mask) { 340 assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector"); 341 342 // Create an atomic load-modify-write region using 343 // `memref.generic_atomic_rmw`. 344 auto atomicOp = builder.create<memref::GenericAtomicRMWOp>( 345 loc, linearizedMemref, ValueRange{storeIdx}); 346 Value origValue = atomicOp.getCurrentValue(); 347 348 OpBuilder::InsertionGuard guard(builder); 349 builder.setInsertionPointToStart(atomicOp.getBody()); 350 351 // Load the original value from memory, and cast it to the original element 352 // type. 353 auto oneElemVecType = VectorType::get({1}, origValue.getType()); 354 Value origVecValue = builder.create<vector::FromElementsOp>( 355 loc, oneElemVecType, ValueRange{origValue}); 356 357 // Construct the final masked value and yield it. 358 Value maskedValue = 359 downcastSelectAndUpcast(builder, loc, valueToStore.getType(), 360 oneElemVecType, mask, valueToStore, origVecValue); 361 auto scalarMaskedValue = 362 builder.create<vector::ExtractOp>(loc, maskedValue, 0); 363 builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue); 364 } 365 366 /// Extract `sliceNumElements` from source `vector` at `extractOffset`, 367 /// and insert it into an empty vector at `insertOffset`. 368 /// Inputs: 369 /// vec_in = |0|1|2|3| : vector<4xi2> 370 /// extractOffset = 1 371 /// sliceNumElements = 2 372 /// insertOffset = 2 373 /// Output: 374 /// vec_out = |0|0|1|2| : vector<4xi2> 375 static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, 376 Location loc, VectorValue vector, 377 int64_t extractOffset, 378 int64_t sliceNumElements, 379 int64_t insertOffset) { 380 assert(vector.getType().getRank() == 1 && "expected 1-D vector"); 381 auto vectorElementType = vector.getType().getElementType(); 382 // TODO: update and use `alignedConversionPrecondition` in the place of 383 // these asserts. 384 assert( 385 sliceNumElements * vectorElementType.getIntOrFloatBitWidth() <= 8 && 386 "sliceNumElements * vector element size must be less than or equal to 8"); 387 assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 && 388 "vector element must be a valid sub-byte type"); 389 auto scale = 8 / vectorElementType.getIntOrFloatBitWidth(); 390 auto emptyByteVector = rewriter.create<arith::ConstantOp>( 391 loc, VectorType::get({scale}, vectorElementType), 392 rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType))); 393 auto extracted = staticallyExtractSubvector(rewriter, loc, vector, 394 extractOffset, sliceNumElements); 395 return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector, 396 insertOffset); 397 } 398 399 namespace { 400 401 //===----------------------------------------------------------------------===// 402 // ConvertVectorStore 403 //===----------------------------------------------------------------------===// 404 405 struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { 406 using OpConversionPattern::OpConversionPattern; 407 408 LogicalResult 409 matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor, 410 ConversionPatternRewriter &rewriter) const override { 411 412 // See #115653 413 if (op.getValueToStore().getType().getRank() != 1) 414 return rewriter.notifyMatchFailure(op, 415 "only 1-D vectors are supported ATM"); 416 417 auto loc = op.getLoc(); 418 auto valueToStore = cast<VectorValue>(op.getValueToStore()); 419 auto oldElementType = valueToStore.getType().getElementType(); 420 auto newElementType = 421 cast<MemRefType>(adaptor.getBase().getType()).getElementType(); 422 int srcBits = oldElementType.getIntOrFloatBitWidth(); 423 int dstBits = newElementType.getIntOrFloatBitWidth(); 424 425 if (dstBits % srcBits != 0) { 426 return rewriter.notifyMatchFailure( 427 op, "only dstBits % srcBits == 0 supported"); 428 } 429 int numSrcElemsPerDest = dstBits / srcBits; 430 431 // Adjust the number of elements to store when emulating narrow types. 432 // Here only the 1-D vector store is considered, and the N-D memref types 433 // should be linearized. 434 // For example, to emulate i4 to i8, the following op: 435 // 436 // vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4> 437 // 438 // can be replaced with 439 // 440 // %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8> 441 // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>, 442 // vector<4xi8> 443 444 auto origElements = valueToStore.getType().getNumElements(); 445 bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0; 446 447 auto stridedMetadata = 448 rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase()); 449 450 OpFoldResult linearizedIndices; 451 memref::LinearizedMemRefInfo linearizedInfo; 452 std::tie(linearizedInfo, linearizedIndices) = 453 memref::getLinearizedMemRefOffsetAndSize( 454 rewriter, loc, srcBits, dstBits, 455 stridedMetadata.getConstifiedMixedOffset(), 456 stridedMetadata.getConstifiedMixedSizes(), 457 stridedMetadata.getConstifiedMixedStrides(), 458 getAsOpFoldResult(adaptor.getIndices())); 459 460 std::optional<int64_t> foldedNumFrontPadElems = 461 isAlignedEmulation 462 ? 0 463 : getConstantIntValue(linearizedInfo.intraDataOffset); 464 465 if (!foldedNumFrontPadElems) { 466 return rewriter.notifyMatchFailure( 467 op, "subbyte store emulation: dynamic front padding size is " 468 "not yet implemented"); 469 } 470 471 auto memrefBase = cast<MemRefValue>(adaptor.getBase()); 472 473 // Conditions when atomic RMWs are not needed: 474 // 1. The source vector size (in bits) is a multiple of byte size. 475 // 2. The address of the store is aligned to the emulated width boundary. 476 // 477 // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not 478 // need unaligned emulation because the store address is aligned and the 479 // source is a whole byte. 480 bool emulationRequiresPartialStores = 481 !isAlignedEmulation || *foldedNumFrontPadElems != 0; 482 if (!emulationRequiresPartialStores) { 483 // Basic case: storing full bytes. 484 auto numElements = origElements / numSrcElemsPerDest; 485 auto bitCast = rewriter.create<vector::BitCastOp>( 486 loc, VectorType::get(numElements, newElementType), 487 op.getValueToStore()); 488 rewriter.replaceOpWithNewOp<vector::StoreOp>( 489 op, bitCast.getResult(), memrefBase, 490 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); 491 return success(); 492 } 493 494 // Next, handle the case when sub-byte read-modify-write 495 // sequences are needed to emulate a vector store. 496 // Here is an example: 497 // 498 // Vector to store: vector<7xi2> 499 // Value to store: 11 11 11 11 11 11 11 (all ones) 500 // 501 // Destination: memref<12xi2> 502 // Store offset: 2 (i.e. 4 bits into the 1st emulated byte). 503 // 504 // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2> 505 // 506 // Destination memref before: 507 // 508 // Byte 0 Byte 1 Byte 2 509 // +----------+----------+----------+ 510 // | 00000000 | 00000000 | 00000000 | 511 // +----------+----------+----------+ 512 // 513 // Destination memref after: 514 // 515 // Byte 0 Byte 1 Byte 2 516 // +----------+----------+----------+ 517 // | 00001111 | 11111111 | 11000000 | 518 // +----------+----------+----------+ 519 // 520 // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no 521 // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence 522 // requiring RMW access (atomicity is required). 523 524 // The index into the target memref we are storing to. 525 Value currentDestIndex = 526 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices); 527 // The index into the source vector we are currently processing. 528 auto currentSourceIndex = 0; 529 530 // Build a mask used for rmw. 531 auto subWidthStoreMaskType = 532 VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type()); 533 534 // 1. Partial width store for the leading byte. 535 // When the store address is not aligned to emulated width boundary, deal 536 // with the unaligned part so that the rest elements are aligned to width 537 // boundary. 538 auto frontSubWidthStoreElem = 539 (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest; 540 if (frontSubWidthStoreElem > 0) { 541 SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false); 542 if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) { 543 std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems, 544 origElements, true); 545 frontSubWidthStoreElem = origElements; 546 } else { 547 std::fill_n(frontMaskValues.end() - frontSubWidthStoreElem, 548 *foldedNumFrontPadElems, true); 549 } 550 auto frontMask = rewriter.create<arith::ConstantOp>( 551 loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues)); 552 553 currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems); 554 auto value = 555 extractSliceIntoByte(rewriter, loc, valueToStore, 0, 556 frontSubWidthStoreElem, *foldedNumFrontPadElems); 557 558 atomicStore(rewriter, loc, memrefBase, currentDestIndex, 559 cast<VectorValue>(value), frontMask.getResult()); 560 } 561 562 if (currentSourceIndex >= origElements) { 563 rewriter.eraseOp(op); 564 return success(); 565 } 566 567 // Increment the destination index by 1 to align to the emulated width 568 // boundary. 569 auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1); 570 currentDestIndex = rewriter.create<arith::AddIOp>( 571 loc, rewriter.getIndexType(), currentDestIndex, constantOne); 572 573 // 2. Full width store for the inner output bytes. 574 // After the previous step, the store address is aligned to the emulated 575 // width boundary. 576 int64_t fullWidthStoreSize = 577 (origElements - currentSourceIndex) / numSrcElemsPerDest; 578 int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest; 579 if (fullWidthStoreSize > 0) { 580 auto fullWidthStorePart = staticallyExtractSubvector( 581 rewriter, loc, valueToStore, currentSourceIndex, 582 numNonFullWidthElements); 583 584 auto originType = cast<VectorType>(fullWidthStorePart.getType()); 585 auto memrefElemType = getElementTypeOrSelf(memrefBase.getType()); 586 auto storeType = VectorType::get( 587 {originType.getNumElements() / numSrcElemsPerDest}, memrefElemType); 588 auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, 589 fullWidthStorePart); 590 rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase, 591 currentDestIndex); 592 593 currentSourceIndex += numNonFullWidthElements; 594 currentDestIndex = rewriter.create<arith::AddIOp>( 595 loc, rewriter.getIndexType(), currentDestIndex, 596 rewriter.create<arith::ConstantIndexOp>(loc, fullWidthStoreSize)); 597 } 598 599 // 3. Partial width store for the trailing output byte. 600 // It is needed when the residual length is smaller than the emulated width, 601 // which is not covered in step 2 above. 602 auto remainingElements = origElements - currentSourceIndex; 603 if (remainingElements != 0) { 604 auto subWidthStorePart = 605 extractSliceIntoByte(rewriter, loc, cast<VectorValue>(valueToStore), 606 currentSourceIndex, remainingElements, 0); 607 608 // Generate back mask. 609 auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0); 610 std::fill_n(maskValues.begin(), remainingElements, 1); 611 auto backMask = rewriter.create<arith::ConstantOp>( 612 loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues)); 613 614 atomicStore(rewriter, loc, memrefBase, currentDestIndex, 615 cast<VectorValue>(subWidthStorePart), backMask.getResult()); 616 } 617 618 rewriter.eraseOp(op); 619 return success(); 620 } 621 }; 622 623 //===----------------------------------------------------------------------===// 624 // ConvertVectorMaskedStore 625 //===----------------------------------------------------------------------===// 626 627 struct ConvertVectorMaskedStore final 628 : OpConversionPattern<vector::MaskedStoreOp> { 629 using OpConversionPattern::OpConversionPattern; 630 631 LogicalResult 632 matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor, 633 ConversionPatternRewriter &rewriter) const override { 634 635 // See #115653 636 if (op.getValueToStore().getType().getRank() != 1) 637 return rewriter.notifyMatchFailure(op, 638 "only 1-D vectors are supported ATM"); 639 640 auto loc = op.getLoc(); 641 auto convertedType = cast<MemRefType>(adaptor.getBase().getType()); 642 Type oldElementType = op.getValueToStore().getType().getElementType(); 643 Type newElementType = convertedType.getElementType(); 644 int srcBits = oldElementType.getIntOrFloatBitWidth(); 645 int dstBits = newElementType.getIntOrFloatBitWidth(); 646 647 if (dstBits % srcBits != 0) { 648 return rewriter.notifyMatchFailure( 649 op, "only dstBits % srcBits == 0 supported"); 650 } 651 652 int scale = dstBits / srcBits; 653 int origElements = op.getValueToStore().getType().getNumElements(); 654 if (origElements % scale != 0) 655 return failure(); 656 657 auto stridedMetadata = 658 rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase()); 659 OpFoldResult linearizedIndicesOfr; 660 memref::LinearizedMemRefInfo linearizedInfo; 661 std::tie(linearizedInfo, linearizedIndicesOfr) = 662 memref::getLinearizedMemRefOffsetAndSize( 663 rewriter, loc, srcBits, dstBits, 664 stridedMetadata.getConstifiedMixedOffset(), 665 stridedMetadata.getConstifiedMixedSizes(), 666 stridedMetadata.getConstifiedMixedStrides(), 667 getAsOpFoldResult(adaptor.getIndices())); 668 Value linearizedIndices = 669 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr); 670 671 // Load the whole data and use arith.select to handle the corner cases. 672 // 673 // As an example, for this masked store of i4 values: 674 // 675 // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store 676 // 677 // and given these input values: 678 // 679 // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1) 680 // %0[%c0, %c0] = 681 // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4) 682 // %val_to_store = 683 // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4) 684 // 685 // we'll have the following i4 output: 686 // 687 // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8] 688 // 689 // Emulating the above using i8 will give: 690 // 691 // %compressed_mask = [1, 1, 1, 0] (4 * i1) 692 // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8) 693 // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4) 694 // %select_using_shifted_mask = 695 // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4) 696 // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8) 697 // 698 // Using the compressed mask to store %packed_data results in expected 699 // output. 700 // 701 // FIXME: Make an example based on the comment above work (see #115460 for 702 // reproducer). 703 FailureOr<Operation *> newMask = 704 getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale); 705 if (failed(newMask)) 706 return failure(); 707 708 auto numElements = (origElements + scale - 1) / scale; 709 auto newType = VectorType::get(numElements, newElementType); 710 auto passThru = rewriter.create<arith::ConstantOp>( 711 loc, newType, rewriter.getZeroAttr(newType)); 712 713 auto newLoad = rewriter.create<vector::MaskedLoadOp>( 714 loc, newType, adaptor.getBase(), linearizedIndices, 715 newMask.value()->getResult(0), passThru); 716 717 auto newBitCastType = VectorType::get(numElements * scale, oldElementType); 718 Value valueToStore = 719 rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad); 720 valueToStore = rewriter.create<arith::SelectOp>( 721 loc, op.getMask(), op.getValueToStore(), valueToStore); 722 valueToStore = 723 rewriter.create<vector::BitCastOp>(loc, newType, valueToStore); 724 725 rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>( 726 op, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0), 727 valueToStore); 728 return success(); 729 } 730 }; 731 732 //===----------------------------------------------------------------------===// 733 // ConvertVectorLoad 734 //===----------------------------------------------------------------------===// 735 736 struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> { 737 using OpConversionPattern::OpConversionPattern; 738 739 LogicalResult 740 matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor, 741 ConversionPatternRewriter &rewriter) const override { 742 743 // See #115653 744 if (op.getVectorType().getRank() != 1) 745 return rewriter.notifyMatchFailure(op, 746 "only 1-D vectors are supported ATM"); 747 748 auto loc = op.getLoc(); 749 auto convertedType = cast<MemRefType>(adaptor.getBase().getType()); 750 Type oldElementType = op.getType().getElementType(); 751 Type newElementType = convertedType.getElementType(); 752 int srcBits = oldElementType.getIntOrFloatBitWidth(); 753 int dstBits = newElementType.getIntOrFloatBitWidth(); 754 755 if (dstBits % srcBits != 0) { 756 return rewriter.notifyMatchFailure( 757 op, "only dstBits % srcBits == 0 supported"); 758 } 759 int scale = dstBits / srcBits; 760 761 // Adjust the number of elements to load when emulating narrow types, 762 // and then cast back to the original type with vector.bitcast op. 763 // Here only the 1-D vector load is considered, and the N-D memref types 764 // should be linearized. 765 // For example, to emulate i4 to i8, the following op: 766 // 767 // %1 = vector.load %0[%c0, %c0] : memref<3x4xi4>, vector<4xi4> 768 // 769 // can be replaced with 770 // 771 // %1 = vector.load %0[%linear_index] : memref<6xi8>, vector<2xi8> 772 // %2 = vector.bitcast %1 : vector<2xi8> to vector<4xi4> 773 // 774 // There are cases where the number of elements to load is not byte-aligned, 775 // for example: 776 // 777 // %1 = vector.load %0[%c1, %c0] : memref<3x3xi2>, vector<3xi2> 778 // 779 // we will have to load extra bytes and extract the exact slice in between. 780 // 781 // %1 = vector.load %0[%c2] : memref<3xi8>, vector<2xi8> 782 // %2 = vector.bitcast %1 : vector<2xi8> to vector<8xi2> 783 // %3 = vector.extract_strided_slice %1 {offsets = [2], sizes = [3], strides 784 // = [1]} 785 // : vector<8xi2> to vector<3xi2> 786 // 787 // TODO: Currently the extract_strided_slice's attributes must be known at 788 // compile time as they must be constants. 789 790 auto origElements = op.getVectorType().getNumElements(); 791 bool isAlignedEmulation = origElements % scale == 0; 792 793 auto stridedMetadata = 794 rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase()); 795 796 OpFoldResult linearizedIndices; 797 memref::LinearizedMemRefInfo linearizedInfo; 798 std::tie(linearizedInfo, linearizedIndices) = 799 memref::getLinearizedMemRefOffsetAndSize( 800 rewriter, loc, srcBits, dstBits, 801 stridedMetadata.getConstifiedMixedOffset(), 802 stridedMetadata.getConstifiedMixedSizes(), 803 stridedMetadata.getConstifiedMixedStrides(), 804 getAsOpFoldResult(adaptor.getIndices())); 805 806 std::optional<int64_t> foldedIntraVectorOffset = 807 isAlignedEmulation 808 ? 0 809 : getConstantIntValue(linearizedInfo.intraDataOffset); 810 811 // Always load enough elements which can cover the original elements. 812 int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); 813 auto numElements = 814 llvm::divideCeil(maxintraDataOffset + origElements, scale); 815 Value result = 816 emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices, 817 numElements, oldElementType, newElementType); 818 819 if (!foldedIntraVectorOffset) { 820 auto resultVector = rewriter.create<arith::ConstantOp>( 821 loc, op.getType(), rewriter.getZeroAttr(op.getType())); 822 result = dynamicallyExtractSubVector(rewriter, loc, result, resultVector, 823 linearizedInfo.intraDataOffset, 824 origElements); 825 } else if (!isAlignedEmulation) { 826 result = staticallyExtractSubvector( 827 rewriter, loc, result, *foldedIntraVectorOffset, origElements); 828 } 829 rewriter.replaceOp(op, result); 830 return success(); 831 } 832 }; 833 834 //===----------------------------------------------------------------------===// 835 // ConvertVectorMaskedLoad 836 //===----------------------------------------------------------------------===// 837 838 struct ConvertVectorMaskedLoad final 839 : OpConversionPattern<vector::MaskedLoadOp> { 840 using OpConversionPattern::OpConversionPattern; 841 842 LogicalResult 843 matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor, 844 ConversionPatternRewriter &rewriter) const override { 845 // See #115653 846 if (op.getVectorType().getRank() != 1) 847 return rewriter.notifyMatchFailure(op, 848 "only 1-D vectors are supported ATM"); 849 850 auto loc = op.getLoc(); 851 auto convertedType = cast<MemRefType>(adaptor.getBase().getType()); 852 Type oldElementType = op.getType().getElementType(); 853 Type newElementType = convertedType.getElementType(); 854 int srcBits = oldElementType.getIntOrFloatBitWidth(); 855 int dstBits = newElementType.getIntOrFloatBitWidth(); 856 857 if (dstBits % srcBits != 0) { 858 return rewriter.notifyMatchFailure( 859 op, "only dstBits % srcBits == 0 supported"); 860 } 861 int scale = dstBits / srcBits; 862 863 // Adjust the number of elements to load when emulating narrow types, 864 // and then cast back to the original type with vector.bitcast op. 865 // For example, to emulate i4 to i8, the following op: 866 // 867 // %mask = vector.constant_mask [3] : vector<6xi1> 868 // %1 = vector.maskedload %0[%c0, %c0], %mask, %pass_thru : 869 // memref<3x6xi4>, vector<6xi1>, vector<6xi4> into vector<6xi4> 870 // 871 // can be replaced with 872 // 873 // %new_mask = vector.constant_mask [2] : vector<3xi1> 874 // %new_pass_thru = vector.bitcast %pass_thru : 875 // vector<6xi4> to vector<3xi8> 876 // %1 = vector.maskedload %0[%linear_index], %new_mask, %new_pass_thru : 877 // memref<9xi8>, vector<3xi1>, vector<3xi8> into vector<3xi8> 878 // %2 = vector.bitcast %1 : vector<3xi8> to vector<6xi4> 879 // 880 // Since we are effectively loading 16 bits (2xi8) from the memref with the 881 // new mask, while originally we only wanted to effectively load 12 bits 882 // (3xi4) from the memref, we need to set the second half of the last i8 883 // that was effectively loaded (i.e. the second i8) to %pass_thru. 884 // 885 // %3 = arith.select %mask, %2, %pass_thru : vector<6xi1>, vector<6xi4> 886 // 887 // Given these input values: 888 // %mask = [1, 1, 1, 0, 0, 0] 889 // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6] 890 // %pass_thru = [0x7, 0x8, 0x9, 0xA, 0xB, 0xC] 891 // 892 // we'll have: 893 // 894 // expected output: [0x1, 0x2, 0x3, 0xA, 0xB, 0xC] 895 // 896 // %new_mask = [1, 1, 0] 897 // %new_pass_thru = [0x78, 0x9A, 0xBC] 898 // %1 = [0x12, 0x34, 0xBC] 899 // %2 = [0x1, 0x2, 0x3, 0x4, 0xB, 0xC] 900 // %3 = [0x1, 0x2, 0x3, 0xA, 0xB, 0xC] 901 // 902 // TODO: Currently, only the even number of elements loading is supported. 903 // To deal with the odd number of elements, one has to extract the 904 // subvector at the proper offset after bit-casting. 905 auto origType = op.getVectorType(); 906 auto origElements = origType.getNumElements(); 907 bool isAlignedEmulation = origElements % scale == 0; 908 909 auto stridedMetadata = 910 rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase()); 911 OpFoldResult linearizedIndices; 912 memref::LinearizedMemRefInfo linearizedInfo; 913 std::tie(linearizedInfo, linearizedIndices) = 914 memref::getLinearizedMemRefOffsetAndSize( 915 rewriter, loc, srcBits, dstBits, 916 stridedMetadata.getConstifiedMixedOffset(), 917 stridedMetadata.getConstifiedMixedSizes(), 918 stridedMetadata.getConstifiedMixedStrides(), 919 getAsOpFoldResult(adaptor.getIndices())); 920 921 std::optional<int64_t> foldedIntraVectorOffset = 922 isAlignedEmulation 923 ? 0 924 : getConstantIntValue(linearizedInfo.intraDataOffset); 925 926 int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); 927 FailureOr<Operation *> newMask = getCompressedMaskOp( 928 rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset); 929 if (failed(newMask)) 930 return failure(); 931 932 Value passthru = op.getPassThru(); 933 934 auto numElements = 935 llvm::divideCeil(maxIntraDataOffset + origElements, scale); 936 auto loadType = VectorType::get(numElements, newElementType); 937 auto newBitcastType = VectorType::get(numElements * scale, oldElementType); 938 939 auto emptyVector = rewriter.create<arith::ConstantOp>( 940 loc, newBitcastType, rewriter.getZeroAttr(newBitcastType)); 941 if (!foldedIntraVectorOffset) { 942 passthru = dynamicallyInsertSubVector( 943 rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset, 944 origElements); 945 } else if (!isAlignedEmulation) { 946 passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector, 947 *foldedIntraVectorOffset); 948 } 949 auto newPassThru = 950 rewriter.create<vector::BitCastOp>(loc, loadType, passthru); 951 952 // Generating the new masked load. 953 auto newLoad = rewriter.create<vector::MaskedLoadOp>( 954 loc, loadType, adaptor.getBase(), 955 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), 956 newMask.value()->getResult(0), newPassThru); 957 958 // Setting the part that originally was not effectively loaded from memory 959 // to pass through. 960 auto bitCast = 961 rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad); 962 963 Value mask = op.getMask(); 964 auto newSelectMaskType = 965 VectorType::get(numElements * scale, rewriter.getI1Type()); 966 // TODO: try to fold if op's mask is constant 967 auto emptyMask = rewriter.create<arith::ConstantOp>( 968 loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType)); 969 if (!foldedIntraVectorOffset) { 970 mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask, 971 linearizedInfo.intraDataOffset, 972 origElements); 973 } else if (!isAlignedEmulation) { 974 mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask, 975 *foldedIntraVectorOffset); 976 } 977 978 Value result = 979 rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru); 980 if (!foldedIntraVectorOffset) { 981 result = dynamicallyExtractSubVector( 982 rewriter, loc, result, op.getPassThru(), 983 linearizedInfo.intraDataOffset, origElements); 984 } else if (!isAlignedEmulation) { 985 result = staticallyExtractSubvector( 986 rewriter, loc, result, *foldedIntraVectorOffset, origElements); 987 } 988 rewriter.replaceOp(op, result); 989 990 return success(); 991 } 992 }; 993 994 //===----------------------------------------------------------------------===// 995 // ConvertVectorTransferRead 996 //===----------------------------------------------------------------------===// 997 998 struct ConvertVectorTransferRead final 999 : OpConversionPattern<vector::TransferReadOp> { 1000 using OpConversionPattern::OpConversionPattern; 1001 1002 LogicalResult 1003 matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor, 1004 ConversionPatternRewriter &rewriter) const override { 1005 1006 // See #115653 1007 if (op.getVectorType().getRank() != 1) 1008 return rewriter.notifyMatchFailure(op, 1009 "only 1-D vectors are supported ATM"); 1010 1011 auto loc = op.getLoc(); 1012 auto convertedType = cast<MemRefType>(adaptor.getSource().getType()); 1013 Type oldElementType = op.getType().getElementType(); 1014 Type newElementType = convertedType.getElementType(); 1015 int srcBits = oldElementType.getIntOrFloatBitWidth(); 1016 int dstBits = newElementType.getIntOrFloatBitWidth(); 1017 1018 if (dstBits % srcBits != 0) { 1019 return rewriter.notifyMatchFailure( 1020 op, "only dstBits % srcBits == 0 supported"); 1021 } 1022 int scale = dstBits / srcBits; 1023 1024 auto origElements = op.getVectorType().getNumElements(); 1025 1026 bool isAlignedEmulation = origElements % scale == 0; 1027 1028 auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType, 1029 adaptor.getPadding()); 1030 1031 auto stridedMetadata = 1032 rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource()); 1033 1034 OpFoldResult linearizedIndices; 1035 memref::LinearizedMemRefInfo linearizedInfo; 1036 std::tie(linearizedInfo, linearizedIndices) = 1037 memref::getLinearizedMemRefOffsetAndSize( 1038 rewriter, loc, srcBits, dstBits, 1039 stridedMetadata.getConstifiedMixedOffset(), 1040 stridedMetadata.getConstifiedMixedSizes(), 1041 stridedMetadata.getConstifiedMixedStrides(), 1042 getAsOpFoldResult(adaptor.getIndices())); 1043 1044 std::optional<int64_t> foldedIntraVectorOffset = 1045 isAlignedEmulation 1046 ? 0 1047 : getConstantIntValue(linearizedInfo.intraDataOffset); 1048 1049 int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); 1050 auto numElements = 1051 llvm::divideCeil(maxIntraDataOffset + origElements, scale); 1052 1053 auto newRead = rewriter.create<vector::TransferReadOp>( 1054 loc, VectorType::get(numElements, newElementType), adaptor.getSource(), 1055 getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), 1056 newPadding); 1057 1058 auto bitCast = rewriter.create<vector::BitCastOp>( 1059 loc, VectorType::get(numElements * scale, oldElementType), newRead); 1060 1061 Value result = bitCast->getResult(0); 1062 if (!foldedIntraVectorOffset) { 1063 auto zeros = rewriter.create<arith::ConstantOp>( 1064 loc, op.getType(), rewriter.getZeroAttr(op.getType())); 1065 result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros, 1066 linearizedInfo.intraDataOffset, 1067 origElements); 1068 } else if (!isAlignedEmulation) { 1069 result = staticallyExtractSubvector( 1070 rewriter, loc, result, *foldedIntraVectorOffset, origElements); 1071 } 1072 rewriter.replaceOp(op, result); 1073 1074 return success(); 1075 } 1076 }; 1077 } // end anonymous namespace 1078 1079 //===----------------------------------------------------------------------===// 1080 // RewriteBitCastOfTruncI 1081 //===----------------------------------------------------------------------===// 1082 1083 namespace { 1084 1085 /// Helper struct to keep track of the provenance of a contiguous set of bits 1086 /// in a source vector. 1087 struct SourceElementRange { 1088 /// The index of the source vector element that contributes bits to *this. 1089 int64_t sourceElementIdx; 1090 /// The range of bits in the source vector element that contribute to *this. 1091 int64_t sourceBitBegin; 1092 int64_t sourceBitEnd; 1093 }; 1094 1095 struct SourceElementRangeList : public SmallVector<SourceElementRange> { 1096 /// Given the index of a SourceElementRange in the SourceElementRangeList, 1097 /// compute the amount of bits that need to be shifted to the left to get the 1098 /// bits in their final location. This shift amount is simply the sum of the 1099 /// bits *before* `shuffleIdx` (i.e. the bits of `shuffleIdx = 0` are always 1100 /// the LSBs, the bits of `shuffleIdx = ` come next, etc). 1101 int64_t computeLeftShiftAmount(int64_t shuffleIdx) const { 1102 int64_t res = 0; 1103 for (int64_t i = 0; i < shuffleIdx; ++i) 1104 res += (*this)[i].sourceBitEnd - (*this)[i].sourceBitBegin; 1105 return res; 1106 } 1107 }; 1108 1109 /// Helper struct to enumerate the source elements and bit ranges that are 1110 /// involved in a bitcast operation. 1111 /// This allows rewriting a vector.bitcast into shuffles and bitwise ops for 1112 /// any 1-D vector shape and any source/target bitwidths. 1113 /// This creates and holds a mapping of the form: 1114 /// [dstVectorElementJ] == 1115 /// [ {srcVectorElementX, bitRange}, {srcVectorElementY, bitRange}, ... ] 1116 /// E.g. `vector.bitcast ... : vector<1xi24> to vector<3xi8>` is decomposed as: 1117 /// [0] = {0, [0-8)} 1118 /// [1] = {0, [8-16)} 1119 /// [2] = {0, [16-24)} 1120 /// and `vector.bitcast ... : vector<2xi15> to vector<3xi10>` is decomposed as: 1121 /// [0] = {0, [0, 10)}, {1, [0, 5)} 1122 /// [1] = {1, [5, 10)}, {2, [0, 10)} 1123 struct BitCastBitsEnumerator { 1124 BitCastBitsEnumerator(VectorType sourceVectorType, 1125 VectorType targetVectorType); 1126 1127 int64_t getMaxNumberOfEntries() { 1128 int64_t numVectors = 0; 1129 for (const auto &l : sourceElementRanges) 1130 numVectors = std::max(numVectors, (int64_t)l.size()); 1131 return numVectors; 1132 } 1133 1134 VectorType sourceVectorType; 1135 VectorType targetVectorType; 1136 SmallVector<SourceElementRangeList> sourceElementRanges; 1137 }; 1138 1139 /// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take 1140 /// advantage of high-level information to avoid leaving LLVM to scramble with 1141 /// peephole optimizations. 1142 /// BitCastBitsEnumerator encodes for each element of the target vector the 1143 /// provenance of the bits in the source vector. We can "transpose" this 1144 /// information to build a sequence of shuffles and bitwise ops that will 1145 /// produce the desired result. 1146 // 1147 /// Consider the following motivating example: 1148 /// ``` 1149 /// %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8> 1150 /// ``` 1151 // 1152 /// BitCastBitsEnumerator contains the following information: 1153 /// ``` 1154 /// { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5} 1155 /// { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7} 1156 /// { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4} 1157 /// { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6} 1158 /// { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3} 1159 /// { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5} 1160 /// { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7} 1161 /// {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4} 1162 /// {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6} 1163 /// {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3} 1164 /// {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5} 1165 /// {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7} 1166 /// {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4} 1167 /// {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6} 1168 /// {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3} 1169 /// {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5} 1170 /// {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7} 1171 /// {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4} 1172 /// {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6} 1173 /// {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3} 1174 /// ``` 1175 /// 1176 /// In the above, each row represents one target vector element and each 1177 /// column represents one bit contribution from a source vector element. 1178 /// The algorithm creates vector.shuffle operations (in this case there are 3 1179 /// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The 1180 /// algorithm populates the bits as follows: 1181 /// ``` 1182 /// src bits 0 ... 1183 /// 1st shuffle |xxxxx |xx |... 1184 /// 2nd shuffle | xxx| xxxxx |... 1185 /// 3rd shuffle | | x|... 1186 /// ``` 1187 // 1188 /// The algorithm proceeds as follows: 1189 /// 1. for each vector.shuffle, collect the source vectors that participate in 1190 /// this shuffle. One source vector per target element of the resulting 1191 /// vector.shuffle. If there is no source element contributing bits for the 1192 /// current vector.shuffle, take 0 (i.e. row 0 in the above example has only 1193 /// 2 columns). 1194 /// 2. represent the bitrange in the source vector as a mask. If there is no 1195 /// source element contributing bits for the current vector.shuffle, take 0. 1196 /// 3. shift right by the proper amount to align the source bitrange at 1197 /// position 0. This is exactly the low end of the bitrange. For instance, 1198 /// the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to 1199 /// shift right by 3 to get the bits contributed by the source element #1 1200 /// into position 0. 1201 /// 4. shift left by the proper amount to to align to the desired position in 1202 /// the result element vector. For instance, the contribution of the second 1203 /// source element for the first row needs to be shifted by `5` to form the 1204 /// first i8 result element. 1205 /// 1206 /// Eventually, we end up building the sequence 1207 /// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update 1208 /// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the 1209 /// bits extracted from the source vector (i.e. the `shuffle -> and` part). 1210 struct BitCastRewriter { 1211 /// Helper metadata struct to hold the static quantities for the rewrite. 1212 struct Metadata { 1213 SmallVector<int64_t> shuffles; 1214 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts; 1215 }; 1216 1217 BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType); 1218 1219 /// Verify that general preconditions for the rewrite are met. 1220 LogicalResult commonPrecondition(PatternRewriter &rewriter, 1221 VectorType preconditionType, Operation *op); 1222 1223 /// Precompute the metadata for the rewrite. 1224 SmallVector<BitCastRewriter::Metadata> 1225 precomputeMetadata(IntegerType shuffledElementType); 1226 1227 /// Rewrite one step of the sequence: 1228 /// `(shuffle -> and -> shiftright -> shiftleft -> or)`. 1229 Value genericRewriteStep(PatternRewriter &rewriter, Location loc, 1230 Value initialValue, Value runningResult, 1231 const BitCastRewriter::Metadata &metadata); 1232 1233 private: 1234 /// Underlying enumerator that encodes the provenance of the bits in the each 1235 /// element of the result vector. 1236 BitCastBitsEnumerator enumerator; 1237 }; 1238 1239 } // namespace 1240 1241 [[maybe_unused]] static raw_ostream & 1242 operator<<(raw_ostream &os, const SmallVector<SourceElementRangeList> &vec) { 1243 for (const auto &l : vec) { 1244 for (auto it : llvm::enumerate(l)) { 1245 os << "{ " << it.value().sourceElementIdx << ": b@[" 1246 << it.value().sourceBitBegin << ".." << it.value().sourceBitEnd 1247 << ") lshl: " << l.computeLeftShiftAmount(it.index()) << " } "; 1248 } 1249 os << "\n"; 1250 } 1251 return os; 1252 } 1253 1254 BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType, 1255 VectorType targetVectorType) 1256 : sourceVectorType(sourceVectorType), targetVectorType(targetVectorType) { 1257 1258 assert(sourceVectorType.getRank() == 1 && !sourceVectorType.isScalable() && 1259 "requires -D non-scalable vector type"); 1260 assert(targetVectorType.getRank() == 1 && !targetVectorType.isScalable() && 1261 "requires -D non-scalable vector type"); 1262 int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth(); 1263 int64_t mostMinorSourceDim = sourceVectorType.getShape().back(); 1264 LDBG("sourceVectorType: " << sourceVectorType); 1265 1266 int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth(); 1267 int64_t mostMinorTargetDim = targetVectorType.getShape().back(); 1268 LDBG("targetVectorType: " << targetVectorType); 1269 1270 int64_t bitwidth = targetBitWidth * mostMinorTargetDim; 1271 (void)mostMinorSourceDim; 1272 assert(bitwidth == sourceBitWidth * mostMinorSourceDim && 1273 "source and target bitwidths must match"); 1274 1275 // Prepopulate one source element range per target element. 1276 sourceElementRanges = SmallVector<SourceElementRangeList>(mostMinorTargetDim); 1277 for (int64_t resultBit = 0; resultBit < bitwidth;) { 1278 int64_t resultElement = resultBit / targetBitWidth; 1279 int64_t resultBitInElement = resultBit % targetBitWidth; 1280 int64_t sourceElementIdx = resultBit / sourceBitWidth; 1281 int64_t sourceBitInElement = resultBit % sourceBitWidth; 1282 int64_t step = std::min(sourceBitWidth - sourceBitInElement, 1283 targetBitWidth - resultBitInElement); 1284 sourceElementRanges[resultElement].push_back( 1285 {sourceElementIdx, sourceBitInElement, sourceBitInElement + step}); 1286 resultBit += step; 1287 } 1288 } 1289 1290 BitCastRewriter::BitCastRewriter(VectorType sourceVectorType, 1291 VectorType targetVectorType) 1292 : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) { 1293 LDBG("\n" << enumerator.sourceElementRanges); 1294 } 1295 1296 /// Verify that the precondition type meets the common preconditions for any 1297 /// conversion. 1298 static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter, 1299 VectorType preconditionType, 1300 Operation *op) { 1301 if (!preconditionType || preconditionType.isScalable()) 1302 return rewriter.notifyMatchFailure(op, "scalable vector"); 1303 1304 // TODO: consider relaxing this restriction in the future if we find ways 1305 // to really work with subbyte elements across the MLIR/LLVM boundary. 1306 unsigned bitwidth = preconditionType.getElementTypeBitWidth(); 1307 if (bitwidth % 8 != 0) 1308 return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8"); 1309 1310 return success(); 1311 } 1312 1313 LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter, 1314 VectorType preconditionType, 1315 Operation *op) { 1316 if (!enumerator.sourceVectorType || !enumerator.targetVectorType) 1317 return rewriter.notifyMatchFailure(op, "types are not vector"); 1318 1319 if (!preconditionType || preconditionType.getRank() != 1) 1320 return rewriter.notifyMatchFailure(op, "unsupported >1-D vector"); 1321 1322 return commonConversionPrecondition(rewriter, preconditionType, op); 1323 } 1324 1325 /// Verify that `subByteVecType` and `dstType` are aligned. Alignment 1326 /// means that: 1327 /// 1. The `dstType` element type is a multiple of the 1328 /// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8 1329 /// is not supported). Let this multiple be `N`. 1330 /// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a 1331 /// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is 1332 /// not supported). 1333 /// 1334 /// NOTE: This method assumes that common conversion preconditions are met. In 1335 /// particular, the element type of `dstType` is assumed to be a multi-byte 1336 /// type (e.g. i8, i16, i32). 1337 static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, 1338 VectorType subByteVecType, 1339 VectorType dstType, 1340 Operation *op) { 1341 if (!subByteVecType || !dstType) 1342 return rewriter.notifyMatchFailure(op, "Not a supported aligned case"); 1343 unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth(); 1344 unsigned dstElemBitwidth = dstType.getElementTypeBitWidth(); 1345 1346 if (dstElemBitwidth < 8) 1347 return rewriter.notifyMatchFailure( 1348 op, "the bitwidth of dstType must be greater than or equal to 8"); 1349 if (dstElemBitwidth % srcElemBitwidth != 0) 1350 return rewriter.notifyMatchFailure(op, "unaligned cases are not supported"); 1351 if (srcElemBitwidth != 2 && srcElemBitwidth != 4) 1352 return rewriter.notifyMatchFailure( 1353 op, "only src bitwidth of 2 or 4 is supported at this moment"); 1354 1355 const int numSrcElemsPerByte = 8 / srcElemBitwidth; 1356 if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0) 1357 return rewriter.notifyMatchFailure( 1358 op, "the trailing dimension of the input vector of sub-bytes must be a " 1359 "multiple of 8 / <sub-byte-width>"); 1360 1361 return success(); 1362 } 1363 1364 SmallVector<BitCastRewriter::Metadata> 1365 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) { 1366 SmallVector<BitCastRewriter::Metadata> result; 1367 for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries(); 1368 shuffleIdx < e; ++shuffleIdx) { 1369 SmallVector<int64_t> shuffles; 1370 SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts; 1371 1372 // Create the attribute quantities for the shuffle / mask / shift ops. 1373 for (auto &srcEltRangeList : enumerator.sourceElementRanges) { 1374 int64_t sourceElement = (shuffleIdx < (int64_t)srcEltRangeList.size()) 1375 ? srcEltRangeList[shuffleIdx].sourceElementIdx 1376 : 0; 1377 shuffles.push_back(sourceElement); 1378 1379 int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size()) 1380 ? srcEltRangeList[shuffleIdx].sourceBitBegin 1381 : 0; 1382 int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size()) 1383 ? srcEltRangeList[shuffleIdx].sourceBitEnd 1384 : 0; 1385 IntegerAttr mask = IntegerAttr::get( 1386 shuffledElementType, 1387 llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(), 1388 bitLo, bitHi)); 1389 masks.push_back(mask); 1390 1391 int64_t shiftRight = bitLo; 1392 shiftRightAmounts.push_back( 1393 IntegerAttr::get(shuffledElementType, shiftRight)); 1394 1395 int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx); 1396 shiftLeftAmounts.push_back( 1397 IntegerAttr::get(shuffledElementType, shiftLeft)); 1398 } 1399 1400 result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts}); 1401 } 1402 return result; 1403 } 1404 1405 Value BitCastRewriter::genericRewriteStep( 1406 PatternRewriter &rewriter, Location loc, Value initialValue, 1407 Value runningResult, const BitCastRewriter::Metadata &metadata) { 1408 // Create vector.shuffle from the metadata. 1409 auto shuffleOp = rewriter.create<vector::ShuffleOp>( 1410 loc, initialValue, initialValue, metadata.shuffles); 1411 1412 // Intersect with the mask. 1413 VectorType shuffledVectorType = shuffleOp.getResultVectorType(); 1414 auto constOp = rewriter.create<arith::ConstantOp>( 1415 loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks)); 1416 Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp); 1417 1418 // Align right on 0. 1419 auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>( 1420 loc, 1421 DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts)); 1422 Value shiftedRight = 1423 rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp); 1424 1425 // Shift bits left into their final position. 1426 auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>( 1427 loc, 1428 DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts)); 1429 Value shiftedLeft = 1430 rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp); 1431 1432 runningResult = 1433 runningResult 1434 ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft) 1435 : shiftedLeft; 1436 1437 return runningResult; 1438 } 1439 1440 /// Bitcasts the aligned `subByteVec` vector to a vector of i8. 1441 /// Where aligned means it satisfies the alignedConversionPreconditions. 1442 /// 1443 /// Example: 1444 /// vector<16x16xi2> -> vector<16x4xi8> 1445 /// vector<16x16xi4> -> vector<16x8xi8> 1446 static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc, 1447 Value subByteVec) { 1448 auto srcVecType = cast<VectorType>(subByteVec.getType()); 1449 int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth(); 1450 assert(8 % srcBitwidth == 0 && 1451 "Unsupported sub-byte type (not a divisor of i8)"); 1452 int64_t numSrcElemsPerByte = 8 / srcBitwidth; 1453 SmallVector<int64_t> vecShape(srcVecType.getShape()); 1454 // Adjust last dimension of the vector, so the total size remains the same. 1455 vecShape.back() = vecShape.back() / numSrcElemsPerByte; 1456 auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type()); 1457 return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec); 1458 } 1459 1460 /// Extracts a signed N-bit sequence from each element of a vector of bytes, 1461 /// starting at the specified bit index. 1462 /// The `bitIdx` starts at 0 from the LSB and moves to the left. 1463 /// 1464 /// Example for a single element: 1465 /// Extract numBits=2 starting at bitIdx=2 1466 /// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0] 1467 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0] 1468 /// target = [. . . . ^ ^ . .] 1469 /// 1470 /// The target sequence is [11](decimal=-1) as signed 2-bit integer. 1471 /// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer. 1472 /// 1473 /// src = [01 01 11 10] 1474 /// shl = arith.shl(src, 4) -> [11 10 00 00] 1475 /// result = arith.shrsi(shl, 6) -> [11 11 11 11] 1476 static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter, 1477 Location loc, Value src, 1478 int bitIdx, int numBits) { 1479 auto srcType = cast<VectorType>(src.getType()); 1480 Value shl = src; 1481 int8_t bitsToShiftLeft = 8 - numBits - bitIdx; 1482 assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 && 1483 "Invalid bitIdx range"); 1484 if (bitsToShiftLeft != 0) { 1485 Value shiftLeftValues = rewriter.create<arith::ConstantOp>( 1486 loc, DenseElementsAttr::get(srcType, bitsToShiftLeft)); 1487 shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues); 1488 } 1489 1490 int8_t bitsToShiftRight = 8 - numBits; 1491 Value shiftRightValues = rewriter.create<arith::ConstantOp>( 1492 loc, DenseElementsAttr::get(srcType, bitsToShiftRight)); 1493 Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues); 1494 return shr; 1495 } 1496 1497 /// Extracts an unsigned N-bit sequence from each element of a vector of bytes, 1498 /// starting at the specified bit index. 1499 /// The `bitIdx` starts at 0 from the LSB and moves to the left. 1500 /// 1501 /// Example for a single element: 1502 /// Extract numBits=2 starting at bitIdx=2 1503 /// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0] 1504 /// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0] 1505 /// target = [. . . . ^ ^ . .] 1506 /// 1507 /// The target sequence is [10](decimal=2) as unsigned 2-bit integer. 1508 /// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer. 1509 /// 1510 /// src = [01 01 10 10] 1511 /// mask = [00 00 00 11] 1512 /// shr = arith.shrui(src, 2) = [00 01 01 10] 1513 /// result = arith.andi(shr, mask) = [00 00 00 10] 1514 /// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be 1515 /// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking. 1516 /// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift 1517 /// left when the index is 0. 1518 static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter, 1519 Location loc, Value src, 1520 int bitIdx, int numBits) { 1521 assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 && 1522 "Invalid bitIdx range"); 1523 auto srcType = cast<VectorType>(src.getType()); 1524 int8_t bitsToShiftRight = bitIdx; 1525 Value shr = src; 1526 if (bitsToShiftRight != 0) { 1527 Value shiftRightValues = rewriter.create<arith::ConstantOp>( 1528 loc, DenseElementsAttr::get(srcType, bitsToShiftRight)); 1529 shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues); 1530 } 1531 if (bitIdx + numBits == 8) { 1532 return shr; 1533 } 1534 uint8_t lowBitsMask = (1 << numBits) - 1; 1535 Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>( 1536 loc, DenseElementsAttr::get(srcType, lowBitsMask)); 1537 return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues); 1538 } 1539 1540 using ExtractNBitsFn = 1541 std::function<Value(PatternRewriter &, Location, Value, int, int)>; 1542 1543 /// Rewrite the i4 -> i8 extension into a sequence of shuffles and 1544 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations. 1545 static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc, 1546 Value srcValue, const ExtractNBitsFn &extFn) { 1547 [[maybe_unused]] auto srcVecType = cast<VectorType>(srcValue.getType()); 1548 assert(srcVecType.getElementType().isSignlessInteger(4) && 1549 "Expected i4 type"); 1550 1551 // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>. 1552 Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue); 1553 1554 // 2. Extend i4 elements to i8 elements. Low i4 elemens of each 1555 // byte are place in one vector and the high i4 elements in another vector. 1556 Value low = extFn(rewriter, loc, i8Vector, 0, 4); 1557 Value high = extFn(rewriter, loc, i8Vector, 4, 4); 1558 1559 // 3. Interleave low and high i8 elements. 1560 return rewriter.create<vector::InterleaveOp>(loc, low, high); 1561 } 1562 1563 /// Rewrite the i2 -> i8 extension into a sequence of shuffles and 1564 /// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations. 1565 static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc, 1566 Value srcValue, const ExtractNBitsFn &extFn) { 1567 [[maybe_unused]] VectorType srcVecType = cast<VectorType>(srcValue.getType()); 1568 assert(srcVecType.getElementType().isSignlessInteger(2) && 1569 "Expected i2 type"); 1570 1571 // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>. 1572 Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue); 1573 1574 // 2. Extract each i2 element 1575 // Positon 0 (bits 0-1) 1576 Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2); 1577 // Position 1 (bits 2-3) 1578 Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2); 1579 // Position 2 (bits 4-5) 1580 Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2); 1581 // Position 3 (bits 6-7) 1582 Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2); 1583 1584 // 3. Interleave all 4 elements by first interleaving 1585 // even elements and then odd 1586 // vec0 = [0,0,0,0],... 1587 // vec1 = [1,1,1,1],... 1588 // vec2 = [2,2,2,2],... 1589 // vec3 = [3,3,3,3],... 1590 // 02 = [0,2,0,2,0,2,0,2],... 1591 // 13 = [1,3,1,3,1,3,1,3],... 1592 // 0213 = [0,1,2,3,...],... 1593 Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, vec0, vec2); 1594 Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, vec1, vec3); 1595 return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13); 1596 } 1597 1598 /// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise 1599 /// ops to avoid leaving LLVM to scramble with peephole optimizations. 1600 static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc, 1601 Value srcValue) { 1602 VectorType srcVecType = cast<VectorType>(srcValue.getType()); 1603 assert(srcVecType.getElementType().isSignlessInteger(8) && 1604 "Expected i8 type"); 1605 1606 // 1. De-interleave low and high i8 elements. 1607 auto deinterleaveOp = rewriter.create<vector::DeinterleaveOp>(loc, srcValue); 1608 1609 // 2. Zero out the upper side of each low i8 element. 1610 constexpr int8_t i8LowBitMask = 0x0F; 1611 VectorType deinterI8VecType = deinterleaveOp.getResultVectorType(); 1612 Value zeroOutMask = rewriter.create<arith::ConstantOp>( 1613 loc, DenseElementsAttr::get(deinterI8VecType, i8LowBitMask)); 1614 Value zeroOutLow = rewriter.create<arith::AndIOp>( 1615 loc, deinterleaveOp.getRes1(), zeroOutMask); 1616 1617 // 3. Move high i4 values to upper side of the byte. 1618 constexpr int8_t bitsToShift = 4; 1619 auto shiftValues = rewriter.create<arith::ConstantOp>( 1620 loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift)); 1621 Value shlHigh = rewriter.create<arith::ShLIOp>(loc, deinterleaveOp.getRes2(), 1622 shiftValues); 1623 1624 // 4. Merge high and low i4 values. 1625 auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh); 1626 1627 // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>. 1628 auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type()); 1629 return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp); 1630 } 1631 1632 namespace { 1633 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take 1634 /// advantage of high-level information to avoid leaving LLVM to scramble with 1635 /// peephole optimizations. 1636 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> { 1637 using OpRewritePattern::OpRewritePattern; 1638 1639 LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp, 1640 PatternRewriter &rewriter) const override { 1641 // The source must be a trunc op. 1642 auto truncOp = 1643 bitCastOp.getSource().template getDefiningOp<arith::TruncIOp>(); 1644 if (!truncOp) 1645 return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source"); 1646 1647 // Set up the BitCastRewriter and verify the precondition. 1648 VectorType sourceVectorType = bitCastOp.getSourceVectorType(); 1649 VectorType targetVectorType = bitCastOp.getResultVectorType(); 1650 BitCastRewriter bcr(sourceVectorType, targetVectorType); 1651 if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp))) 1652 return failure(); 1653 1654 // Perform the rewrite. 1655 Value truncValue = truncOp.getIn(); 1656 auto shuffledElementType = 1657 cast<IntegerType>(getElementTypeOrSelf(truncValue.getType())); 1658 Value runningResult; 1659 for (const BitCastRewriter ::Metadata &metadata : 1660 bcr.precomputeMetadata(shuffledElementType)) { 1661 runningResult = bcr.genericRewriteStep( 1662 rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata); 1663 } 1664 1665 // Finalize the rewrite. 1666 bool narrowing = targetVectorType.getElementTypeBitWidth() <= 1667 shuffledElementType.getIntOrFloatBitWidth(); 1668 if (narrowing) { 1669 if (runningResult.getType() == bitCastOp.getResultVectorType()) { 1670 rewriter.replaceOp(bitCastOp, runningResult); 1671 } else { 1672 rewriter.replaceOpWithNewOp<arith::TruncIOp>( 1673 bitCastOp, bitCastOp.getResultVectorType(), runningResult); 1674 } 1675 } else { 1676 if (runningResult.getType() == bitCastOp.getResultVectorType()) { 1677 rewriter.replaceOp(bitCastOp, runningResult); 1678 } else { 1679 rewriter.replaceOpWithNewOp<arith::ExtUIOp>( 1680 bitCastOp, bitCastOp.getResultVectorType(), runningResult); 1681 } 1682 } 1683 1684 return success(); 1685 } 1686 }; 1687 } // namespace 1688 1689 //===----------------------------------------------------------------------===// 1690 // RewriteExtOfBitCast 1691 //===----------------------------------------------------------------------===// 1692 1693 namespace { 1694 /// Rewrite ext{s,u}i(bitcast) to a sequence of shuffles and bitwise ops that 1695 /// take advantage of high-level information to avoid leaving LLVM to scramble 1696 /// with peephole optimizations. 1697 template <typename ExtOpType> 1698 struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> { 1699 using OpRewritePattern<ExtOpType>::OpRewritePattern; 1700 1701 RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit) 1702 : OpRewritePattern<ExtOpType>(context, benefit) {} 1703 1704 LogicalResult matchAndRewrite(ExtOpType extOp, 1705 PatternRewriter &rewriter) const override { 1706 // The source must be a bitcast op. 1707 auto bitCastOp = extOp.getIn().template getDefiningOp<vector::BitCastOp>(); 1708 if (!bitCastOp) 1709 return rewriter.notifyMatchFailure(extOp, "not a bitcast source"); 1710 1711 // Set up the BitCastRewriter and verify the precondition. 1712 VectorType sourceVectorType = bitCastOp.getSourceVectorType(); 1713 VectorType targetVectorType = bitCastOp.getResultVectorType(); 1714 BitCastRewriter bcr(sourceVectorType, targetVectorType); 1715 if (failed(bcr.commonPrecondition( 1716 rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp))) 1717 return failure(); 1718 1719 // Perform the rewrite. 1720 Value runningResult; 1721 Value sourceValue = bitCastOp.getSource(); 1722 auto shuffledElementType = 1723 cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType())); 1724 for (const BitCastRewriter::Metadata &metadata : 1725 bcr.precomputeMetadata(shuffledElementType)) { 1726 runningResult = bcr.genericRewriteStep( 1727 rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata); 1728 } 1729 1730 // Finalize the rewrite. 1731 bool narrowing = 1732 cast<VectorType>(extOp.getOut().getType()).getElementTypeBitWidth() <= 1733 shuffledElementType.getIntOrFloatBitWidth(); 1734 if (narrowing) { 1735 rewriter.replaceOpWithNewOp<arith::TruncIOp>( 1736 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult); 1737 } else { 1738 rewriter.replaceOpWithNewOp<ExtOpType>( 1739 extOp, cast<VectorType>(extOp.getOut().getType()), runningResult); 1740 } 1741 1742 return success(); 1743 } 1744 }; 1745 1746 /// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and 1747 /// bitwise ops that take advantage of high-level information to avoid leaving 1748 /// LLVM to scramble with peephole optimizations. Templated to choose between 1749 /// signed and unsigned conversions. 1750 /// 1751 /// For example (signed): 1752 /// arith.extsi %in : vector<8xi4> to vector<8xi32> 1753 /// is rewriten as 1754 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> 1755 /// %1 = arith.shli %0, 4 : vector<4xi8> 1756 /// %2 = arith.shrsi %1, 4 : vector<4xi8> 1757 /// %3 = arith.shrsi %0, 4 : vector<4xi8> 1758 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> 1759 /// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32> 1760 /// 1761 /// arith.sitofp %in : vector<8xi4> to vector<8xf32> 1762 /// is rewriten as 1763 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> 1764 /// %1 = arith.shli %0, 4 : vector<4xi8> 1765 /// %2 = arith.shrsi %1, 4 : vector<4xi8> 1766 /// %3 = arith.shrsi %0, 4 : vector<4xi8> 1767 /// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> 1768 /// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32> 1769 /// 1770 /// Example (unsigned): 1771 /// arith.extui %in : vector<8xi4> to vector<8xi32> 1772 /// is rewritten as 1773 /// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> 1774 /// %1 = arith.andi %0, 15 : vector<4xi8> 1775 /// %2 = arith.shrui %0, 4 : vector<4xi8> 1776 /// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8> 1777 /// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32> 1778 /// 1779 template <typename ConversionOpType, bool isSigned> 1780 struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> { 1781 using OpRewritePattern<ConversionOpType>::OpRewritePattern; 1782 1783 LogicalResult matchAndRewrite(ConversionOpType conversionOp, 1784 PatternRewriter &rewriter) const override { 1785 // Verify the preconditions. 1786 Value srcValue = conversionOp.getIn(); 1787 auto srcVecType = dyn_cast<VectorType>(srcValue.getType()); 1788 auto dstVecType = dyn_cast<VectorType>(conversionOp.getType()); 1789 1790 if (failed( 1791 commonConversionPrecondition(rewriter, dstVecType, conversionOp))) 1792 return failure(); 1793 1794 // Check general alignment preconditions. 1795 if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType, 1796 conversionOp))) 1797 return failure(); 1798 1799 // Perform the rewrite. 1800 Location loc = conversionOp.getLoc(); 1801 const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8 1802 : extractNBitsPerByteAndExtendToI8; 1803 Value subByteExt; 1804 switch (srcVecType.getElementType().getIntOrFloatBitWidth()) { 1805 case 2: 1806 subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn); 1807 break; 1808 case 4: 1809 subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn); 1810 break; 1811 default: 1812 return failure(); 1813 } 1814 1815 // Finalize the rewrite. 1816 rewriter.replaceOpWithNewOp<ConversionOpType>( 1817 conversionOp, conversionOp.getType(), subByteExt); 1818 return success(); 1819 } 1820 }; 1821 1822 /// Rewrite the i8 -> i4 part of any truncation into a deinterleave and 1823 /// bitwise ops that take advantage of high-level information to avoid leaving 1824 /// LLVM to scramble with peephole optimizations. 1825 /// 1826 /// For example: 1827 /// arith.trunci %in : vector<8xi32> to vector<8xi4> 1828 /// is rewriten as 1829 /// 1830 /// %cst = arith.constant dense<15> : vector<4xi8> 1831 /// %cst_0 = arith.constant dense<4> : vector<4xi8> 1832 /// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8> 1833 /// %2 = arith.andi %0, %cst : vector<4xi8> 1834 /// %3 = arith.shli %1, %cst_0 : vector<4xi8> 1835 /// %4 = arith.ori %2, %3 : vector<4xi8> 1836 /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4> 1837 /// 1838 struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> { 1839 using OpRewritePattern<arith::TruncIOp>::OpRewritePattern; 1840 1841 LogicalResult matchAndRewrite(arith::TruncIOp truncOp, 1842 PatternRewriter &rewriter) const override { 1843 // Verify the preconditions. 1844 Value srcValue = truncOp.getIn(); 1845 auto srcVecType = dyn_cast<VectorType>(srcValue.getType()); 1846 auto dstVecType = dyn_cast<VectorType>(truncOp.getType()); 1847 if (!srcVecType || !dstVecType) 1848 return failure(); 1849 1850 if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp))) 1851 return failure(); 1852 1853 // TODO: Add support for truncating to i2. 1854 if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2) 1855 return failure(); 1856 1857 // Check general alignment preconditions. We invert the src/dst type order 1858 // to reuse the existing precondition logic. 1859 if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType, 1860 truncOp))) 1861 return failure(); 1862 1863 // Create a new iX -> i8 truncation op. 1864 Location loc = truncOp.getLoc(); 1865 auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type()); 1866 Value i8TruncVal = 1867 rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue); 1868 1869 // Rewrite the i8 -> i4 truncation part. 1870 Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal); 1871 1872 // Finalize the rewrite. 1873 rewriter.replaceOp(truncOp, subByteTrunc); 1874 return success(); 1875 } 1876 }; 1877 1878 /// Rewrite a sub-byte vector transpose into a sequence of instructions that 1879 /// perform the transpose on wider (byte) element types. 1880 /// For example: 1881 /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4> 1882 /// 1883 /// is rewritten as: 1884 /// 1885 /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8> 1886 /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8> 1887 /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4> 1888 /// 1889 struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> { 1890 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 1891 1892 RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit) 1893 : OpRewritePattern<vector::TransposeOp>(context, benefit) {} 1894 1895 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 1896 PatternRewriter &rewriter) const override { 1897 // Precondition: sub-byte integer transpose. 1898 constexpr unsigned minNativeBitwidth = 8; 1899 VectorType srcSubByteVecType = transposeOp.getSourceVectorType(); 1900 if (!srcSubByteVecType.getElementType().isSignlessInteger() || 1901 srcSubByteVecType.getElementTypeBitWidth() >= minNativeBitwidth) { 1902 return rewriter.notifyMatchFailure(transposeOp, 1903 "not a sub-byte transpose"); 1904 } 1905 1906 // Perform the rewrite. 1907 Location loc = transposeOp.getLoc(); 1908 // Signed/unsigned interpretation shouldn't matter here as we are just 1909 // transposing the elements and truncating them back to the original size. 1910 // TODO: Use unsigned extension (more efficient) when emulation or backend 1911 // support is available. 1912 auto srcNativeVecType = srcSubByteVecType.cloneWith( 1913 std::nullopt, rewriter.getIntegerType(minNativeBitwidth)); 1914 Value extOp = rewriter.create<arith::ExtSIOp>(loc, srcNativeVecType, 1915 transposeOp.getVector()); 1916 Value newTranspose = rewriter.create<vector::TransposeOp>( 1917 loc, extOp, transposeOp.getPermutation()); 1918 VectorType dstSubByteVecType = transposeOp.getResultVectorType(); 1919 rewriter.replaceOpWithNewOp<arith::TruncIOp>(transposeOp, dstSubByteVecType, 1920 newTranspose); 1921 return success(); 1922 } 1923 }; 1924 1925 } // namespace 1926 1927 //===----------------------------------------------------------------------===// 1928 // Public Interface Definition 1929 //===----------------------------------------------------------------------===// 1930 1931 void vector::populateVectorNarrowTypeEmulationPatterns( 1932 const arith::NarrowTypeEmulationConverter &typeConverter, 1933 RewritePatternSet &patterns) { 1934 1935 // Populate `vector.*` conversion patterns. 1936 patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore, 1937 ConvertVectorMaskedStore, ConvertVectorTransferRead>( 1938 typeConverter, patterns.getContext()); 1939 } 1940 1941 void vector::populateVectorNarrowTypeRewritePatterns( 1942 RewritePatternSet &patterns, PatternBenefit benefit) { 1943 patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>, 1944 RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(), 1945 benefit); 1946 1947 // Patterns for aligned cases. We set higher priority as they are expected to 1948 // generate better performance for aligned cases. 1949 patterns.add<RewriteAlignedSubByteIntExt<arith::ExtSIOp, /*isSigned=*/true>, 1950 RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>, 1951 RewriteAlignedSubByteIntTrunc>(patterns.getContext(), 1952 benefit.getBenefit() + 1); 1953 patterns 1954 .add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>, 1955 RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>( 1956 patterns.getContext(), benefit.getBenefit() + 1); 1957 } 1958 1959 void vector::populateVectorTransposeNarrowTypeRewritePatterns( 1960 RewritePatternSet &patterns, PatternBenefit benefit) { 1961 patterns.add<RewriteVectorTranspose>(patterns.getContext(), benefit); 1962 } 1963