xref: /llvm-project/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
1 //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
10 
11 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
12 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
13 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "llvm/Support/Casting.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 
22 /// Conversion pattern for vector.transfer_read.
23 ///
24 /// ---
25 ///
26 /// Example 1: op with identity permutation map to horizontal
27 ///            arm_sme.tile_load:
28 ///
29 ///   vector.transfer_read ...  permutation_map: (d0, d1) -> (d0, d1)
30 ///
31 /// is converted to:
32 ///
33 ///   arm_sme.tile_load ...
34 ///
35 /// ---
36 ///
37 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
38 ///            (in-flight transpose):
39 ///
40 ///   vector.transfer_read ...  permutation_map: (d0, d1) -> (d1, d0)
41 ///
42 /// is converted to:
43 ///
44 ///   arm_sme.tile_load ... layout<vertical>
45 struct TransferReadToArmSMELowering
46     : public OpRewritePattern<vector::TransferReadOp> {
47   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
48 
49   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
50                                 PatternRewriter &rewriter) const final {
51     // The permutation map must have two results.
52     if (transferReadOp.getTransferRank() != 2)
53       return rewriter.notifyMatchFailure(transferReadOp,
54                                          "not a 2 result permutation map");
55 
56     auto vectorType = transferReadOp.getVectorType();
57     if (!arm_sme::isValidSMETileVectorType(vectorType))
58       return rewriter.notifyMatchFailure(transferReadOp,
59                                          "not a valid vector type for SME");
60 
61     if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
62       return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
63 
64     // Out-of-bounds dims are not supported.
65     if (transferReadOp.hasOutOfBoundsDim())
66       return rewriter.notifyMatchFailure(transferReadOp,
67                                          "not inbounds transfer read");
68 
69     AffineMap map = transferReadOp.getPermutationMap();
70     if (!map.isPermutation())
71       return rewriter.notifyMatchFailure(transferReadOp,
72                                          "unsupported permutation map");
73 
74     // Note: For 2D vector types the only non-identity permutation is a simple
75     // transpose [1, 0].
76     bool transposed = !map.isIdentity();
77     arm_sme::TileSliceLayout layout =
78         transposed ? arm_sme::TileSliceLayout::Vertical
79                    : arm_sme::TileSliceLayout::Horizontal;
80 
81     // Padding isn't optional for transfer_read, but is only used in the case
82     // of out-of-bounds accesses (not supported here) and/or masking. Mask is
83     // optional, if it's not present don't pass padding.
84     auto mask = transferReadOp.getMask();
85     auto padding = mask ? transferReadOp.getPadding() : nullptr;
86     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
87         transferReadOp, vectorType, transferReadOp.getSource(),
88         transferReadOp.getIndices(), padding, mask, layout);
89 
90     return success();
91   }
92 };
93 
94 /// Conversion pattern for vector.transfer_write.
95 ///
96 /// ---
97 ///
98 /// Example 1: op with identity permutation map to horizontal
99 ///            arm_sme.tile_store:
100 ///
101 ///   vector.transfer_write %vector, %source[%c0, %c0]
102 ///     {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
103 ///
104 /// is converted to:
105 ///
106 ///   arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
107 ///                                                   vector<[16]x[16]xi8>
108 /// ---
109 ///
110 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
111 ///            (in-flight transpose):
112 ///
113 ///   vector.transfer_write %vector, %source[%c0, %c0]
114 ///     {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
115 ///      in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
116 ///
117 /// is converted to:
118 ///
119 ///   arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
120 ///     : memref<?x?xi8>, vector<[16]x[16]xi8>
121 struct TransferWriteToArmSMELowering
122     : public OpRewritePattern<vector::TransferWriteOp> {
123   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
124 
125   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
126                                 PatternRewriter &rewriter) const final {
127     auto vType = writeOp.getVectorType();
128     if (!arm_sme::isValidSMETileVectorType(vType))
129       return failure();
130 
131     if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
132       return failure();
133 
134     // Out-of-bounds dims are not supported.
135     if (writeOp.hasOutOfBoundsDim())
136       return rewriter.notifyMatchFailure(writeOp,
137                                          "not inbounds transfer write");
138 
139     AffineMap map = writeOp.getPermutationMap();
140     if (!map.isPermutation())
141       return rewriter.notifyMatchFailure(writeOp,
142                                          "unsupported permutation map");
143 
144     // Note: For 2D vector types the only non-identity permutation is a simple
145     // transpose [1, 0].
146     bool transposed = !map.isIdentity();
147     arm_sme::TileSliceLayout layout =
148         transposed ? arm_sme::TileSliceLayout::Vertical
149                    : arm_sme::TileSliceLayout::Horizontal;
150 
151     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
152         writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
153         writeOp.getMask(), layout);
154     return success();
155   }
156 };
157 
158 /// Conversion pattern for vector.load.
159 struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
160   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
161 
162   LogicalResult matchAndRewrite(vector::LoadOp load,
163                                 PatternRewriter &rewriter) const override {
164     if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
165       return failure();
166 
167     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
168         load, load.getVectorType(), load.getBase(), load.getIndices());
169 
170     return success();
171   }
172 };
173 
174 /// Conversion pattern for vector.store.
175 struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
176   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
177 
178   LogicalResult matchAndRewrite(vector::StoreOp store,
179                                 PatternRewriter &rewriter) const override {
180     if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
181       return failure();
182 
183     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
184         store, store.getValueToStore(), store.getBase(), store.getIndices());
185 
186     return success();
187   }
188 };
189 
190 /// Conversion pattern for vector.broadcast.
191 ///
192 /// Example:
193 ///
194 ///   %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
195 ///
196 /// is converted to:
197 ///
198 ///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
199 ///   %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
200 ///       step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
201 ///   {
202 ///     %tile_update = arm_sme.insert_tile_slice
203 ///        %broadcast_to_1d, %iter_tile[%tile_slice_index] :
204 ///        vector<[4]xi32> into vector<[4]x[4]xi32>
205 ///     scf.yield %tile_update : vector<[4]x[4]xi32>
206 ///   }
207 ///
208 /// Supports scalar, 0-d vector, and 1-d vector broadcasts.
209 struct BroadcastOpToArmSMELowering
210     : public OpRewritePattern<vector::BroadcastOp> {
211   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
212 
213   LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
214                                 PatternRewriter &rewriter) const final {
215     auto tileType = broadcastOp.getResultVectorType();
216     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
217       return failure();
218 
219     auto loc = broadcastOp.getLoc();
220 
221     auto srcType = broadcastOp.getSourceType();
222     auto srcVectorType = dyn_cast<VectorType>(srcType);
223 
224     Value broadcastOp1D;
225     if (srcType.isIntOrFloat() ||
226         (srcVectorType && (srcVectorType.getRank() == 0))) {
227       // Broadcast scalar or 0-d vector to 1-d vector.
228       VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
229       broadcastOp1D = rewriter.create<vector::BroadcastOp>(
230           loc, tileSliceType, broadcastOp.getSource());
231     } else if (srcVectorType && (srcVectorType.getRank() == 1))
232       // Value to broadcast is already a 1-d vector, nothing to do.
233       broadcastOp1D = broadcastOp.getSource();
234     else
235       return failure();
236 
237     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
238 
239     auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
240                             Value currentTile) {
241       // Create 'arm_sme.insert_tile_slice' to broadcast the value
242       // to each tile slice.
243       auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
244           loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
245       return nextTile.getResult();
246     };
247 
248     // Create a loop over ZA tile slices.
249     auto forOp =
250         createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
251 
252     rewriter.replaceOp(broadcastOp, forOp.getResult(0));
253 
254     return success();
255   }
256 };
257 
258 /// Conversion pattern for vector.splat.
259 ///
260 /// Example:
261 ///
262 ///   %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
263 ///
264 /// is converted to:
265 ///
266 ///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
267 ///   %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
268 ///       step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
269 ///   {
270 ///     %tile_update = arm_sme.insert_tile_slice
271 ///        %broadcast_to_1d, %iter_tile[%tile_slice_index] :
272 ///        vector<[4]xi32> into vector<[4]x[4]xi32>
273 ///     scf.yield %tile_update : vector<[4]x[4]xi32>
274 ///   }
275 ///
276 /// This is identical to vector.broadcast of a scalar.
277 struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
278   using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
279 
280   LogicalResult matchAndRewrite(vector::SplatOp splatOp,
281                                 PatternRewriter &rewriter) const final {
282     auto tileType = splatOp.getResult().getType();
283     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
284       return failure();
285 
286     auto loc = splatOp.getLoc();
287     auto srcType = splatOp.getOperand().getType();
288 
289     assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
290     // Avoid unused-variable warning when building without assertions.
291     (void)srcType;
292 
293     // First, broadcast the scalar to a 1-d vector.
294     VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
295     Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
296         loc, tileSliceType, splatOp.getInput());
297 
298     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
299 
300     auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
301                             Value currentTile) {
302       auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
303           loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
304       return nextTile.getResult();
305     };
306 
307     // Next, create a loop over ZA tile slices and "move" the generated 1-d
308     // vector to each slice.
309     auto forOp =
310         createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
311 
312     rewriter.replaceOp(splatOp, forOp.getResult(0));
313 
314     return success();
315   }
316 };
317 
318 /// Conversion pattern for vector.transpose.
319 ///
320 /// Stores the input tile to memory and reloads vertically.
321 ///
322 /// Example:
323 ///
324 ///   %transposed_src = vector.transpose %src, [1, 0]
325 ///     : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
326 ///
327 /// is converted to:
328 ///
329 ///   %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
330 ///   %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
331 ///     : memref<?x?xi32>, vector<[4]x[4]xi32>
332 ///   %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
333 ///     layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
334 ///
335 /// NOTE: Transposing via memory is obviously expensive, the current intention
336 /// is to avoid the transpose if possible, this is therefore intended as a
337 /// fallback and to provide base support for Vector ops. If it turns out
338 /// transposes can't be avoided then this should be replaced with a more optimal
339 /// implementation, perhaps with tile <-> vector (MOVA) ops.
340 struct TransposeOpToArmSMELowering
341     : public OpRewritePattern<vector::TransposeOp> {
342   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
343 
344   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
345                                 PatternRewriter &rewriter) const final {
346     auto tileType = transposeOp.getResultVectorType();
347     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
348       return failure();
349 
350     // Bail unless this is a true 2-D matrix transpose.
351     ArrayRef<int64_t> permutation = transposeOp.getPermutation();
352     if (permutation[0] != 1 || permutation[1] != 0)
353       return failure();
354 
355     auto loc = transposeOp.getLoc();
356     Value input = transposeOp.getVector();
357 
358     if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
359         xferOp && xferOp->hasOneUse()) {
360       // Fold transpose into transfer_read to enable in-flight transpose when
361       // converting to arm_sme.tile_load.
362       rewriter.modifyOpInPlace(xferOp, [&]() {
363         xferOp->setAttr(xferOp.getPermutationMapAttrName(),
364                         AffineMapAttr::get(AffineMap::getPermutationMap(
365                             permutation, transposeOp.getContext())));
366       });
367       rewriter.replaceOp(transposeOp, xferOp);
368       return success();
369     }
370 
371     // Allocate buffer to store input tile to.
372     Value vscale =
373         rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
374     Value minTileSlices = rewriter.create<arith::ConstantOp>(
375         loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
376     Value c0 =
377         rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
378     Value numTileSlices =
379         rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
380     auto bufferType =
381         MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
382                         tileType.getElementType());
383     auto buffer = rewriter.create<memref::AllocaOp>(
384         loc, bufferType, ValueRange{numTileSlices, numTileSlices});
385 
386     // Store input tile.
387     auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
388         loc, input, buffer, ValueRange{c0, c0});
389 
390     // Reload input tile vertically.
391     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
392         transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
393         arm_sme::TileSliceLayout::Vertical);
394 
395     return success();
396   }
397 };
398 
399 /// Conversion pattern for vector.outerproduct.
400 ///
401 /// If the vector.outerproduct is masked (and the mask is from a
402 /// vector.create_mask), then the mask is decomposed into two 1-D masks for the
403 /// operands.
404 ///
405 /// Example:
406 ///
407 ///   %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
408 ///   %result = vector.mask %mask {
409 ///                vector.outerproduct %vecA, %vecB
410 ///                 : vector<[4]xf32>, vector<[4]xf32>
411 ///             } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
412 ///
413 /// is converted to:
414 ///
415 ///    %maskA = vector.create_mask %dimA : vector<[4]xi1>
416 ///    %maskB = vector.create_mask %dimB : vector<[4]xi1>
417 ///    %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
418 ///                : vector<[4]xf32>, vector<[4]xf32>
419 ///
420 /// Unmasked outerproducts can be directly replaced with the arm_sme op.
421 ///
422 /// Example:
423 ///
424 ///   %result = vector.outerproduct %vecA, %vecB
425 ///              : vector<[4]xf32>, vector<[4]xf32>
426 ///
427 /// is converted to:
428 ///
429 ///   %result = arm_sme.outerproduct %vecA, %vecB
430 ///              : vector<[4]xf32>, vector<[4]xf32>
431 ///
432 struct VectorOuterProductToArmSMELowering
433     : public OpRewritePattern<vector::OuterProductOp> {
434 
435   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
436 
437   LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
438                                 PatternRewriter &rewriter) const override {
439 
440     // We don't yet support lowering AXPY operations to SME. These could be
441     // lowered by masking out all but the first element of the LHS.
442     if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
443       return rewriter.notifyMatchFailure(outerProductOp,
444                                          "AXPY operations not supported");
445 
446     if (!arm_sme::isValidSMETileVectorType(
447             outerProductOp.getResultVectorType()))
448       return rewriter.notifyMatchFailure(
449           outerProductOp, "outer product does not fit into SME tile");
450 
451     auto kind = outerProductOp.getKind();
452     if (kind != vector::CombiningKind::ADD)
453       return rewriter.notifyMatchFailure(
454           outerProductOp,
455           "unsupported kind (lowering to SME only supports ADD at the moment)");
456 
457     Value lhsMask = {};
458     Value rhsMask = {};
459     Operation *rootOp = outerProductOp;
460     auto loc = outerProductOp.getLoc();
461     if (outerProductOp.isMasked()) {
462       auto maskOp = outerProductOp.getMaskingOp();
463       rewriter.setInsertionPoint(maskOp);
464       rootOp = maskOp;
465       auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
466       if (failed(operandMasks))
467         return failure();
468       std::tie(lhsMask, rhsMask) = *operandMasks;
469     }
470 
471     rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
472         rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
473         outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
474 
475     return success();
476   }
477 
478   static FailureOr<std::pair<Value, Value>>
479   decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
480     // Attempt to extract masks from vector.create_mask.
481     // TODO: Add support for other mask sources.
482     auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
483     if (!createMaskOp)
484       return failure();
485 
486     auto maskType = createMaskOp.getVectorType();
487     Value lhsMaskDim = createMaskOp.getOperand(0);
488     Value rhsMaskDim = createMaskOp.getOperand(1);
489 
490     VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
491     Value lhsMask =
492         rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
493     Value rhsMask =
494         rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
495 
496     return std::make_pair(lhsMask, rhsMask);
497   }
498 };
499 
500 /// Lower `vector.extract` using `arm_sme.extract_tile_slice`.
501 ///
502 /// Example:
503 /// ```
504 /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
505 /// ```
506 /// Becomes:
507 /// ```
508 /// %slice = arm_sme.extract_tile_slice %tile[%row]
509 ///            : vector<[4]xi32> from vector<[4]x[4]xi32>
510 /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
511 /// ```
512 struct VectorExtractToArmSMELowering
513     : public OpRewritePattern<vector::ExtractOp> {
514   using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
515 
516   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
517                                 PatternRewriter &rewriter) const override {
518     VectorType sourceType = extractOp.getSourceVectorType();
519     if (!arm_sme::isValidSMETileVectorType(sourceType))
520       return failure();
521 
522     auto loc = extractOp.getLoc();
523     auto position = extractOp.getMixedPosition();
524 
525     Value sourceVector = extractOp.getVector();
526 
527     // Extract entire vector. Should be handled by folder, but just to be safe.
528     if (position.empty()) {
529       rewriter.replaceOp(extractOp, sourceVector);
530       return success();
531     }
532 
533     Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
534     auto extractTileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
535         loc, sourceVector, sliceIndex);
536 
537     if (position.size() == 1) {
538       // Single index case: Extracts a 1D slice.
539       rewriter.replaceOp(extractOp, extractTileSlice);
540       return success();
541     }
542 
543     // Two indices case: Extracts a single element.
544     assert(position.size() == 2);
545     rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, extractTileSlice,
546                                                    position[1]);
547 
548     return success();
549   }
550 };
551 
552 /// Lower `vector.insert` using `arm_sme.insert_tile_slice` and
553 /// `arm_sme.extract_tile_slice`.
554 ///
555 /// Example:
556 /// ```
557 /// %new_tile = vector.insert %el, %tile[%row, %col]
558 ///                     : i32 into vector<[4]x[4]xi32>
559 /// ```
560 /// Becomes:
561 /// ```
562 /// %slice = arm_sme.extract_tile_slice %tile[%row]
563 ///            : vector<[4]xi32> from vector<[4]x[4]xi32>
564 /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
565 /// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
566 ///               : vector<[4]xi32> into vector<[4]x[4]xi32>
567 /// ```
568 struct VectorInsertToArmSMELowering
569     : public OpRewritePattern<vector::InsertOp> {
570   using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
571 
572   LogicalResult matchAndRewrite(vector::InsertOp insertOp,
573                                 PatternRewriter &rewriter) const override {
574     VectorType resultType = insertOp.getResult().getType();
575 
576     if (!arm_sme::isValidSMETileVectorType(resultType))
577       return failure();
578 
579     auto loc = insertOp.getLoc();
580     auto position = insertOp.getMixedPosition();
581 
582     Value source = insertOp.getSource();
583 
584     // Overwrite entire vector with value. Should be handled by folder, but
585     // just to be safe.
586     if (position.empty()) {
587       rewriter.replaceOp(insertOp, source);
588       return success();
589     }
590 
591     Value tileSlice = source;
592     Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
593     if (position.size() == 2) {
594       // Two indices case: Insert single element into tile.
595       // We need to first extract the existing slice and update the element.
596       tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
597           loc, insertOp.getDest(), sliceIndex);
598       tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
599                                                     position[1]);
600     }
601 
602     // Insert the slice into the destination tile.
603     rewriter.replaceOpWithNewOp<arm_sme::InsertTileSliceOp>(
604         insertOp, tileSlice, insertOp.getDest(), sliceIndex);
605     return success();
606   }
607 };
608 
609 /// Lowers `vector.print` of a tile into a loop over the rows of the tile,
610 /// extracting them via `arm_sme.extract_tile_slice`, then printing with
611 /// a 1D `vector.print`.
612 ///
613 ///  BEFORE:
614 ///  ```mlir
615 ///  vector.print %tile : vector<[4]x[4]xf32>
616 ///  ```
617 ///  AFTER:
618 ///  ```mlir
619 ///  %c0 = arith.constant 0 : index
620 ///  %c1 = arith.constant 1 : index
621 ///  %c4 = arith.constant 4 : index
622 ///  %vscale = vector.vscale
623 ///  %svl_s = arith.muli %c4, %vscale : index
624 ///  scf.for %i = %c0 to %svl_s step %c1 {
625 ///    %tile_slice = arm_sme.extract_tile_slice %tile[%i]
626 ///                     : vector<[4]xf32> from vector<[4]x[4]xf32>
627 ///    vector.print %tile_slice : vector<[4]xf32>
628 ///  }
629 ///  ```
630 struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
631   using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
632 
633   LogicalResult matchAndRewrite(vector::PrintOp printOp,
634                                 PatternRewriter &rewriter) const override {
635     if (!printOp.getSource())
636       return failure();
637 
638     VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
639     if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
640       return failure();
641 
642     auto loc = printOp.getLoc();
643 
644     // Create a loop over the rows of the tile.
645     auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
646     auto minTileRows =
647         rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
648     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
649     auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
650     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
651     auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
652     {
653       // Loop body.
654       rewriter.setInsertionPointToStart(forOp.getBody());
655       // Extract the current row from the tile.
656       Value rowIndex = forOp.getInductionVar();
657       auto tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
658           loc, printOp.getSource(), rowIndex);
659       // Print the row with a 1D vector.print.
660       rewriter.create<vector::PrintOp>(loc, tileSlice,
661                                        printOp.getPunctuation());
662     }
663 
664     rewriter.eraseOp(printOp);
665     return success();
666   }
667 };
668 
669 /// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp.
670 ///
671 ///  BEFORE:
672 ///  ```mlir
673 ///  %slice = arm_sme.extract_tile_slice %tile[%index]
674 ///             : vector<[4]xf32> from vector<[4]x[4]xf32>
675 ///  vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
676 ///             : vector<[4]xf32>, memref<?x?xf32>
677 ///  ```
678 ///  AFTER:
679 ///  ```mlir
680 ///  arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
681 ///             : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
682 ///  ```
683 struct FoldTransferWriteOfExtractTileSlice
684     : public OpRewritePattern<vector::TransferWriteOp> {
685   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
686 
687   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
688                                 PatternRewriter &rewriter) const final {
689     if (!isa<MemRefType>(writeOp.getSource().getType()))
690       return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
691 
692     if (writeOp.hasOutOfBoundsDim())
693       return rewriter.notifyMatchFailure(writeOp,
694                                          "not inbounds transfer write");
695 
696     auto extractTileSlice =
697         writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>();
698     if (!extractTileSlice)
699       return rewriter.notifyMatchFailure(
700           writeOp, "vector to store not from ExtractTileSliceOp");
701 
702     AffineMap map = writeOp.getPermutationMap();
703     if (!map.isMinorIdentity())
704       return rewriter.notifyMatchFailure(writeOp,
705                                          "unsupported permutation map");
706 
707     Value mask = writeOp.getMask();
708     if (!mask) {
709       auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
710       mask = rewriter.create<arith::ConstantOp>(
711           writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
712     }
713 
714     rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
715         writeOp, extractTileSlice.getTile(),
716         extractTileSlice.getTileSliceIndex(), mask, writeOp.getSource(),
717         writeOp.getIndices(), extractTileSlice.getLayout());
718     return success();
719   }
720 };
721 
722 /// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
723 /// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
724 /// SVE 2.1), so this is currently the most logical place for this lowering.
725 ///
726 /// Example:
727 /// ```mlir
728 /// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
729 /// %slice = vector.extract %mask[%index]
730 ///            : vector<[8]xi1> from vector<[4]x[8]xi1>
731 /// ```
732 /// Becomes:
733 /// ```
734 /// %mask_rows = vector.create_mask %a : vector<[4]xi1>
735 /// %mask_cols = vector.create_mask %b : vector<[8]xi1>
736 /// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
737 ///            : vector<[8]xi1>, vector<[4]xi1>
738 /// ```
739 struct ExtractFromCreateMaskToPselLowering
740     : public OpRewritePattern<vector::ExtractOp> {
741   using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
742 
743   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
744                                 PatternRewriter &rewriter) const override {
745     if (extractOp.getNumIndices() != 1)
746       return rewriter.notifyMatchFailure(extractOp, "not single extract index");
747 
748     auto resultType = extractOp.getResult().getType();
749     auto resultVectorType = dyn_cast<VectorType>(resultType);
750     if (!resultVectorType)
751       return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
752 
753     auto createMaskOp =
754         extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
755     if (!createMaskOp)
756       return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
757 
758     auto maskType = createMaskOp.getVectorType();
759     if (maskType.getRank() != 2 || !maskType.allDimsScalable())
760       return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
761 
762     auto isSVEPredicateSize = [](int64_t size) {
763       return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
764     };
765 
766     auto rowsBaseSize = maskType.getDimSize(0);
767     auto colsBaseSize = maskType.getDimSize(1);
768     if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
769       return rewriter.notifyMatchFailure(
770           createMaskOp, "mask dimensions not SVE predicate-sized");
771 
772     auto loc = extractOp.getLoc();
773     VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
774     VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
775 
776     // Create the two 1-D masks at the location of the 2-D create_mask (which is
777     // usually outside a loop). This prevents the need for later hoisting.
778     rewriter.setInsertionPoint(createMaskOp);
779     auto rowMask = rewriter.create<vector::CreateMaskOp>(
780         loc, rowMaskType, createMaskOp.getOperand(0));
781     auto colMask = rewriter.create<vector::CreateMaskOp>(
782         loc, colMaskType, createMaskOp.getOperand(1));
783 
784     rewriter.setInsertionPoint(extractOp);
785     auto position =
786         vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
787     rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
788                                                  position[0]);
789     return success();
790   }
791 };
792 
793 } // namespace
794 
795 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
796                                           MLIRContext &ctx) {
797   patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
798                TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
799                TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
800                VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
801                VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
802                VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
803                ExtractFromCreateMaskToPselLowering>(&ctx);
804 }
805