xref: /llvm-project/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <numeric>
14 #include <optional>
15 #include <type_traits>
16 
17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
18 
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Vector/IR/VectorOps.h"
25 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
26 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28 #include "mlir/IR/Builders.h"
29 #include "mlir/IR/ImplicitLocOpBuilder.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
32 #include "mlir/Transforms/Passes.h"
33 
34 namespace mlir {
35 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
36 #include "mlir/Conversion/Passes.h.inc"
37 } // namespace mlir
38 
39 using namespace mlir;
40 using vector::TransferReadOp;
41 using vector::TransferWriteOp;
42 
43 namespace {
44 
45 /// Attribute name used for labeling transfer ops during progressive lowering.
46 static const char kPassLabel[] = "__vector_to_scf_lowering__";
47 
48 /// Return true if this transfer op operates on a source tensor.
49 static bool isTensorOp(VectorTransferOpInterface xferOp) {
50   if (isa<RankedTensorType>(xferOp.getShapedType())) {
51     if (isa<vector::TransferWriteOp>(xferOp)) {
52       // TransferWriteOps on tensors have a result.
53       assert(xferOp->getNumResults() > 0);
54     }
55     return true;
56   }
57   return false;
58 }
59 
60 /// Patterns that inherit from this struct have access to
61 /// VectorTransferToSCFOptions.
62 template <typename OpTy>
63 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
64   explicit VectorToSCFPattern(MLIRContext *context,
65                               VectorTransferToSCFOptions opt)
66       : OpRewritePattern<OpTy>(context), options(opt) {}
67 
68   LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
69                                   PatternRewriter &rewriter) const {
70     if (isTensorOp(xferOp) && !options.lowerTensors) {
71       return rewriter.notifyMatchFailure(
72           xferOp, "lowering tensor transfers is disabled");
73     }
74     return success();
75   }
76 
77   VectorTransferToSCFOptions options;
78 };
79 
80 /// Given a vector transfer op, calculate which dimension of the `source`
81 /// memref should be unpacked in the next application of TransferOpConversion.
82 /// A return value of std::nullopt indicates a broadcast.
83 template <typename OpTy>
84 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
85   // TODO: support 0-d corner case.
86   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
87   auto map = xferOp.getPermutationMap();
88   if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
89     return expr.getPosition();
90   }
91   assert(xferOp.isBroadcastDim(0) &&
92          "Expected AffineDimExpr or AffineConstantExpr");
93   return std::nullopt;
94 }
95 
96 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
97 /// map is identical to the current permutation map, but the first result is
98 /// omitted.
99 template <typename OpTy>
100 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
101   // TODO: support 0-d corner case.
102   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
103   auto map = xferOp.getPermutationMap();
104   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
105                         b.getContext());
106 }
107 
108 /// Calculate the indices for the new vector transfer op.
109 ///
110 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
111 ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
112 ///                                 ^^^^^^
113 ///              `iv` is the iteration variable of the (new) surrounding loop.
114 template <typename OpTy>
115 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
116                            SmallVector<Value, 8> &indices) {
117   typename OpTy::Adaptor adaptor(xferOp);
118   // Corresponding memref dim of the vector dim that is unpacked.
119   auto dim = unpackedDim(xferOp);
120   auto prevIndices = adaptor.getIndices();
121   indices.append(prevIndices.begin(), prevIndices.end());
122 
123   Location loc = xferOp.getLoc();
124   bool isBroadcast = !dim.has_value();
125   if (!isBroadcast) {
126     AffineExpr d0, d1;
127     bindDims(xferOp.getContext(), d0, d1);
128     Value offset = adaptor.getIndices()[*dim];
129     indices[*dim] =
130         affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
131   }
132 }
133 
134 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
135                             Value value) {
136   if (hasRetVal) {
137     assert(value && "Expected non-empty value");
138     b.create<scf::YieldOp>(loc, value);
139   } else {
140     b.create<scf::YieldOp>(loc);
141   }
142 }
143 
144 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
145 /// is set to true. No such check is generated under following circumstances:
146 /// * xferOp does not have a mask.
147 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
148 ///   computed and attached to the new transfer op in the pattern.)
149 /// * The to-be-unpacked dim of xferOp is a broadcast.
150 template <typename OpTy>
151 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
152   if (!xferOp.getMask())
153     return Value();
154   if (xferOp.getMaskType().getRank() != 1)
155     return Value();
156   if (xferOp.isBroadcastDim(0))
157     return Value();
158 
159   Location loc = xferOp.getLoc();
160   return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
161 }
162 
163 /// Helper function TransferOpConversion and TransferOp1dConversion.
164 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
165 /// specified dimension `dim` with the loop iteration variable `iv`.
166 /// E.g., when unpacking dimension 0 from:
167 /// ```
168 /// %vec = vector.transfer_read %A[%a, %b] %cst
169 ///     : vector<5x4xf32>, memref<?x?xf32>
170 /// ```
171 /// An if check similar to this will be generated inside the loop:
172 /// ```
173 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
174 /// if (%a + iv < %d) {
175 ///   (in-bounds case)
176 /// } else {
177 ///   (out-of-bounds case)
178 /// }
179 /// ```
180 ///
181 /// If the transfer is 1D and has a mask, this function generates a more complex
182 /// check also accounts for potentially masked out elements.
183 ///
184 /// This function variant returns the value returned by `inBoundsCase` or
185 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
186 /// `resultTypes`.
187 template <typename OpTy>
188 static Value generateInBoundsCheck(
189     OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
190     TypeRange resultTypes,
191     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
192     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
193   bool hasRetVal = !resultTypes.empty();
194   Value cond; // Condition to be built...
195 
196   // Condition check 1: Access in-bounds?
197   bool isBroadcast = !dim; // No in-bounds check for broadcasts.
198   Location loc = xferOp.getLoc();
199   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
200   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
201     Value memrefDim =
202         vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
203     AffineExpr d0, d1;
204     bindDims(xferOp.getContext(), d0, d1);
205     Value base = xferOp.getIndices()[*dim];
206     Value memrefIdx =
207         affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
208     cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
209                                     memrefIdx);
210   }
211 
212   // Condition check 2: Masked in?
213   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
214     if (cond)
215       cond = lb.create<arith::AndIOp>(cond, maskCond);
216     else
217       cond = maskCond;
218   }
219 
220   // If the condition is non-empty, generate an SCF::IfOp.
221   if (cond) {
222     auto check = lb.create<scf::IfOp>(
223         cond,
224         /*thenBuilder=*/
225         [&](OpBuilder &b, Location loc) {
226           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
227         },
228         /*elseBuilder=*/
229         [&](OpBuilder &b, Location loc) {
230           if (outOfBoundsCase) {
231             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
232           } else {
233             b.create<scf::YieldOp>(loc);
234           }
235         });
236 
237     return hasRetVal ? check.getResult(0) : Value();
238   }
239 
240   // Condition is empty, no need for an SCF::IfOp.
241   return inBoundsCase(b, loc);
242 }
243 
244 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
245 /// a return value. Consequently, this function does not have a return value.
246 template <typename OpTy>
247 static void generateInBoundsCheck(
248     OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
249     function_ref<void(OpBuilder &, Location)> inBoundsCase,
250     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
251   generateInBoundsCheck(
252       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
253       /*inBoundsCase=*/
254       [&](OpBuilder &b, Location loc) {
255         inBoundsCase(b, loc);
256         return Value();
257       },
258       /*outOfBoundsCase=*/
259       [&](OpBuilder &b, Location loc) {
260         if (outOfBoundsCase)
261           outOfBoundsCase(b, loc);
262         return Value();
263       });
264 }
265 
266 /// Given an ArrayAttr, return a copy where the first element is dropped.
267 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
268   if (!attr)
269     return attr;
270   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
271 }
272 
273 /// Add the pass label to a vector transfer op if its rank is not the target
274 /// rank.
275 template <typename OpTy>
276 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
277                                 unsigned targetRank) {
278   if (newXferOp.getVectorType().getRank() > targetRank)
279     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
280 }
281 
282 namespace lowering_n_d {
283 
284 /// Helper data structure for data and mask buffers.
285 struct BufferAllocs {
286   Value dataBuffer;
287   Value maskBuffer;
288 };
289 
290 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
291 static Operation *getAutomaticAllocationScope(Operation *op) {
292   Operation *scope =
293       op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
294   assert(scope && "Expected op to be inside automatic allocation scope");
295   return scope;
296 }
297 
298 /// Allocate temporary buffers for data (vector) and mask (if present).
299 template <typename OpTy>
300 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
301   Location loc = xferOp.getLoc();
302   OpBuilder::InsertionGuard guard(b);
303   Operation *scope = getAutomaticAllocationScope(xferOp);
304   assert(scope->getNumRegions() == 1 &&
305          "AutomaticAllocationScope with >1 regions");
306   b.setInsertionPointToStart(&scope->getRegion(0).front());
307 
308   BufferAllocs result;
309   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
310   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
311 
312   if (xferOp.getMask()) {
313     auto maskType = MemRefType::get({}, xferOp.getMask().getType());
314     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
315     b.setInsertionPoint(xferOp);
316     b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
317     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange());
318   }
319 
320   return result;
321 }
322 
323 /// Given a MemRefType with VectorType element type, unpack one dimension from
324 /// the VectorType into the MemRefType.
325 ///
326 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
327 static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
328   auto vectorType = dyn_cast<VectorType>(type.getElementType());
329   // Vectors with leading scalable dims are not supported.
330   // It may be possible to support these in future by using dynamic memref dims.
331   if (vectorType.getScalableDims().front())
332     return failure();
333   auto memrefShape = type.getShape();
334   SmallVector<int64_t, 8> newMemrefShape;
335   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
336   newMemrefShape.push_back(vectorType.getDimSize(0));
337   return MemRefType::get(newMemrefShape,
338                          VectorType::Builder(vectorType).dropDim(0));
339 }
340 
341 /// Given a transfer op, find the memref from which the mask is loaded. This
342 /// is similar to Strategy<TransferWriteOp>::getBuffer.
343 template <typename OpTy>
344 static Value getMaskBuffer(OpTy xferOp) {
345   assert(xferOp.getMask() && "Expected that transfer op has mask");
346   auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
347   assert(loadOp && "Expected transfer op mask produced by LoadOp");
348   return loadOp.getMemRef();
349 }
350 
351 /// Codegen strategy, depending on the operation.
352 template <typename OpTy>
353 struct Strategy;
354 
355 /// Code strategy for vector TransferReadOp.
356 template <>
357 struct Strategy<TransferReadOp> {
358   /// Find the StoreOp that is used for writing the current TransferReadOp's
359   /// result to the temporary buffer allocation.
360   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
361     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
362     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
363     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
364     return storeOp;
365   }
366 
367   /// Find the temporary buffer allocation. All labeled TransferReadOps are
368   /// used like this, where %buf is either the buffer allocation or a type cast
369   /// of the buffer allocation:
370   /// ```
371   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
372   /// memref.store %vec, %buf[...] ...
373   /// ```
374   static Value getBuffer(TransferReadOp xferOp) {
375     return getStoreOp(xferOp).getMemRef();
376   }
377 
378   /// Retrieve the indices of the current StoreOp that stores into the buffer.
379   static void getBufferIndices(TransferReadOp xferOp,
380                                SmallVector<Value, 8> &indices) {
381     auto storeOp = getStoreOp(xferOp);
382     auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
383     indices.append(prevIndices.begin(), prevIndices.end());
384   }
385 
386   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
387   /// accesses on the to-be-unpacked dimension.
388   ///
389   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
390   ///    variable `iv`.
391   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
392   ///
393   /// E.g.:
394   /// ```
395   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
396   ///     : memref<?x?x?xf32>, vector<4x3xf32>
397   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
398   /// ```
399   /// Is rewritten to:
400   /// ```
401   /// %casted = vector.type_cast %buf
402   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
403   /// for %j = 0 to 4 {
404   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
405   ///       : memref<?x?x?xf32>, vector<3xf32>
406   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
407   /// }
408   /// ```
409   ///
410   /// Note: The loop and type cast are generated in TransferOpConversion.
411   ///       The original TransferReadOp and store op are deleted in `cleanup`.
412   /// Note: The `mask` operand is set in TransferOpConversion.
413   static TransferReadOp rewriteOp(OpBuilder &b,
414                                   VectorTransferToSCFOptions options,
415                                   TransferReadOp xferOp, Value buffer, Value iv,
416                                   ValueRange /*loopState*/) {
417     SmallVector<Value, 8> storeIndices;
418     getBufferIndices(xferOp, storeIndices);
419     storeIndices.push_back(iv);
420 
421     SmallVector<Value, 8> xferIndices;
422     getXferIndices(b, xferOp, iv, xferIndices);
423 
424     Location loc = xferOp.getLoc();
425     auto bufferType = dyn_cast<ShapedType>(buffer.getType());
426     auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
427     auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
428     auto newXferOp = b.create<vector::TransferReadOp>(
429         loc, vecType, xferOp.getSource(), xferIndices,
430         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
431         xferOp.getPadding(), Value(), inBoundsAttr);
432 
433     maybeApplyPassLabel(b, newXferOp, options.targetRank);
434 
435     b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
436     return newXferOp;
437   }
438 
439   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
440   /// padding value to the temporary buffer.
441   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
442                                     Value buffer, Value iv,
443                                     ValueRange /*loopState*/) {
444     SmallVector<Value, 8> storeIndices;
445     getBufferIndices(xferOp, storeIndices);
446     storeIndices.push_back(iv);
447 
448     Location loc = xferOp.getLoc();
449     auto bufferType = dyn_cast<ShapedType>(buffer.getType());
450     auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
451     auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
452     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
453 
454     return Value();
455   }
456 
457   /// Cleanup after rewriting the op.
458   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
459                       scf::ForOp /*forOp*/) {
460     rewriter.eraseOp(getStoreOp(xferOp));
461     rewriter.eraseOp(xferOp);
462   }
463 
464   /// Return the initial loop state for the generated scf.for loop.
465   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
466 };
467 
468 /// Codegen strategy for vector TransferWriteOp.
469 template <>
470 struct Strategy<TransferWriteOp> {
471   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
472   /// used like this, where %buf is either the buffer allocation or a type cast
473   /// of the buffer allocation:
474   /// ```
475   /// %vec = memref.load %buf[...] ...
476   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
477   /// ```
478   static Value getBuffer(TransferWriteOp xferOp) {
479     auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
480     assert(loadOp && "Expected transfer op vector produced by LoadOp");
481     return loadOp.getMemRef();
482   }
483 
484   /// Retrieve the indices of the current LoadOp that loads from the buffer.
485   static void getBufferIndices(TransferWriteOp xferOp,
486                                SmallVector<Value, 8> &indices) {
487     auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
488     auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
489     indices.append(prevIndices.begin(), prevIndices.end());
490   }
491 
492   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
493   /// accesses on the to-be-unpacked dimension.
494   ///
495   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
496   ///    using the loop iteration variable `iv`.
497   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
498   ///    to memory.
499   ///
500   /// Note: For more details, see comments on Strategy<TransferReadOp>.
501   static TransferWriteOp rewriteOp(OpBuilder &b,
502                                    VectorTransferToSCFOptions options,
503                                    TransferWriteOp xferOp, Value buffer,
504                                    Value iv, ValueRange loopState) {
505     SmallVector<Value, 8> loadIndices;
506     getBufferIndices(xferOp, loadIndices);
507     loadIndices.push_back(iv);
508 
509     SmallVector<Value, 8> xferIndices;
510     getXferIndices(b, xferOp, iv, xferIndices);
511 
512     Location loc = xferOp.getLoc();
513     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
514     auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
515     auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
516     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
517     auto newXferOp = b.create<vector::TransferWriteOp>(
518         loc, type, vec, source, xferIndices,
519         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
520         inBoundsAttr);
521 
522     maybeApplyPassLabel(b, newXferOp, options.targetRank);
523 
524     return newXferOp;
525   }
526 
527   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
528   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
529                                     Value buffer, Value iv,
530                                     ValueRange loopState) {
531     return isTensorOp(xferOp) ? loopState[0] : Value();
532   }
533 
534   /// Cleanup after rewriting the op.
535   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
536                       scf::ForOp forOp) {
537     if (isTensorOp(xferOp)) {
538       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
539       rewriter.replaceOp(xferOp, forOp->getResult(0));
540     } else {
541       rewriter.eraseOp(xferOp);
542     }
543   }
544 
545   /// Return the initial loop state for the generated scf.for loop.
546   static Value initialLoopState(TransferWriteOp xferOp) {
547     return isTensorOp(xferOp) ? xferOp.getSource() : Value();
548   }
549 };
550 
551 template <typename OpTy>
552 LogicalResult checkPrepareXferOp(OpTy xferOp,
553                                  VectorTransferToSCFOptions options) {
554   if (xferOp->hasAttr(kPassLabel))
555     return failure();
556   if (xferOp.getVectorType().getRank() <= options.targetRank)
557     return failure();
558   // Currently the unpacking of the leading dimension into the memref is not
559   // supported for scalable dimensions.
560   if (xferOp.getVectorType().getScalableDims().front())
561     return failure();
562   if (isTensorOp(xferOp) && !options.lowerTensors)
563     return failure();
564   // Transfer ops that modify the element type are not supported atm.
565   if (xferOp.getVectorType().getElementType() !=
566       xferOp.getShapedType().getElementType())
567     return failure();
568   return success();
569 }
570 
571 /// Prepare a TransferReadOp for progressive lowering.
572 ///
573 /// 1. Allocate a temporary buffer.
574 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
575 /// 3. Store the result of the TransferReadOp into the temporary buffer.
576 /// 4. Load the result from the temporary buffer and replace all uses of the
577 ///    original TransferReadOp with this load.
578 ///
579 /// E.g.:
580 /// ```
581 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
582 ///     : vector<5x4xf32>, memref<?x?x?xf32>
583 /// ```
584 /// is rewritten to:
585 /// ```
586 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
587 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
588 ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
589 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
590 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
591 /// ```
592 ///
593 /// Note: A second temporary buffer may be allocated for the `mask` operand.
594 struct PrepareTransferReadConversion
595     : public VectorToSCFPattern<TransferReadOp> {
596   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
597 
598   LogicalResult matchAndRewrite(TransferReadOp xferOp,
599                                 PatternRewriter &rewriter) const override {
600     if (checkPrepareXferOp(xferOp, options).failed())
601       return failure();
602 
603     auto buffers = allocBuffers(rewriter, xferOp);
604     auto *newXfer = rewriter.clone(*xferOp.getOperation());
605     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
606     if (xferOp.getMask()) {
607       dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
608           buffers.maskBuffer);
609     }
610 
611     Location loc = xferOp.getLoc();
612     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
613                                      buffers.dataBuffer);
614     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
615 
616     return success();
617   }
618 };
619 
620 /// Prepare a TransferWriteOp for progressive lowering.
621 ///
622 /// 1. Allocate a temporary buffer.
623 /// 2. Store the vector into the buffer.
624 /// 3. Load the vector from the buffer again.
625 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
626 ///    marking it eligible for progressive lowering via TransferOpConversion.
627 ///
628 /// E.g.:
629 /// ```
630 /// vector.transfer_write %vec, %A[%a, %b, %c]
631 ///     : vector<5x4xf32>, memref<?x?x?xf32>
632 /// ```
633 /// is rewritten to:
634 /// ```
635 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
636 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
637 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
638 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
639 ///     : vector<5x4xf32>, memref<?x?x?xf32>
640 /// ```
641 ///
642 /// Note: A second temporary buffer may be allocated for the `mask` operand.
643 struct PrepareTransferWriteConversion
644     : public VectorToSCFPattern<TransferWriteOp> {
645   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
646 
647   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
648                                 PatternRewriter &rewriter) const override {
649     if (checkPrepareXferOp(xferOp, options).failed())
650       return failure();
651 
652     Location loc = xferOp.getLoc();
653     auto buffers = allocBuffers(rewriter, xferOp);
654     rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
655                                      buffers.dataBuffer);
656     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
657     rewriter.modifyOpInPlace(xferOp, [&]() {
658       xferOp.getVectorMutable().assign(loadedVec);
659       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
660     });
661 
662     if (xferOp.getMask()) {
663       rewriter.modifyOpInPlace(xferOp, [&]() {
664         xferOp.getMaskMutable().assign(buffers.maskBuffer);
665       });
666     }
667 
668     return success();
669   }
670 };
671 
672 /// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows
673 /// printing both 1D scalable vectors and n-D fixed size vectors.
674 ///
675 /// E.g.:
676 /// ```
677 /// vector.print %v : vector<[4]xi32>
678 /// ```
679 /// is rewritten to:
680 /// ```
681 /// %c0 = arith.constant 0 : index
682 /// %c4 = arith.constant 4 : index
683 /// %c1 = arith.constant 1 : index
684 /// %vscale = vector.vscale
685 /// %length = arith.muli %vscale, %c4 : index
686 /// %lastIndex = arith.subi %length, %c1 : index
687 /// vector.print punctuation <open>
688 /// scf.for %i = %c0 to %length step %c1 {
689 ///   %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
690 ///   vector.print %el : i32 punctuation <no_punctuation>
691 ///   %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
692 ///   scf.if %notLastIndex {
693 ///     vector.print punctuation <comma>
694 ///   }
695 /// }
696 /// vector.print punctuation <close>
697 /// vector.print
698 /// ```
699 struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
700   using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
701   LogicalResult matchAndRewrite(vector::PrintOp printOp,
702                                 PatternRewriter &rewriter) const override {
703     if (!printOp.getSource())
704       return failure();
705 
706     VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
707     if (!vectorType)
708       return failure();
709 
710     // Currently >= 2D scalable vectors are not supported.
711     // These can't be lowered to LLVM (as LLVM does not support scalable vectors
712     // of scalable vectors), and due to limitations of current ops can't be
713     // indexed with SSA values or flattened. This may change after
714     // https://reviews.llvm.org/D155034, though there still needs to be a path
715     // for lowering to LLVM.
716     if (vectorType.getRank() > 1 && vectorType.isScalable())
717       return failure();
718 
719     auto loc = printOp.getLoc();
720     auto value = printOp.getSource();
721 
722     if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
723       // Oddly sized integers are (somewhat) buggy on a lot of backends, so to
724       // avoid issues extend them to a more standard size.
725       // https://github.com/llvm/llvm-project/issues/30613
726       auto width = intTy.getWidth();
727       auto legalWidth = llvm::NextPowerOf2(std::max(8u, width) - 1);
728       auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth,
729                                          intTy.getSignedness());
730       // arith can only take signless integers, so we must cast back and forth.
731       auto signlessSourceVectorType =
732           vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
733       auto signlessTargetVectorType =
734           vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
735       auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
736       value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType,
737                                                  value);
738       if (value.getType() != signlessTargetVectorType) {
739         if (width == 1 || intTy.isUnsigned())
740           value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType,
741                                                   value);
742         else
743           value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType,
744                                                   value);
745       }
746       value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value);
747       vectorType = targetVectorType;
748     }
749 
750     auto scalableDimensions = vectorType.getScalableDims();
751     auto shape = vectorType.getShape();
752     constexpr int64_t singletonShape[] = {1};
753     if (vectorType.getRank() == 0)
754       shape = singletonShape;
755 
756     if (vectorType.getRank() != 1) {
757       // Flatten n-D vectors to 1D. This is done to allow indexing with a
758       // non-constant value (which can currently only be done via
759       // vector.extractelement for 1D vectors).
760       auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
761                                         std::multiplies<int64_t>());
762       auto flatVectorType =
763           VectorType::get({flatLength}, vectorType.getElementType());
764       value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
765     }
766 
767     vector::PrintOp firstClose;
768     SmallVector<Value, 8> loopIndices;
769     for (unsigned d = 0; d < shape.size(); d++) {
770       // Setup loop bounds and step.
771       Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
772       Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]);
773       Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
774       if (!scalableDimensions.empty() && scalableDimensions[d]) {
775         auto vscale = rewriter.create<vector::VectorScaleOp>(
776             loc, rewriter.getIndexType());
777         upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale);
778       }
779       auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step);
780 
781       // Create a loop to print the elements surrounded by parentheses.
782       rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
783       auto loop =
784           rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
785       auto printClose = rewriter.create<vector::PrintOp>(
786           loc, vector::PrintPunctuation::Close);
787       if (!firstClose)
788         firstClose = printClose;
789 
790       auto loopIdx = loop.getInductionVar();
791       loopIndices.push_back(loopIdx);
792 
793       // Print a comma after all but the last element.
794       rewriter.setInsertionPointToStart(loop.getBody());
795       auto notLastIndex = rewriter.create<arith::CmpIOp>(
796           loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
797       rewriter.create<scf::IfOp>(loc, notLastIndex,
798                                  [&](OpBuilder &builder, Location loc) {
799                                    builder.create<vector::PrintOp>(
800                                        loc, vector::PrintPunctuation::Comma);
801                                    builder.create<scf::YieldOp>(loc);
802                                  });
803 
804       rewriter.setInsertionPointToStart(loop.getBody());
805     }
806 
807     // Compute the flattened index.
808     // Note: For the > rank 1 vectors this assumes non-scalable.
809     Value flatIndex;
810     auto currentStride = 1;
811     for (int d = shape.size() - 1; d >= 0; d--) {
812       auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride);
813       auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]);
814       if (flatIndex)
815         flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index);
816       else
817         flatIndex = index;
818       currentStride *= shape[d];
819     }
820 
821     // Print the scalar elements in the inner most loop.
822     auto element =
823         rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
824     rewriter.create<vector::PrintOp>(loc, element,
825                                      vector::PrintPunctuation::NoPunctuation);
826 
827     rewriter.setInsertionPointAfter(firstClose);
828     rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation());
829     rewriter.eraseOp(printOp);
830     return success();
831   }
832 
833   static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
834     return IntegerType::get(intTy.getContext(), intTy.getWidth(),
835                             IntegerType::Signless);
836   };
837 };
838 
839 /// Progressive lowering of vector transfer ops: Unpack one dimension.
840 ///
841 /// 1. Unpack one dimension from the current buffer type and cast the buffer
842 ///    to that new type. E.g.:
843 ///    ```
844 ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
845 ///    vector.transfer_write %vec ...
846 ///    ```
847 ///    The following cast is generated:
848 ///    ```
849 ///    %casted = vector.type_cast %0
850 ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
851 ///    ```
852 /// 2. Generate a for loop and rewrite the transfer op according to the
853 ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
854 ///    out-of-bounds, generate an if-check and handle both cases separately.
855 /// 3. Clean up according to the corresponding Strategy<OpTy>.
856 ///
857 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
858 /// source (as opposed to a memref source), then each iteration of the generated
859 /// scf.for loop yields the new tensor value. E.g.:
860 /// ```
861 /// %result = scf.for i = 0 to 5 {
862 ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
863 ///   %1 = vector.transfer_write %0, %source[...]
864 ///       : vector<4x3xf32>, tensor<5x4x3xf32>
865 ///   scf.yield %1 : tensor<5x4x3xf32>
866 /// }
867 /// ```
868 template <typename OpTy>
869 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
870   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
871 
872   void initialize() {
873     // This pattern recursively unpacks one dimension at a time. The recursion
874     // bounded as the rank is strictly decreasing.
875     this->setHasBoundedRewriteRecursion();
876   }
877 
878   static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
879                                        SmallVectorImpl<Value> &loadIndices,
880                                        Value iv) {
881     assert(xferOp.getMask() && "Expected transfer op to have mask");
882 
883     // Add load indices from the previous iteration.
884     // The mask buffer depends on the permutation map, which makes determining
885     // the indices quite complex, so this is why we need to "look back" to the
886     // previous iteration to find the right indices.
887     Value maskBuffer = getMaskBuffer(xferOp);
888     for (Operation *user : maskBuffer.getUsers()) {
889       // If there is no previous load op, then the indices are empty.
890       if (auto loadOp = dyn_cast<memref::LoadOp>(user)) {
891         Operation::operand_range prevIndices = loadOp.getIndices();
892         loadIndices.append(prevIndices.begin(), prevIndices.end());
893         break;
894       }
895     }
896 
897     // In case of broadcast: Use same indices to load from memref
898     // as before.
899     if (!xferOp.isBroadcastDim(0))
900       loadIndices.push_back(iv);
901   }
902 
903   LogicalResult matchAndRewrite(OpTy xferOp,
904                                 PatternRewriter &rewriter) const override {
905     if (!xferOp->hasAttr(kPassLabel))
906       return failure();
907 
908     // Find and cast data buffer. How the buffer can be found depends on OpTy.
909     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
910     Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
911     auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
912     FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
913     if (failed(castedDataType))
914       return failure();
915 
916     auto castedDataBuffer =
917         locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer);
918 
919     // If the xferOp has a mask: Find and cast mask buffer.
920     Value castedMaskBuffer;
921     if (xferOp.getMask()) {
922       Value maskBuffer = getMaskBuffer(xferOp);
923       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
924         // Do not unpack a dimension of the mask, if:
925         // * To-be-unpacked transfer op dimension is a broadcast.
926         // * Mask is 1D, i.e., the mask cannot be further unpacked.
927         //   (That means that all remaining dimensions of the transfer op must
928         //   be broadcasted.)
929         castedMaskBuffer = maskBuffer;
930       } else {
931         // It's safe to assume the mask buffer can be unpacked if the data
932         // buffer was unpacked.
933         auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
934         MemRefType castedMaskType = *unpackOneDim(maskBufferType);
935         castedMaskBuffer =
936             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
937       }
938     }
939 
940     // Loop bounds and step.
941     auto lb = locB.create<arith::ConstantIndexOp>(0);
942     auto ub = locB.create<arith::ConstantIndexOp>(
943         castedDataType->getDimSize(castedDataType->getRank() - 1));
944     auto step = locB.create<arith::ConstantIndexOp>(1);
945     // TransferWriteOps that operate on tensors return the modified tensor and
946     // require a loop state.
947     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
948 
949     // Generate for loop.
950     auto result = locB.create<scf::ForOp>(
951         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
952         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
953           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
954 
955           auto result = generateInBoundsCheck(
956               b, xferOp, iv, unpackedDim(xferOp),
957               stateType ? TypeRange(stateType) : TypeRange(),
958               /*inBoundsCase=*/
959               [&](OpBuilder &b, Location loc) {
960                 // Create new transfer op.
961                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
962                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
963 
964                 // If old transfer op has a mask: Set mask on new transfer op.
965                 // Special case: If the mask of the old transfer op is 1D and
966                 // the unpacked dim is not a broadcast, no mask is needed on
967                 // the new transfer op.
968                 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
969                                          xferOp.getMaskType().getRank() > 1)) {
970                   OpBuilder::InsertionGuard guard(b);
971                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
972 
973                   SmallVector<Value, 8> loadIndices;
974                   getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
975                                            loadIndices, iv);
976                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
977                                                        loadIndices);
978                   rewriter.modifyOpInPlace(newXfer, [&]() {
979                     newXfer.getMaskMutable().assign(mask);
980                   });
981                 }
982 
983                 return loopState.empty() ? Value() : newXfer->getResult(0);
984               },
985               /*outOfBoundsCase=*/
986               [&](OpBuilder &b, Location /*loc*/) {
987                 return Strategy<OpTy>::handleOutOfBoundsDim(
988                     b, xferOp, castedDataBuffer, iv, loopState);
989               });
990 
991           maybeYieldValue(b, loc, !loopState.empty(), result);
992         });
993 
994     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
995     return success();
996   }
997 };
998 
999 /// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
1000 /// and ConstantMaskOp.
1001 template <typename VscaleConstantBuilder>
1002 static FailureOr<SmallVector<OpFoldResult>>
1003 getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
1004   if (!mask)
1005     return SmallVector<OpFoldResult>{};
1006   if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
1007     return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
1008       return OpFoldResult(dimSize);
1009     });
1010   }
1011   if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
1012     int dimIdx = 0;
1013     VectorType maskType = constantMask.getVectorType();
1014     auto indexType = IndexType::get(mask.getContext());
1015     return llvm::map_to_vector(
1016         constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
1017           // A scalable dim in a constant_mask means vscale x dimSize.
1018           if (maskType.getScalableDims()[dimIdx++])
1019             return OpFoldResult(createVscaleMultiple(dimSize));
1020           return OpFoldResult(IntegerAttr::get(indexType, dimSize));
1021         });
1022   }
1023   return failure();
1024 }
1025 
1026 /// Scalable vector lowering of transfer_write(transpose). This lowering only
1027 /// supports rank 2 (scalable) vectors, but can be used in conjunction with
1028 /// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
1029 /// unrolls until the first scalable dimension.
1030 ///
1031 /// Example:
1032 ///
1033 /// BEFORE:
1034 /// ```mlir
1035 /// %transpose = vector.transpose %vec, [1, 0]
1036 ///    : vector<4x[4]xf32> to vector<[4]x4xf32>
1037 /// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
1038 ///    : vector<[4]x4xf32>,  memref<?x?xf32>
1039 /// ```
1040 ///
1041 /// AFTER:
1042 /// ```mlir
1043 /// %c1 = arith.constant 1 : index
1044 /// %c4 = arith.constant 4 : index
1045 /// %c0 = arith.constant 0 : index
1046 /// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
1047 /// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
1048 /// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
1049 /// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
1050 /// %vscale = vector.vscale
1051 /// %c4_vscale = arith.muli %vscale, %c4 : index
1052 /// scf.for %idx = %c0 to %c4_vscale step %c1 {
1053 ///   %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
1054 ///   %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
1055 ///   %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
1056 ///   %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
1057 ///   %slice_i = affine.apply #map(%idx)[%i]
1058 ///   %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
1059 ///   vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
1060 ///     : vector<4xf32>, memref<?x?xf32>
1061 /// }
1062 /// ```
1063 struct ScalableTransposeTransferWriteConversion
1064     : VectorToSCFPattern<vector::TransferWriteOp> {
1065   using VectorToSCFPattern::VectorToSCFPattern;
1066 
1067   LogicalResult matchAndRewrite(TransferWriteOp writeOp,
1068                                 PatternRewriter &rewriter) const override {
1069     if (failed(checkLowerTensors(writeOp, rewriter)))
1070       return failure();
1071 
1072     VectorType vectorType = writeOp.getVectorType();
1073 
1074     // Note: By comparing the scalable dims to an ArrayRef of length two this
1075     // implicitly checks the rank (is also two).
1076     ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
1077     if (scalableFlags != ArrayRef<bool>{true, false}) {
1078       return rewriter.notifyMatchFailure(
1079           writeOp, "expected vector of the form vector<[N]xMxty>");
1080     }
1081 
1082     auto permutationMap = writeOp.getPermutationMap();
1083     if (!permutationMap.isIdentity()) {
1084       return rewriter.notifyMatchFailure(
1085           writeOp, "non-identity permutations are unsupported (lower first)");
1086     }
1087 
1088     // Note: This pattern is only lowering the leading dimension (to a loop),
1089     // so we only check if the leading dimension is in bounds. The in-bounds
1090     // attribute for the trailing dimension will be propagated.
1091     if (!writeOp.isDimInBounds(0)) {
1092       return rewriter.notifyMatchFailure(
1093           writeOp, "out-of-bounds dims are unsupported (use masking)");
1094     }
1095 
1096     Value vector = writeOp.getVector();
1097     auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
1098     if (!transposeOp ||
1099         transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
1100       return rewriter.notifyMatchFailure(writeOp, "source not transpose");
1101     }
1102 
1103     auto loc = writeOp.getLoc();
1104     auto createVscaleMultiple =
1105         vector::makeVscaleConstantBuilder(rewriter, loc);
1106 
1107     auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
1108     if (failed(maskDims)) {
1109       return rewriter.notifyMatchFailure(writeOp,
1110                                          "failed to resolve mask dims");
1111     }
1112 
1113     int64_t fixedDimSize = vectorType.getDimSize(1);
1114     auto fixedDimOffsets = llvm::seq(fixedDimSize);
1115 
1116     // Extract all slices from the source of the transpose.
1117     auto transposeSource = transposeOp.getVector();
1118     SmallVector<Value> transposeSourceSlices =
1119         llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1120           return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx);
1121         });
1122 
1123     // Loop bounds and step.
1124     auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1125     auto ub =
1126         maskDims->empty()
1127             ? Value(createVscaleMultiple(vectorType.getDimSize(0)))
1128             : vector::getAsValues(rewriter, loc, maskDims->front()).front();
1129     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1130 
1131     // Generate a new mask for the slice.
1132     VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
1133     Value sliceMask = nullptr;
1134     if (!maskDims->empty()) {
1135       sliceMask = rewriter.create<vector::CreateMaskOp>(
1136           loc, sliceType.clone(rewriter.getI1Type()),
1137           ArrayRef<OpFoldResult>(*maskDims).drop_front());
1138     }
1139 
1140     Value initDest = isTensorOp(writeOp) ? writeOp.getSource() : Value{};
1141     ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
1142     auto result = rewriter.create<scf::ForOp>(
1143         loc, lb, ub, step, initLoopArgs,
1144         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
1145           // Indices for the new transfer op.
1146           SmallVector<Value, 8> xferIndices;
1147           getXferIndices(b, writeOp, iv, xferIndices);
1148 
1149           // Extract a transposed slice from the source vector.
1150           SmallVector<Value> transposeElements =
1151               llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
1152                 return b.create<vector::ExtractOp>(
1153                     loc, transposeSourceSlices[idx], iv);
1154               });
1155           auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
1156                                                            transposeElements);
1157 
1158           // Create the transfer_write for the slice.
1159           Value dest =
1160               loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front();
1161           auto newWriteOp = b.create<vector::TransferWriteOp>(
1162               loc, sliceVec, dest, xferIndices,
1163               ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
1164           if (sliceMask)
1165             newWriteOp.getMaskMutable().assign(sliceMask);
1166 
1167           // Yield from the loop.
1168           b.create<scf::YieldOp>(loc, loopIterArgs.empty()
1169                                           ? ValueRange{}
1170                                           : newWriteOp.getResult());
1171         });
1172 
1173     if (isTensorOp(writeOp))
1174       rewriter.replaceOp(writeOp, result);
1175     else
1176       rewriter.eraseOp(writeOp);
1177 
1178     return success();
1179   }
1180 };
1181 
1182 } // namespace lowering_n_d
1183 
1184 namespace lowering_n_d_unrolled {
1185 
1186 /// If the original transfer op has a mask, compute the mask of the new transfer
1187 /// op (for the current iteration `i`) and assign it.
1188 template <typename OpTy>
1189 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
1190                             int64_t i) {
1191   if (!xferOp.getMask())
1192     return;
1193 
1194   if (xferOp.isBroadcastDim(0)) {
1195     // To-be-unpacked dimension is a broadcast, which does not have a
1196     // corresponding mask dimension. Mask attribute remains unchanged.
1197     newXferOp.getMaskMutable().assign(xferOp.getMask());
1198     return;
1199   }
1200 
1201   if (xferOp.getMaskType().getRank() > 1) {
1202     // Unpack one dimension of the mask.
1203     OpBuilder::InsertionGuard guard(b);
1204     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
1205 
1206     llvm::SmallVector<int64_t, 1> indices({i});
1207     Location loc = xferOp.getLoc();
1208     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
1209     newXferOp.getMaskMutable().assign(newMask);
1210   }
1211 
1212   // If we end up here: The mask of the old transfer op is 1D and the unpacked
1213   // dim is not a broadcast, so no mask is needed on the new transfer op.
1214   // `generateInBoundsCheck` will have evaluated the mask already.
1215 }
1216 
1217 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
1218 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
1219 /// memref buffer is allocated and the SCF loop is fully unrolled.
1220 ///
1221 /// ```
1222 /// E.g.:
1223 /// ```
1224 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
1225 ///     : memref<?x?x?xf32>, vector<5x4xf32>
1226 /// ```
1227 /// is rewritten to IR such as (simplified):
1228 /// ```
1229 /// %v_init = splat %padding : vector<5x4xf32>
1230 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
1231 ///     : memref<?x?x?xf32>, vector<4xf32>
1232 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
1233 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
1234 ///     : memref<?x?x?xf32>, vector<4xf32>
1235 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
1236 /// ...
1237 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
1238 ///     : memref<?x?x?xf32>, vector<4xf32>
1239 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
1240 /// ```
1241 ///
1242 /// Note: As an optimization, if the result of the original TransferReadOp
1243 /// was directly inserted into another vector, no new %v_init vector is created.
1244 /// Instead, the new TransferReadOp results are inserted into that vector.
1245 struct UnrollTransferReadConversion
1246     : public VectorToSCFPattern<TransferReadOp> {
1247   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1248 
1249   void initialize() {
1250     // This pattern recursively unpacks one dimension at a time. The recursion
1251     // bounded as the rank is strictly decreasing.
1252     setHasBoundedRewriteRecursion();
1253   }
1254 
1255   /// Get or build the vector into which the newly created TransferReadOp
1256   /// results are inserted.
1257   Value buildResultVector(PatternRewriter &rewriter,
1258                           TransferReadOp xferOp) const {
1259     if (auto insertOp = getInsertOp(xferOp))
1260       return insertOp.getDest();
1261     Location loc = xferOp.getLoc();
1262     return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1263                                             xferOp.getPadding());
1264   }
1265 
1266   /// If the result of the TransferReadOp has exactly one user, which is a
1267   /// vector::InsertOp, return that operation.
1268   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
1269     if (xferOp->hasOneUse()) {
1270       Operation *xferOpUser = *xferOp->getUsers().begin();
1271       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1272         return insertOp;
1273     }
1274 
1275     return vector::InsertOp();
1276   }
1277 
1278   /// If the result of the TransferReadOp has exactly one user, which is a
1279   /// vector::InsertOp, return that operation's indices.
1280   void getInsertionIndices(TransferReadOp xferOp,
1281                            SmallVectorImpl<OpFoldResult> &indices) const {
1282     if (auto insertOp = getInsertOp(xferOp)) {
1283       auto pos = insertOp.getMixedPosition();
1284       indices.append(pos.begin(), pos.end());
1285     }
1286   }
1287 
1288   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1289   /// accesses, and broadcasts and transposes in permutation maps.
1290   LogicalResult matchAndRewrite(TransferReadOp xferOp,
1291                                 PatternRewriter &rewriter) const override {
1292     if (xferOp.getVectorType().getRank() <= options.targetRank)
1293       return rewriter.notifyMatchFailure(
1294           xferOp, "vector rank is less or equal to target rank");
1295     if (failed(checkLowerTensors(xferOp, rewriter)))
1296       return failure();
1297     // Transfer ops that modify the element type are not supported atm.
1298     if (xferOp.getVectorType().getElementType() !=
1299         xferOp.getShapedType().getElementType())
1300       return rewriter.notifyMatchFailure(
1301           xferOp, "not yet supported: element type mismatch");
1302     auto xferVecType = xferOp.getVectorType();
1303     if (xferVecType.getScalableDims()[0]) {
1304       // Cannot unroll a scalable dimension at compile time.
1305       return rewriter.notifyMatchFailure(
1306           xferOp, "scalable dimensions cannot be unrolled");
1307     }
1308 
1309     auto insertOp = getInsertOp(xferOp);
1310     auto vec = buildResultVector(rewriter, xferOp);
1311     auto vecType = dyn_cast<VectorType>(vec.getType());
1312 
1313     VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
1314 
1315     int64_t dimSize = xferVecType.getShape()[0];
1316 
1317     // Generate fully unrolled loop of transfer ops.
1318     Location loc = xferOp.getLoc();
1319     for (int64_t i = 0; i < dimSize; ++i) {
1320       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1321 
1322       vec = generateInBoundsCheck(
1323           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
1324           /*inBoundsCase=*/
1325           [&](OpBuilder &b, Location loc) {
1326             // Indices for the new transfer op.
1327             SmallVector<Value, 8> xferIndices;
1328             getXferIndices(b, xferOp, iv, xferIndices);
1329 
1330             // Indices for the new vector.insert op.
1331             SmallVector<OpFoldResult, 8> insertionIndices;
1332             getInsertionIndices(xferOp, insertionIndices);
1333             insertionIndices.push_back(rewriter.getIndexAttr(i));
1334 
1335             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1336             auto newXferOp = b.create<vector::TransferReadOp>(
1337                 loc, newXferVecType, xferOp.getSource(), xferIndices,
1338                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1339                 xferOp.getPadding(), Value(), inBoundsAttr);
1340             maybeAssignMask(b, xferOp, newXferOp, i);
1341             return b.create<vector::InsertOp>(loc, newXferOp, vec,
1342                                               insertionIndices);
1343           },
1344           /*outOfBoundsCase=*/
1345           [&](OpBuilder &b, Location loc) {
1346             // Loop through original (unmodified) vector.
1347             return vec;
1348           });
1349     }
1350 
1351     if (insertOp) {
1352       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
1353       rewriter.replaceOp(insertOp, vec);
1354       rewriter.eraseOp(xferOp);
1355     } else {
1356       rewriter.replaceOp(xferOp, vec);
1357     }
1358 
1359     return success();
1360   }
1361 };
1362 
1363 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
1364 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
1365 /// memref buffer is allocated and the SCF loop is fully unrolled.
1366 ///
1367 /// ```
1368 /// E.g.:
1369 /// ```
1370 /// vector.transfer_write %vec, %A[%a, %b, %c]
1371 ///     : vector<5x4xf32>, memref<?x?x?xf32>
1372 /// ```
1373 /// is rewritten to IR such as (simplified):
1374 /// ```
1375 /// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32>
1376 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
1377 /// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32>
1378 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
1379 /// ...
1380 /// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32>
1381 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
1382 /// ```
1383 ///
1384 /// Note: As an optimization, if the vector of the original TransferWriteOp
1385 /// was directly extracted from another vector via an ExtractOp `a`, extract
1386 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
1387 /// doing so, `a` may become dead, and the number of ExtractOps generated during
1388 /// recursive application of this pattern will be minimal.
1389 struct UnrollTransferWriteConversion
1390     : public VectorToSCFPattern<TransferWriteOp> {
1391   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1392 
1393   void initialize() {
1394     // This pattern recursively unpacks one dimension at a time. The recursion
1395     // bounded as the rank is strictly decreasing.
1396     setHasBoundedRewriteRecursion();
1397   }
1398 
1399   /// Return the vector from which newly generated ExtracOps will extract.
1400   Value getDataVector(TransferWriteOp xferOp) const {
1401     if (auto extractOp = getExtractOp(xferOp))
1402       return extractOp.getVector();
1403     return xferOp.getVector();
1404   }
1405 
1406   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
1407   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
1408     if (auto *op = xferOp.getVector().getDefiningOp())
1409       return dyn_cast<vector::ExtractOp>(op);
1410     return vector::ExtractOp();
1411   }
1412 
1413   /// If the input of the given TransferWriteOp is an ExtractOp, return its
1414   /// indices.
1415   void getExtractionIndices(TransferWriteOp xferOp,
1416                             SmallVectorImpl<OpFoldResult> &indices) const {
1417     if (auto extractOp = getExtractOp(xferOp)) {
1418       auto pos = extractOp.getMixedPosition();
1419       indices.append(pos.begin(), pos.end());
1420     }
1421   }
1422 
1423   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1424   /// accesses, and broadcasts and transposes in permutation maps.
1425   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1426                                 PatternRewriter &rewriter) const override {
1427     VectorType inputVectorTy = xferOp.getVectorType();
1428 
1429     if (inputVectorTy.getRank() <= options.targetRank)
1430       return failure();
1431 
1432     if (failed(checkLowerTensors(xferOp, rewriter)))
1433       return failure();
1434     // Transfer ops that modify the element type are not supported atm.
1435     if (inputVectorTy.getElementType() !=
1436         xferOp.getShapedType().getElementType())
1437       return failure();
1438 
1439     auto vec = getDataVector(xferOp);
1440     if (inputVectorTy.getScalableDims()[0]) {
1441       // Cannot unroll a scalable dimension at compile time.
1442       return failure();
1443     }
1444 
1445     int64_t dimSize = inputVectorTy.getShape()[0];
1446     Value source = xferOp.getSource(); // memref or tensor to be written to.
1447     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1448 
1449     // Generate fully unrolled loop of transfer ops.
1450     Location loc = xferOp.getLoc();
1451     for (int64_t i = 0; i < dimSize; ++i) {
1452       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1453 
1454       auto updatedSource = generateInBoundsCheck(
1455           rewriter, xferOp, iv, unpackedDim(xferOp),
1456           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1457           /*inBoundsCase=*/
1458           [&](OpBuilder &b, Location loc) {
1459             // Indices for the new transfer op.
1460             SmallVector<Value, 8> xferIndices;
1461             getXferIndices(b, xferOp, iv, xferIndices);
1462 
1463             // Indices for the new vector.extract op.
1464             SmallVector<OpFoldResult, 8> extractionIndices;
1465             getExtractionIndices(xferOp, extractionIndices);
1466             extractionIndices.push_back(b.getI64IntegerAttr(i));
1467 
1468             auto extracted =
1469                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1470             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1471             Value xferVec;
1472             if (inputVectorTy.getRank() == 1) {
1473               // When target-rank=0, unrolling would causes the vector input
1474               // argument into `transfer_write` to become a scalar. We solve
1475               // this by broadcasting the scalar to a 0D vector.
1476               xferVec = b.create<vector::BroadcastOp>(
1477                   loc, VectorType::get({}, extracted.getType()), extracted);
1478             } else {
1479               xferVec = extracted;
1480             }
1481             auto newXferOp = b.create<vector::TransferWriteOp>(
1482                 loc, sourceType, xferVec, source, xferIndices,
1483                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1484                 inBoundsAttr);
1485 
1486             maybeAssignMask(b, xferOp, newXferOp, i);
1487 
1488             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1489           },
1490           /*outOfBoundsCase=*/
1491           [&](OpBuilder &b, Location loc) {
1492             return isTensorOp(xferOp) ? source : Value();
1493           });
1494 
1495       if (isTensorOp(xferOp))
1496         source = updatedSource;
1497     }
1498 
1499     if (isTensorOp(xferOp))
1500       rewriter.replaceOp(xferOp, source);
1501     else
1502       rewriter.eraseOp(xferOp);
1503 
1504     return success();
1505   }
1506 };
1507 
1508 } // namespace lowering_n_d_unrolled
1509 
1510 namespace lowering_1_d {
1511 
1512 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1513 /// part of TransferOp1dConversion. Return the memref dimension on which
1514 /// the transfer is operating. A return value of std::nullopt indicates a
1515 /// broadcast.
1516 template <typename OpTy>
1517 static std::optional<int64_t>
1518 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1519                    SmallVector<Value, 8> &memrefIndices) {
1520   auto indices = xferOp.getIndices();
1521   auto map = xferOp.getPermutationMap();
1522   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1523 
1524   memrefIndices.append(indices.begin(), indices.end());
1525   assert(map.getNumResults() == 1 &&
1526          "Expected 1 permutation map result for 1D transfer");
1527   if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1528     Location loc = xferOp.getLoc();
1529     auto dim = expr.getPosition();
1530     AffineExpr d0, d1;
1531     bindDims(xferOp.getContext(), d0, d1);
1532     Value offset = memrefIndices[dim];
1533     memrefIndices[dim] =
1534         affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1535     return dim;
1536   }
1537 
1538   assert(xferOp.isBroadcastDim(0) &&
1539          "Expected AffineDimExpr or AffineConstantExpr");
1540   return std::nullopt;
1541 }
1542 
1543 /// Codegen strategy for TransferOp1dConversion, depending on the
1544 /// operation.
1545 template <typename OpTy>
1546 struct Strategy1d;
1547 
1548 /// Codegen strategy for TransferReadOp.
1549 template <>
1550 struct Strategy1d<TransferReadOp> {
1551   static void generateForLoopBody(OpBuilder &b, Location loc,
1552                                   TransferReadOp xferOp, Value iv,
1553                                   ValueRange loopState) {
1554     SmallVector<Value, 8> indices;
1555     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1556     auto vec = loopState[0];
1557 
1558     // In case of out-of-bounds access, leave `vec` as is (was initialized with
1559     // padding value).
1560     auto nextVec = generateInBoundsCheck(
1561         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1562         /*inBoundsCase=*/
1563         [&](OpBuilder &b, Location loc) {
1564           Value val =
1565               b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
1566           return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1567         },
1568         /*outOfBoundsCase=*/
1569         [&](OpBuilder & /*b*/, Location loc) { return vec; });
1570     b.create<scf::YieldOp>(loc, nextVec);
1571   }
1572 
1573   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1574     // Inititalize vector with padding value.
1575     Location loc = xferOp.getLoc();
1576     return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1577                                      xferOp.getPadding());
1578   }
1579 };
1580 
1581 /// Codegen strategy for TransferWriteOp.
1582 template <>
1583 struct Strategy1d<TransferWriteOp> {
1584   static void generateForLoopBody(OpBuilder &b, Location loc,
1585                                   TransferWriteOp xferOp, Value iv,
1586                                   ValueRange /*loopState*/) {
1587     SmallVector<Value, 8> indices;
1588     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1589 
1590     // Nothing to do in case of out-of-bounds access.
1591     generateInBoundsCheck(
1592         b, xferOp, iv, dim,
1593         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1594           auto val =
1595               b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1596           b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
1597         });
1598     b.create<scf::YieldOp>(loc);
1599   }
1600 
1601   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1602     return Value();
1603   }
1604 };
1605 
1606 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1607 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1608 /// vector load/stores due to non-unit strides or broadcasts:
1609 ///
1610 /// * Transfer dimension is not the last memref dimension
1611 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1612 /// * Memref has a layout map with non-unit stride on the last dimension
1613 ///
1614 /// This pattern generates IR as follows:
1615 ///
1616 /// 1. Generate a for loop iterating over each vector element.
1617 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1618 ///    depending on OpTy.
1619 ///
1620 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1621 ///       can be generated instead of TransferOp1dConversion. Add such a pattern
1622 ///       to ConvertVectorToLLVM.
1623 ///
1624 /// E.g.:
1625 /// ```
1626 /// vector.transfer_write %vec, %A[%a, %b]
1627 ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1628 ///    : vector<9xf32>, memref<?x?xf32>
1629 /// ```
1630 /// Is rewritten to approximately the following pseudo-IR:
1631 /// ```
1632 /// for i = 0 to 9 {
1633 ///   %t = vector.extractelement %vec[i] : vector<9xf32>
1634 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1635 /// }
1636 /// ```
1637 template <typename OpTy>
1638 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1639   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1640 
1641   LogicalResult matchAndRewrite(OpTy xferOp,
1642                                 PatternRewriter &rewriter) const override {
1643     // TODO: support 0-d corner case.
1644     if (xferOp.getTransferRank() == 0)
1645       return failure();
1646     auto map = xferOp.getPermutationMap();
1647     auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1648 
1649     if (!memRefType)
1650       return failure();
1651     if (xferOp.getVectorType().getRank() != 1)
1652       return failure();
1653     if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
1654       return failure(); // Handled by ConvertVectorToLLVM
1655 
1656     // Loop bounds, step, state...
1657     Location loc = xferOp.getLoc();
1658     auto vecType = xferOp.getVectorType();
1659     auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1660     Value ub =
1661         rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1662     if (vecType.isScalable()) {
1663       Value vscale =
1664           rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
1665       ub = rewriter.create<arith::MulIOp>(loc, ub, vscale);
1666     }
1667     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1668     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1669 
1670     // Generate for loop.
1671     rewriter.replaceOpWithNewOp<scf::ForOp>(
1672         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1673         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1674           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1675         });
1676 
1677     return success();
1678   }
1679 };
1680 
1681 } // namespace lowering_1_d
1682 } // namespace
1683 
1684 void mlir::populateVectorToSCFConversionPatterns(
1685     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1686   if (options.unroll) {
1687     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1688                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1689         patterns.getContext(), options);
1690   } else {
1691     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1692                  lowering_n_d::PrepareTransferWriteConversion,
1693                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1694                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1695         patterns.getContext(), options);
1696   }
1697   if (options.lowerScalable) {
1698     patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
1699         patterns.getContext(), options);
1700   }
1701   if (options.targetRank == 1) {
1702     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1703                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1704         patterns.getContext(), options);
1705   }
1706   patterns.add<lowering_n_d::DecomposePrintOpConversion>(patterns.getContext(),
1707                                                          options);
1708 }
1709 
1710 namespace {
1711 
1712 struct ConvertVectorToSCFPass
1713     : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1714   ConvertVectorToSCFPass() = default;
1715   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1716     this->fullUnroll = options.unroll;
1717     this->targetRank = options.targetRank;
1718     this->lowerTensors = options.lowerTensors;
1719     this->lowerScalable = options.lowerScalable;
1720   }
1721 
1722   void runOnOperation() override {
1723     VectorTransferToSCFOptions options;
1724     options.unroll = fullUnroll;
1725     options.targetRank = targetRank;
1726     options.lowerTensors = lowerTensors;
1727     options.lowerScalable = lowerScalable;
1728 
1729     // Lower permutation maps first.
1730     RewritePatternSet lowerTransferPatterns(&getContext());
1731     mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1732         lowerTransferPatterns);
1733     (void)applyPatternsGreedily(getOperation(),
1734                                 std::move(lowerTransferPatterns));
1735 
1736     RewritePatternSet patterns(&getContext());
1737     populateVectorToSCFConversionPatterns(patterns, options);
1738     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
1739   }
1740 };
1741 
1742 } // namespace
1743 
1744 std::unique_ptr<Pass>
1745 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1746   return std::make_unique<ConvertVectorToSCFPass>(options);
1747 }
1748