1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector transfer operations to SCF. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <numeric> 14 #include <optional> 15 #include <type_traits> 16 17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 18 19 #include "mlir/Dialect/Affine/IR/AffineOps.h" 20 #include "mlir/Dialect/Arith/IR/Arith.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/IR/SCF.h" 23 #include "mlir/Dialect/Tensor/IR/Tensor.h" 24 #include "mlir/Dialect/Vector/IR/VectorOps.h" 25 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 26 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 28 #include "mlir/IR/Builders.h" 29 #include "mlir/IR/ImplicitLocOpBuilder.h" 30 #include "mlir/Pass/Pass.h" 31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 32 #include "mlir/Transforms/Passes.h" 33 34 namespace mlir { 35 #define GEN_PASS_DEF_CONVERTVECTORTOSCF 36 #include "mlir/Conversion/Passes.h.inc" 37 } // namespace mlir 38 39 using namespace mlir; 40 using vector::TransferReadOp; 41 using vector::TransferWriteOp; 42 43 namespace { 44 45 /// Attribute name used for labeling transfer ops during progressive lowering. 46 static const char kPassLabel[] = "__vector_to_scf_lowering__"; 47 48 /// Return true if this transfer op operates on a source tensor. 49 static bool isTensorOp(VectorTransferOpInterface xferOp) { 50 if (isa<RankedTensorType>(xferOp.getShapedType())) { 51 if (isa<vector::TransferWriteOp>(xferOp)) { 52 // TransferWriteOps on tensors have a result. 53 assert(xferOp->getNumResults() > 0); 54 } 55 return true; 56 } 57 return false; 58 } 59 60 /// Patterns that inherit from this struct have access to 61 /// VectorTransferToSCFOptions. 62 template <typename OpTy> 63 struct VectorToSCFPattern : public OpRewritePattern<OpTy> { 64 explicit VectorToSCFPattern(MLIRContext *context, 65 VectorTransferToSCFOptions opt) 66 : OpRewritePattern<OpTy>(context), options(opt) {} 67 68 LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp, 69 PatternRewriter &rewriter) const { 70 if (isTensorOp(xferOp) && !options.lowerTensors) { 71 return rewriter.notifyMatchFailure( 72 xferOp, "lowering tensor transfers is disabled"); 73 } 74 return success(); 75 } 76 77 VectorTransferToSCFOptions options; 78 }; 79 80 /// Given a vector transfer op, calculate which dimension of the `source` 81 /// memref should be unpacked in the next application of TransferOpConversion. 82 /// A return value of std::nullopt indicates a broadcast. 83 template <typename OpTy> 84 static std::optional<int64_t> unpackedDim(OpTy xferOp) { 85 // TODO: support 0-d corner case. 86 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 87 auto map = xferOp.getPermutationMap(); 88 if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) { 89 return expr.getPosition(); 90 } 91 assert(xferOp.isBroadcastDim(0) && 92 "Expected AffineDimExpr or AffineConstantExpr"); 93 return std::nullopt; 94 } 95 96 /// Compute the permutation map for the new (N-1)-D vector transfer op. This 97 /// map is identical to the current permutation map, but the first result is 98 /// omitted. 99 template <typename OpTy> 100 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { 101 // TODO: support 0-d corner case. 102 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 103 auto map = xferOp.getPermutationMap(); 104 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), 105 b.getContext()); 106 } 107 108 /// Calculate the indices for the new vector transfer op. 109 /// 110 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... 111 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> 112 /// ^^^^^^ 113 /// `iv` is the iteration variable of the (new) surrounding loop. 114 template <typename OpTy> 115 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, 116 SmallVector<Value, 8> &indices) { 117 typename OpTy::Adaptor adaptor(xferOp); 118 // Corresponding memref dim of the vector dim that is unpacked. 119 auto dim = unpackedDim(xferOp); 120 auto prevIndices = adaptor.getIndices(); 121 indices.append(prevIndices.begin(), prevIndices.end()); 122 123 Location loc = xferOp.getLoc(); 124 bool isBroadcast = !dim.has_value(); 125 if (!isBroadcast) { 126 AffineExpr d0, d1; 127 bindDims(xferOp.getContext(), d0, d1); 128 Value offset = adaptor.getIndices()[*dim]; 129 indices[*dim] = 130 affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 131 } 132 } 133 134 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, 135 Value value) { 136 if (hasRetVal) { 137 assert(value && "Expected non-empty value"); 138 b.create<scf::YieldOp>(loc, value); 139 } else { 140 b.create<scf::YieldOp>(loc); 141 } 142 } 143 144 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask 145 /// is set to true. No such check is generated under following circumstances: 146 /// * xferOp does not have a mask. 147 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is 148 /// computed and attached to the new transfer op in the pattern.) 149 /// * The to-be-unpacked dim of xferOp is a broadcast. 150 template <typename OpTy> 151 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { 152 if (!xferOp.getMask()) 153 return Value(); 154 if (xferOp.getMaskType().getRank() != 1) 155 return Value(); 156 if (xferOp.isBroadcastDim(0)) 157 return Value(); 158 159 Location loc = xferOp.getLoc(); 160 return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv); 161 } 162 163 /// Helper function TransferOpConversion and TransferOp1dConversion. 164 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the 165 /// specified dimension `dim` with the loop iteration variable `iv`. 166 /// E.g., when unpacking dimension 0 from: 167 /// ``` 168 /// %vec = vector.transfer_read %A[%a, %b] %cst 169 /// : vector<5x4xf32>, memref<?x?xf32> 170 /// ``` 171 /// An if check similar to this will be generated inside the loop: 172 /// ``` 173 /// %d = memref.dim %A, %c0 : memref<?x?xf32> 174 /// if (%a + iv < %d) { 175 /// (in-bounds case) 176 /// } else { 177 /// (out-of-bounds case) 178 /// } 179 /// ``` 180 /// 181 /// If the transfer is 1D and has a mask, this function generates a more complex 182 /// check also accounts for potentially masked out elements. 183 /// 184 /// This function variant returns the value returned by `inBoundsCase` or 185 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in 186 /// `resultTypes`. 187 template <typename OpTy> 188 static Value generateInBoundsCheck( 189 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim, 190 TypeRange resultTypes, 191 function_ref<Value(OpBuilder &, Location)> inBoundsCase, 192 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 193 bool hasRetVal = !resultTypes.empty(); 194 Value cond; // Condition to be built... 195 196 // Condition check 1: Access in-bounds? 197 bool isBroadcast = !dim; // No in-bounds check for broadcasts. 198 Location loc = xferOp.getLoc(); 199 ImplicitLocOpBuilder lb(xferOp.getLoc(), b); 200 if (!xferOp.isDimInBounds(0) && !isBroadcast) { 201 Value memrefDim = 202 vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim); 203 AffineExpr d0, d1; 204 bindDims(xferOp.getContext(), d0, d1); 205 Value base = xferOp.getIndices()[*dim]; 206 Value memrefIdx = 207 affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); 208 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim, 209 memrefIdx); 210 } 211 212 // Condition check 2: Masked in? 213 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { 214 if (cond) 215 cond = lb.create<arith::AndIOp>(cond, maskCond); 216 else 217 cond = maskCond; 218 } 219 220 // If the condition is non-empty, generate an SCF::IfOp. 221 if (cond) { 222 auto check = lb.create<scf::IfOp>( 223 cond, 224 /*thenBuilder=*/ 225 [&](OpBuilder &b, Location loc) { 226 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); 227 }, 228 /*elseBuilder=*/ 229 [&](OpBuilder &b, Location loc) { 230 if (outOfBoundsCase) { 231 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); 232 } else { 233 b.create<scf::YieldOp>(loc); 234 } 235 }); 236 237 return hasRetVal ? check.getResult(0) : Value(); 238 } 239 240 // Condition is empty, no need for an SCF::IfOp. 241 return inBoundsCase(b, loc); 242 } 243 244 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have 245 /// a return value. Consequently, this function does not have a return value. 246 template <typename OpTy> 247 static void generateInBoundsCheck( 248 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim, 249 function_ref<void(OpBuilder &, Location)> inBoundsCase, 250 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 251 generateInBoundsCheck( 252 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(), 253 /*inBoundsCase=*/ 254 [&](OpBuilder &b, Location loc) { 255 inBoundsCase(b, loc); 256 return Value(); 257 }, 258 /*outOfBoundsCase=*/ 259 [&](OpBuilder &b, Location loc) { 260 if (outOfBoundsCase) 261 outOfBoundsCase(b, loc); 262 return Value(); 263 }); 264 } 265 266 /// Given an ArrayAttr, return a copy where the first element is dropped. 267 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) { 268 if (!attr) 269 return attr; 270 return ArrayAttr::get(b.getContext(), attr.getValue().drop_front()); 271 } 272 273 /// Add the pass label to a vector transfer op if its rank is not the target 274 /// rank. 275 template <typename OpTy> 276 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, 277 unsigned targetRank) { 278 if (newXferOp.getVectorType().getRank() > targetRank) 279 newXferOp->setAttr(kPassLabel, b.getUnitAttr()); 280 } 281 282 namespace lowering_n_d { 283 284 /// Helper data structure for data and mask buffers. 285 struct BufferAllocs { 286 Value dataBuffer; 287 Value maskBuffer; 288 }; 289 290 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. 291 static Operation *getAutomaticAllocationScope(Operation *op) { 292 Operation *scope = 293 op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 294 assert(scope && "Expected op to be inside automatic allocation scope"); 295 return scope; 296 } 297 298 /// Allocate temporary buffers for data (vector) and mask (if present). 299 template <typename OpTy> 300 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { 301 Location loc = xferOp.getLoc(); 302 OpBuilder::InsertionGuard guard(b); 303 Operation *scope = getAutomaticAllocationScope(xferOp); 304 assert(scope->getNumRegions() == 1 && 305 "AutomaticAllocationScope with >1 regions"); 306 b.setInsertionPointToStart(&scope->getRegion(0).front()); 307 308 BufferAllocs result; 309 auto bufferType = MemRefType::get({}, xferOp.getVectorType()); 310 result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType); 311 312 if (xferOp.getMask()) { 313 auto maskType = MemRefType::get({}, xferOp.getMask().getType()); 314 auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType); 315 b.setInsertionPoint(xferOp); 316 b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer); 317 result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange()); 318 } 319 320 return result; 321 } 322 323 /// Given a MemRefType with VectorType element type, unpack one dimension from 324 /// the VectorType into the MemRefType. 325 /// 326 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> 327 static FailureOr<MemRefType> unpackOneDim(MemRefType type) { 328 auto vectorType = dyn_cast<VectorType>(type.getElementType()); 329 // Vectors with leading scalable dims are not supported. 330 // It may be possible to support these in future by using dynamic memref dims. 331 if (vectorType.getScalableDims().front()) 332 return failure(); 333 auto memrefShape = type.getShape(); 334 SmallVector<int64_t, 8> newMemrefShape; 335 newMemrefShape.append(memrefShape.begin(), memrefShape.end()); 336 newMemrefShape.push_back(vectorType.getDimSize(0)); 337 return MemRefType::get(newMemrefShape, 338 VectorType::Builder(vectorType).dropDim(0)); 339 } 340 341 /// Given a transfer op, find the memref from which the mask is loaded. This 342 /// is similar to Strategy<TransferWriteOp>::getBuffer. 343 template <typename OpTy> 344 static Value getMaskBuffer(OpTy xferOp) { 345 assert(xferOp.getMask() && "Expected that transfer op has mask"); 346 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>(); 347 assert(loadOp && "Expected transfer op mask produced by LoadOp"); 348 return loadOp.getMemRef(); 349 } 350 351 /// Codegen strategy, depending on the operation. 352 template <typename OpTy> 353 struct Strategy; 354 355 /// Code strategy for vector TransferReadOp. 356 template <> 357 struct Strategy<TransferReadOp> { 358 /// Find the StoreOp that is used for writing the current TransferReadOp's 359 /// result to the temporary buffer allocation. 360 static memref::StoreOp getStoreOp(TransferReadOp xferOp) { 361 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); 362 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); 363 assert(storeOp && "Expected TransferReadOp result used by StoreOp"); 364 return storeOp; 365 } 366 367 /// Find the temporary buffer allocation. All labeled TransferReadOps are 368 /// used like this, where %buf is either the buffer allocation or a type cast 369 /// of the buffer allocation: 370 /// ``` 371 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... 372 /// memref.store %vec, %buf[...] ... 373 /// ``` 374 static Value getBuffer(TransferReadOp xferOp) { 375 return getStoreOp(xferOp).getMemRef(); 376 } 377 378 /// Retrieve the indices of the current StoreOp that stores into the buffer. 379 static void getBufferIndices(TransferReadOp xferOp, 380 SmallVector<Value, 8> &indices) { 381 auto storeOp = getStoreOp(xferOp); 382 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices(); 383 indices.append(prevIndices.begin(), prevIndices.end()); 384 } 385 386 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds 387 /// accesses on the to-be-unpacked dimension. 388 /// 389 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration 390 /// variable `iv`. 391 /// 2. Store the result into the (already `vector.type_cast`ed) buffer. 392 /// 393 /// E.g.: 394 /// ``` 395 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst 396 /// : memref<?x?x?xf32>, vector<4x3xf32> 397 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> 398 /// ``` 399 /// Is rewritten to: 400 /// ``` 401 /// %casted = vector.type_cast %buf 402 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 403 /// for %j = 0 to 4 { 404 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst 405 /// : memref<?x?x?xf32>, vector<3xf32> 406 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> 407 /// } 408 /// ``` 409 /// 410 /// Note: The loop and type cast are generated in TransferOpConversion. 411 /// The original TransferReadOp and store op are deleted in `cleanup`. 412 /// Note: The `mask` operand is set in TransferOpConversion. 413 static TransferReadOp rewriteOp(OpBuilder &b, 414 VectorTransferToSCFOptions options, 415 TransferReadOp xferOp, Value buffer, Value iv, 416 ValueRange /*loopState*/) { 417 SmallVector<Value, 8> storeIndices; 418 getBufferIndices(xferOp, storeIndices); 419 storeIndices.push_back(iv); 420 421 SmallVector<Value, 8> xferIndices; 422 getXferIndices(b, xferOp, iv, xferIndices); 423 424 Location loc = xferOp.getLoc(); 425 auto bufferType = dyn_cast<ShapedType>(buffer.getType()); 426 auto vecType = dyn_cast<VectorType>(bufferType.getElementType()); 427 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 428 auto newXferOp = b.create<vector::TransferReadOp>( 429 loc, vecType, xferOp.getSource(), xferIndices, 430 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 431 xferOp.getPadding(), Value(), inBoundsAttr); 432 433 maybeApplyPassLabel(b, newXferOp, options.targetRank); 434 435 b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices); 436 return newXferOp; 437 } 438 439 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write 440 /// padding value to the temporary buffer. 441 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, 442 Value buffer, Value iv, 443 ValueRange /*loopState*/) { 444 SmallVector<Value, 8> storeIndices; 445 getBufferIndices(xferOp, storeIndices); 446 storeIndices.push_back(iv); 447 448 Location loc = xferOp.getLoc(); 449 auto bufferType = dyn_cast<ShapedType>(buffer.getType()); 450 auto vecType = dyn_cast<VectorType>(bufferType.getElementType()); 451 auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding()); 452 b.create<memref::StoreOp>(loc, vec, buffer, storeIndices); 453 454 return Value(); 455 } 456 457 /// Cleanup after rewriting the op. 458 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, 459 scf::ForOp /*forOp*/) { 460 rewriter.eraseOp(getStoreOp(xferOp)); 461 rewriter.eraseOp(xferOp); 462 } 463 464 /// Return the initial loop state for the generated scf.for loop. 465 static Value initialLoopState(TransferReadOp xferOp) { return Value(); } 466 }; 467 468 /// Codegen strategy for vector TransferWriteOp. 469 template <> 470 struct Strategy<TransferWriteOp> { 471 /// Find the temporary buffer allocation. All labeled TransferWriteOps are 472 /// used like this, where %buf is either the buffer allocation or a type cast 473 /// of the buffer allocation: 474 /// ``` 475 /// %vec = memref.load %buf[...] ... 476 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... 477 /// ``` 478 static Value getBuffer(TransferWriteOp xferOp) { 479 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>(); 480 assert(loadOp && "Expected transfer op vector produced by LoadOp"); 481 return loadOp.getMemRef(); 482 } 483 484 /// Retrieve the indices of the current LoadOp that loads from the buffer. 485 static void getBufferIndices(TransferWriteOp xferOp, 486 SmallVector<Value, 8> &indices) { 487 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>(); 488 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices(); 489 indices.append(prevIndices.begin(), prevIndices.end()); 490 } 491 492 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds 493 /// accesses on the to-be-unpacked dimension. 494 /// 495 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, 496 /// using the loop iteration variable `iv`. 497 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back 498 /// to memory. 499 /// 500 /// Note: For more details, see comments on Strategy<TransferReadOp>. 501 static TransferWriteOp rewriteOp(OpBuilder &b, 502 VectorTransferToSCFOptions options, 503 TransferWriteOp xferOp, Value buffer, 504 Value iv, ValueRange loopState) { 505 SmallVector<Value, 8> loadIndices; 506 getBufferIndices(xferOp, loadIndices); 507 loadIndices.push_back(iv); 508 509 SmallVector<Value, 8> xferIndices; 510 getXferIndices(b, xferOp, iv, xferIndices); 511 512 Location loc = xferOp.getLoc(); 513 auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices); 514 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 515 auto source = loopState.empty() ? xferOp.getSource() : loopState[0]; 516 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 517 auto newXferOp = b.create<vector::TransferWriteOp>( 518 loc, type, vec, source, xferIndices, 519 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 520 inBoundsAttr); 521 522 maybeApplyPassLabel(b, newXferOp, options.targetRank); 523 524 return newXferOp; 525 } 526 527 /// Handle out-of-bounds accesses on the to-be-unpacked dimension. 528 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, 529 Value buffer, Value iv, 530 ValueRange loopState) { 531 return isTensorOp(xferOp) ? loopState[0] : Value(); 532 } 533 534 /// Cleanup after rewriting the op. 535 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, 536 scf::ForOp forOp) { 537 if (isTensorOp(xferOp)) { 538 assert(forOp->getNumResults() == 1 && "Expected one for loop result"); 539 rewriter.replaceOp(xferOp, forOp->getResult(0)); 540 } else { 541 rewriter.eraseOp(xferOp); 542 } 543 } 544 545 /// Return the initial loop state for the generated scf.for loop. 546 static Value initialLoopState(TransferWriteOp xferOp) { 547 return isTensorOp(xferOp) ? xferOp.getSource() : Value(); 548 } 549 }; 550 551 template <typename OpTy> 552 LogicalResult checkPrepareXferOp(OpTy xferOp, 553 VectorTransferToSCFOptions options) { 554 if (xferOp->hasAttr(kPassLabel)) 555 return failure(); 556 if (xferOp.getVectorType().getRank() <= options.targetRank) 557 return failure(); 558 // Currently the unpacking of the leading dimension into the memref is not 559 // supported for scalable dimensions. 560 if (xferOp.getVectorType().getScalableDims().front()) 561 return failure(); 562 if (isTensorOp(xferOp) && !options.lowerTensors) 563 return failure(); 564 // Transfer ops that modify the element type are not supported atm. 565 if (xferOp.getVectorType().getElementType() != 566 xferOp.getShapedType().getElementType()) 567 return failure(); 568 return success(); 569 } 570 571 /// Prepare a TransferReadOp for progressive lowering. 572 /// 573 /// 1. Allocate a temporary buffer. 574 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. 575 /// 3. Store the result of the TransferReadOp into the temporary buffer. 576 /// 4. Load the result from the temporary buffer and replace all uses of the 577 /// original TransferReadOp with this load. 578 /// 579 /// E.g.: 580 /// ``` 581 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst 582 /// : vector<5x4xf32>, memref<?x?x?xf32> 583 /// ``` 584 /// is rewritten to: 585 /// ``` 586 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 587 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst 588 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> 589 /// memref.store %1, %0[] : memref<vector<5x4xf32>> 590 /// %vec = memref.load %0[] : memref<vector<5x4xf32>> 591 /// ``` 592 /// 593 /// Note: A second temporary buffer may be allocated for the `mask` operand. 594 struct PrepareTransferReadConversion 595 : public VectorToSCFPattern<TransferReadOp> { 596 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 597 598 LogicalResult matchAndRewrite(TransferReadOp xferOp, 599 PatternRewriter &rewriter) const override { 600 if (checkPrepareXferOp(xferOp, options).failed()) 601 return failure(); 602 603 auto buffers = allocBuffers(rewriter, xferOp); 604 auto *newXfer = rewriter.clone(*xferOp.getOperation()); 605 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); 606 if (xferOp.getMask()) { 607 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign( 608 buffers.maskBuffer); 609 } 610 611 Location loc = xferOp.getLoc(); 612 rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0), 613 buffers.dataBuffer); 614 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); 615 616 return success(); 617 } 618 }; 619 620 /// Prepare a TransferWriteOp for progressive lowering. 621 /// 622 /// 1. Allocate a temporary buffer. 623 /// 2. Store the vector into the buffer. 624 /// 3. Load the vector from the buffer again. 625 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, 626 /// marking it eligible for progressive lowering via TransferOpConversion. 627 /// 628 /// E.g.: 629 /// ``` 630 /// vector.transfer_write %vec, %A[%a, %b, %c] 631 /// : vector<5x4xf32>, memref<?x?x?xf32> 632 /// ``` 633 /// is rewritten to: 634 /// ``` 635 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 636 /// memref.store %vec, %0[] : memref<vector<5x4xf32>> 637 /// %1 = memref.load %0[] : memref<vector<5x4xf32>> 638 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } 639 /// : vector<5x4xf32>, memref<?x?x?xf32> 640 /// ``` 641 /// 642 /// Note: A second temporary buffer may be allocated for the `mask` operand. 643 struct PrepareTransferWriteConversion 644 : public VectorToSCFPattern<TransferWriteOp> { 645 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 646 647 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 648 PatternRewriter &rewriter) const override { 649 if (checkPrepareXferOp(xferOp, options).failed()) 650 return failure(); 651 652 Location loc = xferOp.getLoc(); 653 auto buffers = allocBuffers(rewriter, xferOp); 654 rewriter.create<memref::StoreOp>(loc, xferOp.getVector(), 655 buffers.dataBuffer); 656 auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer); 657 rewriter.modifyOpInPlace(xferOp, [&]() { 658 xferOp.getVectorMutable().assign(loadedVec); 659 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); 660 }); 661 662 if (xferOp.getMask()) { 663 rewriter.modifyOpInPlace(xferOp, [&]() { 664 xferOp.getMaskMutable().assign(buffers.maskBuffer); 665 }); 666 } 667 668 return success(); 669 } 670 }; 671 672 /// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows 673 /// printing both 1D scalable vectors and n-D fixed size vectors. 674 /// 675 /// E.g.: 676 /// ``` 677 /// vector.print %v : vector<[4]xi32> 678 /// ``` 679 /// is rewritten to: 680 /// ``` 681 /// %c0 = arith.constant 0 : index 682 /// %c4 = arith.constant 4 : index 683 /// %c1 = arith.constant 1 : index 684 /// %vscale = vector.vscale 685 /// %length = arith.muli %vscale, %c4 : index 686 /// %lastIndex = arith.subi %length, %c1 : index 687 /// vector.print punctuation <open> 688 /// scf.for %i = %c0 to %length step %c1 { 689 /// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> 690 /// vector.print %el : i32 punctuation <no_punctuation> 691 /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index 692 /// scf.if %notLastIndex { 693 /// vector.print punctuation <comma> 694 /// } 695 /// } 696 /// vector.print punctuation <close> 697 /// vector.print 698 /// ``` 699 struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { 700 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern; 701 LogicalResult matchAndRewrite(vector::PrintOp printOp, 702 PatternRewriter &rewriter) const override { 703 if (!printOp.getSource()) 704 return failure(); 705 706 VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType()); 707 if (!vectorType) 708 return failure(); 709 710 // Currently >= 2D scalable vectors are not supported. 711 // These can't be lowered to LLVM (as LLVM does not support scalable vectors 712 // of scalable vectors), and due to limitations of current ops can't be 713 // indexed with SSA values or flattened. This may change after 714 // https://reviews.llvm.org/D155034, though there still needs to be a path 715 // for lowering to LLVM. 716 if (vectorType.getRank() > 1 && vectorType.isScalable()) 717 return failure(); 718 719 auto loc = printOp.getLoc(); 720 auto value = printOp.getSource(); 721 722 if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) { 723 // Oddly sized integers are (somewhat) buggy on a lot of backends, so to 724 // avoid issues extend them to a more standard size. 725 // https://github.com/llvm/llvm-project/issues/30613 726 auto width = intTy.getWidth(); 727 auto legalWidth = llvm::NextPowerOf2(std::max(8u, width) - 1); 728 auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth, 729 intTy.getSignedness()); 730 // arith can only take signless integers, so we must cast back and forth. 731 auto signlessSourceVectorType = 732 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy)); 733 auto signlessTargetVectorType = 734 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy)); 735 auto targetVectorType = vectorType.cloneWith({}, legalIntTy); 736 value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType, 737 value); 738 if (value.getType() != signlessTargetVectorType) { 739 if (width == 1 || intTy.isUnsigned()) 740 value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType, 741 value); 742 else 743 value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType, 744 value); 745 } 746 value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value); 747 vectorType = targetVectorType; 748 } 749 750 auto scalableDimensions = vectorType.getScalableDims(); 751 auto shape = vectorType.getShape(); 752 constexpr int64_t singletonShape[] = {1}; 753 if (vectorType.getRank() == 0) 754 shape = singletonShape; 755 756 if (vectorType.getRank() != 1) { 757 // Flatten n-D vectors to 1D. This is done to allow indexing with a 758 // non-constant value (which can currently only be done via 759 // vector.extractelement for 1D vectors). 760 auto flatLength = std::accumulate(shape.begin(), shape.end(), 1, 761 std::multiplies<int64_t>()); 762 auto flatVectorType = 763 VectorType::get({flatLength}, vectorType.getElementType()); 764 value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value); 765 } 766 767 vector::PrintOp firstClose; 768 SmallVector<Value, 8> loopIndices; 769 for (unsigned d = 0; d < shape.size(); d++) { 770 // Setup loop bounds and step. 771 Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); 772 Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]); 773 Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 774 if (!scalableDimensions.empty() && scalableDimensions[d]) { 775 auto vscale = rewriter.create<vector::VectorScaleOp>( 776 loc, rewriter.getIndexType()); 777 upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale); 778 } 779 auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step); 780 781 // Create a loop to print the elements surrounded by parentheses. 782 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); 783 auto loop = 784 rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); 785 auto printClose = rewriter.create<vector::PrintOp>( 786 loc, vector::PrintPunctuation::Close); 787 if (!firstClose) 788 firstClose = printClose; 789 790 auto loopIdx = loop.getInductionVar(); 791 loopIndices.push_back(loopIdx); 792 793 // Print a comma after all but the last element. 794 rewriter.setInsertionPointToStart(loop.getBody()); 795 auto notLastIndex = rewriter.create<arith::CmpIOp>( 796 loc, arith::CmpIPredicate::ult, loopIdx, lastIndex); 797 rewriter.create<scf::IfOp>(loc, notLastIndex, 798 [&](OpBuilder &builder, Location loc) { 799 builder.create<vector::PrintOp>( 800 loc, vector::PrintPunctuation::Comma); 801 builder.create<scf::YieldOp>(loc); 802 }); 803 804 rewriter.setInsertionPointToStart(loop.getBody()); 805 } 806 807 // Compute the flattened index. 808 // Note: For the > rank 1 vectors this assumes non-scalable. 809 Value flatIndex; 810 auto currentStride = 1; 811 for (int d = shape.size() - 1; d >= 0; d--) { 812 auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride); 813 auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]); 814 if (flatIndex) 815 flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index); 816 else 817 flatIndex = index; 818 currentStride *= shape[d]; 819 } 820 821 // Print the scalar elements in the inner most loop. 822 auto element = 823 rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex); 824 rewriter.create<vector::PrintOp>(loc, element, 825 vector::PrintPunctuation::NoPunctuation); 826 827 rewriter.setInsertionPointAfter(firstClose); 828 rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation()); 829 rewriter.eraseOp(printOp); 830 return success(); 831 } 832 833 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) { 834 return IntegerType::get(intTy.getContext(), intTy.getWidth(), 835 IntegerType::Signless); 836 }; 837 }; 838 839 /// Progressive lowering of vector transfer ops: Unpack one dimension. 840 /// 841 /// 1. Unpack one dimension from the current buffer type and cast the buffer 842 /// to that new type. E.g.: 843 /// ``` 844 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> 845 /// vector.transfer_write %vec ... 846 /// ``` 847 /// The following cast is generated: 848 /// ``` 849 /// %casted = vector.type_cast %0 850 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 851 /// ``` 852 /// 2. Generate a for loop and rewrite the transfer op according to the 853 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be 854 /// out-of-bounds, generate an if-check and handle both cases separately. 855 /// 3. Clean up according to the corresponding Strategy<OpTy>. 856 /// 857 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor 858 /// source (as opposed to a memref source), then each iteration of the generated 859 /// scf.for loop yields the new tensor value. E.g.: 860 /// ``` 861 /// %result = scf.for i = 0 to 5 { 862 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> 863 /// %1 = vector.transfer_write %0, %source[...] 864 /// : vector<4x3xf32>, tensor<5x4x3xf32> 865 /// scf.yield %1 : tensor<5x4x3xf32> 866 /// } 867 /// ``` 868 template <typename OpTy> 869 struct TransferOpConversion : public VectorToSCFPattern<OpTy> { 870 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 871 872 void initialize() { 873 // This pattern recursively unpacks one dimension at a time. The recursion 874 // bounded as the rank is strictly decreasing. 875 this->setHasBoundedRewriteRecursion(); 876 } 877 878 static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer, 879 SmallVectorImpl<Value> &loadIndices, 880 Value iv) { 881 assert(xferOp.getMask() && "Expected transfer op to have mask"); 882 883 // Add load indices from the previous iteration. 884 // The mask buffer depends on the permutation map, which makes determining 885 // the indices quite complex, so this is why we need to "look back" to the 886 // previous iteration to find the right indices. 887 Value maskBuffer = getMaskBuffer(xferOp); 888 for (Operation *user : maskBuffer.getUsers()) { 889 // If there is no previous load op, then the indices are empty. 890 if (auto loadOp = dyn_cast<memref::LoadOp>(user)) { 891 Operation::operand_range prevIndices = loadOp.getIndices(); 892 loadIndices.append(prevIndices.begin(), prevIndices.end()); 893 break; 894 } 895 } 896 897 // In case of broadcast: Use same indices to load from memref 898 // as before. 899 if (!xferOp.isBroadcastDim(0)) 900 loadIndices.push_back(iv); 901 } 902 903 LogicalResult matchAndRewrite(OpTy xferOp, 904 PatternRewriter &rewriter) const override { 905 if (!xferOp->hasAttr(kPassLabel)) 906 return failure(); 907 908 // Find and cast data buffer. How the buffer can be found depends on OpTy. 909 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); 910 Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp); 911 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType()); 912 FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType); 913 if (failed(castedDataType)) 914 return failure(); 915 916 auto castedDataBuffer = 917 locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer); 918 919 // If the xferOp has a mask: Find and cast mask buffer. 920 Value castedMaskBuffer; 921 if (xferOp.getMask()) { 922 Value maskBuffer = getMaskBuffer(xferOp); 923 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { 924 // Do not unpack a dimension of the mask, if: 925 // * To-be-unpacked transfer op dimension is a broadcast. 926 // * Mask is 1D, i.e., the mask cannot be further unpacked. 927 // (That means that all remaining dimensions of the transfer op must 928 // be broadcasted.) 929 castedMaskBuffer = maskBuffer; 930 } else { 931 // It's safe to assume the mask buffer can be unpacked if the data 932 // buffer was unpacked. 933 auto maskBufferType = cast<MemRefType>(maskBuffer.getType()); 934 MemRefType castedMaskType = *unpackOneDim(maskBufferType); 935 castedMaskBuffer = 936 locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); 937 } 938 } 939 940 // Loop bounds and step. 941 auto lb = locB.create<arith::ConstantIndexOp>(0); 942 auto ub = locB.create<arith::ConstantIndexOp>( 943 castedDataType->getDimSize(castedDataType->getRank() - 1)); 944 auto step = locB.create<arith::ConstantIndexOp>(1); 945 // TransferWriteOps that operate on tensors return the modified tensor and 946 // require a loop state. 947 auto loopState = Strategy<OpTy>::initialLoopState(xferOp); 948 949 // Generate for loop. 950 auto result = locB.create<scf::ForOp>( 951 lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 952 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 953 Type stateType = loopState.empty() ? Type() : loopState[0].getType(); 954 955 auto result = generateInBoundsCheck( 956 b, xferOp, iv, unpackedDim(xferOp), 957 stateType ? TypeRange(stateType) : TypeRange(), 958 /*inBoundsCase=*/ 959 [&](OpBuilder &b, Location loc) { 960 // Create new transfer op. 961 OpTy newXfer = Strategy<OpTy>::rewriteOp( 962 b, this->options, xferOp, castedDataBuffer, iv, loopState); 963 964 // If old transfer op has a mask: Set mask on new transfer op. 965 // Special case: If the mask of the old transfer op is 1D and 966 // the unpacked dim is not a broadcast, no mask is needed on 967 // the new transfer op. 968 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) || 969 xferOp.getMaskType().getRank() > 1)) { 970 OpBuilder::InsertionGuard guard(b); 971 b.setInsertionPoint(newXfer); // Insert load before newXfer. 972 973 SmallVector<Value, 8> loadIndices; 974 getMaskBufferLoadIndices(xferOp, castedMaskBuffer, 975 loadIndices, iv); 976 auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, 977 loadIndices); 978 rewriter.modifyOpInPlace(newXfer, [&]() { 979 newXfer.getMaskMutable().assign(mask); 980 }); 981 } 982 983 return loopState.empty() ? Value() : newXfer->getResult(0); 984 }, 985 /*outOfBoundsCase=*/ 986 [&](OpBuilder &b, Location /*loc*/) { 987 return Strategy<OpTy>::handleOutOfBoundsDim( 988 b, xferOp, castedDataBuffer, iv, loopState); 989 }); 990 991 maybeYieldValue(b, loc, !loopState.empty(), result); 992 }); 993 994 Strategy<OpTy>::cleanup(rewriter, xferOp, result); 995 return success(); 996 } 997 }; 998 999 /// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp 1000 /// and ConstantMaskOp. 1001 template <typename VscaleConstantBuilder> 1002 static FailureOr<SmallVector<OpFoldResult>> 1003 getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) { 1004 if (!mask) 1005 return SmallVector<OpFoldResult>{}; 1006 if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) { 1007 return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) { 1008 return OpFoldResult(dimSize); 1009 }); 1010 } 1011 if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) { 1012 int dimIdx = 0; 1013 VectorType maskType = constantMask.getVectorType(); 1014 auto indexType = IndexType::get(mask.getContext()); 1015 return llvm::map_to_vector( 1016 constantMask.getMaskDimSizes(), [&](int64_t dimSize) { 1017 // A scalable dim in a constant_mask means vscale x dimSize. 1018 if (maskType.getScalableDims()[dimIdx++]) 1019 return OpFoldResult(createVscaleMultiple(dimSize)); 1020 return OpFoldResult(IntegerAttr::get(indexType, dimSize)); 1021 }); 1022 } 1023 return failure(); 1024 } 1025 1026 /// Scalable vector lowering of transfer_write(transpose). This lowering only 1027 /// supports rank 2 (scalable) vectors, but can be used in conjunction with 1028 /// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion 1029 /// unrolls until the first scalable dimension. 1030 /// 1031 /// Example: 1032 /// 1033 /// BEFORE: 1034 /// ```mlir 1035 /// %transpose = vector.transpose %vec, [1, 0] 1036 /// : vector<4x[4]xf32> to vector<[4]x4xf32> 1037 /// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} 1038 /// : vector<[4]x4xf32>, memref<?x?xf32> 1039 /// ``` 1040 /// 1041 /// AFTER: 1042 /// ```mlir 1043 /// %c1 = arith.constant 1 : index 1044 /// %c4 = arith.constant 4 : index 1045 /// %c0 = arith.constant 0 : index 1046 /// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32> 1047 /// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32> 1048 /// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32> 1049 /// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32> 1050 /// %vscale = vector.vscale 1051 /// %c4_vscale = arith.muli %vscale, %c4 : index 1052 /// scf.for %idx = %c0 to %c4_vscale step %c1 { 1053 /// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32> 1054 /// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32> 1055 /// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32> 1056 /// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32> 1057 /// %slice_i = affine.apply #map(%idx)[%i] 1058 /// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32> 1059 /// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]} 1060 /// : vector<4xf32>, memref<?x?xf32> 1061 /// } 1062 /// ``` 1063 struct ScalableTransposeTransferWriteConversion 1064 : VectorToSCFPattern<vector::TransferWriteOp> { 1065 using VectorToSCFPattern::VectorToSCFPattern; 1066 1067 LogicalResult matchAndRewrite(TransferWriteOp writeOp, 1068 PatternRewriter &rewriter) const override { 1069 if (failed(checkLowerTensors(writeOp, rewriter))) 1070 return failure(); 1071 1072 VectorType vectorType = writeOp.getVectorType(); 1073 1074 // Note: By comparing the scalable dims to an ArrayRef of length two this 1075 // implicitly checks the rank (is also two). 1076 ArrayRef<bool> scalableFlags = vectorType.getScalableDims(); 1077 if (scalableFlags != ArrayRef<bool>{true, false}) { 1078 return rewriter.notifyMatchFailure( 1079 writeOp, "expected vector of the form vector<[N]xMxty>"); 1080 } 1081 1082 auto permutationMap = writeOp.getPermutationMap(); 1083 if (!permutationMap.isIdentity()) { 1084 return rewriter.notifyMatchFailure( 1085 writeOp, "non-identity permutations are unsupported (lower first)"); 1086 } 1087 1088 // Note: This pattern is only lowering the leading dimension (to a loop), 1089 // so we only check if the leading dimension is in bounds. The in-bounds 1090 // attribute for the trailing dimension will be propagated. 1091 if (!writeOp.isDimInBounds(0)) { 1092 return rewriter.notifyMatchFailure( 1093 writeOp, "out-of-bounds dims are unsupported (use masking)"); 1094 } 1095 1096 Value vector = writeOp.getVector(); 1097 auto transposeOp = vector.getDefiningOp<vector::TransposeOp>(); 1098 if (!transposeOp || 1099 transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) { 1100 return rewriter.notifyMatchFailure(writeOp, "source not transpose"); 1101 } 1102 1103 auto loc = writeOp.getLoc(); 1104 auto createVscaleMultiple = 1105 vector::makeVscaleConstantBuilder(rewriter, loc); 1106 1107 auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple); 1108 if (failed(maskDims)) { 1109 return rewriter.notifyMatchFailure(writeOp, 1110 "failed to resolve mask dims"); 1111 } 1112 1113 int64_t fixedDimSize = vectorType.getDimSize(1); 1114 auto fixedDimOffsets = llvm::seq(fixedDimSize); 1115 1116 // Extract all slices from the source of the transpose. 1117 auto transposeSource = transposeOp.getVector(); 1118 SmallVector<Value> transposeSourceSlices = 1119 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value { 1120 return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx); 1121 }); 1122 1123 // Loop bounds and step. 1124 auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1125 auto ub = 1126 maskDims->empty() 1127 ? Value(createVscaleMultiple(vectorType.getDimSize(0))) 1128 : vector::getAsValues(rewriter, loc, maskDims->front()).front(); 1129 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1130 1131 // Generate a new mask for the slice. 1132 VectorType sliceType = VectorType::Builder(vectorType).dropDim(0); 1133 Value sliceMask = nullptr; 1134 if (!maskDims->empty()) { 1135 sliceMask = rewriter.create<vector::CreateMaskOp>( 1136 loc, sliceType.clone(rewriter.getI1Type()), 1137 ArrayRef<OpFoldResult>(*maskDims).drop_front()); 1138 } 1139 1140 Value initDest = isTensorOp(writeOp) ? writeOp.getSource() : Value{}; 1141 ValueRange initLoopArgs = initDest ? initDest : ValueRange{}; 1142 auto result = rewriter.create<scf::ForOp>( 1143 loc, lb, ub, step, initLoopArgs, 1144 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) { 1145 // Indices for the new transfer op. 1146 SmallVector<Value, 8> xferIndices; 1147 getXferIndices(b, writeOp, iv, xferIndices); 1148 1149 // Extract a transposed slice from the source vector. 1150 SmallVector<Value> transposeElements = 1151 llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value { 1152 return b.create<vector::ExtractOp>( 1153 loc, transposeSourceSlices[idx], iv); 1154 }); 1155 auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType, 1156 transposeElements); 1157 1158 // Create the transfer_write for the slice. 1159 Value dest = 1160 loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front(); 1161 auto newWriteOp = b.create<vector::TransferWriteOp>( 1162 loc, sliceVec, dest, xferIndices, 1163 ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()); 1164 if (sliceMask) 1165 newWriteOp.getMaskMutable().assign(sliceMask); 1166 1167 // Yield from the loop. 1168 b.create<scf::YieldOp>(loc, loopIterArgs.empty() 1169 ? ValueRange{} 1170 : newWriteOp.getResult()); 1171 }); 1172 1173 if (isTensorOp(writeOp)) 1174 rewriter.replaceOp(writeOp, result); 1175 else 1176 rewriter.eraseOp(writeOp); 1177 1178 return success(); 1179 } 1180 }; 1181 1182 } // namespace lowering_n_d 1183 1184 namespace lowering_n_d_unrolled { 1185 1186 /// If the original transfer op has a mask, compute the mask of the new transfer 1187 /// op (for the current iteration `i`) and assign it. 1188 template <typename OpTy> 1189 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, 1190 int64_t i) { 1191 if (!xferOp.getMask()) 1192 return; 1193 1194 if (xferOp.isBroadcastDim(0)) { 1195 // To-be-unpacked dimension is a broadcast, which does not have a 1196 // corresponding mask dimension. Mask attribute remains unchanged. 1197 newXferOp.getMaskMutable().assign(xferOp.getMask()); 1198 return; 1199 } 1200 1201 if (xferOp.getMaskType().getRank() > 1) { 1202 // Unpack one dimension of the mask. 1203 OpBuilder::InsertionGuard guard(b); 1204 b.setInsertionPoint(newXferOp); // Insert load before newXfer. 1205 1206 llvm::SmallVector<int64_t, 1> indices({i}); 1207 Location loc = xferOp.getLoc(); 1208 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices); 1209 newXferOp.getMaskMutable().assign(newMask); 1210 } 1211 1212 // If we end up here: The mask of the old transfer op is 1D and the unpacked 1213 // dim is not a broadcast, so no mask is needed on the new transfer op. 1214 // `generateInBoundsCheck` will have evaluated the mask already. 1215 } 1216 1217 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one 1218 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no 1219 /// memref buffer is allocated and the SCF loop is fully unrolled. 1220 /// 1221 /// ``` 1222 /// E.g.: 1223 /// ``` 1224 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding 1225 /// : memref<?x?x?xf32>, vector<5x4xf32> 1226 /// ``` 1227 /// is rewritten to IR such as (simplified): 1228 /// ``` 1229 /// %v_init = splat %padding : vector<5x4xf32> 1230 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding 1231 /// : memref<?x?x?xf32>, vector<4xf32> 1232 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> 1233 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding 1234 /// : memref<?x?x?xf32>, vector<4xf32> 1235 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> 1236 /// ... 1237 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding 1238 /// : memref<?x?x?xf32>, vector<4xf32> 1239 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> 1240 /// ``` 1241 /// 1242 /// Note: As an optimization, if the result of the original TransferReadOp 1243 /// was directly inserted into another vector, no new %v_init vector is created. 1244 /// Instead, the new TransferReadOp results are inserted into that vector. 1245 struct UnrollTransferReadConversion 1246 : public VectorToSCFPattern<TransferReadOp> { 1247 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 1248 1249 void initialize() { 1250 // This pattern recursively unpacks one dimension at a time. The recursion 1251 // bounded as the rank is strictly decreasing. 1252 setHasBoundedRewriteRecursion(); 1253 } 1254 1255 /// Get or build the vector into which the newly created TransferReadOp 1256 /// results are inserted. 1257 Value buildResultVector(PatternRewriter &rewriter, 1258 TransferReadOp xferOp) const { 1259 if (auto insertOp = getInsertOp(xferOp)) 1260 return insertOp.getDest(); 1261 Location loc = xferOp.getLoc(); 1262 return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(), 1263 xferOp.getPadding()); 1264 } 1265 1266 /// If the result of the TransferReadOp has exactly one user, which is a 1267 /// vector::InsertOp, return that operation. 1268 vector::InsertOp getInsertOp(TransferReadOp xferOp) const { 1269 if (xferOp->hasOneUse()) { 1270 Operation *xferOpUser = *xferOp->getUsers().begin(); 1271 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) 1272 return insertOp; 1273 } 1274 1275 return vector::InsertOp(); 1276 } 1277 1278 /// If the result of the TransferReadOp has exactly one user, which is a 1279 /// vector::InsertOp, return that operation's indices. 1280 void getInsertionIndices(TransferReadOp xferOp, 1281 SmallVectorImpl<OpFoldResult> &indices) const { 1282 if (auto insertOp = getInsertOp(xferOp)) { 1283 auto pos = insertOp.getMixedPosition(); 1284 indices.append(pos.begin(), pos.end()); 1285 } 1286 } 1287 1288 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 1289 /// accesses, and broadcasts and transposes in permutation maps. 1290 LogicalResult matchAndRewrite(TransferReadOp xferOp, 1291 PatternRewriter &rewriter) const override { 1292 if (xferOp.getVectorType().getRank() <= options.targetRank) 1293 return rewriter.notifyMatchFailure( 1294 xferOp, "vector rank is less or equal to target rank"); 1295 if (failed(checkLowerTensors(xferOp, rewriter))) 1296 return failure(); 1297 // Transfer ops that modify the element type are not supported atm. 1298 if (xferOp.getVectorType().getElementType() != 1299 xferOp.getShapedType().getElementType()) 1300 return rewriter.notifyMatchFailure( 1301 xferOp, "not yet supported: element type mismatch"); 1302 auto xferVecType = xferOp.getVectorType(); 1303 if (xferVecType.getScalableDims()[0]) { 1304 // Cannot unroll a scalable dimension at compile time. 1305 return rewriter.notifyMatchFailure( 1306 xferOp, "scalable dimensions cannot be unrolled"); 1307 } 1308 1309 auto insertOp = getInsertOp(xferOp); 1310 auto vec = buildResultVector(rewriter, xferOp); 1311 auto vecType = dyn_cast<VectorType>(vec.getType()); 1312 1313 VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0); 1314 1315 int64_t dimSize = xferVecType.getShape()[0]; 1316 1317 // Generate fully unrolled loop of transfer ops. 1318 Location loc = xferOp.getLoc(); 1319 for (int64_t i = 0; i < dimSize; ++i) { 1320 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 1321 1322 vec = generateInBoundsCheck( 1323 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), 1324 /*inBoundsCase=*/ 1325 [&](OpBuilder &b, Location loc) { 1326 // Indices for the new transfer op. 1327 SmallVector<Value, 8> xferIndices; 1328 getXferIndices(b, xferOp, iv, xferIndices); 1329 1330 // Indices for the new vector.insert op. 1331 SmallVector<OpFoldResult, 8> insertionIndices; 1332 getInsertionIndices(xferOp, insertionIndices); 1333 insertionIndices.push_back(rewriter.getIndexAttr(i)); 1334 1335 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 1336 auto newXferOp = b.create<vector::TransferReadOp>( 1337 loc, newXferVecType, xferOp.getSource(), xferIndices, 1338 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 1339 xferOp.getPadding(), Value(), inBoundsAttr); 1340 maybeAssignMask(b, xferOp, newXferOp, i); 1341 return b.create<vector::InsertOp>(loc, newXferOp, vec, 1342 insertionIndices); 1343 }, 1344 /*outOfBoundsCase=*/ 1345 [&](OpBuilder &b, Location loc) { 1346 // Loop through original (unmodified) vector. 1347 return vec; 1348 }); 1349 } 1350 1351 if (insertOp) { 1352 // Rewrite single user of the old TransferReadOp, which was an InsertOp. 1353 rewriter.replaceOp(insertOp, vec); 1354 rewriter.eraseOp(xferOp); 1355 } else { 1356 rewriter.replaceOp(xferOp, vec); 1357 } 1358 1359 return success(); 1360 } 1361 }; 1362 1363 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one 1364 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no 1365 /// memref buffer is allocated and the SCF loop is fully unrolled. 1366 /// 1367 /// ``` 1368 /// E.g.: 1369 /// ``` 1370 /// vector.transfer_write %vec, %A[%a, %b, %c] 1371 /// : vector<5x4xf32>, memref<?x?x?xf32> 1372 /// ``` 1373 /// is rewritten to IR such as (simplified): 1374 /// ``` 1375 /// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32> 1376 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> 1377 /// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32> 1378 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> 1379 /// ... 1380 /// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32> 1381 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> 1382 /// ``` 1383 /// 1384 /// Note: As an optimization, if the vector of the original TransferWriteOp 1385 /// was directly extracted from another vector via an ExtractOp `a`, extract 1386 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By 1387 /// doing so, `a` may become dead, and the number of ExtractOps generated during 1388 /// recursive application of this pattern will be minimal. 1389 struct UnrollTransferWriteConversion 1390 : public VectorToSCFPattern<TransferWriteOp> { 1391 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 1392 1393 void initialize() { 1394 // This pattern recursively unpacks one dimension at a time. The recursion 1395 // bounded as the rank is strictly decreasing. 1396 setHasBoundedRewriteRecursion(); 1397 } 1398 1399 /// Return the vector from which newly generated ExtracOps will extract. 1400 Value getDataVector(TransferWriteOp xferOp) const { 1401 if (auto extractOp = getExtractOp(xferOp)) 1402 return extractOp.getVector(); 1403 return xferOp.getVector(); 1404 } 1405 1406 /// If the input of the given TransferWriteOp is an ExtractOp, return it. 1407 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { 1408 if (auto *op = xferOp.getVector().getDefiningOp()) 1409 return dyn_cast<vector::ExtractOp>(op); 1410 return vector::ExtractOp(); 1411 } 1412 1413 /// If the input of the given TransferWriteOp is an ExtractOp, return its 1414 /// indices. 1415 void getExtractionIndices(TransferWriteOp xferOp, 1416 SmallVectorImpl<OpFoldResult> &indices) const { 1417 if (auto extractOp = getExtractOp(xferOp)) { 1418 auto pos = extractOp.getMixedPosition(); 1419 indices.append(pos.begin(), pos.end()); 1420 } 1421 } 1422 1423 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 1424 /// accesses, and broadcasts and transposes in permutation maps. 1425 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 1426 PatternRewriter &rewriter) const override { 1427 VectorType inputVectorTy = xferOp.getVectorType(); 1428 1429 if (inputVectorTy.getRank() <= options.targetRank) 1430 return failure(); 1431 1432 if (failed(checkLowerTensors(xferOp, rewriter))) 1433 return failure(); 1434 // Transfer ops that modify the element type are not supported atm. 1435 if (inputVectorTy.getElementType() != 1436 xferOp.getShapedType().getElementType()) 1437 return failure(); 1438 1439 auto vec = getDataVector(xferOp); 1440 if (inputVectorTy.getScalableDims()[0]) { 1441 // Cannot unroll a scalable dimension at compile time. 1442 return failure(); 1443 } 1444 1445 int64_t dimSize = inputVectorTy.getShape()[0]; 1446 Value source = xferOp.getSource(); // memref or tensor to be written to. 1447 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 1448 1449 // Generate fully unrolled loop of transfer ops. 1450 Location loc = xferOp.getLoc(); 1451 for (int64_t i = 0; i < dimSize; ++i) { 1452 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 1453 1454 auto updatedSource = generateInBoundsCheck( 1455 rewriter, xferOp, iv, unpackedDim(xferOp), 1456 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), 1457 /*inBoundsCase=*/ 1458 [&](OpBuilder &b, Location loc) { 1459 // Indices for the new transfer op. 1460 SmallVector<Value, 8> xferIndices; 1461 getXferIndices(b, xferOp, iv, xferIndices); 1462 1463 // Indices for the new vector.extract op. 1464 SmallVector<OpFoldResult, 8> extractionIndices; 1465 getExtractionIndices(xferOp, extractionIndices); 1466 extractionIndices.push_back(b.getI64IntegerAttr(i)); 1467 1468 auto extracted = 1469 b.create<vector::ExtractOp>(loc, vec, extractionIndices); 1470 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 1471 Value xferVec; 1472 if (inputVectorTy.getRank() == 1) { 1473 // When target-rank=0, unrolling would causes the vector input 1474 // argument into `transfer_write` to become a scalar. We solve 1475 // this by broadcasting the scalar to a 0D vector. 1476 xferVec = b.create<vector::BroadcastOp>( 1477 loc, VectorType::get({}, extracted.getType()), extracted); 1478 } else { 1479 xferVec = extracted; 1480 } 1481 auto newXferOp = b.create<vector::TransferWriteOp>( 1482 loc, sourceType, xferVec, source, xferIndices, 1483 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 1484 inBoundsAttr); 1485 1486 maybeAssignMask(b, xferOp, newXferOp, i); 1487 1488 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); 1489 }, 1490 /*outOfBoundsCase=*/ 1491 [&](OpBuilder &b, Location loc) { 1492 return isTensorOp(xferOp) ? source : Value(); 1493 }); 1494 1495 if (isTensorOp(xferOp)) 1496 source = updatedSource; 1497 } 1498 1499 if (isTensorOp(xferOp)) 1500 rewriter.replaceOp(xferOp, source); 1501 else 1502 rewriter.eraseOp(xferOp); 1503 1504 return success(); 1505 } 1506 }; 1507 1508 } // namespace lowering_n_d_unrolled 1509 1510 namespace lowering_1_d { 1511 1512 /// Compute the indices into the memref for the LoadOp/StoreOp generated as 1513 /// part of TransferOp1dConversion. Return the memref dimension on which 1514 /// the transfer is operating. A return value of std::nullopt indicates a 1515 /// broadcast. 1516 template <typename OpTy> 1517 static std::optional<int64_t> 1518 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, 1519 SmallVector<Value, 8> &memrefIndices) { 1520 auto indices = xferOp.getIndices(); 1521 auto map = xferOp.getPermutationMap(); 1522 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 1523 1524 memrefIndices.append(indices.begin(), indices.end()); 1525 assert(map.getNumResults() == 1 && 1526 "Expected 1 permutation map result for 1D transfer"); 1527 if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) { 1528 Location loc = xferOp.getLoc(); 1529 auto dim = expr.getPosition(); 1530 AffineExpr d0, d1; 1531 bindDims(xferOp.getContext(), d0, d1); 1532 Value offset = memrefIndices[dim]; 1533 memrefIndices[dim] = 1534 affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 1535 return dim; 1536 } 1537 1538 assert(xferOp.isBroadcastDim(0) && 1539 "Expected AffineDimExpr or AffineConstantExpr"); 1540 return std::nullopt; 1541 } 1542 1543 /// Codegen strategy for TransferOp1dConversion, depending on the 1544 /// operation. 1545 template <typename OpTy> 1546 struct Strategy1d; 1547 1548 /// Codegen strategy for TransferReadOp. 1549 template <> 1550 struct Strategy1d<TransferReadOp> { 1551 static void generateForLoopBody(OpBuilder &b, Location loc, 1552 TransferReadOp xferOp, Value iv, 1553 ValueRange loopState) { 1554 SmallVector<Value, 8> indices; 1555 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1556 auto vec = loopState[0]; 1557 1558 // In case of out-of-bounds access, leave `vec` as is (was initialized with 1559 // padding value). 1560 auto nextVec = generateInBoundsCheck( 1561 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), 1562 /*inBoundsCase=*/ 1563 [&](OpBuilder &b, Location loc) { 1564 Value val = 1565 b.create<memref::LoadOp>(loc, xferOp.getSource(), indices); 1566 return b.create<vector::InsertElementOp>(loc, val, vec, iv); 1567 }, 1568 /*outOfBoundsCase=*/ 1569 [&](OpBuilder & /*b*/, Location loc) { return vec; }); 1570 b.create<scf::YieldOp>(loc, nextVec); 1571 } 1572 1573 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { 1574 // Inititalize vector with padding value. 1575 Location loc = xferOp.getLoc(); 1576 return b.create<vector::SplatOp>(loc, xferOp.getVectorType(), 1577 xferOp.getPadding()); 1578 } 1579 }; 1580 1581 /// Codegen strategy for TransferWriteOp. 1582 template <> 1583 struct Strategy1d<TransferWriteOp> { 1584 static void generateForLoopBody(OpBuilder &b, Location loc, 1585 TransferWriteOp xferOp, Value iv, 1586 ValueRange /*loopState*/) { 1587 SmallVector<Value, 8> indices; 1588 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1589 1590 // Nothing to do in case of out-of-bounds access. 1591 generateInBoundsCheck( 1592 b, xferOp, iv, dim, 1593 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { 1594 auto val = 1595 b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv); 1596 b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices); 1597 }); 1598 b.create<scf::YieldOp>(loc); 1599 } 1600 1601 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { 1602 return Value(); 1603 } 1604 }; 1605 1606 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is 1607 /// necessary in cases where a 1D vector transfer op cannot be lowered into 1608 /// vector load/stores due to non-unit strides or broadcasts: 1609 /// 1610 /// * Transfer dimension is not the last memref dimension 1611 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) 1612 /// * Memref has a layout map with non-unit stride on the last dimension 1613 /// 1614 /// This pattern generates IR as follows: 1615 /// 1616 /// 1. Generate a for loop iterating over each vector element. 1617 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, 1618 /// depending on OpTy. 1619 /// 1620 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp 1621 /// can be generated instead of TransferOp1dConversion. Add such a pattern 1622 /// to ConvertVectorToLLVM. 1623 /// 1624 /// E.g.: 1625 /// ``` 1626 /// vector.transfer_write %vec, %A[%a, %b] 1627 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 1628 /// : vector<9xf32>, memref<?x?xf32> 1629 /// ``` 1630 /// Is rewritten to approximately the following pseudo-IR: 1631 /// ``` 1632 /// for i = 0 to 9 { 1633 /// %t = vector.extractelement %vec[i] : vector<9xf32> 1634 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> 1635 /// } 1636 /// ``` 1637 template <typename OpTy> 1638 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { 1639 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 1640 1641 LogicalResult matchAndRewrite(OpTy xferOp, 1642 PatternRewriter &rewriter) const override { 1643 // TODO: support 0-d corner case. 1644 if (xferOp.getTransferRank() == 0) 1645 return failure(); 1646 auto map = xferOp.getPermutationMap(); 1647 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType()); 1648 1649 if (!memRefType) 1650 return failure(); 1651 if (xferOp.getVectorType().getRank() != 1) 1652 return failure(); 1653 if (map.isMinorIdentity() && memRefType.isLastDimUnitStride()) 1654 return failure(); // Handled by ConvertVectorToLLVM 1655 1656 // Loop bounds, step, state... 1657 Location loc = xferOp.getLoc(); 1658 auto vecType = xferOp.getVectorType(); 1659 auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1660 Value ub = 1661 rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0)); 1662 if (vecType.isScalable()) { 1663 Value vscale = 1664 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 1665 ub = rewriter.create<arith::MulIOp>(loc, ub, vscale); 1666 } 1667 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1668 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp); 1669 1670 // Generate for loop. 1671 rewriter.replaceOpWithNewOp<scf::ForOp>( 1672 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 1673 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 1674 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState); 1675 }); 1676 1677 return success(); 1678 } 1679 }; 1680 1681 } // namespace lowering_1_d 1682 } // namespace 1683 1684 void mlir::populateVectorToSCFConversionPatterns( 1685 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { 1686 if (options.unroll) { 1687 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion, 1688 lowering_n_d_unrolled::UnrollTransferWriteConversion>( 1689 patterns.getContext(), options); 1690 } else { 1691 patterns.add<lowering_n_d::PrepareTransferReadConversion, 1692 lowering_n_d::PrepareTransferWriteConversion, 1693 lowering_n_d::TransferOpConversion<TransferReadOp>, 1694 lowering_n_d::TransferOpConversion<TransferWriteOp>>( 1695 patterns.getContext(), options); 1696 } 1697 if (options.lowerScalable) { 1698 patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>( 1699 patterns.getContext(), options); 1700 } 1701 if (options.targetRank == 1) { 1702 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>, 1703 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>( 1704 patterns.getContext(), options); 1705 } 1706 patterns.add<lowering_n_d::DecomposePrintOpConversion>(patterns.getContext(), 1707 options); 1708 } 1709 1710 namespace { 1711 1712 struct ConvertVectorToSCFPass 1713 : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> { 1714 ConvertVectorToSCFPass() = default; 1715 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1716 this->fullUnroll = options.unroll; 1717 this->targetRank = options.targetRank; 1718 this->lowerTensors = options.lowerTensors; 1719 this->lowerScalable = options.lowerScalable; 1720 } 1721 1722 void runOnOperation() override { 1723 VectorTransferToSCFOptions options; 1724 options.unroll = fullUnroll; 1725 options.targetRank = targetRank; 1726 options.lowerTensors = lowerTensors; 1727 options.lowerScalable = lowerScalable; 1728 1729 // Lower permutation maps first. 1730 RewritePatternSet lowerTransferPatterns(&getContext()); 1731 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 1732 lowerTransferPatterns); 1733 (void)applyPatternsGreedily(getOperation(), 1734 std::move(lowerTransferPatterns)); 1735 1736 RewritePatternSet patterns(&getContext()); 1737 populateVectorToSCFConversionPatterns(patterns, options); 1738 (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 1739 } 1740 }; 1741 1742 } // namespace 1743 1744 std::unique_ptr<Pass> 1745 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1746 return std::make_unique<ConvertVectorToSCFPass>(options); 1747 } 1748