xref: /llvm-project/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (revision aa2952165cd1808dab2bb49b97becc097f4c9cac)
1447bb5beSAndrzej Warzynski //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===//
2447bb5beSAndrzej Warzynski //
3447bb5beSAndrzej Warzynski // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4447bb5beSAndrzej Warzynski // See https://llvm.org/LICENSE.txt for license information.
5447bb5beSAndrzej Warzynski // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6447bb5beSAndrzej Warzynski //
7447bb5beSAndrzej Warzynski //===----------------------------------------------------------------------===//
8447bb5beSAndrzej Warzynski 
9447bb5beSAndrzej Warzynski #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
10447bb5beSAndrzej Warzynski 
11447bb5beSAndrzej Warzynski #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
12ca9a3354SCullen Rhodes #include "mlir/Dialect/ArmSME/Utils/Utils.h"
13e2296d82SBenjamin Maxwell #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
14eaf15900SCullen Rhodes #include "mlir/Dialect/MemRef/IR/MemRef.h"
15447bb5beSAndrzej Warzynski #include "mlir/IR/BuiltinTypes.h"
16447bb5beSAndrzej Warzynski #include "llvm/Support/Casting.h"
17447bb5beSAndrzej Warzynski 
18447bb5beSAndrzej Warzynski using namespace mlir;
19447bb5beSAndrzej Warzynski 
20447bb5beSAndrzej Warzynski namespace {
21447bb5beSAndrzej Warzynski 
2222f11592SCullen Rhodes /// Conversion pattern for vector.transfer_read.
2322f11592SCullen Rhodes ///
2422f11592SCullen Rhodes /// ---
2522f11592SCullen Rhodes ///
2622f11592SCullen Rhodes /// Example 1: op with identity permutation map to horizontal
2722f11592SCullen Rhodes ///            arm_sme.tile_load:
2822f11592SCullen Rhodes ///
2922f11592SCullen Rhodes ///   vector.transfer_read ...  permutation_map: (d0, d1) -> (d0, d1)
3022f11592SCullen Rhodes ///
3122f11592SCullen Rhodes /// is converted to:
3222f11592SCullen Rhodes ///
3322f11592SCullen Rhodes ///   arm_sme.tile_load ...
3422f11592SCullen Rhodes ///
3522f11592SCullen Rhodes /// ---
3622f11592SCullen Rhodes ///
3722f11592SCullen Rhodes /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
3822f11592SCullen Rhodes ///            (in-flight transpose):
398e64e9c3SCullen Rhodes ///
408e64e9c3SCullen Rhodes ///   vector.transfer_read ...  permutation_map: (d0, d1) -> (d1, d0)
418e64e9c3SCullen Rhodes ///
428e64e9c3SCullen Rhodes /// is converted to:
438e64e9c3SCullen Rhodes ///
44d86047cbSCullen Rhodes ///   arm_sme.tile_load ... layout<vertical>
4522f11592SCullen Rhodes struct TransferReadToArmSMELowering
468e64e9c3SCullen Rhodes     : public OpRewritePattern<vector::TransferReadOp> {
478e64e9c3SCullen Rhodes   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
488e64e9c3SCullen Rhodes 
498e64e9c3SCullen Rhodes   LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
508e64e9c3SCullen Rhodes                                 PatternRewriter &rewriter) const final {
518e64e9c3SCullen Rhodes     // The permutation map must have two results.
528e64e9c3SCullen Rhodes     if (transferReadOp.getTransferRank() != 2)
538e64e9c3SCullen Rhodes       return rewriter.notifyMatchFailure(transferReadOp,
548e64e9c3SCullen Rhodes                                          "not a 2 result permutation map");
558e64e9c3SCullen Rhodes 
568e64e9c3SCullen Rhodes     auto vectorType = transferReadOp.getVectorType();
578e64e9c3SCullen Rhodes     if (!arm_sme::isValidSMETileVectorType(vectorType))
588e64e9c3SCullen Rhodes       return rewriter.notifyMatchFailure(transferReadOp,
598e64e9c3SCullen Rhodes                                          "not a valid vector type for SME");
608e64e9c3SCullen Rhodes 
618e64e9c3SCullen Rhodes     if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
628e64e9c3SCullen Rhodes       return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
638e64e9c3SCullen Rhodes 
648e64e9c3SCullen Rhodes     // Out-of-bounds dims are not supported.
658e64e9c3SCullen Rhodes     if (transferReadOp.hasOutOfBoundsDim())
668e64e9c3SCullen Rhodes       return rewriter.notifyMatchFailure(transferReadOp,
678e64e9c3SCullen Rhodes                                          "not inbounds transfer read");
688e64e9c3SCullen Rhodes 
6922f11592SCullen Rhodes     AffineMap map = transferReadOp.getPermutationMap();
70b49c0b8aSCullen Rhodes     if (!map.isPermutation())
718e64e9c3SCullen Rhodes       return rewriter.notifyMatchFailure(transferReadOp,
7222f11592SCullen Rhodes                                          "unsupported permutation map");
738e64e9c3SCullen Rhodes 
74b49c0b8aSCullen Rhodes     // Note: For 2D vector types the only non-identity permutation is a simple
75b49c0b8aSCullen Rhodes     // transpose [1, 0].
76b49c0b8aSCullen Rhodes     bool transposed = !map.isIdentity();
77b49c0b8aSCullen Rhodes     arm_sme::TileSliceLayout layout =
78b49c0b8aSCullen Rhodes         transposed ? arm_sme::TileSliceLayout::Vertical
79b49c0b8aSCullen Rhodes                    : arm_sme::TileSliceLayout::Horizontal;
80b49c0b8aSCullen Rhodes 
8122f11592SCullen Rhodes     // Padding isn't optional for transfer_read, but is only used in the case
8222f11592SCullen Rhodes     // of out-of-bounds accesses (not supported here) and/or masking. Mask is
8322f11592SCullen Rhodes     // optional, if it's not present don't pass padding.
8422f11592SCullen Rhodes     auto mask = transferReadOp.getMask();
8522f11592SCullen Rhodes     auto padding = mask ? transferReadOp.getPadding() : nullptr;
868e64e9c3SCullen Rhodes     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
878e64e9c3SCullen Rhodes         transferReadOp, vectorType, transferReadOp.getSource(),
8822f11592SCullen Rhodes         transferReadOp.getIndices(), padding, mask, layout);
898e64e9c3SCullen Rhodes 
908e64e9c3SCullen Rhodes     return success();
918e64e9c3SCullen Rhodes   }
928e64e9c3SCullen Rhodes };
938e64e9c3SCullen Rhodes 
9412e1a9b8SCullen Rhodes /// Conversion pattern for vector.transfer_write.
95447bb5beSAndrzej Warzynski ///
964240b179SCullen Rhodes /// ---
974240b179SCullen Rhodes ///
984240b179SCullen Rhodes /// Example 1: op with identity permutation map to horizontal
994240b179SCullen Rhodes ///            arm_sme.tile_store:
1004240b179SCullen Rhodes ///
1014240b179SCullen Rhodes ///   vector.transfer_write %vector, %source[%c0, %c0]
1024240b179SCullen Rhodes ///     {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
103447bb5beSAndrzej Warzynski ///
104447bb5beSAndrzej Warzynski /// is converted to:
105447bb5beSAndrzej Warzynski ///
10612e1a9b8SCullen Rhodes ///   arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
107447bb5beSAndrzej Warzynski ///                                                   vector<[16]x[16]xi8>
1084240b179SCullen Rhodes /// ---
1094240b179SCullen Rhodes ///
1104240b179SCullen Rhodes /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
1114240b179SCullen Rhodes ///            (in-flight transpose):
1124240b179SCullen Rhodes ///
1134240b179SCullen Rhodes ///   vector.transfer_write %vector, %source[%c0, %c0]
1144240b179SCullen Rhodes ///     {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
1154240b179SCullen Rhodes ///      in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
1164240b179SCullen Rhodes ///
1174240b179SCullen Rhodes /// is converted to:
1184240b179SCullen Rhodes ///
1194240b179SCullen Rhodes ///   arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
1204240b179SCullen Rhodes ///     : memref<?x?xi8>, vector<[16]x[16]xi8>
121447bb5beSAndrzej Warzynski struct TransferWriteToArmSMELowering
122447bb5beSAndrzej Warzynski     : public OpRewritePattern<vector::TransferWriteOp> {
123447bb5beSAndrzej Warzynski   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
124447bb5beSAndrzej Warzynski 
125447bb5beSAndrzej Warzynski   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
126447bb5beSAndrzej Warzynski                                 PatternRewriter &rewriter) const final {
127447bb5beSAndrzej Warzynski     auto vType = writeOp.getVectorType();
12812e1a9b8SCullen Rhodes     if (!arm_sme::isValidSMETileVectorType(vType))
129447bb5beSAndrzej Warzynski       return failure();
130447bb5beSAndrzej Warzynski 
131447bb5beSAndrzej Warzynski     if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
132447bb5beSAndrzej Warzynski       return failure();
133447bb5beSAndrzej Warzynski 
1344240b179SCullen Rhodes     // Out-of-bounds dims are not supported.
1354240b179SCullen Rhodes     if (writeOp.hasOutOfBoundsDim())
1364240b179SCullen Rhodes       return rewriter.notifyMatchFailure(writeOp,
1374240b179SCullen Rhodes                                          "not inbounds transfer write");
1384240b179SCullen Rhodes 
1394240b179SCullen Rhodes     AffineMap map = writeOp.getPermutationMap();
140b49c0b8aSCullen Rhodes     if (!map.isPermutation())
1414240b179SCullen Rhodes       return rewriter.notifyMatchFailure(writeOp,
1424240b179SCullen Rhodes                                          "unsupported permutation map");
1434240b179SCullen Rhodes 
144b49c0b8aSCullen Rhodes     // Note: For 2D vector types the only non-identity permutation is a simple
145b49c0b8aSCullen Rhodes     // transpose [1, 0].
146b49c0b8aSCullen Rhodes     bool transposed = !map.isIdentity();
1474240b179SCullen Rhodes     arm_sme::TileSliceLayout layout =
148b49c0b8aSCullen Rhodes         transposed ? arm_sme::TileSliceLayout::Vertical
1494240b179SCullen Rhodes                    : arm_sme::TileSliceLayout::Horizontal;
1504240b179SCullen Rhodes 
151447bb5beSAndrzej Warzynski     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
1521908f47aSCullen Rhodes         writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
1534240b179SCullen Rhodes         writeOp.getMask(), layout);
154447bb5beSAndrzej Warzynski     return success();
155447bb5beSAndrzej Warzynski   }
156447bb5beSAndrzej Warzynski };
157447bb5beSAndrzej Warzynski 
158ca9a3354SCullen Rhodes /// Conversion pattern for vector.load.
159ca9a3354SCullen Rhodes struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
160ca9a3354SCullen Rhodes   using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
161ca9a3354SCullen Rhodes 
162ca9a3354SCullen Rhodes   LogicalResult matchAndRewrite(vector::LoadOp load,
163ca9a3354SCullen Rhodes                                 PatternRewriter &rewriter) const override {
164ca9a3354SCullen Rhodes     if (!arm_sme::isValidSMETileVectorType(load.getVectorType()))
165ca9a3354SCullen Rhodes       return failure();
166ca9a3354SCullen Rhodes 
167ca9a3354SCullen Rhodes     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
168ca9a3354SCullen Rhodes         load, load.getVectorType(), load.getBase(), load.getIndices());
169ca9a3354SCullen Rhodes 
170ca9a3354SCullen Rhodes     return success();
171ca9a3354SCullen Rhodes   }
172ca9a3354SCullen Rhodes };
173ca9a3354SCullen Rhodes 
174ca9a3354SCullen Rhodes /// Conversion pattern for vector.store.
175ca9a3354SCullen Rhodes struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
176ca9a3354SCullen Rhodes   using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
177ca9a3354SCullen Rhodes 
178ca9a3354SCullen Rhodes   LogicalResult matchAndRewrite(vector::StoreOp store,
179ca9a3354SCullen Rhodes                                 PatternRewriter &rewriter) const override {
180ca9a3354SCullen Rhodes     if (!arm_sme::isValidSMETileVectorType(store.getVectorType()))
181ca9a3354SCullen Rhodes       return failure();
182ca9a3354SCullen Rhodes 
183ca9a3354SCullen Rhodes     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
184ca9a3354SCullen Rhodes         store, store.getValueToStore(), store.getBase(), store.getIndices());
185ca9a3354SCullen Rhodes 
186ca9a3354SCullen Rhodes     return success();
187ca9a3354SCullen Rhodes   }
188ca9a3354SCullen Rhodes };
189ca9a3354SCullen Rhodes 
1902dd3f420SCullen Rhodes /// Conversion pattern for vector.broadcast.
1912dd3f420SCullen Rhodes ///
1922dd3f420SCullen Rhodes /// Example:
1932dd3f420SCullen Rhodes ///
1942dd3f420SCullen Rhodes ///   %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32>
1952dd3f420SCullen Rhodes ///
1962dd3f420SCullen Rhodes /// is converted to:
1972dd3f420SCullen Rhodes ///
1982dd3f420SCullen Rhodes ///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
199b0b69fd8SBenjamin Maxwell ///   %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
200b0b69fd8SBenjamin Maxwell ///       step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
201b0b69fd8SBenjamin Maxwell ///   {
202c4251243SBenjamin Maxwell ///     %tile_update = arm_sme.insert_tile_slice
203c4251243SBenjamin Maxwell ///        %broadcast_to_1d, %iter_tile[%tile_slice_index] :
204b0b69fd8SBenjamin Maxwell ///        vector<[4]xi32> into vector<[4]x[4]xi32>
205b0b69fd8SBenjamin Maxwell ///     scf.yield %tile_update : vector<[4]x[4]xi32>
2062dd3f420SCullen Rhodes ///   }
2072dd3f420SCullen Rhodes ///
2082dd3f420SCullen Rhodes /// Supports scalar, 0-d vector, and 1-d vector broadcasts.
2092dd3f420SCullen Rhodes struct BroadcastOpToArmSMELowering
2102dd3f420SCullen Rhodes     : public OpRewritePattern<vector::BroadcastOp> {
2112dd3f420SCullen Rhodes   using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
2122dd3f420SCullen Rhodes 
2132dd3f420SCullen Rhodes   LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
2142dd3f420SCullen Rhodes                                 PatternRewriter &rewriter) const final {
2152dd3f420SCullen Rhodes     auto tileType = broadcastOp.getResultVectorType();
2162dd3f420SCullen Rhodes     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
2172dd3f420SCullen Rhodes       return failure();
2182dd3f420SCullen Rhodes 
2192dd3f420SCullen Rhodes     auto loc = broadcastOp.getLoc();
2202dd3f420SCullen Rhodes 
2212dd3f420SCullen Rhodes     auto srcType = broadcastOp.getSourceType();
2222dd3f420SCullen Rhodes     auto srcVectorType = dyn_cast<VectorType>(srcType);
2232dd3f420SCullen Rhodes 
2242dd3f420SCullen Rhodes     Value broadcastOp1D;
2252dd3f420SCullen Rhodes     if (srcType.isIntOrFloat() ||
2262dd3f420SCullen Rhodes         (srcVectorType && (srcVectorType.getRank() == 0))) {
2272dd3f420SCullen Rhodes       // Broadcast scalar or 0-d vector to 1-d vector.
228b0b69fd8SBenjamin Maxwell       VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
2292dd3f420SCullen Rhodes       broadcastOp1D = rewriter.create<vector::BroadcastOp>(
2302dd3f420SCullen Rhodes           loc, tileSliceType, broadcastOp.getSource());
2312dd3f420SCullen Rhodes     } else if (srcVectorType && (srcVectorType.getRank() == 1))
2322dd3f420SCullen Rhodes       // Value to broadcast is already a 1-d vector, nothing to do.
2332dd3f420SCullen Rhodes       broadcastOp1D = broadcastOp.getSource();
2342dd3f420SCullen Rhodes     else
2352dd3f420SCullen Rhodes       return failure();
2362dd3f420SCullen Rhodes 
237b0b69fd8SBenjamin Maxwell     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
2382dd3f420SCullen Rhodes 
2399f7fff7fSCullen Rhodes     auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
2409f7fff7fSCullen Rhodes                             Value currentTile) {
241c4251243SBenjamin Maxwell       // Create 'arm_sme.insert_tile_slice' to broadcast the value
242b0b69fd8SBenjamin Maxwell       // to each tile slice.
243c4251243SBenjamin Maxwell       auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
244b0b69fd8SBenjamin Maxwell           loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
2459f7fff7fSCullen Rhodes       return nextTile.getResult();
2469f7fff7fSCullen Rhodes     };
2479f7fff7fSCullen Rhodes 
2489f7fff7fSCullen Rhodes     // Create a loop over ZA tile slices.
2499f7fff7fSCullen Rhodes     auto forOp =
2509f7fff7fSCullen Rhodes         createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
2512dd3f420SCullen Rhodes 
252b0b69fd8SBenjamin Maxwell     rewriter.replaceOp(broadcastOp, forOp.getResult(0));
2532dd3f420SCullen Rhodes 
2542dd3f420SCullen Rhodes     return success();
2552dd3f420SCullen Rhodes   }
2562dd3f420SCullen Rhodes };
2572dd3f420SCullen Rhodes 
2580cb0df41SAndrzej Warzyński /// Conversion pattern for vector.splat.
2590cb0df41SAndrzej Warzyński ///
2600cb0df41SAndrzej Warzyński /// Example:
2610cb0df41SAndrzej Warzyński ///
2620cb0df41SAndrzej Warzyński ///   %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32>
2630cb0df41SAndrzej Warzyński ///
2640cb0df41SAndrzej Warzyński /// is converted to:
2650cb0df41SAndrzej Warzyński ///
2660cb0df41SAndrzej Warzyński ///   %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32>
267b0b69fd8SBenjamin Maxwell ///   %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices
268b0b69fd8SBenjamin Maxwell ///       step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>)
269b0b69fd8SBenjamin Maxwell ///   {
270c4251243SBenjamin Maxwell ///     %tile_update = arm_sme.insert_tile_slice
271c4251243SBenjamin Maxwell ///        %broadcast_to_1d, %iter_tile[%tile_slice_index] :
272b0b69fd8SBenjamin Maxwell ///        vector<[4]xi32> into vector<[4]x[4]xi32>
273b0b69fd8SBenjamin Maxwell ///     scf.yield %tile_update : vector<[4]x[4]xi32>
2740cb0df41SAndrzej Warzyński ///   }
2750cb0df41SAndrzej Warzyński ///
2760cb0df41SAndrzej Warzyński /// This is identical to vector.broadcast of a scalar.
2770cb0df41SAndrzej Warzyński struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
2780cb0df41SAndrzej Warzyński   using OpRewritePattern<vector::SplatOp>::OpRewritePattern;
2790cb0df41SAndrzej Warzyński 
2800cb0df41SAndrzej Warzyński   LogicalResult matchAndRewrite(vector::SplatOp splatOp,
2810cb0df41SAndrzej Warzyński                                 PatternRewriter &rewriter) const final {
2820cb0df41SAndrzej Warzyński     auto tileType = splatOp.getResult().getType();
2830cb0df41SAndrzej Warzyński     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
2840cb0df41SAndrzej Warzyński       return failure();
2850cb0df41SAndrzej Warzyński 
2860cb0df41SAndrzej Warzyński     auto loc = splatOp.getLoc();
2870cb0df41SAndrzej Warzyński     auto srcType = splatOp.getOperand().getType();
2880cb0df41SAndrzej Warzyński 
2890cb0df41SAndrzej Warzyński     assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat");
290042468bfSGoran Flegar     // Avoid unused-variable warning when building without assertions.
291042468bfSGoran Flegar     (void)srcType;
2920cb0df41SAndrzej Warzyński 
2930cb0df41SAndrzej Warzyński     // First, broadcast the scalar to a 1-d vector.
2940cb0df41SAndrzej Warzyński     VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
2950cb0df41SAndrzej Warzyński     Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
2960cb0df41SAndrzej Warzyński         loc, tileSliceType, splatOp.getInput());
2970cb0df41SAndrzej Warzyński 
298b0b69fd8SBenjamin Maxwell     auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
2990cb0df41SAndrzej Warzyński 
3009f7fff7fSCullen Rhodes     auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
3019f7fff7fSCullen Rhodes                             Value currentTile) {
302c4251243SBenjamin Maxwell       auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
3039f7fff7fSCullen Rhodes           loc, tileType, broadcastOp1D, currentTile, tileSliceIndex);
3049f7fff7fSCullen Rhodes       return nextTile.getResult();
3059f7fff7fSCullen Rhodes     };
3069f7fff7fSCullen Rhodes 
3070cb0df41SAndrzej Warzyński     // Next, create a loop over ZA tile slices and "move" the generated 1-d
3080cb0df41SAndrzej Warzyński     // vector to each slice.
309b0b69fd8SBenjamin Maxwell     auto forOp =
3109f7fff7fSCullen Rhodes         createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody);
3110cb0df41SAndrzej Warzyński 
312b0b69fd8SBenjamin Maxwell     rewriter.replaceOp(splatOp, forOp.getResult(0));
3130cb0df41SAndrzej Warzyński 
3140cb0df41SAndrzej Warzyński     return success();
3150cb0df41SAndrzej Warzyński   }
3160cb0df41SAndrzej Warzyński };
3170cb0df41SAndrzej Warzyński 
318eaf15900SCullen Rhodes /// Conversion pattern for vector.transpose.
319eaf15900SCullen Rhodes ///
320eaf15900SCullen Rhodes /// Stores the input tile to memory and reloads vertically.
321eaf15900SCullen Rhodes ///
322eaf15900SCullen Rhodes /// Example:
323eaf15900SCullen Rhodes ///
324eaf15900SCullen Rhodes ///   %transposed_src = vector.transpose %src, [1, 0]
325eaf15900SCullen Rhodes ///     : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
326eaf15900SCullen Rhodes ///
327eaf15900SCullen Rhodes /// is converted to:
328eaf15900SCullen Rhodes ///
329eaf15900SCullen Rhodes ///   %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
330eaf15900SCullen Rhodes ///   %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
331eaf15900SCullen Rhodes ///     : memref<?x?xi32>, vector<[4]x[4]xi32>
332d86047cbSCullen Rhodes ///   %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
333d86047cbSCullen Rhodes ///     layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
334eaf15900SCullen Rhodes ///
335*aa295216SJay Foad /// NOTE: Transposing via memory is obviously expensive, the current intention
336eaf15900SCullen Rhodes /// is to avoid the transpose if possible, this is therefore intended as a
337eaf15900SCullen Rhodes /// fallback and to provide base support for Vector ops. If it turns out
338eaf15900SCullen Rhodes /// transposes can't be avoided then this should be replaced with a more optimal
339eaf15900SCullen Rhodes /// implementation, perhaps with tile <-> vector (MOVA) ops.
340eaf15900SCullen Rhodes struct TransposeOpToArmSMELowering
341eaf15900SCullen Rhodes     : public OpRewritePattern<vector::TransposeOp> {
342eaf15900SCullen Rhodes   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
343eaf15900SCullen Rhodes 
344eaf15900SCullen Rhodes   LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
345eaf15900SCullen Rhodes                                 PatternRewriter &rewriter) const final {
346eaf15900SCullen Rhodes     auto tileType = transposeOp.getResultVectorType();
347eaf15900SCullen Rhodes     if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
348eaf15900SCullen Rhodes       return failure();
349eaf15900SCullen Rhodes 
350eaf15900SCullen Rhodes     // Bail unless this is a true 2-D matrix transpose.
35132c3decbSMatthias Springer     ArrayRef<int64_t> permutation = transposeOp.getPermutation();
35232c3decbSMatthias Springer     if (permutation[0] != 1 || permutation[1] != 0)
353eaf15900SCullen Rhodes       return failure();
354eaf15900SCullen Rhodes 
355eaf15900SCullen Rhodes     auto loc = transposeOp.getLoc();
356bfb5fe21SCullen Rhodes     Value input = transposeOp.getVector();
357bfb5fe21SCullen Rhodes 
358bfb5fe21SCullen Rhodes     if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>();
359bfb5fe21SCullen Rhodes         xferOp && xferOp->hasOneUse()) {
360bfb5fe21SCullen Rhodes       // Fold transpose into transfer_read to enable in-flight transpose when
361bfb5fe21SCullen Rhodes       // converting to arm_sme.tile_load.
362bfb5fe21SCullen Rhodes       rewriter.modifyOpInPlace(xferOp, [&]() {
363bfb5fe21SCullen Rhodes         xferOp->setAttr(xferOp.getPermutationMapAttrName(),
364bfb5fe21SCullen Rhodes                         AffineMapAttr::get(AffineMap::getPermutationMap(
365bfb5fe21SCullen Rhodes                             permutation, transposeOp.getContext())));
366bfb5fe21SCullen Rhodes       });
367bfb5fe21SCullen Rhodes       rewriter.replaceOp(transposeOp, xferOp);
368bfb5fe21SCullen Rhodes       return success();
369bfb5fe21SCullen Rhodes     }
370eaf15900SCullen Rhodes 
371eaf15900SCullen Rhodes     // Allocate buffer to store input tile to.
372eaf15900SCullen Rhodes     Value vscale =
373eaf15900SCullen Rhodes         rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
374eaf15900SCullen Rhodes     Value minTileSlices = rewriter.create<arith::ConstantOp>(
375eaf15900SCullen Rhodes         loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
376eaf15900SCullen Rhodes     Value c0 =
377eaf15900SCullen Rhodes         rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
378eaf15900SCullen Rhodes     Value numTileSlices =
379eaf15900SCullen Rhodes         rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
380eaf15900SCullen Rhodes     auto bufferType =
381eaf15900SCullen Rhodes         MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
382eaf15900SCullen Rhodes                         tileType.getElementType());
383eaf15900SCullen Rhodes     auto buffer = rewriter.create<memref::AllocaOp>(
384eaf15900SCullen Rhodes         loc, bufferType, ValueRange{numTileSlices, numTileSlices});
385eaf15900SCullen Rhodes 
386eaf15900SCullen Rhodes     // Store input tile.
387eaf15900SCullen Rhodes     auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
388eaf15900SCullen Rhodes         loc, input, buffer, ValueRange{c0, c0});
389eaf15900SCullen Rhodes 
390eaf15900SCullen Rhodes     // Reload input tile vertically.
391eaf15900SCullen Rhodes     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
392eaf15900SCullen Rhodes         transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(),
393eaf15900SCullen Rhodes         arm_sme::TileSliceLayout::Vertical);
394eaf15900SCullen Rhodes 
395eaf15900SCullen Rhodes     return success();
396eaf15900SCullen Rhodes   }
397eaf15900SCullen Rhodes };
398eaf15900SCullen Rhodes 
399e6662950SBenjamin Maxwell /// Conversion pattern for vector.outerproduct.
400e6662950SBenjamin Maxwell ///
401e6662950SBenjamin Maxwell /// If the vector.outerproduct is masked (and the mask is from a
402e6662950SBenjamin Maxwell /// vector.create_mask), then the mask is decomposed into two 1-D masks for the
403e6662950SBenjamin Maxwell /// operands.
404e6662950SBenjamin Maxwell ///
405e6662950SBenjamin Maxwell /// Example:
406e6662950SBenjamin Maxwell ///
407e6662950SBenjamin Maxwell ///   %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1>
408e6662950SBenjamin Maxwell ///   %result = vector.mask %mask {
409e6662950SBenjamin Maxwell ///                vector.outerproduct %vecA, %vecB
410e6662950SBenjamin Maxwell ///                 : vector<[4]xf32>, vector<[4]xf32>
411e6662950SBenjamin Maxwell ///             } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
412e6662950SBenjamin Maxwell ///
413e6662950SBenjamin Maxwell /// is converted to:
414e6662950SBenjamin Maxwell ///
415e6662950SBenjamin Maxwell ///    %maskA = vector.create_mask %dimA : vector<[4]xi1>
416e6662950SBenjamin Maxwell ///    %maskB = vector.create_mask %dimB : vector<[4]xi1>
417e6662950SBenjamin Maxwell ///    %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB)
418e6662950SBenjamin Maxwell ///                : vector<[4]xf32>, vector<[4]xf32>
419e6662950SBenjamin Maxwell ///
420e6662950SBenjamin Maxwell /// Unmasked outerproducts can be directly replaced with the arm_sme op.
421e6662950SBenjamin Maxwell ///
422e6662950SBenjamin Maxwell /// Example:
423e6662950SBenjamin Maxwell ///
424e6662950SBenjamin Maxwell ///   %result = vector.outerproduct %vecA, %vecB
425e6662950SBenjamin Maxwell ///              : vector<[4]xf32>, vector<[4]xf32>
426e6662950SBenjamin Maxwell ///
427e6662950SBenjamin Maxwell /// is converted to:
428e6662950SBenjamin Maxwell ///
429e6662950SBenjamin Maxwell ///   %result = arm_sme.outerproduct %vecA, %vecB
430e6662950SBenjamin Maxwell ///              : vector<[4]xf32>, vector<[4]xf32>
431e6662950SBenjamin Maxwell ///
432e6662950SBenjamin Maxwell struct VectorOuterProductToArmSMELowering
433e6662950SBenjamin Maxwell     : public OpRewritePattern<vector::OuterProductOp> {
434e6662950SBenjamin Maxwell 
435e6662950SBenjamin Maxwell   using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern;
436e6662950SBenjamin Maxwell 
437e6662950SBenjamin Maxwell   LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp,
438e6662950SBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
439e6662950SBenjamin Maxwell 
440e6662950SBenjamin Maxwell     // We don't yet support lowering AXPY operations to SME. These could be
441e6662950SBenjamin Maxwell     // lowered by masking out all but the first element of the LHS.
442e6662950SBenjamin Maxwell     if (!isa<VectorType>(outerProductOp.getOperandTypeRHS()))
443e7432babSCullen Rhodes       return rewriter.notifyMatchFailure(outerProductOp,
444e7432babSCullen Rhodes                                          "AXPY operations not supported");
445e6662950SBenjamin Maxwell 
446e6662950SBenjamin Maxwell     if (!arm_sme::isValidSMETileVectorType(
447e6662950SBenjamin Maxwell             outerProductOp.getResultVectorType()))
448e7432babSCullen Rhodes       return rewriter.notifyMatchFailure(
449e7432babSCullen Rhodes           outerProductOp, "outer product does not fit into SME tile");
450e6662950SBenjamin Maxwell 
451e6662950SBenjamin Maxwell     auto kind = outerProductOp.getKind();
452e6662950SBenjamin Maxwell     if (kind != vector::CombiningKind::ADD)
453e7432babSCullen Rhodes       return rewriter.notifyMatchFailure(
454e7432babSCullen Rhodes           outerProductOp,
455e6662950SBenjamin Maxwell           "unsupported kind (lowering to SME only supports ADD at the moment)");
456e6662950SBenjamin Maxwell 
457e6662950SBenjamin Maxwell     Value lhsMask = {};
458e6662950SBenjamin Maxwell     Value rhsMask = {};
459e6662950SBenjamin Maxwell     Operation *rootOp = outerProductOp;
460e6662950SBenjamin Maxwell     auto loc = outerProductOp.getLoc();
461e6662950SBenjamin Maxwell     if (outerProductOp.isMasked()) {
462e6662950SBenjamin Maxwell       auto maskOp = outerProductOp.getMaskingOp();
463e6662950SBenjamin Maxwell       rewriter.setInsertionPoint(maskOp);
464e6662950SBenjamin Maxwell       rootOp = maskOp;
465e6662950SBenjamin Maxwell       auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter);
466e6662950SBenjamin Maxwell       if (failed(operandMasks))
467e6662950SBenjamin Maxwell         return failure();
468e6662950SBenjamin Maxwell       std::tie(lhsMask, rhsMask) = *operandMasks;
469e6662950SBenjamin Maxwell     }
470e6662950SBenjamin Maxwell 
471e6662950SBenjamin Maxwell     rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>(
472e6662950SBenjamin Maxwell         rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(),
473e6662950SBenjamin Maxwell         outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc());
474e6662950SBenjamin Maxwell 
475e6662950SBenjamin Maxwell     return success();
476e6662950SBenjamin Maxwell   }
477e6662950SBenjamin Maxwell 
478e6662950SBenjamin Maxwell   static FailureOr<std::pair<Value, Value>>
479e6662950SBenjamin Maxwell   decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) {
480e6662950SBenjamin Maxwell     // Attempt to extract masks from vector.create_mask.
481e6662950SBenjamin Maxwell     // TODO: Add support for other mask sources.
482e6662950SBenjamin Maxwell     auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
483e6662950SBenjamin Maxwell     if (!createMaskOp)
484e6662950SBenjamin Maxwell       return failure();
485e6662950SBenjamin Maxwell 
486e6662950SBenjamin Maxwell     auto maskType = createMaskOp.getVectorType();
487e6662950SBenjamin Maxwell     Value lhsMaskDim = createMaskOp.getOperand(0);
488e6662950SBenjamin Maxwell     Value rhsMaskDim = createMaskOp.getOperand(1);
489e6662950SBenjamin Maxwell 
490e6662950SBenjamin Maxwell     VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0);
491e6662950SBenjamin Maxwell     Value lhsMask =
492e6662950SBenjamin Maxwell         rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim);
493e6662950SBenjamin Maxwell     Value rhsMask =
494e6662950SBenjamin Maxwell         rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim);
495e6662950SBenjamin Maxwell 
496e6662950SBenjamin Maxwell     return std::make_pair(lhsMask, rhsMask);
497e6662950SBenjamin Maxwell   }
498e6662950SBenjamin Maxwell };
499e6662950SBenjamin Maxwell 
500c4251243SBenjamin Maxwell /// Lower `vector.extract` using `arm_sme.extract_tile_slice`.
501c4c52d41SBenjamin Maxwell ///
502c4c52d41SBenjamin Maxwell /// Example:
503c4c52d41SBenjamin Maxwell /// ```
504c4c52d41SBenjamin Maxwell /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
505c4c52d41SBenjamin Maxwell /// ```
506c4c52d41SBenjamin Maxwell /// Becomes:
507c4c52d41SBenjamin Maxwell /// ```
508c4251243SBenjamin Maxwell /// %slice = arm_sme.extract_tile_slice %tile[%row]
509c4c52d41SBenjamin Maxwell ///            : vector<[4]xi32> from vector<[4]x[4]xi32>
510c4c52d41SBenjamin Maxwell /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
511c4c52d41SBenjamin Maxwell /// ```
512c4c52d41SBenjamin Maxwell struct VectorExtractToArmSMELowering
513c4c52d41SBenjamin Maxwell     : public OpRewritePattern<vector::ExtractOp> {
514c4c52d41SBenjamin Maxwell   using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
515c4c52d41SBenjamin Maxwell 
516c4c52d41SBenjamin Maxwell   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
517c4c52d41SBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
518c4c52d41SBenjamin Maxwell     VectorType sourceType = extractOp.getSourceVectorType();
519c4c52d41SBenjamin Maxwell     if (!arm_sme::isValidSMETileVectorType(sourceType))
520c4c52d41SBenjamin Maxwell       return failure();
521c4c52d41SBenjamin Maxwell 
522c4c52d41SBenjamin Maxwell     auto loc = extractOp.getLoc();
523c4c52d41SBenjamin Maxwell     auto position = extractOp.getMixedPosition();
524c4c52d41SBenjamin Maxwell 
525c4c52d41SBenjamin Maxwell     Value sourceVector = extractOp.getVector();
526c4c52d41SBenjamin Maxwell 
527c4c52d41SBenjamin Maxwell     // Extract entire vector. Should be handled by folder, but just to be safe.
528c4c52d41SBenjamin Maxwell     if (position.empty()) {
529c4c52d41SBenjamin Maxwell       rewriter.replaceOp(extractOp, sourceVector);
530c4c52d41SBenjamin Maxwell       return success();
531c4c52d41SBenjamin Maxwell     }
532c4c52d41SBenjamin Maxwell 
533c4c52d41SBenjamin Maxwell     Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
534c4251243SBenjamin Maxwell     auto extractTileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
535c4251243SBenjamin Maxwell         loc, sourceVector, sliceIndex);
536c4c52d41SBenjamin Maxwell 
537c4c52d41SBenjamin Maxwell     if (position.size() == 1) {
538c4c52d41SBenjamin Maxwell       // Single index case: Extracts a 1D slice.
539c4251243SBenjamin Maxwell       rewriter.replaceOp(extractOp, extractTileSlice);
540c4c52d41SBenjamin Maxwell       return success();
541c4c52d41SBenjamin Maxwell     }
542c4c52d41SBenjamin Maxwell 
543c4c52d41SBenjamin Maxwell     // Two indices case: Extracts a single element.
544c4c52d41SBenjamin Maxwell     assert(position.size() == 2);
545c4251243SBenjamin Maxwell     rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, extractTileSlice,
546c4251243SBenjamin Maxwell                                                    position[1]);
547c4c52d41SBenjamin Maxwell 
548c4c52d41SBenjamin Maxwell     return success();
549c4c52d41SBenjamin Maxwell   }
550c4c52d41SBenjamin Maxwell };
551c4c52d41SBenjamin Maxwell 
552c4251243SBenjamin Maxwell /// Lower `vector.insert` using `arm_sme.insert_tile_slice` and
553c4251243SBenjamin Maxwell /// `arm_sme.extract_tile_slice`.
554c4c52d41SBenjamin Maxwell ///
555c4c52d41SBenjamin Maxwell /// Example:
556c4c52d41SBenjamin Maxwell /// ```
557c4c52d41SBenjamin Maxwell /// %new_tile = vector.insert %el, %tile[%row, %col]
558c4c52d41SBenjamin Maxwell ///                     : i32 into vector<[4]x[4]xi32>
559c4c52d41SBenjamin Maxwell /// ```
560c4c52d41SBenjamin Maxwell /// Becomes:
561c4c52d41SBenjamin Maxwell /// ```
562c4251243SBenjamin Maxwell /// %slice = arm_sme.extract_tile_slice %tile[%row]
563c4c52d41SBenjamin Maxwell ///            : vector<[4]xi32> from vector<[4]x[4]xi32>
564c4c52d41SBenjamin Maxwell /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
565c4251243SBenjamin Maxwell /// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row]
566c4c52d41SBenjamin Maxwell ///               : vector<[4]xi32> into vector<[4]x[4]xi32>
567c4c52d41SBenjamin Maxwell /// ```
568c4c52d41SBenjamin Maxwell struct VectorInsertToArmSMELowering
569c4c52d41SBenjamin Maxwell     : public OpRewritePattern<vector::InsertOp> {
570c4c52d41SBenjamin Maxwell   using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
571c4c52d41SBenjamin Maxwell 
572c4c52d41SBenjamin Maxwell   LogicalResult matchAndRewrite(vector::InsertOp insertOp,
573c4c52d41SBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
574c4c52d41SBenjamin Maxwell     VectorType resultType = insertOp.getResult().getType();
575c4c52d41SBenjamin Maxwell 
576c4c52d41SBenjamin Maxwell     if (!arm_sme::isValidSMETileVectorType(resultType))
577c4c52d41SBenjamin Maxwell       return failure();
578c4c52d41SBenjamin Maxwell 
579c4c52d41SBenjamin Maxwell     auto loc = insertOp.getLoc();
580c4c52d41SBenjamin Maxwell     auto position = insertOp.getMixedPosition();
581c4c52d41SBenjamin Maxwell 
582c4c52d41SBenjamin Maxwell     Value source = insertOp.getSource();
583c4c52d41SBenjamin Maxwell 
584c4c52d41SBenjamin Maxwell     // Overwrite entire vector with value. Should be handled by folder, but
585c4c52d41SBenjamin Maxwell     // just to be safe.
586c4c52d41SBenjamin Maxwell     if (position.empty()) {
587c4c52d41SBenjamin Maxwell       rewriter.replaceOp(insertOp, source);
588c4c52d41SBenjamin Maxwell       return success();
589c4c52d41SBenjamin Maxwell     }
590c4c52d41SBenjamin Maxwell 
591c4c52d41SBenjamin Maxwell     Value tileSlice = source;
592c4c52d41SBenjamin Maxwell     Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
593c4c52d41SBenjamin Maxwell     if (position.size() == 2) {
594c4c52d41SBenjamin Maxwell       // Two indices case: Insert single element into tile.
595c4c52d41SBenjamin Maxwell       // We need to first extract the existing slice and update the element.
596c4251243SBenjamin Maxwell       tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
597c4c52d41SBenjamin Maxwell           loc, insertOp.getDest(), sliceIndex);
598c4c52d41SBenjamin Maxwell       tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
599c4c52d41SBenjamin Maxwell                                                     position[1]);
600c4c52d41SBenjamin Maxwell     }
601c4c52d41SBenjamin Maxwell 
602c4c52d41SBenjamin Maxwell     // Insert the slice into the destination tile.
603c4251243SBenjamin Maxwell     rewriter.replaceOpWithNewOp<arm_sme::InsertTileSliceOp>(
604c4c52d41SBenjamin Maxwell         insertOp, tileSlice, insertOp.getDest(), sliceIndex);
605c4c52d41SBenjamin Maxwell     return success();
606c4c52d41SBenjamin Maxwell   }
607c4c52d41SBenjamin Maxwell };
608c4c52d41SBenjamin Maxwell 
60910063c5aSBenjamin Maxwell /// Lowers `vector.print` of a tile into a loop over the rows of the tile,
610c4251243SBenjamin Maxwell /// extracting them via `arm_sme.extract_tile_slice`, then printing with
61110063c5aSBenjamin Maxwell /// a 1D `vector.print`.
61210063c5aSBenjamin Maxwell ///
61310063c5aSBenjamin Maxwell ///  BEFORE:
61410063c5aSBenjamin Maxwell ///  ```mlir
61510063c5aSBenjamin Maxwell ///  vector.print %tile : vector<[4]x[4]xf32>
61610063c5aSBenjamin Maxwell ///  ```
61710063c5aSBenjamin Maxwell ///  AFTER:
61810063c5aSBenjamin Maxwell ///  ```mlir
61910063c5aSBenjamin Maxwell ///  %c0 = arith.constant 0 : index
62010063c5aSBenjamin Maxwell ///  %c1 = arith.constant 1 : index
62110063c5aSBenjamin Maxwell ///  %c4 = arith.constant 4 : index
62210063c5aSBenjamin Maxwell ///  %vscale = vector.vscale
62310063c5aSBenjamin Maxwell ///  %svl_s = arith.muli %c4, %vscale : index
62410063c5aSBenjamin Maxwell ///  scf.for %i = %c0 to %svl_s step %c1 {
625c4251243SBenjamin Maxwell ///    %tile_slice = arm_sme.extract_tile_slice %tile[%i]
62610063c5aSBenjamin Maxwell ///                     : vector<[4]xf32> from vector<[4]x[4]xf32>
62710063c5aSBenjamin Maxwell ///    vector.print %tile_slice : vector<[4]xf32>
62810063c5aSBenjamin Maxwell ///  }
62910063c5aSBenjamin Maxwell ///  ```
63010063c5aSBenjamin Maxwell struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
63110063c5aSBenjamin Maxwell   using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
63210063c5aSBenjamin Maxwell 
63310063c5aSBenjamin Maxwell   LogicalResult matchAndRewrite(vector::PrintOp printOp,
63410063c5aSBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
63510063c5aSBenjamin Maxwell     if (!printOp.getSource())
63610063c5aSBenjamin Maxwell       return failure();
63710063c5aSBenjamin Maxwell 
63810063c5aSBenjamin Maxwell     VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
63910063c5aSBenjamin Maxwell     if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
64010063c5aSBenjamin Maxwell       return failure();
64110063c5aSBenjamin Maxwell 
64210063c5aSBenjamin Maxwell     auto loc = printOp.getLoc();
64310063c5aSBenjamin Maxwell 
64410063c5aSBenjamin Maxwell     // Create a loop over the rows of the tile.
64510063c5aSBenjamin Maxwell     auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
64610063c5aSBenjamin Maxwell     auto minTileRows =
64710063c5aSBenjamin Maxwell         rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
64810063c5aSBenjamin Maxwell     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
64910063c5aSBenjamin Maxwell     auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
65010063c5aSBenjamin Maxwell     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
65110063c5aSBenjamin Maxwell     auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
65210063c5aSBenjamin Maxwell     {
65310063c5aSBenjamin Maxwell       // Loop body.
65410063c5aSBenjamin Maxwell       rewriter.setInsertionPointToStart(forOp.getBody());
65510063c5aSBenjamin Maxwell       // Extract the current row from the tile.
65610063c5aSBenjamin Maxwell       Value rowIndex = forOp.getInductionVar();
657c4251243SBenjamin Maxwell       auto tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>(
65810063c5aSBenjamin Maxwell           loc, printOp.getSource(), rowIndex);
65910063c5aSBenjamin Maxwell       // Print the row with a 1D vector.print.
66010063c5aSBenjamin Maxwell       rewriter.create<vector::PrintOp>(loc, tileSlice,
66110063c5aSBenjamin Maxwell                                        printOp.getPunctuation());
66210063c5aSBenjamin Maxwell     }
66310063c5aSBenjamin Maxwell 
66410063c5aSBenjamin Maxwell     rewriter.eraseOp(printOp);
66510063c5aSBenjamin Maxwell     return success();
66610063c5aSBenjamin Maxwell   }
66710063c5aSBenjamin Maxwell };
66810063c5aSBenjamin Maxwell 
669c4251243SBenjamin Maxwell /// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp.
6704d6b9921SBenjamin Maxwell ///
6714d6b9921SBenjamin Maxwell ///  BEFORE:
6724d6b9921SBenjamin Maxwell ///  ```mlir
673c4251243SBenjamin Maxwell ///  %slice = arm_sme.extract_tile_slice %tile[%index]
6744d6b9921SBenjamin Maxwell ///             : vector<[4]xf32> from vector<[4]x[4]xf32>
6754d6b9921SBenjamin Maxwell ///  vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
6764d6b9921SBenjamin Maxwell ///             : vector<[4]xf32>, memref<?x?xf32>
6774d6b9921SBenjamin Maxwell ///  ```
6784d6b9921SBenjamin Maxwell ///  AFTER:
6794d6b9921SBenjamin Maxwell ///  ```mlir
6804d6b9921SBenjamin Maxwell ///  arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
6814d6b9921SBenjamin Maxwell ///             : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
6824d6b9921SBenjamin Maxwell ///  ```
6834d6b9921SBenjamin Maxwell struct FoldTransferWriteOfExtractTileSlice
6844d6b9921SBenjamin Maxwell     : public OpRewritePattern<vector::TransferWriteOp> {
6854d6b9921SBenjamin Maxwell   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
6864d6b9921SBenjamin Maxwell 
6874d6b9921SBenjamin Maxwell   LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
6884d6b9921SBenjamin Maxwell                                 PatternRewriter &rewriter) const final {
6894d6b9921SBenjamin Maxwell     if (!isa<MemRefType>(writeOp.getSource().getType()))
6904d6b9921SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
6914d6b9921SBenjamin Maxwell 
6924d6b9921SBenjamin Maxwell     if (writeOp.hasOutOfBoundsDim())
6934d6b9921SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp,
6944d6b9921SBenjamin Maxwell                                          "not inbounds transfer write");
6954d6b9921SBenjamin Maxwell 
696c4251243SBenjamin Maxwell     auto extractTileSlice =
697c4251243SBenjamin Maxwell         writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>();
698c4251243SBenjamin Maxwell     if (!extractTileSlice)
6994d6b9921SBenjamin Maxwell       return rewriter.notifyMatchFailure(
700c4251243SBenjamin Maxwell           writeOp, "vector to store not from ExtractTileSliceOp");
7014d6b9921SBenjamin Maxwell 
7024d6b9921SBenjamin Maxwell     AffineMap map = writeOp.getPermutationMap();
7034d6b9921SBenjamin Maxwell     if (!map.isMinorIdentity())
7044d6b9921SBenjamin Maxwell       return rewriter.notifyMatchFailure(writeOp,
7054d6b9921SBenjamin Maxwell                                          "unsupported permutation map");
7064d6b9921SBenjamin Maxwell 
7074d6b9921SBenjamin Maxwell     Value mask = writeOp.getMask();
7084d6b9921SBenjamin Maxwell     if (!mask) {
7094d6b9921SBenjamin Maxwell       auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
7104d6b9921SBenjamin Maxwell       mask = rewriter.create<arith::ConstantOp>(
7114d6b9921SBenjamin Maxwell           writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
7124d6b9921SBenjamin Maxwell     }
7134d6b9921SBenjamin Maxwell 
7144d6b9921SBenjamin Maxwell     rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
715c4251243SBenjamin Maxwell         writeOp, extractTileSlice.getTile(),
716c4251243SBenjamin Maxwell         extractTileSlice.getTileSliceIndex(), mask, writeOp.getSource(),
717c4251243SBenjamin Maxwell         writeOp.getIndices(), extractTileSlice.getLayout());
7184d6b9921SBenjamin Maxwell     return success();
7194d6b9921SBenjamin Maxwell   }
7204d6b9921SBenjamin Maxwell };
7214d6b9921SBenjamin Maxwell 
722e2296d82SBenjamin Maxwell /// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
723e2296d82SBenjamin Maxwell /// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
724e2296d82SBenjamin Maxwell /// SVE 2.1), so this is currently the most logical place for this lowering.
725e2296d82SBenjamin Maxwell ///
726e2296d82SBenjamin Maxwell /// Example:
727e2296d82SBenjamin Maxwell /// ```mlir
728e2296d82SBenjamin Maxwell /// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
729e2296d82SBenjamin Maxwell /// %slice = vector.extract %mask[%index]
730e2296d82SBenjamin Maxwell ///            : vector<[8]xi1> from vector<[4]x[8]xi1>
731e2296d82SBenjamin Maxwell /// ```
732e2296d82SBenjamin Maxwell /// Becomes:
733e2296d82SBenjamin Maxwell /// ```
734e2296d82SBenjamin Maxwell /// %mask_rows = vector.create_mask %a : vector<[4]xi1>
735e2296d82SBenjamin Maxwell /// %mask_cols = vector.create_mask %b : vector<[8]xi1>
736e2296d82SBenjamin Maxwell /// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
737e2296d82SBenjamin Maxwell ///            : vector<[8]xi1>, vector<[4]xi1>
738e2296d82SBenjamin Maxwell /// ```
739e2296d82SBenjamin Maxwell struct ExtractFromCreateMaskToPselLowering
740e2296d82SBenjamin Maxwell     : public OpRewritePattern<vector::ExtractOp> {
741e2296d82SBenjamin Maxwell   using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
742e2296d82SBenjamin Maxwell 
743e2296d82SBenjamin Maxwell   LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
744e2296d82SBenjamin Maxwell                                 PatternRewriter &rewriter) const override {
745e2296d82SBenjamin Maxwell     if (extractOp.getNumIndices() != 1)
746e2296d82SBenjamin Maxwell       return rewriter.notifyMatchFailure(extractOp, "not single extract index");
747e2296d82SBenjamin Maxwell 
748e2296d82SBenjamin Maxwell     auto resultType = extractOp.getResult().getType();
749e2296d82SBenjamin Maxwell     auto resultVectorType = dyn_cast<VectorType>(resultType);
750e2296d82SBenjamin Maxwell     if (!resultVectorType)
751e2296d82SBenjamin Maxwell       return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
752e2296d82SBenjamin Maxwell 
753e2296d82SBenjamin Maxwell     auto createMaskOp =
754e2296d82SBenjamin Maxwell         extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
755e2296d82SBenjamin Maxwell     if (!createMaskOp)
756e2296d82SBenjamin Maxwell       return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
757e2296d82SBenjamin Maxwell 
758e2296d82SBenjamin Maxwell     auto maskType = createMaskOp.getVectorType();
759e2296d82SBenjamin Maxwell     if (maskType.getRank() != 2 || !maskType.allDimsScalable())
760e2296d82SBenjamin Maxwell       return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
761e2296d82SBenjamin Maxwell 
762e2296d82SBenjamin Maxwell     auto isSVEPredicateSize = [](int64_t size) {
763e2296d82SBenjamin Maxwell       return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
764e2296d82SBenjamin Maxwell     };
765e2296d82SBenjamin Maxwell 
766e2296d82SBenjamin Maxwell     auto rowsBaseSize = maskType.getDimSize(0);
767e2296d82SBenjamin Maxwell     auto colsBaseSize = maskType.getDimSize(1);
768e2296d82SBenjamin Maxwell     if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
769e2296d82SBenjamin Maxwell       return rewriter.notifyMatchFailure(
770e2296d82SBenjamin Maxwell           createMaskOp, "mask dimensions not SVE predicate-sized");
771e2296d82SBenjamin Maxwell 
772e2296d82SBenjamin Maxwell     auto loc = extractOp.getLoc();
773e2296d82SBenjamin Maxwell     VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
774e2296d82SBenjamin Maxwell     VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
775e2296d82SBenjamin Maxwell 
776e2296d82SBenjamin Maxwell     // Create the two 1-D masks at the location of the 2-D create_mask (which is
777e2296d82SBenjamin Maxwell     // usually outside a loop). This prevents the need for later hoisting.
778e2296d82SBenjamin Maxwell     rewriter.setInsertionPoint(createMaskOp);
779e2296d82SBenjamin Maxwell     auto rowMask = rewriter.create<vector::CreateMaskOp>(
780e2296d82SBenjamin Maxwell         loc, rowMaskType, createMaskOp.getOperand(0));
781e2296d82SBenjamin Maxwell     auto colMask = rewriter.create<vector::CreateMaskOp>(
782e2296d82SBenjamin Maxwell         loc, colMaskType, createMaskOp.getOperand(1));
783e2296d82SBenjamin Maxwell 
784e2296d82SBenjamin Maxwell     rewriter.setInsertionPoint(extractOp);
785e2296d82SBenjamin Maxwell     auto position =
786e2296d82SBenjamin Maxwell         vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
787e2296d82SBenjamin Maxwell     rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
788e2296d82SBenjamin Maxwell                                                  position[0]);
789e2296d82SBenjamin Maxwell     return success();
790e2296d82SBenjamin Maxwell   }
791e2296d82SBenjamin Maxwell };
792e2296d82SBenjamin Maxwell 
793447bb5beSAndrzej Warzynski } // namespace
794447bb5beSAndrzej Warzynski 
795447bb5beSAndrzej Warzynski void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
796447bb5beSAndrzej Warzynski                                           MLIRContext &ctx) {
797e2296d82SBenjamin Maxwell   patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
7989f7fff7fSCullen Rhodes                TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
7999f7fff7fSCullen Rhodes                TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
8009f7fff7fSCullen Rhodes                VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
8019f7fff7fSCullen Rhodes                VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
802e2296d82SBenjamin Maxwell                VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
803e2296d82SBenjamin Maxwell                ExtractFromCreateMaskToPselLowering>(&ctx);
804447bb5beSAndrzej Warzynski }
805