xref: /llvm-project/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (revision 3a1ae2f46db473cfde4baa6e1b090f5dae67e8db)
1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <numeric>
14 #include <optional>
15 #include <type_traits>
16 
17 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
18 
19 #include "mlir/Dialect/Affine/IR/AffineOps.h"
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/MemRef/IR/MemRef.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/Dialect/Tensor/IR/Tensor.h"
24 #include "mlir/Dialect/Vector/IR/VectorOps.h"
25 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
26 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/ImplicitLocOpBuilder.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
31 #include "mlir/Transforms/Passes.h"
32 
33 namespace mlir {
34 #define GEN_PASS_DEF_CONVERTVECTORTOSCF
35 #include "mlir/Conversion/Passes.h.inc"
36 } // namespace mlir
37 
38 using namespace mlir;
39 using vector::TransferReadOp;
40 using vector::TransferWriteOp;
41 
42 namespace {
43 
44 /// Attribute name used for labeling transfer ops during progressive lowering.
45 static const char kPassLabel[] = "__vector_to_scf_lowering__";
46 
47 /// Patterns that inherit from this struct have access to
48 /// VectorTransferToSCFOptions.
49 template <typename OpTy>
50 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
51   explicit VectorToSCFPattern(MLIRContext *context,
52                               VectorTransferToSCFOptions opt)
53       : OpRewritePattern<OpTy>(context), options(opt) {}
54 
55   VectorTransferToSCFOptions options;
56 };
57 
58 /// Given a vector transfer op, calculate which dimension of the `source`
59 /// memref should be unpacked in the next application of TransferOpConversion.
60 /// A return value of std::nullopt indicates a broadcast.
61 template <typename OpTy>
62 static std::optional<int64_t> unpackedDim(OpTy xferOp) {
63   // TODO: support 0-d corner case.
64   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
65   auto map = xferOp.getPermutationMap();
66   if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
67     return expr.getPosition();
68   }
69   assert(xferOp.isBroadcastDim(0) &&
70          "Expected AffineDimExpr or AffineConstantExpr");
71   return std::nullopt;
72 }
73 
74 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
75 /// map is identical to the current permutation map, but the first result is
76 /// omitted.
77 template <typename OpTy>
78 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
79   // TODO: support 0-d corner case.
80   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
81   auto map = xferOp.getPermutationMap();
82   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
83                         b.getContext());
84 }
85 
86 /// Calculate the indices for the new vector transfer op.
87 ///
88 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
89 ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
90 ///                                 ^^^^^^
91 ///              `iv` is the iteration variable of the (new) surrounding loop.
92 template <typename OpTy>
93 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
94                            SmallVector<Value, 8> &indices) {
95   typename OpTy::Adaptor adaptor(xferOp);
96   // Corresponding memref dim of the vector dim that is unpacked.
97   auto dim = unpackedDim(xferOp);
98   auto prevIndices = adaptor.getIndices();
99   indices.append(prevIndices.begin(), prevIndices.end());
100 
101   Location loc = xferOp.getLoc();
102   bool isBroadcast = !dim.has_value();
103   if (!isBroadcast) {
104     AffineExpr d0, d1;
105     bindDims(xferOp.getContext(), d0, d1);
106     Value offset = adaptor.getIndices()[*dim];
107     indices[*dim] =
108         affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
109   }
110 }
111 
112 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
113                             Value value) {
114   if (hasRetVal) {
115     assert(value && "Expected non-empty value");
116     b.create<scf::YieldOp>(loc, value);
117   } else {
118     b.create<scf::YieldOp>(loc);
119   }
120 }
121 
122 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
123 /// is set to true. No such check is generated under following circumstances:
124 /// * xferOp does not have a mask.
125 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
126 ///   computed and attached to the new transfer op in the pattern.)
127 /// * The to-be-unpacked dim of xferOp is a broadcast.
128 template <typename OpTy>
129 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
130   if (!xferOp.getMask())
131     return Value();
132   if (xferOp.getMaskType().getRank() != 1)
133     return Value();
134   if (xferOp.isBroadcastDim(0))
135     return Value();
136 
137   Location loc = xferOp.getLoc();
138   return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
139 }
140 
141 /// Helper function TransferOpConversion and TransferOp1dConversion.
142 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
143 /// specified dimension `dim` with the loop iteration variable `iv`.
144 /// E.g., when unpacking dimension 0 from:
145 /// ```
146 /// %vec = vector.transfer_read %A[%a, %b] %cst
147 ///     : vector<5x4xf32>, memref<?x?xf32>
148 /// ```
149 /// An if check similar to this will be generated inside the loop:
150 /// ```
151 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
152 /// if (%a + iv < %d) {
153 ///   (in-bounds case)
154 /// } else {
155 ///   (out-of-bounds case)
156 /// }
157 /// ```
158 ///
159 /// If the transfer is 1D and has a mask, this function generates a more complex
160 /// check also accounts for potentially masked out elements.
161 ///
162 /// This function variant returns the value returned by `inBoundsCase` or
163 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
164 /// `resultTypes`.
165 template <typename OpTy>
166 static Value generateInBoundsCheck(
167     OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
168     TypeRange resultTypes,
169     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
170     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
171   bool hasRetVal = !resultTypes.empty();
172   Value cond; // Condition to be built...
173 
174   // Condition check 1: Access in-bounds?
175   bool isBroadcast = !dim; // No in-bounds check for broadcasts.
176   Location loc = xferOp.getLoc();
177   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
178   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
179     Value memrefDim =
180         vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
181     AffineExpr d0, d1;
182     bindDims(xferOp.getContext(), d0, d1);
183     Value base = xferOp.getIndices()[*dim];
184     Value memrefIdx =
185         affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
186     cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
187                                     memrefIdx);
188   }
189 
190   // Condition check 2: Masked in?
191   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
192     if (cond)
193       cond = lb.create<arith::AndIOp>(cond, maskCond);
194     else
195       cond = maskCond;
196   }
197 
198   // If the condition is non-empty, generate an SCF::IfOp.
199   if (cond) {
200     auto check = lb.create<scf::IfOp>(
201         cond,
202         /*thenBuilder=*/
203         [&](OpBuilder &b, Location loc) {
204           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
205         },
206         /*elseBuilder=*/
207         [&](OpBuilder &b, Location loc) {
208           if (outOfBoundsCase) {
209             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
210           } else {
211             b.create<scf::YieldOp>(loc);
212           }
213         });
214 
215     return hasRetVal ? check.getResult(0) : Value();
216   }
217 
218   // Condition is empty, no need for an SCF::IfOp.
219   return inBoundsCase(b, loc);
220 }
221 
222 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
223 /// a return value. Consequently, this function does not have a return value.
224 template <typename OpTy>
225 static void generateInBoundsCheck(
226     OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
227     function_ref<void(OpBuilder &, Location)> inBoundsCase,
228     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
229   generateInBoundsCheck(
230       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
231       /*inBoundsCase=*/
232       [&](OpBuilder &b, Location loc) {
233         inBoundsCase(b, loc);
234         return Value();
235       },
236       /*outOfBoundsCase=*/
237       [&](OpBuilder &b, Location loc) {
238         if (outOfBoundsCase)
239           outOfBoundsCase(b, loc);
240         return Value();
241       });
242 }
243 
244 /// Given an ArrayAttr, return a copy where the first element is dropped.
245 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
246   if (!attr)
247     return attr;
248   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
249 }
250 
251 /// Add the pass label to a vector transfer op if its rank is not the target
252 /// rank.
253 template <typename OpTy>
254 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
255                                 unsigned targetRank) {
256   if (newXferOp.getVectorType().getRank() > targetRank)
257     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
258 }
259 
260 /// Return true if this transfer op operates on a source tensor.
261 template <typename OpTy>
262 static bool isTensorOp(OpTy xferOp) {
263   if (isa<RankedTensorType>(xferOp.getShapedType())) {
264     if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
265       // TransferWriteOps on tensors have a result.
266       assert(xferOp->getNumResults() > 0);
267     }
268     return true;
269   }
270   return false;
271 }
272 
273 namespace lowering_n_d {
274 
275 /// Helper data structure for data and mask buffers.
276 struct BufferAllocs {
277   Value dataBuffer;
278   Value maskBuffer;
279 };
280 
281 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
282 static Operation *getAutomaticAllocationScope(Operation *op) {
283   Operation *scope =
284       op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
285   assert(scope && "Expected op to be inside automatic allocation scope");
286   return scope;
287 }
288 
289 /// Allocate temporary buffers for data (vector) and mask (if present).
290 template <typename OpTy>
291 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
292   Location loc = xferOp.getLoc();
293   OpBuilder::InsertionGuard guard(b);
294   Operation *scope = getAutomaticAllocationScope(xferOp);
295   assert(scope->getNumRegions() == 1 &&
296          "AutomaticAllocationScope with >1 regions");
297   b.setInsertionPointToStart(&scope->getRegion(0).front());
298 
299   BufferAllocs result;
300   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
301   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
302 
303   if (xferOp.getMask()) {
304     auto maskType = MemRefType::get({}, xferOp.getMask().getType());
305     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
306     b.setInsertionPoint(xferOp);
307     b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
308     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange());
309   }
310 
311   return result;
312 }
313 
314 /// Given a MemRefType with VectorType element type, unpack one dimension from
315 /// the VectorType into the MemRefType.
316 ///
317 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
318 static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
319   auto vectorType = dyn_cast<VectorType>(type.getElementType());
320   // Vectors with leading scalable dims are not supported.
321   // It may be possible to support these in future by using dynamic memref dims.
322   if (vectorType.getScalableDims().front())
323     return failure();
324   auto memrefShape = type.getShape();
325   SmallVector<int64_t, 8> newMemrefShape;
326   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
327   newMemrefShape.push_back(vectorType.getDimSize(0));
328   return MemRefType::get(newMemrefShape,
329                          VectorType::Builder(vectorType).dropDim(0));
330 }
331 
332 /// Given a transfer op, find the memref from which the mask is loaded. This
333 /// is similar to Strategy<TransferWriteOp>::getBuffer.
334 template <typename OpTy>
335 static Value getMaskBuffer(OpTy xferOp) {
336   assert(xferOp.getMask() && "Expected that transfer op has mask");
337   auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
338   assert(loadOp && "Expected transfer op mask produced by LoadOp");
339   return loadOp.getMemRef();
340 }
341 
342 /// Codegen strategy, depending on the operation.
343 template <typename OpTy>
344 struct Strategy;
345 
346 /// Code strategy for vector TransferReadOp.
347 template <>
348 struct Strategy<TransferReadOp> {
349   /// Find the StoreOp that is used for writing the current TransferReadOp's
350   /// result to the temporary buffer allocation.
351   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
352     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
353     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
354     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
355     return storeOp;
356   }
357 
358   /// Find the temporary buffer allocation. All labeled TransferReadOps are
359   /// used like this, where %buf is either the buffer allocation or a type cast
360   /// of the buffer allocation:
361   /// ```
362   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
363   /// memref.store %vec, %buf[...] ...
364   /// ```
365   static Value getBuffer(TransferReadOp xferOp) {
366     return getStoreOp(xferOp).getMemRef();
367   }
368 
369   /// Retrieve the indices of the current StoreOp that stores into the buffer.
370   static void getBufferIndices(TransferReadOp xferOp,
371                                SmallVector<Value, 8> &indices) {
372     memref::StoreOp storeOp = getStoreOp(xferOp);
373     auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
374     indices.append(prevIndices.begin(), prevIndices.end());
375   }
376 
377   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
378   /// accesses on the to-be-unpacked dimension.
379   ///
380   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
381   ///    variable `iv`.
382   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
383   ///
384   /// E.g.:
385   /// ```
386   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
387   ///     : memref<?x?x?xf32>, vector<4x3xf32>
388   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
389   /// ```
390   /// Is rewritten to:
391   /// ```
392   /// %casted = vector.type_cast %buf
393   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
394   /// for %j = 0 to 4 {
395   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
396   ///       : memref<?x?x?xf32>, vector<3xf32>
397   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
398   /// }
399   /// ```
400   ///
401   /// Note: The loop and type cast are generated in TransferOpConversion.
402   ///       The original TransferReadOp and store op are deleted in `cleanup`.
403   /// Note: The `mask` operand is set in TransferOpConversion.
404   static TransferReadOp rewriteOp(OpBuilder &b,
405                                   VectorTransferToSCFOptions options,
406                                   TransferReadOp xferOp, Value buffer, Value iv,
407                                   ValueRange /*loopState*/) {
408     SmallVector<Value, 8> storeIndices;
409     getBufferIndices(xferOp, storeIndices);
410     storeIndices.push_back(iv);
411 
412     SmallVector<Value, 8> xferIndices;
413     getXferIndices(b, xferOp, iv, xferIndices);
414 
415     Location loc = xferOp.getLoc();
416     auto bufferType = dyn_cast<ShapedType>(buffer.getType());
417     auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
418     auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
419     auto newXferOp = b.create<vector::TransferReadOp>(
420         loc, vecType, xferOp.getSource(), xferIndices,
421         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
422         xferOp.getPadding(), Value(), inBoundsAttr);
423 
424     maybeApplyPassLabel(b, newXferOp, options.targetRank);
425 
426     b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
427     return newXferOp;
428   }
429 
430   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
431   /// padding value to the temporary buffer.
432   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
433                                     Value buffer, Value iv,
434                                     ValueRange /*loopState*/) {
435     SmallVector<Value, 8> storeIndices;
436     getBufferIndices(xferOp, storeIndices);
437     storeIndices.push_back(iv);
438 
439     Location loc = xferOp.getLoc();
440     auto bufferType = dyn_cast<ShapedType>(buffer.getType());
441     auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
442     auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
443     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
444 
445     return Value();
446   }
447 
448   /// Cleanup after rewriting the op.
449   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
450                       scf::ForOp /*forOp*/) {
451     rewriter.eraseOp(getStoreOp(xferOp));
452     rewriter.eraseOp(xferOp);
453   }
454 
455   /// Return the initial loop state for the generated scf.for loop.
456   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
457 };
458 
459 /// Codegen strategy for vector TransferWriteOp.
460 template <>
461 struct Strategy<TransferWriteOp> {
462   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
463   /// used like this, where %buf is either the buffer allocation or a type cast
464   /// of the buffer allocation:
465   /// ```
466   /// %vec = memref.load %buf[...] ...
467   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
468   /// ```
469   static Value getBuffer(TransferWriteOp xferOp) {
470     auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
471     assert(loadOp && "Expected transfer op vector produced by LoadOp");
472     return loadOp.getMemRef();
473   }
474 
475   /// Retrieve the indices of the current LoadOp that loads from the buffer.
476   static void getBufferIndices(TransferWriteOp xferOp,
477                                SmallVector<Value, 8> &indices) {
478     auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
479     auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
480     indices.append(prevIndices.begin(), prevIndices.end());
481   }
482 
483   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
484   /// accesses on the to-be-unpacked dimension.
485   ///
486   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
487   ///    using the loop iteration variable `iv`.
488   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
489   ///    to memory.
490   ///
491   /// Note: For more details, see comments on Strategy<TransferReadOp>.
492   static TransferWriteOp rewriteOp(OpBuilder &b,
493                                    VectorTransferToSCFOptions options,
494                                    TransferWriteOp xferOp, Value buffer,
495                                    Value iv, ValueRange loopState) {
496     SmallVector<Value, 8> loadIndices;
497     getBufferIndices(xferOp, loadIndices);
498     loadIndices.push_back(iv);
499 
500     SmallVector<Value, 8> xferIndices;
501     getXferIndices(b, xferOp, iv, xferIndices);
502 
503     Location loc = xferOp.getLoc();
504     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
505     auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
506     auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
507     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
508     auto newXferOp = b.create<vector::TransferWriteOp>(
509         loc, type, vec, source, xferIndices,
510         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
511         inBoundsAttr);
512 
513     maybeApplyPassLabel(b, newXferOp, options.targetRank);
514 
515     return newXferOp;
516   }
517 
518   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
519   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
520                                     Value buffer, Value iv,
521                                     ValueRange loopState) {
522     return isTensorOp(xferOp) ? loopState[0] : Value();
523   }
524 
525   /// Cleanup after rewriting the op.
526   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
527                       scf::ForOp forOp) {
528     if (isTensorOp(xferOp)) {
529       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
530       rewriter.replaceOp(xferOp, forOp->getResult(0));
531     } else {
532       rewriter.eraseOp(xferOp);
533     }
534   }
535 
536   /// Return the initial loop state for the generated scf.for loop.
537   static Value initialLoopState(TransferWriteOp xferOp) {
538     return isTensorOp(xferOp) ? xferOp.getSource() : Value();
539   }
540 };
541 
542 template <typename OpTy>
543 LogicalResult checkPrepareXferOp(OpTy xferOp,
544                                  VectorTransferToSCFOptions options) {
545   if (xferOp->hasAttr(kPassLabel))
546     return failure();
547   if (xferOp.getVectorType().getRank() <= options.targetRank)
548     return failure();
549   // Currently the unpacking of the leading dimension into the memref is not
550   // supported for scalable dimensions.
551   if (xferOp.getVectorType().getScalableDims().front())
552     return failure();
553   if (isTensorOp(xferOp) && !options.lowerTensors)
554     return failure();
555   // Transfer ops that modify the element type are not supported atm.
556   if (xferOp.getVectorType().getElementType() !=
557       xferOp.getShapedType().getElementType())
558     return failure();
559   return success();
560 }
561 
562 /// Prepare a TransferReadOp for progressive lowering.
563 ///
564 /// 1. Allocate a temporary buffer.
565 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
566 /// 3. Store the result of the TransferReadOp into the temporary buffer.
567 /// 4. Load the result from the temporary buffer and replace all uses of the
568 ///    original TransferReadOp with this load.
569 ///
570 /// E.g.:
571 /// ```
572 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
573 ///     : vector<5x4xf32>, memref<?x?x?xf32>
574 /// ```
575 /// is rewritten to:
576 /// ```
577 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
578 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
579 ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
580 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
581 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
582 /// ```
583 ///
584 /// Note: A second temporary buffer may be allocated for the `mask` operand.
585 struct PrepareTransferReadConversion
586     : public VectorToSCFPattern<TransferReadOp> {
587   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
588 
589   LogicalResult matchAndRewrite(TransferReadOp xferOp,
590                                 PatternRewriter &rewriter) const override {
591     if (checkPrepareXferOp(xferOp, options).failed())
592       return failure();
593 
594     BufferAllocs buffers = allocBuffers(rewriter, xferOp);
595     Operation *newXfer = rewriter.clone(*xferOp.getOperation());
596     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
597     if (xferOp.getMask()) {
598       dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
599           buffers.maskBuffer);
600     }
601 
602     Location loc = xferOp.getLoc();
603     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
604                                      buffers.dataBuffer);
605     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
606 
607     return success();
608   }
609 };
610 
611 /// Prepare a TransferWriteOp for progressive lowering.
612 ///
613 /// 1. Allocate a temporary buffer.
614 /// 2. Store the vector into the buffer.
615 /// 3. Load the vector from the buffer again.
616 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
617 ///    marking it eligible for progressive lowering via TransferOpConversion.
618 ///
619 /// E.g.:
620 /// ```
621 /// vector.transfer_write %vec, %A[%a, %b, %c]
622 ///     : vector<5x4xf32>, memref<?x?x?xf32>
623 /// ```
624 /// is rewritten to:
625 /// ```
626 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
627 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
628 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
629 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
630 ///     : vector<5x4xf32>, memref<?x?x?xf32>
631 /// ```
632 ///
633 /// Note: A second temporary buffer may be allocated for the `mask` operand.
634 struct PrepareTransferWriteConversion
635     : public VectorToSCFPattern<TransferWriteOp> {
636   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
637 
638   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
639                                 PatternRewriter &rewriter) const override {
640     if (checkPrepareXferOp(xferOp, options).failed())
641       return failure();
642 
643     Location loc = xferOp.getLoc();
644     auto buffers = allocBuffers(rewriter, xferOp);
645     rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
646                                      buffers.dataBuffer);
647     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
648     rewriter.updateRootInPlace(xferOp, [&]() {
649       xferOp.getVectorMutable().assign(loadedVec);
650       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
651     });
652 
653     if (xferOp.getMask()) {
654       rewriter.updateRootInPlace(xferOp, [&]() {
655         xferOp.getMaskMutable().assign(buffers.maskBuffer);
656       });
657     }
658 
659     return success();
660   }
661 };
662 
663 /// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows
664 /// printing both 1D scalable vectors and n-D fixed size vectors.
665 ///
666 /// E.g.:
667 /// ```
668 /// vector.print %v : vector<[4]xi32>
669 /// ```
670 /// is rewritten to:
671 /// ```
672 /// %c0 = arith.constant 0 : index
673 /// %c4 = arith.constant 4 : index
674 /// %c1 = arith.constant 1 : index
675 /// %vscale = vector.vscale
676 /// %length = arith.muli %vscale, %c4 : index
677 /// %lastIndex = arith.subi %length, %c1 : index
678 /// vector.print punctuation <open>
679 /// scf.for %i = %c0 to %length step %c1 {
680 ///   %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
681 ///   vector.print %el : i32 punctuation <no_punctuation>
682 ///   %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
683 ///   scf.if %notLastIndex {
684 ///     vector.print punctuation <comma>
685 ///   }
686 /// }
687 /// vector.print punctuation <close>
688 /// vector.print
689 /// ```
690 struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
691   using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
692   LogicalResult matchAndRewrite(vector::PrintOp printOp,
693                                 PatternRewriter &rewriter) const override {
694     if (!printOp.getSource())
695       return failure();
696 
697     VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
698     if (!vectorType)
699       return failure();
700 
701     // Currently >= 2D scalable vectors are not supported.
702     // These can't be lowered to LLVM (as LLVM does not support scalable vectors
703     // of scalable vectors), and due to limitations of current ops can't be
704     // indexed with SSA values or flattened. This may change after
705     // https://reviews.llvm.org/D155034, though there still needs to be a path
706     // for lowering to LLVM.
707     if (vectorType.getRank() > 1 && vectorType.isScalable())
708       return failure();
709 
710     auto loc = printOp.getLoc();
711     auto value = printOp.getSource();
712 
713     if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
714       // Oddly sized integers are (somewhat) buggy on a lot of backends, so to
715       // avoid issues extend them to a more standard size.
716       // https://github.com/llvm/llvm-project/issues/30613
717       auto width = intTy.getWidth();
718       auto legalWidth = llvm::NextPowerOf2(std::max(8u, width) - 1);
719       auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth,
720                                          intTy.getSignedness());
721       // arith can only take signless integers, so we must cast back and forth.
722       auto signlessSourceVectorType =
723           vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
724       auto signlessTargetVectorType =
725           vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
726       auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
727       value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType,
728                                                  value);
729       if (value.getType() != signlessTargetVectorType) {
730         if (width == 1 || intTy.isUnsigned())
731           value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType,
732                                                   value);
733         else
734           value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType,
735                                                   value);
736       }
737       value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value);
738       vectorType = targetVectorType;
739     }
740 
741     auto scalableDimensions = vectorType.getScalableDims();
742     auto shape = vectorType.getShape();
743     constexpr int64_t singletonShape[] = {1};
744     if (vectorType.getRank() == 0)
745       shape = singletonShape;
746 
747     if (vectorType.getRank() != 1) {
748       // Flatten n-D vectors to 1D. This is done to allow indexing with a
749       // non-constant value (which can currently only be done via
750       // vector.extractelement for 1D vectors).
751       auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
752                                         std::multiplies<int64_t>());
753       auto flatVectorType =
754           VectorType::get({flatLength}, vectorType.getElementType());
755       value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
756     }
757 
758     vector::PrintOp firstClose;
759     SmallVector<Value, 8> loopIndices;
760     for (unsigned d = 0; d < shape.size(); d++) {
761       // Setup loop bounds and step.
762       Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
763       Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]);
764       Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
765       if (!scalableDimensions.empty() && scalableDimensions[d]) {
766         auto vscale = rewriter.create<vector::VectorScaleOp>(
767             loc, rewriter.getIndexType());
768         upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale);
769       }
770       auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step);
771 
772       // Create a loop to print the elements surrounded by parentheses.
773       rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
774       auto loop =
775           rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
776       auto printClose = rewriter.create<vector::PrintOp>(
777           loc, vector::PrintPunctuation::Close);
778       if (!firstClose)
779         firstClose = printClose;
780 
781       auto loopIdx = loop.getInductionVar();
782       loopIndices.push_back(loopIdx);
783 
784       // Print a comma after all but the last element.
785       rewriter.setInsertionPointToStart(loop.getBody());
786       auto notLastIndex = rewriter.create<arith::CmpIOp>(
787           loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
788       rewriter.create<scf::IfOp>(loc, notLastIndex,
789                                  [&](OpBuilder &builder, Location loc) {
790                                    builder.create<vector::PrintOp>(
791                                        loc, vector::PrintPunctuation::Comma);
792                                    builder.create<scf::YieldOp>(loc);
793                                  });
794 
795       rewriter.setInsertionPointToStart(loop.getBody());
796     }
797 
798     // Compute the flattened index.
799     // Note: For the > rank 1 vectors this assumes non-scalable.
800     Value flatIndex;
801     auto currentStride = 1;
802     for (int d = shape.size() - 1; d >= 0; d--) {
803       auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride);
804       auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]);
805       if (flatIndex)
806         flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index);
807       else
808         flatIndex = index;
809       currentStride *= shape[d];
810     }
811 
812     // Print the scalar elements in the inner most loop.
813     auto element =
814         rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
815     rewriter.create<vector::PrintOp>(loc, element,
816                                      vector::PrintPunctuation::NoPunctuation);
817 
818     rewriter.setInsertionPointAfter(firstClose);
819     rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation());
820     rewriter.eraseOp(printOp);
821     return success();
822   }
823 
824   static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
825     return IntegerType::get(intTy.getContext(), intTy.getWidth(),
826                             IntegerType::Signless);
827   };
828 };
829 
830 /// Progressive lowering of vector transfer ops: Unpack one dimension.
831 ///
832 /// 1. Unpack one dimension from the current buffer type and cast the buffer
833 ///    to that new type. E.g.:
834 ///    ```
835 ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
836 ///    vector.transfer_write %vec ...
837 ///    ```
838 ///    The following cast is generated:
839 ///    ```
840 ///    %casted = vector.type_cast %0
841 ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
842 ///    ```
843 /// 2. Generate a for loop and rewrite the transfer op according to the
844 ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
845 ///    out-of-bounds, generate an if-check and handle both cases separately.
846 /// 3. Clean up according to the corresponding Strategy<OpTy>.
847 ///
848 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
849 /// source (as opposed to a memref source), then each iteration of the generated
850 /// scf.for loop yields the new tensor value. E.g.:
851 /// ```
852 /// %result = scf.for i = 0 to 5 {
853 ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
854 ///   %1 = vector.transfer_write %0, %source[...]
855 ///       : vector<4x3xf32>, tensor<5x4x3xf32>
856 ///   scf.yield %1 : tensor<5x4x3xf32>
857 /// }
858 /// ```
859 template <typename OpTy>
860 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
861   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
862 
863   void initialize() {
864     // This pattern recursively unpacks one dimension at a time. The recursion
865     // bounded as the rank is strictly decreasing.
866     this->setHasBoundedRewriteRecursion();
867   }
868 
869   LogicalResult matchAndRewrite(OpTy xferOp,
870                                 PatternRewriter &rewriter) const override {
871     if (!xferOp->hasAttr(kPassLabel))
872       return failure();
873 
874     // Find and cast data buffer. How the buffer can be found depends on OpTy.
875     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
876     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
877     auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
878     auto castedDataType = unpackOneDim(dataBufferType);
879     if (failed(castedDataType))
880       return failure();
881 
882     auto castedDataBuffer =
883         locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer);
884 
885     // If the xferOp has a mask: Find and cast mask buffer.
886     Value castedMaskBuffer;
887     if (xferOp.getMask()) {
888       Value maskBuffer = getMaskBuffer(xferOp);
889       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
890         // Do not unpack a dimension of the mask, if:
891         // * To-be-unpacked transfer op dimension is a broadcast.
892         // * Mask is 1D, i.e., the mask cannot be further unpacked.
893         //   (That means that all remaining dimensions of the transfer op must
894         //   be broadcasted.)
895         castedMaskBuffer = maskBuffer;
896       } else {
897         // It's safe to assume the mask buffer can be unpacked if the data
898         // buffer was unpacked.
899         auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
900         MemRefType castedMaskType = *unpackOneDim(maskBufferType);
901         castedMaskBuffer =
902             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
903       }
904     }
905 
906     // Loop bounds and step.
907     auto lb = locB.create<arith::ConstantIndexOp>(0);
908     auto ub = locB.create<arith::ConstantIndexOp>(
909         castedDataType->getDimSize(castedDataType->getRank() - 1));
910     auto step = locB.create<arith::ConstantIndexOp>(1);
911     // TransferWriteOps that operate on tensors return the modified tensor and
912     // require a loop state.
913     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
914 
915     // Generate for loop.
916     auto result = locB.create<scf::ForOp>(
917         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
918         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
919           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
920 
921           auto result = generateInBoundsCheck(
922               b, xferOp, iv, unpackedDim(xferOp),
923               stateType ? TypeRange(stateType) : TypeRange(),
924               /*inBoundsCase=*/
925               [&](OpBuilder &b, Location loc) {
926                 // Create new transfer op.
927                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
928                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
929 
930                 // If old transfer op has a mask: Set mask on new transfer op.
931                 // Special case: If the mask of the old transfer op is 1D and
932                 // the
933                 //               unpacked dim is not a broadcast, no mask is
934                 //               needed on the new transfer op.
935                 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
936                                          xferOp.getMaskType().getRank() > 1)) {
937                   OpBuilder::InsertionGuard guard(b);
938                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
939 
940                   SmallVector<Value, 8> loadIndices;
941                   if (auto memrefType =
942                           castedMaskBuffer.getType().dyn_cast<MemRefType>()) {
943                     // If castedMaskBuffer is a memref, then one dim was
944                     // unpacked; see above.
945                     loadIndices.push_back(iv);
946                   } else {
947                     Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
948                     // In case of broadcast: Use same indices to load from
949                     // memref as before.
950                     if (!xferOp.isBroadcastDim(0))
951                       loadIndices.push_back(iv);
952                   }
953 
954                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
955                                                        loadIndices);
956                   rewriter.updateRootInPlace(newXfer, [&]() {
957                     newXfer.getMaskMutable().assign(mask);
958                   });
959                 }
960 
961                 return loopState.empty() ? Value() : newXfer->getResult(0);
962               },
963               /*outOfBoundsCase=*/
964               [&](OpBuilder &b, Location /*loc*/) {
965                 return Strategy<OpTy>::handleOutOfBoundsDim(
966                     b, xferOp, castedDataBuffer, iv, loopState);
967               });
968 
969           maybeYieldValue(b, loc, !loopState.empty(), result);
970         });
971 
972     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
973     return success();
974   }
975 };
976 
977 } // namespace lowering_n_d
978 
979 namespace lowering_n_d_unrolled {
980 
981 /// If the original transfer op has a mask, compute the mask of the new transfer
982 /// op (for the current iteration `i`) and assign it.
983 template <typename OpTy>
984 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
985                             int64_t i) {
986   if (!xferOp.getMask())
987     return;
988 
989   if (xferOp.isBroadcastDim(0)) {
990     // To-be-unpacked dimension is a broadcast, which does not have a
991     // corresponding mask dimension. Mask attribute remains unchanged.
992     newXferOp.getMaskMutable().assign(xferOp.getMask());
993     return;
994   }
995 
996   if (xferOp.getMaskType().getRank() > 1) {
997     // Unpack one dimension of the mask.
998     OpBuilder::InsertionGuard guard(b);
999     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
1000 
1001     llvm::SmallVector<int64_t, 1> indices({i});
1002     Location loc = xferOp.getLoc();
1003     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
1004     newXferOp.getMaskMutable().assign(newMask);
1005   }
1006 
1007   // If we end up here: The mask of the old transfer op is 1D and the unpacked
1008   // dim is not a broadcast, so no mask is needed on the new transfer op.
1009   // `generateInBoundsCheck` will have evaluated the mask already.
1010 }
1011 
1012 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
1013 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
1014 /// memref buffer is allocated and the SCF loop is fully unrolled.
1015 ///
1016 /// ```
1017 /// E.g.:
1018 /// ```
1019 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
1020 ///     : memref<?x?x?xf32>, vector<5x4xf32>
1021 /// ```
1022 /// is rewritten to IR such as (simplified):
1023 /// ```
1024 /// %v_init = splat %padding : vector<5x4xf32>
1025 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
1026 ///     : memref<?x?x?xf32>, vector<4xf32>
1027 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
1028 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
1029 ///     : memref<?x?x?xf32>, vector<4xf32>
1030 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
1031 /// ...
1032 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
1033 ///     : memref<?x?x?xf32>, vector<4xf32>
1034 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
1035 /// ```
1036 ///
1037 /// Note: As an optimization, if the result of the original TransferReadOp
1038 /// was directly inserted into another vector, no new %v_init vector is created.
1039 /// Instead, the new TransferReadOp results are inserted into that vector.
1040 struct UnrollTransferReadConversion
1041     : public VectorToSCFPattern<TransferReadOp> {
1042   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
1043 
1044   void initialize() {
1045     // This pattern recursively unpacks one dimension at a time. The recursion
1046     // bounded as the rank is strictly decreasing.
1047     setHasBoundedRewriteRecursion();
1048   }
1049 
1050   /// Return the vector into which the newly created TransferReadOp results
1051   /// are inserted.
1052   Value getResultVector(TransferReadOp xferOp,
1053                         PatternRewriter &rewriter) const {
1054     if (auto insertOp = getInsertOp(xferOp))
1055       return insertOp.getDest();
1056     Location loc = xferOp.getLoc();
1057     return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1058                                             xferOp.getPadding());
1059   }
1060 
1061   /// If the result of the TransferReadOp has exactly one user, which is a
1062   /// vector::InsertOp, return that operation.
1063   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
1064     if (xferOp->hasOneUse()) {
1065       Operation *xferOpUser = *xferOp->getUsers().begin();
1066       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
1067         return insertOp;
1068     }
1069 
1070     return vector::InsertOp();
1071   }
1072 
1073   /// If the result of the TransferReadOp has exactly one user, which is a
1074   /// vector::InsertOp, return that operation's indices.
1075   void getInsertionIndices(TransferReadOp xferOp,
1076                            SmallVectorImpl<OpFoldResult> &indices) const {
1077     if (auto insertOp = getInsertOp(xferOp)) {
1078       auto pos = insertOp.getMixedPosition();
1079       indices.append(pos.begin(), pos.end());
1080     }
1081   }
1082 
1083   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1084   /// accesses, and broadcasts and transposes in permutation maps.
1085   LogicalResult matchAndRewrite(TransferReadOp xferOp,
1086                                 PatternRewriter &rewriter) const override {
1087     if (xferOp.getVectorType().getRank() <= options.targetRank)
1088       return failure();
1089     if (isTensorOp(xferOp) && !options.lowerTensors)
1090       return failure();
1091     // Transfer ops that modify the element type are not supported atm.
1092     if (xferOp.getVectorType().getElementType() !=
1093         xferOp.getShapedType().getElementType())
1094       return failure();
1095 
1096     auto insertOp = getInsertOp(xferOp);
1097     auto vec = getResultVector(xferOp, rewriter);
1098     auto vecType = dyn_cast<VectorType>(vec.getType());
1099     auto xferVecType = xferOp.getVectorType();
1100 
1101     if (xferVecType.getScalableDims()[0]) {
1102       // Cannot unroll a scalable dimension at compile time.
1103       return failure();
1104     }
1105 
1106     VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
1107 
1108     int64_t dimSize = xferVecType.getShape()[0];
1109 
1110     // Generate fully unrolled loop of transfer ops.
1111     Location loc = xferOp.getLoc();
1112     for (int64_t i = 0; i < dimSize; ++i) {
1113       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1114 
1115       vec = generateInBoundsCheck(
1116           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
1117           /*inBoundsCase=*/
1118           [&](OpBuilder &b, Location loc) {
1119             // Indices for the new transfer op.
1120             SmallVector<Value, 8> xferIndices;
1121             getXferIndices(b, xferOp, iv, xferIndices);
1122 
1123             // Indices for the new vector.insert op.
1124             SmallVector<OpFoldResult, 8> insertionIndices;
1125             getInsertionIndices(xferOp, insertionIndices);
1126             insertionIndices.push_back(rewriter.getIndexAttr(i));
1127 
1128             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1129             auto newXferOp = b.create<vector::TransferReadOp>(
1130                 loc, newXferVecType, xferOp.getSource(), xferIndices,
1131                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
1132                 xferOp.getPadding(), Value(), inBoundsAttr);
1133             maybeAssignMask(b, xferOp, newXferOp, i);
1134             return b.create<vector::InsertOp>(loc, newXferOp, vec,
1135                                               insertionIndices);
1136           },
1137           /*outOfBoundsCase=*/
1138           [&](OpBuilder &b, Location loc) {
1139             // Loop through original (unmodified) vector.
1140             return vec;
1141           });
1142     }
1143 
1144     if (insertOp) {
1145       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
1146       rewriter.replaceOp(insertOp, vec);
1147       rewriter.eraseOp(xferOp);
1148     } else {
1149       rewriter.replaceOp(xferOp, vec);
1150     }
1151 
1152     return success();
1153   }
1154 };
1155 
1156 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
1157 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
1158 /// memref buffer is allocated and the SCF loop is fully unrolled.
1159 ///
1160 /// ```
1161 /// E.g.:
1162 /// ```
1163 /// vector.transfer_write %vec, %A[%a, %b, %c]
1164 ///     : vector<5x4xf32>, memref<?x?x?xf32>
1165 /// ```
1166 /// is rewritten to IR such as (simplified):
1167 /// ```
1168 /// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32>
1169 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
1170 /// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32>
1171 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
1172 /// ...
1173 /// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32>
1174 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
1175 /// ```
1176 ///
1177 /// Note: As an optimization, if the vector of the original TransferWriteOp
1178 /// was directly extracted from another vector via an ExtractOp `a`, extract
1179 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
1180 /// doing so, `a` may become dead, and the number of ExtractOps generated during
1181 /// recursive application of this pattern will be minimal.
1182 struct UnrollTransferWriteConversion
1183     : public VectorToSCFPattern<TransferWriteOp> {
1184   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
1185 
1186   void initialize() {
1187     // This pattern recursively unpacks one dimension at a time. The recursion
1188     // bounded as the rank is strictly decreasing.
1189     setHasBoundedRewriteRecursion();
1190   }
1191 
1192   /// Return the vector from which newly generated ExtracOps will extract.
1193   Value getDataVector(TransferWriteOp xferOp) const {
1194     if (auto extractOp = getExtractOp(xferOp))
1195       return extractOp.getVector();
1196     return xferOp.getVector();
1197   }
1198 
1199   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
1200   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
1201     if (auto *op = xferOp.getVector().getDefiningOp())
1202       return dyn_cast<vector::ExtractOp>(op);
1203     return vector::ExtractOp();
1204   }
1205 
1206   /// If the input of the given TransferWriteOp is an ExtractOp, return its
1207   /// indices.
1208   void getExtractionIndices(TransferWriteOp xferOp,
1209                             SmallVectorImpl<OpFoldResult> &indices) const {
1210     if (auto extractOp = getExtractOp(xferOp)) {
1211       auto pos = extractOp.getMixedPosition();
1212       indices.append(pos.begin(), pos.end());
1213     }
1214   }
1215 
1216   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1217   /// accesses, and broadcasts and transposes in permutation maps.
1218   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1219                                 PatternRewriter &rewriter) const override {
1220     VectorType inputVectorTy = xferOp.getVectorType();
1221 
1222     if (inputVectorTy.getRank() <= options.targetRank)
1223       return failure();
1224 
1225     if (isTensorOp(xferOp) && !options.lowerTensors)
1226       return failure();
1227     // Transfer ops that modify the element type are not supported atm.
1228     if (inputVectorTy.getElementType() !=
1229         xferOp.getShapedType().getElementType())
1230       return failure();
1231 
1232     auto vec = getDataVector(xferOp);
1233     if (inputVectorTy.getScalableDims()[0]) {
1234       // Cannot unroll a scalable dimension at compile time.
1235       return failure();
1236     }
1237 
1238     int64_t dimSize = inputVectorTy.getShape()[0];
1239     Value source = xferOp.getSource(); // memref or tensor to be written to.
1240     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1241 
1242     // Generate fully unrolled loop of transfer ops.
1243     Location loc = xferOp.getLoc();
1244     for (int64_t i = 0; i < dimSize; ++i) {
1245       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1246 
1247       auto updatedSource = generateInBoundsCheck(
1248           rewriter, xferOp, iv, unpackedDim(xferOp),
1249           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1250           /*inBoundsCase=*/
1251           [&](OpBuilder &b, Location loc) {
1252             // Indices for the new transfer op.
1253             SmallVector<Value, 8> xferIndices;
1254             getXferIndices(b, xferOp, iv, xferIndices);
1255 
1256             // Indices for the new vector.extract op.
1257             SmallVector<OpFoldResult, 8> extractionIndices;
1258             getExtractionIndices(xferOp, extractionIndices);
1259             extractionIndices.push_back(b.getI64IntegerAttr(i));
1260 
1261             auto extracted =
1262                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1263             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1264             Value xferVec;
1265             if (inputVectorTy.getRank() == 1) {
1266               // When target-rank=0, unrolling would causes the vector input
1267               // argument into `transfer_write` to become a scalar. We solve
1268               // this by broadcasting the scalar to a 0D vector.
1269               xferVec = b.create<vector::BroadcastOp>(
1270                   loc, VectorType::get({}, extracted.getType()), extracted);
1271             } else {
1272               xferVec = extracted;
1273             }
1274             auto newXferOp = b.create<vector::TransferWriteOp>(
1275                 loc, sourceType, xferVec, source, xferIndices,
1276                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1277                 inBoundsAttr);
1278 
1279             maybeAssignMask(b, xferOp, newXferOp, i);
1280 
1281             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1282           },
1283           /*outOfBoundsCase=*/
1284           [&](OpBuilder &b, Location loc) {
1285             return isTensorOp(xferOp) ? source : Value();
1286           });
1287 
1288       if (isTensorOp(xferOp))
1289         source = updatedSource;
1290     }
1291 
1292     if (isTensorOp(xferOp))
1293       rewriter.replaceOp(xferOp, source);
1294     else
1295       rewriter.eraseOp(xferOp);
1296 
1297     return success();
1298   }
1299 };
1300 
1301 } // namespace lowering_n_d_unrolled
1302 
1303 namespace lowering_1_d {
1304 
1305 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1306 /// part of TransferOp1dConversion. Return the memref dimension on which
1307 /// the transfer is operating. A return value of std::nullopt indicates a
1308 /// broadcast.
1309 template <typename OpTy>
1310 static std::optional<int64_t>
1311 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1312                    SmallVector<Value, 8> &memrefIndices) {
1313   auto indices = xferOp.getIndices();
1314   auto map = xferOp.getPermutationMap();
1315   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1316 
1317   memrefIndices.append(indices.begin(), indices.end());
1318   assert(map.getNumResults() == 1 &&
1319          "Expected 1 permutation map result for 1D transfer");
1320   if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
1321     Location loc = xferOp.getLoc();
1322     auto dim = expr.getPosition();
1323     AffineExpr d0, d1;
1324     bindDims(xferOp.getContext(), d0, d1);
1325     Value offset = memrefIndices[dim];
1326     memrefIndices[dim] =
1327         affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1328     return dim;
1329   }
1330 
1331   assert(xferOp.isBroadcastDim(0) &&
1332          "Expected AffineDimExpr or AffineConstantExpr");
1333   return std::nullopt;
1334 }
1335 
1336 /// Codegen strategy for TransferOp1dConversion, depending on the
1337 /// operation.
1338 template <typename OpTy>
1339 struct Strategy1d;
1340 
1341 /// Codegen strategy for TransferReadOp.
1342 template <>
1343 struct Strategy1d<TransferReadOp> {
1344   static void generateForLoopBody(OpBuilder &b, Location loc,
1345                                   TransferReadOp xferOp, Value iv,
1346                                   ValueRange loopState) {
1347     SmallVector<Value, 8> indices;
1348     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1349     auto vec = loopState[0];
1350 
1351     // In case of out-of-bounds access, leave `vec` as is (was initialized with
1352     // padding value).
1353     auto nextVec = generateInBoundsCheck(
1354         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1355         /*inBoundsCase=*/
1356         [&](OpBuilder &b, Location loc) {
1357           Value val =
1358               b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
1359           return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1360         },
1361         /*outOfBoundsCase=*/
1362         [&](OpBuilder & /*b*/, Location loc) { return vec; });
1363     b.create<scf::YieldOp>(loc, nextVec);
1364   }
1365 
1366   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1367     // Inititalize vector with padding value.
1368     Location loc = xferOp.getLoc();
1369     return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1370                                      xferOp.getPadding());
1371   }
1372 };
1373 
1374 /// Codegen strategy for TransferWriteOp.
1375 template <>
1376 struct Strategy1d<TransferWriteOp> {
1377   static void generateForLoopBody(OpBuilder &b, Location loc,
1378                                   TransferWriteOp xferOp, Value iv,
1379                                   ValueRange /*loopState*/) {
1380     SmallVector<Value, 8> indices;
1381     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1382 
1383     // Nothing to do in case of out-of-bounds access.
1384     generateInBoundsCheck(
1385         b, xferOp, iv, dim,
1386         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1387           auto val =
1388               b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1389           b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
1390         });
1391     b.create<scf::YieldOp>(loc);
1392   }
1393 
1394   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1395     return Value();
1396   }
1397 };
1398 
1399 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1400 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1401 /// vector load/stores due to non-unit strides or broadcasts:
1402 ///
1403 /// * Transfer dimension is not the last memref dimension
1404 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1405 /// * Memref has a layout map with non-unit stride on the last dimension
1406 ///
1407 /// This pattern generates IR as follows:
1408 ///
1409 /// 1. Generate a for loop iterating over each vector element.
1410 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1411 ///    depending on OpTy.
1412 ///
1413 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1414 ///       can be generated instead of TransferOp1dConversion. Add such a pattern
1415 ///       to ConvertVectorToLLVM.
1416 ///
1417 /// E.g.:
1418 /// ```
1419 /// vector.transfer_write %vec, %A[%a, %b]
1420 ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1421 ///    : vector<9xf32>, memref<?x?xf32>
1422 /// ```
1423 /// Is rewritten to approximately the following pseudo-IR:
1424 /// ```
1425 /// for i = 0 to 9 {
1426 ///   %t = vector.extractelement %vec[i] : vector<9xf32>
1427 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1428 /// }
1429 /// ```
1430 template <typename OpTy>
1431 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1432   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1433 
1434   LogicalResult matchAndRewrite(OpTy xferOp,
1435                                 PatternRewriter &rewriter) const override {
1436     // TODO: support 0-d corner case.
1437     if (xferOp.getTransferRank() == 0)
1438       return failure();
1439     auto map = xferOp.getPermutationMap();
1440     auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
1441 
1442     if (!memRefType)
1443       return failure();
1444     if (xferOp.getVectorType().getRank() != 1)
1445       return failure();
1446     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1447       return failure(); // Handled by ConvertVectorToLLVM
1448 
1449     // Loop bounds, step, state...
1450     Location loc = xferOp.getLoc();
1451     auto vecType = xferOp.getVectorType();
1452     auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1453     Value ub =
1454         rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1455     if (vecType.isScalable()) {
1456       Value vscale =
1457           rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
1458       ub = rewriter.create<arith::MulIOp>(loc, ub, vscale);
1459     }
1460     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1461     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1462 
1463     // Generate for loop.
1464     rewriter.replaceOpWithNewOp<scf::ForOp>(
1465         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1466         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1467           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1468         });
1469 
1470     return success();
1471   }
1472 };
1473 
1474 } // namespace lowering_1_d
1475 } // namespace
1476 
1477 void mlir::populateVectorToSCFConversionPatterns(
1478     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1479   if (options.unroll) {
1480     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1481                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1482         patterns.getContext(), options);
1483   } else {
1484     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1485                  lowering_n_d::PrepareTransferWriteConversion,
1486                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1487                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1488         patterns.getContext(), options);
1489   }
1490 
1491   if (options.targetRank == 1) {
1492     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1493                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1494         patterns.getContext(), options);
1495   }
1496   patterns.add<lowering_n_d::DecomposePrintOpConversion>(patterns.getContext(),
1497                                                          options);
1498 }
1499 
1500 namespace {
1501 
1502 struct ConvertVectorToSCFPass
1503     : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1504   ConvertVectorToSCFPass() = default;
1505   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1506     this->fullUnroll = options.unroll;
1507     this->targetRank = options.targetRank;
1508     this->lowerTensors = options.lowerTensors;
1509   }
1510 
1511   void runOnOperation() override {
1512     VectorTransferToSCFOptions options;
1513     options.unroll = fullUnroll;
1514     options.targetRank = targetRank;
1515     options.lowerTensors = lowerTensors;
1516 
1517     // Lower permutation maps first.
1518     RewritePatternSet lowerTransferPatterns(&getContext());
1519     mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1520         lowerTransferPatterns);
1521     (void)applyPatternsAndFoldGreedily(getOperation(),
1522                                        std::move(lowerTransferPatterns));
1523 
1524     RewritePatternSet patterns(&getContext());
1525     populateVectorToSCFConversionPatterns(patterns, options);
1526     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1527   }
1528 };
1529 
1530 } // namespace
1531 
1532 std::unique_ptr<Pass>
1533 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1534   return std::make_unique<ConvertVectorToSCFPass>(options);
1535 }
1536