10f241638SMatthias Springer //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// 24ead2cf7SAlex Zinenko // 34ead2cf7SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44ead2cf7SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 54ead2cf7SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64ead2cf7SAlex Zinenko // 74ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===// 84ead2cf7SAlex Zinenko // 90f241638SMatthias Springer // This file implements lowering of vector transfer operations to SCF. 104ead2cf7SAlex Zinenko // 114ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===// 124ead2cf7SAlex Zinenko 13f36e909dSBenjamin Maxwell #include <numeric> 14a1fe1f5fSKazu Hirata #include <optional> 152bc4c3e9SNicolas Vasilache #include <type_traits> 164ead2cf7SAlex Zinenko 174ead2cf7SAlex Zinenko #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 185f9e0466SNicolas Vasilache 196825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h" 20abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h" 2166f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 228b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 23a795955fSKai Sasaki #include "mlir/Dialect/Tensor/IR/Tensor.h" 24c84061fdSRik Huijzer #include "mlir/Dialect/Vector/IR/VectorOps.h" 252bc4c3e9SNicolas Vasilache #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" 2699ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 2727a713f5SBenjamin Maxwell #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 284ead2cf7SAlex Zinenko #include "mlir/IR/Builders.h" 296825bfe2SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h" 305f9e0466SNicolas Vasilache #include "mlir/Pass/Pass.h" 31b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 325f9e0466SNicolas Vasilache #include "mlir/Transforms/Passes.h" 334ead2cf7SAlex Zinenko 3467d0d7acSMichele Scuttari namespace mlir { 3567d0d7acSMichele Scuttari #define GEN_PASS_DEF_CONVERTVECTORTOSCF 3667d0d7acSMichele Scuttari #include "mlir/Conversion/Passes.h.inc" 3767d0d7acSMichele Scuttari } // namespace mlir 3867d0d7acSMichele Scuttari 394ead2cf7SAlex Zinenko using namespace mlir; 404ead2cf7SAlex Zinenko using vector::TransferReadOp; 414ead2cf7SAlex Zinenko using vector::TransferWriteOp; 424ead2cf7SAlex Zinenko 43350dadaaSBenjamin Kramer namespace { 440f241638SMatthias Springer 450f241638SMatthias Springer /// Attribute name used for labeling transfer ops during progressive lowering. 460f241638SMatthias Springer static const char kPassLabel[] = "__vector_to_scf_lowering__"; 470f241638SMatthias Springer 4827a713f5SBenjamin Maxwell /// Return true if this transfer op operates on a source tensor. 4927a713f5SBenjamin Maxwell static bool isTensorOp(VectorTransferOpInterface xferOp) { 5027a713f5SBenjamin Maxwell if (isa<RankedTensorType>(xferOp.getShapedType())) { 5127a713f5SBenjamin Maxwell if (isa<vector::TransferWriteOp>(xferOp)) { 5227a713f5SBenjamin Maxwell // TransferWriteOps on tensors have a result. 5327a713f5SBenjamin Maxwell assert(xferOp->getNumResults() > 0); 5427a713f5SBenjamin Maxwell } 5527a713f5SBenjamin Maxwell return true; 5627a713f5SBenjamin Maxwell } 5727a713f5SBenjamin Maxwell return false; 5827a713f5SBenjamin Maxwell } 5927a713f5SBenjamin Maxwell 602ca887deSMatthias Springer /// Patterns that inherit from this struct have access to 612ca887deSMatthias Springer /// VectorTransferToSCFOptions. 622ca887deSMatthias Springer template <typename OpTy> 632ca887deSMatthias Springer struct VectorToSCFPattern : public OpRewritePattern<OpTy> { 642ca887deSMatthias Springer explicit VectorToSCFPattern(MLIRContext *context, 652ca887deSMatthias Springer VectorTransferToSCFOptions opt) 662ca887deSMatthias Springer : OpRewritePattern<OpTy>(context), options(opt) {} 672ca887deSMatthias Springer 6827a713f5SBenjamin Maxwell LogicalResult checkLowerTensors(VectorTransferOpInterface xferOp, 6927a713f5SBenjamin Maxwell PatternRewriter &rewriter) const { 7027a713f5SBenjamin Maxwell if (isTensorOp(xferOp) && !options.lowerTensors) { 7127a713f5SBenjamin Maxwell return rewriter.notifyMatchFailure( 7227a713f5SBenjamin Maxwell xferOp, "lowering tensor transfers is disabled"); 7327a713f5SBenjamin Maxwell } 7427a713f5SBenjamin Maxwell return success(); 7527a713f5SBenjamin Maxwell } 7627a713f5SBenjamin Maxwell 772ca887deSMatthias Springer VectorTransferToSCFOptions options; 782ca887deSMatthias Springer }; 790f241638SMatthias Springer 800f241638SMatthias Springer /// Given a vector transfer op, calculate which dimension of the `source` 810f241638SMatthias Springer /// memref should be unpacked in the next application of TransferOpConversion. 8215ae9964SKazu Hirata /// A return value of std::nullopt indicates a broadcast. 830f241638SMatthias Springer template <typename OpTy> 840a81ace0SKazu Hirata static std::optional<int64_t> unpackedDim(OpTy xferOp) { 85c537a943SNicolas Vasilache // TODO: support 0-d corner case. 86c537a943SNicolas Vasilache assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 877c38fd60SJacques Pienaar auto map = xferOp.getPermutationMap(); 881609f1c2Slong.chen if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) { 890f241638SMatthias Springer return expr.getPosition(); 907c3c5b11SNicolas Vasilache } 910f241638SMatthias Springer assert(xferOp.isBroadcastDim(0) && 920f241638SMatthias Springer "Expected AffineDimExpr or AffineConstantExpr"); 931a36588eSKazu Hirata return std::nullopt; 940f241638SMatthias Springer } 950f241638SMatthias Springer 960f241638SMatthias Springer /// Compute the permutation map for the new (N-1)-D vector transfer op. This 970f241638SMatthias Springer /// map is identical to the current permutation map, but the first result is 980f241638SMatthias Springer /// omitted. 990f241638SMatthias Springer template <typename OpTy> 1006825bfe2SNicolas Vasilache static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { 101c537a943SNicolas Vasilache // TODO: support 0-d corner case. 102c537a943SNicolas Vasilache assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 1037c38fd60SJacques Pienaar auto map = xferOp.getPermutationMap(); 1040f241638SMatthias Springer return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), 1056825bfe2SNicolas Vasilache b.getContext()); 1060f241638SMatthias Springer } 1070f241638SMatthias Springer 1080f241638SMatthias Springer /// Calculate the indices for the new vector transfer op. 1090f241638SMatthias Springer /// 1100f241638SMatthias Springer /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... 1110f241638SMatthias Springer /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> 1120f241638SMatthias Springer /// ^^^^^^ 1130f241638SMatthias Springer /// `iv` is the iteration variable of the (new) surrounding loop. 1140f241638SMatthias Springer template <typename OpTy> 1156825bfe2SNicolas Vasilache static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, 1160f241638SMatthias Springer SmallVector<Value, 8> &indices) { 1170f241638SMatthias Springer typename OpTy::Adaptor adaptor(xferOp); 1180f241638SMatthias Springer // Corresponding memref dim of the vector dim that is unpacked. 1190f241638SMatthias Springer auto dim = unpackedDim(xferOp); 1207c38fd60SJacques Pienaar auto prevIndices = adaptor.getIndices(); 1210f241638SMatthias Springer indices.append(prevIndices.begin(), prevIndices.end()); 1220f241638SMatthias Springer 1236825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 124491d2701SKazu Hirata bool isBroadcast = !dim.has_value(); 1250f241638SMatthias Springer if (!isBroadcast) { 1266825bfe2SNicolas Vasilache AffineExpr d0, d1; 1276825bfe2SNicolas Vasilache bindDims(xferOp.getContext(), d0, d1); 128cbb09813SFangrui Song Value offset = adaptor.getIndices()[*dim]; 1294c48f016SMatthias Springer indices[*dim] = 1304c48f016SMatthias Springer affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 1310f241638SMatthias Springer } 1320f241638SMatthias Springer } 1330f241638SMatthias Springer 1346825bfe2SNicolas Vasilache static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, 1350f241638SMatthias Springer Value value) { 1360f241638SMatthias Springer if (hasRetVal) { 137558e7401SMatthias Springer assert(value && "Expected non-empty value"); 1386825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc, value); 1390f241638SMatthias Springer } else { 1406825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc); 1410f241638SMatthias Springer } 1420f241638SMatthias Springer } 1430f241638SMatthias Springer 1440f241638SMatthias Springer /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask 1450f241638SMatthias Springer /// is set to true. No such check is generated under following circumstances: 1460f241638SMatthias Springer /// * xferOp does not have a mask. 1470f241638SMatthias Springer /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is 1480f241638SMatthias Springer /// computed and attached to the new transfer op in the pattern.) 1490f241638SMatthias Springer /// * The to-be-unpacked dim of xferOp is a broadcast. 1500f241638SMatthias Springer template <typename OpTy> 1516825bfe2SNicolas Vasilache static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { 1527c38fd60SJacques Pienaar if (!xferOp.getMask()) 1530f241638SMatthias Springer return Value(); 1540f241638SMatthias Springer if (xferOp.getMaskType().getRank() != 1) 1550f241638SMatthias Springer return Value(); 1560f241638SMatthias Springer if (xferOp.isBroadcastDim(0)) 1570f241638SMatthias Springer return Value(); 1580f241638SMatthias Springer 1596825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 1607c38fd60SJacques Pienaar return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv); 1610f241638SMatthias Springer } 1620f241638SMatthias Springer 1630f241638SMatthias Springer /// Helper function TransferOpConversion and TransferOp1dConversion. 1640f241638SMatthias Springer /// Generate an in-bounds check if the transfer op may go out-of-bounds on the 1650f241638SMatthias Springer /// specified dimension `dim` with the loop iteration variable `iv`. 1660f241638SMatthias Springer /// E.g., when unpacking dimension 0 from: 1670f241638SMatthias Springer /// ``` 1680f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b] %cst 1690f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?xf32> 1700f241638SMatthias Springer /// ``` 1710f241638SMatthias Springer /// An if check similar to this will be generated inside the loop: 1720f241638SMatthias Springer /// ``` 1730f241638SMatthias Springer /// %d = memref.dim %A, %c0 : memref<?x?xf32> 1740f241638SMatthias Springer /// if (%a + iv < %d) { 1750f241638SMatthias Springer /// (in-bounds case) 1760f241638SMatthias Springer /// } else { 1770f241638SMatthias Springer /// (out-of-bounds case) 1780f241638SMatthias Springer /// } 1790f241638SMatthias Springer /// ``` 1800f241638SMatthias Springer /// 1810f241638SMatthias Springer /// If the transfer is 1D and has a mask, this function generates a more complex 1820f241638SMatthias Springer /// check also accounts for potentially masked out elements. 1830f241638SMatthias Springer /// 1840f241638SMatthias Springer /// This function variant returns the value returned by `inBoundsCase` or 1850f241638SMatthias Springer /// `outOfBoundsCase`. The MLIR type of the return value must be specified in 1860f241638SMatthias Springer /// `resultTypes`. 1870f241638SMatthias Springer template <typename OpTy> 1880f241638SMatthias Springer static Value generateInBoundsCheck( 1890a81ace0SKazu Hirata OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim, 1900f241638SMatthias Springer TypeRange resultTypes, 1910f241638SMatthias Springer function_ref<Value(OpBuilder &, Location)> inBoundsCase, 1920f241638SMatthias Springer function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 1930f241638SMatthias Springer bool hasRetVal = !resultTypes.empty(); 1940f241638SMatthias Springer Value cond; // Condition to be built... 1950f241638SMatthias Springer 1960f241638SMatthias Springer // Condition check 1: Access in-bounds? 1970916d96dSKazu Hirata bool isBroadcast = !dim; // No in-bounds check for broadcasts. 1986825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 1996825bfe2SNicolas Vasilache ImplicitLocOpBuilder lb(xferOp.getLoc(), b); 2000f241638SMatthias Springer if (!xferOp.isDimInBounds(0) && !isBroadcast) { 2017c38fd60SJacques Pienaar Value memrefDim = 2027c38fd60SJacques Pienaar vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim); 2036825bfe2SNicolas Vasilache AffineExpr d0, d1; 2046825bfe2SNicolas Vasilache bindDims(xferOp.getContext(), d0, d1); 2056d5fc1e3SKazu Hirata Value base = xferOp.getIndices()[*dim]; 2064c48f016SMatthias Springer Value memrefIdx = 2074c48f016SMatthias Springer affine::makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); 208a54f4eaeSMogball cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim, 209a54f4eaeSMogball memrefIdx); 2100f241638SMatthias Springer } 2110f241638SMatthias Springer 2120f241638SMatthias Springer // Condition check 2: Masked in? 2136825bfe2SNicolas Vasilache if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { 2146825bfe2SNicolas Vasilache if (cond) 215a54f4eaeSMogball cond = lb.create<arith::AndIOp>(cond, maskCond); 2166825bfe2SNicolas Vasilache else 2170f241638SMatthias Springer cond = maskCond; 2180f241638SMatthias Springer } 2190f241638SMatthias Springer 2200f241638SMatthias Springer // If the condition is non-empty, generate an SCF::IfOp. 2210f241638SMatthias Springer if (cond) { 2226825bfe2SNicolas Vasilache auto check = lb.create<scf::IfOp>( 2231125c5c0SFrederik Gossen cond, 2240f241638SMatthias Springer /*thenBuilder=*/ 2256825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) { 2266825bfe2SNicolas Vasilache maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); 227cadb7ccfSAlex Zinenko }, 2280f241638SMatthias Springer /*elseBuilder=*/ 2296825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) { 2300f241638SMatthias Springer if (outOfBoundsCase) { 2316825bfe2SNicolas Vasilache maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); 2327c3c5b11SNicolas Vasilache } else { 2336825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc); 2347c3c5b11SNicolas Vasilache } 2357c3c5b11SNicolas Vasilache }); 2367c3c5b11SNicolas Vasilache 2370f241638SMatthias Springer return hasRetVal ? check.getResult(0) : Value(); 2384ead2cf7SAlex Zinenko } 2394ead2cf7SAlex Zinenko 2400f241638SMatthias Springer // Condition is empty, no need for an SCF::IfOp. 2416825bfe2SNicolas Vasilache return inBoundsCase(b, loc); 2420f241638SMatthias Springer } 2430f241638SMatthias Springer 2440f241638SMatthias Springer /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have 2450f241638SMatthias Springer /// a return value. Consequently, this function does not have a return value. 2460f241638SMatthias Springer template <typename OpTy> 2470f241638SMatthias Springer static void generateInBoundsCheck( 2480a81ace0SKazu Hirata OpBuilder &b, OpTy xferOp, Value iv, std::optional<int64_t> dim, 2490f241638SMatthias Springer function_ref<void(OpBuilder &, Location)> inBoundsCase, 2500f241638SMatthias Springer function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 2510f241638SMatthias Springer generateInBoundsCheck( 2526825bfe2SNicolas Vasilache b, xferOp, iv, dim, /*resultTypes=*/TypeRange(), 2530f241638SMatthias Springer /*inBoundsCase=*/ 2546825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) { 2556825bfe2SNicolas Vasilache inBoundsCase(b, loc); 2560f241638SMatthias Springer return Value(); 2570f241638SMatthias Springer }, 2580f241638SMatthias Springer /*outOfBoundsCase=*/ 2596825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) { 2600f241638SMatthias Springer if (outOfBoundsCase) 2616825bfe2SNicolas Vasilache outOfBoundsCase(b, loc); 2620f241638SMatthias Springer return Value(); 2630f241638SMatthias Springer }); 2640f241638SMatthias Springer } 2650f241638SMatthias Springer 2660f241638SMatthias Springer /// Given an ArrayAttr, return a copy where the first element is dropped. 2676825bfe2SNicolas Vasilache static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) { 2680f241638SMatthias Springer if (!attr) 2690f241638SMatthias Springer return attr; 2706825bfe2SNicolas Vasilache return ArrayAttr::get(b.getContext(), attr.getValue().drop_front()); 2710f241638SMatthias Springer } 2720f241638SMatthias Springer 2730f241638SMatthias Springer /// Add the pass label to a vector transfer op if its rank is not the target 2740f241638SMatthias Springer /// rank. 2750f241638SMatthias Springer template <typename OpTy> 2766825bfe2SNicolas Vasilache static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, 2772ca887deSMatthias Springer unsigned targetRank) { 2782ca887deSMatthias Springer if (newXferOp.getVectorType().getRank() > targetRank) 2796825bfe2SNicolas Vasilache newXferOp->setAttr(kPassLabel, b.getUnitAttr()); 2800f241638SMatthias Springer } 2810f241638SMatthias Springer 282a088bed4SMatthias Springer namespace lowering_n_d { 283a088bed4SMatthias Springer 284a088bed4SMatthias Springer /// Helper data structure for data and mask buffers. 285a088bed4SMatthias Springer struct BufferAllocs { 286a088bed4SMatthias Springer Value dataBuffer; 287a088bed4SMatthias Springer Value maskBuffer; 288a088bed4SMatthias Springer }; 289a088bed4SMatthias Springer 2903c3810e7SNicolas Vasilache // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. 2913c3810e7SNicolas Vasilache static Operation *getAutomaticAllocationScope(Operation *op) { 2923c3810e7SNicolas Vasilache Operation *scope = 2933c3810e7SNicolas Vasilache op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 2943c3810e7SNicolas Vasilache assert(scope && "Expected op to be inside automatic allocation scope"); 2953c3810e7SNicolas Vasilache return scope; 2963c3810e7SNicolas Vasilache } 2973c3810e7SNicolas Vasilache 298a088bed4SMatthias Springer /// Allocate temporary buffers for data (vector) and mask (if present). 299a088bed4SMatthias Springer template <typename OpTy> 3006825bfe2SNicolas Vasilache static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { 3016825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 302a088bed4SMatthias Springer OpBuilder::InsertionGuard guard(b); 3033c3810e7SNicolas Vasilache Operation *scope = getAutomaticAllocationScope(xferOp); 3043c3810e7SNicolas Vasilache assert(scope->getNumRegions() == 1 && 3053c3810e7SNicolas Vasilache "AutomaticAllocationScope with >1 regions"); 306a088bed4SMatthias Springer b.setInsertionPointToStart(&scope->getRegion(0).front()); 307a088bed4SMatthias Springer 308a088bed4SMatthias Springer BufferAllocs result; 309a088bed4SMatthias Springer auto bufferType = MemRefType::get({}, xferOp.getVectorType()); 3106825bfe2SNicolas Vasilache result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType); 311a088bed4SMatthias Springer 3127c38fd60SJacques Pienaar if (xferOp.getMask()) { 3137c38fd60SJacques Pienaar auto maskType = MemRefType::get({}, xferOp.getMask().getType()); 3146825bfe2SNicolas Vasilache auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType); 315fb7ec1f1SMatthias Springer b.setInsertionPoint(xferOp); 3167c38fd60SJacques Pienaar b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer); 3171b60f0d7SJeff Niu result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer, ValueRange()); 318a088bed4SMatthias Springer } 319a088bed4SMatthias Springer 320a088bed4SMatthias Springer return result; 321a088bed4SMatthias Springer } 322a088bed4SMatthias Springer 323a088bed4SMatthias Springer /// Given a MemRefType with VectorType element type, unpack one dimension from 324a088bed4SMatthias Springer /// the VectorType into the MemRefType. 325a088bed4SMatthias Springer /// 326a088bed4SMatthias Springer /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> 3272a82dfd7SBenjamin Maxwell static FailureOr<MemRefType> unpackOneDim(MemRefType type) { 3285550c821STres Popp auto vectorType = dyn_cast<VectorType>(type.getElementType()); 3292a82dfd7SBenjamin Maxwell // Vectors with leading scalable dims are not supported. 3302a82dfd7SBenjamin Maxwell // It may be possible to support these in future by using dynamic memref dims. 3312a82dfd7SBenjamin Maxwell if (vectorType.getScalableDims().front()) 3322a82dfd7SBenjamin Maxwell return failure(); 333a088bed4SMatthias Springer auto memrefShape = type.getShape(); 334a088bed4SMatthias Springer SmallVector<int64_t, 8> newMemrefShape; 335a088bed4SMatthias Springer newMemrefShape.append(memrefShape.begin(), memrefShape.end()); 336a088bed4SMatthias Springer newMemrefShape.push_back(vectorType.getDimSize(0)); 337a088bed4SMatthias Springer return MemRefType::get(newMemrefShape, 3382a82dfd7SBenjamin Maxwell VectorType::Builder(vectorType).dropDim(0)); 339a088bed4SMatthias Springer } 340a088bed4SMatthias Springer 3410f241638SMatthias Springer /// Given a transfer op, find the memref from which the mask is loaded. This 3420f241638SMatthias Springer /// is similar to Strategy<TransferWriteOp>::getBuffer. 3430f241638SMatthias Springer template <typename OpTy> 3440f241638SMatthias Springer static Value getMaskBuffer(OpTy xferOp) { 3457c38fd60SJacques Pienaar assert(xferOp.getMask() && "Expected that transfer op has mask"); 3467c38fd60SJacques Pienaar auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>(); 3470f241638SMatthias Springer assert(loadOp && "Expected transfer op mask produced by LoadOp"); 3480f241638SMatthias Springer return loadOp.getMemRef(); 3490f241638SMatthias Springer } 3500f241638SMatthias Springer 3510f241638SMatthias Springer /// Codegen strategy, depending on the operation. 3520f241638SMatthias Springer template <typename OpTy> 3530f241638SMatthias Springer struct Strategy; 3540f241638SMatthias Springer 3550f241638SMatthias Springer /// Code strategy for vector TransferReadOp. 3564ead2cf7SAlex Zinenko template <> 3570f241638SMatthias Springer struct Strategy<TransferReadOp> { 3580f241638SMatthias Springer /// Find the StoreOp that is used for writing the current TransferReadOp's 3590f241638SMatthias Springer /// result to the temporary buffer allocation. 3600f241638SMatthias Springer static memref::StoreOp getStoreOp(TransferReadOp xferOp) { 3610f241638SMatthias Springer assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); 3620f241638SMatthias Springer auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); 3630f241638SMatthias Springer assert(storeOp && "Expected TransferReadOp result used by StoreOp"); 3640f241638SMatthias Springer return storeOp; 3657c3c5b11SNicolas Vasilache } 3664ead2cf7SAlex Zinenko 3670f241638SMatthias Springer /// Find the temporary buffer allocation. All labeled TransferReadOps are 3680f241638SMatthias Springer /// used like this, where %buf is either the buffer allocation or a type cast 3690f241638SMatthias Springer /// of the buffer allocation: 3700f241638SMatthias Springer /// ``` 3710f241638SMatthias Springer /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... 3720f241638SMatthias Springer /// memref.store %vec, %buf[...] ... 3730f241638SMatthias Springer /// ``` 3740f241638SMatthias Springer static Value getBuffer(TransferReadOp xferOp) { 3750f241638SMatthias Springer return getStoreOp(xferOp).getMemRef(); 3761870e787SNicolas Vasilache } 3770f241638SMatthias Springer 3780f241638SMatthias Springer /// Retrieve the indices of the current StoreOp that stores into the buffer. 3790f241638SMatthias Springer static void getBufferIndices(TransferReadOp xferOp, 3800f241638SMatthias Springer SmallVector<Value, 8> &indices) { 3819f5afc3dSRik Huijzer auto storeOp = getStoreOp(xferOp); 382136d746eSJacques Pienaar auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices(); 3830f241638SMatthias Springer indices.append(prevIndices.begin(), prevIndices.end()); 3840f241638SMatthias Springer } 3850f241638SMatthias Springer 3860f241638SMatthias Springer /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds 3870f241638SMatthias Springer /// accesses on the to-be-unpacked dimension. 3880f241638SMatthias Springer /// 3890f241638SMatthias Springer /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration 3900f241638SMatthias Springer /// variable `iv`. 3910f241638SMatthias Springer /// 2. Store the result into the (already `vector.type_cast`ed) buffer. 3920f241638SMatthias Springer /// 3930f241638SMatthias Springer /// E.g.: 3940f241638SMatthias Springer /// ``` 3950f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst 3960f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4x3xf32> 3970f241638SMatthias Springer /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> 3980f241638SMatthias Springer /// ``` 3990f241638SMatthias Springer /// Is rewritten to: 4000f241638SMatthias Springer /// ``` 4010f241638SMatthias Springer /// %casted = vector.type_cast %buf 4020f241638SMatthias Springer /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 4030f241638SMatthias Springer /// for %j = 0 to 4 { 4040f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst 4050f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<3xf32> 4060f241638SMatthias Springer /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> 4070f241638SMatthias Springer /// } 4080f241638SMatthias Springer /// ``` 4090f241638SMatthias Springer /// 4100f241638SMatthias Springer /// Note: The loop and type cast are generated in TransferOpConversion. 4110f241638SMatthias Springer /// The original TransferReadOp and store op are deleted in `cleanup`. 4120f241638SMatthias Springer /// Note: The `mask` operand is set in TransferOpConversion. 4136825bfe2SNicolas Vasilache static TransferReadOp rewriteOp(OpBuilder &b, 4142ca887deSMatthias Springer VectorTransferToSCFOptions options, 415558e7401SMatthias Springer TransferReadOp xferOp, Value buffer, Value iv, 416558e7401SMatthias Springer ValueRange /*loopState*/) { 4170f241638SMatthias Springer SmallVector<Value, 8> storeIndices; 4180f241638SMatthias Springer getBufferIndices(xferOp, storeIndices); 4190f241638SMatthias Springer storeIndices.push_back(iv); 4200f241638SMatthias Springer 4210f241638SMatthias Springer SmallVector<Value, 8> xferIndices; 4226825bfe2SNicolas Vasilache getXferIndices(b, xferOp, iv, xferIndices); 4230f241638SMatthias Springer 4246825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 4255550c821STres Popp auto bufferType = dyn_cast<ShapedType>(buffer.getType()); 4265550c821STres Popp auto vecType = dyn_cast<VectorType>(bufferType.getElementType()); 4277c38fd60SJacques Pienaar auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 4286825bfe2SNicolas Vasilache auto newXferOp = b.create<vector::TransferReadOp>( 4297c38fd60SJacques Pienaar loc, vecType, xferOp.getSource(), xferIndices, 4307c38fd60SJacques Pienaar AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 4317c38fd60SJacques Pienaar xferOp.getPadding(), Value(), inBoundsAttr); 4320f241638SMatthias Springer 4336825bfe2SNicolas Vasilache maybeApplyPassLabel(b, newXferOp, options.targetRank); 4340f241638SMatthias Springer 4357c38fd60SJacques Pienaar b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices); 4366825bfe2SNicolas Vasilache return newXferOp; 4370f241638SMatthias Springer } 4380f241638SMatthias Springer 4390f241638SMatthias Springer /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write 4400f241638SMatthias Springer /// padding value to the temporary buffer. 441558e7401SMatthias Springer static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, 442558e7401SMatthias Springer Value buffer, Value iv, 443558e7401SMatthias Springer ValueRange /*loopState*/) { 4440f241638SMatthias Springer SmallVector<Value, 8> storeIndices; 4450f241638SMatthias Springer getBufferIndices(xferOp, storeIndices); 4460f241638SMatthias Springer storeIndices.push_back(iv); 4470f241638SMatthias Springer 4486825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 4495550c821STres Popp auto bufferType = dyn_cast<ShapedType>(buffer.getType()); 4505550c821STres Popp auto vecType = dyn_cast<VectorType>(bufferType.getElementType()); 4517c38fd60SJacques Pienaar auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding()); 4526825bfe2SNicolas Vasilache b.create<memref::StoreOp>(loc, vec, buffer, storeIndices); 453558e7401SMatthias Springer 454558e7401SMatthias Springer return Value(); 4550f241638SMatthias Springer } 4560f241638SMatthias Springer 4570f241638SMatthias Springer /// Cleanup after rewriting the op. 458558e7401SMatthias Springer static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, 459558e7401SMatthias Springer scf::ForOp /*forOp*/) { 4600f241638SMatthias Springer rewriter.eraseOp(getStoreOp(xferOp)); 4610f241638SMatthias Springer rewriter.eraseOp(xferOp); 4620f241638SMatthias Springer } 463558e7401SMatthias Springer 464558e7401SMatthias Springer /// Return the initial loop state for the generated scf.for loop. 465558e7401SMatthias Springer static Value initialLoopState(TransferReadOp xferOp) { return Value(); } 4664ead2cf7SAlex Zinenko }; 4677c3c5b11SNicolas Vasilache 4680f241638SMatthias Springer /// Codegen strategy for vector TransferWriteOp. 4690f241638SMatthias Springer template <> 4700f241638SMatthias Springer struct Strategy<TransferWriteOp> { 4710f241638SMatthias Springer /// Find the temporary buffer allocation. All labeled TransferWriteOps are 4720f241638SMatthias Springer /// used like this, where %buf is either the buffer allocation or a type cast 4730f241638SMatthias Springer /// of the buffer allocation: 4740f241638SMatthias Springer /// ``` 4750f241638SMatthias Springer /// %vec = memref.load %buf[...] ... 4760f241638SMatthias Springer /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... 4770f241638SMatthias Springer /// ``` 4780f241638SMatthias Springer static Value getBuffer(TransferWriteOp xferOp) { 4797c38fd60SJacques Pienaar auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>(); 4800f241638SMatthias Springer assert(loadOp && "Expected transfer op vector produced by LoadOp"); 4810f241638SMatthias Springer return loadOp.getMemRef(); 4827c3c5b11SNicolas Vasilache } 4834ead2cf7SAlex Zinenko 4840f241638SMatthias Springer /// Retrieve the indices of the current LoadOp that loads from the buffer. 4850f241638SMatthias Springer static void getBufferIndices(TransferWriteOp xferOp, 4860f241638SMatthias Springer SmallVector<Value, 8> &indices) { 4877c38fd60SJacques Pienaar auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>(); 488136d746eSJacques Pienaar auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices(); 4890f241638SMatthias Springer indices.append(prevIndices.begin(), prevIndices.end()); 4900f241638SMatthias Springer } 4910f241638SMatthias Springer 4920f241638SMatthias Springer /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds 4930f241638SMatthias Springer /// accesses on the to-be-unpacked dimension. 4940f241638SMatthias Springer /// 4950f241638SMatthias Springer /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, 4960f241638SMatthias Springer /// using the loop iteration variable `iv`. 4970f241638SMatthias Springer /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back 4980f241638SMatthias Springer /// to memory. 4990f241638SMatthias Springer /// 5000f241638SMatthias Springer /// Note: For more details, see comments on Strategy<TransferReadOp>. 5016825bfe2SNicolas Vasilache static TransferWriteOp rewriteOp(OpBuilder &b, 5022ca887deSMatthias Springer VectorTransferToSCFOptions options, 5032ca887deSMatthias Springer TransferWriteOp xferOp, Value buffer, 504558e7401SMatthias Springer Value iv, ValueRange loopState) { 5050f241638SMatthias Springer SmallVector<Value, 8> loadIndices; 5060f241638SMatthias Springer getBufferIndices(xferOp, loadIndices); 5070f241638SMatthias Springer loadIndices.push_back(iv); 5080f241638SMatthias Springer 5090f241638SMatthias Springer SmallVector<Value, 8> xferIndices; 5106825bfe2SNicolas Vasilache getXferIndices(b, xferOp, iv, xferIndices); 5110f241638SMatthias Springer 5126825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 5136825bfe2SNicolas Vasilache auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices); 5147c38fd60SJacques Pienaar auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 5157c38fd60SJacques Pienaar auto source = loopState.empty() ? xferOp.getSource() : loopState[0]; 516558e7401SMatthias Springer Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 5176825bfe2SNicolas Vasilache auto newXferOp = b.create<vector::TransferWriteOp>( 518558e7401SMatthias Springer loc, type, vec, source, xferIndices, 5196825bfe2SNicolas Vasilache AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 5200f241638SMatthias Springer inBoundsAttr); 5210f241638SMatthias Springer 5226825bfe2SNicolas Vasilache maybeApplyPassLabel(b, newXferOp, options.targetRank); 5230f241638SMatthias Springer 5246825bfe2SNicolas Vasilache return newXferOp; 5250f241638SMatthias Springer } 5260f241638SMatthias Springer 5270f241638SMatthias Springer /// Handle out-of-bounds accesses on the to-be-unpacked dimension. 528558e7401SMatthias Springer static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, 529558e7401SMatthias Springer Value buffer, Value iv, 530558e7401SMatthias Springer ValueRange loopState) { 531558e7401SMatthias Springer return isTensorOp(xferOp) ? loopState[0] : Value(); 532558e7401SMatthias Springer } 5330f241638SMatthias Springer 5340f241638SMatthias Springer /// Cleanup after rewriting the op. 535558e7401SMatthias Springer static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, 536558e7401SMatthias Springer scf::ForOp forOp) { 537558e7401SMatthias Springer if (isTensorOp(xferOp)) { 538558e7401SMatthias Springer assert(forOp->getNumResults() == 1 && "Expected one for loop result"); 539558e7401SMatthias Springer rewriter.replaceOp(xferOp, forOp->getResult(0)); 540558e7401SMatthias Springer } else { 5410f241638SMatthias Springer rewriter.eraseOp(xferOp); 5420f241638SMatthias Springer } 543558e7401SMatthias Springer } 544558e7401SMatthias Springer 545558e7401SMatthias Springer /// Return the initial loop state for the generated scf.for loop. 546558e7401SMatthias Springer static Value initialLoopState(TransferWriteOp xferOp) { 5477c38fd60SJacques Pienaar return isTensorOp(xferOp) ? xferOp.getSource() : Value(); 548558e7401SMatthias Springer } 5490f241638SMatthias Springer }; 5500f241638SMatthias Springer 5510f241638SMatthias Springer template <typename OpTy> 552fb7ec1f1SMatthias Springer LogicalResult checkPrepareXferOp(OpTy xferOp, 553fb7ec1f1SMatthias Springer VectorTransferToSCFOptions options) { 5540f241638SMatthias Springer if (xferOp->hasAttr(kPassLabel)) 5550f241638SMatthias Springer return failure(); 556fb7ec1f1SMatthias Springer if (xferOp.getVectorType().getRank() <= options.targetRank) 5570f241638SMatthias Springer return failure(); 5582a82dfd7SBenjamin Maxwell // Currently the unpacking of the leading dimension into the memref is not 5592a82dfd7SBenjamin Maxwell // supported for scalable dimensions. 5602a82dfd7SBenjamin Maxwell if (xferOp.getVectorType().getScalableDims().front()) 5612a82dfd7SBenjamin Maxwell return failure(); 562558e7401SMatthias Springer if (isTensorOp(xferOp) && !options.lowerTensors) 5638fb48979SMatthias Springer return failure(); 564f718a53dSMatthias Springer // Transfer ops that modify the element type are not supported atm. 565f718a53dSMatthias Springer if (xferOp.getVectorType().getElementType() != 566f718a53dSMatthias Springer xferOp.getShapedType().getElementType()) 567f718a53dSMatthias Springer return failure(); 5680f241638SMatthias Springer return success(); 5690f241638SMatthias Springer } 5700f241638SMatthias Springer 5710f241638SMatthias Springer /// Prepare a TransferReadOp for progressive lowering. 5720f241638SMatthias Springer /// 5730f241638SMatthias Springer /// 1. Allocate a temporary buffer. 5740f241638SMatthias Springer /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. 5750f241638SMatthias Springer /// 3. Store the result of the TransferReadOp into the temporary buffer. 5760f241638SMatthias Springer /// 4. Load the result from the temporary buffer and replace all uses of the 5770f241638SMatthias Springer /// original TransferReadOp with this load. 5780f241638SMatthias Springer /// 5790f241638SMatthias Springer /// E.g.: 5800f241638SMatthias Springer /// ``` 5810f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %cst 5820f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32> 5830f241638SMatthias Springer /// ``` 5840f241638SMatthias Springer /// is rewritten to: 5850f241638SMatthias Springer /// ``` 5860f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>> 5870f241638SMatthias Springer /// %1 = vector.transfer_read %A[%a, %b, %c], %cst 5880f241638SMatthias Springer /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> 5890f241638SMatthias Springer /// memref.store %1, %0[] : memref<vector<5x4xf32>> 5900f241638SMatthias Springer /// %vec = memref.load %0[] : memref<vector<5x4xf32>> 5910f241638SMatthias Springer /// ``` 5920f241638SMatthias Springer /// 5930f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand. 5942ca887deSMatthias Springer struct PrepareTransferReadConversion 5952ca887deSMatthias Springer : public VectorToSCFPattern<TransferReadOp> { 5962ca887deSMatthias Springer using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 5970f241638SMatthias Springer 5980f241638SMatthias Springer LogicalResult matchAndRewrite(TransferReadOp xferOp, 5990f241638SMatthias Springer PatternRewriter &rewriter) const override { 600fb7ec1f1SMatthias Springer if (checkPrepareXferOp(xferOp, options).failed()) 6010f241638SMatthias Springer return failure(); 6020f241638SMatthias Springer 6039f5afc3dSRik Huijzer auto buffers = allocBuffers(rewriter, xferOp); 6049f5afc3dSRik Huijzer auto *newXfer = rewriter.clone(*xferOp.getOperation()); 6050f241638SMatthias Springer newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); 6067c38fd60SJacques Pienaar if (xferOp.getMask()) { 6077c38fd60SJacques Pienaar dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign( 6080f241638SMatthias Springer buffers.maskBuffer); 6090f241638SMatthias Springer } 6100f241638SMatthias Springer 6116825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 6126825bfe2SNicolas Vasilache rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0), 6136825bfe2SNicolas Vasilache buffers.dataBuffer); 6140f241638SMatthias Springer rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); 6154ead2cf7SAlex Zinenko 6164ead2cf7SAlex Zinenko return success(); 6174ead2cf7SAlex Zinenko } 6180f241638SMatthias Springer }; 6190f241638SMatthias Springer 6200f241638SMatthias Springer /// Prepare a TransferWriteOp for progressive lowering. 6210f241638SMatthias Springer /// 6220f241638SMatthias Springer /// 1. Allocate a temporary buffer. 6230f241638SMatthias Springer /// 2. Store the vector into the buffer. 6240f241638SMatthias Springer /// 3. Load the vector from the buffer again. 6250f241638SMatthias Springer /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, 6260f241638SMatthias Springer /// marking it eligible for progressive lowering via TransferOpConversion. 6270f241638SMatthias Springer /// 6280f241638SMatthias Springer /// E.g.: 6290f241638SMatthias Springer /// ``` 6300f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c] 6310f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32> 6320f241638SMatthias Springer /// ``` 6330f241638SMatthias Springer /// is rewritten to: 6340f241638SMatthias Springer /// ``` 6350f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>> 6360f241638SMatthias Springer /// memref.store %vec, %0[] : memref<vector<5x4xf32>> 6370f241638SMatthias Springer /// %1 = memref.load %0[] : memref<vector<5x4xf32>> 6380f241638SMatthias Springer /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } 6390f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32> 6400f241638SMatthias Springer /// ``` 6410f241638SMatthias Springer /// 6420f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand. 6430f241638SMatthias Springer struct PrepareTransferWriteConversion 6442ca887deSMatthias Springer : public VectorToSCFPattern<TransferWriteOp> { 6452ca887deSMatthias Springer using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 6460f241638SMatthias Springer 6470f241638SMatthias Springer LogicalResult matchAndRewrite(TransferWriteOp xferOp, 6480f241638SMatthias Springer PatternRewriter &rewriter) const override { 649fb7ec1f1SMatthias Springer if (checkPrepareXferOp(xferOp, options).failed()) 6500f241638SMatthias Springer return failure(); 6510f241638SMatthias Springer 6526825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 6536825bfe2SNicolas Vasilache auto buffers = allocBuffers(rewriter, xferOp); 6547c38fd60SJacques Pienaar rewriter.create<memref::StoreOp>(loc, xferOp.getVector(), 6557c38fd60SJacques Pienaar buffers.dataBuffer); 6566825bfe2SNicolas Vasilache auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer); 6575fcf907bSMatthias Springer rewriter.modifyOpInPlace(xferOp, [&]() { 6587c38fd60SJacques Pienaar xferOp.getVectorMutable().assign(loadedVec); 6590f241638SMatthias Springer xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); 6600f241638SMatthias Springer }); 6610f241638SMatthias Springer 6627c38fd60SJacques Pienaar if (xferOp.getMask()) { 6635fcf907bSMatthias Springer rewriter.modifyOpInPlace(xferOp, [&]() { 6647c38fd60SJacques Pienaar xferOp.getMaskMutable().assign(buffers.maskBuffer); 6657c38fd60SJacques Pienaar }); 6660f241638SMatthias Springer } 6670f241638SMatthias Springer 6680f241638SMatthias Springer return success(); 6690f241638SMatthias Springer } 6700f241638SMatthias Springer }; 6710f241638SMatthias Springer 672f36e909dSBenjamin Maxwell /// Decompose a n-D PrintOp into a loop of elementary/scalar prints. This allows 673f36e909dSBenjamin Maxwell /// printing both 1D scalable vectors and n-D fixed size vectors. 674f36e909dSBenjamin Maxwell /// 675f36e909dSBenjamin Maxwell /// E.g.: 676f36e909dSBenjamin Maxwell /// ``` 677f36e909dSBenjamin Maxwell /// vector.print %v : vector<[4]xi32> 678f36e909dSBenjamin Maxwell /// ``` 679f36e909dSBenjamin Maxwell /// is rewritten to: 680f36e909dSBenjamin Maxwell /// ``` 681f36e909dSBenjamin Maxwell /// %c0 = arith.constant 0 : index 682f36e909dSBenjamin Maxwell /// %c4 = arith.constant 4 : index 683f36e909dSBenjamin Maxwell /// %c1 = arith.constant 1 : index 684f36e909dSBenjamin Maxwell /// %vscale = vector.vscale 685f36e909dSBenjamin Maxwell /// %length = arith.muli %vscale, %c4 : index 686f36e909dSBenjamin Maxwell /// %lastIndex = arith.subi %length, %c1 : index 687f36e909dSBenjamin Maxwell /// vector.print punctuation <open> 688f36e909dSBenjamin Maxwell /// scf.for %i = %c0 to %length step %c1 { 689f36e909dSBenjamin Maxwell /// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> 690f36e909dSBenjamin Maxwell /// vector.print %el : i32 punctuation <no_punctuation> 691f36e909dSBenjamin Maxwell /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index 692f36e909dSBenjamin Maxwell /// scf.if %notLastIndex { 693f36e909dSBenjamin Maxwell /// vector.print punctuation <comma> 694f36e909dSBenjamin Maxwell /// } 695f36e909dSBenjamin Maxwell /// } 696f36e909dSBenjamin Maxwell /// vector.print punctuation <close> 697f36e909dSBenjamin Maxwell /// vector.print 698f36e909dSBenjamin Maxwell /// ``` 699f36e909dSBenjamin Maxwell struct DecomposePrintOpConversion : public VectorToSCFPattern<vector::PrintOp> { 700f36e909dSBenjamin Maxwell using VectorToSCFPattern<vector::PrintOp>::VectorToSCFPattern; 701f36e909dSBenjamin Maxwell LogicalResult matchAndRewrite(vector::PrintOp printOp, 702f36e909dSBenjamin Maxwell PatternRewriter &rewriter) const override { 703f36e909dSBenjamin Maxwell if (!printOp.getSource()) 704f36e909dSBenjamin Maxwell return failure(); 705f36e909dSBenjamin Maxwell 706f36e909dSBenjamin Maxwell VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType()); 707f36e909dSBenjamin Maxwell if (!vectorType) 708f36e909dSBenjamin Maxwell return failure(); 709f36e909dSBenjamin Maxwell 710f36e909dSBenjamin Maxwell // Currently >= 2D scalable vectors are not supported. 711f36e909dSBenjamin Maxwell // These can't be lowered to LLVM (as LLVM does not support scalable vectors 712f36e909dSBenjamin Maxwell // of scalable vectors), and due to limitations of current ops can't be 713f36e909dSBenjamin Maxwell // indexed with SSA values or flattened. This may change after 714f36e909dSBenjamin Maxwell // https://reviews.llvm.org/D155034, though there still needs to be a path 715f36e909dSBenjamin Maxwell // for lowering to LLVM. 716f36e909dSBenjamin Maxwell if (vectorType.getRank() > 1 && vectorType.isScalable()) 717f36e909dSBenjamin Maxwell return failure(); 718f36e909dSBenjamin Maxwell 719f36e909dSBenjamin Maxwell auto loc = printOp.getLoc(); 720f36e909dSBenjamin Maxwell auto value = printOp.getSource(); 721f36e909dSBenjamin Maxwell 722f36e909dSBenjamin Maxwell if (auto intTy = dyn_cast<IntegerType>(vectorType.getElementType())) { 723f36e909dSBenjamin Maxwell // Oddly sized integers are (somewhat) buggy on a lot of backends, so to 724f36e909dSBenjamin Maxwell // avoid issues extend them to a more standard size. 725f36e909dSBenjamin Maxwell // https://github.com/llvm/llvm-project/issues/30613 726f36e909dSBenjamin Maxwell auto width = intTy.getWidth(); 727f36e909dSBenjamin Maxwell auto legalWidth = llvm::NextPowerOf2(std::max(8u, width) - 1); 728f36e909dSBenjamin Maxwell auto legalIntTy = IntegerType::get(rewriter.getContext(), legalWidth, 729f36e909dSBenjamin Maxwell intTy.getSignedness()); 730f36e909dSBenjamin Maxwell // arith can only take signless integers, so we must cast back and forth. 731f36e909dSBenjamin Maxwell auto signlessSourceVectorType = 732f36e909dSBenjamin Maxwell vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(intTy)); 733f36e909dSBenjamin Maxwell auto signlessTargetVectorType = 734f36e909dSBenjamin Maxwell vectorType.cloneWith({}, getIntTypeWithSignlessSemantics(legalIntTy)); 735f36e909dSBenjamin Maxwell auto targetVectorType = vectorType.cloneWith({}, legalIntTy); 736f36e909dSBenjamin Maxwell value = rewriter.create<vector::BitCastOp>(loc, signlessSourceVectorType, 737f36e909dSBenjamin Maxwell value); 7388f9aac44SMatthias Springer if (value.getType() != signlessTargetVectorType) { 739f36e909dSBenjamin Maxwell if (width == 1 || intTy.isUnsigned()) 740f36e909dSBenjamin Maxwell value = rewriter.create<arith::ExtUIOp>(loc, signlessTargetVectorType, 741f36e909dSBenjamin Maxwell value); 742f36e909dSBenjamin Maxwell else 743f36e909dSBenjamin Maxwell value = rewriter.create<arith::ExtSIOp>(loc, signlessTargetVectorType, 744f36e909dSBenjamin Maxwell value); 7458f9aac44SMatthias Springer } 746f36e909dSBenjamin Maxwell value = rewriter.create<vector::BitCastOp>(loc, targetVectorType, value); 747f36e909dSBenjamin Maxwell vectorType = targetVectorType; 748f36e909dSBenjamin Maxwell } 749f36e909dSBenjamin Maxwell 750f36e909dSBenjamin Maxwell auto scalableDimensions = vectorType.getScalableDims(); 751f36e909dSBenjamin Maxwell auto shape = vectorType.getShape(); 752f36e909dSBenjamin Maxwell constexpr int64_t singletonShape[] = {1}; 753f36e909dSBenjamin Maxwell if (vectorType.getRank() == 0) 754f36e909dSBenjamin Maxwell shape = singletonShape; 755f36e909dSBenjamin Maxwell 756f36e909dSBenjamin Maxwell if (vectorType.getRank() != 1) { 757f36e909dSBenjamin Maxwell // Flatten n-D vectors to 1D. This is done to allow indexing with a 758f36e909dSBenjamin Maxwell // non-constant value (which can currently only be done via 759f36e909dSBenjamin Maxwell // vector.extractelement for 1D vectors). 760f36e909dSBenjamin Maxwell auto flatLength = std::accumulate(shape.begin(), shape.end(), 1, 761f36e909dSBenjamin Maxwell std::multiplies<int64_t>()); 762f36e909dSBenjamin Maxwell auto flatVectorType = 763f36e909dSBenjamin Maxwell VectorType::get({flatLength}, vectorType.getElementType()); 764f36e909dSBenjamin Maxwell value = rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, value); 765f36e909dSBenjamin Maxwell } 766f36e909dSBenjamin Maxwell 767f36e909dSBenjamin Maxwell vector::PrintOp firstClose; 768f36e909dSBenjamin Maxwell SmallVector<Value, 8> loopIndices; 769f36e909dSBenjamin Maxwell for (unsigned d = 0; d < shape.size(); d++) { 770f36e909dSBenjamin Maxwell // Setup loop bounds and step. 771f36e909dSBenjamin Maxwell Value lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); 772f36e909dSBenjamin Maxwell Value upperBound = rewriter.create<arith::ConstantIndexOp>(loc, shape[d]); 773f36e909dSBenjamin Maxwell Value step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 774f36e909dSBenjamin Maxwell if (!scalableDimensions.empty() && scalableDimensions[d]) { 775f36e909dSBenjamin Maxwell auto vscale = rewriter.create<vector::VectorScaleOp>( 776f36e909dSBenjamin Maxwell loc, rewriter.getIndexType()); 777f36e909dSBenjamin Maxwell upperBound = rewriter.create<arith::MulIOp>(loc, upperBound, vscale); 778f36e909dSBenjamin Maxwell } 779f36e909dSBenjamin Maxwell auto lastIndex = rewriter.create<arith::SubIOp>(loc, upperBound, step); 780f36e909dSBenjamin Maxwell 781f36e909dSBenjamin Maxwell // Create a loop to print the elements surrounded by parentheses. 782f36e909dSBenjamin Maxwell rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open); 783f36e909dSBenjamin Maxwell auto loop = 784f36e909dSBenjamin Maxwell rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); 785f36e909dSBenjamin Maxwell auto printClose = rewriter.create<vector::PrintOp>( 786f36e909dSBenjamin Maxwell loc, vector::PrintPunctuation::Close); 787f36e909dSBenjamin Maxwell if (!firstClose) 788f36e909dSBenjamin Maxwell firstClose = printClose; 789f36e909dSBenjamin Maxwell 790f36e909dSBenjamin Maxwell auto loopIdx = loop.getInductionVar(); 791f36e909dSBenjamin Maxwell loopIndices.push_back(loopIdx); 792f36e909dSBenjamin Maxwell 793f36e909dSBenjamin Maxwell // Print a comma after all but the last element. 794f36e909dSBenjamin Maxwell rewriter.setInsertionPointToStart(loop.getBody()); 795f36e909dSBenjamin Maxwell auto notLastIndex = rewriter.create<arith::CmpIOp>( 796f36e909dSBenjamin Maxwell loc, arith::CmpIPredicate::ult, loopIdx, lastIndex); 797f36e909dSBenjamin Maxwell rewriter.create<scf::IfOp>(loc, notLastIndex, 798f36e909dSBenjamin Maxwell [&](OpBuilder &builder, Location loc) { 799f36e909dSBenjamin Maxwell builder.create<vector::PrintOp>( 800f36e909dSBenjamin Maxwell loc, vector::PrintPunctuation::Comma); 801f36e909dSBenjamin Maxwell builder.create<scf::YieldOp>(loc); 802f36e909dSBenjamin Maxwell }); 803f36e909dSBenjamin Maxwell 804f36e909dSBenjamin Maxwell rewriter.setInsertionPointToStart(loop.getBody()); 805f36e909dSBenjamin Maxwell } 806f36e909dSBenjamin Maxwell 807f36e909dSBenjamin Maxwell // Compute the flattened index. 808f36e909dSBenjamin Maxwell // Note: For the > rank 1 vectors this assumes non-scalable. 809f36e909dSBenjamin Maxwell Value flatIndex; 810f36e909dSBenjamin Maxwell auto currentStride = 1; 811f36e909dSBenjamin Maxwell for (int d = shape.size() - 1; d >= 0; d--) { 812f36e909dSBenjamin Maxwell auto stride = rewriter.create<arith::ConstantIndexOp>(loc, currentStride); 813f36e909dSBenjamin Maxwell auto index = rewriter.create<arith::MulIOp>(loc, stride, loopIndices[d]); 814f36e909dSBenjamin Maxwell if (flatIndex) 815f36e909dSBenjamin Maxwell flatIndex = rewriter.create<arith::AddIOp>(loc, flatIndex, index); 816f36e909dSBenjamin Maxwell else 817f36e909dSBenjamin Maxwell flatIndex = index; 818f36e909dSBenjamin Maxwell currentStride *= shape[d]; 819f36e909dSBenjamin Maxwell } 820f36e909dSBenjamin Maxwell 821f36e909dSBenjamin Maxwell // Print the scalar elements in the inner most loop. 822f36e909dSBenjamin Maxwell auto element = 823f36e909dSBenjamin Maxwell rewriter.create<vector::ExtractElementOp>(loc, value, flatIndex); 824f36e909dSBenjamin Maxwell rewriter.create<vector::PrintOp>(loc, element, 825f36e909dSBenjamin Maxwell vector::PrintPunctuation::NoPunctuation); 826f36e909dSBenjamin Maxwell 827f36e909dSBenjamin Maxwell rewriter.setInsertionPointAfter(firstClose); 828f36e909dSBenjamin Maxwell rewriter.create<vector::PrintOp>(loc, printOp.getPunctuation()); 829f36e909dSBenjamin Maxwell rewriter.eraseOp(printOp); 830f36e909dSBenjamin Maxwell return success(); 831f36e909dSBenjamin Maxwell } 832f36e909dSBenjamin Maxwell 833f36e909dSBenjamin Maxwell static IntegerType getIntTypeWithSignlessSemantics(IntegerType intTy) { 834f36e909dSBenjamin Maxwell return IntegerType::get(intTy.getContext(), intTy.getWidth(), 835f36e909dSBenjamin Maxwell IntegerType::Signless); 836f36e909dSBenjamin Maxwell }; 837f36e909dSBenjamin Maxwell }; 838f36e909dSBenjamin Maxwell 8390f241638SMatthias Springer /// Progressive lowering of vector transfer ops: Unpack one dimension. 8400f241638SMatthias Springer /// 8410f241638SMatthias Springer /// 1. Unpack one dimension from the current buffer type and cast the buffer 8420f241638SMatthias Springer /// to that new type. E.g.: 8430f241638SMatthias Springer /// ``` 8440f241638SMatthias Springer /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> 8450f241638SMatthias Springer /// vector.transfer_write %vec ... 8460f241638SMatthias Springer /// ``` 8470f241638SMatthias Springer /// The following cast is generated: 8480f241638SMatthias Springer /// ``` 8490f241638SMatthias Springer /// %casted = vector.type_cast %0 8500f241638SMatthias Springer /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 8510f241638SMatthias Springer /// ``` 8520f241638SMatthias Springer /// 2. Generate a for loop and rewrite the transfer op according to the 8530f241638SMatthias Springer /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be 8540f241638SMatthias Springer /// out-of-bounds, generate an if-check and handle both cases separately. 8550f241638SMatthias Springer /// 3. Clean up according to the corresponding Strategy<OpTy>. 856558e7401SMatthias Springer /// 857558e7401SMatthias Springer /// Note: If the transfer op is a TransferWriteOp and operates on a tensor 858558e7401SMatthias Springer /// source (as opposed to a memref source), then each iteration of the generated 859558e7401SMatthias Springer /// scf.for loop yields the new tensor value. E.g.: 860558e7401SMatthias Springer /// ``` 861558e7401SMatthias Springer /// %result = scf.for i = 0 to 5 { 862558e7401SMatthias Springer /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> 863558e7401SMatthias Springer /// %1 = vector.transfer_write %0, %source[...] 864558e7401SMatthias Springer /// : vector<4x3xf32>, tensor<5x4x3xf32> 865558e7401SMatthias Springer /// scf.yield %1 : tensor<5x4x3xf32> 866558e7401SMatthias Springer /// } 867558e7401SMatthias Springer /// ``` 8680f241638SMatthias Springer template <typename OpTy> 8692ca887deSMatthias Springer struct TransferOpConversion : public VectorToSCFPattern<OpTy> { 8702ca887deSMatthias Springer using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 8710f241638SMatthias Springer 872700b64dcSMatthias Springer void initialize() { 873700b64dcSMatthias Springer // This pattern recursively unpacks one dimension at a time. The recursion 874700b64dcSMatthias Springer // bounded as the rank is strictly decreasing. 875700b64dcSMatthias Springer this->setHasBoundedRewriteRecursion(); 876700b64dcSMatthias Springer } 877700b64dcSMatthias Springer 8786b21948fSRik Huijzer static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer, 8796b21948fSRik Huijzer SmallVectorImpl<Value> &loadIndices, 8806b21948fSRik Huijzer Value iv) { 8816b21948fSRik Huijzer assert(xferOp.getMask() && "Expected transfer op to have mask"); 8826b21948fSRik Huijzer 8836b21948fSRik Huijzer // Add load indices from the previous iteration. 8846b21948fSRik Huijzer // The mask buffer depends on the permutation map, which makes determining 8856b21948fSRik Huijzer // the indices quite complex, so this is why we need to "look back" to the 8866b21948fSRik Huijzer // previous iteration to find the right indices. 8876b21948fSRik Huijzer Value maskBuffer = getMaskBuffer(xferOp); 8886b21948fSRik Huijzer for (Operation *user : maskBuffer.getUsers()) { 8896b21948fSRik Huijzer // If there is no previous load op, then the indices are empty. 8906b21948fSRik Huijzer if (auto loadOp = dyn_cast<memref::LoadOp>(user)) { 8916b21948fSRik Huijzer Operation::operand_range prevIndices = loadOp.getIndices(); 8926b21948fSRik Huijzer loadIndices.append(prevIndices.begin(), prevIndices.end()); 8936b21948fSRik Huijzer break; 8946b21948fSRik Huijzer } 8956b21948fSRik Huijzer } 8966b21948fSRik Huijzer 8976b21948fSRik Huijzer // In case of broadcast: Use same indices to load from memref 8986b21948fSRik Huijzer // as before. 8996b21948fSRik Huijzer if (!xferOp.isBroadcastDim(0)) 9006b21948fSRik Huijzer loadIndices.push_back(iv); 9016b21948fSRik Huijzer } 9026b21948fSRik Huijzer 9030f241638SMatthias Springer LogicalResult matchAndRewrite(OpTy xferOp, 9040f241638SMatthias Springer PatternRewriter &rewriter) const override { 9050f241638SMatthias Springer if (!xferOp->hasAttr(kPassLabel)) 9060f241638SMatthias Springer return failure(); 9070f241638SMatthias Springer 9080f241638SMatthias Springer // Find and cast data buffer. How the buffer can be found depends on OpTy. 9096825bfe2SNicolas Vasilache ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); 9106b21948fSRik Huijzer Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp); 9115550c821STres Popp auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType()); 9126b21948fSRik Huijzer FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType); 9132a82dfd7SBenjamin Maxwell if (failed(castedDataType)) 9142a82dfd7SBenjamin Maxwell return failure(); 9152a82dfd7SBenjamin Maxwell 9166825bfe2SNicolas Vasilache auto castedDataBuffer = 9172a82dfd7SBenjamin Maxwell locB.create<vector::TypeCastOp>(*castedDataType, dataBuffer); 9180f241638SMatthias Springer 9190f241638SMatthias Springer // If the xferOp has a mask: Find and cast mask buffer. 9200f241638SMatthias Springer Value castedMaskBuffer; 9217c38fd60SJacques Pienaar if (xferOp.getMask()) { 9226b21948fSRik Huijzer Value maskBuffer = getMaskBuffer(xferOp); 9230f241638SMatthias Springer if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { 9240f241638SMatthias Springer // Do not unpack a dimension of the mask, if: 9250f241638SMatthias Springer // * To-be-unpacked transfer op dimension is a broadcast. 9260f241638SMatthias Springer // * Mask is 1D, i.e., the mask cannot be further unpacked. 9270f241638SMatthias Springer // (That means that all remaining dimensions of the transfer op must 9280f241638SMatthias Springer // be broadcasted.) 9290f241638SMatthias Springer castedMaskBuffer = maskBuffer; 9300f241638SMatthias Springer } else { 9312a82dfd7SBenjamin Maxwell // It's safe to assume the mask buffer can be unpacked if the data 9322a82dfd7SBenjamin Maxwell // buffer was unpacked. 9336b21948fSRik Huijzer auto maskBufferType = cast<MemRefType>(maskBuffer.getType()); 9346b21948fSRik Huijzer MemRefType castedMaskType = *unpackOneDim(maskBufferType); 9356825bfe2SNicolas Vasilache castedMaskBuffer = 9366825bfe2SNicolas Vasilache locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); 9370f241638SMatthias Springer } 9380f241638SMatthias Springer } 9390f241638SMatthias Springer 9400f241638SMatthias Springer // Loop bounds and step. 941a54f4eaeSMogball auto lb = locB.create<arith::ConstantIndexOp>(0); 942a54f4eaeSMogball auto ub = locB.create<arith::ConstantIndexOp>( 9432a82dfd7SBenjamin Maxwell castedDataType->getDimSize(castedDataType->getRank() - 1)); 944a54f4eaeSMogball auto step = locB.create<arith::ConstantIndexOp>(1); 945558e7401SMatthias Springer // TransferWriteOps that operate on tensors return the modified tensor and 946558e7401SMatthias Springer // require a loop state. 947558e7401SMatthias Springer auto loopState = Strategy<OpTy>::initialLoopState(xferOp); 9480f241638SMatthias Springer 9490f241638SMatthias Springer // Generate for loop. 950558e7401SMatthias Springer auto result = locB.create<scf::ForOp>( 951558e7401SMatthias Springer lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 952558e7401SMatthias Springer [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 953558e7401SMatthias Springer Type stateType = loopState.empty() ? Type() : loopState[0].getType(); 954558e7401SMatthias Springer 955558e7401SMatthias Springer auto result = generateInBoundsCheck( 9566825bfe2SNicolas Vasilache b, xferOp, iv, unpackedDim(xferOp), 957558e7401SMatthias Springer stateType ? TypeRange(stateType) : TypeRange(), 9580f241638SMatthias Springer /*inBoundsCase=*/ 9596825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) { 9600f241638SMatthias Springer // Create new transfer op. 9612ca887deSMatthias Springer OpTy newXfer = Strategy<OpTy>::rewriteOp( 962558e7401SMatthias Springer b, this->options, xferOp, castedDataBuffer, iv, loopState); 9630f241638SMatthias Springer 9640f241638SMatthias Springer // If old transfer op has a mask: Set mask on new transfer op. 9650f241638SMatthias Springer // Special case: If the mask of the old transfer op is 1D and 9666b21948fSRik Huijzer // the unpacked dim is not a broadcast, no mask is needed on 9676b21948fSRik Huijzer // the new transfer op. 9687c38fd60SJacques Pienaar if (xferOp.getMask() && (xferOp.isBroadcastDim(0) || 9690f241638SMatthias Springer xferOp.getMaskType().getRank() > 1)) { 9700f241638SMatthias Springer OpBuilder::InsertionGuard guard(b); 9710f241638SMatthias Springer b.setInsertionPoint(newXfer); // Insert load before newXfer. 9720f241638SMatthias Springer 9730f241638SMatthias Springer SmallVector<Value, 8> loadIndices; 9746b21948fSRik Huijzer getMaskBufferLoadIndices(xferOp, castedMaskBuffer, 9756b21948fSRik Huijzer loadIndices, iv); 9766825bfe2SNicolas Vasilache auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, 9776825bfe2SNicolas Vasilache loadIndices); 9785fcf907bSMatthias Springer rewriter.modifyOpInPlace(newXfer, [&]() { 9797c38fd60SJacques Pienaar newXfer.getMaskMutable().assign(mask); 9807c38fd60SJacques Pienaar }); 9810f241638SMatthias Springer } 982558e7401SMatthias Springer 983558e7401SMatthias Springer return loopState.empty() ? Value() : newXfer->getResult(0); 9840f241638SMatthias Springer }, 9850f241638SMatthias Springer /*outOfBoundsCase=*/ 9860f241638SMatthias Springer [&](OpBuilder &b, Location /*loc*/) { 987558e7401SMatthias Springer return Strategy<OpTy>::handleOutOfBoundsDim( 988558e7401SMatthias Springer b, xferOp, castedDataBuffer, iv, loopState); 9890f241638SMatthias Springer }); 9900f241638SMatthias Springer 991558e7401SMatthias Springer maybeYieldValue(b, loc, !loopState.empty(), result); 992558e7401SMatthias Springer }); 993558e7401SMatthias Springer 994558e7401SMatthias Springer Strategy<OpTy>::cleanup(rewriter, xferOp, result); 9950f241638SMatthias Springer return success(); 9960f241638SMatthias Springer } 9970f241638SMatthias Springer }; 9980f241638SMatthias Springer 99927a713f5SBenjamin Maxwell /// Retrieves the dimensions sizes of a mask. Currently supports CreateMaskOp 100027a713f5SBenjamin Maxwell /// and ConstantMaskOp. 100127a713f5SBenjamin Maxwell template <typename VscaleConstantBuilder> 100227a713f5SBenjamin Maxwell static FailureOr<SmallVector<OpFoldResult>> 100327a713f5SBenjamin Maxwell getMaskDimSizes(Value mask, VscaleConstantBuilder &createVscaleMultiple) { 100427a713f5SBenjamin Maxwell if (!mask) 100527a713f5SBenjamin Maxwell return SmallVector<OpFoldResult>{}; 100627a713f5SBenjamin Maxwell if (auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>()) { 100727a713f5SBenjamin Maxwell return llvm::map_to_vector(createMaskOp.getOperands(), [](Value dimSize) { 100827a713f5SBenjamin Maxwell return OpFoldResult(dimSize); 100927a713f5SBenjamin Maxwell }); 101027a713f5SBenjamin Maxwell } 101127a713f5SBenjamin Maxwell if (auto constantMask = mask.getDefiningOp<vector::ConstantMaskOp>()) { 101227a713f5SBenjamin Maxwell int dimIdx = 0; 101327a713f5SBenjamin Maxwell VectorType maskType = constantMask.getVectorType(); 101427a713f5SBenjamin Maxwell auto indexType = IndexType::get(mask.getContext()); 101527a713f5SBenjamin Maxwell return llvm::map_to_vector( 101627a713f5SBenjamin Maxwell constantMask.getMaskDimSizes(), [&](int64_t dimSize) { 101727a713f5SBenjamin Maxwell // A scalable dim in a constant_mask means vscale x dimSize. 101827a713f5SBenjamin Maxwell if (maskType.getScalableDims()[dimIdx++]) 101927a713f5SBenjamin Maxwell return OpFoldResult(createVscaleMultiple(dimSize)); 102027a713f5SBenjamin Maxwell return OpFoldResult(IntegerAttr::get(indexType, dimSize)); 102127a713f5SBenjamin Maxwell }); 102227a713f5SBenjamin Maxwell } 102327a713f5SBenjamin Maxwell return failure(); 102427a713f5SBenjamin Maxwell } 102527a713f5SBenjamin Maxwell 102627a713f5SBenjamin Maxwell /// Scalable vector lowering of transfer_write(transpose). This lowering only 102727a713f5SBenjamin Maxwell /// supports rank 2 (scalable) vectors, but can be used in conjunction with 102827a713f5SBenjamin Maxwell /// `UnrollTransferWriteConversion` to support n-D cases. The unroll conversion 102927a713f5SBenjamin Maxwell /// unrolls until the first scalable dimension. 103027a713f5SBenjamin Maxwell /// 103127a713f5SBenjamin Maxwell /// Example: 103227a713f5SBenjamin Maxwell /// 103327a713f5SBenjamin Maxwell /// BEFORE: 103427a713f5SBenjamin Maxwell /// ```mlir 103527a713f5SBenjamin Maxwell /// %transpose = vector.transpose %vec, [1, 0] 103627a713f5SBenjamin Maxwell /// : vector<4x[4]xf32> to vector<[4]x4xf32> 103727a713f5SBenjamin Maxwell /// vector.transfer_write %transpose, %dest[%i, %j] {in_bounds = [true, true]} 103827a713f5SBenjamin Maxwell /// : vector<[4]x4xf32>, memref<?x?xf32> 103927a713f5SBenjamin Maxwell /// ``` 104027a713f5SBenjamin Maxwell /// 104127a713f5SBenjamin Maxwell /// AFTER: 104227a713f5SBenjamin Maxwell /// ```mlir 104327a713f5SBenjamin Maxwell /// %c1 = arith.constant 1 : index 104427a713f5SBenjamin Maxwell /// %c4 = arith.constant 4 : index 104527a713f5SBenjamin Maxwell /// %c0 = arith.constant 0 : index 104627a713f5SBenjamin Maxwell /// %0 = vector.extract %arg0[0] : vector<[4]xf32> from vector<4x[4]xf32> 104727a713f5SBenjamin Maxwell /// %1 = vector.extract %arg0[1] : vector<[4]xf32> from vector<4x[4]xf32> 104827a713f5SBenjamin Maxwell /// %2 = vector.extract %arg0[2] : vector<[4]xf32> from vector<4x[4]xf32> 104927a713f5SBenjamin Maxwell /// %3 = vector.extract %arg0[3] : vector<[4]xf32> from vector<4x[4]xf32> 105027a713f5SBenjamin Maxwell /// %vscale = vector.vscale 105127a713f5SBenjamin Maxwell /// %c4_vscale = arith.muli %vscale, %c4 : index 105227a713f5SBenjamin Maxwell /// scf.for %idx = %c0 to %c4_vscale step %c1 { 105327a713f5SBenjamin Maxwell /// %4 = vector.extract %0[%idx] : f32 from vector<[4]xf32> 105427a713f5SBenjamin Maxwell /// %5 = vector.extract %1[%idx] : f32 from vector<[4]xf32> 105527a713f5SBenjamin Maxwell /// %6 = vector.extract %2[%idx] : f32 from vector<[4]xf32> 105627a713f5SBenjamin Maxwell /// %7 = vector.extract %3[%idx] : f32 from vector<[4]xf32> 105727a713f5SBenjamin Maxwell /// %slice_i = affine.apply #map(%idx)[%i] 105827a713f5SBenjamin Maxwell /// %slice = vector.from_elements %4, %5, %6, %7 : vector<4xf32> 105927a713f5SBenjamin Maxwell /// vector.transfer_write %slice, %arg1[%slice_i, %j] {in_bounds = [true]} 106027a713f5SBenjamin Maxwell /// : vector<4xf32>, memref<?x?xf32> 106127a713f5SBenjamin Maxwell /// } 106227a713f5SBenjamin Maxwell /// ``` 106327a713f5SBenjamin Maxwell struct ScalableTransposeTransferWriteConversion 106427a713f5SBenjamin Maxwell : VectorToSCFPattern<vector::TransferWriteOp> { 106527a713f5SBenjamin Maxwell using VectorToSCFPattern::VectorToSCFPattern; 106627a713f5SBenjamin Maxwell 106727a713f5SBenjamin Maxwell LogicalResult matchAndRewrite(TransferWriteOp writeOp, 106827a713f5SBenjamin Maxwell PatternRewriter &rewriter) const override { 106927a713f5SBenjamin Maxwell if (failed(checkLowerTensors(writeOp, rewriter))) 107027a713f5SBenjamin Maxwell return failure(); 107127a713f5SBenjamin Maxwell 107227a713f5SBenjamin Maxwell VectorType vectorType = writeOp.getVectorType(); 107327a713f5SBenjamin Maxwell 107427a713f5SBenjamin Maxwell // Note: By comparing the scalable dims to an ArrayRef of length two this 107527a713f5SBenjamin Maxwell // implicitly checks the rank (is also two). 107627a713f5SBenjamin Maxwell ArrayRef<bool> scalableFlags = vectorType.getScalableDims(); 107727a713f5SBenjamin Maxwell if (scalableFlags != ArrayRef<bool>{true, false}) { 107827a713f5SBenjamin Maxwell return rewriter.notifyMatchFailure( 107927a713f5SBenjamin Maxwell writeOp, "expected vector of the form vector<[N]xMxty>"); 108027a713f5SBenjamin Maxwell } 108127a713f5SBenjamin Maxwell 108227a713f5SBenjamin Maxwell auto permutationMap = writeOp.getPermutationMap(); 108327a713f5SBenjamin Maxwell if (!permutationMap.isIdentity()) { 108427a713f5SBenjamin Maxwell return rewriter.notifyMatchFailure( 108527a713f5SBenjamin Maxwell writeOp, "non-identity permutations are unsupported (lower first)"); 108627a713f5SBenjamin Maxwell } 108727a713f5SBenjamin Maxwell 108827a713f5SBenjamin Maxwell // Note: This pattern is only lowering the leading dimension (to a loop), 108927a713f5SBenjamin Maxwell // so we only check if the leading dimension is in bounds. The in-bounds 109027a713f5SBenjamin Maxwell // attribute for the trailing dimension will be propagated. 109127a713f5SBenjamin Maxwell if (!writeOp.isDimInBounds(0)) { 109227a713f5SBenjamin Maxwell return rewriter.notifyMatchFailure( 109327a713f5SBenjamin Maxwell writeOp, "out-of-bounds dims are unsupported (use masking)"); 109427a713f5SBenjamin Maxwell } 109527a713f5SBenjamin Maxwell 109627a713f5SBenjamin Maxwell Value vector = writeOp.getVector(); 109727a713f5SBenjamin Maxwell auto transposeOp = vector.getDefiningOp<vector::TransposeOp>(); 109827a713f5SBenjamin Maxwell if (!transposeOp || 109927a713f5SBenjamin Maxwell transposeOp.getPermutation() != ArrayRef<int64_t>{1, 0}) { 110027a713f5SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, "source not transpose"); 110127a713f5SBenjamin Maxwell } 110227a713f5SBenjamin Maxwell 110327a713f5SBenjamin Maxwell auto loc = writeOp.getLoc(); 110427a713f5SBenjamin Maxwell auto createVscaleMultiple = 110527a713f5SBenjamin Maxwell vector::makeVscaleConstantBuilder(rewriter, loc); 110627a713f5SBenjamin Maxwell 110727a713f5SBenjamin Maxwell auto maskDims = getMaskDimSizes(writeOp.getMask(), createVscaleMultiple); 110827a713f5SBenjamin Maxwell if (failed(maskDims)) { 110927a713f5SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, 111027a713f5SBenjamin Maxwell "failed to resolve mask dims"); 111127a713f5SBenjamin Maxwell } 111227a713f5SBenjamin Maxwell 111327a713f5SBenjamin Maxwell int64_t fixedDimSize = vectorType.getDimSize(1); 111427a713f5SBenjamin Maxwell auto fixedDimOffsets = llvm::seq(fixedDimSize); 111527a713f5SBenjamin Maxwell 111627a713f5SBenjamin Maxwell // Extract all slices from the source of the transpose. 111727a713f5SBenjamin Maxwell auto transposeSource = transposeOp.getVector(); 111827a713f5SBenjamin Maxwell SmallVector<Value> transposeSourceSlices = 111927a713f5SBenjamin Maxwell llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value { 112027a713f5SBenjamin Maxwell return rewriter.create<vector::ExtractOp>(loc, transposeSource, idx); 112127a713f5SBenjamin Maxwell }); 112227a713f5SBenjamin Maxwell 112327a713f5SBenjamin Maxwell // Loop bounds and step. 112427a713f5SBenjamin Maxwell auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); 112527a713f5SBenjamin Maxwell auto ub = 112627a713f5SBenjamin Maxwell maskDims->empty() 112727a713f5SBenjamin Maxwell ? Value(createVscaleMultiple(vectorType.getDimSize(0))) 112827a713f5SBenjamin Maxwell : vector::getAsValues(rewriter, loc, maskDims->front()).front(); 112927a713f5SBenjamin Maxwell auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 113027a713f5SBenjamin Maxwell 113127a713f5SBenjamin Maxwell // Generate a new mask for the slice. 113227a713f5SBenjamin Maxwell VectorType sliceType = VectorType::Builder(vectorType).dropDim(0); 113327a713f5SBenjamin Maxwell Value sliceMask = nullptr; 113427a713f5SBenjamin Maxwell if (!maskDims->empty()) { 113527a713f5SBenjamin Maxwell sliceMask = rewriter.create<vector::CreateMaskOp>( 113627a713f5SBenjamin Maxwell loc, sliceType.clone(rewriter.getI1Type()), 113727a713f5SBenjamin Maxwell ArrayRef<OpFoldResult>(*maskDims).drop_front()); 113827a713f5SBenjamin Maxwell } 113927a713f5SBenjamin Maxwell 114027a713f5SBenjamin Maxwell Value initDest = isTensorOp(writeOp) ? writeOp.getSource() : Value{}; 114127a713f5SBenjamin Maxwell ValueRange initLoopArgs = initDest ? initDest : ValueRange{}; 114227a713f5SBenjamin Maxwell auto result = rewriter.create<scf::ForOp>( 114327a713f5SBenjamin Maxwell loc, lb, ub, step, initLoopArgs, 114427a713f5SBenjamin Maxwell [&](OpBuilder &b, Location loc, Value iv, ValueRange loopIterArgs) { 114527a713f5SBenjamin Maxwell // Indices for the new transfer op. 114627a713f5SBenjamin Maxwell SmallVector<Value, 8> xferIndices; 114727a713f5SBenjamin Maxwell getXferIndices(b, writeOp, iv, xferIndices); 114827a713f5SBenjamin Maxwell 114927a713f5SBenjamin Maxwell // Extract a transposed slice from the source vector. 115027a713f5SBenjamin Maxwell SmallVector<Value> transposeElements = 115127a713f5SBenjamin Maxwell llvm::map_to_vector(fixedDimOffsets, [&](int64_t idx) -> Value { 115227a713f5SBenjamin Maxwell return b.create<vector::ExtractOp>( 115327a713f5SBenjamin Maxwell loc, transposeSourceSlices[idx], iv); 115427a713f5SBenjamin Maxwell }); 115527a713f5SBenjamin Maxwell auto sliceVec = b.create<vector::FromElementsOp>(loc, sliceType, 115627a713f5SBenjamin Maxwell transposeElements); 115727a713f5SBenjamin Maxwell 115827a713f5SBenjamin Maxwell // Create the transfer_write for the slice. 115927a713f5SBenjamin Maxwell Value dest = 116027a713f5SBenjamin Maxwell loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front(); 116127a713f5SBenjamin Maxwell auto newWriteOp = b.create<vector::TransferWriteOp>( 116227a713f5SBenjamin Maxwell loc, sliceVec, dest, xferIndices, 116327a713f5SBenjamin Maxwell ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()); 116427a713f5SBenjamin Maxwell if (sliceMask) 116527a713f5SBenjamin Maxwell newWriteOp.getMaskMutable().assign(sliceMask); 116627a713f5SBenjamin Maxwell 116727a713f5SBenjamin Maxwell // Yield from the loop. 116827a713f5SBenjamin Maxwell b.create<scf::YieldOp>(loc, loopIterArgs.empty() 116927a713f5SBenjamin Maxwell ? ValueRange{} 117027a713f5SBenjamin Maxwell : newWriteOp.getResult()); 117127a713f5SBenjamin Maxwell }); 117227a713f5SBenjamin Maxwell 117327a713f5SBenjamin Maxwell if (isTensorOp(writeOp)) 117427a713f5SBenjamin Maxwell rewriter.replaceOp(writeOp, result); 117527a713f5SBenjamin Maxwell else 117627a713f5SBenjamin Maxwell rewriter.eraseOp(writeOp); 117727a713f5SBenjamin Maxwell 117827a713f5SBenjamin Maxwell return success(); 117927a713f5SBenjamin Maxwell } 118027a713f5SBenjamin Maxwell }; 118127a713f5SBenjamin Maxwell 1182a088bed4SMatthias Springer } // namespace lowering_n_d 1183a088bed4SMatthias Springer 1184a088bed4SMatthias Springer namespace lowering_n_d_unrolled { 1185a088bed4SMatthias Springer 11860f241638SMatthias Springer /// If the original transfer op has a mask, compute the mask of the new transfer 11870f241638SMatthias Springer /// op (for the current iteration `i`) and assign it. 11880f241638SMatthias Springer template <typename OpTy> 11896825bfe2SNicolas Vasilache static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, 11900f241638SMatthias Springer int64_t i) { 11917c38fd60SJacques Pienaar if (!xferOp.getMask()) 11920f241638SMatthias Springer return; 11930f241638SMatthias Springer 11940f241638SMatthias Springer if (xferOp.isBroadcastDim(0)) { 11950f241638SMatthias Springer // To-be-unpacked dimension is a broadcast, which does not have a 11960f241638SMatthias Springer // corresponding mask dimension. Mask attribute remains unchanged. 11977c38fd60SJacques Pienaar newXferOp.getMaskMutable().assign(xferOp.getMask()); 11980f241638SMatthias Springer return; 11990f241638SMatthias Springer } 12000f241638SMatthias Springer 12010f241638SMatthias Springer if (xferOp.getMaskType().getRank() > 1) { 12020f241638SMatthias Springer // Unpack one dimension of the mask. 12036825bfe2SNicolas Vasilache OpBuilder::InsertionGuard guard(b); 12046825bfe2SNicolas Vasilache b.setInsertionPoint(newXferOp); // Insert load before newXfer. 12050f241638SMatthias Springer 12060f241638SMatthias Springer llvm::SmallVector<int64_t, 1> indices({i}); 12076825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 12087c38fd60SJacques Pienaar auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices); 12097c38fd60SJacques Pienaar newXferOp.getMaskMutable().assign(newMask); 12100f241638SMatthias Springer } 12110f241638SMatthias Springer 12120f241638SMatthias Springer // If we end up here: The mask of the old transfer op is 1D and the unpacked 12130f241638SMatthias Springer // dim is not a broadcast, so no mask is needed on the new transfer op. 12140f241638SMatthias Springer // `generateInBoundsCheck` will have evaluated the mask already. 12150f241638SMatthias Springer } 12160f241638SMatthias Springer 12170f241638SMatthias Springer /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one 12180f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no 12190f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled. 12200f241638SMatthias Springer /// 12210f241638SMatthias Springer /// ``` 12220f241638SMatthias Springer /// E.g.: 12230f241638SMatthias Springer /// ``` 12240f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %padding 12250f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<5x4xf32> 12260f241638SMatthias Springer /// ``` 12270f241638SMatthias Springer /// is rewritten to IR such as (simplified): 12280f241638SMatthias Springer /// ``` 12290f241638SMatthias Springer /// %v_init = splat %padding : vector<5x4xf32> 12300f241638SMatthias Springer /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding 12310f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4xf32> 12320f241638SMatthias Springer /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> 12330f241638SMatthias Springer /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding 12340f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4xf32> 12350f241638SMatthias Springer /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> 12360f241638SMatthias Springer /// ... 12370f241638SMatthias Springer /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding 12380f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4xf32> 12390f241638SMatthias Springer /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> 12400f241638SMatthias Springer /// ``` 12410f241638SMatthias Springer /// 12420f241638SMatthias Springer /// Note: As an optimization, if the result of the original TransferReadOp 12430f241638SMatthias Springer /// was directly inserted into another vector, no new %v_init vector is created. 12440f241638SMatthias Springer /// Instead, the new TransferReadOp results are inserted into that vector. 12452ca887deSMatthias Springer struct UnrollTransferReadConversion 12462ca887deSMatthias Springer : public VectorToSCFPattern<TransferReadOp> { 12472ca887deSMatthias Springer using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 12480f241638SMatthias Springer 1249700b64dcSMatthias Springer void initialize() { 1250700b64dcSMatthias Springer // This pattern recursively unpacks one dimension at a time. The recursion 1251700b64dcSMatthias Springer // bounded as the rank is strictly decreasing. 1252700b64dcSMatthias Springer setHasBoundedRewriteRecursion(); 1253700b64dcSMatthias Springer } 1254700b64dcSMatthias Springer 1255aa2dc792SMatthias Springer /// Get or build the vector into which the newly created TransferReadOp 1256aa2dc792SMatthias Springer /// results are inserted. 1257aa2dc792SMatthias Springer Value buildResultVector(PatternRewriter &rewriter, 1258aa2dc792SMatthias Springer TransferReadOp xferOp) const { 12590f241638SMatthias Springer if (auto insertOp = getInsertOp(xferOp)) 12607c38fd60SJacques Pienaar return insertOp.getDest(); 12616825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 12626a8ba318SRiver Riddle return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(), 12637c38fd60SJacques Pienaar xferOp.getPadding()); 12640f241638SMatthias Springer } 12650f241638SMatthias Springer 12660f241638SMatthias Springer /// If the result of the TransferReadOp has exactly one user, which is a 12670f241638SMatthias Springer /// vector::InsertOp, return that operation. 12680f241638SMatthias Springer vector::InsertOp getInsertOp(TransferReadOp xferOp) const { 12690f241638SMatthias Springer if (xferOp->hasOneUse()) { 12700f241638SMatthias Springer Operation *xferOpUser = *xferOp->getUsers().begin(); 12710f241638SMatthias Springer if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) 12720f241638SMatthias Springer return insertOp; 12730f241638SMatthias Springer } 12740f241638SMatthias Springer 12750f241638SMatthias Springer return vector::InsertOp(); 12760f241638SMatthias Springer } 12770f241638SMatthias Springer 12780f241638SMatthias Springer /// If the result of the TransferReadOp has exactly one user, which is a 12790f241638SMatthias Springer /// vector::InsertOp, return that operation's indices. 12800f241638SMatthias Springer void getInsertionIndices(TransferReadOp xferOp, 128198f6289aSDiego Caballero SmallVectorImpl<OpFoldResult> &indices) const { 128298f6289aSDiego Caballero if (auto insertOp = getInsertOp(xferOp)) { 128398f6289aSDiego Caballero auto pos = insertOp.getMixedPosition(); 128498f6289aSDiego Caballero indices.append(pos.begin(), pos.end()); 128598f6289aSDiego Caballero } 12860f241638SMatthias Springer } 12870f241638SMatthias Springer 12880f241638SMatthias Springer /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 12890f241638SMatthias Springer /// accesses, and broadcasts and transposes in permutation maps. 12900f241638SMatthias Springer LogicalResult matchAndRewrite(TransferReadOp xferOp, 12910f241638SMatthias Springer PatternRewriter &rewriter) const override { 12922ca887deSMatthias Springer if (xferOp.getVectorType().getRank() <= options.targetRank) 1293aa2dc792SMatthias Springer return rewriter.notifyMatchFailure( 1294aa2dc792SMatthias Springer xferOp, "vector rank is less or equal to target rank"); 129527a713f5SBenjamin Maxwell if (failed(checkLowerTensors(xferOp, rewriter))) 129627a713f5SBenjamin Maxwell return failure(); 1297f718a53dSMatthias Springer // Transfer ops that modify the element type are not supported atm. 1298f718a53dSMatthias Springer if (xferOp.getVectorType().getElementType() != 1299f718a53dSMatthias Springer xferOp.getShapedType().getElementType()) 1300aa2dc792SMatthias Springer return rewriter.notifyMatchFailure( 1301aa2dc792SMatthias Springer xferOp, "not yet supported: element type mismatch"); 13020f241638SMatthias Springer auto xferVecType = xferOp.getVectorType(); 13032a82dfd7SBenjamin Maxwell if (xferVecType.getScalableDims()[0]) { 13042a82dfd7SBenjamin Maxwell // Cannot unroll a scalable dimension at compile time. 1305aa2dc792SMatthias Springer return rewriter.notifyMatchFailure( 1306aa2dc792SMatthias Springer xferOp, "scalable dimensions cannot be unrolled"); 13072a82dfd7SBenjamin Maxwell } 13082a82dfd7SBenjamin Maxwell 1309aa2dc792SMatthias Springer auto insertOp = getInsertOp(xferOp); 1310aa2dc792SMatthias Springer auto vec = buildResultVector(rewriter, xferOp); 1311aa2dc792SMatthias Springer auto vecType = dyn_cast<VectorType>(vec.getType()); 1312aa2dc792SMatthias Springer 13132a82dfd7SBenjamin Maxwell VectorType newXferVecType = VectorType::Builder(xferVecType).dropDim(0); 13142a82dfd7SBenjamin Maxwell 13150f241638SMatthias Springer int64_t dimSize = xferVecType.getShape()[0]; 13160f241638SMatthias Springer 13170f241638SMatthias Springer // Generate fully unrolled loop of transfer ops. 13186825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 13190f241638SMatthias Springer for (int64_t i = 0; i < dimSize; ++i) { 1320a54f4eaeSMogball Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 13210f241638SMatthias Springer 13220f241638SMatthias Springer vec = generateInBoundsCheck( 13236825bfe2SNicolas Vasilache rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), 13240f241638SMatthias Springer /*inBoundsCase=*/ 13250f241638SMatthias Springer [&](OpBuilder &b, Location loc) { 13260f241638SMatthias Springer // Indices for the new transfer op. 13270f241638SMatthias Springer SmallVector<Value, 8> xferIndices; 13286825bfe2SNicolas Vasilache getXferIndices(b, xferOp, iv, xferIndices); 13290f241638SMatthias Springer 13300f241638SMatthias Springer // Indices for the new vector.insert op. 133198f6289aSDiego Caballero SmallVector<OpFoldResult, 8> insertionIndices; 13320f241638SMatthias Springer getInsertionIndices(xferOp, insertionIndices); 133398f6289aSDiego Caballero insertionIndices.push_back(rewriter.getIndexAttr(i)); 13340f241638SMatthias Springer 13357c38fd60SJacques Pienaar auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 13366825bfe2SNicolas Vasilache auto newXferOp = b.create<vector::TransferReadOp>( 13377c38fd60SJacques Pienaar loc, newXferVecType, xferOp.getSource(), xferIndices, 13386825bfe2SNicolas Vasilache AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 13397c38fd60SJacques Pienaar xferOp.getPadding(), Value(), inBoundsAttr); 13400f241638SMatthias Springer maybeAssignMask(b, xferOp, newXferOp, i); 13416825bfe2SNicolas Vasilache return b.create<vector::InsertOp>(loc, newXferOp, vec, 13426825bfe2SNicolas Vasilache insertionIndices); 13430f241638SMatthias Springer }, 13440f241638SMatthias Springer /*outOfBoundsCase=*/ 13450f241638SMatthias Springer [&](OpBuilder &b, Location loc) { 13460f241638SMatthias Springer // Loop through original (unmodified) vector. 13470f241638SMatthias Springer return vec; 13480f241638SMatthias Springer }); 13490f241638SMatthias Springer } 13500f241638SMatthias Springer 13510f241638SMatthias Springer if (insertOp) { 13520f241638SMatthias Springer // Rewrite single user of the old TransferReadOp, which was an InsertOp. 13530f241638SMatthias Springer rewriter.replaceOp(insertOp, vec); 13540f241638SMatthias Springer rewriter.eraseOp(xferOp); 13550f241638SMatthias Springer } else { 13560f241638SMatthias Springer rewriter.replaceOp(xferOp, vec); 13570f241638SMatthias Springer } 13580f241638SMatthias Springer 13590f241638SMatthias Springer return success(); 13600f241638SMatthias Springer } 13610f241638SMatthias Springer }; 13620f241638SMatthias Springer 13630f241638SMatthias Springer /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one 13640f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no 13650f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled. 13660f241638SMatthias Springer /// 13670f241638SMatthias Springer /// ``` 13680f241638SMatthias Springer /// E.g.: 13690f241638SMatthias Springer /// ``` 13700f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c] 13710f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32> 13720f241638SMatthias Springer /// ``` 13730f241638SMatthias Springer /// is rewritten to IR such as (simplified): 13740f241638SMatthias Springer /// ``` 13759816edc9SCullen Rhodes /// %v0 = vector.extract %vec[0] : vector<4xf32> from vector<5x4xf32> 13760f241638SMatthias Springer /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> 13779816edc9SCullen Rhodes /// %v1 = vector.extract %vec[1] : vector<4xf32> from vector<5x4xf32> 13780f241638SMatthias Springer /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> 13790f241638SMatthias Springer /// ... 13809816edc9SCullen Rhodes /// %v4 = vector.extract %vec[4] : vector<4xf32> from vector<5x4xf32> 13810f241638SMatthias Springer /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> 13820f241638SMatthias Springer /// ``` 13830f241638SMatthias Springer /// 13840f241638SMatthias Springer /// Note: As an optimization, if the vector of the original TransferWriteOp 13850f241638SMatthias Springer /// was directly extracted from another vector via an ExtractOp `a`, extract 13860f241638SMatthias Springer /// the vectors for the newly generated TransferWriteOps from `a`'s input. By 13870f241638SMatthias Springer /// doing so, `a` may become dead, and the number of ExtractOps generated during 13880f241638SMatthias Springer /// recursive application of this pattern will be minimal. 13890f241638SMatthias Springer struct UnrollTransferWriteConversion 13902ca887deSMatthias Springer : public VectorToSCFPattern<TransferWriteOp> { 13912ca887deSMatthias Springer using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 13920f241638SMatthias Springer 1393700b64dcSMatthias Springer void initialize() { 1394700b64dcSMatthias Springer // This pattern recursively unpacks one dimension at a time. The recursion 1395700b64dcSMatthias Springer // bounded as the rank is strictly decreasing. 1396700b64dcSMatthias Springer setHasBoundedRewriteRecursion(); 1397700b64dcSMatthias Springer } 1398700b64dcSMatthias Springer 13990f241638SMatthias Springer /// Return the vector from which newly generated ExtracOps will extract. 14000f241638SMatthias Springer Value getDataVector(TransferWriteOp xferOp) const { 14010f241638SMatthias Springer if (auto extractOp = getExtractOp(xferOp)) 14027c38fd60SJacques Pienaar return extractOp.getVector(); 14037c38fd60SJacques Pienaar return xferOp.getVector(); 14040f241638SMatthias Springer } 14050f241638SMatthias Springer 14060f241638SMatthias Springer /// If the input of the given TransferWriteOp is an ExtractOp, return it. 14070f241638SMatthias Springer vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { 14087c38fd60SJacques Pienaar if (auto *op = xferOp.getVector().getDefiningOp()) 14090f241638SMatthias Springer return dyn_cast<vector::ExtractOp>(op); 14100f241638SMatthias Springer return vector::ExtractOp(); 14110f241638SMatthias Springer } 14120f241638SMatthias Springer 14130f241638SMatthias Springer /// If the input of the given TransferWriteOp is an ExtractOp, return its 14140f241638SMatthias Springer /// indices. 14150f241638SMatthias Springer void getExtractionIndices(TransferWriteOp xferOp, 141698f6289aSDiego Caballero SmallVectorImpl<OpFoldResult> &indices) const { 141798f6289aSDiego Caballero if (auto extractOp = getExtractOp(xferOp)) { 141898f6289aSDiego Caballero auto pos = extractOp.getMixedPosition(); 141998f6289aSDiego Caballero indices.append(pos.begin(), pos.end()); 142098f6289aSDiego Caballero } 14210f241638SMatthias Springer } 14220f241638SMatthias Springer 14230f241638SMatthias Springer /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 14240f241638SMatthias Springer /// accesses, and broadcasts and transposes in permutation maps. 14250f241638SMatthias Springer LogicalResult matchAndRewrite(TransferWriteOp xferOp, 14260f241638SMatthias Springer PatternRewriter &rewriter) const override { 1427c84061fdSRik Huijzer VectorType inputVectorTy = xferOp.getVectorType(); 1428c84061fdSRik Huijzer 1429c84061fdSRik Huijzer if (inputVectorTy.getRank() <= options.targetRank) 14300f241638SMatthias Springer return failure(); 1431c84061fdSRik Huijzer 143227a713f5SBenjamin Maxwell if (failed(checkLowerTensors(xferOp, rewriter))) 14338fb48979SMatthias Springer return failure(); 1434f718a53dSMatthias Springer // Transfer ops that modify the element type are not supported atm. 1435c84061fdSRik Huijzer if (inputVectorTy.getElementType() != 1436f718a53dSMatthias Springer xferOp.getShapedType().getElementType()) 1437f718a53dSMatthias Springer return failure(); 14380f241638SMatthias Springer 14390f241638SMatthias Springer auto vec = getDataVector(xferOp); 1440c84061fdSRik Huijzer if (inputVectorTy.getScalableDims()[0]) { 1441ced97ffdSBenjamin Maxwell // Cannot unroll a scalable dimension at compile time. 1442ced97ffdSBenjamin Maxwell return failure(); 1443ced97ffdSBenjamin Maxwell } 1444ced97ffdSBenjamin Maxwell 1445c84061fdSRik Huijzer int64_t dimSize = inputVectorTy.getShape()[0]; 14460ce25b12SRahul Kayaith Value source = xferOp.getSource(); // memref or tensor to be written to. 1447bd20756dSMatthias Springer auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 14480f241638SMatthias Springer 14490f241638SMatthias Springer // Generate fully unrolled loop of transfer ops. 14506825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 14510f241638SMatthias Springer for (int64_t i = 0; i < dimSize; ++i) { 1452a54f4eaeSMogball Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 14530f241638SMatthias Springer 1454bd20756dSMatthias Springer auto updatedSource = generateInBoundsCheck( 14556825bfe2SNicolas Vasilache rewriter, xferOp, iv, unpackedDim(xferOp), 1456bd20756dSMatthias Springer isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), 1457bd20756dSMatthias Springer /*inBoundsCase=*/ 1458bd20756dSMatthias Springer [&](OpBuilder &b, Location loc) { 14590f241638SMatthias Springer // Indices for the new transfer op. 14600f241638SMatthias Springer SmallVector<Value, 8> xferIndices; 14616825bfe2SNicolas Vasilache getXferIndices(b, xferOp, iv, xferIndices); 14620f241638SMatthias Springer 14630f241638SMatthias Springer // Indices for the new vector.extract op. 146498f6289aSDiego Caballero SmallVector<OpFoldResult, 8> extractionIndices; 14650f241638SMatthias Springer getExtractionIndices(xferOp, extractionIndices); 146698f6289aSDiego Caballero extractionIndices.push_back(b.getI64IntegerAttr(i)); 14670f241638SMatthias Springer 14686825bfe2SNicolas Vasilache auto extracted = 14696825bfe2SNicolas Vasilache b.create<vector::ExtractOp>(loc, vec, extractionIndices); 14707c38fd60SJacques Pienaar auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 1471c84061fdSRik Huijzer Value xferVec; 1472c84061fdSRik Huijzer if (inputVectorTy.getRank() == 1) { 1473c84061fdSRik Huijzer // When target-rank=0, unrolling would causes the vector input 1474c84061fdSRik Huijzer // argument into `transfer_write` to become a scalar. We solve 1475c84061fdSRik Huijzer // this by broadcasting the scalar to a 0D vector. 1476c84061fdSRik Huijzer xferVec = b.create<vector::BroadcastOp>( 1477c84061fdSRik Huijzer loc, VectorType::get({}, extracted.getType()), extracted); 1478c84061fdSRik Huijzer } else { 1479c84061fdSRik Huijzer xferVec = extracted; 1480c84061fdSRik Huijzer } 14816825bfe2SNicolas Vasilache auto newXferOp = b.create<vector::TransferWriteOp>( 1482c84061fdSRik Huijzer loc, sourceType, xferVec, source, xferIndices, 14836825bfe2SNicolas Vasilache AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 14846825bfe2SNicolas Vasilache inBoundsAttr); 14850f241638SMatthias Springer 14860f241638SMatthias Springer maybeAssignMask(b, xferOp, newXferOp, i); 1487bd20756dSMatthias Springer 1488bd20756dSMatthias Springer return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); 1489bd20756dSMatthias Springer }, 1490bd20756dSMatthias Springer /*outOfBoundsCase=*/ 1491bd20756dSMatthias Springer [&](OpBuilder &b, Location loc) { 1492bd20756dSMatthias Springer return isTensorOp(xferOp) ? source : Value(); 14930f241638SMatthias Springer }); 1494bd20756dSMatthias Springer 1495bd20756dSMatthias Springer if (isTensorOp(xferOp)) 1496bd20756dSMatthias Springer source = updatedSource; 14970f241638SMatthias Springer } 14980f241638SMatthias Springer 1499bd20756dSMatthias Springer if (isTensorOp(xferOp)) 1500bd20756dSMatthias Springer rewriter.replaceOp(xferOp, source); 1501bd20756dSMatthias Springer else 15020f241638SMatthias Springer rewriter.eraseOp(xferOp); 1503bd20756dSMatthias Springer 15040f241638SMatthias Springer return success(); 15050f241638SMatthias Springer } 15060f241638SMatthias Springer }; 15070f241638SMatthias Springer 1508a088bed4SMatthias Springer } // namespace lowering_n_d_unrolled 1509a088bed4SMatthias Springer 1510a088bed4SMatthias Springer namespace lowering_1_d { 1511a088bed4SMatthias Springer 15120f241638SMatthias Springer /// Compute the indices into the memref for the LoadOp/StoreOp generated as 15130f241638SMatthias Springer /// part of TransferOp1dConversion. Return the memref dimension on which 151415ae9964SKazu Hirata /// the transfer is operating. A return value of std::nullopt indicates a 151515ae9964SKazu Hirata /// broadcast. 15160f241638SMatthias Springer template <typename OpTy> 15170a81ace0SKazu Hirata static std::optional<int64_t> 15186825bfe2SNicolas Vasilache get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, 15190f241638SMatthias Springer SmallVector<Value, 8> &memrefIndices) { 15207c38fd60SJacques Pienaar auto indices = xferOp.getIndices(); 15217c38fd60SJacques Pienaar auto map = xferOp.getPermutationMap(); 1522c537a943SNicolas Vasilache assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 15230f241638SMatthias Springer 15240f241638SMatthias Springer memrefIndices.append(indices.begin(), indices.end()); 15250f241638SMatthias Springer assert(map.getNumResults() == 1 && 15260f241638SMatthias Springer "Expected 1 permutation map result for 1D transfer"); 15271609f1c2Slong.chen if (auto expr = dyn_cast<AffineDimExpr>(map.getResult(0))) { 15286825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 15290f241638SMatthias Springer auto dim = expr.getPosition(); 15306825bfe2SNicolas Vasilache AffineExpr d0, d1; 15316825bfe2SNicolas Vasilache bindDims(xferOp.getContext(), d0, d1); 15326825bfe2SNicolas Vasilache Value offset = memrefIndices[dim]; 15334c48f016SMatthias Springer memrefIndices[dim] = 15344c48f016SMatthias Springer affine::makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 15350f241638SMatthias Springer return dim; 15360f241638SMatthias Springer } 15370f241638SMatthias Springer 15380f241638SMatthias Springer assert(xferOp.isBroadcastDim(0) && 15390f241638SMatthias Springer "Expected AffineDimExpr or AffineConstantExpr"); 15401a36588eSKazu Hirata return std::nullopt; 15410f241638SMatthias Springer } 15420f241638SMatthias Springer 15430f241638SMatthias Springer /// Codegen strategy for TransferOp1dConversion, depending on the 15440f241638SMatthias Springer /// operation. 15450f241638SMatthias Springer template <typename OpTy> 15460f241638SMatthias Springer struct Strategy1d; 15470f241638SMatthias Springer 15480f241638SMatthias Springer /// Codegen strategy for TransferReadOp. 15490f241638SMatthias Springer template <> 15500f241638SMatthias Springer struct Strategy1d<TransferReadOp> { 15516825bfe2SNicolas Vasilache static void generateForLoopBody(OpBuilder &b, Location loc, 15520f241638SMatthias Springer TransferReadOp xferOp, Value iv, 15530f241638SMatthias Springer ValueRange loopState) { 15540f241638SMatthias Springer SmallVector<Value, 8> indices; 15556825bfe2SNicolas Vasilache auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 15560f241638SMatthias Springer auto vec = loopState[0]; 15570f241638SMatthias Springer 15580f241638SMatthias Springer // In case of out-of-bounds access, leave `vec` as is (was initialized with 15590f241638SMatthias Springer // padding value). 15600f241638SMatthias Springer auto nextVec = generateInBoundsCheck( 15616825bfe2SNicolas Vasilache b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), 15620f241638SMatthias Springer /*inBoundsCase=*/ 15636825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) { 15647c38fd60SJacques Pienaar Value val = 15657c38fd60SJacques Pienaar b.create<memref::LoadOp>(loc, xferOp.getSource(), indices); 15667c5ecc8bSMogball return b.create<vector::InsertElementOp>(loc, val, vec, iv); 15670f241638SMatthias Springer }, 15680f241638SMatthias Springer /*outOfBoundsCase=*/ 15690f241638SMatthias Springer [&](OpBuilder & /*b*/, Location loc) { return vec; }); 15706825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc, nextVec); 15710f241638SMatthias Springer } 15720f241638SMatthias Springer 15736825bfe2SNicolas Vasilache static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { 15740f241638SMatthias Springer // Inititalize vector with padding value. 15756825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 15766a8ba318SRiver Riddle return b.create<vector::SplatOp>(loc, xferOp.getVectorType(), 15777c38fd60SJacques Pienaar xferOp.getPadding()); 15780f241638SMatthias Springer } 15790f241638SMatthias Springer }; 15800f241638SMatthias Springer 15810f241638SMatthias Springer /// Codegen strategy for TransferWriteOp. 15820f241638SMatthias Springer template <> 15830f241638SMatthias Springer struct Strategy1d<TransferWriteOp> { 15846825bfe2SNicolas Vasilache static void generateForLoopBody(OpBuilder &b, Location loc, 15850f241638SMatthias Springer TransferWriteOp xferOp, Value iv, 15860f241638SMatthias Springer ValueRange /*loopState*/) { 15870f241638SMatthias Springer SmallVector<Value, 8> indices; 15886825bfe2SNicolas Vasilache auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 15890f241638SMatthias Springer 15900f241638SMatthias Springer // Nothing to do in case of out-of-bounds access. 15910f241638SMatthias Springer generateInBoundsCheck( 15926825bfe2SNicolas Vasilache b, xferOp, iv, dim, 15936825bfe2SNicolas Vasilache /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { 15946825bfe2SNicolas Vasilache auto val = 15957c38fd60SJacques Pienaar b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv); 15967c38fd60SJacques Pienaar b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices); 15970f241638SMatthias Springer }); 15986825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc); 15990f241638SMatthias Springer } 16000f241638SMatthias Springer 16016825bfe2SNicolas Vasilache static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { 16026825bfe2SNicolas Vasilache return Value(); 16036825bfe2SNicolas Vasilache } 16040f241638SMatthias Springer }; 16050f241638SMatthias Springer 16060f241638SMatthias Springer /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is 16070f241638SMatthias Springer /// necessary in cases where a 1D vector transfer op cannot be lowered into 16080f241638SMatthias Springer /// vector load/stores due to non-unit strides or broadcasts: 16090f241638SMatthias Springer /// 16100f241638SMatthias Springer /// * Transfer dimension is not the last memref dimension 16110f241638SMatthias Springer /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) 16120f241638SMatthias Springer /// * Memref has a layout map with non-unit stride on the last dimension 16130f241638SMatthias Springer /// 16140f241638SMatthias Springer /// This pattern generates IR as follows: 16150f241638SMatthias Springer /// 16160f241638SMatthias Springer /// 1. Generate a for loop iterating over each vector element. 16170f241638SMatthias Springer /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, 16180f241638SMatthias Springer /// depending on OpTy. 16190f241638SMatthias Springer /// 16200f241638SMatthias Springer /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp 16210f241638SMatthias Springer /// can be generated instead of TransferOp1dConversion. Add such a pattern 16220f241638SMatthias Springer /// to ConvertVectorToLLVM. 16230f241638SMatthias Springer /// 16240f241638SMatthias Springer /// E.g.: 16250f241638SMatthias Springer /// ``` 16260f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b] 16270f241638SMatthias Springer /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 16280f241638SMatthias Springer /// : vector<9xf32>, memref<?x?xf32> 16290f241638SMatthias Springer /// ``` 16300f241638SMatthias Springer /// Is rewritten to approximately the following pseudo-IR: 16310f241638SMatthias Springer /// ``` 16320f241638SMatthias Springer /// for i = 0 to 9 { 16330f241638SMatthias Springer /// %t = vector.extractelement %vec[i] : vector<9xf32> 16340f241638SMatthias Springer /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> 16350f241638SMatthias Springer /// } 16360f241638SMatthias Springer /// ``` 16370f241638SMatthias Springer template <typename OpTy> 16382ca887deSMatthias Springer struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { 16392ca887deSMatthias Springer using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 16400f241638SMatthias Springer 16410f241638SMatthias Springer LogicalResult matchAndRewrite(OpTy xferOp, 16420f241638SMatthias Springer PatternRewriter &rewriter) const override { 1643c537a943SNicolas Vasilache // TODO: support 0-d corner case. 1644c537a943SNicolas Vasilache if (xferOp.getTransferRank() == 0) 1645c537a943SNicolas Vasilache return failure(); 16467c38fd60SJacques Pienaar auto map = xferOp.getPermutationMap(); 16475550c821STres Popp auto memRefType = dyn_cast<MemRefType>(xferOp.getShapedType()); 16480f241638SMatthias Springer 16490f241638SMatthias Springer if (!memRefType) 16500f241638SMatthias Springer return failure(); 16510f241638SMatthias Springer if (xferOp.getVectorType().getRank() != 1) 16520f241638SMatthias Springer return failure(); 1653*6aaa8f25SMatthias Springer if (map.isMinorIdentity() && memRefType.isLastDimUnitStride()) 16540f241638SMatthias Springer return failure(); // Handled by ConvertVectorToLLVM 16550f241638SMatthias Springer 16560f241638SMatthias Springer // Loop bounds, step, state... 16576825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 16580f241638SMatthias Springer auto vecType = xferOp.getVectorType(); 1659a54f4eaeSMogball auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); 16605cebffc2SAndrzej Warzynski Value ub = 1661a54f4eaeSMogball rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0)); 16625cebffc2SAndrzej Warzynski if (vecType.isScalable()) { 16635cebffc2SAndrzej Warzynski Value vscale = 16645cebffc2SAndrzej Warzynski rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 16655cebffc2SAndrzej Warzynski ub = rewriter.create<arith::MulIOp>(loc, ub, vscale); 16665cebffc2SAndrzej Warzynski } 1667a54f4eaeSMogball auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 16686825bfe2SNicolas Vasilache auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp); 16690f241638SMatthias Springer 16700f241638SMatthias Springer // Generate for loop. 16710f241638SMatthias Springer rewriter.replaceOpWithNewOp<scf::ForOp>( 16720f241638SMatthias Springer xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 16736825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 16746825bfe2SNicolas Vasilache Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState); 16750f241638SMatthias Springer }); 16760f241638SMatthias Springer 16770f241638SMatthias Springer return success(); 16780f241638SMatthias Springer } 16790f241638SMatthias Springer }; 16804ead2cf7SAlex Zinenko 1681a088bed4SMatthias Springer } // namespace lowering_1_d 1682df63eedeSBenjamin Kramer } // namespace 1683df63eedeSBenjamin Kramer 168447f175b0SRiver Riddle void mlir::populateVectorToSCFConversionPatterns( 1685dc4e913bSChris Lattner RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { 16860f241638SMatthias Springer if (options.unroll) { 1687a088bed4SMatthias Springer patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion, 1688a088bed4SMatthias Springer lowering_n_d_unrolled::UnrollTransferWriteConversion>( 16892ca887deSMatthias Springer patterns.getContext(), options); 16900f241638SMatthias Springer } else { 1691a088bed4SMatthias Springer patterns.add<lowering_n_d::PrepareTransferReadConversion, 1692a088bed4SMatthias Springer lowering_n_d::PrepareTransferWriteConversion, 1693a088bed4SMatthias Springer lowering_n_d::TransferOpConversion<TransferReadOp>, 1694a088bed4SMatthias Springer lowering_n_d::TransferOpConversion<TransferWriteOp>>( 1695a088bed4SMatthias Springer patterns.getContext(), options); 16960f241638SMatthias Springer } 169727a713f5SBenjamin Maxwell if (options.lowerScalable) { 169827a713f5SBenjamin Maxwell patterns.add<lowering_n_d::ScalableTransposeTransferWriteConversion>( 169927a713f5SBenjamin Maxwell patterns.getContext(), options); 170027a713f5SBenjamin Maxwell } 17012ca887deSMatthias Springer if (options.targetRank == 1) { 1702a088bed4SMatthias Springer patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>, 1703a088bed4SMatthias Springer lowering_1_d::TransferOp1dConversion<TransferWriteOp>>( 1704a088bed4SMatthias Springer patterns.getContext(), options); 17050f241638SMatthias Springer } 1706f36e909dSBenjamin Maxwell patterns.add<lowering_n_d::DecomposePrintOpConversion>(patterns.getContext(), 1707f36e909dSBenjamin Maxwell options); 17084ead2cf7SAlex Zinenko } 17093393cc4cSNicolas Vasilache 17105f9e0466SNicolas Vasilache namespace { 17115f9e0466SNicolas Vasilache 17125f9e0466SNicolas Vasilache struct ConvertVectorToSCFPass 171367d0d7acSMichele Scuttari : public impl::ConvertVectorToSCFBase<ConvertVectorToSCFPass> { 17145f9e0466SNicolas Vasilache ConvertVectorToSCFPass() = default; 17155f9e0466SNicolas Vasilache ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 17165f9e0466SNicolas Vasilache this->fullUnroll = options.unroll; 17172ca887deSMatthias Springer this->targetRank = options.targetRank; 1718558e7401SMatthias Springer this->lowerTensors = options.lowerTensors; 171927a713f5SBenjamin Maxwell this->lowerScalable = options.lowerScalable; 17205f9e0466SNicolas Vasilache } 17215f9e0466SNicolas Vasilache 172241574554SRiver Riddle void runOnOperation() override { 17232ca887deSMatthias Springer VectorTransferToSCFOptions options; 1724fb7ec1f1SMatthias Springer options.unroll = fullUnroll; 1725fb7ec1f1SMatthias Springer options.targetRank = targetRank; 1726558e7401SMatthias Springer options.lowerTensors = lowerTensors; 172727a713f5SBenjamin Maxwell options.lowerScalable = lowerScalable; 1728fb7ec1f1SMatthias Springer 1729fb7ec1f1SMatthias Springer // Lower permutation maps first. 173047f175b0SRiver Riddle RewritePatternSet lowerTransferPatterns(&getContext()); 1731fb7ec1f1SMatthias Springer mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 1732fb7ec1f1SMatthias Springer lowerTransferPatterns); 173309dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), 1734fb7ec1f1SMatthias Springer std::move(lowerTransferPatterns)); 17352ca887deSMatthias Springer 173647f175b0SRiver Riddle RewritePatternSet patterns(&getContext()); 17372ca887deSMatthias Springer populateVectorToSCFConversionPatterns(patterns, options); 173809dfc571SJacques Pienaar (void)applyPatternsGreedily(getOperation(), std::move(patterns)); 17395f9e0466SNicolas Vasilache } 17405f9e0466SNicolas Vasilache }; 17415f9e0466SNicolas Vasilache 17425f9e0466SNicolas Vasilache } // namespace 17435f9e0466SNicolas Vasilache 17445f9e0466SNicolas Vasilache std::unique_ptr<Pass> 17455f9e0466SNicolas Vasilache mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 17465f9e0466SNicolas Vasilache return std::make_unique<ConvertVectorToSCFPass>(options); 17475f9e0466SNicolas Vasilache } 1748