//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass legalizes vector operations so they can be lowered to ArmSME. // // Note: In the context of this pass 'tile' always refers to an SME tile. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "arm-sme-vector-legalization" namespace mlir::arm_sme { #define GEN_PASS_DEF_VECTORLEGALIZATION #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" } // namespace mlir::arm_sme using namespace mlir; using namespace mlir::arm_sme; namespace { //===----------------------------------------------------------------------===// // Decomposition of vector operations larger than an SME tile //===----------------------------------------------------------------------===// // Common match failure reasons. static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple( "op vector size is not multiple of SME tiles"); static constexpr StringLiteral kMatchFailureUnsupportedMaskOp( "op mask is unsupported for legalization/decomposition"); static constexpr StringLiteral kMatchFailureNonPermutationMap("op affine map is not a permutation"); static constexpr StringLiteral kMatchFailureNotIllegalToLegal( "expected transpose from illegal type to legal type"); /// An SMESubTile represents a single SME-sized sub-tile from decomposing a /// larger vector type. The (`row`, `col`) are the position of the tile in the /// original vector type. For example for an [8]x[8] tile with four [4]x[4] /// sub-tiles, we would have: /// /// 8 x vscale /// ┌─────────────┬─────────────┐ /// │(0,0) │(0,4) │ /// │ │ │ /// ├─────────────┼─────────────┤ 8 x vscale /// │(4,0) │(4,4) │ /// │ │ │ /// └─────────────┴─────────────┘ struct SMESubTile { // Note: The units of (row, col) are vscale (as SME tiles are scalable). int row{0}; int col{0}; // The SME tile type. VectorType type; }; /// Adds a constant elementwise scalable offset to `indices` (which are of equal /// length). For example, in the 2D case this would return: // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale } SmallVector addConstantScalableOffset(OpBuilder &builder, Location loc, ValueRange indices, ArrayRef scalableOffsets) { auto vscale = builder.create(loc); return llvm::map_to_vector( llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value { auto [index, base] = pair; auto offset = builder.create( loc, builder.create(loc, base), vscale); return builder.create(loc, index, offset); }); } /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to /// indices for one of the SME sub-tiles it will decompose into. /// /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the /// indices for each tile would need to be adjusted as follows: /// /// initial indices = [a,b], inital size = 8x8, target size = 4x4 /// ┌─────────────┬─────────────┐ /// │[a,b] │[a,b+4] │ /// │ │ │ /// ├─────────────┼─────────────┤ /// │[a+4,b] │[a+4,b+4] │ /// │ │ │ /// └─────────────┴─────────────┘ SmallVector getSMESubTileIndices(OpBuilder &builder, Location loc, ValueRange indices, SMESubTile smeTile) { return addConstantScalableOffset(builder, loc, indices, {smeTile.row, smeTile.col}); } /// Returns true if `mask` is generated by an operation that can be decomposed /// for SME. Currently, that is just no mask, or vector.create_mask. /// TODO: Add support for vector.constant_mask once required for SME. bool isSupportedMaskOp(Value mask) { return !mask || mask.getDefiningOp(); } /// Extracts a mask for an SME sub-tile from the mask of a larger vector type. Value extractSMEMask(OpBuilder &builder, Location loc, Value mask, SMESubTile smeTile) { assert(isSupportedMaskOp(mask)); if (!mask) return Value{}; auto createMask = mask.getDefiningOp(); // The operands of `vector.create_mask` (from a 2D perspective) are the // coordinates where the mask ends. So we subtract where this tile starts, // from the mask operands to get the parameters for this sub-tile. auto smeTileMaskDims = addConstantScalableOffset( builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col}); auto smeTileCreateMask = builder.create( loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims); return smeTileCreateMask.getResult(); } /// Constructs an iterator that returns each SME tile (with coordinates) /// contained within a VectorType. For example, if decomposing an [8]x[8] into /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0), /// (4, 4). auto decomposeToSMETiles(OpBuilder &builder, VectorType type, VectorType smeTileType, bool transposeIndices = false) { return llvm::map_range( StaticTileOffsetRange( type.getShape(), {std::min(type.getDimSize(0), smeTileType.getDimSize(0)), std::min(type.getDimSize(1), smeTileType.getDimSize(1))}), [=](auto indices) { int row = int(indices[0]); int col = int(indices[1]); if (transposeIndices) std::swap(row, col); return SMESubTile{row, col, smeTileType}; }); } /// Returns the number of SME tiles that fit into the (2D-scalable) vector type /// `type`. int getNumberOfSMETilesForVectorType(VectorType type) { assert(isMultipleOfSMETileVectorType(type) && "`type` not multiple of SME tiles"); int64_t vectorRows = type.getDimSize(0); int64_t vectorCols = type.getDimSize(1); auto elementType = type.getElementType(); unsigned minNumElts = getSMETileSliceMinNumElts(elementType); return (vectorRows * vectorCols) / (minNumElts * minNumElts); } /// Legalize `arith.constant dense` splat operations to fit within SME /// tiles by decomposing them into tile-sized operations. struct LegalizeArithConstantOpsByDecomposition : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto vectorType = dyn_cast(constantOp.getType()); auto denseAttr = dyn_cast(constantOp.getValueAttr()); if (!vectorType || !denseAttr || !denseAttr.isSplat()) return failure(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(constantOp, kMatchFailureNotSMETileTypeMultiple); auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); auto tileCount = getNumberOfSMETilesForVectorType(vectorType); auto tileSplat = rewriter.create( constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); SmallVector repl(tileCount, tileSplat); rewriter.replaceOpWithMultiple(constantOp, {repl}); return success(); } }; /// Legalize `vector.outerproduct` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeVectorOuterProductOpsByDecomposition : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto vectorType = outerProductOp.getResultVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(outerProductOp, kMatchFailureNotSMETileTypeMultiple); Value mask; Operation *rootOp = outerProductOp; auto loc = outerProductOp.getLoc(); if (outerProductOp.isMasked()) { auto maskOp = outerProductOp.getMaskingOp(); mask = maskOp.getMask(); rootOp = maskOp; rewriter.setInsertionPoint(rootOp); } if (!isSupportedMaskOp(mask)) return rewriter.notifyMatchFailure(outerProductOp, kMatchFailureUnsupportedMaskOp); ValueRange accSMETiles = adaptor.getAcc(); auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0); SmallVector resultSMETiles; for (auto [index, smeTile] : llvm::enumerate( decomposeToSMETiles(rewriter, vectorType, smeTileType))) { auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); auto lhs = rewriter.create( loc, sliceType, outerProductOp.getLhs(), smeTile.row); auto rhs = rewriter.create( loc, sliceType, outerProductOp.getRhs(), smeTile.col); auto smeOuterProduct = rewriter.create( loc, smeTileType, lhs, rhs, !accSMETiles.empty() ? accSMETiles[index] : Value{}, outerProductOp.getKind()); auto maskedOuterProduct = vector::maskOperation(rewriter, smeOuterProduct, smeMask); resultSMETiles.push_back(maskedOuterProduct->getResult(0)); } rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles}); return success(); } }; // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to // get the help of the type conversion), but doing so results in the type // conversion adding target materializations in the `vector.mask` region // (invalid). This pattern matches on `vector.mask` then calls into the // `vector.outerproduct` pattern to work around this issue. struct LegalizeMaskedVectorOuterProductOpsByDecomposition : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (auto outerProductOp = llvm::dyn_cast_or_null( maskOp.getMaskableOp())) { LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(), getContext()); return static_cast(pattern).matchAndRewrite( outerProductOp, rewriter); } return failure(); } }; /// Legalize `vector.transfer_read` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeTransferReadOpsByDecomposition : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto vectorType = readOp.getVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(readOp, kMatchFailureNotSMETileTypeMultiple); auto mask = readOp.getMask(); if (!isSupportedMaskOp(mask)) return rewriter.notifyMatchFailure(readOp, kMatchFailureUnsupportedMaskOp); auto permutationMap = readOp.getPermutationMap(); if (!permutationMap.isPermutation()) return rewriter.notifyMatchFailure(readOp, kMatchFailureNonPermutationMap); // Note: For 2D vector types the only non-identity permutation is a simple // transpose [1, 0]. bool transposed = !permutationMap.isIdentity(); auto loc = readOp.getLoc(); auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); SmallVector resultSMETiles; for (SMESubTile smeTile : decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) { auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); auto smeRead = rewriter.create( loc, smeTileType, readOp.getSource(), getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile), readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask, readOp.getInBoundsAttr()); resultSMETiles.push_back(smeRead); } rewriter.replaceOpWithMultiple(readOp, {resultSMETiles}); return success(); } }; /// Legalize `vector.transfer_write` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeTransferWriteOpsByDecomposition : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto vectorType = writeOp.getVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(writeOp, kMatchFailureNotSMETileTypeMultiple); auto mask = writeOp.getMask(); if (!isSupportedMaskOp(mask)) return rewriter.notifyMatchFailure(writeOp, kMatchFailureUnsupportedMaskOp); auto permutationMap = writeOp.getPermutationMap(); if (!permutationMap.isPermutation()) return rewriter.notifyMatchFailure(writeOp, kMatchFailureNonPermutationMap); // Note: For 2D vector types the only non-identity permutation is a simple // transpose [1, 0]. bool transposed = !permutationMap.isIdentity(); auto loc = writeOp.getLoc(); auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); auto inputSMETiles = adaptor.getVector(); Value destTensorOrMemref = writeOp.getSource(); for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles( rewriter, vectorType, smeTileType, transposed))) { auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); auto smeWrite = rewriter.create( loc, inputSMETiles[index], destTensorOrMemref, getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile), writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr()); if (writeOp.hasPureTensorSemantics()) destTensorOrMemref = smeWrite.getResult(); } if (writeOp.hasPureTensorSemantics()) rewriter.replaceOp(writeOp, destTensorOrMemref); else rewriter.eraseOp(writeOp); return success(); } }; /// Legalize a multi-tile transfer_write as a single store loop. This is done as /// part of type decomposition as at this level we know each tile write is /// disjoint, but that information is lost after decomposition (without analysis /// to reconstruct it). /// /// Example (pseudo-MLIR): /// /// ``` /// vector.transfer_write %vector, %dest[%y, %x], %mask /// : vector<[16]x[8]xi16>, memref /// ``` /// Is rewritten to: /// ``` /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 { /// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐ /// : vector<[8]xi1> from vector<[16]x[8]xi1> | /// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile /// : vector<[8]xi16> from vector<[8]x[8]xi16> | /// vector.transfer_write %upper_slice, | /// %dest[%slice_idx + %y, %x], %upper_slice_mask | /// : vector<[8]xi16>, memref ┘ /// %lower_slice_idx = %slice_idx + %c8_vscale ─┐ /// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] | /// : vector<[8]xi1> from vector<[16]x[8]xi1> | /// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower /// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile /// vector.transfer_write %lower_slice, | /// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask | /// : vector<[8]xi16>, memref ┘ /// } /// ``` struct LegalizeMultiTileTransferWriteAsStoreLoop : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (writeOp.hasPureTensorSemantics()) return rewriter.notifyMatchFailure( writeOp, "TODO: tensor semantics are unsupported"); auto permutationMap = writeOp.getPermutationMap(); if (!permutationMap.isPermutation()) return rewriter.notifyMatchFailure(writeOp, kMatchFailureNonPermutationMap); bool transposed = !permutationMap.isIdentity(); if (transposed) return rewriter.notifyMatchFailure(writeOp, "TODO: transpose unsupported"); auto vectorType = writeOp.getVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(writeOp, kMatchFailureNotSMETileTypeMultiple); // Note: We also disallow masks where any dimension is > 16 because that // prevents the masking from being lowered to use arm_sve.psel. auto mask = writeOp.getMask(); if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 || vectorType.getDimSize(1) > 16))) return rewriter.notifyMatchFailure(writeOp, kMatchFailureUnsupportedMaskOp); auto loc = writeOp.getLoc(); auto createVscaleMultiple = vector::makeVscaleConstantBuilder(rewriter, loc); // Get SME tile and slice types. auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); auto minTileSlices = smeTileType.getDimSize(0); VectorType sliceMaskType = VectorType::get(minTileSlices, rewriter.getI1Type(), true); // Create loop over all tile slices. auto lowerBound = rewriter.create(loc, 0); auto upperBound = createVscaleMultiple(minTileSlices); auto step = rewriter.create(loc, 1); auto storeLoop = rewriter.create(loc, lowerBound, upperBound, step); rewriter.setInsertionPointToStart(storeLoop.getBody()); // For each sub-tile of the multi-tile `vectorType`. auto inputSMETiles = adaptor.getVector(); auto tileSliceIndex = storeLoop.getInductionVar(); for (auto [index, smeTile] : llvm::enumerate( decomposeToSMETiles(rewriter, vectorType, smeTileType))) { // The coordinates of the tile within `vectorType`. auto tileRow = createVscaleMultiple(smeTile.row); auto tileCol = createVscaleMultiple(smeTile.col); // The current slice of `vectorType` we are processing. auto sliceIndex = rewriter.create(loc, tileRow, tileSliceIndex); // Where in the destination memref the current slice will be stored. auto storeRow = rewriter.create(loc, sliceIndex, writeOp.getIndices()[0]); auto storeCol = rewriter.create(loc, tileCol, writeOp.getIndices()[1]); // Extract the mask for the current slice. Value sliceMask = nullptr; if (mask) { sliceMask = rewriter.create( loc, mask, OpFoldResult(sliceIndex)); if (sliceMaskType != sliceMask.getType()) sliceMask = rewriter.create( loc, sliceMaskType, sliceMask, smeTile.col); } // Extract and store the current slice. Value tile = inputSMETiles[index]; auto slice = rewriter.create(loc, tile, tileSliceIndex); rewriter.create( loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol}, AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)), sliceMask, rewriter.getBoolArrayAttr( ArrayRef(writeOp.getInBoundsValues()).drop_front())); } rewriter.eraseOp(writeOp); return success(); } }; //===----------------------------------------------------------------------===// // ArmSME-specific fixup canonicalizations/folds //===----------------------------------------------------------------------===// /// Folds an extract from a 3D `vector.create_mask` (which is a vector of /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is /// necessary for the mask to be lowered to ArmSME. /// /// Example: /// /// BEFORE: /// ```mlir /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> /// %subMask = vector.extract %mask[2] /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> /// ``` /// /// AFTER: /// ```mlir /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> /// ``` struct FoldExtractFromVectorOfSMELikeCreateMasks : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { auto loc = extractOp.getLoc(); auto createMaskOp = extractOp.getVector().getDefiningOp(); if (!createMaskOp) return rewriter.notifyMatchFailure( extractOp, "extract not from vector.create_mask op"); VectorType extractedMaskType = llvm::dyn_cast(extractOp.getResult().getType()); if (!extractedMaskType) return rewriter.notifyMatchFailure(extractOp, "extracted type is not a vector type"); auto numScalable = extractedMaskType.getNumScalableDims(); if (numScalable != 2) return rewriter.notifyMatchFailure( extractOp, "expected extracted type to be an SME-like mask"); // TODO: Support multiple extraction indices. if (extractOp.getStaticPosition().size() != 1) return rewriter.notifyMatchFailure( extractOp, "only a single extraction index is supported"); auto frontMaskDim = createMaskOp.getOperand(0); if (frontMaskDim.getDefiningOp()) return rewriter.notifyMatchFailure( extractOp, "constant vector.create_masks dims should be folded elsewhere"); auto zero = rewriter.create(loc, 0); auto extractionIndex = getValueOrCreateConstantIndexOp( rewriter, loc, extractOp.getMixedPosition()[0]); auto extractionInTrueRegion = rewriter.create( loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex, frontMaskDim); auto newMaskFrontDim = rewriter.create( loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero); rewriter.replaceOpWithNewOp( extractOp, extractedMaskType, ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)}); return success(); } }; /// A vector type where no fixed dimension comes after a scalable dimension. bool isLegalVectorType(VectorType vType) { bool seenFixedDim = false; for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) { seenFixedDim |= !scalableFlag; if (seenFixedDim && scalableFlag) return false; } return true; } /// Lifts an illegal vector.transpose and vector.transfer_read to a /// memref.subview + memref.transpose, followed by a legal read. /// /// 'Illegal' here means a leading scalable dimension and a fixed trailing /// dimension, which has no valid lowering. /// /// The memref.transpose is metadata-only transpose that produces a strided /// memref, which eventually becomes a loop reading individual elements. /// /// Example: /// /// BEFORE: /// ```mlir /// %illegalRead = vector.transfer_read %memref[%a, %b] /// : memref, vector<[8]x4xf32> /// %legalType = vector.transpose %illegalRead, [1, 0] /// : vector<[8]x4xf32> to vector<4x[8]xf32> /// ``` /// /// AFTER: /// ```mlir /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1] /// : memref to memref /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0) /// : memref to memref /// %legalType = vector.transfer_read %transpose[%c0, %c0] /// : memref, vector<4x[8]xf32> /// ``` struct LiftIllegalVectorTransposeToMemory : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; static Value getExtensionSource(Operation *op) { if (isa_and_present(op)) return op->getOperand(0); return {}; } LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { auto sourceType = transposeOp.getSourceVectorType(); auto resultType = transposeOp.getResultVectorType(); if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) return rewriter.notifyMatchFailure(transposeOp, kMatchFailureNotIllegalToLegal); // Look through extend for transfer_read. Value maybeRead = transposeOp.getVector(); auto *transposeSourceOp = maybeRead.getDefiningOp(); Operation *extendOp = nullptr; if (Value extendSource = getExtensionSource(transposeSourceOp)) { maybeRead = extendSource; extendOp = transposeSourceOp; } auto illegalRead = maybeRead.getDefiningOp(); if (!illegalRead) return rewriter.notifyMatchFailure( transposeOp, "expected source to be (possibly extended) transfer_read"); if (!illegalRead.getPermutationMap().isIdentity()) return rewriter.notifyMatchFailure( illegalRead, "expected read to have identity permutation map"); auto loc = transposeOp.getLoc(); auto zero = rewriter.create(loc, 0); auto one = rewriter.create(loc, 1); // Create a subview that matches the size of the illegal read vector type. auto readType = illegalRead.getVectorType(); auto readSizes = llvm::map_to_vector( llvm::zip_equal(readType.getShape(), readType.getScalableDims()), [&](auto dim) -> Value { auto [size, isScalable] = dim; auto dimSize = rewriter.create(loc, size); if (!isScalable) return dimSize; auto vscale = rewriter.create(loc); return rewriter.create(loc, vscale, dimSize); }); SmallVector strides(readType.getRank(), Value(one)); auto readSubview = rewriter.create( loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes, strides); // Apply the transpose to all values/attributes of the transfer_read: // - The mask Value mask = illegalRead.getMask(); if (mask) { // Note: The transpose for the mask should fold into the // vector.create_mask/constant_mask op, which will then become legal. mask = rewriter.create(loc, mask, transposeOp.getPermutation()); } // - The source memref mlir::AffineMap transposeMap = AffineMap::getPermutationMap( transposeOp.getPermutation(), getContext()); auto transposedSubview = rewriter.create( loc, readSubview, AffineMapAttr::get(transposeMap)); ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr(); // - The `in_bounds` attribute if (inBoundsAttr) { SmallVector inBoundsValues(inBoundsAttr.begin(), inBoundsAttr.end()); applyPermutationToVector(inBoundsValues, transposeOp.getPermutation()); inBoundsAttr = rewriter.getArrayAttr(inBoundsValues); } VectorType legalReadType = resultType.clone(readType.getElementType()); // Note: The indices are all zero as the subview is already offset. SmallVector readIndices(illegalRead.getIndices().size(), zero); auto legalRead = rewriter.create( loc, legalReadType, transposedSubview, readIndices, illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask, inBoundsAttr); // Replace the transpose with the new read, extending the result if // necessary. rewriter.replaceOp(transposeOp, [&]() -> Operation * { if (extendOp) return rewriter.create(loc, extendOp->getName().getIdentifier(), Value(legalRead), resultType); return legalRead; }()); return success(); } }; /// A rewrite to turn unit dim transpose-like vector.shape_casts into /// vector.transposes. The shape_cast has to be from an illegal vector type to a /// legal one (as defined by isLegalVectorType). /// /// The reasoning for this is if we've got to this pass and we still have /// shape_casts of illegal types, then they likely will not cancel out. Turning /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to /// eliminate them. /// /// Example: /// /// BEFORE: /// ```mlir /// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32> /// ``` /// /// AFTER: /// ```mlir /// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> /// ``` struct ConvertIllegalShapeCastOpsToTransposes : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { auto sourceType = shapeCastOp.getSourceVectorType(); auto resultType = shapeCastOp.getResultVectorType(); if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) return rewriter.notifyMatchFailure(shapeCastOp, kMatchFailureNotIllegalToLegal); // Note: If we know that `sourceType` is an illegal vector type (and 2D) // then dim 0 is scalable and dim 1 is fixed. if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1) return rewriter.notifyMatchFailure( shapeCastOp, "expected source to be a 2D scalable vector with a " "trailing unit dim"); auto loc = shapeCastOp.getLoc(); auto transpose = rewriter.create( loc, shapeCastOp.getSource(), ArrayRef{1, 0}); if (resultType.getRank() == 1) rewriter.replaceOpWithNewOp(shapeCastOp, resultType, transpose); else rewriter.replaceOp(shapeCastOp, transpose); return success(); } }; /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use /// the ZA state. This workaround rewrite to support these transposes when ZA is /// available. /// /// Example: /// /// BEFORE: /// ```mlir /// %transpose = vector.transpose %vec, [1, 0] /// : vector<2x[4]xf32> to vector<[4]x2xf32> /// vector.transfer_write %transpose, %dest[%y, %x] /// : vector<[4]x2xf32>, memref /// ``` /// /// AFTER: /// ```mlir /// %0 = arm_sme.get_tile : vector<[4]x[4]xf32> /// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32> /// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32> /// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32> /// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32> /// %c4_vscale = arith.muli %vscale, %c4 : index /// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1> /// vector.transfer_write %4, %dest[%y, %x], %mask /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} /// : vector<[4]x[4]xf32>, memref /// ``` /// /// Values larger than a single tile are supported via decomposition. struct LowerIllegalTransposeStoreViaZA : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { if (!isSupportedMaskOp(writeOp.getMask())) return rewriter.notifyMatchFailure(writeOp, kMatchFailureUnsupportedMaskOp); auto permutationMap = writeOp.getPermutationMap(); if (!permutationMap.isIdentity()) return rewriter.notifyMatchFailure(writeOp, kMatchFailureNonPermutationMap); auto transposeOp = writeOp.getVector().getDefiningOp(); if (!transposeOp) return failure(); auto sourceType = transposeOp.getSourceVectorType(); auto resultType = transposeOp.getResultVectorType(); if (resultType.getRank() != 2) return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2"); if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType)) return rewriter.notifyMatchFailure( transposeOp, "not illegal/unsupported SVE transpose"); auto smeTileType = getSMETileTypeForElement(resultType.getElementType()); VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0); if (sourceType.getDimSize(0) <= 1 || sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0) return rewriter.notifyMatchFailure(writeOp, "unsupported source shape"); auto loc = writeOp.getLoc(); auto createVscaleMultiple = vector::makeVscaleConstantBuilder(rewriter, loc); auto transposeMap = AffineMapAttr::get( AffineMap::getPermutationMap(ArrayRef{1, 0}, getContext())); // Note: We need to use `get_tile` as there's no vector-level `undef`. Value undefTile = rewriter.create(loc, smeTileType); Value destTensorOrMemref = writeOp.getSource(); auto numSlicesPerTile = std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0)); auto numSlices = rewriter.create(loc, numSlicesPerTile); for (auto [index, smeTile] : llvm::enumerate( decomposeToSMETiles(rewriter, sourceType, smeTileType))) { // 1. _Deliberately_ drop a scalable dimension and insert a fixed number // of slices from the source type into the SME tile. Without checking // vscale (and emitting multiple implementations) we can't make use of the // rows of the tile after 1*vscale rows. Value tile = undefTile; for (int d = 0; d < numSlicesPerTile; ++d) { Value vector = rewriter.create( loc, transposeOp.getVector(), rewriter.getIndexAttr(d + smeTile.row)); if (vector.getType() != smeSliceType) { vector = rewriter.create( loc, smeSliceType, vector, smeTile.col); } tile = rewriter.create(loc, vector, tile, d); } // 2. Transpose the tile position. auto transposedRow = createVscaleMultiple(smeTile.col); auto transposedCol = rewriter.create(loc, smeTile.row); // 3. Compute mask for tile store. Value maskRows; Value maskCols; if (auto mask = writeOp.getMask()) { auto createMask = mask.getDefiningOp(); maskRows = rewriter.create(loc, createMask.getOperand(0), transposedRow); maskCols = rewriter.create(loc, createMask.getOperand(1), transposedCol); maskCols = rewriter.create(loc, maskCols, numSlices); } else { maskRows = createVscaleMultiple(smeTileType.getDimSize(0)); maskCols = numSlices; } auto subMask = rewriter.create( loc, smeTileType.clone(rewriter.getI1Type()), ValueRange{maskRows, maskCols}); // 4. Emit a transposed tile write. auto writeIndices = writeOp.getIndices(); Value destRow = rewriter.create(loc, transposedRow, writeIndices[0]); Value destCol = rewriter.create(loc, transposedCol, writeIndices[1]); auto smeWrite = rewriter.create( loc, tile, destTensorOrMemref, ValueRange{destRow, destCol}, transposeMap, subMask, writeOp.getInBounds()); if (writeOp.hasPureTensorSemantics()) destTensorOrMemref = smeWrite.getResult(); } if (writeOp.hasPureTensorSemantics()) rewriter.replaceOp(writeOp, destTensorOrMemref); else rewriter.eraseOp(writeOp); return success(); } }; struct VectorLegalizationPass : public arm_sme::impl::VectorLegalizationBase { void runOnOperation() override { auto *context = &getContext(); TypeConverter converter; RewritePatternSet patterns(context); converter.addConversion([](Type type) { return type; }); converter.addConversion( [](VectorType vectorType, SmallVectorImpl &types) -> std::optional { if (!isMultipleOfSMETileVectorType(vectorType)) return std::nullopt; auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType); auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); types = SmallVector(smeTileCount, smeTileType); return success(); }); // Apply preprocessing patterns. RewritePatternSet rewritePatterns(context); rewritePatterns.add(context); if (failed( applyPatternsGreedily(getOperation(), std::move(rewritePatterns)))) return signalPassFailure(); // Note: These two patterns are added with a high benefit to ensure: // - Masked outer products are handled before unmasked ones // - Multi-tile writes are lowered as a store loop (if possible) patterns.add(converter, context, /*benefit=*/1024); patterns.add(converter, context); populateFunctionOpInterfaceTypeConversionPattern(patterns, converter); populateCallOpTypeConversionPattern(patterns, converter); populateReturnOpTypeConversionPattern(patterns, converter); scf::populateSCFStructuralTypeConversions(converter, patterns); ConversionTarget target(getContext()); target.markUnknownOpDynamicallyLegal( [&](Operation *op) { return converter.isLegal(op); }); target.addDynamicallyLegalOp([&](func::FuncOp op) { return converter.isSignatureLegal(op.getFunctionType()); }); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); } }; } // namespace std::unique_ptr mlir::arm_sme::createVectorLegalizationPass() { return std::make_unique(); }