1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of vector transfer operations to SCF. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include <numeric> 14 #include <optional> 15 #include <type_traits> 16 17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 18 19 #include "mlir/Dialect/Affine/IR/AffineOps.h" 20 #include "mlir/Dialect/Arith/IR/Arith.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/SCF/IR/SCF.h" 23 #include "mlir/Dialect/Tensor/IR/Tensor.h" 24 #include "mlir/Dialect/Vector/IR/VectorOps.h" 25 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 26 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 27 #include "mlir/IR/Builders.h" 28 #include "mlir/IR/ImplicitLocOpBuilder.h" 29 #include "mlir/Pass/Pass.h" 30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 31 #include "mlir/Transforms/Passes.h" 32 33 namespace mlir { 34 #define GEN_PASS_DEF_CONVERTVECTORTOSCF 35 #include "mlir/Conversion/Passes.h.inc" 36 } // namespace mlir 37 38 using namespace mlir; 39 using vector::TransferReadOp; 40 using vector::TransferWriteOp; 41 42 namespace { 43 44 /// Attribute name used for labeling transfer ops during progressive lowering. 45 static const char kPassLabel[] = "__vector_to_scf_lowering__"; 46 47 /// Patterns that inherit from this struct have access to 48 /// VectorTransferToSCFOptions. 49 template <typename OpTy> 50 struct VectorToSCFPattern : public OpRewritePattern<OpTy> { 51 explicit VectorToSCFPattern(MLIRContext *context, 52 VectorTransferToSCFOptions opt) 53 : OpRewritePattern<OpTy>(context), options(opt) {} 54 55 VectorTransferToSCFOptions options; 56 }; 57 58 /// Given a vector transfer op, calculate which dimension of the `source` 59 /// memref should be unpacked in the next application of TransferOpConversion. 60 /// A return value of std::nullopt indicates a broadcast. 61 template <typename OpTy> 62 static std::optional<int64_t> unpackedDim(OpTy xferOp) { 63 // TODO: support 0-d corner case. 64 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 65 auto map = xferOp.getPermutationMap(); 66 if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) { 67 return expr.getPosition(); 68 } 69 assert(xferOp.isBroadcastDim(0) && 70 "Expected AffineDimExpr or AffineConstantExpr"); 71 return std::nullopt; 72 } 73 74 /// Compute the permutation map for the new (N-1)-D vector transfer op. This 75 /// map is identical to the current permutation map, but the first result is 76 /// omitted. 77 template <typename OpTy> 78 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { 79 // TODO: support 0-d corner case. 80 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 81 auto map = xferOp.getPermutationMap(); 82 return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), 83 b.getContext()); 84 } 85 86 /// Calculate the indices for the new vector transfer op. 87 /// 88 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... 89 /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> 90 /// ^^^^^^ 91 /// `iv` is the iteration variable of the (new) surrounding loop. 92 template <typename OpTy> 93 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, 94 SmallVector<Value, 8> &indices) { 95 typename OpTy::Adaptor adaptor(xferOp); 96 // Corresponding memref dim of the vector dim that is unpacked. 97 auto dim = unpackedDim(xferOp); 98 auto prevIndices = adaptor.getIndices(); 99 indices.append(prevIndices.begin(), prevIndices.end()); 100 101 Location loc = xferOp.getLoc(); 102 bool isBroadcast = !dim.has_value(); 103 if (!isBroadcast) { 104 AffineExpr d0, d1; 105 bindDims(xferOp.getContext(), d0, d1); 106 Value offset = adaptor.getIndices()[*dim]; 107 indices[*dim] = 108 affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 109 } 110 } 111 112 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, 113 Value value) { 114 if (hasRetVal) { 115 assert(value && "Expected non-empty value"); 116 b.create<scf::YieldOp>(loc, value); 117 } else { 118 b.create<scf::YieldOp>(loc); 119 } 120 } 121 122 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask 123 /// is set to true. No such check is generated under following circumstances: 124 /// * xferOp does not have a mask. 125 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is 126 /// computed and attached to the new transfer op in the pattern.) 127 /// * The to-be-unpacked dim of xferOp is a broadcast. 128 template <typename OpTy> 129 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { 130 if (!xferOp.getMask()) 131 return Value(); 132 if (xferOp.getMaskType().getRank() != 1) 133 return Value(); 134 if (xferOp.isBroadcastDim(0)) 135 return Value(); 136 137 Location loc = xferOp.getLoc(); 138 return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv); 139 } 140 141 /// Helper function TransferOpConversion and TransferOp1dConversion. 142 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the 143 /// specified dimension `dim` with the loop iteration variable `iv`. 144 /// E.g., when unpacking dimension 0 from: 145 /// ``` 146 /// %vec = vector.transfer_read %A[%a, %b] %cst 147 /// : vector<5x4xf32>, memref<?x?xf32> 148 /// ``` 149 /// An if check similar to this will be generated inside the loop: 150 /// ``` 151 /// %d = memref.dim %A, %c0 : memref<?x?xf32> 152 /// if (%a + iv < %d) { 153 /// (in-bounds case) 154 /// } else { 155 /// (out-of-bounds case) 156 /// } 157 /// ``` 158 /// 159 /// If the transfer is 1D and has a mask, this function generates a more complex 160 /// check also accounts for potentially masked out elements. 161 /// 162 /// This function variant returns the value returned by `inBoundsCase` or 163 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in 164 /// `resultTypes`. 165 template <typename OpTy> 166 static Value generateInBoundsCheck( 167 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim, 168 TypeRange resultTypes, 169 function_ref<Value(OpBuilder &, Location)> inBoundsCase, 170 function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 171 bool hasRetVal = !resultTypes.empty(); 172 Value cond; // Condition to be built... 173 174 // Condition check 1: Access in-bounds? 175 bool isBroadcast = !dim; // No in-bounds check for broadcasts. 176 Location loc = xferOp.getLoc(); 177 ImplicitLocOpBuilder lb(xferOp.getLoc(), b); 178 if (!xferOp.isDimInBounds(0) && !isBroadcast) { 179 Value memrefDim = 180 vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim); 181 AffineExpr d0, d1; 182 bindDims(xferOp.getContext(), d0, d1); 183 Value base = xferOp.getIndices()[*dim]; 184 Value memrefIdx = 185 affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); 186 cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim, 187 memrefIdx); 188 } 189 190 // Condition check 2: Masked in? 191 if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { 192 if (cond) 193 cond = lb.create<arith::AndIOp>(cond, maskCond); 194 else 195 cond = maskCond; 196 } 197 198 // If the condition is non-empty, generate an SCF::IfOp. 199 if (cond) { 200 auto check = lb.create<scf::IfOp>( 201 cond, 202 /*thenBuilder=*/ 203 [&](OpBuilder &b, Location loc) { 204 maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); 205 }, 206 /*elseBuilder=*/ 207 [&](OpBuilder &b, Location loc) { 208 if (outOfBoundsCase) { 209 maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); 210 } else { 211 b.create<scf::YieldOp>(loc); 212 } 213 }); 214 215 return hasRetVal ? check.getResult(0) : Value(); 216 } 217 218 // Condition is empty, no need for an SCF::IfOp. 219 return inBoundsCase(b, loc); 220 } 221 222 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have 223 /// a return value. Consequently, this function does not have a return value. 224 template <typename OpTy> 225 static void generateInBoundsCheck( 226 OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim, 227 function_ref<void(OpBuilder &, Location)> inBoundsCase, 228 function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 229 generateInBoundsCheck( 230 b, xferOp, iv, dim, /*resultTypes=*/TypeRange(), 231 /*inBoundsCase=*/ 232 [&](OpBuilder &b, Location loc) { 233 inBoundsCase(b, loc); 234 return Value(); 235 }, 236 /*outOfBoundsCase=*/ 237 [&](OpBuilder &b, Location loc) { 238 if (outOfBoundsCase) 239 outOfBoundsCase(b, loc); 240 return Value(); 241 }); 242 } 243 244 /// Given an ArrayAttr, return a copy where the first element is dropped. 245 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) { 246 if (!attr) 247 return attr; 248 return ArrayAttr::get(b.getContext(), attr.getValue().drop_front()); 249 } 250 251 /// Add the pass label to a vector transfer op if its rank is not the target 252 /// rank. 253 template <typename OpTy> 254 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, 255 unsigned targetRank) { 256 if (newXferOp.getVectorType().getRank() > targetRank) 257 newXferOp->setAttr(kPassLabel, b.getUnitAttr()); 258 } 259 260 /// Return true if this transfer op operates on a source tensor. 261 template <typename OpTy> 262 static bool isTensorOp(OpTy xferOp) { 263 if (isa<RankedTensorType>(xferOp.getShapedType())) { 264 if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { 265 // TransferWriteOps on tensors have a result. 266 assert(xferOp->getNumResults() > 0); 267 } 268 return true; 269 } 270 return false; 271 } 272 273 namespace lowering_n_d { 274 275 /// Helper data structure for data and mask buffers. 276 struct BufferAllocs { 277 Value dataBuffer; 278 Value maskBuffer; 279 }; 280 281 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. 282 static Operation *getAutomaticAllocationScope(Operation *op) { 283 Operation *scope = 284 op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 285 assert(scope && "Expected op to be inside automatic allocation scope"); 286 return scope; 287 } 288 289 /// Allocate temporary buffers for data (vector) and mask (if present). 290 template <typename OpTy> 291 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { 292 Location loc = xferOp.getLoc(); 293 OpBuilder::InsertionGuard guard(b); 294 Operation *scope = getAutomaticAllocationScope(xferOp); 295 assert(scope->getNumRegions() == 1 && 296 "AutomaticAllocationScope with >1 regions"); 297 b.setInsertionPointToStart(&scope->getRegion(0).front()); 298 299 BufferAllocs result; 300 auto bufferType = MemRefType::get({}, xferOp.getVectorType()); 301 result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType); 302 303 if (xferOp.getMask()) { 304 auto maskType = MemRefType::get({}, xferOp.getMask().getType()); 305 auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType); 306 b.setInsertionPoint(xferOp); 307 b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer); 308 result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange()); 309 } 310 311 return result; 312 } 313 314 /// Given a MemRefType with VectorType element type, unpack one dimension from 315 /// the VectorType into the MemRefType. 316 /// 317 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> 318 static FailureOr<MemRefType> unpackOneDim(MemRefType type) { 319 auto vectorType = dyn_cast<VectorType>(type.getElementType()); 320 // Vectors with leading scalable dims are not supported. 321 // It may be possible to support these in future by using dynamic memref dims. 322 if (vectorType.getScalableDims().front()) 323 return failure(); 324 auto memrefShape = type.getShape(); 325 SmallVector<int64_t, 8> newMemrefShape; 326 newMemrefShape.append(memrefShape.begin(), memrefShape.end()); 327 newMemrefShape.push_back(vectorType.getDimSize(0)); 328 return MemRefType::get(newMemrefShape, 329 VectorType::Builder(vectorType).dropDim(0)); 330 } 331 332 /// Given a transfer op, find the memref from which the mask is loaded. This 333 /// is similar to Strategy<TransferWriteOp>::getBuffer. 334 template <typename OpTy> 335 static Value getMaskBuffer(OpTy xferOp) { 336 assert(xferOp.getMask() && "Expected that transfer op has mask"); 337 auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>(); 338 assert(loadOp && "Expected transfer op mask produced by LoadOp"); 339 return loadOp.getMemRef(); 340 } 341 342 /// Codegen strategy, depending on the operation. 343 template <typename OpTy> 344 struct Strategy; 345 346 /// Code strategy for vector TransferReadOp. 347 template <> 348 struct Strategy<TransferReadOp> { 349 /// Find the StoreOp that is used for writing the current TransferReadOp's 350 /// result to the temporary buffer allocation. 351 static memref::StoreOp getStoreOp(TransferReadOp xferOp) { 352 assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); 353 auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); 354 assert(storeOp && "Expected TransferReadOp result used by StoreOp"); 355 return storeOp; 356 } 357 358 /// Find the temporary buffer allocation. All labeled TransferReadOps are 359 /// used like this, where %buf is either the buffer allocation or a type cast 360 /// of the buffer allocation: 361 /// ``` 362 /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... 363 /// memref.store %vec, %buf[...] ... 364 /// ``` 365 static Value getBuffer(TransferReadOp xferOp) { 366 return getStoreOp(xferOp).getMemRef(); 367 } 368 369 /// Retrieve the indices of the current StoreOp that stores into the buffer. 370 static void getBufferIndices(TransferReadOp xferOp, 371 SmallVector<Value, 8> &indices) { 372 memref::StoreOp storeOp = getStoreOp(xferOp); 373 auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices(); 374 indices.append(prevIndices.begin(), prevIndices.end()); 375 } 376 377 /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds 378 /// accesses on the to-be-unpacked dimension. 379 /// 380 /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration 381 /// variable `iv`. 382 /// 2. Store the result into the (already `vector.type_cast`ed) buffer. 383 /// 384 /// E.g.: 385 /// ``` 386 /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst 387 /// : memref<?x?x?xf32>, vector<4x3xf32> 388 /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> 389 /// ``` 390 /// Is rewritten to: 391 /// ``` 392 /// %casted = vector.type_cast %buf 393 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 394 /// for %j = 0 to 4 { 395 /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst 396 /// : memref<?x?x?xf32>, vector<3xf32> 397 /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> 398 /// } 399 /// ``` 400 /// 401 /// Note: The loop and type cast are generated in TransferOpConversion. 402 /// The original TransferReadOp and store op are deleted in `cleanup`. 403 /// Note: The `mask` operand is set in TransferOpConversion. 404 static TransferReadOp rewriteOp(OpBuilder &b, 405 VectorTransferToSCFOptions options, 406 TransferReadOp xferOp, Value buffer, Value iv, 407 ValueRange /*loopState*/) { 408 SmallVector<Value, 8> storeIndices; 409 getBufferIndices(xferOp, storeIndices); 410 storeIndices.push_back(iv); 411 412 SmallVector<Value, 8> xferIndices; 413 getXferIndices(b, xferOp, iv, xferIndices); 414 415 Location loc = xferOp.getLoc(); 416 auto bufferType = dyn_cast<ShapedType>(buffer.getType()); 417 auto vecType = dyn_cast<VectorType>(bufferType.getElementType()); 418 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 419 auto newXferOp = b.create<vector::TransferReadOp>( 420 loc, vecType, xferOp.getSource(), xferIndices, 421 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 422 xferOp.getPadding(), Value(), inBoundsAttr); 423 424 maybeApplyPassLabel(b, newXferOp, options.targetRank); 425 426 b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices); 427 return newXferOp; 428 } 429 430 /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write 431 /// padding value to the temporary buffer. 432 static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, 433 Value buffer, Value iv, 434 ValueRange /*loopState*/) { 435 SmallVector<Value, 8> storeIndices; 436 getBufferIndices(xferOp, storeIndices); 437 storeIndices.push_back(iv); 438 439 Location loc = xferOp.getLoc(); 440 auto bufferType = dyn_cast<ShapedType>(buffer.getType()); 441 auto vecType = dyn_cast<VectorType>(bufferType.getElementType()); 442 auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding()); 443 b.create<memref::StoreOp>(loc, vec, buffer, storeIndices); 444 445 return Value(); 446 } 447 448 /// Cleanup after rewriting the op. 449 static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, 450 scf::ForOp /*forOp*/) { 451 rewriter.eraseOp(getStoreOp(xferOp)); 452 rewriter.eraseOp(xferOp); 453 } 454 455 /// Return the initial loop state for the generated scf.for loop. 456 static Value initialLoopState(TransferReadOp xferOp) { return Value(); } 457 }; 458 459 /// Codegen strategy for vector TransferWriteOp. 460 template <> 461 struct Strategy<TransferWriteOp> { 462 /// Find the temporary buffer allocation. All labeled TransferWriteOps are 463 /// used like this, where %buf is either the buffer allocation or a type cast 464 /// of the buffer allocation: 465 /// ``` 466 /// %vec = memref.load %buf[...] ... 467 /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... 468 /// ``` 469 static Value getBuffer(TransferWriteOp xferOp) { 470 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>(); 471 assert(loadOp && "Expected transfer op vector produced by LoadOp"); 472 return loadOp.getMemRef(); 473 } 474 475 /// Retrieve the indices of the current LoadOp that loads from the buffer. 476 static void getBufferIndices(TransferWriteOp xferOp, 477 SmallVector<Value, 8> &indices) { 478 auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>(); 479 auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices(); 480 indices.append(prevIndices.begin(), prevIndices.end()); 481 } 482 483 /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds 484 /// accesses on the to-be-unpacked dimension. 485 /// 486 /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, 487 /// using the loop iteration variable `iv`. 488 /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back 489 /// to memory. 490 /// 491 /// Note: For more details, see comments on Strategy<TransferReadOp>. 492 static TransferWriteOp rewriteOp(OpBuilder &b, 493 VectorTransferToSCFOptions options, 494 TransferWriteOp xferOp, Value buffer, 495 Value iv, ValueRange loopState) { 496 SmallVector<Value, 8> loadIndices; 497 getBufferIndices(xferOp, loadIndices); 498 loadIndices.push_back(iv); 499 500 SmallVector<Value, 8> xferIndices; 501 getXferIndices(b, xferOp, iv, xferIndices); 502 503 Location loc = xferOp.getLoc(); 504 auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices); 505 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 506 auto source = loopState.empty() ? xferOp.getSource() : loopState[0]; 507 Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 508 auto newXferOp = b.create<vector::TransferWriteOp>( 509 loc, type, vec, source, xferIndices, 510 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 511 inBoundsAttr); 512 513 maybeApplyPassLabel(b, newXferOp, options.targetRank); 514 515 return newXferOp; 516 } 517 518 /// Handle out-of-bounds accesses on the to-be-unpacked dimension. 519 static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, 520 Value buffer, Value iv, 521 ValueRange loopState) { 522 return isTensorOp(xferOp) ? loopState[0] : Value(); 523 } 524 525 /// Cleanup after rewriting the op. 526 static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, 527 scf::ForOp forOp) { 528 if (isTensorOp(xferOp)) { 529 assert(forOp->getNumResults() == 1 && "Expected one for loop result"); 530 rewriter.replaceOp(xferOp, forOp->getResult(0)); 531 } else { 532 rewriter.eraseOp(xferOp); 533 } 534 } 535 536 /// Return the initial loop state for the generated scf.for loop. 537 static Value initialLoopState(TransferWriteOp xferOp) { 538 return isTensorOp(xferOp) ? xferOp.getSource() : Value(); 539 } 540 }; 541 542 template <typename OpTy> 543 LogicalResult checkPrepareXferOp(OpTy xferOp, 544 VectorTransferToSCFOptions options) { 545 if (xferOp->hasAttr(kPassLabel)) 546 return failure(); 547 if (xferOp.getVectorType().getRank() <= options.targetRank) 548 return failure(); 549 // Currently the unpacking of the leading dimension into the memref is not 550 // supported for scalable dimensions. 551 if (xferOp.getVectorType().getScalableDims().front()) 552 return failure(); 553 if (isTensorOp(xferOp) && !options.lowerTensors) 554 return failure(); 555 // Transfer ops that modify the element type are not supported atm. 556 if (xferOp.getVectorType().getElementType() != 557 xferOp.getShapedType().getElementType()) 558 return failure(); 559 return success(); 560 } 561 562 /// Prepare a TransferReadOp for progressive lowering. 563 /// 564 /// 1. Allocate a temporary buffer. 565 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. 566 /// 3. Store the result of the TransferReadOp into the temporary buffer. 567 /// 4. Load the result from the temporary buffer and replace all uses of the 568 /// original TransferReadOp with this load. 569 /// 570 /// E.g.: 571 /// ``` 572 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst 573 /// : vector<5x4xf32>, memref<?x?x?xf32> 574 /// ``` 575 /// is rewritten to: 576 /// ``` 577 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 578 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst 579 /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> 580 /// memref.store %1, %0[] : memref<vector<5x4xf32>> 581 /// %vec = memref.load %0[] : memref<vector<5x4xf32>> 582 /// ``` 583 /// 584 /// Note: A second temporary buffer may be allocated for the `mask` operand. 585 struct PrepareTransferReadConversion 586 : public VectorToSCFPattern<TransferReadOp> { 587 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 588 589 LogicalResult matchAndRewrite(TransferReadOp xferOp, 590 PatternRewriter &rewriter) const override { 591 if (checkPrepareXferOp(xferOp, options).failed()) 592 return failure(); 593 594 BufferAllocs buffers = allocBuffers(rewriter, xferOp); 595 Operation *newXfer = rewriter.clone(*xferOp.getOperation()); 596 newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); 597 if (xferOp.getMask()) { 598 dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign( 599 buffers.maskBuffer); 600 } 601 602 Location loc = xferOp.getLoc(); 603 rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0), 604 buffers.dataBuffer); 605 rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); 606 607 return success(); 608 } 609 }; 610 611 /// Prepare a TransferWriteOp for progressive lowering. 612 /// 613 /// 1. Allocate a temporary buffer. 614 /// 2. Store the vector into the buffer. 615 /// 3. Load the vector from the buffer again. 616 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, 617 /// marking it eligible for progressive lowering via TransferOpConversion. 618 /// 619 /// E.g.: 620 /// ``` 621 /// vector.transfer_write %vec, %A[%a, %b, %c] 622 /// : vector<5x4xf32>, memref<?x?x?xf32> 623 /// ``` 624 /// is rewritten to: 625 /// ``` 626 /// %0 = memref.alloca() : memref<vector<5x4xf32>> 627 /// memref.store %vec, %0[] : memref<vector<5x4xf32>> 628 /// %1 = memref.load %0[] : memref<vector<5x4xf32>> 629 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } 630 /// : vector<5x4xf32>, memref<?x?x?xf32> 631 /// ``` 632 /// 633 /// Note: A second temporary buffer may be allocated for the `mask` operand. 634 struct PrepareTransferWriteConversion 635 : public VectorToSCFPattern<TransferWriteOp> { 636 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 637 638 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 639 PatternRewriter &rewriter) const override { 640 if (checkPrepareXferOp(xferOp, options).failed()) 641 return failure(); 642 643 Location loc = xferOp.getLoc(); 644 auto buffers = allocBuffers(rewriter, xferOp); 645 rewriter.create<memref::StoreOp>(loc, xferOp.getVector(), 646 buffers.dataBuffer); 647 auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer); 648 rewriter.updateRootInPlace(xferOp, [&]() { 649 xferOp.getVectorMutable().assign(loadedVec); 650 xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); 651 }); 652 653 if (xferOp.getMask()) { 654 rewriter.updateRootInPlace(xferOp, [&]() { 655 xferOp.getMaskMutable().assign(buffers.maskBuffer); 656 }); 657 } 658 659 return success(); 660 } 661 }; 662 663 /// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows 664 /// printing both 1D scalable vectors and n-D fixed size vectors. 665 /// 666 /// E.g.: 667 /// ``` 668 /// vector.print %v : vector<[4]xi32> 669 /// ``` 670 /// is rewritten to: 671 /// ``` 672 /// %c0 = arith.constant 0 : index 673 /// %c4 = arith.constant 4 : index 674 /// %c1 = arith.constant 1 : index 675 /// %vscale = vector.vscale 676 /// %length = arith.muli %vscale, %c4 : index 677 /// %lastIndex = arith.subi %length, %c1 : index 678 /// vector.print punctuation <open> 679 /// scf.for %i = %c0 to %length step %c1 { 680 /// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> 681 /// vector.print %el : i32 punctuation <no_punctuation> 682 /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index 683 /// scf.if %notLastIndex { 684 /// vector.print punctuation <comma> 685 /// } 686 /// } 687 /// vector.print punctuation <close> 688 /// vector.print 689 /// ``` 690 struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { 691 using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern; 692 LogicalResult matchAndRewrite(vector::PrintOp printOp, 693 PatternRewriter &rewriter) const override { 694 if (!printOp.getSource()) 695 return failure(); 696 697 VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType()); 698 if (!vectorType) 699 return failure(); 700 701 // Currently >= 2D scalable vectors are not supported. 702 // These can't be lowered to LLVM (as LLVM does not support scalable vectors 703 // of scalable vectors), and due to limitations of current ops can't be 704 // indexed with SSA values or flattened. This may change after 705 // https://reviews.llvm.org/D155034, though there still needs to be a path 706 // for lowering to LLVM. 707 if (vectorType.getRank() > 1 && vectorType.isScalable()) 708 return failure(); 709 710 auto loc = printOp.getLoc(); 711 auto value = printOp.getSource(); 712 713 if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) { 714 // Oddly sized integers are (somewhat) buggy on a lot of backends, so to 715 // avoid issues extend them to a more standard size. 716 // https://github.com/llvm/llvm-project/issues/30613 717 auto width = intTy.getWidth(); 718 auto legalWidth = llvm::NextPowerOf2(std::max(8u, width) - 1); 719 auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth, 720 intTy.getSignedness()); 721 // arith can only take signless integers, so we must cast back and forth. 722 auto signlessSourceVectorType = 723 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy)); 724 auto signlessTargetVectorType = 725 vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy)); 726 auto targetVectorType = vectorType.cloneWith({}, legalIntTy); 727 value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType, 728 value); 729 if (value.getType() != signlessTargetVectorType) { 730 if (width == 1 || intTy.isUnsigned()) 731 value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType, 732 value); 733 else 734 value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType, 735 value); 736 } 737 value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value); 738 vectorType = targetVectorType; 739 } 740 741 auto scalableDimensions = vectorType.getScalableDims(); 742 auto shape = vectorType.getShape(); 743 constexpr int64_t singletonShape[] = {1}; 744 if (vectorType.getRank() == 0) 745 shape = singletonShape; 746 747 if (vectorType.getRank() != 1) { 748 // Flatten n-D vectors to 1D. This is done to allow indexing with a 749 // non-constant value (which can currently only be done via 750 // vector.extractelement for 1D vectors). 751 auto flatLength = std::accumulate(shape.begin(), shape.end(), 1, 752 std::multiplies<int64_t>()); 753 auto flatVectorType = 754 VectorType::get({flatLength}, vectorType.getElementType()); 755 value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value); 756 } 757 758 vector::PrintOp firstClose; 759 SmallVector<Value, 8> loopIndices; 760 for (unsigned d = 0; d < shape.size(); d++) { 761 // Setup loop bounds and step. 762 Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); 763 Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]); 764 Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 765 if (!scalableDimensions.empty() && scalableDimensions[d]) { 766 auto vscale = rewriter.create<vector::VectorScaleOp>( 767 loc, rewriter.getIndexType()); 768 upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale); 769 } 770 auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step); 771 772 // Create a loop to print the elements surrounded by parentheses. 773 rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); 774 auto loop = 775 rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); 776 auto printClose = rewriter.create<vector::PrintOp>( 777 loc, vector::PrintPunctuation::Close); 778 if (!firstClose) 779 firstClose = printClose; 780 781 auto loopIdx = loop.getInductionVar(); 782 loopIndices.push_back(loopIdx); 783 784 // Print a comma after all but the last element. 785 rewriter.setInsertionPointToStart(loop.getBody()); 786 auto notLastIndex = rewriter.create<arith::CmpIOp>( 787 loc, arith::CmpIPredicate::ult, loopIdx, lastIndex); 788 rewriter.create<scf::IfOp>(loc, notLastIndex, 789 [&](OpBuilder &builder, Location loc) { 790 builder.create<vector::PrintOp>( 791 loc, vector::PrintPunctuation::Comma); 792 builder.create<scf::YieldOp>(loc); 793 }); 794 795 rewriter.setInsertionPointToStart(loop.getBody()); 796 } 797 798 // Compute the flattened index. 799 // Note: For the > rank 1 vectors this assumes non-scalable. 800 Value flatIndex; 801 auto currentStride = 1; 802 for (int d = shape.size() - 1; d >= 0; d--) { 803 auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride); 804 auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]); 805 if (flatIndex) 806 flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index); 807 else 808 flatIndex = index; 809 currentStride *= shape[d]; 810 } 811 812 // Print the scalar elements in the inner most loop. 813 auto element = 814 rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex); 815 rewriter.create<vector::PrintOp>(loc, element, 816 vector::PrintPunctuation::NoPunctuation); 817 818 rewriter.setInsertionPointAfter(firstClose); 819 rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation()); 820 rewriter.eraseOp(printOp); 821 return success(); 822 } 823 824 static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) { 825 return IntegerType::get(intTy.getContext(), intTy.getWidth(), 826 IntegerType::Signless); 827 }; 828 }; 829 830 /// Progressive lowering of vector transfer ops: Unpack one dimension. 831 /// 832 /// 1. Unpack one dimension from the current buffer type and cast the buffer 833 /// to that new type. E.g.: 834 /// ``` 835 /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> 836 /// vector.transfer_write %vec ... 837 /// ``` 838 /// The following cast is generated: 839 /// ``` 840 /// %casted = vector.type_cast %0 841 /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 842 /// ``` 843 /// 2. Generate a for loop and rewrite the transfer op according to the 844 /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be 845 /// out-of-bounds, generate an if-check and handle both cases separately. 846 /// 3. Clean up according to the corresponding Strategy<OpTy>. 847 /// 848 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor 849 /// source (as opposed to a memref source), then each iteration of the generated 850 /// scf.for loop yields the new tensor value. E.g.: 851 /// ``` 852 /// %result = scf.for i = 0 to 5 { 853 /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> 854 /// %1 = vector.transfer_write %0, %source[...] 855 /// : vector<4x3xf32>, tensor<5x4x3xf32> 856 /// scf.yield %1 : tensor<5x4x3xf32> 857 /// } 858 /// ``` 859 template <typename OpTy> 860 struct TransferOpConversion : public VectorToSCFPattern<OpTy> { 861 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 862 863 void initialize() { 864 // This pattern recursively unpacks one dimension at a time. The recursion 865 // bounded as the rank is strictly decreasing. 866 this->setHasBoundedRewriteRecursion(); 867 } 868 869 LogicalResult matchAndRewrite(OpTy xferOp, 870 PatternRewriter &rewriter) const override { 871 if (!xferOp->hasAttr(kPassLabel)) 872 return failure(); 873 874 // Find and cast data buffer. How the buffer can be found depends on OpTy. 875 ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); 876 auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp); 877 auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType()); 878 auto castedDataType = unpackOneDim(dataBufferType); 879 if (failed(castedDataType)) 880 return failure(); 881 882 auto castedDataBuffer = 883 locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer); 884 885 // If the xferOp has a mask: Find and cast mask buffer. 886 Value castedMaskBuffer; 887 if (xferOp.getMask()) { 888 Value maskBuffer = getMaskBuffer(xferOp); 889 if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { 890 // Do not unpack a dimension of the mask, if: 891 // * To-be-unpacked transfer op dimension is a broadcast. 892 // * Mask is 1D, i.e., the mask cannot be further unpacked. 893 // (That means that all remaining dimensions of the transfer op must 894 // be broadcasted.) 895 castedMaskBuffer = maskBuffer; 896 } else { 897 // It's safe to assume the mask buffer can be unpacked if the data 898 // buffer was unpacked. 899 auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType()); 900 MemRefType castedMaskType = *unpackOneDim(maskBufferType); 901 castedMaskBuffer = 902 locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); 903 } 904 } 905 906 // Loop bounds and step. 907 auto lb = locB.create<arith::ConstantIndexOp>(0); 908 auto ub = locB.create<arith::ConstantIndexOp>( 909 castedDataType->getDimSize(castedDataType->getRank() - 1)); 910 auto step = locB.create<arith::ConstantIndexOp>(1); 911 // TransferWriteOps that operate on tensors return the modified tensor and 912 // require a loop state. 913 auto loopState = Strategy<OpTy>::initialLoopState(xferOp); 914 915 // Generate for loop. 916 auto result = locB.create<scf::ForOp>( 917 lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 918 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 919 Type stateType = loopState.empty() ? Type() : loopState[0].getType(); 920 921 auto result = generateInBoundsCheck( 922 b, xferOp, iv, unpackedDim(xferOp), 923 stateType ? TypeRange(stateType) : TypeRange(), 924 /*inBoundsCase=*/ 925 [&](OpBuilder &b, Location loc) { 926 // Create new transfer op. 927 OpTy newXfer = Strategy<OpTy>::rewriteOp( 928 b, this->options, xferOp, castedDataBuffer, iv, loopState); 929 930 // If old transfer op has a mask: Set mask on new transfer op. 931 // Special case: If the mask of the old transfer op is 1D and 932 // the 933 // unpacked dim is not a broadcast, no mask is 934 // needed on the new transfer op. 935 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) || 936 xferOp.getMaskType().getRank() > 1)) { 937 OpBuilder::InsertionGuard guard(b); 938 b.setInsertionPoint(newXfer); // Insert load before newXfer. 939 940 SmallVector<Value, 8> loadIndices; 941 if (auto memrefType = 942 castedMaskBuffer.getType().dyn_cast<MemRefType>()) { 943 // If castedMaskBuffer is a memref, then one dim was 944 // unpacked; see above. 945 loadIndices.push_back(iv); 946 } else { 947 Strategy<OpTy>::getBufferIndices(xferOp, loadIndices); 948 // In case of broadcast: Use same indices to load from 949 // memref as before. 950 if (!xferOp.isBroadcastDim(0)) 951 loadIndices.push_back(iv); 952 } 953 954 auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, 955 loadIndices); 956 rewriter.updateRootInPlace(newXfer, [&]() { 957 newXfer.getMaskMutable().assign(mask); 958 }); 959 } 960 961 return loopState.empty() ? Value() : newXfer->getResult(0); 962 }, 963 /*outOfBoundsCase=*/ 964 [&](OpBuilder &b, Location /*loc*/) { 965 return Strategy<OpTy>::handleOutOfBoundsDim( 966 b, xferOp, castedDataBuffer, iv, loopState); 967 }); 968 969 maybeYieldValue(b, loc, !loopState.empty(), result); 970 }); 971 972 Strategy<OpTy>::cleanup(rewriter, xferOp, result); 973 return success(); 974 } 975 }; 976 977 } // namespace lowering_n_d 978 979 namespace lowering_n_d_unrolled { 980 981 /// If the original transfer op has a mask, compute the mask of the new transfer 982 /// op (for the current iteration `i`) and assign it. 983 template <typename OpTy> 984 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, 985 int64_t i) { 986 if (!xferOp.getMask()) 987 return; 988 989 if (xferOp.isBroadcastDim(0)) { 990 // To-be-unpacked dimension is a broadcast, which does not have a 991 // corresponding mask dimension. Mask attribute remains unchanged. 992 newXferOp.getMaskMutable().assign(xferOp.getMask()); 993 return; 994 } 995 996 if (xferOp.getMaskType().getRank() > 1) { 997 // Unpack one dimension of the mask. 998 OpBuilder::InsertionGuard guard(b); 999 b.setInsertionPoint(newXferOp); // Insert load before newXfer. 1000 1001 llvm::SmallVector<int64_t, 1> indices({i}); 1002 Location loc = xferOp.getLoc(); 1003 auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices); 1004 newXferOp.getMaskMutable().assign(newMask); 1005 } 1006 1007 // If we end up here: The mask of the old transfer op is 1D and the unpacked 1008 // dim is not a broadcast, so no mask is needed on the new transfer op. 1009 // `generateInBoundsCheck` will have evaluated the mask already. 1010 } 1011 1012 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one 1013 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no 1014 /// memref buffer is allocated and the SCF loop is fully unrolled. 1015 /// 1016 /// ``` 1017 /// E.g.: 1018 /// ``` 1019 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding 1020 /// : memref<?x?x?xf32>, vector<5x4xf32> 1021 /// ``` 1022 /// is rewritten to IR such as (simplified): 1023 /// ``` 1024 /// %v_init = splat %padding : vector<5x4xf32> 1025 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding 1026 /// : memref<?x?x?xf32>, vector<4xf32> 1027 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> 1028 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding 1029 /// : memref<?x?x?xf32>, vector<4xf32> 1030 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> 1031 /// ... 1032 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding 1033 /// : memref<?x?x?xf32>, vector<4xf32> 1034 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> 1035 /// ``` 1036 /// 1037 /// Note: As an optimization, if the result of the original TransferReadOp 1038 /// was directly inserted into another vector, no new %v_init vector is created. 1039 /// Instead, the new TransferReadOp results are inserted into that vector. 1040 struct UnrollTransferReadConversion 1041 : public VectorToSCFPattern<TransferReadOp> { 1042 using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 1043 1044 void initialize() { 1045 // This pattern recursively unpacks one dimension at a time. The recursion 1046 // bounded as the rank is strictly decreasing. 1047 setHasBoundedRewriteRecursion(); 1048 } 1049 1050 /// Return the vector into which the newly created TransferReadOp results 1051 /// are inserted. 1052 Value getResultVector(TransferReadOp xferOp, 1053 PatternRewriter &rewriter) const { 1054 if (auto insertOp = getInsertOp(xferOp)) 1055 return insertOp.getDest(); 1056 Location loc = xferOp.getLoc(); 1057 return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(), 1058 xferOp.getPadding()); 1059 } 1060 1061 /// If the result of the TransferReadOp has exactly one user, which is a 1062 /// vector::InsertOp, return that operation. 1063 vector::InsertOp getInsertOp(TransferReadOp xferOp) const { 1064 if (xferOp->hasOneUse()) { 1065 Operation *xferOpUser = *xferOp->getUsers().begin(); 1066 if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) 1067 return insertOp; 1068 } 1069 1070 return vector::InsertOp(); 1071 } 1072 1073 /// If the result of the TransferReadOp has exactly one user, which is a 1074 /// vector::InsertOp, return that operation's indices. 1075 void getInsertionIndices(TransferReadOp xferOp, 1076 SmallVectorImpl<OpFoldResult> &indices) const { 1077 if (auto insertOp = getInsertOp(xferOp)) { 1078 auto pos = insertOp.getMixedPosition(); 1079 indices.append(pos.begin(), pos.end()); 1080 } 1081 } 1082 1083 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 1084 /// accesses, and broadcasts and transposes in permutation maps. 1085 LogicalResult matchAndRewrite(TransferReadOp xferOp, 1086 PatternRewriter &rewriter) const override { 1087 if (xferOp.getVectorType().getRank() <= options.targetRank) 1088 return failure(); 1089 if (isTensorOp(xferOp) && !options.lowerTensors) 1090 return failure(); 1091 // Transfer ops that modify the element type are not supported atm. 1092 if (xferOp.getVectorType().getElementType() != 1093 xferOp.getShapedType().getElementType()) 1094 return failure(); 1095 1096 auto insertOp = getInsertOp(xferOp); 1097 auto vec = getResultVector(xferOp, rewriter); 1098 auto vecType = dyn_cast<VectorType>(vec.getType()); 1099 auto xferVecType = xferOp.getVectorType(); 1100 1101 if (xferVecType.getScalableDims()[0]) { 1102 // Cannot unroll a scalable dimension at compile time. 1103 return failure(); 1104 } 1105 1106 VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0); 1107 1108 int64_t dimSize = xferVecType.getShape()[0]; 1109 1110 // Generate fully unrolled loop of transfer ops. 1111 Location loc = xferOp.getLoc(); 1112 for (int64_t i = 0; i < dimSize; ++i) { 1113 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 1114 1115 vec = generateInBoundsCheck( 1116 rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), 1117 /*inBoundsCase=*/ 1118 [&](OpBuilder &b, Location loc) { 1119 // Indices for the new transfer op. 1120 SmallVector<Value, 8> xferIndices; 1121 getXferIndices(b, xferOp, iv, xferIndices); 1122 1123 // Indices for the new vector.insert op. 1124 SmallVector<OpFoldResult, 8> insertionIndices; 1125 getInsertionIndices(xferOp, insertionIndices); 1126 insertionIndices.push_back(rewriter.getIndexAttr(i)); 1127 1128 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 1129 auto newXferOp = b.create<vector::TransferReadOp>( 1130 loc, newXferVecType, xferOp.getSource(), xferIndices, 1131 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 1132 xferOp.getPadding(), Value(), inBoundsAttr); 1133 maybeAssignMask(b, xferOp, newXferOp, i); 1134 return b.create<vector::InsertOp>(loc, newXferOp, vec, 1135 insertionIndices); 1136 }, 1137 /*outOfBoundsCase=*/ 1138 [&](OpBuilder &b, Location loc) { 1139 // Loop through original (unmodified) vector. 1140 return vec; 1141 }); 1142 } 1143 1144 if (insertOp) { 1145 // Rewrite single user of the old TransferReadOp, which was an InsertOp. 1146 rewriter.replaceOp(insertOp, vec); 1147 rewriter.eraseOp(xferOp); 1148 } else { 1149 rewriter.replaceOp(xferOp, vec); 1150 } 1151 1152 return success(); 1153 } 1154 }; 1155 1156 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one 1157 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no 1158 /// memref buffer is allocated and the SCF loop is fully unrolled. 1159 /// 1160 /// ``` 1161 /// E.g.: 1162 /// ``` 1163 /// vector.transfer_write %vec, %A[%a, %b, %c] 1164 /// : vector<5x4xf32>, memref<?x?x?xf32> 1165 /// ``` 1166 /// is rewritten to IR such as (simplified): 1167 /// ``` 1168 /// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32> 1169 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> 1170 /// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32> 1171 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> 1172 /// ... 1173 /// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32> 1174 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> 1175 /// ``` 1176 /// 1177 /// Note: As an optimization, if the vector of the original TransferWriteOp 1178 /// was directly extracted from another vector via an ExtractOp `a`, extract 1179 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By 1180 /// doing so, `a` may become dead, and the number of ExtractOps generated during 1181 /// recursive application of this pattern will be minimal. 1182 struct UnrollTransferWriteConversion 1183 : public VectorToSCFPattern<TransferWriteOp> { 1184 using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 1185 1186 void initialize() { 1187 // This pattern recursively unpacks one dimension at a time. The recursion 1188 // bounded as the rank is strictly decreasing. 1189 setHasBoundedRewriteRecursion(); 1190 } 1191 1192 /// Return the vector from which newly generated ExtracOps will extract. 1193 Value getDataVector(TransferWriteOp xferOp) const { 1194 if (auto extractOp = getExtractOp(xferOp)) 1195 return extractOp.getVector(); 1196 return xferOp.getVector(); 1197 } 1198 1199 /// If the input of the given TransferWriteOp is an ExtractOp, return it. 1200 vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { 1201 if (auto *op = xferOp.getVector().getDefiningOp()) 1202 return dyn_cast<vector::ExtractOp>(op); 1203 return vector::ExtractOp(); 1204 } 1205 1206 /// If the input of the given TransferWriteOp is an ExtractOp, return its 1207 /// indices. 1208 void getExtractionIndices(TransferWriteOp xferOp, 1209 SmallVectorImpl<OpFoldResult> &indices) const { 1210 if (auto extractOp = getExtractOp(xferOp)) { 1211 auto pos = extractOp.getMixedPosition(); 1212 indices.append(pos.begin(), pos.end()); 1213 } 1214 } 1215 1216 /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 1217 /// accesses, and broadcasts and transposes in permutation maps. 1218 LogicalResult matchAndRewrite(TransferWriteOp xferOp, 1219 PatternRewriter &rewriter) const override { 1220 VectorType inputVectorTy = xferOp.getVectorType(); 1221 1222 if (inputVectorTy.getRank() <= options.targetRank) 1223 return failure(); 1224 1225 if (isTensorOp(xferOp) && !options.lowerTensors) 1226 return failure(); 1227 // Transfer ops that modify the element type are not supported atm. 1228 if (inputVectorTy.getElementType() != 1229 xferOp.getShapedType().getElementType()) 1230 return failure(); 1231 1232 auto vec = getDataVector(xferOp); 1233 if (inputVectorTy.getScalableDims()[0]) { 1234 // Cannot unroll a scalable dimension at compile time. 1235 return failure(); 1236 } 1237 1238 int64_t dimSize = inputVectorTy.getShape()[0]; 1239 Value source = xferOp.getSource(); // memref or tensor to be written to. 1240 auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 1241 1242 // Generate fully unrolled loop of transfer ops. 1243 Location loc = xferOp.getLoc(); 1244 for (int64_t i = 0; i < dimSize; ++i) { 1245 Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 1246 1247 auto updatedSource = generateInBoundsCheck( 1248 rewriter, xferOp, iv, unpackedDim(xferOp), 1249 isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), 1250 /*inBoundsCase=*/ 1251 [&](OpBuilder &b, Location loc) { 1252 // Indices for the new transfer op. 1253 SmallVector<Value, 8> xferIndices; 1254 getXferIndices(b, xferOp, iv, xferIndices); 1255 1256 // Indices for the new vector.extract op. 1257 SmallVector<OpFoldResult, 8> extractionIndices; 1258 getExtractionIndices(xferOp, extractionIndices); 1259 extractionIndices.push_back(b.getI64IntegerAttr(i)); 1260 1261 auto extracted = 1262 b.create<vector::ExtractOp>(loc, vec, extractionIndices); 1263 auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 1264 Value xferVec; 1265 if (inputVectorTy.getRank() == 1) { 1266 // When target-rank=0, unrolling would causes the vector input 1267 // argument into `transfer_write` to become a scalar. We solve 1268 // this by broadcasting the scalar to a 0D vector. 1269 xferVec = b.create<vector::BroadcastOp>( 1270 loc, VectorType::get({}, extracted.getType()), extracted); 1271 } else { 1272 xferVec = extracted; 1273 } 1274 auto newXferOp = b.create<vector::TransferWriteOp>( 1275 loc, sourceType, xferVec, source, xferIndices, 1276 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 1277 inBoundsAttr); 1278 1279 maybeAssignMask(b, xferOp, newXferOp, i); 1280 1281 return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); 1282 }, 1283 /*outOfBoundsCase=*/ 1284 [&](OpBuilder &b, Location loc) { 1285 return isTensorOp(xferOp) ? source : Value(); 1286 }); 1287 1288 if (isTensorOp(xferOp)) 1289 source = updatedSource; 1290 } 1291 1292 if (isTensorOp(xferOp)) 1293 rewriter.replaceOp(xferOp, source); 1294 else 1295 rewriter.eraseOp(xferOp); 1296 1297 return success(); 1298 } 1299 }; 1300 1301 } // namespace lowering_n_d_unrolled 1302 1303 namespace lowering_1_d { 1304 1305 /// Compute the indices into the memref for the LoadOp/StoreOp generated as 1306 /// part of TransferOp1dConversion. Return the memref dimension on which 1307 /// the transfer is operating. A return value of std::nullopt indicates a 1308 /// broadcast. 1309 template <typename OpTy> 1310 static std::optional<int64_t> 1311 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, 1312 SmallVector<Value, 8> &memrefIndices) { 1313 auto indices = xferOp.getIndices(); 1314 auto map = xferOp.getPermutationMap(); 1315 assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 1316 1317 memrefIndices.append(indices.begin(), indices.end()); 1318 assert(map.getNumResults() == 1 && 1319 "Expected 1 permutation map result for 1D transfer"); 1320 if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) { 1321 Location loc = xferOp.getLoc(); 1322 auto dim = expr.getPosition(); 1323 AffineExpr d0, d1; 1324 bindDims(xferOp.getContext(), d0, d1); 1325 Value offset = memrefIndices[dim]; 1326 memrefIndices[dim] = 1327 affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 1328 return dim; 1329 } 1330 1331 assert(xferOp.isBroadcastDim(0) && 1332 "Expected AffineDimExpr or AffineConstantExpr"); 1333 return std::nullopt; 1334 } 1335 1336 /// Codegen strategy for TransferOp1dConversion, depending on the 1337 /// operation. 1338 template <typename OpTy> 1339 struct Strategy1d; 1340 1341 /// Codegen strategy for TransferReadOp. 1342 template <> 1343 struct Strategy1d<TransferReadOp> { 1344 static void generateForLoopBody(OpBuilder &b, Location loc, 1345 TransferReadOp xferOp, Value iv, 1346 ValueRange loopState) { 1347 SmallVector<Value, 8> indices; 1348 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1349 auto vec = loopState[0]; 1350 1351 // In case of out-of-bounds access, leave `vec` as is (was initialized with 1352 // padding value). 1353 auto nextVec = generateInBoundsCheck( 1354 b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), 1355 /*inBoundsCase=*/ 1356 [&](OpBuilder &b, Location loc) { 1357 Value val = 1358 b.create<memref::LoadOp>(loc, xferOp.getSource(), indices); 1359 return b.create<vector::InsertElementOp>(loc, val, vec, iv); 1360 }, 1361 /*outOfBoundsCase=*/ 1362 [&](OpBuilder & /*b*/, Location loc) { return vec; }); 1363 b.create<scf::YieldOp>(loc, nextVec); 1364 } 1365 1366 static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { 1367 // Inititalize vector with padding value. 1368 Location loc = xferOp.getLoc(); 1369 return b.create<vector::SplatOp>(loc, xferOp.getVectorType(), 1370 xferOp.getPadding()); 1371 } 1372 }; 1373 1374 /// Codegen strategy for TransferWriteOp. 1375 template <> 1376 struct Strategy1d<TransferWriteOp> { 1377 static void generateForLoopBody(OpBuilder &b, Location loc, 1378 TransferWriteOp xferOp, Value iv, 1379 ValueRange /*loopState*/) { 1380 SmallVector<Value, 8> indices; 1381 auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 1382 1383 // Nothing to do in case of out-of-bounds access. 1384 generateInBoundsCheck( 1385 b, xferOp, iv, dim, 1386 /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { 1387 auto val = 1388 b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv); 1389 b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices); 1390 }); 1391 b.create<scf::YieldOp>(loc); 1392 } 1393 1394 static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { 1395 return Value(); 1396 } 1397 }; 1398 1399 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is 1400 /// necessary in cases where a 1D vector transfer op cannot be lowered into 1401 /// vector load/stores due to non-unit strides or broadcasts: 1402 /// 1403 /// * Transfer dimension is not the last memref dimension 1404 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) 1405 /// * Memref has a layout map with non-unit stride on the last dimension 1406 /// 1407 /// This pattern generates IR as follows: 1408 /// 1409 /// 1. Generate a for loop iterating over each vector element. 1410 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, 1411 /// depending on OpTy. 1412 /// 1413 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp 1414 /// can be generated instead of TransferOp1dConversion. Add such a pattern 1415 /// to ConvertVectorToLLVM. 1416 /// 1417 /// E.g.: 1418 /// ``` 1419 /// vector.transfer_write %vec, %A[%a, %b] 1420 /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 1421 /// : vector<9xf32>, memref<?x?xf32> 1422 /// ``` 1423 /// Is rewritten to approximately the following pseudo-IR: 1424 /// ``` 1425 /// for i = 0 to 9 { 1426 /// %t = vector.extractelement %vec[i] : vector<9xf32> 1427 /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> 1428 /// } 1429 /// ``` 1430 template <typename OpTy> 1431 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { 1432 using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 1433 1434 LogicalResult matchAndRewrite(OpTy xferOp, 1435 PatternRewriter &rewriter) const override { 1436 // TODO: support 0-d corner case. 1437 if (xferOp.getTransferRank() == 0) 1438 return failure(); 1439 auto map = xferOp.getPermutationMap(); 1440 auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType()); 1441 1442 if (!memRefType) 1443 return failure(); 1444 if (xferOp.getVectorType().getRank() != 1) 1445 return failure(); 1446 if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) 1447 return failure(); // Handled by ConvertVectorToLLVM 1448 1449 // Loop bounds, step, state... 1450 Location loc = xferOp.getLoc(); 1451 auto vecType = xferOp.getVectorType(); 1452 auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1453 Value ub = 1454 rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0)); 1455 if (vecType.isScalable()) { 1456 Value vscale = 1457 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 1458 ub = rewriter.create<arith::MulIOp>(loc, ub, vscale); 1459 } 1460 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 1461 auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp); 1462 1463 // Generate for loop. 1464 rewriter.replaceOpWithNewOp<scf::ForOp>( 1465 xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 1466 [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 1467 Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState); 1468 }); 1469 1470 return success(); 1471 } 1472 }; 1473 1474 } // namespace lowering_1_d 1475 } // namespace 1476 1477 void mlir::populateVectorToSCFConversionPatterns( 1478 RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { 1479 if (options.unroll) { 1480 patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion, 1481 lowering_n_d_unrolled::UnrollTransferWriteConversion>( 1482 patterns.getContext(), options); 1483 } else { 1484 patterns.add<lowering_n_d::PrepareTransferReadConversion, 1485 lowering_n_d::PrepareTransferWriteConversion, 1486 lowering_n_d::TransferOpConversion<TransferReadOp>, 1487 lowering_n_d::TransferOpConversion<TransferWriteOp>>( 1488 patterns.getContext(), options); 1489 } 1490 1491 if (options.targetRank == 1) { 1492 patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>, 1493 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>( 1494 patterns.getContext(), options); 1495 } 1496 patterns.add<lowering_n_d::DecomposePrintOpConversion>(patterns.getContext(), 1497 options); 1498 } 1499 1500 namespace { 1501 1502 struct ConvertVectorToSCFPass 1503 : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> { 1504 ConvertVectorToSCFPass() = default; 1505 ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1506 this->fullUnroll = options.unroll; 1507 this->targetRank = options.targetRank; 1508 this->lowerTensors = options.lowerTensors; 1509 } 1510 1511 void runOnOperation() override { 1512 VectorTransferToSCFOptions options; 1513 options.unroll = fullUnroll; 1514 options.targetRank = targetRank; 1515 options.lowerTensors = lowerTensors; 1516 1517 // Lower permutation maps first. 1518 RewritePatternSet lowerTransferPatterns(&getContext()); 1519 mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 1520 lowerTransferPatterns); 1521 (void)applyPatternsAndFoldGreedily(getOperation(), 1522 std::move(lowerTransferPatterns)); 1523 1524 RewritePatternSet patterns(&getContext()); 1525 populateVectorToSCFConversionPatterns(patterns, options); 1526 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 1527 } 1528 }; 1529 1530 } // namespace 1531 1532 std::unique_ptr<Pass> 1533 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 1534 return std::make_unique<ConvertVectorToSCFPass>(options); 1535 } 1536