xref: /llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent patterns to rewrite a vector.transfer
10 // op into a fully in-bounds part and a partial part.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <optional>
15 #include <type_traits>
16 
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
23 
24 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
25 #include "mlir/IR/Matchers.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Interfaces/VectorInterfaces.h"
28 
29 #include "llvm/ADT/DenseSet.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/raw_ostream.h"
35 
36 #define DEBUG_TYPE "vector-transfer-split"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 /// Build the condition to ensure that a particular VectorTransferOpInterface
42 /// is in-bounds.
43 static Value createInBoundsCond(RewriterBase &b,
44                                 VectorTransferOpInterface xferOp) {
45   assert(xferOp.getPermutationMap().isMinorIdentity() &&
46          "Expected minor identity map");
47   Value inBoundsCond;
48   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
49     // Zip over the resulting vector shape and memref indices.
50     // If the dimension is known to be in-bounds, it does not participate in
51     // the construction of `inBoundsCond`.
52     if (xferOp.isDimInBounds(resultIdx))
53       return;
54     // Fold or create the check that `index + vector_size` <= `memref_size`.
55     Location loc = xferOp.getLoc();
56     int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
57     OpFoldResult sum = affine::makeComposedFoldedAffineApply(
58         b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
59         {xferOp.getIndices()[indicesIdx]});
60     OpFoldResult dimSz =
61         memref::getMixedSize(b, loc, xferOp.getSource(), indicesIdx);
62     auto maybeCstSum = getConstantIntValue(sum);
63     auto maybeCstDimSz = getConstantIntValue(dimSz);
64     if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
65       return;
66     Value cond =
67         b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
68                                 getValueOrCreateConstantIndexOp(b, loc, sum),
69                                 getValueOrCreateConstantIndexOp(b, loc, dimSz));
70     // Conjunction over all dims for which we are in-bounds.
71     if (inBoundsCond)
72       inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
73     else
74       inBoundsCond = cond;
75   });
76   return inBoundsCond;
77 }
78 
79 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
80 /// masking) fast path and a slow path.
81 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
82 /// newly created conditional upon function return.
83 /// To accommodate for the fact that the original vector.transfer indexing may
84 /// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
85 /// scf.if op returns a view and values of type index.
86 /// At this time, only vector.transfer_read case is implemented.
87 ///
88 /// Example (a 2-D vector.transfer_read):
89 /// ```
90 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
91 /// ```
92 /// is transformed into:
93 /// ```
94 ///    %1:3 = scf.if (%inBounds) {
95 ///      // fast path, direct cast
96 ///      memref.cast %A: memref<A...> to compatibleMemRefType
97 ///      scf.yield %view : compatibleMemRefType, index, index
98 ///    } else {
99 ///      // slow path, not in-bounds vector.transfer or linalg.copy.
100 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
101 ///      scf.yield %4 : compatibleMemRefType, index, index
102 //     }
103 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
104 /// ```
105 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
106 ///
107 /// Preconditions:
108 ///  1. `xferOp.getPermutationMap()` must be a minor identity map
109 ///  2. the rank of the `xferOp.memref()` and the rank of the
110 ///     `xferOp.getVector()` must be equal. This will be relaxed in the future
111 ///     but requires rank-reducing subviews.
112 static LogicalResult
113 splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
114   // TODO: support 0-d corner case.
115   if (xferOp.getTransferRank() == 0)
116     return failure();
117 
118   // TODO: expand support to these 2 cases.
119   if (!xferOp.getPermutationMap().isMinorIdentity())
120     return failure();
121   // Must have some out-of-bounds dimension to be a candidate for splitting.
122   if (!xferOp.hasOutOfBoundsDim())
123     return failure();
124   // Don't split transfer operations directly under IfOp, this avoids applying
125   // the pattern recursively.
126   // TODO: improve the filtering condition to make it more applicable.
127   if (isa<scf::IfOp>(xferOp->getParentOp()))
128     return failure();
129   return success();
130 }
131 
132 /// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
133 /// be cast. If the MemRefTypes don't have the same rank or are not strided,
134 /// return null; otherwise:
135 ///   1. if `aT` and `bT` are cast-compatible, return `aT`.
136 ///   2. else return a new MemRefType obtained by iterating over the shape and
137 ///   strides and:
138 ///     a. keeping the ones that are static and equal across `aT` and `bT`.
139 ///     b. using a dynamic shape and/or stride for the dimensions that don't
140 ///        agree.
141 static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
142   if (memref::CastOp::areCastCompatible(aT, bT))
143     return aT;
144   if (aT.getRank() != bT.getRank())
145     return MemRefType();
146   int64_t aOffset, bOffset;
147   SmallVector<int64_t, 4> aStrides, bStrides;
148   if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
149       failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
150       aStrides.size() != bStrides.size())
151     return MemRefType();
152 
153   ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
154   int64_t resOffset;
155   SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
156       resStrides(bT.getRank(), 0);
157   for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
158     resShape[idx] =
159         (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
160     resStrides[idx] =
161         (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
162   }
163   resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
164   return MemRefType::get(
165       resShape, aT.getElementType(),
166       StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
167 }
168 
169 /// Casts the given memref to a compatible memref type. If the source memref has
170 /// a different address space than the target type, a `memref.memory_space_cast`
171 /// is first inserted, followed by a `memref.cast`.
172 static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
173                                         MemRefType compatibleMemRefType) {
174   MemRefType sourceType = cast<MemRefType>(memref.getType());
175   Value res = memref;
176   if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
177     sourceType = MemRefType::get(
178         sourceType.getShape(), sourceType.getElementType(),
179         sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
180     res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res);
181   }
182   if (sourceType == compatibleMemRefType)
183     return res;
184   return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res);
185 }
186 
187 /// Operates under a scoped context to build the intersection between the
188 /// view `xferOp.getSource()` @ `xferOp.getIndices()` and the view `alloc`.
189 // TODO: view intersection/union/differences should be a proper std op.
190 static std::pair<Value, Value>
191 createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
192                           Value alloc) {
193   Location loc = xferOp.getLoc();
194   int64_t memrefRank = xferOp.getShapedType().getRank();
195   // TODO: relax this precondition, will require rank-reducing subviews.
196   assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
197          "Expected memref rank to match the alloc rank");
198   ValueRange leadingIndices =
199       xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
200   SmallVector<OpFoldResult, 4> sizes;
201   sizes.append(leadingIndices.begin(), leadingIndices.end());
202   auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
203   xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
204     using MapList = ArrayRef<ArrayRef<AffineExpr>>;
205     Value dimMemRef = b.create<memref::DimOp>(xferOp.getLoc(),
206                                               xferOp.getSource(), indicesIdx);
207     Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
208     Value index = xferOp.getIndices()[indicesIdx];
209     AffineExpr i, j, k;
210     bindDims(xferOp.getContext(), i, j, k);
211     SmallVector<AffineMap, 4> maps =
212         AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext());
213     // affine_min(%dimMemRef - %index, %dimAlloc)
214     Value affineMin = b.create<affine::AffineMinOp>(
215         loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
216     sizes.push_back(affineMin);
217   });
218 
219   SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
220       xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));
221   SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
222   SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
223   auto copySrc = b.create<memref::SubViewOp>(
224       loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides);
225   auto copyDest = b.create<memref::SubViewOp>(
226       loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides);
227   return std::make_pair(copySrc, copyDest);
228 }
229 
230 /// Given an `xferOp` for which:
231 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
232 ///   2. a memref of single vector `alloc` has been allocated.
233 /// Produce IR resembling:
234 /// ```
235 ///    %1:3 = scf.if (%inBounds) {
236 ///      (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
237 ///      %view = memref.cast %A: memref<A...> to compatibleMemRefType
238 ///      scf.yield %view, ... : compatibleMemRefType, index, index
239 ///    } else {
240 ///      %2 = linalg.fill(%pad, %alloc)
241 ///      %3 = subview %view [...][...][...]
242 ///      %4 = subview %alloc [0, 0] [...] [...]
243 ///      linalg.copy(%3, %4)
244 ///      %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
245 ///      scf.yield %5, ... : compatibleMemRefType, index, index
246 ///   }
247 /// ```
248 /// Return the produced scf::IfOp.
249 static scf::IfOp
250 createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
251                             TypeRange returnTypes, Value inBoundsCond,
252                             MemRefType compatibleMemRefType, Value alloc) {
253   Location loc = xferOp.getLoc();
254   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
255   Value memref = xferOp.getSource();
256   return b.create<scf::IfOp>(
257       loc, inBoundsCond,
258       [&](OpBuilder &b, Location loc) {
259         Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
260         scf::ValueVector viewAndIndices{res};
261         viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
262                               xferOp.getIndices().end());
263         b.create<scf::YieldOp>(loc, viewAndIndices);
264       },
265       [&](OpBuilder &b, Location loc) {
266         b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
267                                  ValueRange{alloc});
268         // Take partial subview of memref which guarantees no dimension
269         // overflows.
270         IRRewriter rewriter(b);
271         std::pair<Value, Value> copyArgs = createSubViewIntersection(
272             rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
273             alloc);
274         b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
275         Value casted =
276             castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
277         scf::ValueVector viewAndIndices{casted};
278         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
279                               zero);
280         b.create<scf::YieldOp>(loc, viewAndIndices);
281       });
282 }
283 
284 /// Given an `xferOp` for which:
285 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
286 ///   2. a memref of single vector `alloc` has been allocated.
287 /// Produce IR resembling:
288 /// ```
289 ///    %1:3 = scf.if (%inBounds) {
290 ///      (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
291 ///      memref.cast %A: memref<A...> to compatibleMemRefType
292 ///      scf.yield %view, ... : compatibleMemRefType, index, index
293 ///    } else {
294 ///      %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
295 ///      %3 = vector.type_cast %extra_alloc :
296 ///        memref<...> to memref<vector<...>>
297 ///      store %2, %3[] : memref<vector<...>>
298 ///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
299 ///      scf.yield %4, ... : compatibleMemRefType, index, index
300 ///   }
301 /// ```
302 /// Return the produced scf::IfOp.
303 static scf::IfOp createFullPartialVectorTransferRead(
304     RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
305     Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
306   Location loc = xferOp.getLoc();
307   scf::IfOp fullPartialIfOp;
308   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
309   Value memref = xferOp.getSource();
310   return b.create<scf::IfOp>(
311       loc, inBoundsCond,
312       [&](OpBuilder &b, Location loc) {
313         Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
314         scf::ValueVector viewAndIndices{res};
315         viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(),
316                               xferOp.getIndices().end());
317         b.create<scf::YieldOp>(loc, viewAndIndices);
318       },
319       [&](OpBuilder &b, Location loc) {
320         Operation *newXfer = b.clone(*xferOp.getOperation());
321         Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
322         b.create<memref::StoreOp>(
323             loc, vector,
324             b.create<vector::TypeCastOp>(
325                 loc, MemRefType::get({}, vector.getType()), alloc));
326 
327         Value casted =
328             castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
329         scf::ValueVector viewAndIndices{casted};
330         viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
331                               zero);
332         b.create<scf::YieldOp>(loc, viewAndIndices);
333       });
334 }
335 
336 /// Given an `xferOp` for which:
337 ///   1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
338 ///   2. a memref of single vector `alloc` has been allocated.
339 /// Produce IR resembling:
340 /// ```
341 ///    %1:3 = scf.if (%inBounds) {
342 ///      memref.cast %A: memref<A...> to compatibleMemRefType
343 ///      scf.yield %view, ... : compatibleMemRefType, index, index
344 ///    } else {
345 ///      %3 = vector.type_cast %extra_alloc :
346 ///        memref<...> to memref<vector<...>>
347 ///      %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
348 ///      scf.yield %4, ... : compatibleMemRefType, index, index
349 ///   }
350 /// ```
351 static ValueRange
352 getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
353                           TypeRange returnTypes, Value inBoundsCond,
354                           MemRefType compatibleMemRefType, Value alloc) {
355   Location loc = xferOp.getLoc();
356   Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
357   Value memref = xferOp.getSource();
358   return b
359       .create<scf::IfOp>(
360           loc, inBoundsCond,
361           [&](OpBuilder &b, Location loc) {
362             Value res =
363                 castToCompatibleMemRefType(b, memref, compatibleMemRefType);
364             scf::ValueVector viewAndIndices{res};
365             viewAndIndices.insert(viewAndIndices.end(),
366                                   xferOp.getIndices().begin(),
367                                   xferOp.getIndices().end());
368             b.create<scf::YieldOp>(loc, viewAndIndices);
369           },
370           [&](OpBuilder &b, Location loc) {
371             Value casted =
372                 castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
373             scf::ValueVector viewAndIndices{casted};
374             viewAndIndices.insert(viewAndIndices.end(),
375                                   xferOp.getTransferRank(), zero);
376             b.create<scf::YieldOp>(loc, viewAndIndices);
377           })
378       ->getResults();
379 }
380 
381 /// Given an `xferOp` for which:
382 ///   1. `inBoundsCond` has been computed.
383 ///   2. a memref of single vector `alloc` has been allocated.
384 ///   3. it originally wrote to %view
385 /// Produce IR resembling:
386 /// ```
387 ///    %notInBounds = arith.xori %inBounds, %true
388 ///    scf.if (%notInBounds) {
389 ///      %3 = subview %alloc [...][...][...]
390 ///      %4 = subview %view [0, 0][...][...]
391 ///      linalg.copy(%3, %4)
392 ///   }
393 /// ```
394 static void createFullPartialLinalgCopy(RewriterBase &b,
395                                         vector::TransferWriteOp xferOp,
396                                         Value inBoundsCond, Value alloc) {
397   Location loc = xferOp.getLoc();
398   auto notInBounds = b.create<arith::XOrIOp>(
399       loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
400   b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
401     IRRewriter rewriter(b);
402     std::pair<Value, Value> copyArgs = createSubViewIntersection(
403         rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
404         alloc);
405     b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
406     b.create<scf::YieldOp>(loc, ValueRange{});
407   });
408 }
409 
410 /// Given an `xferOp` for which:
411 ///   1. `inBoundsCond` has been computed.
412 ///   2. a memref of single vector `alloc` has been allocated.
413 ///   3. it originally wrote to %view
414 /// Produce IR resembling:
415 /// ```
416 ///    %notInBounds = arith.xori %inBounds, %true
417 ///    scf.if (%notInBounds) {
418 ///      %2 = load %alloc : memref<vector<...>>
419 ///      vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
420 ///   }
421 /// ```
422 static void createFullPartialVectorTransferWrite(RewriterBase &b,
423                                                  vector::TransferWriteOp xferOp,
424                                                  Value inBoundsCond,
425                                                  Value alloc) {
426   Location loc = xferOp.getLoc();
427   auto notInBounds = b.create<arith::XOrIOp>(
428       loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
429   b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
430     IRMapping mapping;
431     Value load = b.create<memref::LoadOp>(
432         loc,
433         b.create<vector::TypeCastOp>(
434             loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
435         ValueRange());
436     mapping.map(xferOp.getVector(), load);
437     b.clone(*xferOp.getOperation(), mapping);
438     b.create<scf::YieldOp>(loc, ValueRange{});
439   });
440 }
441 
442 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
443 static Operation *getAutomaticAllocationScope(Operation *op) {
444   // Find the closest surrounding allocation scope that is not a known looping
445   // construct (putting alloca's in loops doesn't always lower to deallocation
446   // until the end of the loop).
447   Operation *scope = nullptr;
448   for (Operation *parent = op->getParentOp(); parent != nullptr;
449        parent = parent->getParentOp()) {
450     if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
451       scope = parent;
452     if (!isa<scf::ForOp, affine::AffineForOp>(parent))
453       break;
454   }
455   assert(scope && "Expected op to be inside automatic allocation scope");
456   return scope;
457 }
458 
459 /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
460 /// masking) fastpath and a slowpath.
461 ///
462 /// For vector.transfer_read:
463 /// If `ifOp` is not null and the result is `success, the `ifOp` points to the
464 /// newly created conditional upon function return.
465 /// To accomodate for the fact that the original vector.transfer indexing may be
466 /// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
467 /// scf.if op returns a view and values of type index.
468 ///
469 /// Example (a 2-D vector.transfer_read):
470 /// ```
471 ///    %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
472 /// ```
473 /// is transformed into:
474 /// ```
475 ///    %1:3 = scf.if (%inBounds) {
476 ///      // fastpath, direct cast
477 ///      memref.cast %A: memref<A...> to compatibleMemRefType
478 ///      scf.yield %view : compatibleMemRefType, index, index
479 ///    } else {
480 ///      // slowpath, not in-bounds vector.transfer or linalg.copy.
481 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
482 ///      scf.yield %4 : compatibleMemRefType, index, index
483 //     }
484 ///    %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
485 /// ```
486 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
487 ///
488 /// For vector.transfer_write:
489 /// There are 2 conditional blocks. First a block to decide which memref and
490 /// indices to use for an unmasked, inbounds write. Then a conditional block to
491 /// further copy a partial buffer into the final result in the slow path case.
492 ///
493 /// Example (a 2-D vector.transfer_write):
494 /// ```
495 ///    vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
496 /// ```
497 /// is transformed into:
498 /// ```
499 ///    %1:3 = scf.if (%inBounds) {
500 ///      memref.cast %A: memref<A...> to compatibleMemRefType
501 ///      scf.yield %view : compatibleMemRefType, index, index
502 ///    } else {
503 ///      memref.cast %alloc: memref<B...> to compatibleMemRefType
504 ///      scf.yield %4 : compatibleMemRefType, index, index
505 ///     }
506 ///    %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
507 ///                                                                    true]}
508 ///    scf.if (%notInBounds) {
509 ///      // slowpath: not in-bounds vector.transfer or linalg.copy.
510 ///    }
511 /// ```
512 /// where `alloc` is a top of the function alloca'ed buffer of one vector.
513 ///
514 /// Preconditions:
515 ///  1. `xferOp.getPermutationMap()` must be a minor identity map
516 ///  2. the rank of the `xferOp.getSource()` and the rank of the
517 ///     `xferOp.getVector()` must be equal. This will be relaxed in the future
518 ///     but requires rank-reducing subviews.
519 LogicalResult mlir::vector::splitFullAndPartialTransfer(
520     RewriterBase &b, VectorTransferOpInterface xferOp,
521     VectorTransformsOptions options, scf::IfOp *ifOp) {
522   if (options.vectorTransferSplit == VectorTransferSplit::None)
523     return failure();
524 
525   SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
526   auto inBoundsAttr = b.getBoolArrayAttr(bools);
527   if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
528     b.modifyOpInPlace(xferOp, [&]() {
529       xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
530     });
531     return success();
532   }
533 
534   // Assert preconditions. Additionally, keep the variables in an inner scope to
535   // ensure they aren't used in the wrong scopes further down.
536   {
537     assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
538            "Expected splitFullAndPartialTransferPrecondition to hold");
539 
540     auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
541     auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
542 
543     if (!(xferReadOp || xferWriteOp))
544       return failure();
545     if (xferWriteOp && xferWriteOp.getMask())
546       return failure();
547     if (xferReadOp && xferReadOp.getMask())
548       return failure();
549   }
550 
551   RewriterBase::InsertionGuard guard(b);
552   b.setInsertionPoint(xferOp);
553   Value inBoundsCond = createInBoundsCond(
554       b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
555   if (!inBoundsCond)
556     return failure();
557 
558   // Top of the function `alloc` for transient storage.
559   Value alloc;
560   {
561     RewriterBase::InsertionGuard guard(b);
562     Operation *scope = getAutomaticAllocationScope(xferOp);
563     assert(scope->getNumRegions() == 1 &&
564            "AutomaticAllocationScope with >1 regions");
565     b.setInsertionPointToStart(&scope->getRegion(0).front());
566     auto shape = xferOp.getVectorType().getShape();
567     Type elementType = xferOp.getVectorType().getElementType();
568     alloc = b.create<memref::AllocaOp>(scope->getLoc(),
569                                        MemRefType::get(shape, elementType),
570                                        ValueRange{}, b.getI64IntegerAttr(32));
571   }
572 
573   MemRefType compatibleMemRefType =
574       getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
575                                   cast<MemRefType>(alloc.getType()));
576   if (!compatibleMemRefType)
577     return failure();
578 
579   SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
580                                    b.getIndexType());
581   returnTypes[0] = compatibleMemRefType;
582 
583   if (auto xferReadOp =
584           dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
585     // Read case: full fill + partial copy -> in-bounds vector.xfer_read.
586     scf::IfOp fullPartialIfOp =
587         options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
588             ? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
589                                                   inBoundsCond,
590                                                   compatibleMemRefType, alloc)
591             : createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
592                                           inBoundsCond, compatibleMemRefType,
593                                           alloc);
594     if (ifOp)
595       *ifOp = fullPartialIfOp;
596 
597     // Set existing read op to in-bounds, it always reads from a full buffer.
598     for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
599       xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
600 
601     b.modifyOpInPlace(xferOp, [&]() {
602       xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
603     });
604 
605     return success();
606   }
607 
608   auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
609 
610   // Decide which location to write the entire vector to.
611   auto memrefAndIndices = getLocationToWriteFullVec(
612       b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
613 
614   // Do an in bounds write to either the output or the extra allocated buffer.
615   // The operation is cloned to prevent deleting information needed for the
616   // later IR creation.
617   IRMapping mapping;
618   mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
619   mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
620   auto *clone = b.clone(*xferWriteOp, mapping);
621   clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
622 
623   // Create a potential copy from the allocated buffer to the final output in
624   // the slow path case.
625   if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
626     createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
627   else
628     createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
629 
630   b.eraseOp(xferOp);
631 
632   return success();
633 }
634 
635 namespace {
636 /// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
637 /// may take an extra filter to perform selection at a finer granularity.
638 struct VectorTransferFullPartialRewriter : public RewritePattern {
639   using FilterConstraintType =
640       std::function<LogicalResult(VectorTransferOpInterface op)>;
641 
642   explicit VectorTransferFullPartialRewriter(
643       MLIRContext *context,
644       VectorTransformsOptions options = VectorTransformsOptions(),
645       FilterConstraintType filter =
646           [](VectorTransferOpInterface op) { return success(); },
647       PatternBenefit benefit = 1)
648       : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
649         filter(std::move(filter)) {}
650 
651   /// Performs the rewrite.
652   LogicalResult matchAndRewrite(Operation *op,
653                                 PatternRewriter &rewriter) const override;
654 
655 private:
656   VectorTransformsOptions options;
657   FilterConstraintType filter;
658 };
659 
660 } // namespace
661 
662 LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
663     Operation *op, PatternRewriter &rewriter) const {
664   auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
665   if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
666       failed(filter(xferOp)))
667     return failure();
668   return splitFullAndPartialTransfer(rewriter, xferOp, options);
669 }
670 
671 void mlir::vector::populateVectorTransferFullPartialPatterns(
672     RewritePatternSet &patterns, const VectorTransformsOptions &options) {
673   patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
674                                                   options);
675 }
676