xref: /llvm-project/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (revision 6aaa8f25b66dc1fef4e465f274ee40b82d632988)
10f241638SMatthias Springer //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
24ead2cf7SAlex Zinenko //
34ead2cf7SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44ead2cf7SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
54ead2cf7SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64ead2cf7SAlex Zinenko //
74ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
84ead2cf7SAlex Zinenko //
90f241638SMatthias Springer // This file implements lowering of vector transfer operations to SCF.
104ead2cf7SAlex Zinenko //
114ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
124ead2cf7SAlex Zinenko 
13f36e909dSBenjamin Maxwell #include <numeric>
14a1fe1f5fSKazu Hirata #include <optional>
152bc4c3e9SNicolas Vasilache #include <type_traits>
164ead2cf7SAlex Zinenko 
174ead2cf7SAlex Zinenko #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
185f9e0466SNicolas Vasilache 
196825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
20abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
2166f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
228b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
23a795955fSKai Sasaki #include "mlir/Dialect/Tensor/IR/Tensor.h"
24c84061fdSRik Huijzer #include "mlir/Dialect/Vector/IR/VectorOps.h"
252bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
2727a713f5SBenjamin Maxwell #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
284ead2cf7SAlex Zinenko #include "mlir/IR/Builders.h"
296825bfe2SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
305f9e0466SNicolas Vasilache #include "mlir/Pass/Pass.h"
31b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
325f9e0466SNicolas Vasilache #include "mlir/Transforms/Passes.h"
334ead2cf7SAlex Zinenko 
3467d0d7acSMichele Scuttari namespace mlir {
3567d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTVECTORTOSCF
3667d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc"
3767d0d7acSMichele Scuttari } // namespace mlir
3867d0d7acSMichele Scuttari 
394ead2cf7SAlex Zinenko using namespace mlir;
404ead2cf7SAlex Zinenko using vector::TransferReadOp;
414ead2cf7SAlex Zinenko using vector::TransferWriteOp;
424ead2cf7SAlex Zinenko 
43350dadaaSBenjamin Kramer namespace {
440f241638SMatthias Springer 
450f241638SMatthias Springer /// Attribute name used for labeling transfer ops during progressive lowering.
460f241638SMatthias Springer static const char kPassLabel[] = "__vector_to_scf_lowering__";
470f241638SMatthias Springer 
4827a713f5SBenjamin Maxwell /// Return true if this transfer op operates on a source tensor.
4927a713f5SBenjamin Maxwell static bool isTensorOp(VectorTransferOpInterface xferOp) {
5027a713f5SBenjamin Maxwell   if (isa<RankedTensorType>(xferOp.getShapedType())) {
5127a713f5SBenjamin Maxwell     if (isa<vector::TransferWriteOp>(xferOp)) {
5227a713f5SBenjamin Maxwell       // TransferWriteOps on tensors have a result.
5327a713f5SBenjamin Maxwell       assert(xferOp->getNumResults() > 0);
5427a713f5SBenjamin Maxwell     }
5527a713f5SBenjamin Maxwell     return true;
5627a713f5SBenjamin Maxwell   }
5727a713f5SBenjamin Maxwell   return false;
5827a713f5SBenjamin Maxwell }
5927a713f5SBenjamin Maxwell 
602ca887deSMatthias Springer /// Patterns that inherit from this struct have access to
612ca887deSMatthias Springer /// VectorTransferToSCFOptions.
622ca887deSMatthias Springer template <typename OpTy>
632ca887deSMatthias Springer struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
642ca887deSMatthias Springer   explicit VectorToSCFPattern(MLIRContext *context,
652ca887deSMatthias Springer                               VectorTransferToSCFOptions opt)
662ca887deSMatthias Springer       : OpRewritePattern<OpTy>(context), options(opt) {}
672ca887deSMatthias Springer 
6827a713f5SBenjamin Maxwell   LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp,
6927a713f5SBenjamin Maxwell                                   PatternRewriter &rewriter) const {
7027a713f5SBenjamin Maxwell     if (isTensorOp(xferOp) && !options.lowerTensors) {
7127a713f5SBenjamin Maxwell       return rewriter.notifyMatchFailure(
7227a713f5SBenjamin Maxwell           xferOp, "lowering tensor transfers is disabled");
7327a713f5SBenjamin Maxwell     }
7427a713f5SBenjamin Maxwell     return success();
7527a713f5SBenjamin Maxwell   }
7627a713f5SBenjamin Maxwell 
772ca887deSMatthias Springer   VectorTransferToSCFOptions options;
782ca887deSMatthias Springer };
790f241638SMatthias Springer 
800f241638SMatthias Springer /// Given a vector transfer op, calculate which dimension of the `source`
810f241638SMatthias Springer /// memref should be unpacked in the next application of TransferOpConversion.
8215ae9964SKazu Hirata /// A return value of std::nullopt indicates a broadcast.
830f241638SMatthias Springer template <typename OpTy>
840a81ace0SKazu Hirata static std::optional<int64_t> unpackedDim(OpTy xferOp) {
85c537a943SNicolas Vasilache   // TODO: support 0-d corner case.
86c537a943SNicolas Vasilache   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
877c38fd60SJacques Pienaar   auto map = xferOp.getPermutationMap();
881609f1c2Slong.chen   if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
890f241638SMatthias Springer     return expr.getPosition();
907c3c5b11SNicolas Vasilache   }
910f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
920f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
931a36588eSKazu Hirata   return std::nullopt;
940f241638SMatthias Springer }
950f241638SMatthias Springer 
960f241638SMatthias Springer /// Compute the permutation map for the new (N-1)-D vector transfer op. This
970f241638SMatthias Springer /// map is identical to the current permutation map, but the first result is
980f241638SMatthias Springer /// omitted.
990f241638SMatthias Springer template <typename OpTy>
1006825bfe2SNicolas Vasilache static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
101c537a943SNicolas Vasilache   // TODO: support 0-d corner case.
102c537a943SNicolas Vasilache   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1037c38fd60SJacques Pienaar   auto map = xferOp.getPermutationMap();
1040f241638SMatthias Springer   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
1056825bfe2SNicolas Vasilache                         b.getContext());
1060f241638SMatthias Springer }
1070f241638SMatthias Springer 
1080f241638SMatthias Springer /// Calculate the indices for the new vector transfer op.
1090f241638SMatthias Springer ///
1100f241638SMatthias Springer /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
1110f241638SMatthias Springer ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
1120f241638SMatthias Springer ///                                 ^^^^^^
1130f241638SMatthias Springer ///              `iv` is the iteration variable of the (new) surrounding loop.
1140f241638SMatthias Springer template <typename OpTy>
1156825bfe2SNicolas Vasilache static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
1160f241638SMatthias Springer                            SmallVector<Value, 8> &indices) {
1170f241638SMatthias Springer   typename OpTy::Adaptor adaptor(xferOp);
1180f241638SMatthias Springer   // Corresponding memref dim of the vector dim that is unpacked.
1190f241638SMatthias Springer   auto dim = unpackedDim(xferOp);
1207c38fd60SJacques Pienaar   auto prevIndices = adaptor.getIndices();
1210f241638SMatthias Springer   indices.append(prevIndices.begin(), prevIndices.end());
1220f241638SMatthias Springer 
1236825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
124491d2701SKazu Hirata   bool isBroadcast = !dim.has_value();
1250f241638SMatthias Springer   if (!isBroadcast) {
1266825bfe2SNicolas Vasilache     AffineExpr d0, d1;
1276825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
128cbb09813SFangrui Song     Value offset = adaptor.getIndices()[*dim];
1294c48f016SMatthias Springer     indices[*dim] =
1304c48f016SMatthias Springer         affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1310f241638SMatthias Springer   }
1320f241638SMatthias Springer }
1330f241638SMatthias Springer 
1346825bfe2SNicolas Vasilache static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
1350f241638SMatthias Springer                             Value value) {
1360f241638SMatthias Springer   if (hasRetVal) {
137558e7401SMatthias Springer     assert(value && "Expected non-empty value");
1386825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc, value);
1390f241638SMatthias Springer   } else {
1406825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc);
1410f241638SMatthias Springer   }
1420f241638SMatthias Springer }
1430f241638SMatthias Springer 
1440f241638SMatthias Springer /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
1450f241638SMatthias Springer /// is set to true. No such check is generated under following circumstances:
1460f241638SMatthias Springer /// * xferOp does not have a mask.
1470f241638SMatthias Springer /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
1480f241638SMatthias Springer ///   computed and attached to the new transfer op in the pattern.)
1490f241638SMatthias Springer /// * The to-be-unpacked dim of xferOp is a broadcast.
1500f241638SMatthias Springer template <typename OpTy>
1516825bfe2SNicolas Vasilache static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
1527c38fd60SJacques Pienaar   if (!xferOp.getMask())
1530f241638SMatthias Springer     return Value();
1540f241638SMatthias Springer   if (xferOp.getMaskType().getRank() != 1)
1550f241638SMatthias Springer     return Value();
1560f241638SMatthias Springer   if (xferOp.isBroadcastDim(0))
1570f241638SMatthias Springer     return Value();
1580f241638SMatthias Springer 
1596825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
1607c38fd60SJacques Pienaar   return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
1610f241638SMatthias Springer }
1620f241638SMatthias Springer 
1630f241638SMatthias Springer /// Helper function TransferOpConversion and TransferOp1dConversion.
1640f241638SMatthias Springer /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
1650f241638SMatthias Springer /// specified dimension `dim` with the loop iteration variable `iv`.
1660f241638SMatthias Springer /// E.g., when unpacking dimension 0 from:
1670f241638SMatthias Springer /// ```
1680f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b] %cst
1690f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?xf32>
1700f241638SMatthias Springer /// ```
1710f241638SMatthias Springer /// An if check similar to this will be generated inside the loop:
1720f241638SMatthias Springer /// ```
1730f241638SMatthias Springer /// %d = memref.dim %A, %c0 : memref<?x?xf32>
1740f241638SMatthias Springer /// if (%a + iv < %d) {
1750f241638SMatthias Springer ///   (in-bounds case)
1760f241638SMatthias Springer /// } else {
1770f241638SMatthias Springer ///   (out-of-bounds case)
1780f241638SMatthias Springer /// }
1790f241638SMatthias Springer /// ```
1800f241638SMatthias Springer ///
1810f241638SMatthias Springer /// If the transfer is 1D and has a mask, this function generates a more complex
1820f241638SMatthias Springer /// check also accounts for potentially masked out elements.
1830f241638SMatthias Springer ///
1840f241638SMatthias Springer /// This function variant returns the value returned by `inBoundsCase` or
1850f241638SMatthias Springer /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
1860f241638SMatthias Springer /// `resultTypes`.
1870f241638SMatthias Springer template <typename OpTy>
1880f241638SMatthias Springer static Value generateInBoundsCheck(
1890a81ace0SKazu Hirata     OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
1900f241638SMatthias Springer     TypeRange resultTypes,
1910f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
1920f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
1930f241638SMatthias Springer   bool hasRetVal = !resultTypes.empty();
1940f241638SMatthias Springer   Value cond; // Condition to be built...
1950f241638SMatthias Springer 
1960f241638SMatthias Springer   // Condition check 1: Access in-bounds?
1970916d96dSKazu Hirata   bool isBroadcast = !dim; // No in-bounds check for broadcasts.
1986825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
1996825bfe2SNicolas Vasilache   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
2000f241638SMatthias Springer   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
2017c38fd60SJacques Pienaar     Value memrefDim =
2027c38fd60SJacques Pienaar         vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
2036825bfe2SNicolas Vasilache     AffineExpr d0, d1;
2046825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
2056d5fc1e3SKazu Hirata     Value base = xferOp.getIndices()[*dim];
2064c48f016SMatthias Springer     Value memrefIdx =
2074c48f016SMatthias Springer         affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
208a54f4eaeSMogball     cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
209a54f4eaeSMogball                                     memrefIdx);
2100f241638SMatthias Springer   }
2110f241638SMatthias Springer 
2120f241638SMatthias Springer   // Condition check 2: Masked in?
2136825bfe2SNicolas Vasilache   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
2146825bfe2SNicolas Vasilache     if (cond)
215a54f4eaeSMogball       cond = lb.create<arith::AndIOp>(cond, maskCond);
2166825bfe2SNicolas Vasilache     else
2170f241638SMatthias Springer       cond = maskCond;
2180f241638SMatthias Springer   }
2190f241638SMatthias Springer 
2200f241638SMatthias Springer   // If the condition is non-empty, generate an SCF::IfOp.
2210f241638SMatthias Springer   if (cond) {
2226825bfe2SNicolas Vasilache     auto check = lb.create<scf::IfOp>(
2231125c5c0SFrederik Gossen         cond,
2240f241638SMatthias Springer         /*thenBuilder=*/
2256825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
2266825bfe2SNicolas Vasilache           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
227cadb7ccfSAlex Zinenko         },
2280f241638SMatthias Springer         /*elseBuilder=*/
2296825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
2300f241638SMatthias Springer           if (outOfBoundsCase) {
2316825bfe2SNicolas Vasilache             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
2327c3c5b11SNicolas Vasilache           } else {
2336825bfe2SNicolas Vasilache             b.create<scf::YieldOp>(loc);
2347c3c5b11SNicolas Vasilache           }
2357c3c5b11SNicolas Vasilache         });
2367c3c5b11SNicolas Vasilache 
2370f241638SMatthias Springer     return hasRetVal ? check.getResult(0) : Value();
2384ead2cf7SAlex Zinenko   }
2394ead2cf7SAlex Zinenko 
2400f241638SMatthias Springer   // Condition is empty, no need for an SCF::IfOp.
2416825bfe2SNicolas Vasilache   return inBoundsCase(b, loc);
2420f241638SMatthias Springer }
2430f241638SMatthias Springer 
2440f241638SMatthias Springer /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
2450f241638SMatthias Springer /// a return value. Consequently, this function does not have a return value.
2460f241638SMatthias Springer template <typename OpTy>
2470f241638SMatthias Springer static void generateInBoundsCheck(
2480a81ace0SKazu Hirata     OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim,
2490f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> inBoundsCase,
2500f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
2510f241638SMatthias Springer   generateInBoundsCheck(
2526825bfe2SNicolas Vasilache       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
2530f241638SMatthias Springer       /*inBoundsCase=*/
2546825bfe2SNicolas Vasilache       [&](OpBuilder &b, Location loc) {
2556825bfe2SNicolas Vasilache         inBoundsCase(b, loc);
2560f241638SMatthias Springer         return Value();
2570f241638SMatthias Springer       },
2580f241638SMatthias Springer       /*outOfBoundsCase=*/
2596825bfe2SNicolas Vasilache       [&](OpBuilder &b, Location loc) {
2600f241638SMatthias Springer         if (outOfBoundsCase)
2616825bfe2SNicolas Vasilache           outOfBoundsCase(b, loc);
2620f241638SMatthias Springer         return Value();
2630f241638SMatthias Springer       });
2640f241638SMatthias Springer }
2650f241638SMatthias Springer 
2660f241638SMatthias Springer /// Given an ArrayAttr, return a copy where the first element is dropped.
2676825bfe2SNicolas Vasilache static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
2680f241638SMatthias Springer   if (!attr)
2690f241638SMatthias Springer     return attr;
2706825bfe2SNicolas Vasilache   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
2710f241638SMatthias Springer }
2720f241638SMatthias Springer 
2730f241638SMatthias Springer /// Add the pass label to a vector transfer op if its rank is not the target
2740f241638SMatthias Springer /// rank.
2750f241638SMatthias Springer template <typename OpTy>
2766825bfe2SNicolas Vasilache static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
2772ca887deSMatthias Springer                                 unsigned targetRank) {
2782ca887deSMatthias Springer   if (newXferOp.getVectorType().getRank() > targetRank)
2796825bfe2SNicolas Vasilache     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
2800f241638SMatthias Springer }
2810f241638SMatthias Springer 
282a088bed4SMatthias Springer namespace lowering_n_d {
283a088bed4SMatthias Springer 
284a088bed4SMatthias Springer /// Helper data structure for data and mask buffers.
285a088bed4SMatthias Springer struct BufferAllocs {
286a088bed4SMatthias Springer   Value dataBuffer;
287a088bed4SMatthias Springer   Value maskBuffer;
288a088bed4SMatthias Springer };
289a088bed4SMatthias Springer 
2903c3810e7SNicolas Vasilache // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
2913c3810e7SNicolas Vasilache static Operation *getAutomaticAllocationScope(Operation *op) {
2923c3810e7SNicolas Vasilache   Operation *scope =
2933c3810e7SNicolas Vasilache       op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
2943c3810e7SNicolas Vasilache   assert(scope && "Expected op to be inside automatic allocation scope");
2953c3810e7SNicolas Vasilache   return scope;
2963c3810e7SNicolas Vasilache }
2973c3810e7SNicolas Vasilache 
298a088bed4SMatthias Springer /// Allocate temporary buffers for data (vector) and mask (if present).
299a088bed4SMatthias Springer template <typename OpTy>
3006825bfe2SNicolas Vasilache static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
3016825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
302a088bed4SMatthias Springer   OpBuilder::InsertionGuard guard(b);
3033c3810e7SNicolas Vasilache   Operation *scope = getAutomaticAllocationScope(xferOp);
3043c3810e7SNicolas Vasilache   assert(scope->getNumRegions() == 1 &&
3053c3810e7SNicolas Vasilache          "AutomaticAllocationScope with >1 regions");
306a088bed4SMatthias Springer   b.setInsertionPointToStart(&scope->getRegion(0).front());
307a088bed4SMatthias Springer 
308a088bed4SMatthias Springer   BufferAllocs result;
309a088bed4SMatthias Springer   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
3106825bfe2SNicolas Vasilache   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
311a088bed4SMatthias Springer 
3127c38fd60SJacques Pienaar   if (xferOp.getMask()) {
3137c38fd60SJacques Pienaar     auto maskType = MemRefType::get({}, xferOp.getMask().getType());
3146825bfe2SNicolas Vasilache     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
315fb7ec1f1SMatthias Springer     b.setInsertionPoint(xferOp);
3167c38fd60SJacques Pienaar     b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
3171b60f0d7SJeff Niu     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange());
318a088bed4SMatthias Springer   }
319a088bed4SMatthias Springer 
320a088bed4SMatthias Springer   return result;
321a088bed4SMatthias Springer }
322a088bed4SMatthias Springer 
323a088bed4SMatthias Springer /// Given a MemRefType with VectorType element type, unpack one dimension from
324a088bed4SMatthias Springer /// the VectorType into the MemRefType.
325a088bed4SMatthias Springer ///
326a088bed4SMatthias Springer /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
3272a82dfd7SBenjamin Maxwell static FailureOr<MemRefType> unpackOneDim(MemRefType type) {
3285550c821STres Popp   auto vectorType = dyn_cast<VectorType>(type.getElementType());
3292a82dfd7SBenjamin Maxwell   // Vectors with leading scalable dims are not supported.
3302a82dfd7SBenjamin Maxwell   // It may be possible to support these in future by using dynamic memref dims.
3312a82dfd7SBenjamin Maxwell   if (vectorType.getScalableDims().front())
3322a82dfd7SBenjamin Maxwell     return failure();
333a088bed4SMatthias Springer   auto memrefShape = type.getShape();
334a088bed4SMatthias Springer   SmallVector<int64_t, 8> newMemrefShape;
335a088bed4SMatthias Springer   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
336a088bed4SMatthias Springer   newMemrefShape.push_back(vectorType.getDimSize(0));
337a088bed4SMatthias Springer   return MemRefType::get(newMemrefShape,
3382a82dfd7SBenjamin Maxwell                          VectorType::Builder(vectorType).dropDim(0));
339a088bed4SMatthias Springer }
340a088bed4SMatthias Springer 
3410f241638SMatthias Springer /// Given a transfer op, find the memref from which the mask is loaded. This
3420f241638SMatthias Springer /// is similar to Strategy<TransferWriteOp>::getBuffer.
3430f241638SMatthias Springer template <typename OpTy>
3440f241638SMatthias Springer static Value getMaskBuffer(OpTy xferOp) {
3457c38fd60SJacques Pienaar   assert(xferOp.getMask() && "Expected that transfer op has mask");
3467c38fd60SJacques Pienaar   auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
3470f241638SMatthias Springer   assert(loadOp && "Expected transfer op mask produced by LoadOp");
3480f241638SMatthias Springer   return loadOp.getMemRef();
3490f241638SMatthias Springer }
3500f241638SMatthias Springer 
3510f241638SMatthias Springer /// Codegen strategy, depending on the operation.
3520f241638SMatthias Springer template <typename OpTy>
3530f241638SMatthias Springer struct Strategy;
3540f241638SMatthias Springer 
3550f241638SMatthias Springer /// Code strategy for vector TransferReadOp.
3564ead2cf7SAlex Zinenko template <>
3570f241638SMatthias Springer struct Strategy<TransferReadOp> {
3580f241638SMatthias Springer   /// Find the StoreOp that is used for writing the current TransferReadOp's
3590f241638SMatthias Springer   /// result to the temporary buffer allocation.
3600f241638SMatthias Springer   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
3610f241638SMatthias Springer     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
3620f241638SMatthias Springer     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
3630f241638SMatthias Springer     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
3640f241638SMatthias Springer     return storeOp;
3657c3c5b11SNicolas Vasilache   }
3664ead2cf7SAlex Zinenko 
3670f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferReadOps are
3680f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
3690f241638SMatthias Springer   /// of the buffer allocation:
3700f241638SMatthias Springer   /// ```
3710f241638SMatthias Springer   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
3720f241638SMatthias Springer   /// memref.store %vec, %buf[...] ...
3730f241638SMatthias Springer   /// ```
3740f241638SMatthias Springer   static Value getBuffer(TransferReadOp xferOp) {
3750f241638SMatthias Springer     return getStoreOp(xferOp).getMemRef();
3761870e787SNicolas Vasilache   }
3770f241638SMatthias Springer 
3780f241638SMatthias Springer   /// Retrieve the indices of the current StoreOp that stores into the buffer.
3790f241638SMatthias Springer   static void getBufferIndices(TransferReadOp xferOp,
3800f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
3819f5afc3dSRik Huijzer     auto storeOp = getStoreOp(xferOp);
382136d746eSJacques Pienaar     auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
3830f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
3840f241638SMatthias Springer   }
3850f241638SMatthias Springer 
3860f241638SMatthias Springer   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
3870f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
3880f241638SMatthias Springer   ///
3890f241638SMatthias Springer   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
3900f241638SMatthias Springer   ///    variable `iv`.
3910f241638SMatthias Springer   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
3920f241638SMatthias Springer   ///
3930f241638SMatthias Springer   /// E.g.:
3940f241638SMatthias Springer   /// ```
3950f241638SMatthias Springer   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
3960f241638SMatthias Springer   ///     : memref<?x?x?xf32>, vector<4x3xf32>
3970f241638SMatthias Springer   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
3980f241638SMatthias Springer   /// ```
3990f241638SMatthias Springer   /// Is rewritten to:
4000f241638SMatthias Springer   /// ```
4010f241638SMatthias Springer   /// %casted = vector.type_cast %buf
4020f241638SMatthias Springer   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
4030f241638SMatthias Springer   /// for %j = 0 to 4 {
4040f241638SMatthias Springer   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
4050f241638SMatthias Springer   ///       : memref<?x?x?xf32>, vector<3xf32>
4060f241638SMatthias Springer   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
4070f241638SMatthias Springer   /// }
4080f241638SMatthias Springer   /// ```
4090f241638SMatthias Springer   ///
4100f241638SMatthias Springer   /// Note: The loop and type cast are generated in TransferOpConversion.
4110f241638SMatthias Springer   ///       The original TransferReadOp and store op are deleted in `cleanup`.
4120f241638SMatthias Springer   /// Note: The `mask` operand is set in TransferOpConversion.
4136825bfe2SNicolas Vasilache   static TransferReadOp rewriteOp(OpBuilder &b,
4142ca887deSMatthias Springer                                   VectorTransferToSCFOptions options,
415558e7401SMatthias Springer                                   TransferReadOp xferOp, Value buffer, Value iv,
416558e7401SMatthias Springer                                   ValueRange /*loopState*/) {
4170f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
4180f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
4190f241638SMatthias Springer     storeIndices.push_back(iv);
4200f241638SMatthias Springer 
4210f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
4226825bfe2SNicolas Vasilache     getXferIndices(b, xferOp, iv, xferIndices);
4230f241638SMatthias Springer 
4246825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
4255550c821STres Popp     auto bufferType = dyn_cast<ShapedType>(buffer.getType());
4265550c821STres Popp     auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
4277c38fd60SJacques Pienaar     auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
4286825bfe2SNicolas Vasilache     auto newXferOp = b.create<vector::TransferReadOp>(
4297c38fd60SJacques Pienaar         loc, vecType, xferOp.getSource(), xferIndices,
4307c38fd60SJacques Pienaar         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
4317c38fd60SJacques Pienaar         xferOp.getPadding(), Value(), inBoundsAttr);
4320f241638SMatthias Springer 
4336825bfe2SNicolas Vasilache     maybeApplyPassLabel(b, newXferOp, options.targetRank);
4340f241638SMatthias Springer 
4357c38fd60SJacques Pienaar     b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
4366825bfe2SNicolas Vasilache     return newXferOp;
4370f241638SMatthias Springer   }
4380f241638SMatthias Springer 
4390f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
4400f241638SMatthias Springer   /// padding value to the temporary buffer.
441558e7401SMatthias Springer   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
442558e7401SMatthias Springer                                     Value buffer, Value iv,
443558e7401SMatthias Springer                                     ValueRange /*loopState*/) {
4440f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
4450f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
4460f241638SMatthias Springer     storeIndices.push_back(iv);
4470f241638SMatthias Springer 
4486825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
4495550c821STres Popp     auto bufferType = dyn_cast<ShapedType>(buffer.getType());
4505550c821STres Popp     auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
4517c38fd60SJacques Pienaar     auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
4526825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
453558e7401SMatthias Springer 
454558e7401SMatthias Springer     return Value();
4550f241638SMatthias Springer   }
4560f241638SMatthias Springer 
4570f241638SMatthias Springer   /// Cleanup after rewriting the op.
458558e7401SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
459558e7401SMatthias Springer                       scf::ForOp /*forOp*/) {
4600f241638SMatthias Springer     rewriter.eraseOp(getStoreOp(xferOp));
4610f241638SMatthias Springer     rewriter.eraseOp(xferOp);
4620f241638SMatthias Springer   }
463558e7401SMatthias Springer 
464558e7401SMatthias Springer   /// Return the initial loop state for the generated scf.for loop.
465558e7401SMatthias Springer   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
4664ead2cf7SAlex Zinenko };
4677c3c5b11SNicolas Vasilache 
4680f241638SMatthias Springer /// Codegen strategy for vector TransferWriteOp.
4690f241638SMatthias Springer template <>
4700f241638SMatthias Springer struct Strategy<TransferWriteOp> {
4710f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
4720f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
4730f241638SMatthias Springer   /// of the buffer allocation:
4740f241638SMatthias Springer   /// ```
4750f241638SMatthias Springer   /// %vec = memref.load %buf[...] ...
4760f241638SMatthias Springer   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
4770f241638SMatthias Springer   /// ```
4780f241638SMatthias Springer   static Value getBuffer(TransferWriteOp xferOp) {
4797c38fd60SJacques Pienaar     auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
4800f241638SMatthias Springer     assert(loadOp && "Expected transfer op vector produced by LoadOp");
4810f241638SMatthias Springer     return loadOp.getMemRef();
4827c3c5b11SNicolas Vasilache   }
4834ead2cf7SAlex Zinenko 
4840f241638SMatthias Springer   /// Retrieve the indices of the current LoadOp that loads from the buffer.
4850f241638SMatthias Springer   static void getBufferIndices(TransferWriteOp xferOp,
4860f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
4877c38fd60SJacques Pienaar     auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
488136d746eSJacques Pienaar     auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
4890f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
4900f241638SMatthias Springer   }
4910f241638SMatthias Springer 
4920f241638SMatthias Springer   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
4930f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
4940f241638SMatthias Springer   ///
4950f241638SMatthias Springer   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
4960f241638SMatthias Springer   ///    using the loop iteration variable `iv`.
4970f241638SMatthias Springer   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
4980f241638SMatthias Springer   ///    to memory.
4990f241638SMatthias Springer   ///
5000f241638SMatthias Springer   /// Note: For more details, see comments on Strategy<TransferReadOp>.
5016825bfe2SNicolas Vasilache   static TransferWriteOp rewriteOp(OpBuilder &b,
5022ca887deSMatthias Springer                                    VectorTransferToSCFOptions options,
5032ca887deSMatthias Springer                                    TransferWriteOp xferOp, Value buffer,
504558e7401SMatthias Springer                                    Value iv, ValueRange loopState) {
5050f241638SMatthias Springer     SmallVector<Value, 8> loadIndices;
5060f241638SMatthias Springer     getBufferIndices(xferOp, loadIndices);
5070f241638SMatthias Springer     loadIndices.push_back(iv);
5080f241638SMatthias Springer 
5090f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
5106825bfe2SNicolas Vasilache     getXferIndices(b, xferOp, iv, xferIndices);
5110f241638SMatthias Springer 
5126825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
5136825bfe2SNicolas Vasilache     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
5147c38fd60SJacques Pienaar     auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
5157c38fd60SJacques Pienaar     auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
516558e7401SMatthias Springer     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
5176825bfe2SNicolas Vasilache     auto newXferOp = b.create<vector::TransferWriteOp>(
518558e7401SMatthias Springer         loc, type, vec, source, xferIndices,
5196825bfe2SNicolas Vasilache         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
5200f241638SMatthias Springer         inBoundsAttr);
5210f241638SMatthias Springer 
5226825bfe2SNicolas Vasilache     maybeApplyPassLabel(b, newXferOp, options.targetRank);
5230f241638SMatthias Springer 
5246825bfe2SNicolas Vasilache     return newXferOp;
5250f241638SMatthias Springer   }
5260f241638SMatthias Springer 
5270f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
528558e7401SMatthias Springer   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
529558e7401SMatthias Springer                                     Value buffer, Value iv,
530558e7401SMatthias Springer                                     ValueRange loopState) {
531558e7401SMatthias Springer     return isTensorOp(xferOp) ? loopState[0] : Value();
532558e7401SMatthias Springer   }
5330f241638SMatthias Springer 
5340f241638SMatthias Springer   /// Cleanup after rewriting the op.
535558e7401SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
536558e7401SMatthias Springer                       scf::ForOp forOp) {
537558e7401SMatthias Springer     if (isTensorOp(xferOp)) {
538558e7401SMatthias Springer       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
539558e7401SMatthias Springer       rewriter.replaceOp(xferOp, forOp->getResult(0));
540558e7401SMatthias Springer     } else {
5410f241638SMatthias Springer       rewriter.eraseOp(xferOp);
5420f241638SMatthias Springer     }
543558e7401SMatthias Springer   }
544558e7401SMatthias Springer 
545558e7401SMatthias Springer   /// Return the initial loop state for the generated scf.for loop.
546558e7401SMatthias Springer   static Value initialLoopState(TransferWriteOp xferOp) {
5477c38fd60SJacques Pienaar     return isTensorOp(xferOp) ? xferOp.getSource() : Value();
548558e7401SMatthias Springer   }
5490f241638SMatthias Springer };
5500f241638SMatthias Springer 
5510f241638SMatthias Springer template <typename OpTy>
552fb7ec1f1SMatthias Springer LogicalResult checkPrepareXferOp(OpTy xferOp,
553fb7ec1f1SMatthias Springer                                  VectorTransferToSCFOptions options) {
5540f241638SMatthias Springer   if (xferOp->hasAttr(kPassLabel))
5550f241638SMatthias Springer     return failure();
556fb7ec1f1SMatthias Springer   if (xferOp.getVectorType().getRank() <= options.targetRank)
5570f241638SMatthias Springer     return failure();
5582a82dfd7SBenjamin Maxwell   // Currently the unpacking of the leading dimension into the memref is not
5592a82dfd7SBenjamin Maxwell   // supported for scalable dimensions.
5602a82dfd7SBenjamin Maxwell   if (xferOp.getVectorType().getScalableDims().front())
5612a82dfd7SBenjamin Maxwell     return failure();
562558e7401SMatthias Springer   if (isTensorOp(xferOp) && !options.lowerTensors)
5638fb48979SMatthias Springer     return failure();
564f718a53dSMatthias Springer   // Transfer ops that modify the element type are not supported atm.
565f718a53dSMatthias Springer   if (xferOp.getVectorType().getElementType() !=
566f718a53dSMatthias Springer       xferOp.getShapedType().getElementType())
567f718a53dSMatthias Springer     return failure();
5680f241638SMatthias Springer   return success();
5690f241638SMatthias Springer }
5700f241638SMatthias Springer 
5710f241638SMatthias Springer /// Prepare a TransferReadOp for progressive lowering.
5720f241638SMatthias Springer ///
5730f241638SMatthias Springer /// 1. Allocate a temporary buffer.
5740f241638SMatthias Springer /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
5750f241638SMatthias Springer /// 3. Store the result of the TransferReadOp into the temporary buffer.
5760f241638SMatthias Springer /// 4. Load the result from the temporary buffer and replace all uses of the
5770f241638SMatthias Springer ///    original TransferReadOp with this load.
5780f241638SMatthias Springer ///
5790f241638SMatthias Springer /// E.g.:
5800f241638SMatthias Springer /// ```
5810f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
5820f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5830f241638SMatthias Springer /// ```
5840f241638SMatthias Springer /// is rewritten to:
5850f241638SMatthias Springer /// ```
5860f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
5870f241638SMatthias Springer /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
5880f241638SMatthias Springer ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
5890f241638SMatthias Springer /// memref.store %1, %0[] : memref<vector<5x4xf32>>
5900f241638SMatthias Springer /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
5910f241638SMatthias Springer /// ```
5920f241638SMatthias Springer ///
5930f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
5942ca887deSMatthias Springer struct PrepareTransferReadConversion
5952ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
5962ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
5970f241638SMatthias Springer 
5980f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
5990f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
600fb7ec1f1SMatthias Springer     if (checkPrepareXferOp(xferOp, options).failed())
6010f241638SMatthias Springer       return failure();
6020f241638SMatthias Springer 
6039f5afc3dSRik Huijzer     auto buffers = allocBuffers(rewriter, xferOp);
6049f5afc3dSRik Huijzer     auto *newXfer = rewriter.clone(*xferOp.getOperation());
6050f241638SMatthias Springer     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
6067c38fd60SJacques Pienaar     if (xferOp.getMask()) {
6077c38fd60SJacques Pienaar       dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
6080f241638SMatthias Springer           buffers.maskBuffer);
6090f241638SMatthias Springer     }
6100f241638SMatthias Springer 
6116825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
6126825bfe2SNicolas Vasilache     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
6136825bfe2SNicolas Vasilache                                      buffers.dataBuffer);
6140f241638SMatthias Springer     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
6154ead2cf7SAlex Zinenko 
6164ead2cf7SAlex Zinenko     return success();
6174ead2cf7SAlex Zinenko   }
6180f241638SMatthias Springer };
6190f241638SMatthias Springer 
6200f241638SMatthias Springer /// Prepare a TransferWriteOp for progressive lowering.
6210f241638SMatthias Springer ///
6220f241638SMatthias Springer /// 1. Allocate a temporary buffer.
6230f241638SMatthias Springer /// 2. Store the vector into the buffer.
6240f241638SMatthias Springer /// 3. Load the vector from the buffer again.
6250f241638SMatthias Springer /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
6260f241638SMatthias Springer ///    marking it eligible for progressive lowering via TransferOpConversion.
6270f241638SMatthias Springer ///
6280f241638SMatthias Springer /// E.g.:
6290f241638SMatthias Springer /// ```
6300f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
6310f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
6320f241638SMatthias Springer /// ```
6330f241638SMatthias Springer /// is rewritten to:
6340f241638SMatthias Springer /// ```
6350f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
6360f241638SMatthias Springer /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
6370f241638SMatthias Springer /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
6380f241638SMatthias Springer /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
6390f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
6400f241638SMatthias Springer /// ```
6410f241638SMatthias Springer ///
6420f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
6430f241638SMatthias Springer struct PrepareTransferWriteConversion
6442ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
6452ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
6460f241638SMatthias Springer 
6470f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
6480f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
649fb7ec1f1SMatthias Springer     if (checkPrepareXferOp(xferOp, options).failed())
6500f241638SMatthias Springer       return failure();
6510f241638SMatthias Springer 
6526825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
6536825bfe2SNicolas Vasilache     auto buffers = allocBuffers(rewriter, xferOp);
6547c38fd60SJacques Pienaar     rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
6557c38fd60SJacques Pienaar                                      buffers.dataBuffer);
6566825bfe2SNicolas Vasilache     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
6575fcf907bSMatthias Springer     rewriter.modifyOpInPlace(xferOp, [&]() {
6587c38fd60SJacques Pienaar       xferOp.getVectorMutable().assign(loadedVec);
6590f241638SMatthias Springer       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
6600f241638SMatthias Springer     });
6610f241638SMatthias Springer 
6627c38fd60SJacques Pienaar     if (xferOp.getMask()) {
6635fcf907bSMatthias Springer       rewriter.modifyOpInPlace(xferOp, [&]() {
6647c38fd60SJacques Pienaar         xferOp.getMaskMutable().assign(buffers.maskBuffer);
6657c38fd60SJacques Pienaar       });
6660f241638SMatthias Springer     }
6670f241638SMatthias Springer 
6680f241638SMatthias Springer     return success();
6690f241638SMatthias Springer   }
6700f241638SMatthias Springer };
6710f241638SMatthias Springer 
672f36e909dSBenjamin Maxwell /// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows
673f36e909dSBenjamin Maxwell /// printing both 1D scalable vectors and n-D fixed size vectors.
674f36e909dSBenjamin Maxwell ///
675f36e909dSBenjamin Maxwell /// E.g.:
676f36e909dSBenjamin Maxwell /// ```
677f36e909dSBenjamin Maxwell /// vector.print %v : vector<[4]xi32>
678f36e909dSBenjamin Maxwell /// ```
679f36e909dSBenjamin Maxwell /// is rewritten to:
680f36e909dSBenjamin Maxwell /// ```
681f36e909dSBenjamin Maxwell /// %c0 = arith.constant 0 : index
682f36e909dSBenjamin Maxwell /// %c4 = arith.constant 4 : index
683f36e909dSBenjamin Maxwell /// %c1 = arith.constant 1 : index
684f36e909dSBenjamin Maxwell /// %vscale = vector.vscale
685f36e909dSBenjamin Maxwell /// %length = arith.muli %vscale, %c4 : index
686f36e909dSBenjamin Maxwell /// %lastIndex = arith.subi %length, %c1 : index
687f36e909dSBenjamin Maxwell /// vector.print punctuation <open>
688f36e909dSBenjamin Maxwell /// scf.for %i = %c0 to %length step %c1 {
689f36e909dSBenjamin Maxwell ///   %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
690f36e909dSBenjamin Maxwell ///   vector.print %el : i32 punctuation <no_punctuation>
691f36e909dSBenjamin Maxwell ///   %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
692f36e909dSBenjamin Maxwell ///   scf.if %notLastIndex {
693f36e909dSBenjamin Maxwell ///     vector.print punctuation <comma>
694f36e909dSBenjamin Maxwell ///   }
695f36e909dSBenjamin Maxwell /// }
696f36e909dSBenjamin Maxwell /// vector.print punctuation <close>
697f36e909dSBenjamin Maxwell /// vector.print
698f36e909dSBenjamin Maxwell /// ```
699f36e909dSBenjamin Maxwell struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> {
700f36e909dSBenjamin Maxwell   using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern;
701f36e909dSBenjamin Maxwell   LogicalResult matchAndRewrite(vector::PrintOp printOp,
702f36e909dSBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
703f36e909dSBenjamin Maxwell     if (!printOp.getSource())
704f36e909dSBenjamin Maxwell       return failure();
705f36e909dSBenjamin Maxwell 
706f36e909dSBenjamin Maxwell     VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
707f36e909dSBenjamin Maxwell     if (!vectorType)
708f36e909dSBenjamin Maxwell       return failure();
709f36e909dSBenjamin Maxwell 
710f36e909dSBenjamin Maxwell     // Currently >= 2D scalable vectors are not supported.
711f36e909dSBenjamin Maxwell     // These can't be lowered to LLVM (as LLVM does not support scalable vectors
712f36e909dSBenjamin Maxwell     // of scalable vectors), and due to limitations of current ops can't be
713f36e909dSBenjamin Maxwell     // indexed with SSA values or flattened. This may change after
714f36e909dSBenjamin Maxwell     // https://reviews.llvm.org/D155034, though there still needs to be a path
715f36e909dSBenjamin Maxwell     // for lowering to LLVM.
716f36e909dSBenjamin Maxwell     if (vectorType.getRank() > 1 && vectorType.isScalable())
717f36e909dSBenjamin Maxwell       return failure();
718f36e909dSBenjamin Maxwell 
719f36e909dSBenjamin Maxwell     auto loc = printOp.getLoc();
720f36e909dSBenjamin Maxwell     auto value = printOp.getSource();
721f36e909dSBenjamin Maxwell 
722f36e909dSBenjamin Maxwell     if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) {
723f36e909dSBenjamin Maxwell       // Oddly sized integers are (somewhat) buggy on a lot of backends, so to
724f36e909dSBenjamin Maxwell       // avoid issues extend them to a more standard size.
725f36e909dSBenjamin Maxwell       // https://github.com/llvm/llvm-project/issues/30613
726f36e909dSBenjamin Maxwell       auto width = intTy.getWidth();
727f36e909dSBenjamin Maxwell       auto legalWidth = llvm::NextPowerOf2(std::max(8u, width) - 1);
728f36e909dSBenjamin Maxwell       auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth,
729f36e909dSBenjamin Maxwell                                          intTy.getSignedness());
730f36e909dSBenjamin Maxwell       // arith can only take signless integers, so we must cast back and forth.
731f36e909dSBenjamin Maxwell       auto signlessSourceVectorType =
732f36e909dSBenjamin Maxwell           vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy));
733f36e909dSBenjamin Maxwell       auto signlessTargetVectorType =
734f36e909dSBenjamin Maxwell           vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy));
735f36e909dSBenjamin Maxwell       auto targetVectorType = vectorType.cloneWith({}, legalIntTy);
736f36e909dSBenjamin Maxwell       value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType,
737f36e909dSBenjamin Maxwell                                                  value);
7388f9aac44SMatthias Springer       if (value.getType() != signlessTargetVectorType) {
739f36e909dSBenjamin Maxwell         if (width == 1 || intTy.isUnsigned())
740f36e909dSBenjamin Maxwell           value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType,
741f36e909dSBenjamin Maxwell                                                   value);
742f36e909dSBenjamin Maxwell         else
743f36e909dSBenjamin Maxwell           value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType,
744f36e909dSBenjamin Maxwell                                                   value);
7458f9aac44SMatthias Springer       }
746f36e909dSBenjamin Maxwell       value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value);
747f36e909dSBenjamin Maxwell       vectorType = targetVectorType;
748f36e909dSBenjamin Maxwell     }
749f36e909dSBenjamin Maxwell 
750f36e909dSBenjamin Maxwell     auto scalableDimensions = vectorType.getScalableDims();
751f36e909dSBenjamin Maxwell     auto shape = vectorType.getShape();
752f36e909dSBenjamin Maxwell     constexpr int64_t singletonShape[] = {1};
753f36e909dSBenjamin Maxwell     if (vectorType.getRank() == 0)
754f36e909dSBenjamin Maxwell       shape = singletonShape;
755f36e909dSBenjamin Maxwell 
756f36e909dSBenjamin Maxwell     if (vectorType.getRank() != 1) {
757f36e909dSBenjamin Maxwell       // Flatten n-D vectors to 1D. This is done to allow indexing with a
758f36e909dSBenjamin Maxwell       // non-constant value (which can currently only be done via
759f36e909dSBenjamin Maxwell       // vector.extractelement for 1D vectors).
760f36e909dSBenjamin Maxwell       auto flatLength = std::accumulate(shape.begin(), shape.end(), 1,
761f36e909dSBenjamin Maxwell                                         std::multiplies<int64_t>());
762f36e909dSBenjamin Maxwell       auto flatVectorType =
763f36e909dSBenjamin Maxwell           VectorType::get({flatLength}, vectorType.getElementType());
764f36e909dSBenjamin Maxwell       value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value);
765f36e909dSBenjamin Maxwell     }
766f36e909dSBenjamin Maxwell 
767f36e909dSBenjamin Maxwell     vector::PrintOp firstClose;
768f36e909dSBenjamin Maxwell     SmallVector<Value, 8> loopIndices;
769f36e909dSBenjamin Maxwell     for (unsigned d = 0; d < shape.size(); d++) {
770f36e909dSBenjamin Maxwell       // Setup loop bounds and step.
771f36e909dSBenjamin Maxwell       Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
772f36e909dSBenjamin Maxwell       Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]);
773f36e909dSBenjamin Maxwell       Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
774f36e909dSBenjamin Maxwell       if (!scalableDimensions.empty() && scalableDimensions[d]) {
775f36e909dSBenjamin Maxwell         auto vscale = rewriter.create<vector::VectorScaleOp>(
776f36e909dSBenjamin Maxwell             loc, rewriter.getIndexType());
777f36e909dSBenjamin Maxwell         upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale);
778f36e909dSBenjamin Maxwell       }
779f36e909dSBenjamin Maxwell       auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step);
780f36e909dSBenjamin Maxwell 
781f36e909dSBenjamin Maxwell       // Create a loop to print the elements surrounded by parentheses.
782f36e909dSBenjamin Maxwell       rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
783f36e909dSBenjamin Maxwell       auto loop =
784f36e909dSBenjamin Maxwell           rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
785f36e909dSBenjamin Maxwell       auto printClose = rewriter.create<vector::PrintOp>(
786f36e909dSBenjamin Maxwell           loc, vector::PrintPunctuation::Close);
787f36e909dSBenjamin Maxwell       if (!firstClose)
788f36e909dSBenjamin Maxwell         firstClose = printClose;
789f36e909dSBenjamin Maxwell 
790f36e909dSBenjamin Maxwell       auto loopIdx = loop.getInductionVar();
791f36e909dSBenjamin Maxwell       loopIndices.push_back(loopIdx);
792f36e909dSBenjamin Maxwell 
793f36e909dSBenjamin Maxwell       // Print a comma after all but the last element.
794f36e909dSBenjamin Maxwell       rewriter.setInsertionPointToStart(loop.getBody());
795f36e909dSBenjamin Maxwell       auto notLastIndex = rewriter.create<arith::CmpIOp>(
796f36e909dSBenjamin Maxwell           loc, arith::CmpIPredicate::ult, loopIdx, lastIndex);
797f36e909dSBenjamin Maxwell       rewriter.create<scf::IfOp>(loc, notLastIndex,
798f36e909dSBenjamin Maxwell                                  [&](OpBuilder &builder, Location loc) {
799f36e909dSBenjamin Maxwell                                    builder.create<vector::PrintOp>(
800f36e909dSBenjamin Maxwell                                        loc, vector::PrintPunctuation::Comma);
801f36e909dSBenjamin Maxwell                                    builder.create<scf::YieldOp>(loc);
802f36e909dSBenjamin Maxwell                                  });
803f36e909dSBenjamin Maxwell 
804f36e909dSBenjamin Maxwell       rewriter.setInsertionPointToStart(loop.getBody());
805f36e909dSBenjamin Maxwell     }
806f36e909dSBenjamin Maxwell 
807f36e909dSBenjamin Maxwell     // Compute the flattened index.
808f36e909dSBenjamin Maxwell     // Note: For the > rank 1 vectors this assumes non-scalable.
809f36e909dSBenjamin Maxwell     Value flatIndex;
810f36e909dSBenjamin Maxwell     auto currentStride = 1;
811f36e909dSBenjamin Maxwell     for (int d = shape.size() - 1; d >= 0; d--) {
812f36e909dSBenjamin Maxwell       auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride);
813f36e909dSBenjamin Maxwell       auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]);
814f36e909dSBenjamin Maxwell       if (flatIndex)
815f36e909dSBenjamin Maxwell         flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index);
816f36e909dSBenjamin Maxwell       else
817f36e909dSBenjamin Maxwell         flatIndex = index;
818f36e909dSBenjamin Maxwell       currentStride *= shape[d];
819f36e909dSBenjamin Maxwell     }
820f36e909dSBenjamin Maxwell 
821f36e909dSBenjamin Maxwell     // Print the scalar elements in the inner most loop.
822f36e909dSBenjamin Maxwell     auto element =
823f36e909dSBenjamin Maxwell         rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex);
824f36e909dSBenjamin Maxwell     rewriter.create<vector::PrintOp>(loc, element,
825f36e909dSBenjamin Maxwell                                      vector::PrintPunctuation::NoPunctuation);
826f36e909dSBenjamin Maxwell 
827f36e909dSBenjamin Maxwell     rewriter.setInsertionPointAfter(firstClose);
828f36e909dSBenjamin Maxwell     rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation());
829f36e909dSBenjamin Maxwell     rewriter.eraseOp(printOp);
830f36e909dSBenjamin Maxwell     return success();
831f36e909dSBenjamin Maxwell   }
832f36e909dSBenjamin Maxwell 
833f36e909dSBenjamin Maxwell   static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) {
834f36e909dSBenjamin Maxwell     return IntegerType::get(intTy.getContext(), intTy.getWidth(),
835f36e909dSBenjamin Maxwell                             IntegerType::Signless);
836f36e909dSBenjamin Maxwell   };
837f36e909dSBenjamin Maxwell };
838f36e909dSBenjamin Maxwell 
8390f241638SMatthias Springer /// Progressive lowering of vector transfer ops: Unpack one dimension.
8400f241638SMatthias Springer ///
8410f241638SMatthias Springer /// 1. Unpack one dimension from the current buffer type and cast the buffer
8420f241638SMatthias Springer ///    to that new type. E.g.:
8430f241638SMatthias Springer ///    ```
8440f241638SMatthias Springer ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
8450f241638SMatthias Springer ///    vector.transfer_write %vec ...
8460f241638SMatthias Springer ///    ```
8470f241638SMatthias Springer ///    The following cast is generated:
8480f241638SMatthias Springer ///    ```
8490f241638SMatthias Springer ///    %casted = vector.type_cast %0
8500f241638SMatthias Springer ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
8510f241638SMatthias Springer ///    ```
8520f241638SMatthias Springer /// 2. Generate a for loop and rewrite the transfer op according to the
8530f241638SMatthias Springer ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
8540f241638SMatthias Springer ///    out-of-bounds, generate an if-check and handle both cases separately.
8550f241638SMatthias Springer /// 3. Clean up according to the corresponding Strategy<OpTy>.
856558e7401SMatthias Springer ///
857558e7401SMatthias Springer /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
858558e7401SMatthias Springer /// source (as opposed to a memref source), then each iteration of the generated
859558e7401SMatthias Springer /// scf.for loop yields the new tensor value. E.g.:
860558e7401SMatthias Springer /// ```
861558e7401SMatthias Springer /// %result = scf.for i = 0 to 5 {
862558e7401SMatthias Springer ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
863558e7401SMatthias Springer ///   %1 = vector.transfer_write %0, %source[...]
864558e7401SMatthias Springer ///       : vector<4x3xf32>, tensor<5x4x3xf32>
865558e7401SMatthias Springer ///   scf.yield %1 : tensor<5x4x3xf32>
866558e7401SMatthias Springer /// }
867558e7401SMatthias Springer /// ```
8680f241638SMatthias Springer template <typename OpTy>
8692ca887deSMatthias Springer struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
8702ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
8710f241638SMatthias Springer 
872700b64dcSMatthias Springer   void initialize() {
873700b64dcSMatthias Springer     // This pattern recursively unpacks one dimension at a time. The recursion
874700b64dcSMatthias Springer     // bounded as the rank is strictly decreasing.
875700b64dcSMatthias Springer     this->setHasBoundedRewriteRecursion();
876700b64dcSMatthias Springer   }
877700b64dcSMatthias Springer 
8786b21948fSRik Huijzer   static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
8796b21948fSRik Huijzer                                        SmallVectorImpl<Value> &loadIndices,
8806b21948fSRik Huijzer                                        Value iv) {
8816b21948fSRik Huijzer     assert(xferOp.getMask() && "Expected transfer op to have mask");
8826b21948fSRik Huijzer 
8836b21948fSRik Huijzer     // Add load indices from the previous iteration.
8846b21948fSRik Huijzer     // The mask buffer depends on the permutation map, which makes determining
8856b21948fSRik Huijzer     // the indices quite complex, so this is why we need to "look back" to the
8866b21948fSRik Huijzer     // previous iteration to find the right indices.
8876b21948fSRik Huijzer     Value maskBuffer = getMaskBuffer(xferOp);
8886b21948fSRik Huijzer     for (Operation *user : maskBuffer.getUsers()) {
8896b21948fSRik Huijzer       // If there is no previous load op, then the indices are empty.
8906b21948fSRik Huijzer       if (auto loadOp = dyn_cast<memref::LoadOp>(user)) {
8916b21948fSRik Huijzer         Operation::operand_range prevIndices = loadOp.getIndices();
8926b21948fSRik Huijzer         loadIndices.append(prevIndices.begin(), prevIndices.end());
8936b21948fSRik Huijzer         break;
8946b21948fSRik Huijzer       }
8956b21948fSRik Huijzer     }
8966b21948fSRik Huijzer 
8976b21948fSRik Huijzer     // In case of broadcast: Use same indices to load from memref
8986b21948fSRik Huijzer     // as before.
8996b21948fSRik Huijzer     if (!xferOp.isBroadcastDim(0))
9006b21948fSRik Huijzer       loadIndices.push_back(iv);
9016b21948fSRik Huijzer   }
9026b21948fSRik Huijzer 
9030f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
9040f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
9050f241638SMatthias Springer     if (!xferOp->hasAttr(kPassLabel))
9060f241638SMatthias Springer       return failure();
9070f241638SMatthias Springer 
9080f241638SMatthias Springer     // Find and cast data buffer. How the buffer can be found depends on OpTy.
9096825bfe2SNicolas Vasilache     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
9106b21948fSRik Huijzer     Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
9115550c821STres Popp     auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
9126b21948fSRik Huijzer     FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
9132a82dfd7SBenjamin Maxwell     if (failed(castedDataType))
9142a82dfd7SBenjamin Maxwell       return failure();
9152a82dfd7SBenjamin Maxwell 
9166825bfe2SNicolas Vasilache     auto castedDataBuffer =
9172a82dfd7SBenjamin Maxwell         locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer);
9180f241638SMatthias Springer 
9190f241638SMatthias Springer     // If the xferOp has a mask: Find and cast mask buffer.
9200f241638SMatthias Springer     Value castedMaskBuffer;
9217c38fd60SJacques Pienaar     if (xferOp.getMask()) {
9226b21948fSRik Huijzer       Value maskBuffer = getMaskBuffer(xferOp);
9230f241638SMatthias Springer       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
9240f241638SMatthias Springer         // Do not unpack a dimension of the mask, if:
9250f241638SMatthias Springer         // * To-be-unpacked transfer op dimension is a broadcast.
9260f241638SMatthias Springer         // * Mask is 1D, i.e., the mask cannot be further unpacked.
9270f241638SMatthias Springer         //   (That means that all remaining dimensions of the transfer op must
9280f241638SMatthias Springer         //   be broadcasted.)
9290f241638SMatthias Springer         castedMaskBuffer = maskBuffer;
9300f241638SMatthias Springer       } else {
9312a82dfd7SBenjamin Maxwell         // It's safe to assume the mask buffer can be unpacked if the data
9322a82dfd7SBenjamin Maxwell         // buffer was unpacked.
9336b21948fSRik Huijzer         auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
9346b21948fSRik Huijzer         MemRefType castedMaskType = *unpackOneDim(maskBufferType);
9356825bfe2SNicolas Vasilache         castedMaskBuffer =
9366825bfe2SNicolas Vasilache             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
9370f241638SMatthias Springer       }
9380f241638SMatthias Springer     }
9390f241638SMatthias Springer 
9400f241638SMatthias Springer     // Loop bounds and step.
941a54f4eaeSMogball     auto lb = locB.create<arith::ConstantIndexOp>(0);
942a54f4eaeSMogball     auto ub = locB.create<arith::ConstantIndexOp>(
9432a82dfd7SBenjamin Maxwell         castedDataType->getDimSize(castedDataType->getRank() - 1));
944a54f4eaeSMogball     auto step = locB.create<arith::ConstantIndexOp>(1);
945558e7401SMatthias Springer     // TransferWriteOps that operate on tensors return the modified tensor and
946558e7401SMatthias Springer     // require a loop state.
947558e7401SMatthias Springer     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
9480f241638SMatthias Springer 
9490f241638SMatthias Springer     // Generate for loop.
950558e7401SMatthias Springer     auto result = locB.create<scf::ForOp>(
951558e7401SMatthias Springer         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
952558e7401SMatthias Springer         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
953558e7401SMatthias Springer           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
954558e7401SMatthias Springer 
955558e7401SMatthias Springer           auto result = generateInBoundsCheck(
9566825bfe2SNicolas Vasilache               b, xferOp, iv, unpackedDim(xferOp),
957558e7401SMatthias Springer               stateType ? TypeRange(stateType) : TypeRange(),
9580f241638SMatthias Springer               /*inBoundsCase=*/
9596825bfe2SNicolas Vasilache               [&](OpBuilder &b, Location loc) {
9600f241638SMatthias Springer                 // Create new transfer op.
9612ca887deSMatthias Springer                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
962558e7401SMatthias Springer                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
9630f241638SMatthias Springer 
9640f241638SMatthias Springer                 // If old transfer op has a mask: Set mask on new transfer op.
9650f241638SMatthias Springer                 // Special case: If the mask of the old transfer op is 1D and
9666b21948fSRik Huijzer                 // the unpacked dim is not a broadcast, no mask is needed on
9676b21948fSRik Huijzer                 // the new transfer op.
9687c38fd60SJacques Pienaar                 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
9690f241638SMatthias Springer                                          xferOp.getMaskType().getRank() > 1)) {
9700f241638SMatthias Springer                   OpBuilder::InsertionGuard guard(b);
9710f241638SMatthias Springer                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
9720f241638SMatthias Springer 
9730f241638SMatthias Springer                   SmallVector<Value, 8> loadIndices;
9746b21948fSRik Huijzer                   getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
9756b21948fSRik Huijzer                                            loadIndices, iv);
9766825bfe2SNicolas Vasilache                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
9776825bfe2SNicolas Vasilache                                                        loadIndices);
9785fcf907bSMatthias Springer                   rewriter.modifyOpInPlace(newXfer, [&]() {
9797c38fd60SJacques Pienaar                     newXfer.getMaskMutable().assign(mask);
9807c38fd60SJacques Pienaar                   });
9810f241638SMatthias Springer                 }
982558e7401SMatthias Springer 
983558e7401SMatthias Springer                 return loopState.empty() ? Value() : newXfer->getResult(0);
9840f241638SMatthias Springer               },
9850f241638SMatthias Springer               /*outOfBoundsCase=*/
9860f241638SMatthias Springer               [&](OpBuilder &b, Location /*loc*/) {
987558e7401SMatthias Springer                 return Strategy<OpTy>::handleOutOfBoundsDim(
988558e7401SMatthias Springer                     b, xferOp, castedDataBuffer, iv, loopState);
9890f241638SMatthias Springer               });
9900f241638SMatthias Springer 
991558e7401SMatthias Springer           maybeYieldValue(b, loc, !loopState.empty(), result);
992558e7401SMatthias Springer         });
993558e7401SMatthias Springer 
994558e7401SMatthias Springer     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
9950f241638SMatthias Springer     return success();
9960f241638SMatthias Springer   }
9970f241638SMatthias Springer };
9980f241638SMatthias Springer 
99927a713f5SBenjamin Maxwell /// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp
100027a713f5SBenjamin Maxwell /// and ConstantMaskOp.
100127a713f5SBenjamin Maxwell template <typename VscaleConstantBuilder>
100227a713f5SBenjamin Maxwell static FailureOr<SmallVector<OpFoldResult>>
100327a713f5SBenjamin Maxwell getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) {
100427a713f5SBenjamin Maxwell   if (!mask)
100527a713f5SBenjamin Maxwell     return SmallVector<OpFoldResult>{};
100627a713f5SBenjamin Maxwell   if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) {
100727a713f5SBenjamin Maxwell     return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) {
100827a713f5SBenjamin Maxwell       return OpFoldResult(dimSize);
100927a713f5SBenjamin Maxwell     });
101027a713f5SBenjamin Maxwell   }
101127a713f5SBenjamin Maxwell   if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) {
101227a713f5SBenjamin Maxwell     int dimIdx = 0;
101327a713f5SBenjamin Maxwell     VectorType maskType = constantMask.getVectorType();
101427a713f5SBenjamin Maxwell     auto indexType = IndexType::get(mask.getContext());
101527a713f5SBenjamin Maxwell     return llvm::map_to_vector(
101627a713f5SBenjamin Maxwell         constantMask.getMaskDimSizes(), [&](int64_t dimSize) {
101727a713f5SBenjamin Maxwell           // A scalable dim in a constant_mask means vscale x dimSize.
101827a713f5SBenjamin Maxwell           if (maskType.getScalableDims()[dimIdx++])
101927a713f5SBenjamin Maxwell             return OpFoldResult(createVscaleMultiple(dimSize));
102027a713f5SBenjamin Maxwell           return OpFoldResult(IntegerAttr::get(indexType, dimSize));
102127a713f5SBenjamin Maxwell         });
102227a713f5SBenjamin Maxwell   }
102327a713f5SBenjamin Maxwell   return failure();
102427a713f5SBenjamin Maxwell }
102527a713f5SBenjamin Maxwell 
102627a713f5SBenjamin Maxwell /// Scalable vector lowering of transfer_write(transpose). This lowering only
102727a713f5SBenjamin Maxwell /// supports rank 2 (scalable) vectors, but can be used in conjunction with
102827a713f5SBenjamin Maxwell /// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion
102927a713f5SBenjamin Maxwell /// unrolls until the first scalable dimension.
103027a713f5SBenjamin Maxwell ///
103127a713f5SBenjamin Maxwell /// Example:
103227a713f5SBenjamin Maxwell ///
103327a713f5SBenjamin Maxwell /// BEFORE:
103427a713f5SBenjamin Maxwell /// ```mlir
103527a713f5SBenjamin Maxwell /// %transpose = vector.transpose %vec, [1, 0]
103627a713f5SBenjamin Maxwell ///    : vector<4x[4]xf32> to vector<[4]x4xf32>
103727a713f5SBenjamin Maxwell /// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]}
103827a713f5SBenjamin Maxwell ///    : vector<[4]x4xf32>,  memref<?x?xf32>
103927a713f5SBenjamin Maxwell /// ```
104027a713f5SBenjamin Maxwell ///
104127a713f5SBenjamin Maxwell /// AFTER:
104227a713f5SBenjamin Maxwell /// ```mlir
104327a713f5SBenjamin Maxwell /// %c1 = arith.constant 1 : index
104427a713f5SBenjamin Maxwell /// %c4 = arith.constant 4 : index
104527a713f5SBenjamin Maxwell /// %c0 = arith.constant 0 : index
104627a713f5SBenjamin Maxwell /// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32>
104727a713f5SBenjamin Maxwell /// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32>
104827a713f5SBenjamin Maxwell /// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32>
104927a713f5SBenjamin Maxwell /// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32>
105027a713f5SBenjamin Maxwell /// %vscale = vector.vscale
105127a713f5SBenjamin Maxwell /// %c4_vscale = arith.muli %vscale, %c4 : index
105227a713f5SBenjamin Maxwell /// scf.for %idx = %c0 to %c4_vscale step %c1 {
105327a713f5SBenjamin Maxwell ///   %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32>
105427a713f5SBenjamin Maxwell ///   %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32>
105527a713f5SBenjamin Maxwell ///   %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32>
105627a713f5SBenjamin Maxwell ///   %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32>
105727a713f5SBenjamin Maxwell ///   %slice_i = affine.apply #map(%idx)[%i]
105827a713f5SBenjamin Maxwell ///   %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32>
105927a713f5SBenjamin Maxwell ///   vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]}
106027a713f5SBenjamin Maxwell ///     : vector<4xf32>, memref<?x?xf32>
106127a713f5SBenjamin Maxwell /// }
106227a713f5SBenjamin Maxwell /// ```
106327a713f5SBenjamin Maxwell struct ScalableTransposeTransferWriteConversion
106427a713f5SBenjamin Maxwell     : VectorToSCFPattern<vector::TransferWriteOp> {
106527a713f5SBenjamin Maxwell   using VectorToSCFPattern::VectorToSCFPattern;
106627a713f5SBenjamin Maxwell 
106727a713f5SBenjamin Maxwell   LogicalResult matchAndRewrite(TransferWriteOp writeOp,
106827a713f5SBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
106927a713f5SBenjamin Maxwell     if (failed(checkLowerTensors(writeOp, rewriter)))
107027a713f5SBenjamin Maxwell       return failure();
107127a713f5SBenjamin Maxwell 
107227a713f5SBenjamin Maxwell     VectorType vectorType = writeOp.getVectorType();
107327a713f5SBenjamin Maxwell 
107427a713f5SBenjamin Maxwell     // Note: By comparing the scalable dims to an ArrayRef of length two this
107527a713f5SBenjamin Maxwell     // implicitly checks the rank (is also two).
107627a713f5SBenjamin Maxwell     ArrayRef<bool> scalableFlags = vectorType.getScalableDims();
107727a713f5SBenjamin Maxwell     if (scalableFlags != ArrayRef<bool>{true, false}) {
107827a713f5SBenjamin Maxwell       return rewriter.notifyMatchFailure(
107927a713f5SBenjamin Maxwell           writeOp, "expected vector of the form vector<[N]xMxty>");
108027a713f5SBenjamin Maxwell     }
108127a713f5SBenjamin Maxwell 
108227a713f5SBenjamin Maxwell     auto permutationMap = writeOp.getPermutationMap();
108327a713f5SBenjamin Maxwell     if (!permutationMap.isIdentity()) {
108427a713f5SBenjamin Maxwell       return rewriter.notifyMatchFailure(
108527a713f5SBenjamin Maxwell           writeOp, "non-identity permutations are unsupported (lower first)");
108627a713f5SBenjamin Maxwell     }
108727a713f5SBenjamin Maxwell 
108827a713f5SBenjamin Maxwell     // Note: This pattern is only lowering the leading dimension (to a loop),
108927a713f5SBenjamin Maxwell     // so we only check if the leading dimension is in bounds. The in-bounds
109027a713f5SBenjamin Maxwell     // attribute for the trailing dimension will be propagated.
109127a713f5SBenjamin Maxwell     if (!writeOp.isDimInBounds(0)) {
109227a713f5SBenjamin Maxwell       return rewriter.notifyMatchFailure(
109327a713f5SBenjamin Maxwell           writeOp, "out-of-bounds dims are unsupported (use masking)");
109427a713f5SBenjamin Maxwell     }
109527a713f5SBenjamin Maxwell 
109627a713f5SBenjamin Maxwell     Value vector = writeOp.getVector();
109727a713f5SBenjamin Maxwell     auto transposeOp = vector.getDefiningOp<vector::TransposeOp>();
109827a713f5SBenjamin Maxwell     if (!transposeOp ||
109927a713f5SBenjamin Maxwell         transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) {
110027a713f5SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp, "source not transpose");
110127a713f5SBenjamin Maxwell     }
110227a713f5SBenjamin Maxwell 
110327a713f5SBenjamin Maxwell     auto loc = writeOp.getLoc();
110427a713f5SBenjamin Maxwell     auto createVscaleMultiple =
110527a713f5SBenjamin Maxwell         vector::makeVscaleConstantBuilder(rewriter, loc);
110627a713f5SBenjamin Maxwell 
110727a713f5SBenjamin Maxwell     auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple);
110827a713f5SBenjamin Maxwell     if (failed(maskDims)) {
110927a713f5SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp,
111027a713f5SBenjamin Maxwell                                          "failed to resolve mask dims");
111127a713f5SBenjamin Maxwell     }
111227a713f5SBenjamin Maxwell 
111327a713f5SBenjamin Maxwell     int64_t fixedDimSize = vectorType.getDimSize(1);
111427a713f5SBenjamin Maxwell     auto fixedDimOffsets = llvm::seq(fixedDimSize);
111527a713f5SBenjamin Maxwell 
111627a713f5SBenjamin Maxwell     // Extract all slices from the source of the transpose.
111727a713f5SBenjamin Maxwell     auto transposeSource = transposeOp.getVector();
111827a713f5SBenjamin Maxwell     SmallVector<Value> transposeSourceSlices =
111927a713f5SBenjamin Maxwell         llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
112027a713f5SBenjamin Maxwell           return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx);
112127a713f5SBenjamin Maxwell         });
112227a713f5SBenjamin Maxwell 
112327a713f5SBenjamin Maxwell     // Loop bounds and step.
112427a713f5SBenjamin Maxwell     auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
112527a713f5SBenjamin Maxwell     auto ub =
112627a713f5SBenjamin Maxwell         maskDims->empty()
112727a713f5SBenjamin Maxwell             ? Value(createVscaleMultiple(vectorType.getDimSize(0)))
112827a713f5SBenjamin Maxwell             : vector::getAsValues(rewriter, loc, maskDims->front()).front();
112927a713f5SBenjamin Maxwell     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
113027a713f5SBenjamin Maxwell 
113127a713f5SBenjamin Maxwell     // Generate a new mask for the slice.
113227a713f5SBenjamin Maxwell     VectorType sliceType = VectorType::Builder(vectorType).dropDim(0);
113327a713f5SBenjamin Maxwell     Value sliceMask = nullptr;
113427a713f5SBenjamin Maxwell     if (!maskDims->empty()) {
113527a713f5SBenjamin Maxwell       sliceMask = rewriter.create<vector::CreateMaskOp>(
113627a713f5SBenjamin Maxwell           loc, sliceType.clone(rewriter.getI1Type()),
113727a713f5SBenjamin Maxwell           ArrayRef<OpFoldResult>(*maskDims).drop_front());
113827a713f5SBenjamin Maxwell     }
113927a713f5SBenjamin Maxwell 
114027a713f5SBenjamin Maxwell     Value initDest = isTensorOp(writeOp) ? writeOp.getSource() : Value{};
114127a713f5SBenjamin Maxwell     ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
114227a713f5SBenjamin Maxwell     auto result = rewriter.create<scf::ForOp>(
114327a713f5SBenjamin Maxwell         loc, lb, ub, step, initLoopArgs,
114427a713f5SBenjamin Maxwell         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) {
114527a713f5SBenjamin Maxwell           // Indices for the new transfer op.
114627a713f5SBenjamin Maxwell           SmallVector<Value, 8> xferIndices;
114727a713f5SBenjamin Maxwell           getXferIndices(b, writeOp, iv, xferIndices);
114827a713f5SBenjamin Maxwell 
114927a713f5SBenjamin Maxwell           // Extract a transposed slice from the source vector.
115027a713f5SBenjamin Maxwell           SmallVector<Value> transposeElements =
115127a713f5SBenjamin Maxwell               llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value {
115227a713f5SBenjamin Maxwell                 return b.create<vector::ExtractOp>(
115327a713f5SBenjamin Maxwell                     loc, transposeSourceSlices[idx], iv);
115427a713f5SBenjamin Maxwell               });
115527a713f5SBenjamin Maxwell           auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType,
115627a713f5SBenjamin Maxwell                                                            transposeElements);
115727a713f5SBenjamin Maxwell 
115827a713f5SBenjamin Maxwell           // Create the transfer_write for the slice.
115927a713f5SBenjamin Maxwell           Value dest =
116027a713f5SBenjamin Maxwell               loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front();
116127a713f5SBenjamin Maxwell           auto newWriteOp = b.create<vector::TransferWriteOp>(
116227a713f5SBenjamin Maxwell               loc, sliceVec, dest, xferIndices,
116327a713f5SBenjamin Maxwell               ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
116427a713f5SBenjamin Maxwell           if (sliceMask)
116527a713f5SBenjamin Maxwell             newWriteOp.getMaskMutable().assign(sliceMask);
116627a713f5SBenjamin Maxwell 
116727a713f5SBenjamin Maxwell           // Yield from the loop.
116827a713f5SBenjamin Maxwell           b.create<scf::YieldOp>(loc, loopIterArgs.empty()
116927a713f5SBenjamin Maxwell                                           ? ValueRange{}
117027a713f5SBenjamin Maxwell                                           : newWriteOp.getResult());
117127a713f5SBenjamin Maxwell         });
117227a713f5SBenjamin Maxwell 
117327a713f5SBenjamin Maxwell     if (isTensorOp(writeOp))
117427a713f5SBenjamin Maxwell       rewriter.replaceOp(writeOp, result);
117527a713f5SBenjamin Maxwell     else
117627a713f5SBenjamin Maxwell       rewriter.eraseOp(writeOp);
117727a713f5SBenjamin Maxwell 
117827a713f5SBenjamin Maxwell     return success();
117927a713f5SBenjamin Maxwell   }
118027a713f5SBenjamin Maxwell };
118127a713f5SBenjamin Maxwell 
1182a088bed4SMatthias Springer } // namespace lowering_n_d
1183a088bed4SMatthias Springer 
1184a088bed4SMatthias Springer namespace lowering_n_d_unrolled {
1185a088bed4SMatthias Springer 
11860f241638SMatthias Springer /// If the original transfer op has a mask, compute the mask of the new transfer
11870f241638SMatthias Springer /// op (for the current iteration `i`) and assign it.
11880f241638SMatthias Springer template <typename OpTy>
11896825bfe2SNicolas Vasilache static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
11900f241638SMatthias Springer                             int64_t i) {
11917c38fd60SJacques Pienaar   if (!xferOp.getMask())
11920f241638SMatthias Springer     return;
11930f241638SMatthias Springer 
11940f241638SMatthias Springer   if (xferOp.isBroadcastDim(0)) {
11950f241638SMatthias Springer     // To-be-unpacked dimension is a broadcast, which does not have a
11960f241638SMatthias Springer     // corresponding mask dimension. Mask attribute remains unchanged.
11977c38fd60SJacques Pienaar     newXferOp.getMaskMutable().assign(xferOp.getMask());
11980f241638SMatthias Springer     return;
11990f241638SMatthias Springer   }
12000f241638SMatthias Springer 
12010f241638SMatthias Springer   if (xferOp.getMaskType().getRank() > 1) {
12020f241638SMatthias Springer     // Unpack one dimension of the mask.
12036825bfe2SNicolas Vasilache     OpBuilder::InsertionGuard guard(b);
12046825bfe2SNicolas Vasilache     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
12050f241638SMatthias Springer 
12060f241638SMatthias Springer     llvm::SmallVector<int64_t, 1> indices({i});
12076825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
12087c38fd60SJacques Pienaar     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
12097c38fd60SJacques Pienaar     newXferOp.getMaskMutable().assign(newMask);
12100f241638SMatthias Springer   }
12110f241638SMatthias Springer 
12120f241638SMatthias Springer   // If we end up here: The mask of the old transfer op is 1D and the unpacked
12130f241638SMatthias Springer   // dim is not a broadcast, so no mask is needed on the new transfer op.
12140f241638SMatthias Springer   // `generateInBoundsCheck` will have evaluated the mask already.
12150f241638SMatthias Springer }
12160f241638SMatthias Springer 
12170f241638SMatthias Springer /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
12180f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
12190f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
12200f241638SMatthias Springer ///
12210f241638SMatthias Springer /// ```
12220f241638SMatthias Springer /// E.g.:
12230f241638SMatthias Springer /// ```
12240f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
12250f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<5x4xf32>
12260f241638SMatthias Springer /// ```
12270f241638SMatthias Springer /// is rewritten to IR such as (simplified):
12280f241638SMatthias Springer /// ```
12290f241638SMatthias Springer /// %v_init = splat %padding : vector<5x4xf32>
12300f241638SMatthias Springer /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
12310f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
12320f241638SMatthias Springer /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
12330f241638SMatthias Springer /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
12340f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
12350f241638SMatthias Springer /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
12360f241638SMatthias Springer /// ...
12370f241638SMatthias Springer /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
12380f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
12390f241638SMatthias Springer /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
12400f241638SMatthias Springer /// ```
12410f241638SMatthias Springer ///
12420f241638SMatthias Springer /// Note: As an optimization, if the result of the original TransferReadOp
12430f241638SMatthias Springer /// was directly inserted into another vector, no new %v_init vector is created.
12440f241638SMatthias Springer /// Instead, the new TransferReadOp results are inserted into that vector.
12452ca887deSMatthias Springer struct UnrollTransferReadConversion
12462ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
12472ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
12480f241638SMatthias Springer 
1249700b64dcSMatthias Springer   void initialize() {
1250700b64dcSMatthias Springer     // This pattern recursively unpacks one dimension at a time. The recursion
1251700b64dcSMatthias Springer     // bounded as the rank is strictly decreasing.
1252700b64dcSMatthias Springer     setHasBoundedRewriteRecursion();
1253700b64dcSMatthias Springer   }
1254700b64dcSMatthias Springer 
1255aa2dc792SMatthias Springer   /// Get or build the vector into which the newly created TransferReadOp
1256aa2dc792SMatthias Springer   /// results are inserted.
1257aa2dc792SMatthias Springer   Value buildResultVector(PatternRewriter &rewriter,
1258aa2dc792SMatthias Springer                           TransferReadOp xferOp) const {
12590f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp))
12607c38fd60SJacques Pienaar       return insertOp.getDest();
12616825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
12626a8ba318SRiver Riddle     return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
12637c38fd60SJacques Pienaar                                             xferOp.getPadding());
12640f241638SMatthias Springer   }
12650f241638SMatthias Springer 
12660f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
12670f241638SMatthias Springer   /// vector::InsertOp, return that operation.
12680f241638SMatthias Springer   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
12690f241638SMatthias Springer     if (xferOp->hasOneUse()) {
12700f241638SMatthias Springer       Operation *xferOpUser = *xferOp->getUsers().begin();
12710f241638SMatthias Springer       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
12720f241638SMatthias Springer         return insertOp;
12730f241638SMatthias Springer     }
12740f241638SMatthias Springer 
12750f241638SMatthias Springer     return vector::InsertOp();
12760f241638SMatthias Springer   }
12770f241638SMatthias Springer 
12780f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
12790f241638SMatthias Springer   /// vector::InsertOp, return that operation's indices.
12800f241638SMatthias Springer   void getInsertionIndices(TransferReadOp xferOp,
128198f6289aSDiego Caballero                            SmallVectorImpl<OpFoldResult> &indices) const {
128298f6289aSDiego Caballero     if (auto insertOp = getInsertOp(xferOp)) {
128398f6289aSDiego Caballero       auto pos = insertOp.getMixedPosition();
128498f6289aSDiego Caballero       indices.append(pos.begin(), pos.end());
128598f6289aSDiego Caballero     }
12860f241638SMatthias Springer   }
12870f241638SMatthias Springer 
12880f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
12890f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
12900f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
12910f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
12922ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
1293aa2dc792SMatthias Springer       return rewriter.notifyMatchFailure(
1294aa2dc792SMatthias Springer           xferOp, "vector rank is less or equal to target rank");
129527a713f5SBenjamin Maxwell     if (failed(checkLowerTensors(xferOp, rewriter)))
129627a713f5SBenjamin Maxwell       return failure();
1297f718a53dSMatthias Springer     // Transfer ops that modify the element type are not supported atm.
1298f718a53dSMatthias Springer     if (xferOp.getVectorType().getElementType() !=
1299f718a53dSMatthias Springer         xferOp.getShapedType().getElementType())
1300aa2dc792SMatthias Springer       return rewriter.notifyMatchFailure(
1301aa2dc792SMatthias Springer           xferOp, "not yet supported: element type mismatch");
13020f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
13032a82dfd7SBenjamin Maxwell     if (xferVecType.getScalableDims()[0]) {
13042a82dfd7SBenjamin Maxwell       // Cannot unroll a scalable dimension at compile time.
1305aa2dc792SMatthias Springer       return rewriter.notifyMatchFailure(
1306aa2dc792SMatthias Springer           xferOp, "scalable dimensions cannot be unrolled");
13072a82dfd7SBenjamin Maxwell     }
13082a82dfd7SBenjamin Maxwell 
1309aa2dc792SMatthias Springer     auto insertOp = getInsertOp(xferOp);
1310aa2dc792SMatthias Springer     auto vec = buildResultVector(rewriter, xferOp);
1311aa2dc792SMatthias Springer     auto vecType = dyn_cast<VectorType>(vec.getType());
1312aa2dc792SMatthias Springer 
13132a82dfd7SBenjamin Maxwell     VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0);
13142a82dfd7SBenjamin Maxwell 
13150f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
13160f241638SMatthias Springer 
13170f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
13186825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
13190f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
1320a54f4eaeSMogball       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
13210f241638SMatthias Springer 
13220f241638SMatthias Springer       vec = generateInBoundsCheck(
13236825bfe2SNicolas Vasilache           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
13240f241638SMatthias Springer           /*inBoundsCase=*/
13250f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
13260f241638SMatthias Springer             // Indices for the new transfer op.
13270f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
13286825bfe2SNicolas Vasilache             getXferIndices(b, xferOp, iv, xferIndices);
13290f241638SMatthias Springer 
13300f241638SMatthias Springer             // Indices for the new vector.insert op.
133198f6289aSDiego Caballero             SmallVector<OpFoldResult, 8> insertionIndices;
13320f241638SMatthias Springer             getInsertionIndices(xferOp, insertionIndices);
133398f6289aSDiego Caballero             insertionIndices.push_back(rewriter.getIndexAttr(i));
13340f241638SMatthias Springer 
13357c38fd60SJacques Pienaar             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
13366825bfe2SNicolas Vasilache             auto newXferOp = b.create<vector::TransferReadOp>(
13377c38fd60SJacques Pienaar                 loc, newXferVecType, xferOp.getSource(), xferIndices,
13386825bfe2SNicolas Vasilache                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
13397c38fd60SJacques Pienaar                 xferOp.getPadding(), Value(), inBoundsAttr);
13400f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
13416825bfe2SNicolas Vasilache             return b.create<vector::InsertOp>(loc, newXferOp, vec,
13426825bfe2SNicolas Vasilache                                               insertionIndices);
13430f241638SMatthias Springer           },
13440f241638SMatthias Springer           /*outOfBoundsCase=*/
13450f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
13460f241638SMatthias Springer             // Loop through original (unmodified) vector.
13470f241638SMatthias Springer             return vec;
13480f241638SMatthias Springer           });
13490f241638SMatthias Springer     }
13500f241638SMatthias Springer 
13510f241638SMatthias Springer     if (insertOp) {
13520f241638SMatthias Springer       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
13530f241638SMatthias Springer       rewriter.replaceOp(insertOp, vec);
13540f241638SMatthias Springer       rewriter.eraseOp(xferOp);
13550f241638SMatthias Springer     } else {
13560f241638SMatthias Springer       rewriter.replaceOp(xferOp, vec);
13570f241638SMatthias Springer     }
13580f241638SMatthias Springer 
13590f241638SMatthias Springer     return success();
13600f241638SMatthias Springer   }
13610f241638SMatthias Springer };
13620f241638SMatthias Springer 
13630f241638SMatthias Springer /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
13640f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
13650f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
13660f241638SMatthias Springer ///
13670f241638SMatthias Springer /// ```
13680f241638SMatthias Springer /// E.g.:
13690f241638SMatthias Springer /// ```
13700f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
13710f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
13720f241638SMatthias Springer /// ```
13730f241638SMatthias Springer /// is rewritten to IR such as (simplified):
13740f241638SMatthias Springer /// ```
13759816edc9SCullen Rhodes /// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32>
13760f241638SMatthias Springer /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
13779816edc9SCullen Rhodes /// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32>
13780f241638SMatthias Springer /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
13790f241638SMatthias Springer /// ...
13809816edc9SCullen Rhodes /// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32>
13810f241638SMatthias Springer /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
13820f241638SMatthias Springer /// ```
13830f241638SMatthias Springer ///
13840f241638SMatthias Springer /// Note: As an optimization, if the vector of the original TransferWriteOp
13850f241638SMatthias Springer /// was directly extracted from another vector via an ExtractOp `a`, extract
13860f241638SMatthias Springer /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
13870f241638SMatthias Springer /// doing so, `a` may become dead, and the number of ExtractOps generated during
13880f241638SMatthias Springer /// recursive application of this pattern will be minimal.
13890f241638SMatthias Springer struct UnrollTransferWriteConversion
13902ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
13912ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
13920f241638SMatthias Springer 
1393700b64dcSMatthias Springer   void initialize() {
1394700b64dcSMatthias Springer     // This pattern recursively unpacks one dimension at a time. The recursion
1395700b64dcSMatthias Springer     // bounded as the rank is strictly decreasing.
1396700b64dcSMatthias Springer     setHasBoundedRewriteRecursion();
1397700b64dcSMatthias Springer   }
1398700b64dcSMatthias Springer 
13990f241638SMatthias Springer   /// Return the vector from which newly generated ExtracOps will extract.
14000f241638SMatthias Springer   Value getDataVector(TransferWriteOp xferOp) const {
14010f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp))
14027c38fd60SJacques Pienaar       return extractOp.getVector();
14037c38fd60SJacques Pienaar     return xferOp.getVector();
14040f241638SMatthias Springer   }
14050f241638SMatthias Springer 
14060f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
14070f241638SMatthias Springer   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
14087c38fd60SJacques Pienaar     if (auto *op = xferOp.getVector().getDefiningOp())
14090f241638SMatthias Springer       return dyn_cast<vector::ExtractOp>(op);
14100f241638SMatthias Springer     return vector::ExtractOp();
14110f241638SMatthias Springer   }
14120f241638SMatthias Springer 
14130f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return its
14140f241638SMatthias Springer   /// indices.
14150f241638SMatthias Springer   void getExtractionIndices(TransferWriteOp xferOp,
141698f6289aSDiego Caballero                             SmallVectorImpl<OpFoldResult> &indices) const {
141798f6289aSDiego Caballero     if (auto extractOp = getExtractOp(xferOp)) {
141898f6289aSDiego Caballero       auto pos = extractOp.getMixedPosition();
141998f6289aSDiego Caballero       indices.append(pos.begin(), pos.end());
142098f6289aSDiego Caballero     }
14210f241638SMatthias Springer   }
14220f241638SMatthias Springer 
14230f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
14240f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
14250f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
14260f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
1427c84061fdSRik Huijzer     VectorType inputVectorTy = xferOp.getVectorType();
1428c84061fdSRik Huijzer 
1429c84061fdSRik Huijzer     if (inputVectorTy.getRank() <= options.targetRank)
14300f241638SMatthias Springer       return failure();
1431c84061fdSRik Huijzer 
143227a713f5SBenjamin Maxwell     if (failed(checkLowerTensors(xferOp, rewriter)))
14338fb48979SMatthias Springer       return failure();
1434f718a53dSMatthias Springer     // Transfer ops that modify the element type are not supported atm.
1435c84061fdSRik Huijzer     if (inputVectorTy.getElementType() !=
1436f718a53dSMatthias Springer         xferOp.getShapedType().getElementType())
1437f718a53dSMatthias Springer       return failure();
14380f241638SMatthias Springer 
14390f241638SMatthias Springer     auto vec = getDataVector(xferOp);
1440c84061fdSRik Huijzer     if (inputVectorTy.getScalableDims()[0]) {
1441ced97ffdSBenjamin Maxwell       // Cannot unroll a scalable dimension at compile time.
1442ced97ffdSBenjamin Maxwell       return failure();
1443ced97ffdSBenjamin Maxwell     }
1444ced97ffdSBenjamin Maxwell 
1445c84061fdSRik Huijzer     int64_t dimSize = inputVectorTy.getShape()[0];
14460ce25b12SRahul Kayaith     Value source = xferOp.getSource(); // memref or tensor to be written to.
1447bd20756dSMatthias Springer     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
14480f241638SMatthias Springer 
14490f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
14506825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
14510f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
1452a54f4eaeSMogball       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
14530f241638SMatthias Springer 
1454bd20756dSMatthias Springer       auto updatedSource = generateInBoundsCheck(
14556825bfe2SNicolas Vasilache           rewriter, xferOp, iv, unpackedDim(xferOp),
1456bd20756dSMatthias Springer           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1457bd20756dSMatthias Springer           /*inBoundsCase=*/
1458bd20756dSMatthias Springer           [&](OpBuilder &b, Location loc) {
14590f241638SMatthias Springer             // Indices for the new transfer op.
14600f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
14616825bfe2SNicolas Vasilache             getXferIndices(b, xferOp, iv, xferIndices);
14620f241638SMatthias Springer 
14630f241638SMatthias Springer             // Indices for the new vector.extract op.
146498f6289aSDiego Caballero             SmallVector<OpFoldResult, 8> extractionIndices;
14650f241638SMatthias Springer             getExtractionIndices(xferOp, extractionIndices);
146698f6289aSDiego Caballero             extractionIndices.push_back(b.getI64IntegerAttr(i));
14670f241638SMatthias Springer 
14686825bfe2SNicolas Vasilache             auto extracted =
14696825bfe2SNicolas Vasilache                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
14707c38fd60SJacques Pienaar             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1471c84061fdSRik Huijzer             Value xferVec;
1472c84061fdSRik Huijzer             if (inputVectorTy.getRank() == 1) {
1473c84061fdSRik Huijzer               // When target-rank=0, unrolling would causes the vector input
1474c84061fdSRik Huijzer               // argument into `transfer_write` to become a scalar. We solve
1475c84061fdSRik Huijzer               // this by broadcasting the scalar to a 0D vector.
1476c84061fdSRik Huijzer               xferVec = b.create<vector::BroadcastOp>(
1477c84061fdSRik Huijzer                   loc, VectorType::get({}, extracted.getType()), extracted);
1478c84061fdSRik Huijzer             } else {
1479c84061fdSRik Huijzer               xferVec = extracted;
1480c84061fdSRik Huijzer             }
14816825bfe2SNicolas Vasilache             auto newXferOp = b.create<vector::TransferWriteOp>(
1482c84061fdSRik Huijzer                 loc, sourceType, xferVec, source, xferIndices,
14836825bfe2SNicolas Vasilache                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
14846825bfe2SNicolas Vasilache                 inBoundsAttr);
14850f241638SMatthias Springer 
14860f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
1487bd20756dSMatthias Springer 
1488bd20756dSMatthias Springer             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1489bd20756dSMatthias Springer           },
1490bd20756dSMatthias Springer           /*outOfBoundsCase=*/
1491bd20756dSMatthias Springer           [&](OpBuilder &b, Location loc) {
1492bd20756dSMatthias Springer             return isTensorOp(xferOp) ? source : Value();
14930f241638SMatthias Springer           });
1494bd20756dSMatthias Springer 
1495bd20756dSMatthias Springer       if (isTensorOp(xferOp))
1496bd20756dSMatthias Springer         source = updatedSource;
14970f241638SMatthias Springer     }
14980f241638SMatthias Springer 
1499bd20756dSMatthias Springer     if (isTensorOp(xferOp))
1500bd20756dSMatthias Springer       rewriter.replaceOp(xferOp, source);
1501bd20756dSMatthias Springer     else
15020f241638SMatthias Springer       rewriter.eraseOp(xferOp);
1503bd20756dSMatthias Springer 
15040f241638SMatthias Springer     return success();
15050f241638SMatthias Springer   }
15060f241638SMatthias Springer };
15070f241638SMatthias Springer 
1508a088bed4SMatthias Springer } // namespace lowering_n_d_unrolled
1509a088bed4SMatthias Springer 
1510a088bed4SMatthias Springer namespace lowering_1_d {
1511a088bed4SMatthias Springer 
15120f241638SMatthias Springer /// Compute the indices into the memref for the LoadOp/StoreOp generated as
15130f241638SMatthias Springer /// part of TransferOp1dConversion. Return the memref dimension on which
151415ae9964SKazu Hirata /// the transfer is operating. A return value of std::nullopt indicates a
151515ae9964SKazu Hirata /// broadcast.
15160f241638SMatthias Springer template <typename OpTy>
15170a81ace0SKazu Hirata static std::optional<int64_t>
15186825bfe2SNicolas Vasilache get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
15190f241638SMatthias Springer                    SmallVector<Value, 8> &memrefIndices) {
15207c38fd60SJacques Pienaar   auto indices = xferOp.getIndices();
15217c38fd60SJacques Pienaar   auto map = xferOp.getPermutationMap();
1522c537a943SNicolas Vasilache   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
15230f241638SMatthias Springer 
15240f241638SMatthias Springer   memrefIndices.append(indices.begin(), indices.end());
15250f241638SMatthias Springer   assert(map.getNumResults() == 1 &&
15260f241638SMatthias Springer          "Expected 1 permutation map result for 1D transfer");
15271609f1c2Slong.chen   if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) {
15286825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
15290f241638SMatthias Springer     auto dim = expr.getPosition();
15306825bfe2SNicolas Vasilache     AffineExpr d0, d1;
15316825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
15326825bfe2SNicolas Vasilache     Value offset = memrefIndices[dim];
15334c48f016SMatthias Springer     memrefIndices[dim] =
15344c48f016SMatthias Springer         affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
15350f241638SMatthias Springer     return dim;
15360f241638SMatthias Springer   }
15370f241638SMatthias Springer 
15380f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
15390f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
15401a36588eSKazu Hirata   return std::nullopt;
15410f241638SMatthias Springer }
15420f241638SMatthias Springer 
15430f241638SMatthias Springer /// Codegen strategy for TransferOp1dConversion, depending on the
15440f241638SMatthias Springer /// operation.
15450f241638SMatthias Springer template <typename OpTy>
15460f241638SMatthias Springer struct Strategy1d;
15470f241638SMatthias Springer 
15480f241638SMatthias Springer /// Codegen strategy for TransferReadOp.
15490f241638SMatthias Springer template <>
15500f241638SMatthias Springer struct Strategy1d<TransferReadOp> {
15516825bfe2SNicolas Vasilache   static void generateForLoopBody(OpBuilder &b, Location loc,
15520f241638SMatthias Springer                                   TransferReadOp xferOp, Value iv,
15530f241638SMatthias Springer                                   ValueRange loopState) {
15540f241638SMatthias Springer     SmallVector<Value, 8> indices;
15556825bfe2SNicolas Vasilache     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
15560f241638SMatthias Springer     auto vec = loopState[0];
15570f241638SMatthias Springer 
15580f241638SMatthias Springer     // In case of out-of-bounds access, leave `vec` as is (was initialized with
15590f241638SMatthias Springer     // padding value).
15600f241638SMatthias Springer     auto nextVec = generateInBoundsCheck(
15616825bfe2SNicolas Vasilache         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
15620f241638SMatthias Springer         /*inBoundsCase=*/
15636825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
15647c38fd60SJacques Pienaar           Value val =
15657c38fd60SJacques Pienaar               b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
15667c5ecc8bSMogball           return b.create<vector::InsertElementOp>(loc, val, vec, iv);
15670f241638SMatthias Springer         },
15680f241638SMatthias Springer         /*outOfBoundsCase=*/
15690f241638SMatthias Springer         [&](OpBuilder & /*b*/, Location loc) { return vec; });
15706825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc, nextVec);
15710f241638SMatthias Springer   }
15720f241638SMatthias Springer 
15736825bfe2SNicolas Vasilache   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
15740f241638SMatthias Springer     // Inititalize vector with padding value.
15756825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
15766a8ba318SRiver Riddle     return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
15777c38fd60SJacques Pienaar                                      xferOp.getPadding());
15780f241638SMatthias Springer   }
15790f241638SMatthias Springer };
15800f241638SMatthias Springer 
15810f241638SMatthias Springer /// Codegen strategy for TransferWriteOp.
15820f241638SMatthias Springer template <>
15830f241638SMatthias Springer struct Strategy1d<TransferWriteOp> {
15846825bfe2SNicolas Vasilache   static void generateForLoopBody(OpBuilder &b, Location loc,
15850f241638SMatthias Springer                                   TransferWriteOp xferOp, Value iv,
15860f241638SMatthias Springer                                   ValueRange /*loopState*/) {
15870f241638SMatthias Springer     SmallVector<Value, 8> indices;
15886825bfe2SNicolas Vasilache     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
15890f241638SMatthias Springer 
15900f241638SMatthias Springer     // Nothing to do in case of out-of-bounds access.
15910f241638SMatthias Springer     generateInBoundsCheck(
15926825bfe2SNicolas Vasilache         b, xferOp, iv, dim,
15936825bfe2SNicolas Vasilache         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
15946825bfe2SNicolas Vasilache           auto val =
15957c38fd60SJacques Pienaar               b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
15967c38fd60SJacques Pienaar           b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
15970f241638SMatthias Springer         });
15986825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc);
15990f241638SMatthias Springer   }
16000f241638SMatthias Springer 
16016825bfe2SNicolas Vasilache   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
16026825bfe2SNicolas Vasilache     return Value();
16036825bfe2SNicolas Vasilache   }
16040f241638SMatthias Springer };
16050f241638SMatthias Springer 
16060f241638SMatthias Springer /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
16070f241638SMatthias Springer /// necessary in cases where a 1D vector transfer op cannot be lowered into
16080f241638SMatthias Springer /// vector load/stores due to non-unit strides or broadcasts:
16090f241638SMatthias Springer ///
16100f241638SMatthias Springer /// * Transfer dimension is not the last memref dimension
16110f241638SMatthias Springer /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
16120f241638SMatthias Springer /// * Memref has a layout map with non-unit stride on the last dimension
16130f241638SMatthias Springer ///
16140f241638SMatthias Springer /// This pattern generates IR as follows:
16150f241638SMatthias Springer ///
16160f241638SMatthias Springer /// 1. Generate a for loop iterating over each vector element.
16170f241638SMatthias Springer /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
16180f241638SMatthias Springer ///    depending on OpTy.
16190f241638SMatthias Springer ///
16200f241638SMatthias Springer /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
16210f241638SMatthias Springer ///       can be generated instead of TransferOp1dConversion. Add such a pattern
16220f241638SMatthias Springer ///       to ConvertVectorToLLVM.
16230f241638SMatthias Springer ///
16240f241638SMatthias Springer /// E.g.:
16250f241638SMatthias Springer /// ```
16260f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b]
16270f241638SMatthias Springer ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
16280f241638SMatthias Springer ///    : vector<9xf32>, memref<?x?xf32>
16290f241638SMatthias Springer /// ```
16300f241638SMatthias Springer /// Is rewritten to approximately the following pseudo-IR:
16310f241638SMatthias Springer /// ```
16320f241638SMatthias Springer /// for i = 0 to 9 {
16330f241638SMatthias Springer ///   %t = vector.extractelement %vec[i] : vector<9xf32>
16340f241638SMatthias Springer ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
16350f241638SMatthias Springer /// }
16360f241638SMatthias Springer /// ```
16370f241638SMatthias Springer template <typename OpTy>
16382ca887deSMatthias Springer struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
16392ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
16400f241638SMatthias Springer 
16410f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
16420f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
1643c537a943SNicolas Vasilache     // TODO: support 0-d corner case.
1644c537a943SNicolas Vasilache     if (xferOp.getTransferRank() == 0)
1645c537a943SNicolas Vasilache       return failure();
16467c38fd60SJacques Pienaar     auto map = xferOp.getPermutationMap();
16475550c821STres Popp     auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType());
16480f241638SMatthias Springer 
16490f241638SMatthias Springer     if (!memRefType)
16500f241638SMatthias Springer       return failure();
16510f241638SMatthias Springer     if (xferOp.getVectorType().getRank() != 1)
16520f241638SMatthias Springer       return failure();
1653*6aaa8f25SMatthias Springer     if (map.isMinorIdentity() && memRefType.isLastDimUnitStride())
16540f241638SMatthias Springer       return failure(); // Handled by ConvertVectorToLLVM
16550f241638SMatthias Springer 
16560f241638SMatthias Springer     // Loop bounds, step, state...
16576825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
16580f241638SMatthias Springer     auto vecType = xferOp.getVectorType();
1659a54f4eaeSMogball     auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
16605cebffc2SAndrzej Warzynski     Value ub =
1661a54f4eaeSMogball         rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
16625cebffc2SAndrzej Warzynski     if (vecType.isScalable()) {
16635cebffc2SAndrzej Warzynski       Value vscale =
16645cebffc2SAndrzej Warzynski           rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
16655cebffc2SAndrzej Warzynski       ub = rewriter.create<arith::MulIOp>(loc, ub, vscale);
16665cebffc2SAndrzej Warzynski     }
1667a54f4eaeSMogball     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
16686825bfe2SNicolas Vasilache     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
16690f241638SMatthias Springer 
16700f241638SMatthias Springer     // Generate for loop.
16710f241638SMatthias Springer     rewriter.replaceOpWithNewOp<scf::ForOp>(
16720f241638SMatthias Springer         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
16736825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
16746825bfe2SNicolas Vasilache           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
16750f241638SMatthias Springer         });
16760f241638SMatthias Springer 
16770f241638SMatthias Springer     return success();
16780f241638SMatthias Springer   }
16790f241638SMatthias Springer };
16804ead2cf7SAlex Zinenko 
1681a088bed4SMatthias Springer } // namespace lowering_1_d
1682df63eedeSBenjamin Kramer } // namespace
1683df63eedeSBenjamin Kramer 
168447f175b0SRiver Riddle void mlir::populateVectorToSCFConversionPatterns(
1685dc4e913bSChris Lattner     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
16860f241638SMatthias Springer   if (options.unroll) {
1687a088bed4SMatthias Springer     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1688a088bed4SMatthias Springer                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
16892ca887deSMatthias Springer         patterns.getContext(), options);
16900f241638SMatthias Springer   } else {
1691a088bed4SMatthias Springer     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1692a088bed4SMatthias Springer                  lowering_n_d::PrepareTransferWriteConversion,
1693a088bed4SMatthias Springer                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1694a088bed4SMatthias Springer                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1695a088bed4SMatthias Springer         patterns.getContext(), options);
16960f241638SMatthias Springer   }
169727a713f5SBenjamin Maxwell   if (options.lowerScalable) {
169827a713f5SBenjamin Maxwell     patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>(
169927a713f5SBenjamin Maxwell         patterns.getContext(), options);
170027a713f5SBenjamin Maxwell   }
17012ca887deSMatthias Springer   if (options.targetRank == 1) {
1702a088bed4SMatthias Springer     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1703a088bed4SMatthias Springer                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1704a088bed4SMatthias Springer         patterns.getContext(), options);
17050f241638SMatthias Springer   }
1706f36e909dSBenjamin Maxwell   patterns.add<lowering_n_d::DecomposePrintOpConversion>(patterns.getContext(),
1707f36e909dSBenjamin Maxwell                                                          options);
17084ead2cf7SAlex Zinenko }
17093393cc4cSNicolas Vasilache 
17105f9e0466SNicolas Vasilache namespace {
17115f9e0466SNicolas Vasilache 
17125f9e0466SNicolas Vasilache struct ConvertVectorToSCFPass
171367d0d7acSMichele Scuttari     : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
17145f9e0466SNicolas Vasilache   ConvertVectorToSCFPass() = default;
17155f9e0466SNicolas Vasilache   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
17165f9e0466SNicolas Vasilache     this->fullUnroll = options.unroll;
17172ca887deSMatthias Springer     this->targetRank = options.targetRank;
1718558e7401SMatthias Springer     this->lowerTensors = options.lowerTensors;
171927a713f5SBenjamin Maxwell     this->lowerScalable = options.lowerScalable;
17205f9e0466SNicolas Vasilache   }
17215f9e0466SNicolas Vasilache 
172241574554SRiver Riddle   void runOnOperation() override {
17232ca887deSMatthias Springer     VectorTransferToSCFOptions options;
1724fb7ec1f1SMatthias Springer     options.unroll = fullUnroll;
1725fb7ec1f1SMatthias Springer     options.targetRank = targetRank;
1726558e7401SMatthias Springer     options.lowerTensors = lowerTensors;
172727a713f5SBenjamin Maxwell     options.lowerScalable = lowerScalable;
1728fb7ec1f1SMatthias Springer 
1729fb7ec1f1SMatthias Springer     // Lower permutation maps first.
173047f175b0SRiver Riddle     RewritePatternSet lowerTransferPatterns(&getContext());
1731fb7ec1f1SMatthias Springer     mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1732fb7ec1f1SMatthias Springer         lowerTransferPatterns);
173309dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(),
1734fb7ec1f1SMatthias Springer                                 std::move(lowerTransferPatterns));
17352ca887deSMatthias Springer 
173647f175b0SRiver Riddle     RewritePatternSet patterns(&getContext());
17372ca887deSMatthias Springer     populateVectorToSCFConversionPatterns(patterns, options);
173809dfc571SJacques Pienaar     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
17395f9e0466SNicolas Vasilache   }
17405f9e0466SNicolas Vasilache };
17415f9e0466SNicolas Vasilache 
17425f9e0466SNicolas Vasilache } // namespace
17435f9e0466SNicolas Vasilache 
17445f9e0466SNicolas Vasilache std::unique_ptr<Pass>
17455f9e0466SNicolas Vasilache mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
17465f9e0466SNicolas Vasilache   return std::make_unique<ConvertVectorToSCFPass>(options);
17475f9e0466SNicolas Vasilache }
1748