1447bb5beSAndrzej Warzynski //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===// 2447bb5beSAndrzej Warzynski // 3447bb5beSAndrzej Warzynski // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4447bb5beSAndrzej Warzynski // See https://llvm.org/LICENSE.txt for license information. 5447bb5beSAndrzej Warzynski // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6447bb5beSAndrzej Warzynski // 7447bb5beSAndrzej Warzynski //===----------------------------------------------------------------------===// 8447bb5beSAndrzej Warzynski 9447bb5beSAndrzej Warzynski #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" 10447bb5beSAndrzej Warzynski 11447bb5beSAndrzej Warzynski #include "mlir/Dialect/ArmSME/IR/ArmSME.h" 12ca9a3354SCullen Rhodes #include "mlir/Dialect/ArmSME/Utils/Utils.h" 13e2296d82SBenjamin Maxwell #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" 14eaf15900SCullen Rhodes #include "mlir/Dialect/MemRef/IR/MemRef.h" 15447bb5beSAndrzej Warzynski #include "mlir/IR/BuiltinTypes.h" 16447bb5beSAndrzej Warzynski #include "llvm/Support/Casting.h" 17447bb5beSAndrzej Warzynski 18447bb5beSAndrzej Warzynski using namespace mlir; 19447bb5beSAndrzej Warzynski 20447bb5beSAndrzej Warzynski namespace { 21447bb5beSAndrzej Warzynski 2222f11592SCullen Rhodes /// Conversion pattern for vector.transfer_read. 2322f11592SCullen Rhodes /// 2422f11592SCullen Rhodes /// --- 2522f11592SCullen Rhodes /// 2622f11592SCullen Rhodes /// Example 1: op with identity permutation map to horizontal 2722f11592SCullen Rhodes /// arm_sme.tile_load: 2822f11592SCullen Rhodes /// 2922f11592SCullen Rhodes /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1) 3022f11592SCullen Rhodes /// 3122f11592SCullen Rhodes /// is converted to: 3222f11592SCullen Rhodes /// 3322f11592SCullen Rhodes /// arm_sme.tile_load ... 3422f11592SCullen Rhodes /// 3522f11592SCullen Rhodes /// --- 3622f11592SCullen Rhodes /// 3722f11592SCullen Rhodes /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load 3822f11592SCullen Rhodes /// (in-flight transpose): 398e64e9c3SCullen Rhodes /// 408e64e9c3SCullen Rhodes /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0) 418e64e9c3SCullen Rhodes /// 428e64e9c3SCullen Rhodes /// is converted to: 438e64e9c3SCullen Rhodes /// 44d86047cbSCullen Rhodes /// arm_sme.tile_load ... layout<vertical> 4522f11592SCullen Rhodes struct TransferReadToArmSMELowering 468e64e9c3SCullen Rhodes : public OpRewritePattern<vector::TransferReadOp> { 478e64e9c3SCullen Rhodes using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 488e64e9c3SCullen Rhodes 498e64e9c3SCullen Rhodes LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 508e64e9c3SCullen Rhodes PatternRewriter &rewriter) const final { 518e64e9c3SCullen Rhodes // The permutation map must have two results. 528e64e9c3SCullen Rhodes if (transferReadOp.getTransferRank() != 2) 538e64e9c3SCullen Rhodes return rewriter.notifyMatchFailure(transferReadOp, 548e64e9c3SCullen Rhodes "not a 2 result permutation map"); 558e64e9c3SCullen Rhodes 568e64e9c3SCullen Rhodes auto vectorType = transferReadOp.getVectorType(); 578e64e9c3SCullen Rhodes if (!arm_sme::isValidSMETileVectorType(vectorType)) 588e64e9c3SCullen Rhodes return rewriter.notifyMatchFailure(transferReadOp, 598e64e9c3SCullen Rhodes "not a valid vector type for SME"); 608e64e9c3SCullen Rhodes 618e64e9c3SCullen Rhodes if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType())) 628e64e9c3SCullen Rhodes return rewriter.notifyMatchFailure(transferReadOp, "not a memref source"); 638e64e9c3SCullen Rhodes 648e64e9c3SCullen Rhodes // Out-of-bounds dims are not supported. 658e64e9c3SCullen Rhodes if (transferReadOp.hasOutOfBoundsDim()) 668e64e9c3SCullen Rhodes return rewriter.notifyMatchFailure(transferReadOp, 678e64e9c3SCullen Rhodes "not inbounds transfer read"); 688e64e9c3SCullen Rhodes 6922f11592SCullen Rhodes AffineMap map = transferReadOp.getPermutationMap(); 70b49c0b8aSCullen Rhodes if (!map.isPermutation()) 718e64e9c3SCullen Rhodes return rewriter.notifyMatchFailure(transferReadOp, 7222f11592SCullen Rhodes "unsupported permutation map"); 738e64e9c3SCullen Rhodes 74b49c0b8aSCullen Rhodes // Note: For 2D vector types the only non-identity permutation is a simple 75b49c0b8aSCullen Rhodes // transpose [1, 0]. 76b49c0b8aSCullen Rhodes bool transposed = !map.isIdentity(); 77b49c0b8aSCullen Rhodes arm_sme::TileSliceLayout layout = 78b49c0b8aSCullen Rhodes transposed ? arm_sme::TileSliceLayout::Vertical 79b49c0b8aSCullen Rhodes : arm_sme::TileSliceLayout::Horizontal; 80b49c0b8aSCullen Rhodes 8122f11592SCullen Rhodes // Padding isn't optional for transfer_read, but is only used in the case 8222f11592SCullen Rhodes // of out-of-bounds accesses (not supported here) and/or masking. Mask is 8322f11592SCullen Rhodes // optional, if it's not present don't pass padding. 8422f11592SCullen Rhodes auto mask = transferReadOp.getMask(); 8522f11592SCullen Rhodes auto padding = mask ? transferReadOp.getPadding() : nullptr; 868e64e9c3SCullen Rhodes rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( 878e64e9c3SCullen Rhodes transferReadOp, vectorType, transferReadOp.getSource(), 8822f11592SCullen Rhodes transferReadOp.getIndices(), padding, mask, layout); 898e64e9c3SCullen Rhodes 908e64e9c3SCullen Rhodes return success(); 918e64e9c3SCullen Rhodes } 928e64e9c3SCullen Rhodes }; 938e64e9c3SCullen Rhodes 9412e1a9b8SCullen Rhodes /// Conversion pattern for vector.transfer_write. 95447bb5beSAndrzej Warzynski /// 964240b179SCullen Rhodes /// --- 974240b179SCullen Rhodes /// 984240b179SCullen Rhodes /// Example 1: op with identity permutation map to horizontal 994240b179SCullen Rhodes /// arm_sme.tile_store: 1004240b179SCullen Rhodes /// 1014240b179SCullen Rhodes /// vector.transfer_write %vector, %source[%c0, %c0] 1024240b179SCullen Rhodes /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> 103447bb5beSAndrzej Warzynski /// 104447bb5beSAndrzej Warzynski /// is converted to: 105447bb5beSAndrzej Warzynski /// 10612e1a9b8SCullen Rhodes /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>, 107447bb5beSAndrzej Warzynski /// vector<[16]x[16]xi8> 1084240b179SCullen Rhodes /// --- 1094240b179SCullen Rhodes /// 1104240b179SCullen Rhodes /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store 1114240b179SCullen Rhodes /// (in-flight transpose): 1124240b179SCullen Rhodes /// 1134240b179SCullen Rhodes /// vector.transfer_write %vector, %source[%c0, %c0] 1144240b179SCullen Rhodes /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, 1154240b179SCullen Rhodes /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> 1164240b179SCullen Rhodes /// 1174240b179SCullen Rhodes /// is converted to: 1184240b179SCullen Rhodes /// 1194240b179SCullen Rhodes /// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical> 1204240b179SCullen Rhodes /// : memref<?x?xi8>, vector<[16]x[16]xi8> 121447bb5beSAndrzej Warzynski struct TransferWriteToArmSMELowering 122447bb5beSAndrzej Warzynski : public OpRewritePattern<vector::TransferWriteOp> { 123447bb5beSAndrzej Warzynski using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 124447bb5beSAndrzej Warzynski 125447bb5beSAndrzej Warzynski LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 126447bb5beSAndrzej Warzynski PatternRewriter &rewriter) const final { 127447bb5beSAndrzej Warzynski auto vType = writeOp.getVectorType(); 12812e1a9b8SCullen Rhodes if (!arm_sme::isValidSMETileVectorType(vType)) 129447bb5beSAndrzej Warzynski return failure(); 130447bb5beSAndrzej Warzynski 131447bb5beSAndrzej Warzynski if (!llvm::isa<MemRefType>(writeOp.getSource().getType())) 132447bb5beSAndrzej Warzynski return failure(); 133447bb5beSAndrzej Warzynski 1344240b179SCullen Rhodes // Out-of-bounds dims are not supported. 1354240b179SCullen Rhodes if (writeOp.hasOutOfBoundsDim()) 1364240b179SCullen Rhodes return rewriter.notifyMatchFailure(writeOp, 1374240b179SCullen Rhodes "not inbounds transfer write"); 1384240b179SCullen Rhodes 1394240b179SCullen Rhodes AffineMap map = writeOp.getPermutationMap(); 140b49c0b8aSCullen Rhodes if (!map.isPermutation()) 1414240b179SCullen Rhodes return rewriter.notifyMatchFailure(writeOp, 1424240b179SCullen Rhodes "unsupported permutation map"); 1434240b179SCullen Rhodes 144b49c0b8aSCullen Rhodes // Note: For 2D vector types the only non-identity permutation is a simple 145b49c0b8aSCullen Rhodes // transpose [1, 0]. 146b49c0b8aSCullen Rhodes bool transposed = !map.isIdentity(); 1474240b179SCullen Rhodes arm_sme::TileSliceLayout layout = 148b49c0b8aSCullen Rhodes transposed ? arm_sme::TileSliceLayout::Vertical 1494240b179SCullen Rhodes : arm_sme::TileSliceLayout::Horizontal; 1504240b179SCullen Rhodes 151447bb5beSAndrzej Warzynski rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>( 1521908f47aSCullen Rhodes writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(), 1534240b179SCullen Rhodes writeOp.getMask(), layout); 154447bb5beSAndrzej Warzynski return success(); 155447bb5beSAndrzej Warzynski } 156447bb5beSAndrzej Warzynski }; 157447bb5beSAndrzej Warzynski 158ca9a3354SCullen Rhodes /// Conversion pattern for vector.load. 159ca9a3354SCullen Rhodes struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> { 160ca9a3354SCullen Rhodes using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 161ca9a3354SCullen Rhodes 162ca9a3354SCullen Rhodes LogicalResult matchAndRewrite(vector::LoadOp load, 163ca9a3354SCullen Rhodes PatternRewriter &rewriter) const override { 164ca9a3354SCullen Rhodes if (!arm_sme::isValidSMETileVectorType(load.getVectorType())) 165ca9a3354SCullen Rhodes return failure(); 166ca9a3354SCullen Rhodes 167ca9a3354SCullen Rhodes rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( 168ca9a3354SCullen Rhodes load, load.getVectorType(), load.getBase(), load.getIndices()); 169ca9a3354SCullen Rhodes 170ca9a3354SCullen Rhodes return success(); 171ca9a3354SCullen Rhodes } 172ca9a3354SCullen Rhodes }; 173ca9a3354SCullen Rhodes 174ca9a3354SCullen Rhodes /// Conversion pattern for vector.store. 175ca9a3354SCullen Rhodes struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> { 176ca9a3354SCullen Rhodes using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 177ca9a3354SCullen Rhodes 178ca9a3354SCullen Rhodes LogicalResult matchAndRewrite(vector::StoreOp store, 179ca9a3354SCullen Rhodes PatternRewriter &rewriter) const override { 180ca9a3354SCullen Rhodes if (!arm_sme::isValidSMETileVectorType(store.getVectorType())) 181ca9a3354SCullen Rhodes return failure(); 182ca9a3354SCullen Rhodes 183ca9a3354SCullen Rhodes rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>( 184ca9a3354SCullen Rhodes store, store.getValueToStore(), store.getBase(), store.getIndices()); 185ca9a3354SCullen Rhodes 186ca9a3354SCullen Rhodes return success(); 187ca9a3354SCullen Rhodes } 188ca9a3354SCullen Rhodes }; 189ca9a3354SCullen Rhodes 1902dd3f420SCullen Rhodes /// Conversion pattern for vector.broadcast. 1912dd3f420SCullen Rhodes /// 1922dd3f420SCullen Rhodes /// Example: 1932dd3f420SCullen Rhodes /// 1942dd3f420SCullen Rhodes /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32> 1952dd3f420SCullen Rhodes /// 1962dd3f420SCullen Rhodes /// is converted to: 1972dd3f420SCullen Rhodes /// 1982dd3f420SCullen Rhodes /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> 199b0b69fd8SBenjamin Maxwell /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices 200b0b69fd8SBenjamin Maxwell /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) 201b0b69fd8SBenjamin Maxwell /// { 202c4251243SBenjamin Maxwell /// %tile_update = arm_sme.insert_tile_slice 203c4251243SBenjamin Maxwell /// %broadcast_to_1d, %iter_tile[%tile_slice_index] : 204b0b69fd8SBenjamin Maxwell /// vector<[4]xi32> into vector<[4]x[4]xi32> 205b0b69fd8SBenjamin Maxwell /// scf.yield %tile_update : vector<[4]x[4]xi32> 2062dd3f420SCullen Rhodes /// } 2072dd3f420SCullen Rhodes /// 2082dd3f420SCullen Rhodes /// Supports scalar, 0-d vector, and 1-d vector broadcasts. 2092dd3f420SCullen Rhodes struct BroadcastOpToArmSMELowering 2102dd3f420SCullen Rhodes : public OpRewritePattern<vector::BroadcastOp> { 2112dd3f420SCullen Rhodes using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 2122dd3f420SCullen Rhodes 2132dd3f420SCullen Rhodes LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, 2142dd3f420SCullen Rhodes PatternRewriter &rewriter) const final { 2152dd3f420SCullen Rhodes auto tileType = broadcastOp.getResultVectorType(); 2162dd3f420SCullen Rhodes if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) 2172dd3f420SCullen Rhodes return failure(); 2182dd3f420SCullen Rhodes 2192dd3f420SCullen Rhodes auto loc = broadcastOp.getLoc(); 2202dd3f420SCullen Rhodes 2212dd3f420SCullen Rhodes auto srcType = broadcastOp.getSourceType(); 2222dd3f420SCullen Rhodes auto srcVectorType = dyn_cast<VectorType>(srcType); 2232dd3f420SCullen Rhodes 2242dd3f420SCullen Rhodes Value broadcastOp1D; 2252dd3f420SCullen Rhodes if (srcType.isIntOrFloat() || 2262dd3f420SCullen Rhodes (srcVectorType && (srcVectorType.getRank() == 0))) { 2272dd3f420SCullen Rhodes // Broadcast scalar or 0-d vector to 1-d vector. 228b0b69fd8SBenjamin Maxwell VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); 2292dd3f420SCullen Rhodes broadcastOp1D = rewriter.create<vector::BroadcastOp>( 2302dd3f420SCullen Rhodes loc, tileSliceType, broadcastOp.getSource()); 2312dd3f420SCullen Rhodes } else if (srcVectorType && (srcVectorType.getRank() == 1)) 2322dd3f420SCullen Rhodes // Value to broadcast is already a 1-d vector, nothing to do. 2332dd3f420SCullen Rhodes broadcastOp1D = broadcastOp.getSource(); 2342dd3f420SCullen Rhodes else 2352dd3f420SCullen Rhodes return failure(); 2362dd3f420SCullen Rhodes 237b0b69fd8SBenjamin Maxwell auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); 2382dd3f420SCullen Rhodes 2399f7fff7fSCullen Rhodes auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, 2409f7fff7fSCullen Rhodes Value currentTile) { 241c4251243SBenjamin Maxwell // Create 'arm_sme.insert_tile_slice' to broadcast the value 242b0b69fd8SBenjamin Maxwell // to each tile slice. 243c4251243SBenjamin Maxwell auto nextTile = b.create<arm_sme::InsertTileSliceOp>( 244b0b69fd8SBenjamin Maxwell loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); 2459f7fff7fSCullen Rhodes return nextTile.getResult(); 2469f7fff7fSCullen Rhodes }; 2479f7fff7fSCullen Rhodes 2489f7fff7fSCullen Rhodes // Create a loop over ZA tile slices. 2499f7fff7fSCullen Rhodes auto forOp = 2509f7fff7fSCullen Rhodes createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); 2512dd3f420SCullen Rhodes 252b0b69fd8SBenjamin Maxwell rewriter.replaceOp(broadcastOp, forOp.getResult(0)); 2532dd3f420SCullen Rhodes 2542dd3f420SCullen Rhodes return success(); 2552dd3f420SCullen Rhodes } 2562dd3f420SCullen Rhodes }; 2572dd3f420SCullen Rhodes 2580cb0df41SAndrzej Warzyński /// Conversion pattern for vector.splat. 2590cb0df41SAndrzej Warzyński /// 2600cb0df41SAndrzej Warzyński /// Example: 2610cb0df41SAndrzej Warzyński /// 2620cb0df41SAndrzej Warzyński /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32> 2630cb0df41SAndrzej Warzyński /// 2640cb0df41SAndrzej Warzyński /// is converted to: 2650cb0df41SAndrzej Warzyński /// 2660cb0df41SAndrzej Warzyński /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> 267b0b69fd8SBenjamin Maxwell /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices 268b0b69fd8SBenjamin Maxwell /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) 269b0b69fd8SBenjamin Maxwell /// { 270c4251243SBenjamin Maxwell /// %tile_update = arm_sme.insert_tile_slice 271c4251243SBenjamin Maxwell /// %broadcast_to_1d, %iter_tile[%tile_slice_index] : 272b0b69fd8SBenjamin Maxwell /// vector<[4]xi32> into vector<[4]x[4]xi32> 273b0b69fd8SBenjamin Maxwell /// scf.yield %tile_update : vector<[4]x[4]xi32> 2740cb0df41SAndrzej Warzyński /// } 2750cb0df41SAndrzej Warzyński /// 2760cb0df41SAndrzej Warzyński /// This is identical to vector.broadcast of a scalar. 2770cb0df41SAndrzej Warzyński struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> { 2780cb0df41SAndrzej Warzyński using OpRewritePattern<vector::SplatOp>::OpRewritePattern; 2790cb0df41SAndrzej Warzyński 2800cb0df41SAndrzej Warzyński LogicalResult matchAndRewrite(vector::SplatOp splatOp, 2810cb0df41SAndrzej Warzyński PatternRewriter &rewriter) const final { 2820cb0df41SAndrzej Warzyński auto tileType = splatOp.getResult().getType(); 2830cb0df41SAndrzej Warzyński if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) 2840cb0df41SAndrzej Warzyński return failure(); 2850cb0df41SAndrzej Warzyński 2860cb0df41SAndrzej Warzyński auto loc = splatOp.getLoc(); 2870cb0df41SAndrzej Warzyński auto srcType = splatOp.getOperand().getType(); 2880cb0df41SAndrzej Warzyński 2890cb0df41SAndrzej Warzyński assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat"); 290042468bfSGoran Flegar // Avoid unused-variable warning when building without assertions. 291042468bfSGoran Flegar (void)srcType; 2920cb0df41SAndrzej Warzyński 2930cb0df41SAndrzej Warzyński // First, broadcast the scalar to a 1-d vector. 2940cb0df41SAndrzej Warzyński VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); 2950cb0df41SAndrzej Warzyński Value broadcastOp1D = rewriter.create<vector::BroadcastOp>( 2960cb0df41SAndrzej Warzyński loc, tileSliceType, splatOp.getInput()); 2970cb0df41SAndrzej Warzyński 298b0b69fd8SBenjamin Maxwell auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); 2990cb0df41SAndrzej Warzyński 3009f7fff7fSCullen Rhodes auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, 3019f7fff7fSCullen Rhodes Value currentTile) { 302c4251243SBenjamin Maxwell auto nextTile = b.create<arm_sme::InsertTileSliceOp>( 3039f7fff7fSCullen Rhodes loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); 3049f7fff7fSCullen Rhodes return nextTile.getResult(); 3059f7fff7fSCullen Rhodes }; 3069f7fff7fSCullen Rhodes 3070cb0df41SAndrzej Warzyński // Next, create a loop over ZA tile slices and "move" the generated 1-d 3080cb0df41SAndrzej Warzyński // vector to each slice. 309b0b69fd8SBenjamin Maxwell auto forOp = 3109f7fff7fSCullen Rhodes createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); 3110cb0df41SAndrzej Warzyński 312b0b69fd8SBenjamin Maxwell rewriter.replaceOp(splatOp, forOp.getResult(0)); 3130cb0df41SAndrzej Warzyński 3140cb0df41SAndrzej Warzyński return success(); 3150cb0df41SAndrzej Warzyński } 3160cb0df41SAndrzej Warzyński }; 3170cb0df41SAndrzej Warzyński 318eaf15900SCullen Rhodes /// Conversion pattern for vector.transpose. 319eaf15900SCullen Rhodes /// 320eaf15900SCullen Rhodes /// Stores the input tile to memory and reloads vertically. 321eaf15900SCullen Rhodes /// 322eaf15900SCullen Rhodes /// Example: 323eaf15900SCullen Rhodes /// 324eaf15900SCullen Rhodes /// %transposed_src = vector.transpose %src, [1, 0] 325eaf15900SCullen Rhodes /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> 326eaf15900SCullen Rhodes /// 327eaf15900SCullen Rhodes /// is converted to: 328eaf15900SCullen Rhodes /// 329eaf15900SCullen Rhodes /// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32> 330eaf15900SCullen Rhodes /// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0] 331eaf15900SCullen Rhodes /// : memref<?x?xi32>, vector<[4]x[4]xi32> 332d86047cbSCullen Rhodes /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0] 333d86047cbSCullen Rhodes /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32> 334eaf15900SCullen Rhodes /// 335*aa295216SJay Foad /// NOTE: Transposing via memory is obviously expensive, the current intention 336eaf15900SCullen Rhodes /// is to avoid the transpose if possible, this is therefore intended as a 337eaf15900SCullen Rhodes /// fallback and to provide base support for Vector ops. If it turns out 338eaf15900SCullen Rhodes /// transposes can't be avoided then this should be replaced with a more optimal 339eaf15900SCullen Rhodes /// implementation, perhaps with tile <-> vector (MOVA) ops. 340eaf15900SCullen Rhodes struct TransposeOpToArmSMELowering 341eaf15900SCullen Rhodes : public OpRewritePattern<vector::TransposeOp> { 342eaf15900SCullen Rhodes using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 343eaf15900SCullen Rhodes 344eaf15900SCullen Rhodes LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 345eaf15900SCullen Rhodes PatternRewriter &rewriter) const final { 346eaf15900SCullen Rhodes auto tileType = transposeOp.getResultVectorType(); 347eaf15900SCullen Rhodes if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) 348eaf15900SCullen Rhodes return failure(); 349eaf15900SCullen Rhodes 350eaf15900SCullen Rhodes // Bail unless this is a true 2-D matrix transpose. 35132c3decbSMatthias Springer ArrayRef<int64_t> permutation = transposeOp.getPermutation(); 35232c3decbSMatthias Springer if (permutation[0] != 1 || permutation[1] != 0) 353eaf15900SCullen Rhodes return failure(); 354eaf15900SCullen Rhodes 355eaf15900SCullen Rhodes auto loc = transposeOp.getLoc(); 356bfb5fe21SCullen Rhodes Value input = transposeOp.getVector(); 357bfb5fe21SCullen Rhodes 358bfb5fe21SCullen Rhodes if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>(); 359bfb5fe21SCullen Rhodes xferOp && xferOp->hasOneUse()) { 360bfb5fe21SCullen Rhodes // Fold transpose into transfer_read to enable in-flight transpose when 361bfb5fe21SCullen Rhodes // converting to arm_sme.tile_load. 362bfb5fe21SCullen Rhodes rewriter.modifyOpInPlace(xferOp, [&]() { 363bfb5fe21SCullen Rhodes xferOp->setAttr(xferOp.getPermutationMapAttrName(), 364bfb5fe21SCullen Rhodes AffineMapAttr::get(AffineMap::getPermutationMap( 365bfb5fe21SCullen Rhodes permutation, transposeOp.getContext()))); 366bfb5fe21SCullen Rhodes }); 367bfb5fe21SCullen Rhodes rewriter.replaceOp(transposeOp, xferOp); 368bfb5fe21SCullen Rhodes return success(); 369bfb5fe21SCullen Rhodes } 370eaf15900SCullen Rhodes 371eaf15900SCullen Rhodes // Allocate buffer to store input tile to. 372eaf15900SCullen Rhodes Value vscale = 373eaf15900SCullen Rhodes rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 374eaf15900SCullen Rhodes Value minTileSlices = rewriter.create<arith::ConstantOp>( 375eaf15900SCullen Rhodes loc, rewriter.getIndexAttr(tileType.getDimSize(0))); 376eaf15900SCullen Rhodes Value c0 = 377eaf15900SCullen Rhodes rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); 378eaf15900SCullen Rhodes Value numTileSlices = 379eaf15900SCullen Rhodes rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices); 380eaf15900SCullen Rhodes auto bufferType = 381eaf15900SCullen Rhodes MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic}, 382eaf15900SCullen Rhodes tileType.getElementType()); 383eaf15900SCullen Rhodes auto buffer = rewriter.create<memref::AllocaOp>( 384eaf15900SCullen Rhodes loc, bufferType, ValueRange{numTileSlices, numTileSlices}); 385eaf15900SCullen Rhodes 386eaf15900SCullen Rhodes // Store input tile. 387eaf15900SCullen Rhodes auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>( 388eaf15900SCullen Rhodes loc, input, buffer, ValueRange{c0, c0}); 389eaf15900SCullen Rhodes 390eaf15900SCullen Rhodes // Reload input tile vertically. 391eaf15900SCullen Rhodes rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( 392eaf15900SCullen Rhodes transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(), 393eaf15900SCullen Rhodes arm_sme::TileSliceLayout::Vertical); 394eaf15900SCullen Rhodes 395eaf15900SCullen Rhodes return success(); 396eaf15900SCullen Rhodes } 397eaf15900SCullen Rhodes }; 398eaf15900SCullen Rhodes 399e6662950SBenjamin Maxwell /// Conversion pattern for vector.outerproduct. 400e6662950SBenjamin Maxwell /// 401e6662950SBenjamin Maxwell /// If the vector.outerproduct is masked (and the mask is from a 402e6662950SBenjamin Maxwell /// vector.create_mask), then the mask is decomposed into two 1-D masks for the 403e6662950SBenjamin Maxwell /// operands. 404e6662950SBenjamin Maxwell /// 405e6662950SBenjamin Maxwell /// Example: 406e6662950SBenjamin Maxwell /// 407e6662950SBenjamin Maxwell /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1> 408e6662950SBenjamin Maxwell /// %result = vector.mask %mask { 409e6662950SBenjamin Maxwell /// vector.outerproduct %vecA, %vecB 410e6662950SBenjamin Maxwell /// : vector<[4]xf32>, vector<[4]xf32> 411e6662950SBenjamin Maxwell /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> 412e6662950SBenjamin Maxwell /// 413e6662950SBenjamin Maxwell /// is converted to: 414e6662950SBenjamin Maxwell /// 415e6662950SBenjamin Maxwell /// %maskA = vector.create_mask %dimA : vector<[4]xi1> 416e6662950SBenjamin Maxwell /// %maskB = vector.create_mask %dimB : vector<[4]xi1> 417e6662950SBenjamin Maxwell /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) 418e6662950SBenjamin Maxwell /// : vector<[4]xf32>, vector<[4]xf32> 419e6662950SBenjamin Maxwell /// 420e6662950SBenjamin Maxwell /// Unmasked outerproducts can be directly replaced with the arm_sme op. 421e6662950SBenjamin Maxwell /// 422e6662950SBenjamin Maxwell /// Example: 423e6662950SBenjamin Maxwell /// 424e6662950SBenjamin Maxwell /// %result = vector.outerproduct %vecA, %vecB 425e6662950SBenjamin Maxwell /// : vector<[4]xf32>, vector<[4]xf32> 426e6662950SBenjamin Maxwell /// 427e6662950SBenjamin Maxwell /// is converted to: 428e6662950SBenjamin Maxwell /// 429e6662950SBenjamin Maxwell /// %result = arm_sme.outerproduct %vecA, %vecB 430e6662950SBenjamin Maxwell /// : vector<[4]xf32>, vector<[4]xf32> 431e6662950SBenjamin Maxwell /// 432e6662950SBenjamin Maxwell struct VectorOuterProductToArmSMELowering 433e6662950SBenjamin Maxwell : public OpRewritePattern<vector::OuterProductOp> { 434e6662950SBenjamin Maxwell 435e6662950SBenjamin Maxwell using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 436e6662950SBenjamin Maxwell 437e6662950SBenjamin Maxwell LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp, 438e6662950SBenjamin Maxwell PatternRewriter &rewriter) const override { 439e6662950SBenjamin Maxwell 440e6662950SBenjamin Maxwell // We don't yet support lowering AXPY operations to SME. These could be 441e6662950SBenjamin Maxwell // lowered by masking out all but the first element of the LHS. 442e6662950SBenjamin Maxwell if (!isa<VectorType>(outerProductOp.getOperandTypeRHS())) 443e7432babSCullen Rhodes return rewriter.notifyMatchFailure(outerProductOp, 444e7432babSCullen Rhodes "AXPY operations not supported"); 445e6662950SBenjamin Maxwell 446e6662950SBenjamin Maxwell if (!arm_sme::isValidSMETileVectorType( 447e6662950SBenjamin Maxwell outerProductOp.getResultVectorType())) 448e7432babSCullen Rhodes return rewriter.notifyMatchFailure( 449e7432babSCullen Rhodes outerProductOp, "outer product does not fit into SME tile"); 450e6662950SBenjamin Maxwell 451e6662950SBenjamin Maxwell auto kind = outerProductOp.getKind(); 452e6662950SBenjamin Maxwell if (kind != vector::CombiningKind::ADD) 453e7432babSCullen Rhodes return rewriter.notifyMatchFailure( 454e7432babSCullen Rhodes outerProductOp, 455e6662950SBenjamin Maxwell "unsupported kind (lowering to SME only supports ADD at the moment)"); 456e6662950SBenjamin Maxwell 457e6662950SBenjamin Maxwell Value lhsMask = {}; 458e6662950SBenjamin Maxwell Value rhsMask = {}; 459e6662950SBenjamin Maxwell Operation *rootOp = outerProductOp; 460e6662950SBenjamin Maxwell auto loc = outerProductOp.getLoc(); 461e6662950SBenjamin Maxwell if (outerProductOp.isMasked()) { 462e6662950SBenjamin Maxwell auto maskOp = outerProductOp.getMaskingOp(); 463e6662950SBenjamin Maxwell rewriter.setInsertionPoint(maskOp); 464e6662950SBenjamin Maxwell rootOp = maskOp; 465e6662950SBenjamin Maxwell auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter); 466e6662950SBenjamin Maxwell if (failed(operandMasks)) 467e6662950SBenjamin Maxwell return failure(); 468e6662950SBenjamin Maxwell std::tie(lhsMask, rhsMask) = *operandMasks; 469e6662950SBenjamin Maxwell } 470e6662950SBenjamin Maxwell 471e6662950SBenjamin Maxwell rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>( 472e6662950SBenjamin Maxwell rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(), 473e6662950SBenjamin Maxwell outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc()); 474e6662950SBenjamin Maxwell 475e6662950SBenjamin Maxwell return success(); 476e6662950SBenjamin Maxwell } 477e6662950SBenjamin Maxwell 478e6662950SBenjamin Maxwell static FailureOr<std::pair<Value, Value>> 479e6662950SBenjamin Maxwell decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) { 480e6662950SBenjamin Maxwell // Attempt to extract masks from vector.create_mask. 481e6662950SBenjamin Maxwell // TODO: Add support for other mask sources. 482e6662950SBenjamin Maxwell auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>(); 483e6662950SBenjamin Maxwell if (!createMaskOp) 484e6662950SBenjamin Maxwell return failure(); 485e6662950SBenjamin Maxwell 486e6662950SBenjamin Maxwell auto maskType = createMaskOp.getVectorType(); 487e6662950SBenjamin Maxwell Value lhsMaskDim = createMaskOp.getOperand(0); 488e6662950SBenjamin Maxwell Value rhsMaskDim = createMaskOp.getOperand(1); 489e6662950SBenjamin Maxwell 490e6662950SBenjamin Maxwell VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0); 491e6662950SBenjamin Maxwell Value lhsMask = 492e6662950SBenjamin Maxwell rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim); 493e6662950SBenjamin Maxwell Value rhsMask = 494e6662950SBenjamin Maxwell rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim); 495e6662950SBenjamin Maxwell 496e6662950SBenjamin Maxwell return std::make_pair(lhsMask, rhsMask); 497e6662950SBenjamin Maxwell } 498e6662950SBenjamin Maxwell }; 499e6662950SBenjamin Maxwell 500c4251243SBenjamin Maxwell /// Lower `vector.extract` using `arm_sme.extract_tile_slice`. 501c4c52d41SBenjamin Maxwell /// 502c4c52d41SBenjamin Maxwell /// Example: 503c4c52d41SBenjamin Maxwell /// ``` 504c4c52d41SBenjamin Maxwell /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32> 505c4c52d41SBenjamin Maxwell /// ``` 506c4c52d41SBenjamin Maxwell /// Becomes: 507c4c52d41SBenjamin Maxwell /// ``` 508c4251243SBenjamin Maxwell /// %slice = arm_sme.extract_tile_slice %tile[%row] 509c4c52d41SBenjamin Maxwell /// : vector<[4]xi32> from vector<[4]x[4]xi32> 510c4c52d41SBenjamin Maxwell /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32> 511c4c52d41SBenjamin Maxwell /// ``` 512c4c52d41SBenjamin Maxwell struct VectorExtractToArmSMELowering 513c4c52d41SBenjamin Maxwell : public OpRewritePattern<vector::ExtractOp> { 514c4c52d41SBenjamin Maxwell using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; 515c4c52d41SBenjamin Maxwell 516c4c52d41SBenjamin Maxwell LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 517c4c52d41SBenjamin Maxwell PatternRewriter &rewriter) const override { 518c4c52d41SBenjamin Maxwell VectorType sourceType = extractOp.getSourceVectorType(); 519c4c52d41SBenjamin Maxwell if (!arm_sme::isValidSMETileVectorType(sourceType)) 520c4c52d41SBenjamin Maxwell return failure(); 521c4c52d41SBenjamin Maxwell 522c4c52d41SBenjamin Maxwell auto loc = extractOp.getLoc(); 523c4c52d41SBenjamin Maxwell auto position = extractOp.getMixedPosition(); 524c4c52d41SBenjamin Maxwell 525c4c52d41SBenjamin Maxwell Value sourceVector = extractOp.getVector(); 526c4c52d41SBenjamin Maxwell 527c4c52d41SBenjamin Maxwell // Extract entire vector. Should be handled by folder, but just to be safe. 528c4c52d41SBenjamin Maxwell if (position.empty()) { 529c4c52d41SBenjamin Maxwell rewriter.replaceOp(extractOp, sourceVector); 530c4c52d41SBenjamin Maxwell return success(); 531c4c52d41SBenjamin Maxwell } 532c4c52d41SBenjamin Maxwell 533c4c52d41SBenjamin Maxwell Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front(); 534c4251243SBenjamin Maxwell auto extractTileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( 535c4251243SBenjamin Maxwell loc, sourceVector, sliceIndex); 536c4c52d41SBenjamin Maxwell 537c4c52d41SBenjamin Maxwell if (position.size() == 1) { 538c4c52d41SBenjamin Maxwell // Single index case: Extracts a 1D slice. 539c4251243SBenjamin Maxwell rewriter.replaceOp(extractOp, extractTileSlice); 540c4c52d41SBenjamin Maxwell return success(); 541c4c52d41SBenjamin Maxwell } 542c4c52d41SBenjamin Maxwell 543c4c52d41SBenjamin Maxwell // Two indices case: Extracts a single element. 544c4c52d41SBenjamin Maxwell assert(position.size() == 2); 545c4251243SBenjamin Maxwell rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, extractTileSlice, 546c4251243SBenjamin Maxwell position[1]); 547c4c52d41SBenjamin Maxwell 548c4c52d41SBenjamin Maxwell return success(); 549c4c52d41SBenjamin Maxwell } 550c4c52d41SBenjamin Maxwell }; 551c4c52d41SBenjamin Maxwell 552c4251243SBenjamin Maxwell /// Lower `vector.insert` using `arm_sme.insert_tile_slice` and 553c4251243SBenjamin Maxwell /// `arm_sme.extract_tile_slice`. 554c4c52d41SBenjamin Maxwell /// 555c4c52d41SBenjamin Maxwell /// Example: 556c4c52d41SBenjamin Maxwell /// ``` 557c4c52d41SBenjamin Maxwell /// %new_tile = vector.insert %el, %tile[%row, %col] 558c4c52d41SBenjamin Maxwell /// : i32 into vector<[4]x[4]xi32> 559c4c52d41SBenjamin Maxwell /// ``` 560c4c52d41SBenjamin Maxwell /// Becomes: 561c4c52d41SBenjamin Maxwell /// ``` 562c4251243SBenjamin Maxwell /// %slice = arm_sme.extract_tile_slice %tile[%row] 563c4c52d41SBenjamin Maxwell /// : vector<[4]xi32> from vector<[4]x[4]xi32> 564c4c52d41SBenjamin Maxwell /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32> 565c4251243SBenjamin Maxwell /// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row] 566c4c52d41SBenjamin Maxwell /// : vector<[4]xi32> into vector<[4]x[4]xi32> 567c4c52d41SBenjamin Maxwell /// ``` 568c4c52d41SBenjamin Maxwell struct VectorInsertToArmSMELowering 569c4c52d41SBenjamin Maxwell : public OpRewritePattern<vector::InsertOp> { 570c4c52d41SBenjamin Maxwell using OpRewritePattern<vector::InsertOp>::OpRewritePattern; 571c4c52d41SBenjamin Maxwell 572c4c52d41SBenjamin Maxwell LogicalResult matchAndRewrite(vector::InsertOp insertOp, 573c4c52d41SBenjamin Maxwell PatternRewriter &rewriter) const override { 574c4c52d41SBenjamin Maxwell VectorType resultType = insertOp.getResult().getType(); 575c4c52d41SBenjamin Maxwell 576c4c52d41SBenjamin Maxwell if (!arm_sme::isValidSMETileVectorType(resultType)) 577c4c52d41SBenjamin Maxwell return failure(); 578c4c52d41SBenjamin Maxwell 579c4c52d41SBenjamin Maxwell auto loc = insertOp.getLoc(); 580c4c52d41SBenjamin Maxwell auto position = insertOp.getMixedPosition(); 581c4c52d41SBenjamin Maxwell 582c4c52d41SBenjamin Maxwell Value source = insertOp.getSource(); 583c4c52d41SBenjamin Maxwell 584c4c52d41SBenjamin Maxwell // Overwrite entire vector with value. Should be handled by folder, but 585c4c52d41SBenjamin Maxwell // just to be safe. 586c4c52d41SBenjamin Maxwell if (position.empty()) { 587c4c52d41SBenjamin Maxwell rewriter.replaceOp(insertOp, source); 588c4c52d41SBenjamin Maxwell return success(); 589c4c52d41SBenjamin Maxwell } 590c4c52d41SBenjamin Maxwell 591c4c52d41SBenjamin Maxwell Value tileSlice = source; 592c4c52d41SBenjamin Maxwell Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front(); 593c4c52d41SBenjamin Maxwell if (position.size() == 2) { 594c4c52d41SBenjamin Maxwell // Two indices case: Insert single element into tile. 595c4c52d41SBenjamin Maxwell // We need to first extract the existing slice and update the element. 596c4251243SBenjamin Maxwell tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( 597c4c52d41SBenjamin Maxwell loc, insertOp.getDest(), sliceIndex); 598c4c52d41SBenjamin Maxwell tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice, 599c4c52d41SBenjamin Maxwell position[1]); 600c4c52d41SBenjamin Maxwell } 601c4c52d41SBenjamin Maxwell 602c4c52d41SBenjamin Maxwell // Insert the slice into the destination tile. 603c4251243SBenjamin Maxwell rewriter.replaceOpWithNewOp<arm_sme::InsertTileSliceOp>( 604c4c52d41SBenjamin Maxwell insertOp, tileSlice, insertOp.getDest(), sliceIndex); 605c4c52d41SBenjamin Maxwell return success(); 606c4c52d41SBenjamin Maxwell } 607c4c52d41SBenjamin Maxwell }; 608c4c52d41SBenjamin Maxwell 60910063c5aSBenjamin Maxwell /// Lowers `vector.print` of a tile into a loop over the rows of the tile, 610c4251243SBenjamin Maxwell /// extracting them via `arm_sme.extract_tile_slice`, then printing with 61110063c5aSBenjamin Maxwell /// a 1D `vector.print`. 61210063c5aSBenjamin Maxwell /// 61310063c5aSBenjamin Maxwell /// BEFORE: 61410063c5aSBenjamin Maxwell /// ```mlir 61510063c5aSBenjamin Maxwell /// vector.print %tile : vector<[4]x[4]xf32> 61610063c5aSBenjamin Maxwell /// ``` 61710063c5aSBenjamin Maxwell /// AFTER: 61810063c5aSBenjamin Maxwell /// ```mlir 61910063c5aSBenjamin Maxwell /// %c0 = arith.constant 0 : index 62010063c5aSBenjamin Maxwell /// %c1 = arith.constant 1 : index 62110063c5aSBenjamin Maxwell /// %c4 = arith.constant 4 : index 62210063c5aSBenjamin Maxwell /// %vscale = vector.vscale 62310063c5aSBenjamin Maxwell /// %svl_s = arith.muli %c4, %vscale : index 62410063c5aSBenjamin Maxwell /// scf.for %i = %c0 to %svl_s step %c1 { 625c4251243SBenjamin Maxwell /// %tile_slice = arm_sme.extract_tile_slice %tile[%i] 62610063c5aSBenjamin Maxwell /// : vector<[4]xf32> from vector<[4]x[4]xf32> 62710063c5aSBenjamin Maxwell /// vector.print %tile_slice : vector<[4]xf32> 62810063c5aSBenjamin Maxwell /// } 62910063c5aSBenjamin Maxwell /// ``` 63010063c5aSBenjamin Maxwell struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> { 63110063c5aSBenjamin Maxwell using OpRewritePattern<vector::PrintOp>::OpRewritePattern; 63210063c5aSBenjamin Maxwell 63310063c5aSBenjamin Maxwell LogicalResult matchAndRewrite(vector::PrintOp printOp, 63410063c5aSBenjamin Maxwell PatternRewriter &rewriter) const override { 63510063c5aSBenjamin Maxwell if (!printOp.getSource()) 63610063c5aSBenjamin Maxwell return failure(); 63710063c5aSBenjamin Maxwell 63810063c5aSBenjamin Maxwell VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType()); 63910063c5aSBenjamin Maxwell if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType)) 64010063c5aSBenjamin Maxwell return failure(); 64110063c5aSBenjamin Maxwell 64210063c5aSBenjamin Maxwell auto loc = printOp.getLoc(); 64310063c5aSBenjamin Maxwell 64410063c5aSBenjamin Maxwell // Create a loop over the rows of the tile. 64510063c5aSBenjamin Maxwell auto vscale = rewriter.create<vector::VectorScaleOp>(loc); 64610063c5aSBenjamin Maxwell auto minTileRows = 64710063c5aSBenjamin Maxwell rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0)); 64810063c5aSBenjamin Maxwell auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); 64910063c5aSBenjamin Maxwell auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale); 65010063c5aSBenjamin Maxwell auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 65110063c5aSBenjamin Maxwell auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); 65210063c5aSBenjamin Maxwell { 65310063c5aSBenjamin Maxwell // Loop body. 65410063c5aSBenjamin Maxwell rewriter.setInsertionPointToStart(forOp.getBody()); 65510063c5aSBenjamin Maxwell // Extract the current row from the tile. 65610063c5aSBenjamin Maxwell Value rowIndex = forOp.getInductionVar(); 657c4251243SBenjamin Maxwell auto tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( 65810063c5aSBenjamin Maxwell loc, printOp.getSource(), rowIndex); 65910063c5aSBenjamin Maxwell // Print the row with a 1D vector.print. 66010063c5aSBenjamin Maxwell rewriter.create<vector::PrintOp>(loc, tileSlice, 66110063c5aSBenjamin Maxwell printOp.getPunctuation()); 66210063c5aSBenjamin Maxwell } 66310063c5aSBenjamin Maxwell 66410063c5aSBenjamin Maxwell rewriter.eraseOp(printOp); 66510063c5aSBenjamin Maxwell return success(); 66610063c5aSBenjamin Maxwell } 66710063c5aSBenjamin Maxwell }; 66810063c5aSBenjamin Maxwell 669c4251243SBenjamin Maxwell /// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp. 6704d6b9921SBenjamin Maxwell /// 6714d6b9921SBenjamin Maxwell /// BEFORE: 6724d6b9921SBenjamin Maxwell /// ```mlir 673c4251243SBenjamin Maxwell /// %slice = arm_sme.extract_tile_slice %tile[%index] 6744d6b9921SBenjamin Maxwell /// : vector<[4]xf32> from vector<[4]x[4]xf32> 6754d6b9921SBenjamin Maxwell /// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]} 6764d6b9921SBenjamin Maxwell /// : vector<[4]xf32>, memref<?x?xf32> 6774d6b9921SBenjamin Maxwell /// ``` 6784d6b9921SBenjamin Maxwell /// AFTER: 6794d6b9921SBenjamin Maxwell /// ```mlir 6804d6b9921SBenjamin Maxwell /// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j] 6814d6b9921SBenjamin Maxwell /// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32> 6824d6b9921SBenjamin Maxwell /// ``` 6834d6b9921SBenjamin Maxwell struct FoldTransferWriteOfExtractTileSlice 6844d6b9921SBenjamin Maxwell : public OpRewritePattern<vector::TransferWriteOp> { 6854d6b9921SBenjamin Maxwell using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 6864d6b9921SBenjamin Maxwell 6874d6b9921SBenjamin Maxwell LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 6884d6b9921SBenjamin Maxwell PatternRewriter &rewriter) const final { 6894d6b9921SBenjamin Maxwell if (!isa<MemRefType>(writeOp.getSource().getType())) 6904d6b9921SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, "destination not a memref"); 6914d6b9921SBenjamin Maxwell 6924d6b9921SBenjamin Maxwell if (writeOp.hasOutOfBoundsDim()) 6934d6b9921SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, 6944d6b9921SBenjamin Maxwell "not inbounds transfer write"); 6954d6b9921SBenjamin Maxwell 696c4251243SBenjamin Maxwell auto extractTileSlice = 697c4251243SBenjamin Maxwell writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>(); 698c4251243SBenjamin Maxwell if (!extractTileSlice) 6994d6b9921SBenjamin Maxwell return rewriter.notifyMatchFailure( 700c4251243SBenjamin Maxwell writeOp, "vector to store not from ExtractTileSliceOp"); 7014d6b9921SBenjamin Maxwell 7024d6b9921SBenjamin Maxwell AffineMap map = writeOp.getPermutationMap(); 7034d6b9921SBenjamin Maxwell if (!map.isMinorIdentity()) 7044d6b9921SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, 7054d6b9921SBenjamin Maxwell "unsupported permutation map"); 7064d6b9921SBenjamin Maxwell 7074d6b9921SBenjamin Maxwell Value mask = writeOp.getMask(); 7084d6b9921SBenjamin Maxwell if (!mask) { 7094d6b9921SBenjamin Maxwell auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type()); 7104d6b9921SBenjamin Maxwell mask = rewriter.create<arith::ConstantOp>( 7114d6b9921SBenjamin Maxwell writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true)); 7124d6b9921SBenjamin Maxwell } 7134d6b9921SBenjamin Maxwell 7144d6b9921SBenjamin Maxwell rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>( 715c4251243SBenjamin Maxwell writeOp, extractTileSlice.getTile(), 716c4251243SBenjamin Maxwell extractTileSlice.getTileSliceIndex(), mask, writeOp.getSource(), 717c4251243SBenjamin Maxwell writeOp.getIndices(), extractTileSlice.getLayout()); 7184d6b9921SBenjamin Maxwell return success(); 7194d6b9921SBenjamin Maxwell } 7204d6b9921SBenjamin Maxwell }; 7214d6b9921SBenjamin Maxwell 722e2296d82SBenjamin Maxwell /// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to 723e2296d82SBenjamin Maxwell /// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or 724e2296d82SBenjamin Maxwell /// SVE 2.1), so this is currently the most logical place for this lowering. 725e2296d82SBenjamin Maxwell /// 726e2296d82SBenjamin Maxwell /// Example: 727e2296d82SBenjamin Maxwell /// ```mlir 728e2296d82SBenjamin Maxwell /// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> 729e2296d82SBenjamin Maxwell /// %slice = vector.extract %mask[%index] 730e2296d82SBenjamin Maxwell /// : vector<[8]xi1> from vector<[4]x[8]xi1> 731e2296d82SBenjamin Maxwell /// ``` 732e2296d82SBenjamin Maxwell /// Becomes: 733e2296d82SBenjamin Maxwell /// ``` 734e2296d82SBenjamin Maxwell /// %mask_rows = vector.create_mask %a : vector<[4]xi1> 735e2296d82SBenjamin Maxwell /// %mask_cols = vector.create_mask %b : vector<[8]xi1> 736e2296d82SBenjamin Maxwell /// %slice = arm_sve.psel %mask_cols, %mask_rows[%index] 737e2296d82SBenjamin Maxwell /// : vector<[8]xi1>, vector<[4]xi1> 738e2296d82SBenjamin Maxwell /// ``` 739e2296d82SBenjamin Maxwell struct ExtractFromCreateMaskToPselLowering 740e2296d82SBenjamin Maxwell : public OpRewritePattern<vector::ExtractOp> { 741e2296d82SBenjamin Maxwell using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; 742e2296d82SBenjamin Maxwell 743e2296d82SBenjamin Maxwell LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 744e2296d82SBenjamin Maxwell PatternRewriter &rewriter) const override { 745e2296d82SBenjamin Maxwell if (extractOp.getNumIndices() != 1) 746e2296d82SBenjamin Maxwell return rewriter.notifyMatchFailure(extractOp, "not single extract index"); 747e2296d82SBenjamin Maxwell 748e2296d82SBenjamin Maxwell auto resultType = extractOp.getResult().getType(); 749e2296d82SBenjamin Maxwell auto resultVectorType = dyn_cast<VectorType>(resultType); 750e2296d82SBenjamin Maxwell if (!resultVectorType) 751e2296d82SBenjamin Maxwell return rewriter.notifyMatchFailure(extractOp, "result not VectorType"); 752e2296d82SBenjamin Maxwell 753e2296d82SBenjamin Maxwell auto createMaskOp = 754e2296d82SBenjamin Maxwell extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); 755e2296d82SBenjamin Maxwell if (!createMaskOp) 756e2296d82SBenjamin Maxwell return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp"); 757e2296d82SBenjamin Maxwell 758e2296d82SBenjamin Maxwell auto maskType = createMaskOp.getVectorType(); 759e2296d82SBenjamin Maxwell if (maskType.getRank() != 2 || !maskType.allDimsScalable()) 760e2296d82SBenjamin Maxwell return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask"); 761e2296d82SBenjamin Maxwell 762e2296d82SBenjamin Maxwell auto isSVEPredicateSize = [](int64_t size) { 763e2296d82SBenjamin Maxwell return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size)); 764e2296d82SBenjamin Maxwell }; 765e2296d82SBenjamin Maxwell 766e2296d82SBenjamin Maxwell auto rowsBaseSize = maskType.getDimSize(0); 767e2296d82SBenjamin Maxwell auto colsBaseSize = maskType.getDimSize(1); 768e2296d82SBenjamin Maxwell if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize)) 769e2296d82SBenjamin Maxwell return rewriter.notifyMatchFailure( 770e2296d82SBenjamin Maxwell createMaskOp, "mask dimensions not SVE predicate-sized"); 771e2296d82SBenjamin Maxwell 772e2296d82SBenjamin Maxwell auto loc = extractOp.getLoc(); 773e2296d82SBenjamin Maxwell VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1); 774e2296d82SBenjamin Maxwell VectorType colMaskType = VectorType::Builder(maskType).dropDim(0); 775e2296d82SBenjamin Maxwell 776e2296d82SBenjamin Maxwell // Create the two 1-D masks at the location of the 2-D create_mask (which is 777e2296d82SBenjamin Maxwell // usually outside a loop). This prevents the need for later hoisting. 778e2296d82SBenjamin Maxwell rewriter.setInsertionPoint(createMaskOp); 779e2296d82SBenjamin Maxwell auto rowMask = rewriter.create<vector::CreateMaskOp>( 780e2296d82SBenjamin Maxwell loc, rowMaskType, createMaskOp.getOperand(0)); 781e2296d82SBenjamin Maxwell auto colMask = rewriter.create<vector::CreateMaskOp>( 782e2296d82SBenjamin Maxwell loc, colMaskType, createMaskOp.getOperand(1)); 783e2296d82SBenjamin Maxwell 784e2296d82SBenjamin Maxwell rewriter.setInsertionPoint(extractOp); 785e2296d82SBenjamin Maxwell auto position = 786e2296d82SBenjamin Maxwell vector::getAsValues(rewriter, loc, extractOp.getMixedPosition()); 787e2296d82SBenjamin Maxwell rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask, 788e2296d82SBenjamin Maxwell position[0]); 789e2296d82SBenjamin Maxwell return success(); 790e2296d82SBenjamin Maxwell } 791e2296d82SBenjamin Maxwell }; 792e2296d82SBenjamin Maxwell 793447bb5beSAndrzej Warzynski } // namespace 794447bb5beSAndrzej Warzynski 795447bb5beSAndrzej Warzynski void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, 796447bb5beSAndrzej Warzynski MLIRContext &ctx) { 797e2296d82SBenjamin Maxwell patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering, 7989f7fff7fSCullen Rhodes TransferReadToArmSMELowering, TransferWriteToArmSMELowering, 7999f7fff7fSCullen Rhodes TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, 8009f7fff7fSCullen Rhodes VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, 8019f7fff7fSCullen Rhodes VectorExtractToArmSMELowering, VectorInsertToArmSMELowering, 802e2296d82SBenjamin Maxwell VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice, 803e2296d82SBenjamin Maxwell ExtractFromCreateMaskToPselLowering>(&ctx); 804447bb5beSAndrzej Warzynski } 805