1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements target-independent patterns to rewrite a vector.transfer 10 // op into a fully in-bounds part and a partial part. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include <optional> 15 #include <type_traits> 16 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/SCF/IR/SCF.h" 22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 23 24 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 25 #include "mlir/IR/Matchers.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/Interfaces/VectorInterfaces.h" 28 29 #include "llvm/ADT/DenseSet.h" 30 #include "llvm/ADT/MapVector.h" 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/raw_ostream.h" 35 36 #define DEBUG_TYPE "vector-transfer-split" 37 38 using namespace mlir; 39 using namespace mlir::vector; 40 41 /// Build the condition to ensure that a particular VectorTransferOpInterface 42 /// is in-bounds. 43 static Value createInBoundsCond(RewriterBase &b, 44 VectorTransferOpInterface xferOp) { 45 assert(xferOp.getPermutationMap().isMinorIdentity() && 46 "Expected minor identity map"); 47 Value inBoundsCond; 48 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { 49 // Zip over the resulting vector shape and memref indices. 50 // If the dimension is known to be in-bounds, it does not participate in 51 // the construction of `inBoundsCond`. 52 if (xferOp.isDimInBounds(resultIdx)) 53 return; 54 // Fold or create the check that `index + vector_size` <= `memref_size`. 55 Location loc = xferOp.getLoc(); 56 int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx); 57 OpFoldResult sum = affine::makeComposedFoldedAffineApply( 58 b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize), 59 {xferOp.getIndices()[indicesIdx]}); 60 OpFoldResult dimSz = 61 memref::getMixedSize(b, loc, xferOp.getSource(), indicesIdx); 62 auto maybeCstSum = getConstantIntValue(sum); 63 auto maybeCstDimSz = getConstantIntValue(dimSz); 64 if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz) 65 return; 66 Value cond = 67 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle, 68 getValueOrCreateConstantIndexOp(b, loc, sum), 69 getValueOrCreateConstantIndexOp(b, loc, dimSz)); 70 // Conjunction over all dims for which we are in-bounds. 71 if (inBoundsCond) 72 inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond); 73 else 74 inBoundsCond = cond; 75 }); 76 return inBoundsCond; 77 } 78 79 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 80 /// masking) fast path and a slow path. 81 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 82 /// newly created conditional upon function return. 83 /// To accommodate for the fact that the original vector.transfer indexing may 84 /// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 85 /// scf.if op returns a view and values of type index. 86 /// At this time, only vector.transfer_read case is implemented. 87 /// 88 /// Example (a 2-D vector.transfer_read): 89 /// ``` 90 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 91 /// ``` 92 /// is transformed into: 93 /// ``` 94 /// %1:3 = scf.if (%inBounds) { 95 /// // fast path, direct cast 96 /// memref.cast %A: memref<A...> to compatibleMemRefType 97 /// scf.yield %view : compatibleMemRefType, index, index 98 /// } else { 99 /// // slow path, not in-bounds vector.transfer or linalg.copy. 100 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 101 /// scf.yield %4 : compatibleMemRefType, index, index 102 // } 103 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 104 /// ``` 105 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 106 /// 107 /// Preconditions: 108 /// 1. `xferOp.getPermutationMap()` must be a minor identity map 109 /// 2. the rank of the `xferOp.memref()` and the rank of the 110 /// `xferOp.getVector()` must be equal. This will be relaxed in the future 111 /// but requires rank-reducing subviews. 112 static LogicalResult 113 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) { 114 // TODO: support 0-d corner case. 115 if (xferOp.getTransferRank() == 0) 116 return failure(); 117 118 // TODO: expand support to these 2 cases. 119 if (!xferOp.getPermutationMap().isMinorIdentity()) 120 return failure(); 121 // Must have some out-of-bounds dimension to be a candidate for splitting. 122 if (!xferOp.hasOutOfBoundsDim()) 123 return failure(); 124 // Don't split transfer operations directly under IfOp, this avoids applying 125 // the pattern recursively. 126 // TODO: improve the filtering condition to make it more applicable. 127 if (isa<scf::IfOp>(xferOp->getParentOp())) 128 return failure(); 129 return success(); 130 } 131 132 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can 133 /// be cast. If the MemRefTypes don't have the same rank or are not strided, 134 /// return null; otherwise: 135 /// 1. if `aT` and `bT` are cast-compatible, return `aT`. 136 /// 2. else return a new MemRefType obtained by iterating over the shape and 137 /// strides and: 138 /// a. keeping the ones that are static and equal across `aT` and `bT`. 139 /// b. using a dynamic shape and/or stride for the dimensions that don't 140 /// agree. 141 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) { 142 if (memref::CastOp::areCastCompatible(aT, bT)) 143 return aT; 144 if (aT.getRank() != bT.getRank()) 145 return MemRefType(); 146 int64_t aOffset, bOffset; 147 SmallVector<int64_t, 4> aStrides, bStrides; 148 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) || 149 failed(bT.getStridesAndOffset(bStrides, bOffset)) || 150 aStrides.size() != bStrides.size()) 151 return MemRefType(); 152 153 ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape(); 154 int64_t resOffset; 155 SmallVector<int64_t, 4> resShape(aT.getRank(), 0), 156 resStrides(bT.getRank(), 0); 157 for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { 158 resShape[idx] = 159 (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic; 160 resStrides[idx] = 161 (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic; 162 } 163 resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic; 164 return MemRefType::get( 165 resShape, aT.getElementType(), 166 StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides)); 167 } 168 169 /// Casts the given memref to a compatible memref type. If the source memref has 170 /// a different address space than the target type, a `memref.memory_space_cast` 171 /// is first inserted, followed by a `memref.cast`. 172 static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, 173 MemRefType compatibleMemRefType) { 174 MemRefType sourceType = cast<MemRefType>(memref.getType()); 175 Value res = memref; 176 if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) { 177 sourceType = MemRefType::get( 178 sourceType.getShape(), sourceType.getElementType(), 179 sourceType.getLayout(), compatibleMemRefType.getMemorySpace()); 180 res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res); 181 } 182 if (sourceType == compatibleMemRefType) 183 return res; 184 return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res); 185 } 186 187 /// Operates under a scoped context to build the intersection between the 188 /// view `xferOp.getSource()` @ `xferOp.getIndices()` and the view `alloc`. 189 // TODO: view intersection/union/differences should be a proper std op. 190 static std::pair<Value, Value> 191 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp, 192 Value alloc) { 193 Location loc = xferOp.getLoc(); 194 int64_t memrefRank = xferOp.getShapedType().getRank(); 195 // TODO: relax this precondition, will require rank-reducing subviews. 196 assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() && 197 "Expected memref rank to match the alloc rank"); 198 ValueRange leadingIndices = 199 xferOp.getIndices().take_front(xferOp.getLeadingShapedRank()); 200 SmallVector<OpFoldResult, 4> sizes; 201 sizes.append(leadingIndices.begin(), leadingIndices.end()); 202 auto isaWrite = isa<vector::TransferWriteOp>(xferOp); 203 xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { 204 using MapList = ArrayRef<ArrayRef<AffineExpr>>; 205 Value dimMemRef = b.create<memref::DimOp>(xferOp.getLoc(), 206 xferOp.getSource(), indicesIdx); 207 Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx); 208 Value index = xferOp.getIndices()[indicesIdx]; 209 AffineExpr i, j, k; 210 bindDims(xferOp.getContext(), i, j, k); 211 SmallVector<AffineMap, 4> maps = 212 AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext()); 213 // affine_min(%dimMemRef - %index, %dimAlloc) 214 Value affineMin = b.create<affine::AffineMinOp>( 215 loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc}); 216 sizes.push_back(affineMin); 217 }); 218 219 SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range( 220 xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; })); 221 SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0)); 222 SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1)); 223 auto copySrc = b.create<memref::SubViewOp>( 224 loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides); 225 auto copyDest = b.create<memref::SubViewOp>( 226 loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides); 227 return std::make_pair(copySrc, copyDest); 228 } 229 230 /// Given an `xferOp` for which: 231 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 232 /// 2. a memref of single vector `alloc` has been allocated. 233 /// Produce IR resembling: 234 /// ``` 235 /// %1:3 = scf.if (%inBounds) { 236 /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>) 237 /// %view = memref.cast %A: memref<A...> to compatibleMemRefType 238 /// scf.yield %view, ... : compatibleMemRefType, index, index 239 /// } else { 240 /// %2 = linalg.fill(%pad, %alloc) 241 /// %3 = subview %view [...][...][...] 242 /// %4 = subview %alloc [0, 0] [...] [...] 243 /// linalg.copy(%3, %4) 244 /// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType 245 /// scf.yield %5, ... : compatibleMemRefType, index, index 246 /// } 247 /// ``` 248 /// Return the produced scf::IfOp. 249 static scf::IfOp 250 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp, 251 TypeRange returnTypes, Value inBoundsCond, 252 MemRefType compatibleMemRefType, Value alloc) { 253 Location loc = xferOp.getLoc(); 254 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 255 Value memref = xferOp.getSource(); 256 return b.create<scf::IfOp>( 257 loc, inBoundsCond, 258 [&](OpBuilder &b, Location loc) { 259 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType); 260 scf::ValueVector viewAndIndices{res}; 261 viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(), 262 xferOp.getIndices().end()); 263 b.create<scf::YieldOp>(loc, viewAndIndices); 264 }, 265 [&](OpBuilder &b, Location loc) { 266 b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()}, 267 ValueRange{alloc}); 268 // Take partial subview of memref which guarantees no dimension 269 // overflows. 270 IRRewriter rewriter(b); 271 std::pair<Value, Value> copyArgs = createSubViewIntersection( 272 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()), 273 alloc); 274 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second); 275 Value casted = 276 castToCompatibleMemRefType(b, alloc, compatibleMemRefType); 277 scf::ValueVector viewAndIndices{casted}; 278 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), 279 zero); 280 b.create<scf::YieldOp>(loc, viewAndIndices); 281 }); 282 } 283 284 /// Given an `xferOp` for which: 285 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 286 /// 2. a memref of single vector `alloc` has been allocated. 287 /// Produce IR resembling: 288 /// ``` 289 /// %1:3 = scf.if (%inBounds) { 290 /// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>) 291 /// memref.cast %A: memref<A...> to compatibleMemRefType 292 /// scf.yield %view, ... : compatibleMemRefType, index, index 293 /// } else { 294 /// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...> 295 /// %3 = vector.type_cast %extra_alloc : 296 /// memref<...> to memref<vector<...>> 297 /// store %2, %3[] : memref<vector<...>> 298 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType 299 /// scf.yield %4, ... : compatibleMemRefType, index, index 300 /// } 301 /// ``` 302 /// Return the produced scf::IfOp. 303 static scf::IfOp createFullPartialVectorTransferRead( 304 RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes, 305 Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) { 306 Location loc = xferOp.getLoc(); 307 scf::IfOp fullPartialIfOp; 308 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 309 Value memref = xferOp.getSource(); 310 return b.create<scf::IfOp>( 311 loc, inBoundsCond, 312 [&](OpBuilder &b, Location loc) { 313 Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType); 314 scf::ValueVector viewAndIndices{res}; 315 viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(), 316 xferOp.getIndices().end()); 317 b.create<scf::YieldOp>(loc, viewAndIndices); 318 }, 319 [&](OpBuilder &b, Location loc) { 320 Operation *newXfer = b.clone(*xferOp.getOperation()); 321 Value vector = cast<VectorTransferOpInterface>(newXfer).getVector(); 322 b.create<memref::StoreOp>( 323 loc, vector, 324 b.create<vector::TypeCastOp>( 325 loc, MemRefType::get({}, vector.getType()), alloc)); 326 327 Value casted = 328 castToCompatibleMemRefType(b, alloc, compatibleMemRefType); 329 scf::ValueVector viewAndIndices{casted}; 330 viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), 331 zero); 332 b.create<scf::YieldOp>(loc, viewAndIndices); 333 }); 334 } 335 336 /// Given an `xferOp` for which: 337 /// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed. 338 /// 2. a memref of single vector `alloc` has been allocated. 339 /// Produce IR resembling: 340 /// ``` 341 /// %1:3 = scf.if (%inBounds) { 342 /// memref.cast %A: memref<A...> to compatibleMemRefType 343 /// scf.yield %view, ... : compatibleMemRefType, index, index 344 /// } else { 345 /// %3 = vector.type_cast %extra_alloc : 346 /// memref<...> to memref<vector<...>> 347 /// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType 348 /// scf.yield %4, ... : compatibleMemRefType, index, index 349 /// } 350 /// ``` 351 static ValueRange 352 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, 353 TypeRange returnTypes, Value inBoundsCond, 354 MemRefType compatibleMemRefType, Value alloc) { 355 Location loc = xferOp.getLoc(); 356 Value zero = b.create<arith::ConstantIndexOp>(loc, 0); 357 Value memref = xferOp.getSource(); 358 return b 359 .create<scf::IfOp>( 360 loc, inBoundsCond, 361 [&](OpBuilder &b, Location loc) { 362 Value res = 363 castToCompatibleMemRefType(b, memref, compatibleMemRefType); 364 scf::ValueVector viewAndIndices{res}; 365 viewAndIndices.insert(viewAndIndices.end(), 366 xferOp.getIndices().begin(), 367 xferOp.getIndices().end()); 368 b.create<scf::YieldOp>(loc, viewAndIndices); 369 }, 370 [&](OpBuilder &b, Location loc) { 371 Value casted = 372 castToCompatibleMemRefType(b, alloc, compatibleMemRefType); 373 scf::ValueVector viewAndIndices{casted}; 374 viewAndIndices.insert(viewAndIndices.end(), 375 xferOp.getTransferRank(), zero); 376 b.create<scf::YieldOp>(loc, viewAndIndices); 377 }) 378 ->getResults(); 379 } 380 381 /// Given an `xferOp` for which: 382 /// 1. `inBoundsCond` has been computed. 383 /// 2. a memref of single vector `alloc` has been allocated. 384 /// 3. it originally wrote to %view 385 /// Produce IR resembling: 386 /// ``` 387 /// %notInBounds = arith.xori %inBounds, %true 388 /// scf.if (%notInBounds) { 389 /// %3 = subview %alloc [...][...][...] 390 /// %4 = subview %view [0, 0][...][...] 391 /// linalg.copy(%3, %4) 392 /// } 393 /// ``` 394 static void createFullPartialLinalgCopy(RewriterBase &b, 395 vector::TransferWriteOp xferOp, 396 Value inBoundsCond, Value alloc) { 397 Location loc = xferOp.getLoc(); 398 auto notInBounds = b.create<arith::XOrIOp>( 399 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1)); 400 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) { 401 IRRewriter rewriter(b); 402 std::pair<Value, Value> copyArgs = createSubViewIntersection( 403 rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()), 404 alloc); 405 b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second); 406 b.create<scf::YieldOp>(loc, ValueRange{}); 407 }); 408 } 409 410 /// Given an `xferOp` for which: 411 /// 1. `inBoundsCond` has been computed. 412 /// 2. a memref of single vector `alloc` has been allocated. 413 /// 3. it originally wrote to %view 414 /// Produce IR resembling: 415 /// ``` 416 /// %notInBounds = arith.xori %inBounds, %true 417 /// scf.if (%notInBounds) { 418 /// %2 = load %alloc : memref<vector<...>> 419 /// vector.transfer_write %2, %view[...] : memref<A...>, vector<...> 420 /// } 421 /// ``` 422 static void createFullPartialVectorTransferWrite(RewriterBase &b, 423 vector::TransferWriteOp xferOp, 424 Value inBoundsCond, 425 Value alloc) { 426 Location loc = xferOp.getLoc(); 427 auto notInBounds = b.create<arith::XOrIOp>( 428 loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1)); 429 b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) { 430 IRMapping mapping; 431 Value load = b.create<memref::LoadOp>( 432 loc, 433 b.create<vector::TypeCastOp>( 434 loc, MemRefType::get({}, xferOp.getVector().getType()), alloc), 435 ValueRange()); 436 mapping.map(xferOp.getVector(), load); 437 b.clone(*xferOp.getOperation(), mapping); 438 b.create<scf::YieldOp>(loc, ValueRange{}); 439 }); 440 } 441 442 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. 443 static Operation *getAutomaticAllocationScope(Operation *op) { 444 // Find the closest surrounding allocation scope that is not a known looping 445 // construct (putting alloca's in loops doesn't always lower to deallocation 446 // until the end of the loop). 447 Operation *scope = nullptr; 448 for (Operation *parent = op->getParentOp(); parent != nullptr; 449 parent = parent->getParentOp()) { 450 if (parent->hasTrait<OpTrait::AutomaticAllocationScope>()) 451 scope = parent; 452 if (!isa<scf::ForOp, affine::AffineForOp>(parent)) 453 break; 454 } 455 assert(scope && "Expected op to be inside automatic allocation scope"); 456 return scope; 457 } 458 459 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds 460 /// masking) fastpath and a slowpath. 461 /// 462 /// For vector.transfer_read: 463 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the 464 /// newly created conditional upon function return. 465 /// To accomodate for the fact that the original vector.transfer indexing may be 466 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the 467 /// scf.if op returns a view and values of type index. 468 /// 469 /// Example (a 2-D vector.transfer_read): 470 /// ``` 471 /// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...> 472 /// ``` 473 /// is transformed into: 474 /// ``` 475 /// %1:3 = scf.if (%inBounds) { 476 /// // fastpath, direct cast 477 /// memref.cast %A: memref<A...> to compatibleMemRefType 478 /// scf.yield %view : compatibleMemRefType, index, index 479 /// } else { 480 /// // slowpath, not in-bounds vector.transfer or linalg.copy. 481 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 482 /// scf.yield %4 : compatibleMemRefType, index, index 483 // } 484 /// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} 485 /// ``` 486 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 487 /// 488 /// For vector.transfer_write: 489 /// There are 2 conditional blocks. First a block to decide which memref and 490 /// indices to use for an unmasked, inbounds write. Then a conditional block to 491 /// further copy a partial buffer into the final result in the slow path case. 492 /// 493 /// Example (a 2-D vector.transfer_write): 494 /// ``` 495 /// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...> 496 /// ``` 497 /// is transformed into: 498 /// ``` 499 /// %1:3 = scf.if (%inBounds) { 500 /// memref.cast %A: memref<A...> to compatibleMemRefType 501 /// scf.yield %view : compatibleMemRefType, index, index 502 /// } else { 503 /// memref.cast %alloc: memref<B...> to compatibleMemRefType 504 /// scf.yield %4 : compatibleMemRefType, index, index 505 /// } 506 /// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ... 507 /// true]} 508 /// scf.if (%notInBounds) { 509 /// // slowpath: not in-bounds vector.transfer or linalg.copy. 510 /// } 511 /// ``` 512 /// where `alloc` is a top of the function alloca'ed buffer of one vector. 513 /// 514 /// Preconditions: 515 /// 1. `xferOp.getPermutationMap()` must be a minor identity map 516 /// 2. the rank of the `xferOp.getSource()` and the rank of the 517 /// `xferOp.getVector()` must be equal. This will be relaxed in the future 518 /// but requires rank-reducing subviews. 519 LogicalResult mlir::vector::splitFullAndPartialTransfer( 520 RewriterBase &b, VectorTransferOpInterface xferOp, 521 VectorTransformsOptions options, scf::IfOp *ifOp) { 522 if (options.vectorTransferSplit == VectorTransferSplit::None) 523 return failure(); 524 525 SmallVector<bool, 4> bools(xferOp.getTransferRank(), true); 526 auto inBoundsAttr = b.getBoolArrayAttr(bools); 527 if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) { 528 b.modifyOpInPlace(xferOp, [&]() { 529 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); 530 }); 531 return success(); 532 } 533 534 // Assert preconditions. Additionally, keep the variables in an inner scope to 535 // ensure they aren't used in the wrong scopes further down. 536 { 537 assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) && 538 "Expected splitFullAndPartialTransferPrecondition to hold"); 539 540 auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation()); 541 auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation()); 542 543 if (!(xferReadOp || xferWriteOp)) 544 return failure(); 545 if (xferWriteOp && xferWriteOp.getMask()) 546 return failure(); 547 if (xferReadOp && xferReadOp.getMask()) 548 return failure(); 549 } 550 551 RewriterBase::InsertionGuard guard(b); 552 b.setInsertionPoint(xferOp); 553 Value inBoundsCond = createInBoundsCond( 554 b, cast<VectorTransferOpInterface>(xferOp.getOperation())); 555 if (!inBoundsCond) 556 return failure(); 557 558 // Top of the function `alloc` for transient storage. 559 Value alloc; 560 { 561 RewriterBase::InsertionGuard guard(b); 562 Operation *scope = getAutomaticAllocationScope(xferOp); 563 assert(scope->getNumRegions() == 1 && 564 "AutomaticAllocationScope with >1 regions"); 565 b.setInsertionPointToStart(&scope->getRegion(0).front()); 566 auto shape = xferOp.getVectorType().getShape(); 567 Type elementType = xferOp.getVectorType().getElementType(); 568 alloc = b.create<memref::AllocaOp>(scope->getLoc(), 569 MemRefType::get(shape, elementType), 570 ValueRange{}, b.getI64IntegerAttr(32)); 571 } 572 573 MemRefType compatibleMemRefType = 574 getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()), 575 cast<MemRefType>(alloc.getType())); 576 if (!compatibleMemRefType) 577 return failure(); 578 579 SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(), 580 b.getIndexType()); 581 returnTypes[0] = compatibleMemRefType; 582 583 if (auto xferReadOp = 584 dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) { 585 // Read case: full fill + partial copy -> in-bounds vector.xfer_read. 586 scf::IfOp fullPartialIfOp = 587 options.vectorTransferSplit == VectorTransferSplit::VectorTransfer 588 ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes, 589 inBoundsCond, 590 compatibleMemRefType, alloc) 591 : createFullPartialLinalgCopy(b, xferReadOp, returnTypes, 592 inBoundsCond, compatibleMemRefType, 593 alloc); 594 if (ifOp) 595 *ifOp = fullPartialIfOp; 596 597 // Set existing read op to in-bounds, it always reads from a full buffer. 598 for (unsigned i = 0, e = returnTypes.size(); i != e; ++i) 599 xferReadOp.setOperand(i, fullPartialIfOp.getResult(i)); 600 601 b.modifyOpInPlace(xferOp, [&]() { 602 xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr); 603 }); 604 605 return success(); 606 } 607 608 auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation()); 609 610 // Decide which location to write the entire vector to. 611 auto memrefAndIndices = getLocationToWriteFullVec( 612 b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc); 613 614 // Do an in bounds write to either the output or the extra allocated buffer. 615 // The operation is cloned to prevent deleting information needed for the 616 // later IR creation. 617 IRMapping mapping; 618 mapping.map(xferWriteOp.getSource(), memrefAndIndices.front()); 619 mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front()); 620 auto *clone = b.clone(*xferWriteOp, mapping); 621 clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr); 622 623 // Create a potential copy from the allocated buffer to the final output in 624 // the slow path case. 625 if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer) 626 createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc); 627 else 628 createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc); 629 630 b.eraseOp(xferOp); 631 632 return success(); 633 } 634 635 namespace { 636 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern 637 /// may take an extra filter to perform selection at a finer granularity. 638 struct VectorTransferFullPartialRewriter : public RewritePattern { 639 using FilterConstraintType = 640 std::function<LogicalResult(VectorTransferOpInterface op)>; 641 642 explicit VectorTransferFullPartialRewriter( 643 MLIRContext *context, 644 VectorTransformsOptions options = VectorTransformsOptions(), 645 FilterConstraintType filter = 646 [](VectorTransferOpInterface op) { return success(); }, 647 PatternBenefit benefit = 1) 648 : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options), 649 filter(std::move(filter)) {} 650 651 /// Performs the rewrite. 652 LogicalResult matchAndRewrite(Operation *op, 653 PatternRewriter &rewriter) const override; 654 655 private: 656 VectorTransformsOptions options; 657 FilterConstraintType filter; 658 }; 659 660 } // namespace 661 662 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite( 663 Operation *op, PatternRewriter &rewriter) const { 664 auto xferOp = dyn_cast<VectorTransferOpInterface>(op); 665 if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) || 666 failed(filter(xferOp))) 667 return failure(); 668 return splitFullAndPartialTransfer(rewriter, xferOp, options); 669 } 670 671 void mlir::vector::populateVectorTransferFullPartialPatterns( 672 RewritePatternSet &patterns, const VectorTransformsOptions &options) { 673 patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(), 674 options); 675 } 676