xref: /llvm-project/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
1 //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass legalizes vector operations so they can be lowered to ArmSME.
10 //
11 // Note: In the context of this pass 'tile' always refers to an SME tile.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Arith/Utils/Utils.h"
16 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
17 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
18 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/Func/Transforms/FuncConversions.h"
21 #include "mlir/Dialect/Index/IR/IndexDialect.h"
22 #include "mlir/Dialect/Index/IR/IndexOps.h"
23 #include "mlir/Dialect/MemRef/IR/MemRef.h"
24 #include "mlir/Dialect/SCF/IR/SCF.h"
25 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
26 #include "mlir/Dialect/Utils/IndexingUtils.h"
27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28 #include "mlir/Transforms/DialectConversion.h"
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30 
31 #define DEBUG_TYPE "arm-sme-vector-legalization"
32 
33 namespace mlir::arm_sme {
34 #define GEN_PASS_DEF_VECTORLEGALIZATION
35 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
36 } // namespace mlir::arm_sme
37 
38 using namespace mlir;
39 using namespace mlir::arm_sme;
40 
41 namespace {
42 
43 //===----------------------------------------------------------------------===//
44 // Decomposition of vector operations larger than an SME tile
45 //===----------------------------------------------------------------------===//
46 
47 // Common match failure reasons.
48 static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
49     "op vector size is not multiple of SME tiles");
50 static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
51     "op mask is unsupported for legalization/decomposition");
52 static constexpr StringLiteral
53     kMatchFailureNonPermutationMap("op affine map is not a permutation");
54 static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
55     "expected transpose from illegal type to legal type");
56 
57 /// An SMESubTile represents a single SME-sized sub-tile from decomposing a
58 /// larger vector type. The (`row`, `col`) are the position of the tile in the
59 /// original vector type. For example for an [8]x[8] tile with four [4]x[4]
60 /// sub-tiles, we would have:
61 ///
62 ///           8 x vscale
63 /// ┌─────────────┬─────────────┐
64 /// │(0,0)        │(0,4)        │
65 /// │             │             │
66 /// ├─────────────┼─────────────┤ 8 x vscale
67 /// │(4,0)        │(4,4)        │
68 /// │             │             │
69 /// └─────────────┴─────────────┘
70 struct SMESubTile {
71   // Note: The units of (row, col) are vscale (as SME tiles are scalable).
72   int row{0};
73   int col{0};
74   // The SME tile type.
75   VectorType type;
76 };
77 
78 /// Adds a constant elementwise scalable offset to `indices` (which are of equal
79 /// length). For example, in the 2D case this would return:
80 // { indices[0] + offset[0] * vscale, indices[1] + offset[1] *  vscale }
81 SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
82                                                 Location loc,
83                                                 ValueRange indices,
84                                                 ArrayRef<int> scalableOffsets) {
85   auto vscale = builder.create<vector::VectorScaleOp>(loc);
86   return llvm::map_to_vector(
87       llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
88         auto [index, base] = pair;
89         auto offset = builder.create<arith::MulIOp>(
90             loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale);
91         return builder.create<arith::AddIOp>(loc, index, offset);
92       });
93 }
94 
95 /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
96 /// indices for one of the SME sub-tiles it will decompose into.
97 ///
98 /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
99 /// indices for each tile would need to be adjusted as follows:
100 ///
101 /// initial indices = [a,b], inital size = 8x8, target size = 4x4
102 /// ┌─────────────┬─────────────┐
103 /// │[a,b]        │[a,b+4]      │
104 /// │             │             │
105 /// ├─────────────┼─────────────┤
106 /// │[a+4,b]      │[a+4,b+4]    │
107 /// │             │             │
108 /// └─────────────┴─────────────┘
109 SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
110                                            ValueRange indices,
111                                            SMESubTile smeTile) {
112   return addConstantScalableOffset(builder, loc, indices,
113                                    {smeTile.row, smeTile.col});
114 }
115 
116 /// Returns true if `mask` is generated by an operation that can be decomposed
117 /// for SME. Currently, that is just no mask, or vector.create_mask.
118 /// TODO: Add support for vector.constant_mask once required for SME.
119 bool isSupportedMaskOp(Value mask) {
120   return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
121 }
122 
123 /// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
124 Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
125                      SMESubTile smeTile) {
126   assert(isSupportedMaskOp(mask));
127   if (!mask)
128     return Value{};
129   auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
130   // The operands of `vector.create_mask` (from a 2D perspective) are the
131   // coordinates where the mask ends. So we subtract where this tile starts,
132   // from the mask operands to get the parameters for this sub-tile.
133   auto smeTileMaskDims = addConstantScalableOffset(
134       builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
135   auto smeTileCreateMask = builder.create<vector::CreateMaskOp>(
136       loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
137   return smeTileCreateMask.getResult();
138 }
139 
140 /// Constructs an iterator that returns each SME tile (with coordinates)
141 /// contained within a VectorType. For example, if decomposing an [8]x[8] into
142 /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
143 /// (4, 4).
144 auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
145                          VectorType smeTileType,
146                          bool transposeIndices = false) {
147   return llvm::map_range(
148       StaticTileOffsetRange(
149           type.getShape(),
150           {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
151            std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
152       [=](auto indices) {
153         int row = int(indices[0]);
154         int col = int(indices[1]);
155         if (transposeIndices)
156           std::swap(row, col);
157         return SMESubTile{row, col, smeTileType};
158       });
159 }
160 
161 /// Returns the number of SME tiles that fit into the (2D-scalable) vector type
162 /// `type`.
163 int getNumberOfSMETilesForVectorType(VectorType type) {
164   assert(isMultipleOfSMETileVectorType(type) &&
165          "`type` not multiple of SME tiles");
166   int64_t vectorRows = type.getDimSize(0);
167   int64_t vectorCols = type.getDimSize(1);
168   auto elementType = type.getElementType();
169   unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
170   return (vectorRows * vectorCols) / (minNumElts * minNumElts);
171 }
172 
173 /// Legalize `arith.constant dense<value>` splat operations to fit within SME
174 /// tiles by decomposing them into tile-sized operations.
175 struct LegalizeArithConstantOpsByDecomposition
176     : public OpConversionPattern<arith::ConstantOp> {
177   using OpConversionPattern::OpConversionPattern;
178 
179   LogicalResult
180   matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
181                   ConversionPatternRewriter &rewriter) const override {
182     auto vectorType = dyn_cast<VectorType>(constantOp.getType());
183     auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
184     if (!vectorType || !denseAttr || !denseAttr.isSplat())
185       return failure();
186 
187     if (!isMultipleOfSMETileVectorType(vectorType))
188       return rewriter.notifyMatchFailure(constantOp,
189                                          kMatchFailureNotSMETileTypeMultiple);
190 
191     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
192     auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
193     auto tileSplat = rewriter.create<arith::ConstantOp>(
194         constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
195     SmallVector<Value> repl(tileCount, tileSplat);
196     rewriter.replaceOpWithMultiple(constantOp, {repl});
197 
198     return success();
199   }
200 };
201 
202 /// Legalize `vector.outerproduct` operations to fit within SME tiles by
203 /// decomposing them into tile-sized operations.
204 struct LegalizeVectorOuterProductOpsByDecomposition
205     : public OpConversionPattern<vector::OuterProductOp> {
206   using OpConversionPattern::OpConversionPattern;
207 
208   LogicalResult
209   matchAndRewrite(vector::OuterProductOp outerProductOp,
210                   OneToNOpAdaptor adaptor,
211                   ConversionPatternRewriter &rewriter) const override {
212     auto vectorType = outerProductOp.getResultVectorType();
213     if (!isMultipleOfSMETileVectorType(vectorType))
214       return rewriter.notifyMatchFailure(outerProductOp,
215                                          kMatchFailureNotSMETileTypeMultiple);
216 
217     Value mask;
218     Operation *rootOp = outerProductOp;
219     auto loc = outerProductOp.getLoc();
220     if (outerProductOp.isMasked()) {
221       auto maskOp = outerProductOp.getMaskingOp();
222       mask = maskOp.getMask();
223       rootOp = maskOp;
224       rewriter.setInsertionPoint(rootOp);
225     }
226 
227     if (!isSupportedMaskOp(mask))
228       return rewriter.notifyMatchFailure(outerProductOp,
229                                          kMatchFailureUnsupportedMaskOp);
230 
231     ValueRange accSMETiles = adaptor.getAcc();
232     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
233     VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
234 
235     SmallVector<Value> resultSMETiles;
236     for (auto [index, smeTile] : llvm::enumerate(
237              decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
238 
239       auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
240       auto lhs = rewriter.create<vector::ScalableExtractOp>(
241           loc, sliceType, outerProductOp.getLhs(), smeTile.row);
242       auto rhs = rewriter.create<vector::ScalableExtractOp>(
243           loc, sliceType, outerProductOp.getRhs(), smeTile.col);
244       auto smeOuterProduct = rewriter.create<vector::OuterProductOp>(
245           loc, smeTileType, lhs, rhs,
246           !accSMETiles.empty() ? accSMETiles[index] : Value{},
247           outerProductOp.getKind());
248 
249       auto maskedOuterProduct =
250           vector::maskOperation(rewriter, smeOuterProduct, smeMask);
251       resultSMETiles.push_back(maskedOuterProduct->getResult(0));
252     }
253 
254     rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
255     return success();
256   }
257 };
258 
259 // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
260 // get the help of the type conversion), but doing so results in the type
261 // conversion adding target materializations in the `vector.mask` region
262 // (invalid). This pattern matches on `vector.mask` then calls into the
263 // `vector.outerproduct` pattern to work around this issue.
264 struct LegalizeMaskedVectorOuterProductOpsByDecomposition
265     : public OpConversionPattern<vector::MaskOp> {
266   using OpConversionPattern::OpConversionPattern;
267 
268   LogicalResult
269   matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
270                   ConversionPatternRewriter &rewriter) const override {
271     if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
272             maskOp.getMaskableOp())) {
273       LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
274                                                            getContext());
275       return static_cast<RewritePattern &>(pattern).matchAndRewrite(
276           outerProductOp, rewriter);
277     }
278     return failure();
279   }
280 };
281 
282 /// Legalize `vector.transfer_read` operations to fit within SME tiles by
283 /// decomposing them into tile-sized operations.
284 struct LegalizeTransferReadOpsByDecomposition
285     : public OpConversionPattern<vector::TransferReadOp> {
286   using OpConversionPattern::OpConversionPattern;
287 
288   LogicalResult
289   matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
290                   ConversionPatternRewriter &rewriter) const override {
291     auto vectorType = readOp.getVectorType();
292     if (!isMultipleOfSMETileVectorType(vectorType))
293       return rewriter.notifyMatchFailure(readOp,
294                                          kMatchFailureNotSMETileTypeMultiple);
295 
296     auto mask = readOp.getMask();
297     if (!isSupportedMaskOp(mask))
298       return rewriter.notifyMatchFailure(readOp,
299                                          kMatchFailureUnsupportedMaskOp);
300 
301     auto permutationMap = readOp.getPermutationMap();
302     if (!permutationMap.isPermutation())
303       return rewriter.notifyMatchFailure(readOp,
304                                          kMatchFailureNonPermutationMap);
305 
306     // Note: For 2D vector types the only non-identity permutation is a simple
307     // transpose [1, 0].
308     bool transposed = !permutationMap.isIdentity();
309 
310     auto loc = readOp.getLoc();
311     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
312 
313     SmallVector<Value> resultSMETiles;
314     for (SMESubTile smeTile :
315          decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
316       auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
317       auto smeRead = rewriter.create<vector::TransferReadOp>(
318           loc, smeTileType, readOp.getSource(),
319           getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
320           readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
321           readOp.getInBoundsAttr());
322       resultSMETiles.push_back(smeRead);
323     }
324 
325     rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
326     return success();
327   }
328 };
329 
330 /// Legalize `vector.transfer_write` operations to fit within SME tiles by
331 /// decomposing them into tile-sized operations.
332 struct LegalizeTransferWriteOpsByDecomposition
333     : public OpConversionPattern<vector::TransferWriteOp> {
334   using OpConversionPattern::OpConversionPattern;
335 
336   LogicalResult
337   matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
338                   ConversionPatternRewriter &rewriter) const override {
339     auto vectorType = writeOp.getVectorType();
340     if (!isMultipleOfSMETileVectorType(vectorType))
341       return rewriter.notifyMatchFailure(writeOp,
342                                          kMatchFailureNotSMETileTypeMultiple);
343 
344     auto mask = writeOp.getMask();
345     if (!isSupportedMaskOp(mask))
346       return rewriter.notifyMatchFailure(writeOp,
347                                          kMatchFailureUnsupportedMaskOp);
348 
349     auto permutationMap = writeOp.getPermutationMap();
350     if (!permutationMap.isPermutation())
351       return rewriter.notifyMatchFailure(writeOp,
352                                          kMatchFailureNonPermutationMap);
353 
354     // Note: For 2D vector types the only non-identity permutation is a simple
355     // transpose [1, 0].
356     bool transposed = !permutationMap.isIdentity();
357 
358     auto loc = writeOp.getLoc();
359     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
360     auto inputSMETiles = adaptor.getVector();
361 
362     Value destTensorOrMemref = writeOp.getSource();
363     for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
364              rewriter, vectorType, smeTileType, transposed))) {
365       auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
366       auto smeWrite = rewriter.create<vector::TransferWriteOp>(
367           loc, inputSMETiles[index], destTensorOrMemref,
368           getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
369           writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
370       if (writeOp.hasPureTensorSemantics())
371         destTensorOrMemref = smeWrite.getResult();
372     }
373 
374     if (writeOp.hasPureTensorSemantics())
375       rewriter.replaceOp(writeOp, destTensorOrMemref);
376     else
377       rewriter.eraseOp(writeOp);
378 
379     return success();
380   }
381 };
382 
383 /// Legalize a multi-tile transfer_write as a single store loop. This is done as
384 /// part of type decomposition as at this level we know each tile write is
385 /// disjoint, but that information is lost after decomposition (without analysis
386 /// to reconstruct it).
387 ///
388 /// Example (pseudo-MLIR):
389 ///
390 /// ```
391 /// vector.transfer_write %vector, %dest[%y, %x], %mask
392 ///   : vector<[16]x[8]xi16>, memref<?x?xi16>
393 /// ```
394 /// Is rewritten to:
395 /// ```
396 /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
397 ///   %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
398 ///     : vector<[8]xi1> from vector<[16]x[8]xi1>           |
399 ///   %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
400 ///     : vector<[8]xi16> from vector<[8]x[8]xi16>          |
401 ///   vector.transfer_write %upper_slice,                   |
402 ///     %dest[%slice_idx + %y, %x], %upper_slice_mask       |
403 ///     : vector<[8]xi16>, memref<?x?xi16>                  ┘
404 ///   %lower_slice_idx = %slice_idx + %c8_vscale                 ─┐
405 ///   %lower_slice_mask = vector.extract %mask[%lower_slice_idx]  |
406 ///     : vector<[8]xi1> from vector<[16]x[8]xi1>                 |
407 ///   %lower_slice = vector.extract %lower_tile[%slice_idx]       |- Store lower
408 ///     : vector<[8]xi16> from vector<[8]x[8]xi16>                |  tile
409 ///   vector.transfer_write %lower_slice,                         |
410 ///     %dest[%lower_slice_idx + %y, %x], %lower_slice_mask       |
411 ///     : vector<[8]xi16>, memref<?x?xi16>                        ┘
412 /// }
413 /// ```
414 struct LegalizeMultiTileTransferWriteAsStoreLoop
415     : public OpConversionPattern<vector::TransferWriteOp> {
416   using OpConversionPattern::OpConversionPattern;
417 
418   LogicalResult
419   matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
420                   ConversionPatternRewriter &rewriter) const override {
421     if (writeOp.hasPureTensorSemantics())
422       return rewriter.notifyMatchFailure(
423           writeOp, "TODO: tensor semantics are unsupported");
424 
425     auto permutationMap = writeOp.getPermutationMap();
426     if (!permutationMap.isPermutation())
427       return rewriter.notifyMatchFailure(writeOp,
428                                          kMatchFailureNonPermutationMap);
429 
430     bool transposed = !permutationMap.isIdentity();
431     if (transposed)
432       return rewriter.notifyMatchFailure(writeOp,
433                                          "TODO: transpose unsupported");
434 
435     auto vectorType = writeOp.getVectorType();
436     if (!isMultipleOfSMETileVectorType(vectorType))
437       return rewriter.notifyMatchFailure(writeOp,
438                                          kMatchFailureNotSMETileTypeMultiple);
439 
440     // Note: We also disallow masks where any dimension is > 16 because that
441     // prevents the masking from being lowered to use arm_sve.psel.
442     auto mask = writeOp.getMask();
443     if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
444                                               vectorType.getDimSize(1) > 16)))
445       return rewriter.notifyMatchFailure(writeOp,
446                                          kMatchFailureUnsupportedMaskOp);
447 
448     auto loc = writeOp.getLoc();
449     auto createVscaleMultiple =
450         vector::makeVscaleConstantBuilder(rewriter, loc);
451 
452     // Get SME tile and slice types.
453     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
454     auto minTileSlices = smeTileType.getDimSize(0);
455     VectorType sliceMaskType =
456         VectorType::get(minTileSlices, rewriter.getI1Type(), true);
457 
458     // Create loop over all tile slices.
459     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
460     auto upperBound = createVscaleMultiple(minTileSlices);
461     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
462     auto storeLoop =
463         rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
464     rewriter.setInsertionPointToStart(storeLoop.getBody());
465 
466     // For each sub-tile of the multi-tile `vectorType`.
467     auto inputSMETiles = adaptor.getVector();
468     auto tileSliceIndex = storeLoop.getInductionVar();
469     for (auto [index, smeTile] : llvm::enumerate(
470              decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
471       // The coordinates of the tile within `vectorType`.
472       auto tileRow = createVscaleMultiple(smeTile.row);
473       auto tileCol = createVscaleMultiple(smeTile.col);
474 
475       // The current slice of `vectorType` we are processing.
476       auto sliceIndex =
477           rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
478 
479       // Where in the destination memref the current slice will be stored.
480       auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
481                                                      writeOp.getIndices()[0]);
482       auto storeCol =
483           rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
484 
485       // Extract the mask for the current slice.
486       Value sliceMask = nullptr;
487       if (mask) {
488         sliceMask = rewriter.create<vector::ExtractOp>(
489             loc, mask, OpFoldResult(sliceIndex));
490         if (sliceMaskType != sliceMask.getType())
491           sliceMask = rewriter.create<vector::ScalableExtractOp>(
492               loc, sliceMaskType, sliceMask, smeTile.col);
493       }
494 
495       // Extract and store the current slice.
496       Value tile = inputSMETiles[index];
497       auto slice =
498           rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
499       rewriter.create<vector::TransferWriteOp>(
500           loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
501           AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
502           sliceMask,
503           rewriter.getBoolArrayAttr(
504               ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
505     }
506 
507     rewriter.eraseOp(writeOp);
508     return success();
509   }
510 };
511 
512 //===----------------------------------------------------------------------===//
513 // ArmSME-specific fixup canonicalizations/folds
514 //===----------------------------------------------------------------------===//
515 
516 /// Folds an extract from a 3D `vector.create_mask` (which is a vector of
517 /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
518 /// necessary for the mask to be lowered to ArmSME.
519 ///
520 /// Example:
521 ///
522 ///  BEFORE:
523 ///  ```mlir
524 ///  %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
525 ///  %subMask = vector.extract %mask[2]
526 ///          : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
527 ///  ```
528 ///
529 ///  AFTER:
530 ///  ```mlir
531 ///  %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
532 ///  %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
533 ///  %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
534 ///  ```
535 struct FoldExtractFromVectorOfSMELikeCreateMasks
536     : public OpRewritePattern<vector::ExtractOp> {
537   using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
538 
539   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
540                                 PatternRewriter &rewriter) const override {
541     auto loc = extractOp.getLoc();
542     auto createMaskOp =
543         extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
544     if (!createMaskOp)
545       return rewriter.notifyMatchFailure(
546           extractOp, "extract not from vector.create_mask op");
547 
548     VectorType extractedMaskType =
549         llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
550     if (!extractedMaskType)
551       return rewriter.notifyMatchFailure(extractOp,
552                                          "extracted type is not a vector type");
553 
554     auto numScalable = extractedMaskType.getNumScalableDims();
555     if (numScalable != 2)
556       return rewriter.notifyMatchFailure(
557           extractOp, "expected extracted type to be an SME-like mask");
558 
559     // TODO: Support multiple extraction indices.
560     if (extractOp.getStaticPosition().size() != 1)
561       return rewriter.notifyMatchFailure(
562           extractOp, "only a single extraction index is supported");
563 
564     auto frontMaskDim = createMaskOp.getOperand(0);
565     if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
566       return rewriter.notifyMatchFailure(
567           extractOp,
568           "constant vector.create_masks dims should be folded elsewhere");
569 
570     auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
571     auto extractionIndex = getValueOrCreateConstantIndexOp(
572         rewriter, loc, extractOp.getMixedPosition()[0]);
573     auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
574         loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
575         frontMaskDim);
576     auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
577         loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
578 
579     rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
580         extractOp, extractedMaskType,
581         ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
582     return success();
583   }
584 };
585 
586 /// A vector type where no fixed dimension comes after a scalable dimension.
587 bool isLegalVectorType(VectorType vType) {
588   bool seenFixedDim = false;
589   for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
590     seenFixedDim |= !scalableFlag;
591     if (seenFixedDim && scalableFlag)
592       return false;
593   }
594   return true;
595 }
596 
597 /// Lifts an illegal vector.transpose and vector.transfer_read to a
598 /// memref.subview + memref.transpose, followed by a legal read.
599 ///
600 /// 'Illegal' here means a leading scalable dimension and a fixed trailing
601 /// dimension, which has no valid lowering.
602 ///
603 /// The memref.transpose is metadata-only transpose that produces a strided
604 /// memref, which eventually becomes a loop reading individual elements.
605 ///
606 /// Example:
607 ///
608 ///  BEFORE:
609 ///  ```mlir
610 ///  %illegalRead = vector.transfer_read %memref[%a, %b]
611 ///                  : memref<?x?xf32>, vector<[8]x4xf32>
612 ///  %legalType = vector.transpose %illegalRead, [1, 0]
613 ///                  : vector<[8]x4xf32> to vector<4x[8]xf32>
614 ///  ```
615 ///
616 ///  AFTER:
617 ///  ```mlir
618 ///  %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
619 ///                  : memref<?x?xf32> to memref<?x?xf32>
620 ///  %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
621 ///                  : memref<?x?xf32> to memref<?x?xf32>
622 ///  %legalType = vector.transfer_read %transpose[%c0, %c0]
623 ///                  : memref<?x?xf32>, vector<4x[8]xf32>
624 ///  ```
625 struct LiftIllegalVectorTransposeToMemory
626     : public OpRewritePattern<vector::TransposeOp> {
627   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
628 
629   static Value getExtensionSource(Operation *op) {
630     if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
631       return op->getOperand(0);
632     return {};
633   }
634 
635   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
636                                 PatternRewriter &rewriter) const override {
637     auto sourceType = transposeOp.getSourceVectorType();
638     auto resultType = transposeOp.getResultVectorType();
639     if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
640       return rewriter.notifyMatchFailure(transposeOp,
641                                          kMatchFailureNotIllegalToLegal);
642 
643     // Look through extend for transfer_read.
644     Value maybeRead = transposeOp.getVector();
645     auto *transposeSourceOp = maybeRead.getDefiningOp();
646     Operation *extendOp = nullptr;
647     if (Value extendSource = getExtensionSource(transposeSourceOp)) {
648       maybeRead = extendSource;
649       extendOp = transposeSourceOp;
650     }
651 
652     auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
653     if (!illegalRead)
654       return rewriter.notifyMatchFailure(
655           transposeOp,
656           "expected source to be (possibly extended) transfer_read");
657 
658     if (!illegalRead.getPermutationMap().isIdentity())
659       return rewriter.notifyMatchFailure(
660           illegalRead, "expected read to have identity permutation map");
661 
662     auto loc = transposeOp.getLoc();
663     auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
664     auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
665 
666     // Create a subview that matches the size of the illegal read vector type.
667     auto readType = illegalRead.getVectorType();
668     auto readSizes = llvm::map_to_vector(
669         llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
670         [&](auto dim) -> Value {
671           auto [size, isScalable] = dim;
672           auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size);
673           if (!isScalable)
674             return dimSize;
675           auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
676           return rewriter.create<arith::MulIOp>(loc, vscale, dimSize);
677         });
678     SmallVector<Value> strides(readType.getRank(), Value(one));
679     auto readSubview = rewriter.create<memref::SubViewOp>(
680         loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
681         strides);
682 
683     // Apply the transpose to all values/attributes of the transfer_read:
684     // - The mask
685     Value mask = illegalRead.getMask();
686     if (mask) {
687       // Note: The transpose for the mask should fold into the
688       // vector.create_mask/constant_mask op, which will then become legal.
689       mask = rewriter.create<vector::TransposeOp>(loc, mask,
690                                                   transposeOp.getPermutation());
691     }
692     // - The source memref
693     mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
694         transposeOp.getPermutation(), getContext());
695     auto transposedSubview = rewriter.create<memref::TransposeOp>(
696         loc, readSubview, AffineMapAttr::get(transposeMap));
697     ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
698     // - The `in_bounds` attribute
699     if (inBoundsAttr) {
700       SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
701                                             inBoundsAttr.end());
702       applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
703       inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
704     }
705 
706     VectorType legalReadType = resultType.clone(readType.getElementType());
707     // Note: The indices are all zero as the subview is already offset.
708     SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
709     auto legalRead = rewriter.create<vector::TransferReadOp>(
710         loc, legalReadType, transposedSubview, readIndices,
711         illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
712         inBoundsAttr);
713 
714     // Replace the transpose with the new read, extending the result if
715     // necessary.
716     rewriter.replaceOp(transposeOp, [&]() -> Operation * {
717       if (extendOp)
718         return rewriter.create(loc, extendOp->getName().getIdentifier(),
719                                Value(legalRead), resultType);
720       return legalRead;
721     }());
722 
723     return success();
724   }
725 };
726 
727 /// A rewrite to turn unit dim transpose-like vector.shape_casts into
728 /// vector.transposes. The shape_cast has to be from an illegal vector type to a
729 /// legal one (as defined by isLegalVectorType).
730 ///
731 /// The reasoning for this is if we've got to this pass and we still have
732 /// shape_casts of illegal types, then they likely will not cancel out. Turning
733 /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
734 /// eliminate them.
735 ///
736 /// Example:
737 ///
738 ///  BEFORE:
739 ///  ```mlir
740 ///  %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
741 ///  ```
742 ///
743 ///  AFTER:
744 ///  ```mlir
745 ///  %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
746 ///  ```
747 struct ConvertIllegalShapeCastOpsToTransposes
748     : public OpRewritePattern<vector::ShapeCastOp> {
749   using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
750 
751   LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
752                                 PatternRewriter &rewriter) const override {
753     auto sourceType = shapeCastOp.getSourceVectorType();
754     auto resultType = shapeCastOp.getResultVectorType();
755     if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
756       return rewriter.notifyMatchFailure(shapeCastOp,
757                                          kMatchFailureNotIllegalToLegal);
758 
759     // Note: If we know that `sourceType` is an illegal vector type (and 2D)
760     // then dim 0 is scalable and dim 1 is fixed.
761     if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
762       return rewriter.notifyMatchFailure(
763           shapeCastOp, "expected source to be a 2D scalable vector with a "
764                        "trailing unit dim");
765 
766     auto loc = shapeCastOp.getLoc();
767     auto transpose = rewriter.create<vector::TransposeOp>(
768         loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
769 
770     if (resultType.getRank() == 1)
771       rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
772                                                        transpose);
773     else
774       rewriter.replaceOp(shapeCastOp, transpose);
775 
776     return success();
777   }
778 };
779 
780 /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
781 /// the ZA state. This workaround rewrite to support these transposes when ZA is
782 /// available.
783 ///
784 /// Example:
785 ///
786 ///  BEFORE:
787 ///  ```mlir
788 ///  %transpose = vector.transpose %vec, [1, 0]
789 ///     : vector<2x[4]xf32> to vector<[4]x2xf32>
790 ///  vector.transfer_write %transpose, %dest[%y, %x]
791 ///     : vector<[4]x2xf32>,  memref<?x?xf32>
792 ///  ```
793 ///
794 ///  AFTER:
795 ///  ```mlir
796 ///   %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
797 ///   %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
798 ///   %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
799 ///   %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
800 ///   %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
801 ///   %c4_vscale = arith.muli %vscale, %c4 : index
802 ///   %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
803 ///   vector.transfer_write %4, %dest[%y, %x], %mask
804 ///      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
805 ///      : vector<[4]x[4]xf32>, memref<?x?xf32>
806 ///  ```
807 ///
808 /// Values larger than a single tile are supported via decomposition.
809 struct LowerIllegalTransposeStoreViaZA
810     : public OpRewritePattern<vector::TransferWriteOp> {
811   using OpRewritePattern::OpRewritePattern;
812 
813   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
814                                 PatternRewriter &rewriter) const override {
815     if (!isSupportedMaskOp(writeOp.getMask()))
816       return rewriter.notifyMatchFailure(writeOp,
817                                          kMatchFailureUnsupportedMaskOp);
818 
819     auto permutationMap = writeOp.getPermutationMap();
820     if (!permutationMap.isIdentity())
821       return rewriter.notifyMatchFailure(writeOp,
822                                          kMatchFailureNonPermutationMap);
823 
824     auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
825     if (!transposeOp)
826       return failure();
827 
828     auto sourceType = transposeOp.getSourceVectorType();
829     auto resultType = transposeOp.getResultVectorType();
830 
831     if (resultType.getRank() != 2)
832       return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");
833 
834     if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
835       return rewriter.notifyMatchFailure(
836           transposeOp, "not illegal/unsupported SVE transpose");
837 
838     auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
839     VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
840 
841     if (sourceType.getDimSize(0) <= 1 ||
842         sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
843       return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
844 
845     auto loc = writeOp.getLoc();
846     auto createVscaleMultiple =
847         vector::makeVscaleConstantBuilder(rewriter, loc);
848 
849     auto transposeMap = AffineMapAttr::get(
850         AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
851 
852     // Note: We need to use `get_tile` as there's no vector-level `undef`.
853     Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
854     Value destTensorOrMemref = writeOp.getSource();
855     auto numSlicesPerTile =
856         std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
857     auto numSlices =
858         rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
859     for (auto [index, smeTile] : llvm::enumerate(
860              decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
861       // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
862       // of slices from the source type into the SME tile. Without checking
863       // vscale (and emitting multiple implementations) we can't make use of the
864       // rows of the tile after 1*vscale rows.
865       Value tile = undefTile;
866       for (int d = 0; d < numSlicesPerTile; ++d) {
867         Value vector = rewriter.create<vector::ExtractOp>(
868             loc, transposeOp.getVector(),
869             rewriter.getIndexAttr(d + smeTile.row));
870         if (vector.getType() != smeSliceType) {
871           vector = rewriter.create<vector::ScalableExtractOp>(
872               loc, smeSliceType, vector, smeTile.col);
873         }
874         tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d);
875       }
876 
877       // 2. Transpose the tile position.
878       auto transposedRow = createVscaleMultiple(smeTile.col);
879       auto transposedCol =
880           rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row);
881 
882       // 3. Compute mask for tile store.
883       Value maskRows;
884       Value maskCols;
885       if (auto mask = writeOp.getMask()) {
886         auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
887         maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0),
888                                                   transposedRow);
889         maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1),
890                                                   transposedCol);
891         maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices);
892       } else {
893         maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
894         maskCols = numSlices;
895       }
896       auto subMask = rewriter.create<vector::CreateMaskOp>(
897           loc, smeTileType.clone(rewriter.getI1Type()),
898           ValueRange{maskRows, maskCols});
899 
900       // 4. Emit a transposed tile write.
901       auto writeIndices = writeOp.getIndices();
902       Value destRow =
903           rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
904       Value destCol =
905           rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
906       auto smeWrite = rewriter.create<vector::TransferWriteOp>(
907           loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
908           transposeMap, subMask, writeOp.getInBounds());
909 
910       if (writeOp.hasPureTensorSemantics())
911         destTensorOrMemref = smeWrite.getResult();
912     }
913 
914     if (writeOp.hasPureTensorSemantics())
915       rewriter.replaceOp(writeOp, destTensorOrMemref);
916     else
917       rewriter.eraseOp(writeOp);
918 
919     return success();
920   }
921 };
922 
923 struct VectorLegalizationPass
924     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
925   void runOnOperation() override {
926     auto *context = &getContext();
927     TypeConverter converter;
928     RewritePatternSet patterns(context);
929     converter.addConversion([](Type type) { return type; });
930     converter.addConversion(
931         [](VectorType vectorType,
932            SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
933           if (!isMultipleOfSMETileVectorType(vectorType))
934             return std::nullopt;
935           auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
936           auto smeTileType =
937               getSMETileTypeForElement(vectorType.getElementType());
938           types = SmallVector<Type>(smeTileCount, smeTileType);
939           return success();
940         });
941 
942     // Apply preprocessing patterns.
943     RewritePatternSet rewritePatterns(context);
944     rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
945                         LiftIllegalVectorTransposeToMemory,
946                         ConvertIllegalShapeCastOpsToTransposes,
947                         LowerIllegalTransposeStoreViaZA>(context);
948     if (failed(
949             applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
950       return signalPassFailure();
951 
952     // Note: These two patterns are added with a high benefit to ensure:
953     //  - Masked outer products are handled before unmasked ones
954     //  - Multi-tile writes are lowered as a store loop (if possible)
955     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
956                  LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
957                                                             /*benefit=*/1024);
958     patterns.add<LegalizeArithConstantOpsByDecomposition,
959                  LegalizeVectorOuterProductOpsByDecomposition,
960                  LegalizeTransferReadOpsByDecomposition,
961                  LegalizeTransferWriteOpsByDecomposition>(converter, context);
962     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
963                                                                    converter);
964     populateCallOpTypeConversionPattern(patterns, converter);
965     populateReturnOpTypeConversionPattern(patterns, converter);
966     scf::populateSCFStructuralTypeConversions(converter, patterns);
967 
968     ConversionTarget target(getContext());
969     target.markUnknownOpDynamicallyLegal(
970         [&](Operation *op) { return converter.isLegal(op); });
971     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
972       return converter.isSignatureLegal(op.getFunctionType());
973     });
974     if (failed(applyPartialConversion(getOperation(), target,
975                                       std::move(patterns))))
976       return signalPassFailure();
977   }
978 };
979 
980 } // namespace
981 
982 std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
983   return std::make_unique<VectorLegalizationPass>();
984 }
985