xref: /llvm-project/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
1042800a4SBenjamin Maxwell //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
2042800a4SBenjamin Maxwell //
3042800a4SBenjamin Maxwell // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4042800a4SBenjamin Maxwell // See https://llvm.org/LICENSE.txt for license information.
5042800a4SBenjamin Maxwell // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6042800a4SBenjamin Maxwell //
7042800a4SBenjamin Maxwell //===----------------------------------------------------------------------===//
8042800a4SBenjamin Maxwell //
9042800a4SBenjamin Maxwell // This pass legalizes vector operations so they can be lowered to ArmSME.
10042800a4SBenjamin Maxwell //
11042800a4SBenjamin Maxwell // Note: In the context of this pass 'tile' always refers to an SME tile.
12042800a4SBenjamin Maxwell //
13042800a4SBenjamin Maxwell //===----------------------------------------------------------------------===//
14042800a4SBenjamin Maxwell 
15c2dea712SBenjamin Maxwell #include "mlir/Dialect/Arith/Utils/Utils.h"
16042800a4SBenjamin Maxwell #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
17042800a4SBenjamin Maxwell #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
18042800a4SBenjamin Maxwell #include "mlir/Dialect/ArmSME/Utils/Utils.h"
19042800a4SBenjamin Maxwell #include "mlir/Dialect/Func/IR/FuncOps.h"
2031613de9SMatthias Springer #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
21c194bc77SBenjamin Maxwell #include "mlir/Dialect/Index/IR/IndexDialect.h"
22c194bc77SBenjamin Maxwell #include "mlir/Dialect/Index/IR/IndexOps.h"
230473e322SBenjamin Maxwell #include "mlir/Dialect/MemRef/IR/MemRef.h"
245ed5d723SBenjamin Maxwell #include "mlir/Dialect/SCF/IR/SCF.h"
25042800a4SBenjamin Maxwell #include "mlir/Dialect/SCF/Transforms/Patterns.h"
26042800a4SBenjamin Maxwell #include "mlir/Dialect/Utils/IndexingUtils.h"
27c194bc77SBenjamin Maxwell #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
2831613de9SMatthias Springer #include "mlir/Transforms/DialectConversion.h"
2931613de9SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30042800a4SBenjamin Maxwell 
31042800a4SBenjamin Maxwell #define DEBUG_TYPE "arm-sme-vector-legalization"
32042800a4SBenjamin Maxwell 
33042800a4SBenjamin Maxwell namespace mlir::arm_sme {
34042800a4SBenjamin Maxwell #define GEN_PASS_DEF_VECTORLEGALIZATION
35042800a4SBenjamin Maxwell #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
36042800a4SBenjamin Maxwell } // namespace mlir::arm_sme
37042800a4SBenjamin Maxwell 
38042800a4SBenjamin Maxwell using namespace mlir;
39042800a4SBenjamin Maxwell using namespace mlir::arm_sme;
40042800a4SBenjamin Maxwell 
41042800a4SBenjamin Maxwell namespace {
42042800a4SBenjamin Maxwell 
43c2dea712SBenjamin Maxwell //===----------------------------------------------------------------------===//
44c2dea712SBenjamin Maxwell // Decomposition of vector operations larger than an SME tile
45c2dea712SBenjamin Maxwell //===----------------------------------------------------------------------===//
46c2dea712SBenjamin Maxwell 
47042800a4SBenjamin Maxwell // Common match failure reasons.
481408667fSBenjamin Maxwell static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
49042800a4SBenjamin Maxwell     "op vector size is not multiple of SME tiles");
501408667fSBenjamin Maxwell static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
51042800a4SBenjamin Maxwell     "op mask is unsupported for legalization/decomposition");
52042800a4SBenjamin Maxwell static constexpr StringLiteral
531408667fSBenjamin Maxwell     kMatchFailureNonPermutationMap("op affine map is not a permutation");
54d1fc59c3SBenjamin Maxwell static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
55d1fc59c3SBenjamin Maxwell     "expected transpose from illegal type to legal type");
56042800a4SBenjamin Maxwell 
57042800a4SBenjamin Maxwell /// An SMESubTile represents a single SME-sized sub-tile from decomposing a
58042800a4SBenjamin Maxwell /// larger vector type. The (`row`, `col`) are the position of the tile in the
59042800a4SBenjamin Maxwell /// original vector type. For example for an [8]x[8] tile with four [4]x[4]
60042800a4SBenjamin Maxwell /// sub-tiles, we would have:
61042800a4SBenjamin Maxwell ///
62042800a4SBenjamin Maxwell ///           8 x vscale
63042800a4SBenjamin Maxwell /// ┌─────────────┬─────────────┐
64042800a4SBenjamin Maxwell /// │(0,0)        │(0,4)        │
65042800a4SBenjamin Maxwell /// │             │             │
66042800a4SBenjamin Maxwell /// ├─────────────┼─────────────┤ 8 x vscale
67042800a4SBenjamin Maxwell /// │(4,0)        │(4,4)        │
68042800a4SBenjamin Maxwell /// │             │             │
69042800a4SBenjamin Maxwell /// └─────────────┴─────────────┘
70042800a4SBenjamin Maxwell struct SMESubTile {
71042800a4SBenjamin Maxwell   // Note: The units of (row, col) are vscale (as SME tiles are scalable).
72042800a4SBenjamin Maxwell   int row{0};
73042800a4SBenjamin Maxwell   int col{0};
74042800a4SBenjamin Maxwell   // The SME tile type.
75042800a4SBenjamin Maxwell   VectorType type;
76042800a4SBenjamin Maxwell };
77042800a4SBenjamin Maxwell 
78042800a4SBenjamin Maxwell /// Adds a constant elementwise scalable offset to `indices` (which are of equal
79042800a4SBenjamin Maxwell /// length). For example, in the 2D case this would return:
80042800a4SBenjamin Maxwell // { indices[0] + offset[0] * vscale, indices[1] + offset[1] *  vscale }
81042800a4SBenjamin Maxwell SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
82042800a4SBenjamin Maxwell                                                 Location loc,
83042800a4SBenjamin Maxwell                                                 ValueRange indices,
84042800a4SBenjamin Maxwell                                                 ArrayRef<int> scalableOffsets) {
85042800a4SBenjamin Maxwell   auto vscale = builder.create<vector::VectorScaleOp>(loc);
86042800a4SBenjamin Maxwell   return llvm::map_to_vector(
87042800a4SBenjamin Maxwell       llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
88042800a4SBenjamin Maxwell         auto [index, base] = pair;
89042800a4SBenjamin Maxwell         auto offset = builder.create<arith::MulIOp>(
90042800a4SBenjamin Maxwell             loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
91042800a4SBenjamin Maxwell         return builder.create<arith::AddIOp>(loc, index, offset);
92042800a4SBenjamin Maxwell       });
93042800a4SBenjamin Maxwell }
94042800a4SBenjamin Maxwell 
95042800a4SBenjamin Maxwell /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
96042800a4SBenjamin Maxwell /// indices for one of the SME sub-tiles it will decompose into.
97042800a4SBenjamin Maxwell ///
98042800a4SBenjamin Maxwell /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
99042800a4SBenjamin Maxwell /// indices for each tile would need to be adjusted as follows:
100042800a4SBenjamin Maxwell ///
101042800a4SBenjamin Maxwell /// initial indices = [a,b], inital size = 8x8, target size = 4x4
102042800a4SBenjamin Maxwell /// ┌─────────────┬─────────────┐
103042800a4SBenjamin Maxwell /// │[a,b]        │[a,b+4]      │
104042800a4SBenjamin Maxwell /// │             │             │
105042800a4SBenjamin Maxwell /// ├─────────────┼─────────────┤
106042800a4SBenjamin Maxwell /// │[a+4,b]      │[a+4,b+4]    │
107042800a4SBenjamin Maxwell /// │             │             │
108042800a4SBenjamin Maxwell /// └─────────────┴─────────────┘
109042800a4SBenjamin Maxwell SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
110042800a4SBenjamin Maxwell                                            ValueRange indices,
111042800a4SBenjamin Maxwell                                            SMESubTile smeTile) {
112042800a4SBenjamin Maxwell   return addConstantScalableOffset(builder, loc, indices,
113042800a4SBenjamin Maxwell                                    {smeTile.row, smeTile.col});
114042800a4SBenjamin Maxwell }
115042800a4SBenjamin Maxwell 
116042800a4SBenjamin Maxwell /// Returns true if `mask` is generated by an operation that can be decomposed
117042800a4SBenjamin Maxwell /// for SME. Currently, that is just no mask, or vector.create_mask.
118042800a4SBenjamin Maxwell /// TODO: Add support for vector.constant_mask once required for SME.
119042800a4SBenjamin Maxwell bool isSupportedMaskOp(Value mask) {
120042800a4SBenjamin Maxwell   return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
121042800a4SBenjamin Maxwell }
122042800a4SBenjamin Maxwell 
123042800a4SBenjamin Maxwell /// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
124042800a4SBenjamin Maxwell Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
125042800a4SBenjamin Maxwell                      SMESubTile smeTile) {
126042800a4SBenjamin Maxwell   assert(isSupportedMaskOp(mask));
127042800a4SBenjamin Maxwell   if (!mask)
128042800a4SBenjamin Maxwell     return Value{};
129042800a4SBenjamin Maxwell   auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
130042800a4SBenjamin Maxwell   // The operands of `vector.create_mask` (from a 2D perspective) are the
131042800a4SBenjamin Maxwell   // coordinates where the mask ends. So we subtract where this tile starts,
132042800a4SBenjamin Maxwell   // from the mask operands to get the parameters for this sub-tile.
133042800a4SBenjamin Maxwell   auto smeTileMaskDims = addConstantScalableOffset(
134042800a4SBenjamin Maxwell       builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
135042800a4SBenjamin Maxwell   auto smeTileCreateMask = builder.create<vector::CreateMaskOp>(
136042800a4SBenjamin Maxwell       loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
137042800a4SBenjamin Maxwell   return smeTileCreateMask.getResult();
138042800a4SBenjamin Maxwell }
139042800a4SBenjamin Maxwell 
140042800a4SBenjamin Maxwell /// Constructs an iterator that returns each SME tile (with coordinates)
141042800a4SBenjamin Maxwell /// contained within a VectorType. For example, if decomposing an [8]x[8] into
142042800a4SBenjamin Maxwell /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
143042800a4SBenjamin Maxwell /// (4, 4).
144042800a4SBenjamin Maxwell auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
145042800a4SBenjamin Maxwell                          VectorType smeTileType,
146042800a4SBenjamin Maxwell                          bool transposeIndices = false) {
147042800a4SBenjamin Maxwell   return llvm::map_range(
148c194bc77SBenjamin Maxwell       StaticTileOffsetRange(
149c194bc77SBenjamin Maxwell           type.getShape(),
150c194bc77SBenjamin Maxwell           {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
151c194bc77SBenjamin Maxwell            std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
152042800a4SBenjamin Maxwell       [=](auto indices) {
153042800a4SBenjamin Maxwell         int row = int(indices[0]);
154042800a4SBenjamin Maxwell         int col = int(indices[1]);
155042800a4SBenjamin Maxwell         if (transposeIndices)
156042800a4SBenjamin Maxwell           std::swap(row, col);
157042800a4SBenjamin Maxwell         return SMESubTile{row, col, smeTileType};
158042800a4SBenjamin Maxwell       });
159042800a4SBenjamin Maxwell }
160042800a4SBenjamin Maxwell 
161042800a4SBenjamin Maxwell /// Returns the number of SME tiles that fit into the (2D-scalable) vector type
162042800a4SBenjamin Maxwell /// `type`.
163042800a4SBenjamin Maxwell int getNumberOfSMETilesForVectorType(VectorType type) {
164042800a4SBenjamin Maxwell   assert(isMultipleOfSMETileVectorType(type) &&
165042800a4SBenjamin Maxwell          "`type` not multiple of SME tiles");
166042800a4SBenjamin Maxwell   int64_t vectorRows = type.getDimSize(0);
167042800a4SBenjamin Maxwell   int64_t vectorCols = type.getDimSize(1);
168042800a4SBenjamin Maxwell   auto elementType = type.getElementType();
169042800a4SBenjamin Maxwell   unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
170042800a4SBenjamin Maxwell   return (vectorRows * vectorCols) / (minNumElts * minNumElts);
171042800a4SBenjamin Maxwell }
172042800a4SBenjamin Maxwell 
173dadcaf82SBenjamin Maxwell /// Legalize `arith.constant dense<value>` splat operations to fit within SME
174dadcaf82SBenjamin Maxwell /// tiles by decomposing them into tile-sized operations.
175dadcaf82SBenjamin Maxwell struct LegalizeArithConstantOpsByDecomposition
17631613de9SMatthias Springer     : public OpConversionPattern<arith::ConstantOp> {
17731613de9SMatthias Springer   using OpConversionPattern::OpConversionPattern;
178dadcaf82SBenjamin Maxwell 
179dadcaf82SBenjamin Maxwell   LogicalResult
180dadcaf82SBenjamin Maxwell   matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
18131613de9SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
182dadcaf82SBenjamin Maxwell     auto vectorType = dyn_cast<VectorType>(constantOp.getType());
183dadcaf82SBenjamin Maxwell     auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
184dadcaf82SBenjamin Maxwell     if (!vectorType || !denseAttr || !denseAttr.isSplat())
185dadcaf82SBenjamin Maxwell       return failure();
186dadcaf82SBenjamin Maxwell 
187dadcaf82SBenjamin Maxwell     if (!isMultipleOfSMETileVectorType(vectorType))
188dadcaf82SBenjamin Maxwell       return rewriter.notifyMatchFailure(constantOp,
189dadcaf82SBenjamin Maxwell                                          kMatchFailureNotSMETileTypeMultiple);
190dadcaf82SBenjamin Maxwell 
191dadcaf82SBenjamin Maxwell     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
192dadcaf82SBenjamin Maxwell     auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
193dadcaf82SBenjamin Maxwell     auto tileSplat = rewriter.create<arith::ConstantOp>(
194dadcaf82SBenjamin Maxwell         constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
19531613de9SMatthias Springer     SmallVector<Value> repl(tileCount, tileSplat);
19631613de9SMatthias Springer     rewriter.replaceOpWithMultiple(constantOp, {repl});
197dadcaf82SBenjamin Maxwell 
198dadcaf82SBenjamin Maxwell     return success();
199dadcaf82SBenjamin Maxwell   }
200dadcaf82SBenjamin Maxwell };
201dadcaf82SBenjamin Maxwell 
202042800a4SBenjamin Maxwell /// Legalize `vector.outerproduct` operations to fit within SME tiles by
203042800a4SBenjamin Maxwell /// decomposing them into tile-sized operations.
204042800a4SBenjamin Maxwell struct LegalizeVectorOuterProductOpsByDecomposition
20531613de9SMatthias Springer     : public OpConversionPattern<vector::OuterProductOp> {
20631613de9SMatthias Springer   using OpConversionPattern::OpConversionPattern;
207042800a4SBenjamin Maxwell 
208042800a4SBenjamin Maxwell   LogicalResult
20931613de9SMatthias Springer   matchAndRewrite(vector::OuterProductOp outerProductOp,
21031613de9SMatthias Springer                   OneToNOpAdaptor adaptor,
21131613de9SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
212042800a4SBenjamin Maxwell     auto vectorType = outerProductOp.getResultVectorType();
213042800a4SBenjamin Maxwell     if (!isMultipleOfSMETileVectorType(vectorType))
2141408667fSBenjamin Maxwell       return rewriter.notifyMatchFailure(outerProductOp,
2151408667fSBenjamin Maxwell                                          kMatchFailureNotSMETileTypeMultiple);
216042800a4SBenjamin Maxwell 
217042800a4SBenjamin Maxwell     Value mask;
218042800a4SBenjamin Maxwell     Operation *rootOp = outerProductOp;
219042800a4SBenjamin Maxwell     auto loc = outerProductOp.getLoc();
220042800a4SBenjamin Maxwell     if (outerProductOp.isMasked()) {
221042800a4SBenjamin Maxwell       auto maskOp = outerProductOp.getMaskingOp();
222042800a4SBenjamin Maxwell       mask = maskOp.getMask();
223042800a4SBenjamin Maxwell       rootOp = maskOp;
22431613de9SMatthias Springer       rewriter.setInsertionPoint(rootOp);
225042800a4SBenjamin Maxwell     }
226042800a4SBenjamin Maxwell 
227042800a4SBenjamin Maxwell     if (!isSupportedMaskOp(mask))
228042800a4SBenjamin Maxwell       return rewriter.notifyMatchFailure(outerProductOp,
2291408667fSBenjamin Maxwell                                          kMatchFailureUnsupportedMaskOp);
230042800a4SBenjamin Maxwell 
231042800a4SBenjamin Maxwell     ValueRange accSMETiles = adaptor.getAcc();
232042800a4SBenjamin Maxwell     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
233042800a4SBenjamin Maxwell     VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
234042800a4SBenjamin Maxwell 
235042800a4SBenjamin Maxwell     SmallVector<Value> resultSMETiles;
236042800a4SBenjamin Maxwell     for (auto [index, smeTile] : llvm::enumerate(
237042800a4SBenjamin Maxwell              decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
238042800a4SBenjamin Maxwell 
239042800a4SBenjamin Maxwell       auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
240042800a4SBenjamin Maxwell       auto lhs = rewriter.create<vector::ScalableExtractOp>(
241042800a4SBenjamin Maxwell           loc, sliceType, outerProductOp.getLhs(), smeTile.row);
242042800a4SBenjamin Maxwell       auto rhs = rewriter.create<vector::ScalableExtractOp>(
243042800a4SBenjamin Maxwell           loc, sliceType, outerProductOp.getRhs(), smeTile.col);
244042800a4SBenjamin Maxwell       auto smeOuterProduct = rewriter.create<vector::OuterProductOp>(
245042800a4SBenjamin Maxwell           loc, smeTileType, lhs, rhs,
246042800a4SBenjamin Maxwell           !accSMETiles.empty() ? accSMETiles[index] : Value{},
247042800a4SBenjamin Maxwell           outerProductOp.getKind());
248042800a4SBenjamin Maxwell 
249042800a4SBenjamin Maxwell       auto maskedOuterProduct =
250042800a4SBenjamin Maxwell           vector::maskOperation(rewriter, smeOuterProduct, smeMask);
251042800a4SBenjamin Maxwell       resultSMETiles.push_back(maskedOuterProduct->getResult(0));
252042800a4SBenjamin Maxwell     }
253042800a4SBenjamin Maxwell 
25431613de9SMatthias Springer     rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
255042800a4SBenjamin Maxwell     return success();
256042800a4SBenjamin Maxwell   }
257042800a4SBenjamin Maxwell };
258042800a4SBenjamin Maxwell 
259042800a4SBenjamin Maxwell // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
260042800a4SBenjamin Maxwell // get the help of the type conversion), but doing so results in the type
261042800a4SBenjamin Maxwell // conversion adding target materializations in the `vector.mask` region
262042800a4SBenjamin Maxwell // (invalid). This pattern matches on `vector.mask` then calls into the
263042800a4SBenjamin Maxwell // `vector.outerproduct` pattern to work around this issue.
264042800a4SBenjamin Maxwell struct LegalizeMaskedVectorOuterProductOpsByDecomposition
26531613de9SMatthias Springer     : public OpConversionPattern<vector::MaskOp> {
26631613de9SMatthias Springer   using OpConversionPattern::OpConversionPattern;
267042800a4SBenjamin Maxwell 
268042800a4SBenjamin Maxwell   LogicalResult
26931613de9SMatthias Springer   matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
27031613de9SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
271a9eb8f0eSBenjamin Maxwell     if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
272a9eb8f0eSBenjamin Maxwell             maskOp.getMaskableOp())) {
273042800a4SBenjamin Maxwell       LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
274042800a4SBenjamin Maxwell                                                            getContext());
275042800a4SBenjamin Maxwell       return static_cast<RewritePattern &>(pattern).matchAndRewrite(
276042800a4SBenjamin Maxwell           outerProductOp, rewriter);
277042800a4SBenjamin Maxwell     }
278042800a4SBenjamin Maxwell     return failure();
279042800a4SBenjamin Maxwell   }
280042800a4SBenjamin Maxwell };
281042800a4SBenjamin Maxwell 
282042800a4SBenjamin Maxwell /// Legalize `vector.transfer_read` operations to fit within SME tiles by
283042800a4SBenjamin Maxwell /// decomposing them into tile-sized operations.
284042800a4SBenjamin Maxwell struct LegalizeTransferReadOpsByDecomposition
28531613de9SMatthias Springer     : public OpConversionPattern<vector::TransferReadOp> {
28631613de9SMatthias Springer   using OpConversionPattern::OpConversionPattern;
287042800a4SBenjamin Maxwell 
288042800a4SBenjamin Maxwell   LogicalResult
28931613de9SMatthias Springer   matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
29031613de9SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
291042800a4SBenjamin Maxwell     auto vectorType = readOp.getVectorType();
292042800a4SBenjamin Maxwell     if (!isMultipleOfSMETileVectorType(vectorType))
2931408667fSBenjamin Maxwell       return rewriter.notifyMatchFailure(readOp,
2941408667fSBenjamin Maxwell                                          kMatchFailureNotSMETileTypeMultiple);
295042800a4SBenjamin Maxwell 
296042800a4SBenjamin Maxwell     auto mask = readOp.getMask();
297042800a4SBenjamin Maxwell     if (!isSupportedMaskOp(mask))
298042800a4SBenjamin Maxwell       return rewriter.notifyMatchFailure(readOp,
2991408667fSBenjamin Maxwell                                          kMatchFailureUnsupportedMaskOp);
300042800a4SBenjamin Maxwell 
301042800a4SBenjamin Maxwell     auto permutationMap = readOp.getPermutationMap();
302042800a4SBenjamin Maxwell     if (!permutationMap.isPermutation())
303042800a4SBenjamin Maxwell       return rewriter.notifyMatchFailure(readOp,
3041408667fSBenjamin Maxwell                                          kMatchFailureNonPermutationMap);
305042800a4SBenjamin Maxwell 
306042800a4SBenjamin Maxwell     // Note: For 2D vector types the only non-identity permutation is a simple
307*aa295216SJay Foad     // transpose [1, 0].
308042800a4SBenjamin Maxwell     bool transposed = !permutationMap.isIdentity();
309042800a4SBenjamin Maxwell 
310042800a4SBenjamin Maxwell     auto loc = readOp.getLoc();
311042800a4SBenjamin Maxwell     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
312042800a4SBenjamin Maxwell 
313042800a4SBenjamin Maxwell     SmallVector<Value> resultSMETiles;
314042800a4SBenjamin Maxwell     for (SMESubTile smeTile :
315042800a4SBenjamin Maxwell          decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
316042800a4SBenjamin Maxwell       auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
317042800a4SBenjamin Maxwell       auto smeRead = rewriter.create<vector::TransferReadOp>(
318042800a4SBenjamin Maxwell           loc, smeTileType, readOp.getSource(),
319042800a4SBenjamin Maxwell           getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
320042800a4SBenjamin Maxwell           readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
321042800a4SBenjamin Maxwell           readOp.getInBoundsAttr());
322042800a4SBenjamin Maxwell       resultSMETiles.push_back(smeRead);
323042800a4SBenjamin Maxwell     }
324042800a4SBenjamin Maxwell 
32531613de9SMatthias Springer     rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
326042800a4SBenjamin Maxwell     return success();
327042800a4SBenjamin Maxwell   }
328042800a4SBenjamin Maxwell };
329042800a4SBenjamin Maxwell 
330042800a4SBenjamin Maxwell /// Legalize `vector.transfer_write` operations to fit within SME tiles by
331042800a4SBenjamin Maxwell /// decomposing them into tile-sized operations.
332042800a4SBenjamin Maxwell struct LegalizeTransferWriteOpsByDecomposition
33331613de9SMatthias Springer     : public OpConversionPattern<vector::TransferWriteOp> {
33431613de9SMatthias Springer   using OpConversionPattern::OpConversionPattern;
335042800a4SBenjamin Maxwell 
336042800a4SBenjamin Maxwell   LogicalResult
33731613de9SMatthias Springer   matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
33831613de9SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
339042800a4SBenjamin Maxwell     auto vectorType = writeOp.getVectorType();
340042800a4SBenjamin Maxwell     if (!isMultipleOfSMETileVectorType(vectorType))
3411408667fSBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp,
3421408667fSBenjamin Maxwell                                          kMatchFailureNotSMETileTypeMultiple);
343042800a4SBenjamin Maxwell 
344042800a4SBenjamin Maxwell     auto mask = writeOp.getMask();
345042800a4SBenjamin Maxwell     if (!isSupportedMaskOp(mask))
346042800a4SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp,
3471408667fSBenjamin Maxwell                                          kMatchFailureUnsupportedMaskOp);
348042800a4SBenjamin Maxwell 
349042800a4SBenjamin Maxwell     auto permutationMap = writeOp.getPermutationMap();
350042800a4SBenjamin Maxwell     if (!permutationMap.isPermutation())
351042800a4SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp,
3521408667fSBenjamin Maxwell                                          kMatchFailureNonPermutationMap);
353042800a4SBenjamin Maxwell 
354042800a4SBenjamin Maxwell     // Note: For 2D vector types the only non-identity permutation is a simple
355*aa295216SJay Foad     // transpose [1, 0].
356042800a4SBenjamin Maxwell     bool transposed = !permutationMap.isIdentity();
357042800a4SBenjamin Maxwell 
358042800a4SBenjamin Maxwell     auto loc = writeOp.getLoc();
359042800a4SBenjamin Maxwell     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
360042800a4SBenjamin Maxwell     auto inputSMETiles = adaptor.getVector();
361042800a4SBenjamin Maxwell 
362042800a4SBenjamin Maxwell     Value destTensorOrMemref = writeOp.getSource();
363042800a4SBenjamin Maxwell     for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
364042800a4SBenjamin Maxwell              rewriter, vectorType, smeTileType, transposed))) {
365042800a4SBenjamin Maxwell       auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
366042800a4SBenjamin Maxwell       auto smeWrite = rewriter.create<vector::TransferWriteOp>(
367042800a4SBenjamin Maxwell           loc, inputSMETiles[index], destTensorOrMemref,
368042800a4SBenjamin Maxwell           getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
369042800a4SBenjamin Maxwell           writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
370042800a4SBenjamin Maxwell       if (writeOp.hasPureTensorSemantics())
371042800a4SBenjamin Maxwell         destTensorOrMemref = smeWrite.getResult();
372042800a4SBenjamin Maxwell     }
373042800a4SBenjamin Maxwell 
374042800a4SBenjamin Maxwell     if (writeOp.hasPureTensorSemantics())
375042800a4SBenjamin Maxwell       rewriter.replaceOp(writeOp, destTensorOrMemref);
376042800a4SBenjamin Maxwell     else
377042800a4SBenjamin Maxwell       rewriter.eraseOp(writeOp);
378042800a4SBenjamin Maxwell 
379042800a4SBenjamin Maxwell     return success();
380042800a4SBenjamin Maxwell   }
381042800a4SBenjamin Maxwell };
382042800a4SBenjamin Maxwell 
3835ed5d723SBenjamin Maxwell /// Legalize a multi-tile transfer_write as a single store loop. This is done as
3845ed5d723SBenjamin Maxwell /// part of type decomposition as at this level we know each tile write is
3855ed5d723SBenjamin Maxwell /// disjoint, but that information is lost after decomposition (without analysis
3865ed5d723SBenjamin Maxwell /// to reconstruct it).
3875ed5d723SBenjamin Maxwell ///
3885ed5d723SBenjamin Maxwell /// Example (pseudo-MLIR):
3895ed5d723SBenjamin Maxwell ///
3905ed5d723SBenjamin Maxwell /// ```
3915ed5d723SBenjamin Maxwell /// vector.transfer_write %vector, %dest[%y, %x], %mask
3925ed5d723SBenjamin Maxwell ///   : vector<[16]x[8]xi16>, memref<?x?xi16>
3935ed5d723SBenjamin Maxwell /// ```
3945ed5d723SBenjamin Maxwell /// Is rewritten to:
3955ed5d723SBenjamin Maxwell /// ```
3965ed5d723SBenjamin Maxwell /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
3975ed5d723SBenjamin Maxwell ///   %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
3985ed5d723SBenjamin Maxwell ///     : vector<[8]xi1> from vector<[16]x[8]xi1>           |
3995ed5d723SBenjamin Maxwell ///   %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
4005ed5d723SBenjamin Maxwell ///     : vector<[8]xi16> from vector<[8]x[8]xi16>          |
4015ed5d723SBenjamin Maxwell ///   vector.transfer_write %upper_slice,                   |
4025ed5d723SBenjamin Maxwell ///     %dest[%slice_idx + %y, %x], %upper_slice_mask       |
4035ed5d723SBenjamin Maxwell ///     : vector<[8]xi16>, memref<?x?xi16>                  ┘
4045ed5d723SBenjamin Maxwell ///   %lower_slice_idx = %slice_idx + %c8_vscale                 ─┐
4055ed5d723SBenjamin Maxwell ///   %lower_slice_mask = vector.extract %mask[%lower_slice_idx]  |
4065ed5d723SBenjamin Maxwell ///     : vector<[8]xi1> from vector<[16]x[8]xi1>                 |
4075ed5d723SBenjamin Maxwell ///   %lower_slice = vector.extract %lower_tile[%slice_idx]       |- Store lower
4085ed5d723SBenjamin Maxwell ///     : vector<[8]xi16> from vector<[8]x[8]xi16>                |  tile
4095ed5d723SBenjamin Maxwell ///   vector.transfer_write %lower_slice,                         |
4105ed5d723SBenjamin Maxwell ///     %dest[%lower_slice_idx + %y, %x], %lower_slice_mask       |
4115ed5d723SBenjamin Maxwell ///     : vector<[8]xi16>, memref<?x?xi16>                        ┘
4125ed5d723SBenjamin Maxwell /// }
4135ed5d723SBenjamin Maxwell /// ```
4145ed5d723SBenjamin Maxwell struct LegalizeMultiTileTransferWriteAsStoreLoop
41531613de9SMatthias Springer     : public OpConversionPattern<vector::TransferWriteOp> {
41631613de9SMatthias Springer   using OpConversionPattern::OpConversionPattern;
4175ed5d723SBenjamin Maxwell 
4185ed5d723SBenjamin Maxwell   LogicalResult
41931613de9SMatthias Springer   matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
42031613de9SMatthias Springer                   ConversionPatternRewriter &rewriter) const override {
4215ed5d723SBenjamin Maxwell     if (writeOp.hasPureTensorSemantics())
4225ed5d723SBenjamin Maxwell       return rewriter.notifyMatchFailure(
4235ed5d723SBenjamin Maxwell           writeOp, "TODO: tensor semantics are unsupported");
4245ed5d723SBenjamin Maxwell 
4255ed5d723SBenjamin Maxwell     auto permutationMap = writeOp.getPermutationMap();
4265ed5d723SBenjamin Maxwell     if (!permutationMap.isPermutation())
4275ed5d723SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp,
4285ed5d723SBenjamin Maxwell                                          kMatchFailureNonPermutationMap);
4295ed5d723SBenjamin Maxwell 
4305ed5d723SBenjamin Maxwell     bool transposed = !permutationMap.isIdentity();
4315ed5d723SBenjamin Maxwell     if (transposed)
4325ed5d723SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp,
4335ed5d723SBenjamin Maxwell                                          "TODO: transpose unsupported");
4345ed5d723SBenjamin Maxwell 
4355ed5d723SBenjamin Maxwell     auto vectorType = writeOp.getVectorType();
4365ed5d723SBenjamin Maxwell     if (!isMultipleOfSMETileVectorType(vectorType))
4375ed5d723SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp,
4385ed5d723SBenjamin Maxwell                                          kMatchFailureNotSMETileTypeMultiple);
4395ed5d723SBenjamin Maxwell 
4405ed5d723SBenjamin Maxwell     // Note: We also disallow masks where any dimension is > 16 because that
4415ed5d723SBenjamin Maxwell     // prevents the masking from being lowered to use arm_sve.psel.
4425ed5d723SBenjamin Maxwell     auto mask = writeOp.getMask();
4435ed5d723SBenjamin Maxwell     if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
4445ed5d723SBenjamin Maxwell                                               vectorType.getDimSize(1) > 16)))
4455ed5d723SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp,
4465ed5d723SBenjamin Maxwell                                          kMatchFailureUnsupportedMaskOp);
4475ed5d723SBenjamin Maxwell 
4485ed5d723SBenjamin Maxwell     auto loc = writeOp.getLoc();
449c194bc77SBenjamin Maxwell     auto createVscaleMultiple =
450c194bc77SBenjamin Maxwell         vector::makeVscaleConstantBuilder(rewriter, loc);
4515ed5d723SBenjamin Maxwell 
4525ed5d723SBenjamin Maxwell     // Get SME tile and slice types.
4535ed5d723SBenjamin Maxwell     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
4545ed5d723SBenjamin Maxwell     auto minTileSlices = smeTileType.getDimSize(0);
4555ed5d723SBenjamin Maxwell     VectorType sliceMaskType =
4565ed5d723SBenjamin Maxwell         VectorType::get(minTileSlices, rewriter.getI1Type(), true);
4575ed5d723SBenjamin Maxwell 
4585ed5d723SBenjamin Maxwell     // Create loop over all tile slices.
4595ed5d723SBenjamin Maxwell     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4605ed5d723SBenjamin Maxwell     auto upperBound = createVscaleMultiple(minTileSlices);
4615ed5d723SBenjamin Maxwell     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
4625ed5d723SBenjamin Maxwell     auto storeLoop =
4635ed5d723SBenjamin Maxwell         rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
4645ed5d723SBenjamin Maxwell     rewriter.setInsertionPointToStart(storeLoop.getBody());
4655ed5d723SBenjamin Maxwell 
4665ed5d723SBenjamin Maxwell     // For each sub-tile of the multi-tile `vectorType`.
4675ed5d723SBenjamin Maxwell     auto inputSMETiles = adaptor.getVector();
4685ed5d723SBenjamin Maxwell     auto tileSliceIndex = storeLoop.getInductionVar();
4695ed5d723SBenjamin Maxwell     for (auto [index, smeTile] : llvm::enumerate(
4705ed5d723SBenjamin Maxwell              decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
4715ed5d723SBenjamin Maxwell       // The coordinates of the tile within `vectorType`.
4725ed5d723SBenjamin Maxwell       auto tileRow = createVscaleMultiple(smeTile.row);
4735ed5d723SBenjamin Maxwell       auto tileCol = createVscaleMultiple(smeTile.col);
4745ed5d723SBenjamin Maxwell 
4755ed5d723SBenjamin Maxwell       // The current slice of `vectorType` we are processing.
4765ed5d723SBenjamin Maxwell       auto sliceIndex =
4775ed5d723SBenjamin Maxwell           rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
4785ed5d723SBenjamin Maxwell 
4795ed5d723SBenjamin Maxwell       // Where in the destination memref the current slice will be stored.
4805ed5d723SBenjamin Maxwell       auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
4815ed5d723SBenjamin Maxwell                                                      writeOp.getIndices()[0]);
4825ed5d723SBenjamin Maxwell       auto storeCol =
4835ed5d723SBenjamin Maxwell           rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
4845ed5d723SBenjamin Maxwell 
4855ed5d723SBenjamin Maxwell       // Extract the mask for the current slice.
4865ed5d723SBenjamin Maxwell       Value sliceMask = nullptr;
4875ed5d723SBenjamin Maxwell       if (mask) {
4885ed5d723SBenjamin Maxwell         sliceMask = rewriter.create<vector::ExtractOp>(
4895ed5d723SBenjamin Maxwell             loc, mask, OpFoldResult(sliceIndex));
4905ed5d723SBenjamin Maxwell         if (sliceMaskType != sliceMask.getType())
4915ed5d723SBenjamin Maxwell           sliceMask = rewriter.create<vector::ScalableExtractOp>(
4925ed5d723SBenjamin Maxwell               loc, sliceMaskType, sliceMask, smeTile.col);
4935ed5d723SBenjamin Maxwell       }
4945ed5d723SBenjamin Maxwell 
4955ed5d723SBenjamin Maxwell       // Extract and store the current slice.
4965ed5d723SBenjamin Maxwell       Value tile = inputSMETiles[index];
4975ed5d723SBenjamin Maxwell       auto slice =
4985ed5d723SBenjamin Maxwell           rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
4995ed5d723SBenjamin Maxwell       rewriter.create<vector::TransferWriteOp>(
5005ed5d723SBenjamin Maxwell           loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
5015ed5d723SBenjamin Maxwell           AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
5025ed5d723SBenjamin Maxwell           sliceMask,
5035ed5d723SBenjamin Maxwell           rewriter.getBoolArrayAttr(
5045ed5d723SBenjamin Maxwell               ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
5055ed5d723SBenjamin Maxwell     }
5065ed5d723SBenjamin Maxwell 
5075ed5d723SBenjamin Maxwell     rewriter.eraseOp(writeOp);
5085ed5d723SBenjamin Maxwell     return success();
5095ed5d723SBenjamin Maxwell   }
5105ed5d723SBenjamin Maxwell };
5115ed5d723SBenjamin Maxwell 
512c2dea712SBenjamin Maxwell //===----------------------------------------------------------------------===//
513c2dea712SBenjamin Maxwell // ArmSME-specific fixup canonicalizations/folds
514c2dea712SBenjamin Maxwell //===----------------------------------------------------------------------===//
515c2dea712SBenjamin Maxwell 
516c2dea712SBenjamin Maxwell /// Folds an extract from a 3D `vector.create_mask` (which is a vector of
517c2dea712SBenjamin Maxwell /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
518c2dea712SBenjamin Maxwell /// necessary for the mask to be lowered to ArmSME.
519c2dea712SBenjamin Maxwell ///
520c2dea712SBenjamin Maxwell /// Example:
521c2dea712SBenjamin Maxwell ///
522c2dea712SBenjamin Maxwell ///  BEFORE:
523c2dea712SBenjamin Maxwell ///  ```mlir
524c2dea712SBenjamin Maxwell ///  %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
525c2dea712SBenjamin Maxwell ///  %subMask = vector.extract %mask[2]
526c2dea712SBenjamin Maxwell ///          : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
527c2dea712SBenjamin Maxwell ///  ```
528c2dea712SBenjamin Maxwell ///
529c2dea712SBenjamin Maxwell ///  AFTER:
530c2dea712SBenjamin Maxwell ///  ```mlir
531c2dea712SBenjamin Maxwell ///  %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
532c2dea712SBenjamin Maxwell ///  %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
533c2dea712SBenjamin Maxwell ///  %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
534c2dea712SBenjamin Maxwell ///  ```
535c2dea712SBenjamin Maxwell struct FoldExtractFromVectorOfSMELikeCreateMasks
536c2dea712SBenjamin Maxwell     : public OpRewritePattern<vector::ExtractOp> {
537c2dea712SBenjamin Maxwell   using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
538c2dea712SBenjamin Maxwell 
539c2dea712SBenjamin Maxwell   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
540c2dea712SBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
541c2dea712SBenjamin Maxwell     auto loc = extractOp.getLoc();
542c2dea712SBenjamin Maxwell     auto createMaskOp =
543c2dea712SBenjamin Maxwell         extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
544c2dea712SBenjamin Maxwell     if (!createMaskOp)
545c2dea712SBenjamin Maxwell       return rewriter.notifyMatchFailure(
546c2dea712SBenjamin Maxwell           extractOp, "extract not from vector.create_mask op");
547c2dea712SBenjamin Maxwell 
548c2dea712SBenjamin Maxwell     VectorType extractedMaskType =
549c2dea712SBenjamin Maxwell         llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
550c2dea712SBenjamin Maxwell     if (!extractedMaskType)
551c2dea712SBenjamin Maxwell       return rewriter.notifyMatchFailure(extractOp,
552c2dea712SBenjamin Maxwell                                          "extracted type is not a vector type");
553c2dea712SBenjamin Maxwell 
554fe07d9aaSAndrzej Warzyński     auto numScalable = extractedMaskType.getNumScalableDims();
555c2dea712SBenjamin Maxwell     if (numScalable != 2)
556c2dea712SBenjamin Maxwell       return rewriter.notifyMatchFailure(
557c2dea712SBenjamin Maxwell           extractOp, "expected extracted type to be an SME-like mask");
558c2dea712SBenjamin Maxwell 
559c2dea712SBenjamin Maxwell     // TODO: Support multiple extraction indices.
560c2dea712SBenjamin Maxwell     if (extractOp.getStaticPosition().size() != 1)
561c2dea712SBenjamin Maxwell       return rewriter.notifyMatchFailure(
562c2dea712SBenjamin Maxwell           extractOp, "only a single extraction index is supported");
563c2dea712SBenjamin Maxwell 
564c2dea712SBenjamin Maxwell     auto frontMaskDim = createMaskOp.getOperand(0);
565c2dea712SBenjamin Maxwell     if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
566c2dea712SBenjamin Maxwell       return rewriter.notifyMatchFailure(
567c2dea712SBenjamin Maxwell           extractOp,
568c2dea712SBenjamin Maxwell           "constant vector.create_masks dims should be folded elsewhere");
569c2dea712SBenjamin Maxwell 
570c2dea712SBenjamin Maxwell     auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
571c2dea712SBenjamin Maxwell     auto extractionIndex = getValueOrCreateConstantIndexOp(
572c2dea712SBenjamin Maxwell         rewriter, loc, extractOp.getMixedPosition()[0]);
573c2dea712SBenjamin Maxwell     auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
574c2dea712SBenjamin Maxwell         loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
575c2dea712SBenjamin Maxwell         frontMaskDim);
576c2dea712SBenjamin Maxwell     auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
577c2dea712SBenjamin Maxwell         loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
578c2dea712SBenjamin Maxwell 
579c2dea712SBenjamin Maxwell     rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
580c2dea712SBenjamin Maxwell         extractOp, extractedMaskType,
581c2dea712SBenjamin Maxwell         ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
582c2dea712SBenjamin Maxwell     return success();
583c2dea712SBenjamin Maxwell   }
584c2dea712SBenjamin Maxwell };
585c2dea712SBenjamin Maxwell 
586d1fc59c3SBenjamin Maxwell /// A vector type where no fixed dimension comes after a scalable dimension.
587d1fc59c3SBenjamin Maxwell bool isLegalVectorType(VectorType vType) {
588d1fc59c3SBenjamin Maxwell   bool seenFixedDim = false;
589d1fc59c3SBenjamin Maxwell   for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
590d1fc59c3SBenjamin Maxwell     seenFixedDim |= !scalableFlag;
591d1fc59c3SBenjamin Maxwell     if (seenFixedDim && scalableFlag)
592d1fc59c3SBenjamin Maxwell       return false;
593d1fc59c3SBenjamin Maxwell   }
594d1fc59c3SBenjamin Maxwell   return true;
595d1fc59c3SBenjamin Maxwell }
596d1fc59c3SBenjamin Maxwell 
5970473e322SBenjamin Maxwell /// Lifts an illegal vector.transpose and vector.transfer_read to a
5980473e322SBenjamin Maxwell /// memref.subview + memref.transpose, followed by a legal read.
5990473e322SBenjamin Maxwell ///
6000473e322SBenjamin Maxwell /// 'Illegal' here means a leading scalable dimension and a fixed trailing
6010473e322SBenjamin Maxwell /// dimension, which has no valid lowering.
6020473e322SBenjamin Maxwell ///
6030473e322SBenjamin Maxwell /// The memref.transpose is metadata-only transpose that produces a strided
6040473e322SBenjamin Maxwell /// memref, which eventually becomes a loop reading individual elements.
6050473e322SBenjamin Maxwell ///
6060473e322SBenjamin Maxwell /// Example:
6070473e322SBenjamin Maxwell ///
6080473e322SBenjamin Maxwell ///  BEFORE:
6090473e322SBenjamin Maxwell ///  ```mlir
6100473e322SBenjamin Maxwell ///  %illegalRead = vector.transfer_read %memref[%a, %b]
6110473e322SBenjamin Maxwell ///                  : memref<?x?xf32>, vector<[8]x4xf32>
6120473e322SBenjamin Maxwell ///  %legalType = vector.transpose %illegalRead, [1, 0]
6130473e322SBenjamin Maxwell ///                  : vector<[8]x4xf32> to vector<4x[8]xf32>
6140473e322SBenjamin Maxwell ///  ```
6150473e322SBenjamin Maxwell ///
6160473e322SBenjamin Maxwell ///  AFTER:
6170473e322SBenjamin Maxwell ///  ```mlir
6180473e322SBenjamin Maxwell ///  %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
6190473e322SBenjamin Maxwell ///                  : memref<?x?xf32> to memref<?x?xf32>
6200473e322SBenjamin Maxwell ///  %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
6210473e322SBenjamin Maxwell ///                  : memref<?x?xf32> to memref<?x?xf32>
6220473e322SBenjamin Maxwell ///  %legalType = vector.transfer_read %transpose[%c0, %c0]
6230473e322SBenjamin Maxwell ///                  : memref<?x?xf32>, vector<4x[8]xf32>
6240473e322SBenjamin Maxwell ///  ```
6250473e322SBenjamin Maxwell struct LiftIllegalVectorTransposeToMemory
6260473e322SBenjamin Maxwell     : public OpRewritePattern<vector::TransposeOp> {
6270473e322SBenjamin Maxwell   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
6280473e322SBenjamin Maxwell 
6290473e322SBenjamin Maxwell   static Value getExtensionSource(Operation *op) {
6308cfb7161SBenjamin Maxwell     if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
6310473e322SBenjamin Maxwell       return op->getOperand(0);
6320473e322SBenjamin Maxwell     return {};
6330473e322SBenjamin Maxwell   }
6340473e322SBenjamin Maxwell 
6350473e322SBenjamin Maxwell   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
6360473e322SBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
6370473e322SBenjamin Maxwell     auto sourceType = transposeOp.getSourceVectorType();
6380473e322SBenjamin Maxwell     auto resultType = transposeOp.getResultVectorType();
639d1fc59c3SBenjamin Maxwell     if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
640d1fc59c3SBenjamin Maxwell       return rewriter.notifyMatchFailure(transposeOp,
641d1fc59c3SBenjamin Maxwell                                          kMatchFailureNotIllegalToLegal);
6420473e322SBenjamin Maxwell 
6430473e322SBenjamin Maxwell     // Look through extend for transfer_read.
6440473e322SBenjamin Maxwell     Value maybeRead = transposeOp.getVector();
6450473e322SBenjamin Maxwell     auto *transposeSourceOp = maybeRead.getDefiningOp();
6460473e322SBenjamin Maxwell     Operation *extendOp = nullptr;
6470473e322SBenjamin Maxwell     if (Value extendSource = getExtensionSource(transposeSourceOp)) {
6480473e322SBenjamin Maxwell       maybeRead = extendSource;
6490473e322SBenjamin Maxwell       extendOp = transposeSourceOp;
6500473e322SBenjamin Maxwell     }
6510473e322SBenjamin Maxwell 
6520473e322SBenjamin Maxwell     auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
6530473e322SBenjamin Maxwell     if (!illegalRead)
6540473e322SBenjamin Maxwell       return rewriter.notifyMatchFailure(
6550473e322SBenjamin Maxwell           transposeOp,
6560473e322SBenjamin Maxwell           "expected source to be (possibly extended) transfer_read");
6570473e322SBenjamin Maxwell 
6580473e322SBenjamin Maxwell     if (!illegalRead.getPermutationMap().isIdentity())
6590473e322SBenjamin Maxwell       return rewriter.notifyMatchFailure(
6600473e322SBenjamin Maxwell           illegalRead, "expected read to have identity permutation map");
6610473e322SBenjamin Maxwell 
6620473e322SBenjamin Maxwell     auto loc = transposeOp.getLoc();
6630473e322SBenjamin Maxwell     auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
6640473e322SBenjamin Maxwell     auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
6650473e322SBenjamin Maxwell 
6660473e322SBenjamin Maxwell     // Create a subview that matches the size of the illegal read vector type.
6670473e322SBenjamin Maxwell     auto readType = illegalRead.getVectorType();
6680473e322SBenjamin Maxwell     auto readSizes = llvm::map_to_vector(
6690473e322SBenjamin Maxwell         llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
6700473e322SBenjamin Maxwell         [&](auto dim) -> Value {
6710473e322SBenjamin Maxwell           auto [size, isScalable] = dim;
6720473e322SBenjamin Maxwell           auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
6730473e322SBenjamin Maxwell           if (!isScalable)
6740473e322SBenjamin Maxwell             return dimSize;
6750473e322SBenjamin Maxwell           auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
6760473e322SBenjamin Maxwell           return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
6770473e322SBenjamin Maxwell         });
6780473e322SBenjamin Maxwell     SmallVector<Value> strides(readType.getRank(), Value(one));
6790473e322SBenjamin Maxwell     auto readSubview = rewriter.create<memref::SubViewOp>(
6800473e322SBenjamin Maxwell         loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
6810473e322SBenjamin Maxwell         strides);
6820473e322SBenjamin Maxwell 
6830473e322SBenjamin Maxwell     // Apply the transpose to all values/attributes of the transfer_read:
6840473e322SBenjamin Maxwell     // - The mask
6850473e322SBenjamin Maxwell     Value mask = illegalRead.getMask();
6860473e322SBenjamin Maxwell     if (mask) {
6870473e322SBenjamin Maxwell       // Note: The transpose for the mask should fold into the
6880473e322SBenjamin Maxwell       // vector.create_mask/constant_mask op, which will then become legal.
6890473e322SBenjamin Maxwell       mask = rewriter.create<vector::TransposeOp>(loc, mask,
6900473e322SBenjamin Maxwell                                                   transposeOp.getPermutation());
6910473e322SBenjamin Maxwell     }
6920473e322SBenjamin Maxwell     // - The source memref
6930473e322SBenjamin Maxwell     mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
6940473e322SBenjamin Maxwell         transposeOp.getPermutation(), getContext());
6950473e322SBenjamin Maxwell     auto transposedSubview = rewriter.create<memref::TransposeOp>(
6960473e322SBenjamin Maxwell         loc, readSubview, AffineMapAttr::get(transposeMap));
6970473e322SBenjamin Maxwell     ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
6980473e322SBenjamin Maxwell     // - The `in_bounds` attribute
6990473e322SBenjamin Maxwell     if (inBoundsAttr) {
7000473e322SBenjamin Maxwell       SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
7010473e322SBenjamin Maxwell                                             inBoundsAttr.end());
7020473e322SBenjamin Maxwell       applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
7030473e322SBenjamin Maxwell       inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
7040473e322SBenjamin Maxwell     }
7050473e322SBenjamin Maxwell 
7060473e322SBenjamin Maxwell     VectorType legalReadType = resultType.clone(readType.getElementType());
7070473e322SBenjamin Maxwell     // Note: The indices are all zero as the subview is already offset.
7080473e322SBenjamin Maxwell     SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
7090473e322SBenjamin Maxwell     auto legalRead = rewriter.create<vector::TransferReadOp>(
7100473e322SBenjamin Maxwell         loc, legalReadType, transposedSubview, readIndices,
7110473e322SBenjamin Maxwell         illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
7120473e322SBenjamin Maxwell         inBoundsAttr);
7130473e322SBenjamin Maxwell 
7140473e322SBenjamin Maxwell     // Replace the transpose with the new read, extending the result if
7150473e322SBenjamin Maxwell     // necessary.
7160473e322SBenjamin Maxwell     rewriter.replaceOp(transposeOp, [&]() -> Operation * {
7170473e322SBenjamin Maxwell       if (extendOp)
7180473e322SBenjamin Maxwell         return rewriter.create(loc, extendOp->getName().getIdentifier(),
7190473e322SBenjamin Maxwell                                Value(legalRead), resultType);
7200473e322SBenjamin Maxwell       return legalRead;
7210473e322SBenjamin Maxwell     }());
7220473e322SBenjamin Maxwell 
7230473e322SBenjamin Maxwell     return success();
7240473e322SBenjamin Maxwell   }
7250473e322SBenjamin Maxwell };
7260473e322SBenjamin Maxwell 
727d1fc59c3SBenjamin Maxwell /// A rewrite to turn unit dim transpose-like vector.shape_casts into
728d1fc59c3SBenjamin Maxwell /// vector.transposes. The shape_cast has to be from an illegal vector type to a
729d1fc59c3SBenjamin Maxwell /// legal one (as defined by isLegalVectorType).
730d1fc59c3SBenjamin Maxwell ///
731d1fc59c3SBenjamin Maxwell /// The reasoning for this is if we've got to this pass and we still have
732d1fc59c3SBenjamin Maxwell /// shape_casts of illegal types, then they likely will not cancel out. Turning
733d1fc59c3SBenjamin Maxwell /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
734d1fc59c3SBenjamin Maxwell /// eliminate them.
735d1fc59c3SBenjamin Maxwell ///
736d1fc59c3SBenjamin Maxwell /// Example:
737d1fc59c3SBenjamin Maxwell ///
738d1fc59c3SBenjamin Maxwell ///  BEFORE:
739d1fc59c3SBenjamin Maxwell ///  ```mlir
740d1fc59c3SBenjamin Maxwell ///  %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
741d1fc59c3SBenjamin Maxwell ///  ```
742d1fc59c3SBenjamin Maxwell ///
743d1fc59c3SBenjamin Maxwell ///  AFTER:
744d1fc59c3SBenjamin Maxwell ///  ```mlir
745d1fc59c3SBenjamin Maxwell ///  %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
746d1fc59c3SBenjamin Maxwell ///  ```
747d1fc59c3SBenjamin Maxwell struct ConvertIllegalShapeCastOpsToTransposes
748d1fc59c3SBenjamin Maxwell     : public OpRewritePattern<vector::ShapeCastOp> {
749d1fc59c3SBenjamin Maxwell   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
750d1fc59c3SBenjamin Maxwell 
751d1fc59c3SBenjamin Maxwell   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
752d1fc59c3SBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
753d1fc59c3SBenjamin Maxwell     auto sourceType = shapeCastOp.getSourceVectorType();
754d1fc59c3SBenjamin Maxwell     auto resultType = shapeCastOp.getResultVectorType();
755d1fc59c3SBenjamin Maxwell     if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
756d1fc59c3SBenjamin Maxwell       return rewriter.notifyMatchFailure(shapeCastOp,
757d1fc59c3SBenjamin Maxwell                                          kMatchFailureNotIllegalToLegal);
758d1fc59c3SBenjamin Maxwell 
759d1fc59c3SBenjamin Maxwell     // Note: If we know that `sourceType` is an illegal vector type (and 2D)
760d1fc59c3SBenjamin Maxwell     // then dim 0 is scalable and dim 1 is fixed.
761d1fc59c3SBenjamin Maxwell     if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
762d1fc59c3SBenjamin Maxwell       return rewriter.notifyMatchFailure(
763d1fc59c3SBenjamin Maxwell           shapeCastOp, "expected source to be a 2D scalable vector with a "
764d1fc59c3SBenjamin Maxwell                        "trailing unit dim");
765d1fc59c3SBenjamin Maxwell 
766d1fc59c3SBenjamin Maxwell     auto loc = shapeCastOp.getLoc();
767d1fc59c3SBenjamin Maxwell     auto transpose = rewriter.create<vector::TransposeOp>(
768d1fc59c3SBenjamin Maxwell         loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
769d1fc59c3SBenjamin Maxwell 
770d1fc59c3SBenjamin Maxwell     if (resultType.getRank() == 1)
771d1fc59c3SBenjamin Maxwell       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
772d1fc59c3SBenjamin Maxwell                                                        transpose);
773d1fc59c3SBenjamin Maxwell     else
774d1fc59c3SBenjamin Maxwell       rewriter.replaceOp(shapeCastOp, transpose);
775d1fc59c3SBenjamin Maxwell 
776d1fc59c3SBenjamin Maxwell     return success();
777d1fc59c3SBenjamin Maxwell   }
778d1fc59c3SBenjamin Maxwell };
779d1fc59c3SBenjamin Maxwell 
780c194bc77SBenjamin Maxwell /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
781c194bc77SBenjamin Maxwell /// the ZA state. This workaround rewrite to support these transposes when ZA is
782c194bc77SBenjamin Maxwell /// available.
783c194bc77SBenjamin Maxwell ///
784c194bc77SBenjamin Maxwell /// Example:
785c194bc77SBenjamin Maxwell ///
786c194bc77SBenjamin Maxwell ///  BEFORE:
787c194bc77SBenjamin Maxwell ///  ```mlir
788c194bc77SBenjamin Maxwell ///  %transpose = vector.transpose %vec, [1, 0]
789c194bc77SBenjamin Maxwell ///     : vector<2x[4]xf32> to vector<[4]x2xf32>
790c194bc77SBenjamin Maxwell ///  vector.transfer_write %transpose, %dest[%y, %x]
791c194bc77SBenjamin Maxwell ///     : vector<[4]x2xf32>,  memref<?x?xf32>
792c194bc77SBenjamin Maxwell ///  ```
793c194bc77SBenjamin Maxwell ///
794c194bc77SBenjamin Maxwell ///  AFTER:
795c194bc77SBenjamin Maxwell ///  ```mlir
796c194bc77SBenjamin Maxwell ///   %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
797c194bc77SBenjamin Maxwell ///   %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
798c194bc77SBenjamin Maxwell ///   %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
799c194bc77SBenjamin Maxwell ///   %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
800c194bc77SBenjamin Maxwell ///   %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
801c194bc77SBenjamin Maxwell ///   %c4_vscale = arith.muli %vscale, %c4 : index
802c194bc77SBenjamin Maxwell ///   %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
803c194bc77SBenjamin Maxwell ///   vector.transfer_write %4, %dest[%y, %x], %mask
804c194bc77SBenjamin Maxwell ///      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
805c194bc77SBenjamin Maxwell ///      : vector<[4]x[4]xf32>, memref<?x?xf32>
806c194bc77SBenjamin Maxwell ///  ```
807c194bc77SBenjamin Maxwell ///
808c194bc77SBenjamin Maxwell /// Values larger than a single tile are supported via decomposition.
809c194bc77SBenjamin Maxwell struct LowerIllegalTransposeStoreViaZA
810c194bc77SBenjamin Maxwell     : public OpRewritePattern<vector::TransferWriteOp> {
811c194bc77SBenjamin Maxwell   using OpRewritePattern::OpRewritePattern;
812c194bc77SBenjamin Maxwell 
813c194bc77SBenjamin Maxwell   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
814c194bc77SBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
815c194bc77SBenjamin Maxwell     if (!isSupportedMaskOp(writeOp.getMask()))
816c194bc77SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp,
817c194bc77SBenjamin Maxwell                                          kMatchFailureUnsupportedMaskOp);
818c194bc77SBenjamin Maxwell 
819c194bc77SBenjamin Maxwell     auto permutationMap = writeOp.getPermutationMap();
820c194bc77SBenjamin Maxwell     if (!permutationMap.isIdentity())
821c194bc77SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp,
822c194bc77SBenjamin Maxwell                                          kMatchFailureNonPermutationMap);
823c194bc77SBenjamin Maxwell 
824c194bc77SBenjamin Maxwell     auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
825c194bc77SBenjamin Maxwell     if (!transposeOp)
826c194bc77SBenjamin Maxwell       return failure();
827c194bc77SBenjamin Maxwell 
828c194bc77SBenjamin Maxwell     auto sourceType = transposeOp.getSourceVectorType();
829c194bc77SBenjamin Maxwell     auto resultType = transposeOp.getResultVectorType();
830c194bc77SBenjamin Maxwell 
831c194bc77SBenjamin Maxwell     if (resultType.getRank() != 2)
832c194bc77SBenjamin Maxwell       return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");
833c194bc77SBenjamin Maxwell 
834c194bc77SBenjamin Maxwell     if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
835c194bc77SBenjamin Maxwell       return rewriter.notifyMatchFailure(
836c194bc77SBenjamin Maxwell           transposeOp, "not illegal/unsupported SVE transpose");
837c194bc77SBenjamin Maxwell 
838c194bc77SBenjamin Maxwell     auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
839c194bc77SBenjamin Maxwell     VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
840c194bc77SBenjamin Maxwell 
841c194bc77SBenjamin Maxwell     if (sourceType.getDimSize(0) <= 1 ||
842c194bc77SBenjamin Maxwell         sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
843c194bc77SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
844c194bc77SBenjamin Maxwell 
845c194bc77SBenjamin Maxwell     auto loc = writeOp.getLoc();
846c194bc77SBenjamin Maxwell     auto createVscaleMultiple =
847c194bc77SBenjamin Maxwell         vector::makeVscaleConstantBuilder(rewriter, loc);
848c194bc77SBenjamin Maxwell 
849c194bc77SBenjamin Maxwell     auto transposeMap = AffineMapAttr::get(
850c194bc77SBenjamin Maxwell         AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
851c194bc77SBenjamin Maxwell 
852c194bc77SBenjamin Maxwell     // Note: We need to use `get_tile` as there's no vector-level `undef`.
853c194bc77SBenjamin Maxwell     Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
854c194bc77SBenjamin Maxwell     Value destTensorOrMemref = writeOp.getSource();
855c194bc77SBenjamin Maxwell     auto numSlicesPerTile =
856c194bc77SBenjamin Maxwell         std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
857c194bc77SBenjamin Maxwell     auto numSlices =
858c194bc77SBenjamin Maxwell         rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
859c194bc77SBenjamin Maxwell     for (auto [index, smeTile] : llvm::enumerate(
860c194bc77SBenjamin Maxwell              decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
861c194bc77SBenjamin Maxwell       // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
862c194bc77SBenjamin Maxwell       // of slices from the source type into the SME tile. Without checking
863c194bc77SBenjamin Maxwell       // vscale (and emitting multiple implementations) we can't make use of the
864c194bc77SBenjamin Maxwell       // rows of the tile after 1*vscale rows.
865c194bc77SBenjamin Maxwell       Value tile = undefTile;
866c194bc77SBenjamin Maxwell       for (int d = 0; d < numSlicesPerTile; ++d) {
867c194bc77SBenjamin Maxwell         Value vector = rewriter.create<vector::ExtractOp>(
868c194bc77SBenjamin Maxwell             loc, transposeOp.getVector(),
869c194bc77SBenjamin Maxwell             rewriter.getIndexAttr(d + smeTile.row));
870c194bc77SBenjamin Maxwell         if (vector.getType() != smeSliceType) {
871c194bc77SBenjamin Maxwell           vector = rewriter.create<vector::ScalableExtractOp>(
872c194bc77SBenjamin Maxwell               loc, smeSliceType, vector, smeTile.col);
873c194bc77SBenjamin Maxwell         }
874c194bc77SBenjamin Maxwell         tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d);
875c194bc77SBenjamin Maxwell       }
876c194bc77SBenjamin Maxwell 
877c194bc77SBenjamin Maxwell       // 2. Transpose the tile position.
878c194bc77SBenjamin Maxwell       auto transposedRow = createVscaleMultiple(smeTile.col);
879c194bc77SBenjamin Maxwell       auto transposedCol =
880c194bc77SBenjamin Maxwell           rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row);
881c194bc77SBenjamin Maxwell 
882c194bc77SBenjamin Maxwell       // 3. Compute mask for tile store.
883c194bc77SBenjamin Maxwell       Value maskRows;
884c194bc77SBenjamin Maxwell       Value maskCols;
885c194bc77SBenjamin Maxwell       if (auto mask = writeOp.getMask()) {
886c194bc77SBenjamin Maxwell         auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
887c194bc77SBenjamin Maxwell         maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0),
888c194bc77SBenjamin Maxwell                                                   transposedRow);
889c194bc77SBenjamin Maxwell         maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1),
890c194bc77SBenjamin Maxwell                                                   transposedCol);
891c194bc77SBenjamin Maxwell         maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices);
892c194bc77SBenjamin Maxwell       } else {
893c194bc77SBenjamin Maxwell         maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
894c194bc77SBenjamin Maxwell         maskCols = numSlices;
895c194bc77SBenjamin Maxwell       }
896c194bc77SBenjamin Maxwell       auto subMask = rewriter.create<vector::CreateMaskOp>(
897c194bc77SBenjamin Maxwell           loc, smeTileType.clone(rewriter.getI1Type()),
898c194bc77SBenjamin Maxwell           ValueRange{maskRows, maskCols});
899c194bc77SBenjamin Maxwell 
900c194bc77SBenjamin Maxwell       // 4. Emit a transposed tile write.
901c194bc77SBenjamin Maxwell       auto writeIndices = writeOp.getIndices();
902c194bc77SBenjamin Maxwell       Value destRow =
903c194bc77SBenjamin Maxwell           rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
904c194bc77SBenjamin Maxwell       Value destCol =
905c194bc77SBenjamin Maxwell           rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
906c194bc77SBenjamin Maxwell       auto smeWrite = rewriter.create<vector::TransferWriteOp>(
907c194bc77SBenjamin Maxwell           loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
908c194bc77SBenjamin Maxwell           transposeMap, subMask, writeOp.getInBounds());
909c194bc77SBenjamin Maxwell 
910c194bc77SBenjamin Maxwell       if (writeOp.hasPureTensorSemantics())
911c194bc77SBenjamin Maxwell         destTensorOrMemref = smeWrite.getResult();
912c194bc77SBenjamin Maxwell     }
913c194bc77SBenjamin Maxwell 
914c194bc77SBenjamin Maxwell     if (writeOp.hasPureTensorSemantics())
915c194bc77SBenjamin Maxwell       rewriter.replaceOp(writeOp, destTensorOrMemref);
916c194bc77SBenjamin Maxwell     else
917c194bc77SBenjamin Maxwell       rewriter.eraseOp(writeOp);
918c194bc77SBenjamin Maxwell 
919c194bc77SBenjamin Maxwell     return success();
920c194bc77SBenjamin Maxwell   }
921c194bc77SBenjamin Maxwell };
922c194bc77SBenjamin Maxwell 
923042800a4SBenjamin Maxwell struct VectorLegalizationPass
924042800a4SBenjamin Maxwell     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
925042800a4SBenjamin Maxwell   void runOnOperation() override {
926042800a4SBenjamin Maxwell     auto *context = &getContext();
9278c4bc1e7SMatthias Springer     TypeConverter converter;
928042800a4SBenjamin Maxwell     RewritePatternSet patterns(context);
929042800a4SBenjamin Maxwell     converter.addConversion([](Type type) { return type; });
930042800a4SBenjamin Maxwell     converter.addConversion(
931042800a4SBenjamin Maxwell         [](VectorType vectorType,
932042800a4SBenjamin Maxwell            SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
933042800a4SBenjamin Maxwell           if (!isMultipleOfSMETileVectorType(vectorType))
934042800a4SBenjamin Maxwell             return std::nullopt;
935042800a4SBenjamin Maxwell           auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
936042800a4SBenjamin Maxwell           auto smeTileType =
937042800a4SBenjamin Maxwell               getSMETileTypeForElement(vectorType.getElementType());
938042800a4SBenjamin Maxwell           types = SmallVector<Type>(smeTileCount, smeTileType);
939042800a4SBenjamin Maxwell           return success();
940042800a4SBenjamin Maxwell         });
941042800a4SBenjamin Maxwell 
94231613de9SMatthias Springer     // Apply preprocessing patterns.
94331613de9SMatthias Springer     RewritePatternSet rewritePatterns(context);
94431613de9SMatthias Springer     rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
945d1fc59c3SBenjamin Maxwell                         LiftIllegalVectorTransposeToMemory,
946c194bc77SBenjamin Maxwell                         ConvertIllegalShapeCastOpsToTransposes,
947fc4485bfSBenjamin Maxwell                         LowerIllegalTransposeStoreViaZA>(context);
94831613de9SMatthias Springer     if (failed(
94931613de9SMatthias Springer             applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
95031613de9SMatthias Springer       return signalPassFailure();
95131613de9SMatthias Springer 
9525ed5d723SBenjamin Maxwell     // Note: These two patterns are added with a high benefit to ensure:
9535ed5d723SBenjamin Maxwell     //  - Masked outer products are handled before unmasked ones
9545ed5d723SBenjamin Maxwell     //  - Multi-tile writes are lowered as a store loop (if possible)
9555ed5d723SBenjamin Maxwell     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
9565ed5d723SBenjamin Maxwell                  LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
9575ed5d723SBenjamin Maxwell                                                             /*benefit=*/1024);
958dadcaf82SBenjamin Maxwell     patterns.add<LegalizeArithConstantOpsByDecomposition,
959dadcaf82SBenjamin Maxwell                  LegalizeVectorOuterProductOpsByDecomposition,
960042800a4SBenjamin Maxwell                  LegalizeTransferReadOpsByDecomposition,
961042800a4SBenjamin Maxwell                  LegalizeTransferWriteOpsByDecomposition>(converter, context);
96231613de9SMatthias Springer     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
96331613de9SMatthias Springer                                                                    converter);
96431613de9SMatthias Springer     populateCallOpTypeConversionPattern(patterns, converter);
96531613de9SMatthias Springer     populateReturnOpTypeConversionPattern(patterns, converter);
96631613de9SMatthias Springer     scf::populateSCFStructuralTypeConversions(converter, patterns);
967042800a4SBenjamin Maxwell 
96831613de9SMatthias Springer     ConversionTarget target(getContext());
96931613de9SMatthias Springer     target.markUnknownOpDynamicallyLegal(
97031613de9SMatthias Springer         [&](Operation *op) { return converter.isLegal(op); });
97131613de9SMatthias Springer     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
97231613de9SMatthias Springer       return converter.isSignatureLegal(op.getFunctionType());
97331613de9SMatthias Springer     });
97431613de9SMatthias Springer     if (failed(applyPartialConversion(getOperation(), target,
975042800a4SBenjamin Maxwell                                       std::move(patterns))))
976042800a4SBenjamin Maxwell       return signalPassFailure();
977042800a4SBenjamin Maxwell   }
978042800a4SBenjamin Maxwell };
979042800a4SBenjamin Maxwell 
980042800a4SBenjamin Maxwell } // namespace
981042800a4SBenjamin Maxwell 
982042800a4SBenjamin Maxwell std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
983042800a4SBenjamin Maxwell   return std::make_unique<VectorLegalizationPass>();
984042800a4SBenjamin Maxwell }
985