1042800a4SBenjamin Maxwell //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===// 2042800a4SBenjamin Maxwell // 3042800a4SBenjamin Maxwell // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4042800a4SBenjamin Maxwell // See https://llvm.org/LICENSE.txt for license information. 5042800a4SBenjamin Maxwell // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6042800a4SBenjamin Maxwell // 7042800a4SBenjamin Maxwell //===----------------------------------------------------------------------===// 8042800a4SBenjamin Maxwell // 9042800a4SBenjamin Maxwell // This pass legalizes vector operations so they can be lowered to ArmSME. 10042800a4SBenjamin Maxwell // 11042800a4SBenjamin Maxwell // Note: In the context of this pass 'tile' always refers to an SME tile. 12042800a4SBenjamin Maxwell // 13042800a4SBenjamin Maxwell //===----------------------------------------------------------------------===// 14042800a4SBenjamin Maxwell 15c2dea712SBenjamin Maxwell #include "mlir/Dialect/Arith/Utils/Utils.h" 16042800a4SBenjamin Maxwell #include "mlir/Dialect/ArmSME/IR/ArmSME.h" 17042800a4SBenjamin Maxwell #include "mlir/Dialect/ArmSME/Transforms/Passes.h" 18042800a4SBenjamin Maxwell #include "mlir/Dialect/ArmSME/Utils/Utils.h" 19042800a4SBenjamin Maxwell #include "mlir/Dialect/Func/IR/FuncOps.h" 2031613de9SMatthias Springer #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 21c194bc77SBenjamin Maxwell #include "mlir/Dialect/Index/IR/IndexDialect.h" 22c194bc77SBenjamin Maxwell #include "mlir/Dialect/Index/IR/IndexOps.h" 230473e322SBenjamin Maxwell #include "mlir/Dialect/MemRef/IR/MemRef.h" 245ed5d723SBenjamin Maxwell #include "mlir/Dialect/SCF/IR/SCF.h" 25042800a4SBenjamin Maxwell #include "mlir/Dialect/SCF/Transforms/Patterns.h" 26042800a4SBenjamin Maxwell #include "mlir/Dialect/Utils/IndexingUtils.h" 27c194bc77SBenjamin Maxwell #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 2831613de9SMatthias Springer #include "mlir/Transforms/DialectConversion.h" 2931613de9SMatthias Springer #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30042800a4SBenjamin Maxwell 31042800a4SBenjamin Maxwell #define DEBUG_TYPE "arm-sme-vector-legalization" 32042800a4SBenjamin Maxwell 33042800a4SBenjamin Maxwell namespace mlir::arm_sme { 34042800a4SBenjamin Maxwell #define GEN_PASS_DEF_VECTORLEGALIZATION 35042800a4SBenjamin Maxwell #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" 36042800a4SBenjamin Maxwell } // namespace mlir::arm_sme 37042800a4SBenjamin Maxwell 38042800a4SBenjamin Maxwell using namespace mlir; 39042800a4SBenjamin Maxwell using namespace mlir::arm_sme; 40042800a4SBenjamin Maxwell 41042800a4SBenjamin Maxwell namespace { 42042800a4SBenjamin Maxwell 43c2dea712SBenjamin Maxwell //===----------------------------------------------------------------------===// 44c2dea712SBenjamin Maxwell // Decomposition of vector operations larger than an SME tile 45c2dea712SBenjamin Maxwell //===----------------------------------------------------------------------===// 46c2dea712SBenjamin Maxwell 47042800a4SBenjamin Maxwell // Common match failure reasons. 481408667fSBenjamin Maxwell static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple( 49042800a4SBenjamin Maxwell "op vector size is not multiple of SME tiles"); 501408667fSBenjamin Maxwell static constexpr StringLiteral kMatchFailureUnsupportedMaskOp( 51042800a4SBenjamin Maxwell "op mask is unsupported for legalization/decomposition"); 52042800a4SBenjamin Maxwell static constexpr StringLiteral 531408667fSBenjamin Maxwell kMatchFailureNonPermutationMap("op affine map is not a permutation"); 54d1fc59c3SBenjamin Maxwell static constexpr StringLiteral kMatchFailureNotIllegalToLegal( 55d1fc59c3SBenjamin Maxwell "expected transpose from illegal type to legal type"); 56042800a4SBenjamin Maxwell 57042800a4SBenjamin Maxwell /// An SMESubTile represents a single SME-sized sub-tile from decomposing a 58042800a4SBenjamin Maxwell /// larger vector type. The (`row`, `col`) are the position of the tile in the 59042800a4SBenjamin Maxwell /// original vector type. For example for an [8]x[8] tile with four [4]x[4] 60042800a4SBenjamin Maxwell /// sub-tiles, we would have: 61042800a4SBenjamin Maxwell /// 62042800a4SBenjamin Maxwell /// 8 x vscale 63042800a4SBenjamin Maxwell /// ┌─────────────┬─────────────┐ 64042800a4SBenjamin Maxwell /// │(0,0) │(0,4) │ 65042800a4SBenjamin Maxwell /// │ │ │ 66042800a4SBenjamin Maxwell /// ├─────────────┼─────────────┤ 8 x vscale 67042800a4SBenjamin Maxwell /// │(4,0) │(4,4) │ 68042800a4SBenjamin Maxwell /// │ │ │ 69042800a4SBenjamin Maxwell /// └─────────────┴─────────────┘ 70042800a4SBenjamin Maxwell struct SMESubTile { 71042800a4SBenjamin Maxwell // Note: The units of (row, col) are vscale (as SME tiles are scalable). 72042800a4SBenjamin Maxwell int row{0}; 73042800a4SBenjamin Maxwell int col{0}; 74042800a4SBenjamin Maxwell // The SME tile type. 75042800a4SBenjamin Maxwell VectorType type; 76042800a4SBenjamin Maxwell }; 77042800a4SBenjamin Maxwell 78042800a4SBenjamin Maxwell /// Adds a constant elementwise scalable offset to `indices` (which are of equal 79042800a4SBenjamin Maxwell /// length). For example, in the 2D case this would return: 80042800a4SBenjamin Maxwell // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale } 81042800a4SBenjamin Maxwell SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder, 82042800a4SBenjamin Maxwell Location loc, 83042800a4SBenjamin Maxwell ValueRange indices, 84042800a4SBenjamin Maxwell ArrayRef<int> scalableOffsets) { 85042800a4SBenjamin Maxwell auto vscale = builder.create<vector::VectorScaleOp>(loc); 86042800a4SBenjamin Maxwell return llvm::map_to_vector( 87042800a4SBenjamin Maxwell llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value { 88042800a4SBenjamin Maxwell auto [index, base] = pair; 89042800a4SBenjamin Maxwell auto offset = builder.create<arith::MulIOp>( 90042800a4SBenjamin Maxwell loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale); 91042800a4SBenjamin Maxwell return builder.create<arith::AddIOp>(loc, index, offset); 92042800a4SBenjamin Maxwell }); 93042800a4SBenjamin Maxwell } 94042800a4SBenjamin Maxwell 95042800a4SBenjamin Maxwell /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to 96042800a4SBenjamin Maxwell /// indices for one of the SME sub-tiles it will decompose into. 97042800a4SBenjamin Maxwell /// 98042800a4SBenjamin Maxwell /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the 99042800a4SBenjamin Maxwell /// indices for each tile would need to be adjusted as follows: 100042800a4SBenjamin Maxwell /// 101042800a4SBenjamin Maxwell /// initial indices = [a,b], inital size = 8x8, target size = 4x4 102042800a4SBenjamin Maxwell /// ┌─────────────┬─────────────┐ 103042800a4SBenjamin Maxwell /// │[a,b] │[a,b+4] │ 104042800a4SBenjamin Maxwell /// │ │ │ 105042800a4SBenjamin Maxwell /// ├─────────────┼─────────────┤ 106042800a4SBenjamin Maxwell /// │[a+4,b] │[a+4,b+4] │ 107042800a4SBenjamin Maxwell /// │ │ │ 108042800a4SBenjamin Maxwell /// └─────────────┴─────────────┘ 109042800a4SBenjamin Maxwell SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc, 110042800a4SBenjamin Maxwell ValueRange indices, 111042800a4SBenjamin Maxwell SMESubTile smeTile) { 112042800a4SBenjamin Maxwell return addConstantScalableOffset(builder, loc, indices, 113042800a4SBenjamin Maxwell {smeTile.row, smeTile.col}); 114042800a4SBenjamin Maxwell } 115042800a4SBenjamin Maxwell 116042800a4SBenjamin Maxwell /// Returns true if `mask` is generated by an operation that can be decomposed 117042800a4SBenjamin Maxwell /// for SME. Currently, that is just no mask, or vector.create_mask. 118042800a4SBenjamin Maxwell /// TODO: Add support for vector.constant_mask once required for SME. 119042800a4SBenjamin Maxwell bool isSupportedMaskOp(Value mask) { 120042800a4SBenjamin Maxwell return !mask || mask.getDefiningOp<vector::CreateMaskOp>(); 121042800a4SBenjamin Maxwell } 122042800a4SBenjamin Maxwell 123042800a4SBenjamin Maxwell /// Extracts a mask for an SME sub-tile from the mask of a larger vector type. 124042800a4SBenjamin Maxwell Value extractSMEMask(OpBuilder &builder, Location loc, Value mask, 125042800a4SBenjamin Maxwell SMESubTile smeTile) { 126042800a4SBenjamin Maxwell assert(isSupportedMaskOp(mask)); 127042800a4SBenjamin Maxwell if (!mask) 128042800a4SBenjamin Maxwell return Value{}; 129042800a4SBenjamin Maxwell auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); 130042800a4SBenjamin Maxwell // The operands of `vector.create_mask` (from a 2D perspective) are the 131042800a4SBenjamin Maxwell // coordinates where the mask ends. So we subtract where this tile starts, 132042800a4SBenjamin Maxwell // from the mask operands to get the parameters for this sub-tile. 133042800a4SBenjamin Maxwell auto smeTileMaskDims = addConstantScalableOffset( 134042800a4SBenjamin Maxwell builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col}); 135042800a4SBenjamin Maxwell auto smeTileCreateMask = builder.create<vector::CreateMaskOp>( 136042800a4SBenjamin Maxwell loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims); 137042800a4SBenjamin Maxwell return smeTileCreateMask.getResult(); 138042800a4SBenjamin Maxwell } 139042800a4SBenjamin Maxwell 140042800a4SBenjamin Maxwell /// Constructs an iterator that returns each SME tile (with coordinates) 141042800a4SBenjamin Maxwell /// contained within a VectorType. For example, if decomposing an [8]x[8] into 142042800a4SBenjamin Maxwell /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0), 143042800a4SBenjamin Maxwell /// (4, 4). 144042800a4SBenjamin Maxwell auto decomposeToSMETiles(OpBuilder &builder, VectorType type, 145042800a4SBenjamin Maxwell VectorType smeTileType, 146042800a4SBenjamin Maxwell bool transposeIndices = false) { 147042800a4SBenjamin Maxwell return llvm::map_range( 148c194bc77SBenjamin Maxwell StaticTileOffsetRange( 149c194bc77SBenjamin Maxwell type.getShape(), 150c194bc77SBenjamin Maxwell {std::min(type.getDimSize(0), smeTileType.getDimSize(0)), 151c194bc77SBenjamin Maxwell std::min(type.getDimSize(1), smeTileType.getDimSize(1))}), 152042800a4SBenjamin Maxwell [=](auto indices) { 153042800a4SBenjamin Maxwell int row = int(indices[0]); 154042800a4SBenjamin Maxwell int col = int(indices[1]); 155042800a4SBenjamin Maxwell if (transposeIndices) 156042800a4SBenjamin Maxwell std::swap(row, col); 157042800a4SBenjamin Maxwell return SMESubTile{row, col, smeTileType}; 158042800a4SBenjamin Maxwell }); 159042800a4SBenjamin Maxwell } 160042800a4SBenjamin Maxwell 161042800a4SBenjamin Maxwell /// Returns the number of SME tiles that fit into the (2D-scalable) vector type 162042800a4SBenjamin Maxwell /// `type`. 163042800a4SBenjamin Maxwell int getNumberOfSMETilesForVectorType(VectorType type) { 164042800a4SBenjamin Maxwell assert(isMultipleOfSMETileVectorType(type) && 165042800a4SBenjamin Maxwell "`type` not multiple of SME tiles"); 166042800a4SBenjamin Maxwell int64_t vectorRows = type.getDimSize(0); 167042800a4SBenjamin Maxwell int64_t vectorCols = type.getDimSize(1); 168042800a4SBenjamin Maxwell auto elementType = type.getElementType(); 169042800a4SBenjamin Maxwell unsigned minNumElts = getSMETileSliceMinNumElts(elementType); 170042800a4SBenjamin Maxwell return (vectorRows * vectorCols) / (minNumElts * minNumElts); 171042800a4SBenjamin Maxwell } 172042800a4SBenjamin Maxwell 173dadcaf82SBenjamin Maxwell /// Legalize `arith.constant dense<value>` splat operations to fit within SME 174dadcaf82SBenjamin Maxwell /// tiles by decomposing them into tile-sized operations. 175dadcaf82SBenjamin Maxwell struct LegalizeArithConstantOpsByDecomposition 17631613de9SMatthias Springer : public OpConversionPattern<arith::ConstantOp> { 17731613de9SMatthias Springer using OpConversionPattern::OpConversionPattern; 178dadcaf82SBenjamin Maxwell 179dadcaf82SBenjamin Maxwell LogicalResult 180dadcaf82SBenjamin Maxwell matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, 18131613de9SMatthias Springer ConversionPatternRewriter &rewriter) const override { 182dadcaf82SBenjamin Maxwell auto vectorType = dyn_cast<VectorType>(constantOp.getType()); 183dadcaf82SBenjamin Maxwell auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); 184dadcaf82SBenjamin Maxwell if (!vectorType || !denseAttr || !denseAttr.isSplat()) 185dadcaf82SBenjamin Maxwell return failure(); 186dadcaf82SBenjamin Maxwell 187dadcaf82SBenjamin Maxwell if (!isMultipleOfSMETileVectorType(vectorType)) 188dadcaf82SBenjamin Maxwell return rewriter.notifyMatchFailure(constantOp, 189dadcaf82SBenjamin Maxwell kMatchFailureNotSMETileTypeMultiple); 190dadcaf82SBenjamin Maxwell 191dadcaf82SBenjamin Maxwell auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); 192dadcaf82SBenjamin Maxwell auto tileCount = getNumberOfSMETilesForVectorType(vectorType); 193dadcaf82SBenjamin Maxwell auto tileSplat = rewriter.create<arith::ConstantOp>( 194dadcaf82SBenjamin Maxwell constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); 19531613de9SMatthias Springer SmallVector<Value> repl(tileCount, tileSplat); 19631613de9SMatthias Springer rewriter.replaceOpWithMultiple(constantOp, {repl}); 197dadcaf82SBenjamin Maxwell 198dadcaf82SBenjamin Maxwell return success(); 199dadcaf82SBenjamin Maxwell } 200dadcaf82SBenjamin Maxwell }; 201dadcaf82SBenjamin Maxwell 202042800a4SBenjamin Maxwell /// Legalize `vector.outerproduct` operations to fit within SME tiles by 203042800a4SBenjamin Maxwell /// decomposing them into tile-sized operations. 204042800a4SBenjamin Maxwell struct LegalizeVectorOuterProductOpsByDecomposition 20531613de9SMatthias Springer : public OpConversionPattern<vector::OuterProductOp> { 20631613de9SMatthias Springer using OpConversionPattern::OpConversionPattern; 207042800a4SBenjamin Maxwell 208042800a4SBenjamin Maxwell LogicalResult 20931613de9SMatthias Springer matchAndRewrite(vector::OuterProductOp outerProductOp, 21031613de9SMatthias Springer OneToNOpAdaptor adaptor, 21131613de9SMatthias Springer ConversionPatternRewriter &rewriter) const override { 212042800a4SBenjamin Maxwell auto vectorType = outerProductOp.getResultVectorType(); 213042800a4SBenjamin Maxwell if (!isMultipleOfSMETileVectorType(vectorType)) 2141408667fSBenjamin Maxwell return rewriter.notifyMatchFailure(outerProductOp, 2151408667fSBenjamin Maxwell kMatchFailureNotSMETileTypeMultiple); 216042800a4SBenjamin Maxwell 217042800a4SBenjamin Maxwell Value mask; 218042800a4SBenjamin Maxwell Operation *rootOp = outerProductOp; 219042800a4SBenjamin Maxwell auto loc = outerProductOp.getLoc(); 220042800a4SBenjamin Maxwell if (outerProductOp.isMasked()) { 221042800a4SBenjamin Maxwell auto maskOp = outerProductOp.getMaskingOp(); 222042800a4SBenjamin Maxwell mask = maskOp.getMask(); 223042800a4SBenjamin Maxwell rootOp = maskOp; 22431613de9SMatthias Springer rewriter.setInsertionPoint(rootOp); 225042800a4SBenjamin Maxwell } 226042800a4SBenjamin Maxwell 227042800a4SBenjamin Maxwell if (!isSupportedMaskOp(mask)) 228042800a4SBenjamin Maxwell return rewriter.notifyMatchFailure(outerProductOp, 2291408667fSBenjamin Maxwell kMatchFailureUnsupportedMaskOp); 230042800a4SBenjamin Maxwell 231042800a4SBenjamin Maxwell ValueRange accSMETiles = adaptor.getAcc(); 232042800a4SBenjamin Maxwell auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); 233042800a4SBenjamin Maxwell VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0); 234042800a4SBenjamin Maxwell 235042800a4SBenjamin Maxwell SmallVector<Value> resultSMETiles; 236042800a4SBenjamin Maxwell for (auto [index, smeTile] : llvm::enumerate( 237042800a4SBenjamin Maxwell decomposeToSMETiles(rewriter, vectorType, smeTileType))) { 238042800a4SBenjamin Maxwell 239042800a4SBenjamin Maxwell auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); 240042800a4SBenjamin Maxwell auto lhs = rewriter.create<vector::ScalableExtractOp>( 241042800a4SBenjamin Maxwell loc, sliceType, outerProductOp.getLhs(), smeTile.row); 242042800a4SBenjamin Maxwell auto rhs = rewriter.create<vector::ScalableExtractOp>( 243042800a4SBenjamin Maxwell loc, sliceType, outerProductOp.getRhs(), smeTile.col); 244042800a4SBenjamin Maxwell auto smeOuterProduct = rewriter.create<vector::OuterProductOp>( 245042800a4SBenjamin Maxwell loc, smeTileType, lhs, rhs, 246042800a4SBenjamin Maxwell !accSMETiles.empty() ? accSMETiles[index] : Value{}, 247042800a4SBenjamin Maxwell outerProductOp.getKind()); 248042800a4SBenjamin Maxwell 249042800a4SBenjamin Maxwell auto maskedOuterProduct = 250042800a4SBenjamin Maxwell vector::maskOperation(rewriter, smeOuterProduct, smeMask); 251042800a4SBenjamin Maxwell resultSMETiles.push_back(maskedOuterProduct->getResult(0)); 252042800a4SBenjamin Maxwell } 253042800a4SBenjamin Maxwell 25431613de9SMatthias Springer rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles}); 255042800a4SBenjamin Maxwell return success(); 256042800a4SBenjamin Maxwell } 257042800a4SBenjamin Maxwell }; 258042800a4SBenjamin Maxwell 259042800a4SBenjamin Maxwell // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to 260042800a4SBenjamin Maxwell // get the help of the type conversion), but doing so results in the type 261042800a4SBenjamin Maxwell // conversion adding target materializations in the `vector.mask` region 262042800a4SBenjamin Maxwell // (invalid). This pattern matches on `vector.mask` then calls into the 263042800a4SBenjamin Maxwell // `vector.outerproduct` pattern to work around this issue. 264042800a4SBenjamin Maxwell struct LegalizeMaskedVectorOuterProductOpsByDecomposition 26531613de9SMatthias Springer : public OpConversionPattern<vector::MaskOp> { 26631613de9SMatthias Springer using OpConversionPattern::OpConversionPattern; 267042800a4SBenjamin Maxwell 268042800a4SBenjamin Maxwell LogicalResult 26931613de9SMatthias Springer matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor, 27031613de9SMatthias Springer ConversionPatternRewriter &rewriter) const override { 271a9eb8f0eSBenjamin Maxwell if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>( 272a9eb8f0eSBenjamin Maxwell maskOp.getMaskableOp())) { 273042800a4SBenjamin Maxwell LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(), 274042800a4SBenjamin Maxwell getContext()); 275042800a4SBenjamin Maxwell return static_cast<RewritePattern &>(pattern).matchAndRewrite( 276042800a4SBenjamin Maxwell outerProductOp, rewriter); 277042800a4SBenjamin Maxwell } 278042800a4SBenjamin Maxwell return failure(); 279042800a4SBenjamin Maxwell } 280042800a4SBenjamin Maxwell }; 281042800a4SBenjamin Maxwell 282042800a4SBenjamin Maxwell /// Legalize `vector.transfer_read` operations to fit within SME tiles by 283042800a4SBenjamin Maxwell /// decomposing them into tile-sized operations. 284042800a4SBenjamin Maxwell struct LegalizeTransferReadOpsByDecomposition 28531613de9SMatthias Springer : public OpConversionPattern<vector::TransferReadOp> { 28631613de9SMatthias Springer using OpConversionPattern::OpConversionPattern; 287042800a4SBenjamin Maxwell 288042800a4SBenjamin Maxwell LogicalResult 28931613de9SMatthias Springer matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor, 29031613de9SMatthias Springer ConversionPatternRewriter &rewriter) const override { 291042800a4SBenjamin Maxwell auto vectorType = readOp.getVectorType(); 292042800a4SBenjamin Maxwell if (!isMultipleOfSMETileVectorType(vectorType)) 2931408667fSBenjamin Maxwell return rewriter.notifyMatchFailure(readOp, 2941408667fSBenjamin Maxwell kMatchFailureNotSMETileTypeMultiple); 295042800a4SBenjamin Maxwell 296042800a4SBenjamin Maxwell auto mask = readOp.getMask(); 297042800a4SBenjamin Maxwell if (!isSupportedMaskOp(mask)) 298042800a4SBenjamin Maxwell return rewriter.notifyMatchFailure(readOp, 2991408667fSBenjamin Maxwell kMatchFailureUnsupportedMaskOp); 300042800a4SBenjamin Maxwell 301042800a4SBenjamin Maxwell auto permutationMap = readOp.getPermutationMap(); 302042800a4SBenjamin Maxwell if (!permutationMap.isPermutation()) 303042800a4SBenjamin Maxwell return rewriter.notifyMatchFailure(readOp, 3041408667fSBenjamin Maxwell kMatchFailureNonPermutationMap); 305042800a4SBenjamin Maxwell 306042800a4SBenjamin Maxwell // Note: For 2D vector types the only non-identity permutation is a simple 307*aa295216SJay Foad // transpose [1, 0]. 308042800a4SBenjamin Maxwell bool transposed = !permutationMap.isIdentity(); 309042800a4SBenjamin Maxwell 310042800a4SBenjamin Maxwell auto loc = readOp.getLoc(); 311042800a4SBenjamin Maxwell auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); 312042800a4SBenjamin Maxwell 313042800a4SBenjamin Maxwell SmallVector<Value> resultSMETiles; 314042800a4SBenjamin Maxwell for (SMESubTile smeTile : 315042800a4SBenjamin Maxwell decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) { 316042800a4SBenjamin Maxwell auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); 317042800a4SBenjamin Maxwell auto smeRead = rewriter.create<vector::TransferReadOp>( 318042800a4SBenjamin Maxwell loc, smeTileType, readOp.getSource(), 319042800a4SBenjamin Maxwell getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile), 320042800a4SBenjamin Maxwell readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask, 321042800a4SBenjamin Maxwell readOp.getInBoundsAttr()); 322042800a4SBenjamin Maxwell resultSMETiles.push_back(smeRead); 323042800a4SBenjamin Maxwell } 324042800a4SBenjamin Maxwell 32531613de9SMatthias Springer rewriter.replaceOpWithMultiple(readOp, {resultSMETiles}); 326042800a4SBenjamin Maxwell return success(); 327042800a4SBenjamin Maxwell } 328042800a4SBenjamin Maxwell }; 329042800a4SBenjamin Maxwell 330042800a4SBenjamin Maxwell /// Legalize `vector.transfer_write` operations to fit within SME tiles by 331042800a4SBenjamin Maxwell /// decomposing them into tile-sized operations. 332042800a4SBenjamin Maxwell struct LegalizeTransferWriteOpsByDecomposition 33331613de9SMatthias Springer : public OpConversionPattern<vector::TransferWriteOp> { 33431613de9SMatthias Springer using OpConversionPattern::OpConversionPattern; 335042800a4SBenjamin Maxwell 336042800a4SBenjamin Maxwell LogicalResult 33731613de9SMatthias Springer matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, 33831613de9SMatthias Springer ConversionPatternRewriter &rewriter) const override { 339042800a4SBenjamin Maxwell auto vectorType = writeOp.getVectorType(); 340042800a4SBenjamin Maxwell if (!isMultipleOfSMETileVectorType(vectorType)) 3411408667fSBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, 3421408667fSBenjamin Maxwell kMatchFailureNotSMETileTypeMultiple); 343042800a4SBenjamin Maxwell 344042800a4SBenjamin Maxwell auto mask = writeOp.getMask(); 345042800a4SBenjamin Maxwell if (!isSupportedMaskOp(mask)) 346042800a4SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, 3471408667fSBenjamin Maxwell kMatchFailureUnsupportedMaskOp); 348042800a4SBenjamin Maxwell 349042800a4SBenjamin Maxwell auto permutationMap = writeOp.getPermutationMap(); 350042800a4SBenjamin Maxwell if (!permutationMap.isPermutation()) 351042800a4SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, 3521408667fSBenjamin Maxwell kMatchFailureNonPermutationMap); 353042800a4SBenjamin Maxwell 354042800a4SBenjamin Maxwell // Note: For 2D vector types the only non-identity permutation is a simple 355*aa295216SJay Foad // transpose [1, 0]. 356042800a4SBenjamin Maxwell bool transposed = !permutationMap.isIdentity(); 357042800a4SBenjamin Maxwell 358042800a4SBenjamin Maxwell auto loc = writeOp.getLoc(); 359042800a4SBenjamin Maxwell auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); 360042800a4SBenjamin Maxwell auto inputSMETiles = adaptor.getVector(); 361042800a4SBenjamin Maxwell 362042800a4SBenjamin Maxwell Value destTensorOrMemref = writeOp.getSource(); 363042800a4SBenjamin Maxwell for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles( 364042800a4SBenjamin Maxwell rewriter, vectorType, smeTileType, transposed))) { 365042800a4SBenjamin Maxwell auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); 366042800a4SBenjamin Maxwell auto smeWrite = rewriter.create<vector::TransferWriteOp>( 367042800a4SBenjamin Maxwell loc, inputSMETiles[index], destTensorOrMemref, 368042800a4SBenjamin Maxwell getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile), 369042800a4SBenjamin Maxwell writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr()); 370042800a4SBenjamin Maxwell if (writeOp.hasPureTensorSemantics()) 371042800a4SBenjamin Maxwell destTensorOrMemref = smeWrite.getResult(); 372042800a4SBenjamin Maxwell } 373042800a4SBenjamin Maxwell 374042800a4SBenjamin Maxwell if (writeOp.hasPureTensorSemantics()) 375042800a4SBenjamin Maxwell rewriter.replaceOp(writeOp, destTensorOrMemref); 376042800a4SBenjamin Maxwell else 377042800a4SBenjamin Maxwell rewriter.eraseOp(writeOp); 378042800a4SBenjamin Maxwell 379042800a4SBenjamin Maxwell return success(); 380042800a4SBenjamin Maxwell } 381042800a4SBenjamin Maxwell }; 382042800a4SBenjamin Maxwell 3835ed5d723SBenjamin Maxwell /// Legalize a multi-tile transfer_write as a single store loop. This is done as 3845ed5d723SBenjamin Maxwell /// part of type decomposition as at this level we know each tile write is 3855ed5d723SBenjamin Maxwell /// disjoint, but that information is lost after decomposition (without analysis 3865ed5d723SBenjamin Maxwell /// to reconstruct it). 3875ed5d723SBenjamin Maxwell /// 3885ed5d723SBenjamin Maxwell /// Example (pseudo-MLIR): 3895ed5d723SBenjamin Maxwell /// 3905ed5d723SBenjamin Maxwell /// ``` 3915ed5d723SBenjamin Maxwell /// vector.transfer_write %vector, %dest[%y, %x], %mask 3925ed5d723SBenjamin Maxwell /// : vector<[16]x[8]xi16>, memref<?x?xi16> 3935ed5d723SBenjamin Maxwell /// ``` 3945ed5d723SBenjamin Maxwell /// Is rewritten to: 3955ed5d723SBenjamin Maxwell /// ``` 3965ed5d723SBenjamin Maxwell /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 { 3975ed5d723SBenjamin Maxwell /// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐ 3985ed5d723SBenjamin Maxwell /// : vector<[8]xi1> from vector<[16]x[8]xi1> | 3995ed5d723SBenjamin Maxwell /// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile 4005ed5d723SBenjamin Maxwell /// : vector<[8]xi16> from vector<[8]x[8]xi16> | 4015ed5d723SBenjamin Maxwell /// vector.transfer_write %upper_slice, | 4025ed5d723SBenjamin Maxwell /// %dest[%slice_idx + %y, %x], %upper_slice_mask | 4035ed5d723SBenjamin Maxwell /// : vector<[8]xi16>, memref<?x?xi16> ┘ 4045ed5d723SBenjamin Maxwell /// %lower_slice_idx = %slice_idx + %c8_vscale ─┐ 4055ed5d723SBenjamin Maxwell /// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] | 4065ed5d723SBenjamin Maxwell /// : vector<[8]xi1> from vector<[16]x[8]xi1> | 4075ed5d723SBenjamin Maxwell /// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower 4085ed5d723SBenjamin Maxwell /// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile 4095ed5d723SBenjamin Maxwell /// vector.transfer_write %lower_slice, | 4105ed5d723SBenjamin Maxwell /// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask | 4115ed5d723SBenjamin Maxwell /// : vector<[8]xi16>, memref<?x?xi16> ┘ 4125ed5d723SBenjamin Maxwell /// } 4135ed5d723SBenjamin Maxwell /// ``` 4145ed5d723SBenjamin Maxwell struct LegalizeMultiTileTransferWriteAsStoreLoop 41531613de9SMatthias Springer : public OpConversionPattern<vector::TransferWriteOp> { 41631613de9SMatthias Springer using OpConversionPattern::OpConversionPattern; 4175ed5d723SBenjamin Maxwell 4185ed5d723SBenjamin Maxwell LogicalResult 41931613de9SMatthias Springer matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, 42031613de9SMatthias Springer ConversionPatternRewriter &rewriter) const override { 4215ed5d723SBenjamin Maxwell if (writeOp.hasPureTensorSemantics()) 4225ed5d723SBenjamin Maxwell return rewriter.notifyMatchFailure( 4235ed5d723SBenjamin Maxwell writeOp, "TODO: tensor semantics are unsupported"); 4245ed5d723SBenjamin Maxwell 4255ed5d723SBenjamin Maxwell auto permutationMap = writeOp.getPermutationMap(); 4265ed5d723SBenjamin Maxwell if (!permutationMap.isPermutation()) 4275ed5d723SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, 4285ed5d723SBenjamin Maxwell kMatchFailureNonPermutationMap); 4295ed5d723SBenjamin Maxwell 4305ed5d723SBenjamin Maxwell bool transposed = !permutationMap.isIdentity(); 4315ed5d723SBenjamin Maxwell if (transposed) 4325ed5d723SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, 4335ed5d723SBenjamin Maxwell "TODO: transpose unsupported"); 4345ed5d723SBenjamin Maxwell 4355ed5d723SBenjamin Maxwell auto vectorType = writeOp.getVectorType(); 4365ed5d723SBenjamin Maxwell if (!isMultipleOfSMETileVectorType(vectorType)) 4375ed5d723SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, 4385ed5d723SBenjamin Maxwell kMatchFailureNotSMETileTypeMultiple); 4395ed5d723SBenjamin Maxwell 4405ed5d723SBenjamin Maxwell // Note: We also disallow masks where any dimension is > 16 because that 4415ed5d723SBenjamin Maxwell // prevents the masking from being lowered to use arm_sve.psel. 4425ed5d723SBenjamin Maxwell auto mask = writeOp.getMask(); 4435ed5d723SBenjamin Maxwell if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 || 4445ed5d723SBenjamin Maxwell vectorType.getDimSize(1) > 16))) 4455ed5d723SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, 4465ed5d723SBenjamin Maxwell kMatchFailureUnsupportedMaskOp); 4475ed5d723SBenjamin Maxwell 4485ed5d723SBenjamin Maxwell auto loc = writeOp.getLoc(); 449c194bc77SBenjamin Maxwell auto createVscaleMultiple = 450c194bc77SBenjamin Maxwell vector::makeVscaleConstantBuilder(rewriter, loc); 4515ed5d723SBenjamin Maxwell 4525ed5d723SBenjamin Maxwell // Get SME tile and slice types. 4535ed5d723SBenjamin Maxwell auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); 4545ed5d723SBenjamin Maxwell auto minTileSlices = smeTileType.getDimSize(0); 4555ed5d723SBenjamin Maxwell VectorType sliceMaskType = 4565ed5d723SBenjamin Maxwell VectorType::get(minTileSlices, rewriter.getI1Type(), true); 4575ed5d723SBenjamin Maxwell 4585ed5d723SBenjamin Maxwell // Create loop over all tile slices. 4595ed5d723SBenjamin Maxwell auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); 4605ed5d723SBenjamin Maxwell auto upperBound = createVscaleMultiple(minTileSlices); 4615ed5d723SBenjamin Maxwell auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 4625ed5d723SBenjamin Maxwell auto storeLoop = 4635ed5d723SBenjamin Maxwell rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); 4645ed5d723SBenjamin Maxwell rewriter.setInsertionPointToStart(storeLoop.getBody()); 4655ed5d723SBenjamin Maxwell 4665ed5d723SBenjamin Maxwell // For each sub-tile of the multi-tile `vectorType`. 4675ed5d723SBenjamin Maxwell auto inputSMETiles = adaptor.getVector(); 4685ed5d723SBenjamin Maxwell auto tileSliceIndex = storeLoop.getInductionVar(); 4695ed5d723SBenjamin Maxwell for (auto [index, smeTile] : llvm::enumerate( 4705ed5d723SBenjamin Maxwell decomposeToSMETiles(rewriter, vectorType, smeTileType))) { 4715ed5d723SBenjamin Maxwell // The coordinates of the tile within `vectorType`. 4725ed5d723SBenjamin Maxwell auto tileRow = createVscaleMultiple(smeTile.row); 4735ed5d723SBenjamin Maxwell auto tileCol = createVscaleMultiple(smeTile.col); 4745ed5d723SBenjamin Maxwell 4755ed5d723SBenjamin Maxwell // The current slice of `vectorType` we are processing. 4765ed5d723SBenjamin Maxwell auto sliceIndex = 4775ed5d723SBenjamin Maxwell rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex); 4785ed5d723SBenjamin Maxwell 4795ed5d723SBenjamin Maxwell // Where in the destination memref the current slice will be stored. 4805ed5d723SBenjamin Maxwell auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex, 4815ed5d723SBenjamin Maxwell writeOp.getIndices()[0]); 4825ed5d723SBenjamin Maxwell auto storeCol = 4835ed5d723SBenjamin Maxwell rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]); 4845ed5d723SBenjamin Maxwell 4855ed5d723SBenjamin Maxwell // Extract the mask for the current slice. 4865ed5d723SBenjamin Maxwell Value sliceMask = nullptr; 4875ed5d723SBenjamin Maxwell if (mask) { 4885ed5d723SBenjamin Maxwell sliceMask = rewriter.create<vector::ExtractOp>( 4895ed5d723SBenjamin Maxwell loc, mask, OpFoldResult(sliceIndex)); 4905ed5d723SBenjamin Maxwell if (sliceMaskType != sliceMask.getType()) 4915ed5d723SBenjamin Maxwell sliceMask = rewriter.create<vector::ScalableExtractOp>( 4925ed5d723SBenjamin Maxwell loc, sliceMaskType, sliceMask, smeTile.col); 4935ed5d723SBenjamin Maxwell } 4945ed5d723SBenjamin Maxwell 4955ed5d723SBenjamin Maxwell // Extract and store the current slice. 4965ed5d723SBenjamin Maxwell Value tile = inputSMETiles[index]; 4975ed5d723SBenjamin Maxwell auto slice = 4985ed5d723SBenjamin Maxwell rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex); 4995ed5d723SBenjamin Maxwell rewriter.create<vector::TransferWriteOp>( 5005ed5d723SBenjamin Maxwell loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol}, 5015ed5d723SBenjamin Maxwell AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)), 5025ed5d723SBenjamin Maxwell sliceMask, 5035ed5d723SBenjamin Maxwell rewriter.getBoolArrayAttr( 5045ed5d723SBenjamin Maxwell ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front())); 5055ed5d723SBenjamin Maxwell } 5065ed5d723SBenjamin Maxwell 5075ed5d723SBenjamin Maxwell rewriter.eraseOp(writeOp); 5085ed5d723SBenjamin Maxwell return success(); 5095ed5d723SBenjamin Maxwell } 5105ed5d723SBenjamin Maxwell }; 5115ed5d723SBenjamin Maxwell 512c2dea712SBenjamin Maxwell //===----------------------------------------------------------------------===// 513c2dea712SBenjamin Maxwell // ArmSME-specific fixup canonicalizations/folds 514c2dea712SBenjamin Maxwell //===----------------------------------------------------------------------===// 515c2dea712SBenjamin Maxwell 516c2dea712SBenjamin Maxwell /// Folds an extract from a 3D `vector.create_mask` (which is a vector of 517c2dea712SBenjamin Maxwell /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is 518c2dea712SBenjamin Maxwell /// necessary for the mask to be lowered to ArmSME. 519c2dea712SBenjamin Maxwell /// 520c2dea712SBenjamin Maxwell /// Example: 521c2dea712SBenjamin Maxwell /// 522c2dea712SBenjamin Maxwell /// BEFORE: 523c2dea712SBenjamin Maxwell /// ```mlir 524c2dea712SBenjamin Maxwell /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> 525c2dea712SBenjamin Maxwell /// %subMask = vector.extract %mask[2] 526c2dea712SBenjamin Maxwell /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> 527c2dea712SBenjamin Maxwell /// ``` 528c2dea712SBenjamin Maxwell /// 529c2dea712SBenjamin Maxwell /// AFTER: 530c2dea712SBenjamin Maxwell /// ```mlir 531c2dea712SBenjamin Maxwell /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index 532c2dea712SBenjamin Maxwell /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index 533c2dea712SBenjamin Maxwell /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> 534c2dea712SBenjamin Maxwell /// ``` 535c2dea712SBenjamin Maxwell struct FoldExtractFromVectorOfSMELikeCreateMasks 536c2dea712SBenjamin Maxwell : public OpRewritePattern<vector::ExtractOp> { 537c2dea712SBenjamin Maxwell using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; 538c2dea712SBenjamin Maxwell 539c2dea712SBenjamin Maxwell LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 540c2dea712SBenjamin Maxwell PatternRewriter &rewriter) const override { 541c2dea712SBenjamin Maxwell auto loc = extractOp.getLoc(); 542c2dea712SBenjamin Maxwell auto createMaskOp = 543c2dea712SBenjamin Maxwell extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); 544c2dea712SBenjamin Maxwell if (!createMaskOp) 545c2dea712SBenjamin Maxwell return rewriter.notifyMatchFailure( 546c2dea712SBenjamin Maxwell extractOp, "extract not from vector.create_mask op"); 547c2dea712SBenjamin Maxwell 548c2dea712SBenjamin Maxwell VectorType extractedMaskType = 549c2dea712SBenjamin Maxwell llvm::dyn_cast<VectorType>(extractOp.getResult().getType()); 550c2dea712SBenjamin Maxwell if (!extractedMaskType) 551c2dea712SBenjamin Maxwell return rewriter.notifyMatchFailure(extractOp, 552c2dea712SBenjamin Maxwell "extracted type is not a vector type"); 553c2dea712SBenjamin Maxwell 554fe07d9aaSAndrzej Warzyński auto numScalable = extractedMaskType.getNumScalableDims(); 555c2dea712SBenjamin Maxwell if (numScalable != 2) 556c2dea712SBenjamin Maxwell return rewriter.notifyMatchFailure( 557c2dea712SBenjamin Maxwell extractOp, "expected extracted type to be an SME-like mask"); 558c2dea712SBenjamin Maxwell 559c2dea712SBenjamin Maxwell // TODO: Support multiple extraction indices. 560c2dea712SBenjamin Maxwell if (extractOp.getStaticPosition().size() != 1) 561c2dea712SBenjamin Maxwell return rewriter.notifyMatchFailure( 562c2dea712SBenjamin Maxwell extractOp, "only a single extraction index is supported"); 563c2dea712SBenjamin Maxwell 564c2dea712SBenjamin Maxwell auto frontMaskDim = createMaskOp.getOperand(0); 565c2dea712SBenjamin Maxwell if (frontMaskDim.getDefiningOp<arith::ConstantOp>()) 566c2dea712SBenjamin Maxwell return rewriter.notifyMatchFailure( 567c2dea712SBenjamin Maxwell extractOp, 568c2dea712SBenjamin Maxwell "constant vector.create_masks dims should be folded elsewhere"); 569c2dea712SBenjamin Maxwell 570c2dea712SBenjamin Maxwell auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 571c2dea712SBenjamin Maxwell auto extractionIndex = getValueOrCreateConstantIndexOp( 572c2dea712SBenjamin Maxwell rewriter, loc, extractOp.getMixedPosition()[0]); 573c2dea712SBenjamin Maxwell auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>( 574c2dea712SBenjamin Maxwell loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex, 575c2dea712SBenjamin Maxwell frontMaskDim); 576c2dea712SBenjamin Maxwell auto newMaskFrontDim = rewriter.create<arith::SelectOp>( 577c2dea712SBenjamin Maxwell loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero); 578c2dea712SBenjamin Maxwell 579c2dea712SBenjamin Maxwell rewriter.replaceOpWithNewOp<vector::CreateMaskOp>( 580c2dea712SBenjamin Maxwell extractOp, extractedMaskType, 581c2dea712SBenjamin Maxwell ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)}); 582c2dea712SBenjamin Maxwell return success(); 583c2dea712SBenjamin Maxwell } 584c2dea712SBenjamin Maxwell }; 585c2dea712SBenjamin Maxwell 586d1fc59c3SBenjamin Maxwell /// A vector type where no fixed dimension comes after a scalable dimension. 587d1fc59c3SBenjamin Maxwell bool isLegalVectorType(VectorType vType) { 588d1fc59c3SBenjamin Maxwell bool seenFixedDim = false; 589d1fc59c3SBenjamin Maxwell for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) { 590d1fc59c3SBenjamin Maxwell seenFixedDim |= !scalableFlag; 591d1fc59c3SBenjamin Maxwell if (seenFixedDim && scalableFlag) 592d1fc59c3SBenjamin Maxwell return false; 593d1fc59c3SBenjamin Maxwell } 594d1fc59c3SBenjamin Maxwell return true; 595d1fc59c3SBenjamin Maxwell } 596d1fc59c3SBenjamin Maxwell 5970473e322SBenjamin Maxwell /// Lifts an illegal vector.transpose and vector.transfer_read to a 5980473e322SBenjamin Maxwell /// memref.subview + memref.transpose, followed by a legal read. 5990473e322SBenjamin Maxwell /// 6000473e322SBenjamin Maxwell /// 'Illegal' here means a leading scalable dimension and a fixed trailing 6010473e322SBenjamin Maxwell /// dimension, which has no valid lowering. 6020473e322SBenjamin Maxwell /// 6030473e322SBenjamin Maxwell /// The memref.transpose is metadata-only transpose that produces a strided 6040473e322SBenjamin Maxwell /// memref, which eventually becomes a loop reading individual elements. 6050473e322SBenjamin Maxwell /// 6060473e322SBenjamin Maxwell /// Example: 6070473e322SBenjamin Maxwell /// 6080473e322SBenjamin Maxwell /// BEFORE: 6090473e322SBenjamin Maxwell /// ```mlir 6100473e322SBenjamin Maxwell /// %illegalRead = vector.transfer_read %memref[%a, %b] 6110473e322SBenjamin Maxwell /// : memref<?x?xf32>, vector<[8]x4xf32> 6120473e322SBenjamin Maxwell /// %legalType = vector.transpose %illegalRead, [1, 0] 6130473e322SBenjamin Maxwell /// : vector<[8]x4xf32> to vector<4x[8]xf32> 6140473e322SBenjamin Maxwell /// ``` 6150473e322SBenjamin Maxwell /// 6160473e322SBenjamin Maxwell /// AFTER: 6170473e322SBenjamin Maxwell /// ```mlir 6180473e322SBenjamin Maxwell /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1] 6190473e322SBenjamin Maxwell /// : memref<?x?xf32> to memref<?x?xf32> 6200473e322SBenjamin Maxwell /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0) 6210473e322SBenjamin Maxwell /// : memref<?x?xf32> to memref<?x?xf32> 6220473e322SBenjamin Maxwell /// %legalType = vector.transfer_read %transpose[%c0, %c0] 6230473e322SBenjamin Maxwell /// : memref<?x?xf32>, vector<4x[8]xf32> 6240473e322SBenjamin Maxwell /// ``` 6250473e322SBenjamin Maxwell struct LiftIllegalVectorTransposeToMemory 6260473e322SBenjamin Maxwell : public OpRewritePattern<vector::TransposeOp> { 6270473e322SBenjamin Maxwell using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 6280473e322SBenjamin Maxwell 6290473e322SBenjamin Maxwell static Value getExtensionSource(Operation *op) { 6308cfb7161SBenjamin Maxwell if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op)) 6310473e322SBenjamin Maxwell return op->getOperand(0); 6320473e322SBenjamin Maxwell return {}; 6330473e322SBenjamin Maxwell } 6340473e322SBenjamin Maxwell 6350473e322SBenjamin Maxwell LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 6360473e322SBenjamin Maxwell PatternRewriter &rewriter) const override { 6370473e322SBenjamin Maxwell auto sourceType = transposeOp.getSourceVectorType(); 6380473e322SBenjamin Maxwell auto resultType = transposeOp.getResultVectorType(); 639d1fc59c3SBenjamin Maxwell if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) 640d1fc59c3SBenjamin Maxwell return rewriter.notifyMatchFailure(transposeOp, 641d1fc59c3SBenjamin Maxwell kMatchFailureNotIllegalToLegal); 6420473e322SBenjamin Maxwell 6430473e322SBenjamin Maxwell // Look through extend for transfer_read. 6440473e322SBenjamin Maxwell Value maybeRead = transposeOp.getVector(); 6450473e322SBenjamin Maxwell auto *transposeSourceOp = maybeRead.getDefiningOp(); 6460473e322SBenjamin Maxwell Operation *extendOp = nullptr; 6470473e322SBenjamin Maxwell if (Value extendSource = getExtensionSource(transposeSourceOp)) { 6480473e322SBenjamin Maxwell maybeRead = extendSource; 6490473e322SBenjamin Maxwell extendOp = transposeSourceOp; 6500473e322SBenjamin Maxwell } 6510473e322SBenjamin Maxwell 6520473e322SBenjamin Maxwell auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>(); 6530473e322SBenjamin Maxwell if (!illegalRead) 6540473e322SBenjamin Maxwell return rewriter.notifyMatchFailure( 6550473e322SBenjamin Maxwell transposeOp, 6560473e322SBenjamin Maxwell "expected source to be (possibly extended) transfer_read"); 6570473e322SBenjamin Maxwell 6580473e322SBenjamin Maxwell if (!illegalRead.getPermutationMap().isIdentity()) 6590473e322SBenjamin Maxwell return rewriter.notifyMatchFailure( 6600473e322SBenjamin Maxwell illegalRead, "expected read to have identity permutation map"); 6610473e322SBenjamin Maxwell 6620473e322SBenjamin Maxwell auto loc = transposeOp.getLoc(); 6630473e322SBenjamin Maxwell auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 6640473e322SBenjamin Maxwell auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 6650473e322SBenjamin Maxwell 6660473e322SBenjamin Maxwell // Create a subview that matches the size of the illegal read vector type. 6670473e322SBenjamin Maxwell auto readType = illegalRead.getVectorType(); 6680473e322SBenjamin Maxwell auto readSizes = llvm::map_to_vector( 6690473e322SBenjamin Maxwell llvm::zip_equal(readType.getShape(), readType.getScalableDims()), 6700473e322SBenjamin Maxwell [&](auto dim) -> Value { 6710473e322SBenjamin Maxwell auto [size, isScalable] = dim; 6720473e322SBenjamin Maxwell auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size); 6730473e322SBenjamin Maxwell if (!isScalable) 6740473e322SBenjamin Maxwell return dimSize; 6750473e322SBenjamin Maxwell auto vscale = rewriter.create<vector::VectorScaleOp>(loc); 6760473e322SBenjamin Maxwell return rewriter.create<arith::MulIOp>(loc, vscale, dimSize); 6770473e322SBenjamin Maxwell }); 6780473e322SBenjamin Maxwell SmallVector<Value> strides(readType.getRank(), Value(one)); 6790473e322SBenjamin Maxwell auto readSubview = rewriter.create<memref::SubViewOp>( 6800473e322SBenjamin Maxwell loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes, 6810473e322SBenjamin Maxwell strides); 6820473e322SBenjamin Maxwell 6830473e322SBenjamin Maxwell // Apply the transpose to all values/attributes of the transfer_read: 6840473e322SBenjamin Maxwell // - The mask 6850473e322SBenjamin Maxwell Value mask = illegalRead.getMask(); 6860473e322SBenjamin Maxwell if (mask) { 6870473e322SBenjamin Maxwell // Note: The transpose for the mask should fold into the 6880473e322SBenjamin Maxwell // vector.create_mask/constant_mask op, which will then become legal. 6890473e322SBenjamin Maxwell mask = rewriter.create<vector::TransposeOp>(loc, mask, 6900473e322SBenjamin Maxwell transposeOp.getPermutation()); 6910473e322SBenjamin Maxwell } 6920473e322SBenjamin Maxwell // - The source memref 6930473e322SBenjamin Maxwell mlir::AffineMap transposeMap = AffineMap::getPermutationMap( 6940473e322SBenjamin Maxwell transposeOp.getPermutation(), getContext()); 6950473e322SBenjamin Maxwell auto transposedSubview = rewriter.create<memref::TransposeOp>( 6960473e322SBenjamin Maxwell loc, readSubview, AffineMapAttr::get(transposeMap)); 6970473e322SBenjamin Maxwell ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr(); 6980473e322SBenjamin Maxwell // - The `in_bounds` attribute 6990473e322SBenjamin Maxwell if (inBoundsAttr) { 7000473e322SBenjamin Maxwell SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(), 7010473e322SBenjamin Maxwell inBoundsAttr.end()); 7020473e322SBenjamin Maxwell applyPermutationToVector(inBoundsValues, transposeOp.getPermutation()); 7030473e322SBenjamin Maxwell inBoundsAttr = rewriter.getArrayAttr(inBoundsValues); 7040473e322SBenjamin Maxwell } 7050473e322SBenjamin Maxwell 7060473e322SBenjamin Maxwell VectorType legalReadType = resultType.clone(readType.getElementType()); 7070473e322SBenjamin Maxwell // Note: The indices are all zero as the subview is already offset. 7080473e322SBenjamin Maxwell SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero); 7090473e322SBenjamin Maxwell auto legalRead = rewriter.create<vector::TransferReadOp>( 7100473e322SBenjamin Maxwell loc, legalReadType, transposedSubview, readIndices, 7110473e322SBenjamin Maxwell illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask, 7120473e322SBenjamin Maxwell inBoundsAttr); 7130473e322SBenjamin Maxwell 7140473e322SBenjamin Maxwell // Replace the transpose with the new read, extending the result if 7150473e322SBenjamin Maxwell // necessary. 7160473e322SBenjamin Maxwell rewriter.replaceOp(transposeOp, [&]() -> Operation * { 7170473e322SBenjamin Maxwell if (extendOp) 7180473e322SBenjamin Maxwell return rewriter.create(loc, extendOp->getName().getIdentifier(), 7190473e322SBenjamin Maxwell Value(legalRead), resultType); 7200473e322SBenjamin Maxwell return legalRead; 7210473e322SBenjamin Maxwell }()); 7220473e322SBenjamin Maxwell 7230473e322SBenjamin Maxwell return success(); 7240473e322SBenjamin Maxwell } 7250473e322SBenjamin Maxwell }; 7260473e322SBenjamin Maxwell 727d1fc59c3SBenjamin Maxwell /// A rewrite to turn unit dim transpose-like vector.shape_casts into 728d1fc59c3SBenjamin Maxwell /// vector.transposes. The shape_cast has to be from an illegal vector type to a 729d1fc59c3SBenjamin Maxwell /// legal one (as defined by isLegalVectorType). 730d1fc59c3SBenjamin Maxwell /// 731d1fc59c3SBenjamin Maxwell /// The reasoning for this is if we've got to this pass and we still have 732d1fc59c3SBenjamin Maxwell /// shape_casts of illegal types, then they likely will not cancel out. Turning 733d1fc59c3SBenjamin Maxwell /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to 734d1fc59c3SBenjamin Maxwell /// eliminate them. 735d1fc59c3SBenjamin Maxwell /// 736d1fc59c3SBenjamin Maxwell /// Example: 737d1fc59c3SBenjamin Maxwell /// 738d1fc59c3SBenjamin Maxwell /// BEFORE: 739d1fc59c3SBenjamin Maxwell /// ```mlir 740d1fc59c3SBenjamin Maxwell /// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32> 741d1fc59c3SBenjamin Maxwell /// ``` 742d1fc59c3SBenjamin Maxwell /// 743d1fc59c3SBenjamin Maxwell /// AFTER: 744d1fc59c3SBenjamin Maxwell /// ```mlir 745d1fc59c3SBenjamin Maxwell /// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> 746d1fc59c3SBenjamin Maxwell /// ``` 747d1fc59c3SBenjamin Maxwell struct ConvertIllegalShapeCastOpsToTransposes 748d1fc59c3SBenjamin Maxwell : public OpRewritePattern<vector::ShapeCastOp> { 749d1fc59c3SBenjamin Maxwell using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 750d1fc59c3SBenjamin Maxwell 751d1fc59c3SBenjamin Maxwell LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 752d1fc59c3SBenjamin Maxwell PatternRewriter &rewriter) const override { 753d1fc59c3SBenjamin Maxwell auto sourceType = shapeCastOp.getSourceVectorType(); 754d1fc59c3SBenjamin Maxwell auto resultType = shapeCastOp.getResultVectorType(); 755d1fc59c3SBenjamin Maxwell if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) 756d1fc59c3SBenjamin Maxwell return rewriter.notifyMatchFailure(shapeCastOp, 757d1fc59c3SBenjamin Maxwell kMatchFailureNotIllegalToLegal); 758d1fc59c3SBenjamin Maxwell 759d1fc59c3SBenjamin Maxwell // Note: If we know that `sourceType` is an illegal vector type (and 2D) 760d1fc59c3SBenjamin Maxwell // then dim 0 is scalable and dim 1 is fixed. 761d1fc59c3SBenjamin Maxwell if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1) 762d1fc59c3SBenjamin Maxwell return rewriter.notifyMatchFailure( 763d1fc59c3SBenjamin Maxwell shapeCastOp, "expected source to be a 2D scalable vector with a " 764d1fc59c3SBenjamin Maxwell "trailing unit dim"); 765d1fc59c3SBenjamin Maxwell 766d1fc59c3SBenjamin Maxwell auto loc = shapeCastOp.getLoc(); 767d1fc59c3SBenjamin Maxwell auto transpose = rewriter.create<vector::TransposeOp>( 768d1fc59c3SBenjamin Maxwell loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0}); 769d1fc59c3SBenjamin Maxwell 770d1fc59c3SBenjamin Maxwell if (resultType.getRank() == 1) 771d1fc59c3SBenjamin Maxwell rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType, 772d1fc59c3SBenjamin Maxwell transpose); 773d1fc59c3SBenjamin Maxwell else 774d1fc59c3SBenjamin Maxwell rewriter.replaceOp(shapeCastOp, transpose); 775d1fc59c3SBenjamin Maxwell 776d1fc59c3SBenjamin Maxwell return success(); 777d1fc59c3SBenjamin Maxwell } 778d1fc59c3SBenjamin Maxwell }; 779d1fc59c3SBenjamin Maxwell 780c194bc77SBenjamin Maxwell /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use 781c194bc77SBenjamin Maxwell /// the ZA state. This workaround rewrite to support these transposes when ZA is 782c194bc77SBenjamin Maxwell /// available. 783c194bc77SBenjamin Maxwell /// 784c194bc77SBenjamin Maxwell /// Example: 785c194bc77SBenjamin Maxwell /// 786c194bc77SBenjamin Maxwell /// BEFORE: 787c194bc77SBenjamin Maxwell /// ```mlir 788c194bc77SBenjamin Maxwell /// %transpose = vector.transpose %vec, [1, 0] 789c194bc77SBenjamin Maxwell /// : vector<2x[4]xf32> to vector<[4]x2xf32> 790c194bc77SBenjamin Maxwell /// vector.transfer_write %transpose, %dest[%y, %x] 791c194bc77SBenjamin Maxwell /// : vector<[4]x2xf32>, memref<?x?xf32> 792c194bc77SBenjamin Maxwell /// ``` 793c194bc77SBenjamin Maxwell /// 794c194bc77SBenjamin Maxwell /// AFTER: 795c194bc77SBenjamin Maxwell /// ```mlir 796c194bc77SBenjamin Maxwell /// %0 = arm_sme.get_tile : vector<[4]x[4]xf32> 797c194bc77SBenjamin Maxwell /// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32> 798c194bc77SBenjamin Maxwell /// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32> 799c194bc77SBenjamin Maxwell /// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32> 800c194bc77SBenjamin Maxwell /// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32> 801c194bc77SBenjamin Maxwell /// %c4_vscale = arith.muli %vscale, %c4 : index 802c194bc77SBenjamin Maxwell /// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1> 803c194bc77SBenjamin Maxwell /// vector.transfer_write %4, %dest[%y, %x], %mask 804c194bc77SBenjamin Maxwell /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} 805c194bc77SBenjamin Maxwell /// : vector<[4]x[4]xf32>, memref<?x?xf32> 806c194bc77SBenjamin Maxwell /// ``` 807c194bc77SBenjamin Maxwell /// 808c194bc77SBenjamin Maxwell /// Values larger than a single tile are supported via decomposition. 809c194bc77SBenjamin Maxwell struct LowerIllegalTransposeStoreViaZA 810c194bc77SBenjamin Maxwell : public OpRewritePattern<vector::TransferWriteOp> { 811c194bc77SBenjamin Maxwell using OpRewritePattern::OpRewritePattern; 812c194bc77SBenjamin Maxwell 813c194bc77SBenjamin Maxwell LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 814c194bc77SBenjamin Maxwell PatternRewriter &rewriter) const override { 815c194bc77SBenjamin Maxwell if (!isSupportedMaskOp(writeOp.getMask())) 816c194bc77SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, 817c194bc77SBenjamin Maxwell kMatchFailureUnsupportedMaskOp); 818c194bc77SBenjamin Maxwell 819c194bc77SBenjamin Maxwell auto permutationMap = writeOp.getPermutationMap(); 820c194bc77SBenjamin Maxwell if (!permutationMap.isIdentity()) 821c194bc77SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, 822c194bc77SBenjamin Maxwell kMatchFailureNonPermutationMap); 823c194bc77SBenjamin Maxwell 824c194bc77SBenjamin Maxwell auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>(); 825c194bc77SBenjamin Maxwell if (!transposeOp) 826c194bc77SBenjamin Maxwell return failure(); 827c194bc77SBenjamin Maxwell 828c194bc77SBenjamin Maxwell auto sourceType = transposeOp.getSourceVectorType(); 829c194bc77SBenjamin Maxwell auto resultType = transposeOp.getResultVectorType(); 830c194bc77SBenjamin Maxwell 831c194bc77SBenjamin Maxwell if (resultType.getRank() != 2) 832c194bc77SBenjamin Maxwell return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2"); 833c194bc77SBenjamin Maxwell 834c194bc77SBenjamin Maxwell if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType)) 835c194bc77SBenjamin Maxwell return rewriter.notifyMatchFailure( 836c194bc77SBenjamin Maxwell transposeOp, "not illegal/unsupported SVE transpose"); 837c194bc77SBenjamin Maxwell 838c194bc77SBenjamin Maxwell auto smeTileType = getSMETileTypeForElement(resultType.getElementType()); 839c194bc77SBenjamin Maxwell VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0); 840c194bc77SBenjamin Maxwell 841c194bc77SBenjamin Maxwell if (sourceType.getDimSize(0) <= 1 || 842c194bc77SBenjamin Maxwell sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0) 843c194bc77SBenjamin Maxwell return rewriter.notifyMatchFailure(writeOp, "unsupported source shape"); 844c194bc77SBenjamin Maxwell 845c194bc77SBenjamin Maxwell auto loc = writeOp.getLoc(); 846c194bc77SBenjamin Maxwell auto createVscaleMultiple = 847c194bc77SBenjamin Maxwell vector::makeVscaleConstantBuilder(rewriter, loc); 848c194bc77SBenjamin Maxwell 849c194bc77SBenjamin Maxwell auto transposeMap = AffineMapAttr::get( 850c194bc77SBenjamin Maxwell AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext())); 851c194bc77SBenjamin Maxwell 852c194bc77SBenjamin Maxwell // Note: We need to use `get_tile` as there's no vector-level `undef`. 853c194bc77SBenjamin Maxwell Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType); 854c194bc77SBenjamin Maxwell Value destTensorOrMemref = writeOp.getSource(); 855c194bc77SBenjamin Maxwell auto numSlicesPerTile = 856c194bc77SBenjamin Maxwell std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0)); 857c194bc77SBenjamin Maxwell auto numSlices = 858c194bc77SBenjamin Maxwell rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile); 859c194bc77SBenjamin Maxwell for (auto [index, smeTile] : llvm::enumerate( 860c194bc77SBenjamin Maxwell decomposeToSMETiles(rewriter, sourceType, smeTileType))) { 861c194bc77SBenjamin Maxwell // 1. _Deliberately_ drop a scalable dimension and insert a fixed number 862c194bc77SBenjamin Maxwell // of slices from the source type into the SME tile. Without checking 863c194bc77SBenjamin Maxwell // vscale (and emitting multiple implementations) we can't make use of the 864c194bc77SBenjamin Maxwell // rows of the tile after 1*vscale rows. 865c194bc77SBenjamin Maxwell Value tile = undefTile; 866c194bc77SBenjamin Maxwell for (int d = 0; d < numSlicesPerTile; ++d) { 867c194bc77SBenjamin Maxwell Value vector = rewriter.create<vector::ExtractOp>( 868c194bc77SBenjamin Maxwell loc, transposeOp.getVector(), 869c194bc77SBenjamin Maxwell rewriter.getIndexAttr(d + smeTile.row)); 870c194bc77SBenjamin Maxwell if (vector.getType() != smeSliceType) { 871c194bc77SBenjamin Maxwell vector = rewriter.create<vector::ScalableExtractOp>( 872c194bc77SBenjamin Maxwell loc, smeSliceType, vector, smeTile.col); 873c194bc77SBenjamin Maxwell } 874c194bc77SBenjamin Maxwell tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d); 875c194bc77SBenjamin Maxwell } 876c194bc77SBenjamin Maxwell 877c194bc77SBenjamin Maxwell // 2. Transpose the tile position. 878c194bc77SBenjamin Maxwell auto transposedRow = createVscaleMultiple(smeTile.col); 879c194bc77SBenjamin Maxwell auto transposedCol = 880c194bc77SBenjamin Maxwell rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row); 881c194bc77SBenjamin Maxwell 882c194bc77SBenjamin Maxwell // 3. Compute mask for tile store. 883c194bc77SBenjamin Maxwell Value maskRows; 884c194bc77SBenjamin Maxwell Value maskCols; 885c194bc77SBenjamin Maxwell if (auto mask = writeOp.getMask()) { 886c194bc77SBenjamin Maxwell auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); 887c194bc77SBenjamin Maxwell maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0), 888c194bc77SBenjamin Maxwell transposedRow); 889c194bc77SBenjamin Maxwell maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1), 890c194bc77SBenjamin Maxwell transposedCol); 891c194bc77SBenjamin Maxwell maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices); 892c194bc77SBenjamin Maxwell } else { 893c194bc77SBenjamin Maxwell maskRows = createVscaleMultiple(smeTileType.getDimSize(0)); 894c194bc77SBenjamin Maxwell maskCols = numSlices; 895c194bc77SBenjamin Maxwell } 896c194bc77SBenjamin Maxwell auto subMask = rewriter.create<vector::CreateMaskOp>( 897c194bc77SBenjamin Maxwell loc, smeTileType.clone(rewriter.getI1Type()), 898c194bc77SBenjamin Maxwell ValueRange{maskRows, maskCols}); 899c194bc77SBenjamin Maxwell 900c194bc77SBenjamin Maxwell // 4. Emit a transposed tile write. 901c194bc77SBenjamin Maxwell auto writeIndices = writeOp.getIndices(); 902c194bc77SBenjamin Maxwell Value destRow = 903c194bc77SBenjamin Maxwell rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]); 904c194bc77SBenjamin Maxwell Value destCol = 905c194bc77SBenjamin Maxwell rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]); 906c194bc77SBenjamin Maxwell auto smeWrite = rewriter.create<vector::TransferWriteOp>( 907c194bc77SBenjamin Maxwell loc, tile, destTensorOrMemref, ValueRange{destRow, destCol}, 908c194bc77SBenjamin Maxwell transposeMap, subMask, writeOp.getInBounds()); 909c194bc77SBenjamin Maxwell 910c194bc77SBenjamin Maxwell if (writeOp.hasPureTensorSemantics()) 911c194bc77SBenjamin Maxwell destTensorOrMemref = smeWrite.getResult(); 912c194bc77SBenjamin Maxwell } 913c194bc77SBenjamin Maxwell 914c194bc77SBenjamin Maxwell if (writeOp.hasPureTensorSemantics()) 915c194bc77SBenjamin Maxwell rewriter.replaceOp(writeOp, destTensorOrMemref); 916c194bc77SBenjamin Maxwell else 917c194bc77SBenjamin Maxwell rewriter.eraseOp(writeOp); 918c194bc77SBenjamin Maxwell 919c194bc77SBenjamin Maxwell return success(); 920c194bc77SBenjamin Maxwell } 921c194bc77SBenjamin Maxwell }; 922c194bc77SBenjamin Maxwell 923042800a4SBenjamin Maxwell struct VectorLegalizationPass 924042800a4SBenjamin Maxwell : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> { 925042800a4SBenjamin Maxwell void runOnOperation() override { 926042800a4SBenjamin Maxwell auto *context = &getContext(); 9278c4bc1e7SMatthias Springer TypeConverter converter; 928042800a4SBenjamin Maxwell RewritePatternSet patterns(context); 929042800a4SBenjamin Maxwell converter.addConversion([](Type type) { return type; }); 930042800a4SBenjamin Maxwell converter.addConversion( 931042800a4SBenjamin Maxwell [](VectorType vectorType, 932042800a4SBenjamin Maxwell SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> { 933042800a4SBenjamin Maxwell if (!isMultipleOfSMETileVectorType(vectorType)) 934042800a4SBenjamin Maxwell return std::nullopt; 935042800a4SBenjamin Maxwell auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType); 936042800a4SBenjamin Maxwell auto smeTileType = 937042800a4SBenjamin Maxwell getSMETileTypeForElement(vectorType.getElementType()); 938042800a4SBenjamin Maxwell types = SmallVector<Type>(smeTileCount, smeTileType); 939042800a4SBenjamin Maxwell return success(); 940042800a4SBenjamin Maxwell }); 941042800a4SBenjamin Maxwell 94231613de9SMatthias Springer // Apply preprocessing patterns. 94331613de9SMatthias Springer RewritePatternSet rewritePatterns(context); 94431613de9SMatthias Springer rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks, 945d1fc59c3SBenjamin Maxwell LiftIllegalVectorTransposeToMemory, 946c194bc77SBenjamin Maxwell ConvertIllegalShapeCastOpsToTransposes, 947fc4485bfSBenjamin Maxwell LowerIllegalTransposeStoreViaZA>(context); 94831613de9SMatthias Springer if (failed( 94931613de9SMatthias Springer applyPatternsGreedily(getOperation(), std::move(rewritePatterns)))) 95031613de9SMatthias Springer return signalPassFailure(); 95131613de9SMatthias Springer 9525ed5d723SBenjamin Maxwell // Note: These two patterns are added with a high benefit to ensure: 9535ed5d723SBenjamin Maxwell // - Masked outer products are handled before unmasked ones 9545ed5d723SBenjamin Maxwell // - Multi-tile writes are lowered as a store loop (if possible) 9555ed5d723SBenjamin Maxwell patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition, 9565ed5d723SBenjamin Maxwell LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context, 9575ed5d723SBenjamin Maxwell /*benefit=*/1024); 958dadcaf82SBenjamin Maxwell patterns.add<LegalizeArithConstantOpsByDecomposition, 959dadcaf82SBenjamin Maxwell LegalizeVectorOuterProductOpsByDecomposition, 960042800a4SBenjamin Maxwell LegalizeTransferReadOpsByDecomposition, 961042800a4SBenjamin Maxwell LegalizeTransferWriteOpsByDecomposition>(converter, context); 96231613de9SMatthias Springer populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 96331613de9SMatthias Springer converter); 96431613de9SMatthias Springer populateCallOpTypeConversionPattern(patterns, converter); 96531613de9SMatthias Springer populateReturnOpTypeConversionPattern(patterns, converter); 96631613de9SMatthias Springer scf::populateSCFStructuralTypeConversions(converter, patterns); 967042800a4SBenjamin Maxwell 96831613de9SMatthias Springer ConversionTarget target(getContext()); 96931613de9SMatthias Springer target.markUnknownOpDynamicallyLegal( 97031613de9SMatthias Springer [&](Operation *op) { return converter.isLegal(op); }); 97131613de9SMatthias Springer target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 97231613de9SMatthias Springer return converter.isSignatureLegal(op.getFunctionType()); 97331613de9SMatthias Springer }); 97431613de9SMatthias Springer if (failed(applyPartialConversion(getOperation(), target, 975042800a4SBenjamin Maxwell std::move(patterns)))) 976042800a4SBenjamin Maxwell return signalPassFailure(); 977042800a4SBenjamin Maxwell } 978042800a4SBenjamin Maxwell }; 979042800a4SBenjamin Maxwell 980042800a4SBenjamin Maxwell } // namespace 981042800a4SBenjamin Maxwell 982042800a4SBenjamin Maxwell std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() { 983042800a4SBenjamin Maxwell return std::make_unique<VectorLegalizationPass>(); 984042800a4SBenjamin Maxwell } 985