1 //===- VectorUnrollDistribute.cpp - patterns to do vector unrolling -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to do vector unrolling and vector distribution. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Utils/IndexingUtils.h" 15 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 16 #include "mlir/IR/ImplicitLocOpBuilder.h" 17 #include "mlir/Interfaces/VectorInterfaces.h" 18 #include "llvm/ADT/MapVector.h" 19 #include "llvm/ADT/STLExtras.h" 20 #include "llvm/Support/Debug.h" 21 #include <numeric> 22 #include <optional> 23 24 #define DEBUG_TYPE "vector-unroll" 25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") 26 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") 27 28 using namespace mlir; 29 using namespace mlir::vector; 30 31 /// Compute the indices of the slice `index` for a tranfer op. 32 static SmallVector<Value> sliceTransferIndices(ArrayRef<int64_t> elementOffsets, 33 ArrayRef<Value> indices, 34 AffineMap permutationMap, 35 Location loc, 36 OpBuilder &builder) { 37 MLIRContext *ctx = builder.getContext(); 38 auto isBroadcast = [](AffineExpr expr) { 39 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) 40 return constExpr.getValue() == 0; 41 return false; 42 }; 43 // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. 44 SmallVector<Value> slicedIndices(indices); 45 for (const auto &dim : llvm::enumerate(permutationMap.getResults())) { 46 if (isBroadcast(dim.value())) 47 continue; 48 unsigned pos = cast<AffineDimExpr>(dim.value()).getPosition(); 49 auto expr = getAffineDimExpr(0, builder.getContext()) + 50 getAffineConstantExpr(elementOffsets[dim.index()], ctx); 51 auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); 52 slicedIndices[pos] = 53 builder.create<affine::AffineApplyOp>(loc, map, indices[pos]); 54 } 55 return slicedIndices; 56 } 57 58 // Clones `op` into a new operations that takes `operands` and returns 59 // `resultTypes`. 60 static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, 61 Operation *op, 62 ArrayRef<Value> operands, 63 ArrayRef<Type> resultTypes) { 64 return builder.create(loc, op->getName().getIdentifier(), operands, 65 resultTypes, op->getAttrs()); 66 } 67 68 /// Return the target shape for unrolling for the given `op`. Return 69 /// std::nullopt if the op shouldn't be or cannot be unrolled. 70 static std::optional<SmallVector<int64_t>> 71 getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { 72 LDBG(""); 73 LDBG("Get unroll shape for op " << op->getName().getStringRef()); 74 if (options.filterConstraint && failed(options.filterConstraint(op))) { 75 LDBG("--no filter constraint -> BAIL"); 76 return std::nullopt; 77 } 78 assert(options.nativeShape && 79 "vector unrolling expects the native shape or native" 80 "shape call back function to be set"); 81 auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); 82 if (!unrollableVectorOp) { 83 LDBG("--not an unrollable op -> BAIL"); 84 return std::nullopt; 85 } 86 auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); 87 if (!maybeUnrollShape) { 88 LDBG("--could not get shape of op " << *op << " -> BAIL"); 89 return std::nullopt; 90 } 91 LLVM_DEBUG( 92 llvm::interleaveComma(*maybeUnrollShape, DBGS() << "--vector op shape: "); 93 llvm::dbgs() << "\n";); 94 95 std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op); 96 if (!targetShape) { 97 LDBG("--no unrolling target shape defined " << *op << "-> SKIP"); 98 return std::nullopt; 99 } 100 LLVM_DEBUG(llvm::interleaveComma(*targetShape, DBGS() << "--target shape: "); 101 llvm::dbgs() << "\n";); 102 103 auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape); 104 if (!maybeShapeRatio) { 105 LDBG("--could not compute integral shape ratio -> BAIL"); 106 return std::nullopt; 107 } 108 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { 109 LDBG("--no unrolling needed -> SKIP"); 110 return std::nullopt; 111 } 112 LDBG("--found an integral shape ratio to unroll to -> SUCCESS"); 113 return targetShape; 114 } 115 116 static SmallVector<int64_t> 117 getUnrollOrder(unsigned numLoops, Operation *op, 118 const vector::UnrollVectorOptions &options) { 119 SmallVector<int64_t> loopOrder = 120 llvm::to_vector(llvm::seq<int64_t>(0, static_cast<int64_t>(numLoops))); 121 if (options.traversalOrderCallback != nullptr) { 122 std::optional<SmallVector<int64_t>> order = 123 options.traversalOrderCallback(op); 124 if (order) { 125 loopOrder = std::move(*order); 126 } 127 } 128 return loopOrder; 129 } 130 131 namespace { 132 133 struct UnrollTransferReadPattern 134 : public OpRewritePattern<vector::TransferReadOp> { 135 UnrollTransferReadPattern(MLIRContext *context, 136 const vector::UnrollVectorOptions &options, 137 PatternBenefit benefit = 1) 138 : OpRewritePattern<vector::TransferReadOp>(context, benefit), 139 options(options) {} 140 141 LogicalResult matchAndRewrite(vector::TransferReadOp readOp, 142 PatternRewriter &rewriter) const override { 143 // TODO: support 0-d corner case. 144 if (readOp.getTransferRank() == 0) 145 return failure(); 146 if (readOp.getMask()) 147 return failure(); 148 auto targetShape = getTargetShape(options, readOp); 149 if (!targetShape) 150 return failure(); 151 auto sourceVectorType = readOp.getVectorType(); 152 SmallVector<int64_t> strides(targetShape->size(), 1); 153 Location loc = readOp.getLoc(); 154 ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); 155 156 // Prepare the result vector; 157 Value result = rewriter.create<arith::ConstantOp>( 158 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); 159 auto targetType = 160 VectorType::get(*targetShape, sourceVectorType.getElementType()); 161 SmallVector<Value> originalIndices(readOp.getIndices().begin(), 162 readOp.getIndices().end()); 163 SmallVector<int64_t> loopOrder = 164 getUnrollOrder(originalSize.size(), readOp, options); 165 for (SmallVector<int64_t> elementOffsets : 166 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { 167 SmallVector<Value> indices = 168 sliceTransferIndices(elementOffsets, originalIndices, 169 readOp.getPermutationMap(), loc, rewriter); 170 auto slicedRead = rewriter.create<vector::TransferReadOp>( 171 loc, targetType, readOp.getSource(), indices, 172 readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(), 173 readOp.getInBoundsAttr()); 174 175 result = rewriter.create<vector::InsertStridedSliceOp>( 176 loc, slicedRead, result, elementOffsets, strides); 177 } 178 rewriter.replaceOp(readOp, result); 179 return success(); 180 } 181 182 private: 183 vector::UnrollVectorOptions options; 184 }; 185 186 struct UnrollTransferWritePattern 187 : public OpRewritePattern<vector::TransferWriteOp> { 188 UnrollTransferWritePattern(MLIRContext *context, 189 const vector::UnrollVectorOptions &options, 190 PatternBenefit benefit = 1) 191 : OpRewritePattern<vector::TransferWriteOp>(context, benefit), 192 options(options) {} 193 194 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 195 PatternRewriter &rewriter) const override { 196 // TODO: support 0-d corner case. 197 if (writeOp.getTransferRank() == 0) 198 return failure(); 199 200 if (writeOp.getMask()) 201 return failure(); 202 auto targetShape = getTargetShape(options, writeOp); 203 if (!targetShape) 204 return failure(); 205 auto sourceVectorType = writeOp.getVectorType(); 206 SmallVector<int64_t> strides(targetShape->size(), 1); 207 Location loc = writeOp.getLoc(); 208 ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); 209 SmallVector<Value> originalIndices(writeOp.getIndices().begin(), 210 writeOp.getIndices().end()); 211 SmallVector<int64_t> loopOrder = 212 getUnrollOrder(originalSize.size(), writeOp, options); 213 Value resultTensor; 214 for (SmallVector<int64_t> elementOffsets : 215 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { 216 Value slicedVector = rewriter.create<vector::ExtractStridedSliceOp>( 217 loc, writeOp.getVector(), elementOffsets, *targetShape, strides); 218 SmallVector<Value> indices = 219 sliceTransferIndices(elementOffsets, originalIndices, 220 writeOp.getPermutationMap(), loc, rewriter); 221 Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>( 222 loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(), 223 indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); 224 // For the tensor case update the destination for the next transfer write. 225 if (!slicedWrite->getResults().empty()) 226 resultTensor = slicedWrite->getResult(0); 227 } 228 if (resultTensor) 229 rewriter.replaceOp(writeOp, resultTensor); 230 else 231 rewriter.eraseOp(writeOp); 232 return success(); 233 } 234 235 private: 236 vector::UnrollVectorOptions options; 237 }; 238 239 struct OffsetMapInfo { 240 static SmallVector<int64_t> getEmptyKey() { return {int64_t(-1)}; } 241 242 static SmallVector<int64_t> getTombstoneKey() { return {int64_t(-2)}; } 243 244 static unsigned getHashValue(const SmallVector<int64_t> &v) { 245 return static_cast<unsigned>(llvm::hash_combine_range(v.begin(), v.end())); 246 } 247 248 static bool isEqual(const SmallVector<int64_t> &lhs, 249 const SmallVector<int64_t> &rhs) { 250 return lhs == rhs; 251 } 252 }; 253 254 struct UnrollContractionPattern 255 : public OpRewritePattern<vector::ContractionOp> { 256 UnrollContractionPattern(MLIRContext *context, 257 const vector::UnrollVectorOptions &options, 258 PatternBenefit benefit = 1) 259 : OpRewritePattern<vector::ContractionOp>(context, benefit), 260 options(options) {} 261 262 LogicalResult matchAndRewrite(vector::ContractionOp contractOp, 263 PatternRewriter &rewriter) const override { 264 auto targetShape = getTargetShape(options, contractOp); 265 if (!targetShape) 266 return failure(); 267 auto dstVecType = cast<VectorType>(contractOp.getResultType()); 268 SmallVector<int64_t> originalSize = *contractOp.getShapeForUnroll(); 269 270 Location loc = contractOp.getLoc(); 271 unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); 272 AffineMap dstAffineMap = contractOp.getIndexingMapsArray()[accIndex]; 273 llvm::MapVector< 274 SmallVector<int64_t>, Value, 275 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 276 accCache; 277 278 SmallVector<int64_t> loopOrder = getUnrollOrder( 279 contractOp.getIteratorTypes().size(), contractOp, options); 280 281 for (SmallVector<int64_t> offsets : 282 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { 283 SmallVector<Value> slicesOperands(contractOp.getNumOperands()); 284 285 // Helper to compute the new shape of each operand and extract the slice. 286 auto extractOperand = [&](unsigned index, Value operand, 287 AffineMap permutationMap, 288 ArrayRef<int64_t> operandOffets) { 289 SmallVector<int64_t> operandShape = applyPermutationMap( 290 permutationMap, ArrayRef<int64_t>(*targetShape)); 291 SmallVector<int64_t> operandStrides(operandOffets.size(), 1); 292 slicesOperands[index] = rewriter.create<vector::ExtractStridedSliceOp>( 293 loc, operand, operandOffets, operandShape, operandStrides); 294 }; 295 296 // Extract the new lhs operand. 297 AffineMap lhsPermutationMap = contractOp.getIndexingMapsArray()[0]; 298 SmallVector<int64_t> lhsOffets = 299 applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets)); 300 extractOperand(0, contractOp.getLhs(), lhsPermutationMap, lhsOffets); 301 302 // Extract the new rhs operand. 303 AffineMap rhsPermutationMap = contractOp.getIndexingMapsArray()[1]; 304 SmallVector<int64_t> rhsOffets = 305 applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets)); 306 extractOperand(1, contractOp.getRhs(), rhsPermutationMap, rhsOffets); 307 308 AffineMap accPermutationMap = contractOp.getIndexingMapsArray()[2]; 309 SmallVector<int64_t> accOffets = 310 applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets)); 311 // If a version of the accumulator has already been computed, use it 312 // otherwise extract the first version from the original operand. 313 auto *accIt = accCache.find(accOffets); 314 if (accIt != accCache.end()) 315 slicesOperands[2] = accIt->second; 316 else 317 extractOperand(2, contractOp.getAcc(), accPermutationMap, accOffets); 318 319 SmallVector<int64_t> dstShape = 320 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(*targetShape)); 321 auto targetType = VectorType::get(dstShape, dstVecType.getElementType()); 322 Operation *newOp = cloneOpWithOperandsAndTypes( 323 rewriter, loc, contractOp, slicesOperands, targetType); 324 325 SmallVector<int64_t> dstOffets = 326 applyPermutationMap(dstAffineMap, ArrayRef<int64_t>(offsets)); 327 // Save the accumulated value untill all the loops are unrolled since 328 // reduction loop keep updating the accumulator. 329 accCache[dstOffets] = newOp->getResult(0); 330 } 331 // Assemble back the accumulator into a single vector. 332 Value result = rewriter.create<arith::ConstantOp>( 333 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 334 for (const auto &it : accCache) { 335 SmallVector<int64_t> dstStrides(it.first.size(), 1); 336 result = rewriter.create<vector::InsertStridedSliceOp>( 337 loc, it.second, result, it.first, dstStrides); 338 } 339 rewriter.replaceOp(contractOp, result); 340 return success(); 341 } 342 343 private: 344 vector::UnrollVectorOptions options; 345 }; 346 347 struct UnrollMultiReductionPattern 348 : public OpRewritePattern<vector::MultiDimReductionOp> { 349 UnrollMultiReductionPattern(MLIRContext *context, 350 const vector::UnrollVectorOptions &options, 351 PatternBenefit benefit = 1) 352 : OpRewritePattern<vector::MultiDimReductionOp>(context, benefit), 353 options(options) {} 354 355 LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, 356 PatternRewriter &rewriter) const override { 357 std::optional<SmallVector<int64_t>> targetShape = 358 getTargetShape(options, reductionOp); 359 if (!targetShape) 360 return failure(); 361 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); 362 llvm::MapVector< 363 SmallVector<int64_t>, Value, 364 llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>> 365 accCache; 366 Location loc = reductionOp.getLoc(); 367 368 // Stride of the ratios, this gives us the offsets of sliceCount in a basis 369 // of multiples of the targetShape. 370 for (SmallVector<int64_t> offsets : 371 StaticTileOffsetRange(originalSize, *targetShape)) { 372 SmallVector<Value> operands; 373 SmallVector<int64_t> operandStrides(offsets.size(), 1); 374 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 375 loc, reductionOp.getSource(), offsets, *targetShape, operandStrides); 376 operands.push_back(slicedOperand); 377 SmallVector<int64_t> dstShape; 378 SmallVector<int64_t> destOffset; 379 for (size_t i : llvm::seq(size_t(0), targetShape->size())) { 380 if (!reductionOp.isReducedDim(i)) { 381 destOffset.push_back(offsets[i]); 382 dstShape.push_back((*targetShape)[i]); 383 } 384 } 385 Value acc; 386 SmallVector<int64_t> accStrides(destOffset.size(), 1); 387 // If a version of the accumulator has already been computed, use it 388 // otherwise extract the first version from the original operand. 389 auto *accIt = accCache.find(destOffset); 390 if (accIt != accCache.end()) 391 acc = accIt->second; 392 else 393 acc = rewriter.create<vector::ExtractStridedSliceOp>( 394 loc, reductionOp.getAcc(), destOffset, dstShape, accStrides); 395 operands.push_back(acc); 396 auto targetType = VectorType::get( 397 dstShape, reductionOp.getSourceVectorType().getElementType()); 398 Operation *newOp = cloneOpWithOperandsAndTypes(rewriter, loc, reductionOp, 399 operands, targetType); 400 Value result = newOp->getResult(0); 401 accCache[destOffset] = result; 402 } 403 // Assemble back the accumulator into a single vector. 404 Value result = rewriter.create<arith::ConstantOp>( 405 loc, reductionOp.getDestType(), 406 rewriter.getZeroAttr(reductionOp.getDestType())); 407 for (const auto &it : accCache) { 408 SmallVector<int64_t> dstStrides(it.first.size(), 1); 409 result = rewriter.create<vector::InsertStridedSliceOp>( 410 loc, it.second, result, it.first, dstStrides); 411 } 412 rewriter.replaceOp(reductionOp, result); 413 return success(); 414 } 415 416 private: 417 vector::UnrollVectorOptions options; 418 }; 419 420 struct UnrollElementwisePattern : public RewritePattern { 421 UnrollElementwisePattern(MLIRContext *context, 422 const vector::UnrollVectorOptions &options, 423 PatternBenefit benefit = 1) 424 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), 425 options(options) {} 426 427 LogicalResult matchAndRewrite(Operation *op, 428 PatternRewriter &rewriter) const override { 429 if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) 430 return failure(); 431 auto targetShape = getTargetShape(options, op); 432 if (!targetShape) 433 return failure(); 434 auto dstVecType = cast<VectorType>(op->getResult(0).getType()); 435 SmallVector<int64_t> originalSize = 436 *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); 437 Location loc = op->getLoc(); 438 // Prepare the result vector. 439 Value result = rewriter.create<arith::ConstantOp>( 440 loc, dstVecType, rewriter.getZeroAttr(dstVecType)); 441 SmallVector<int64_t> strides(targetShape->size(), 1); 442 VectorType newVecType = 443 VectorType::get(*targetShape, dstVecType.getElementType()); 444 445 // Create the unrolled computation. 446 for (SmallVector<int64_t> offsets : 447 StaticTileOffsetRange(originalSize, *targetShape)) { 448 SmallVector<Value> extractOperands; 449 for (OpOperand &operand : op->getOpOperands()) { 450 auto vecType = dyn_cast<VectorType>(operand.get().getType()); 451 if (!vecType) { 452 extractOperands.push_back(operand.get()); 453 continue; 454 } 455 extractOperands.push_back( 456 rewriter.create<vector::ExtractStridedSliceOp>( 457 loc, operand.get(), offsets, *targetShape, strides)); 458 } 459 Operation *newOp = cloneOpWithOperandsAndTypes( 460 rewriter, loc, op, extractOperands, newVecType); 461 result = rewriter.create<vector::InsertStridedSliceOp>( 462 loc, newOp->getResult(0), result, offsets, strides); 463 } 464 rewriter.replaceOp(op, result); 465 return success(); 466 } 467 468 private: 469 vector::UnrollVectorOptions options; 470 }; 471 472 struct UnrollReductionPattern : public OpRewritePattern<vector::ReductionOp> { 473 UnrollReductionPattern(MLIRContext *context, 474 const vector::UnrollVectorOptions &options, 475 PatternBenefit benefit = 1) 476 : OpRewritePattern<vector::ReductionOp>(context, benefit), 477 options(options) {} 478 479 LogicalResult matchAndRewrite(vector::ReductionOp reductionOp, 480 PatternRewriter &rewriter) const override { 481 std::optional<SmallVector<int64_t>> targetShape = 482 getTargetShape(options, reductionOp); 483 if (!targetShape) 484 return failure(); 485 SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll(); 486 487 // Create unrolled vector reduction. 488 Location loc = reductionOp.getLoc(); 489 Value accumulator = nullptr; 490 for (SmallVector<int64_t> offsets : 491 StaticTileOffsetRange(originalSize, *targetShape)) { 492 SmallVector<int64_t> strides(offsets.size(), 1); 493 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 494 loc, reductionOp.getVector(), offsets, *targetShape, strides); 495 Operation *newOp = cloneOpWithOperandsAndTypes( 496 rewriter, loc, reductionOp, slicedOperand, reductionOp.getType()); 497 Value result = newOp->getResult(0); 498 499 if (!accumulator) { 500 // This is the first reduction. 501 accumulator = result; 502 } else { 503 // On subsequent reduction, combine with the accumulator. 504 accumulator = makeArithReduction(rewriter, loc, reductionOp.getKind(), 505 accumulator, result); 506 } 507 } 508 509 rewriter.replaceOp(reductionOp, accumulator); 510 return success(); 511 } 512 513 private: 514 const vector::UnrollVectorOptions options; 515 }; 516 517 struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> { 518 UnrollTransposePattern(MLIRContext *context, 519 const vector::UnrollVectorOptions &options, 520 PatternBenefit benefit = 1) 521 : OpRewritePattern<vector::TransposeOp>(context, benefit), 522 options(options) {} 523 524 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 525 PatternRewriter &rewriter) const override { 526 if (transposeOp.getResultVectorType().getRank() == 0) 527 return failure(); 528 auto targetShape = getTargetShape(options, transposeOp); 529 if (!targetShape) 530 return failure(); 531 auto originalVectorType = transposeOp.getResultVectorType(); 532 SmallVector<int64_t> strides(targetShape->size(), 1); 533 Location loc = transposeOp.getLoc(); 534 ArrayRef<int64_t> originalSize = originalVectorType.getShape(); 535 536 // Prepare the result vector; 537 Value result = rewriter.create<arith::ConstantOp>( 538 loc, originalVectorType, rewriter.getZeroAttr(originalVectorType)); 539 ArrayRef<int64_t> permutation = transposeOp.getPermutation(); 540 541 // Unroll the computation. 542 for (SmallVector<int64_t> elementOffsets : 543 StaticTileOffsetRange(originalSize, *targetShape)) { 544 SmallVector<int64_t> permutedOffsets(elementOffsets.size()); 545 SmallVector<int64_t> permutedShape(elementOffsets.size()); 546 // Compute the source offsets and shape. 547 for (auto indices : llvm::enumerate(permutation)) { 548 permutedOffsets[indices.value()] = elementOffsets[indices.index()]; 549 permutedShape[indices.value()] = (*targetShape)[indices.index()]; 550 } 551 Value slicedOperand = rewriter.create<vector::ExtractStridedSliceOp>( 552 loc, transposeOp.getVector(), permutedOffsets, permutedShape, 553 strides); 554 Value transposedSlice = 555 rewriter.create<vector::TransposeOp>(loc, slicedOperand, permutation); 556 result = rewriter.create<vector::InsertStridedSliceOp>( 557 loc, transposedSlice, result, elementOffsets, strides); 558 } 559 rewriter.replaceOp(transposeOp, result); 560 return success(); 561 } 562 563 private: 564 vector::UnrollVectorOptions options; 565 }; 566 567 struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> { 568 UnrollGatherPattern(MLIRContext *context, 569 const vector::UnrollVectorOptions &options, 570 PatternBenefit benefit = 1) 571 : OpRewritePattern<vector::GatherOp>(context, benefit), options(options) { 572 } 573 574 LogicalResult matchAndRewrite(vector::GatherOp gatherOp, 575 PatternRewriter &rewriter) const override { 576 VectorType sourceVectorType = gatherOp.getVectorType(); 577 if (sourceVectorType.getRank() == 0) 578 return failure(); 579 auto targetShape = getTargetShape(options, gatherOp); 580 if (!targetShape) 581 return failure(); 582 SmallVector<int64_t> strides(targetShape->size(), 1); 583 Location loc = gatherOp.getLoc(); 584 ArrayRef<int64_t> originalSize = gatherOp.getVectorType().getShape(); 585 586 // Prepare the result vector; 587 Value result = rewriter.create<arith::ConstantOp>( 588 loc, sourceVectorType, rewriter.getZeroAttr(sourceVectorType)); 589 auto targetType = 590 VectorType::get(*targetShape, sourceVectorType.getElementType()); 591 592 SmallVector<int64_t> loopOrder = 593 getUnrollOrder(originalSize.size(), gatherOp, options); 594 for (SmallVector<int64_t> elementOffsets : 595 StaticTileOffsetRange(originalSize, *targetShape, loopOrder)) { 596 // To get the unrolled gather, extract the same slice based on the 597 // decomposed shape from each of the index, mask, and pass-through 598 // vectors. 599 Value indexSubVec = rewriter.create<vector::ExtractStridedSliceOp>( 600 loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); 601 Value maskSubVec = rewriter.create<vector::ExtractStridedSliceOp>( 602 loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); 603 Value passThruSubVec = rewriter.create<vector::ExtractStridedSliceOp>( 604 loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides); 605 auto slicedGather = rewriter.create<vector::GatherOp>( 606 loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), 607 indexSubVec, maskSubVec, passThruSubVec); 608 609 result = rewriter.create<vector::InsertStridedSliceOp>( 610 loc, slicedGather, result, elementOffsets, strides); 611 } 612 rewriter.replaceOp(gatherOp, result); 613 return success(); 614 } 615 616 private: 617 vector::UnrollVectorOptions options; 618 }; 619 620 } // namespace 621 622 void mlir::vector::populateVectorUnrollPatterns( 623 RewritePatternSet &patterns, const UnrollVectorOptions &options, 624 PatternBenefit benefit) { 625 patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern, 626 UnrollContractionPattern, UnrollElementwisePattern, 627 UnrollReductionPattern, UnrollMultiReductionPattern, 628 UnrollTransposePattern, UnrollGatherPattern>( 629 patterns.getContext(), options, benefit); 630 } 631