1 //===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass legalizes vector operations so they can be lowered to ArmSME. 10 // 11 // Note: In the context of this pass 'tile' always refers to an SME tile. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Arith/Utils/Utils.h" 16 #include "mlir/Dialect/ArmSME/IR/ArmSME.h" 17 #include "mlir/Dialect/ArmSME/Transforms/Passes.h" 18 #include "mlir/Dialect/ArmSME/Utils/Utils.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/Dialect/Func/Transforms/FuncConversions.h" 21 #include "mlir/Dialect/Index/IR/IndexDialect.h" 22 #include "mlir/Dialect/Index/IR/IndexOps.h" 23 #include "mlir/Dialect/MemRef/IR/MemRef.h" 24 #include "mlir/Dialect/SCF/IR/SCF.h" 25 #include "mlir/Dialect/SCF/Transforms/Patterns.h" 26 #include "mlir/Dialect/Utils/IndexingUtils.h" 27 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 28 #include "mlir/Transforms/DialectConversion.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 31 #define DEBUG_TYPE "arm-sme-vector-legalization" 32 33 namespace mlir::arm_sme { 34 #define GEN_PASS_DEF_VECTORLEGALIZATION 35 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" 36 } // namespace mlir::arm_sme 37 38 using namespace mlir; 39 using namespace mlir::arm_sme; 40 41 namespace { 42 43 //===----------------------------------------------------------------------===// 44 // Decomposition of vector operations larger than an SME tile 45 //===----------------------------------------------------------------------===// 46 47 // Common match failure reasons. 48 static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple( 49 "op vector size is not multiple of SME tiles"); 50 static constexpr StringLiteral kMatchFailureUnsupportedMaskOp( 51 "op mask is unsupported for legalization/decomposition"); 52 static constexpr StringLiteral 53 kMatchFailureNonPermutationMap("op affine map is not a permutation"); 54 static constexpr StringLiteral kMatchFailureNotIllegalToLegal( 55 "expected transpose from illegal type to legal type"); 56 57 /// An SMESubTile represents a single SME-sized sub-tile from decomposing a 58 /// larger vector type. The (`row`, `col`) are the position of the tile in the 59 /// original vector type. For example for an [8]x[8] tile with four [4]x[4] 60 /// sub-tiles, we would have: 61 /// 62 /// 8 x vscale 63 /// ┌─────────────┬─────────────┐ 64 /// │(0,0) │(0,4) │ 65 /// │ │ │ 66 /// ├─────────────┼─────────────┤ 8 x vscale 67 /// │(4,0) │(4,4) │ 68 /// │ │ │ 69 /// └─────────────┴─────────────┘ 70 struct SMESubTile { 71 // Note: The units of (row, col) are vscale (as SME tiles are scalable). 72 int row{0}; 73 int col{0}; 74 // The SME tile type. 75 VectorType type; 76 }; 77 78 /// Adds a constant elementwise scalable offset to `indices` (which are of equal 79 /// length). For example, in the 2D case this would return: 80 // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale } 81 SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder, 82 Location loc, 83 ValueRange indices, 84 ArrayRef<int> scalableOffsets) { 85 auto vscale = builder.create<vector::VectorScaleOp>(loc); 86 return llvm::map_to_vector( 87 llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value { 88 auto [index, base] = pair; 89 auto offset = builder.create<arith::MulIOp>( 90 loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale); 91 return builder.create<arith::AddIOp>(loc, index, offset); 92 }); 93 } 94 95 /// Adjusts `indices` (e.g. from a load/store) for a larger vector type to 96 /// indices for one of the SME sub-tiles it will decompose into. 97 /// 98 /// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the 99 /// indices for each tile would need to be adjusted as follows: 100 /// 101 /// initial indices = [a,b], inital size = 8x8, target size = 4x4 102 /// ┌─────────────┬─────────────┐ 103 /// │[a,b] │[a,b+4] │ 104 /// │ │ │ 105 /// ├─────────────┼─────────────┤ 106 /// │[a+4,b] │[a+4,b+4] │ 107 /// │ │ │ 108 /// └─────────────┴─────────────┘ 109 SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc, 110 ValueRange indices, 111 SMESubTile smeTile) { 112 return addConstantScalableOffset(builder, loc, indices, 113 {smeTile.row, smeTile.col}); 114 } 115 116 /// Returns true if `mask` is generated by an operation that can be decomposed 117 /// for SME. Currently, that is just no mask, or vector.create_mask. 118 /// TODO: Add support for vector.constant_mask once required for SME. 119 bool isSupportedMaskOp(Value mask) { 120 return !mask || mask.getDefiningOp<vector::CreateMaskOp>(); 121 } 122 123 /// Extracts a mask for an SME sub-tile from the mask of a larger vector type. 124 Value extractSMEMask(OpBuilder &builder, Location loc, Value mask, 125 SMESubTile smeTile) { 126 assert(isSupportedMaskOp(mask)); 127 if (!mask) 128 return Value{}; 129 auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); 130 // The operands of `vector.create_mask` (from a 2D perspective) are the 131 // coordinates where the mask ends. So we subtract where this tile starts, 132 // from the mask operands to get the parameters for this sub-tile. 133 auto smeTileMaskDims = addConstantScalableOffset( 134 builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col}); 135 auto smeTileCreateMask = builder.create<vector::CreateMaskOp>( 136 loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims); 137 return smeTileCreateMask.getResult(); 138 } 139 140 /// Constructs an iterator that returns each SME tile (with coordinates) 141 /// contained within a VectorType. For example, if decomposing an [8]x[8] into 142 /// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0), 143 /// (4, 4). 144 auto decomposeToSMETiles(OpBuilder &builder, VectorType type, 145 VectorType smeTileType, 146 bool transposeIndices = false) { 147 return llvm::map_range( 148 StaticTileOffsetRange( 149 type.getShape(), 150 {std::min(type.getDimSize(0), smeTileType.getDimSize(0)), 151 std::min(type.getDimSize(1), smeTileType.getDimSize(1))}), 152 [=](auto indices) { 153 int row = int(indices[0]); 154 int col = int(indices[1]); 155 if (transposeIndices) 156 std::swap(row, col); 157 return SMESubTile{row, col, smeTileType}; 158 }); 159 } 160 161 /// Returns the number of SME tiles that fit into the (2D-scalable) vector type 162 /// `type`. 163 int getNumberOfSMETilesForVectorType(VectorType type) { 164 assert(isMultipleOfSMETileVectorType(type) && 165 "`type` not multiple of SME tiles"); 166 int64_t vectorRows = type.getDimSize(0); 167 int64_t vectorCols = type.getDimSize(1); 168 auto elementType = type.getElementType(); 169 unsigned minNumElts = getSMETileSliceMinNumElts(elementType); 170 return (vectorRows * vectorCols) / (minNumElts * minNumElts); 171 } 172 173 /// Legalize `arith.constant dense<value>` splat operations to fit within SME 174 /// tiles by decomposing them into tile-sized operations. 175 struct LegalizeArithConstantOpsByDecomposition 176 : public OpConversionPattern<arith::ConstantOp> { 177 using OpConversionPattern::OpConversionPattern; 178 179 LogicalResult 180 matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, 181 ConversionPatternRewriter &rewriter) const override { 182 auto vectorType = dyn_cast<VectorType>(constantOp.getType()); 183 auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); 184 if (!vectorType || !denseAttr || !denseAttr.isSplat()) 185 return failure(); 186 187 if (!isMultipleOfSMETileVectorType(vectorType)) 188 return rewriter.notifyMatchFailure(constantOp, 189 kMatchFailureNotSMETileTypeMultiple); 190 191 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); 192 auto tileCount = getNumberOfSMETilesForVectorType(vectorType); 193 auto tileSplat = rewriter.create<arith::ConstantOp>( 194 constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); 195 SmallVector<Value> repl(tileCount, tileSplat); 196 rewriter.replaceOpWithMultiple(constantOp, {repl}); 197 198 return success(); 199 } 200 }; 201 202 /// Legalize `vector.outerproduct` operations to fit within SME tiles by 203 /// decomposing them into tile-sized operations. 204 struct LegalizeVectorOuterProductOpsByDecomposition 205 : public OpConversionPattern<vector::OuterProductOp> { 206 using OpConversionPattern::OpConversionPattern; 207 208 LogicalResult 209 matchAndRewrite(vector::OuterProductOp outerProductOp, 210 OneToNOpAdaptor adaptor, 211 ConversionPatternRewriter &rewriter) const override { 212 auto vectorType = outerProductOp.getResultVectorType(); 213 if (!isMultipleOfSMETileVectorType(vectorType)) 214 return rewriter.notifyMatchFailure(outerProductOp, 215 kMatchFailureNotSMETileTypeMultiple); 216 217 Value mask; 218 Operation *rootOp = outerProductOp; 219 auto loc = outerProductOp.getLoc(); 220 if (outerProductOp.isMasked()) { 221 auto maskOp = outerProductOp.getMaskingOp(); 222 mask = maskOp.getMask(); 223 rootOp = maskOp; 224 rewriter.setInsertionPoint(rootOp); 225 } 226 227 if (!isSupportedMaskOp(mask)) 228 return rewriter.notifyMatchFailure(outerProductOp, 229 kMatchFailureUnsupportedMaskOp); 230 231 ValueRange accSMETiles = adaptor.getAcc(); 232 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); 233 VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0); 234 235 SmallVector<Value> resultSMETiles; 236 for (auto [index, smeTile] : llvm::enumerate( 237 decomposeToSMETiles(rewriter, vectorType, smeTileType))) { 238 239 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); 240 auto lhs = rewriter.create<vector::ScalableExtractOp>( 241 loc, sliceType, outerProductOp.getLhs(), smeTile.row); 242 auto rhs = rewriter.create<vector::ScalableExtractOp>( 243 loc, sliceType, outerProductOp.getRhs(), smeTile.col); 244 auto smeOuterProduct = rewriter.create<vector::OuterProductOp>( 245 loc, smeTileType, lhs, rhs, 246 !accSMETiles.empty() ? accSMETiles[index] : Value{}, 247 outerProductOp.getKind()); 248 249 auto maskedOuterProduct = 250 vector::maskOperation(rewriter, smeOuterProduct, smeMask); 251 resultSMETiles.push_back(maskedOuterProduct->getResult(0)); 252 } 253 254 rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles}); 255 return success(); 256 } 257 }; 258 259 // Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to 260 // get the help of the type conversion), but doing so results in the type 261 // conversion adding target materializations in the `vector.mask` region 262 // (invalid). This pattern matches on `vector.mask` then calls into the 263 // `vector.outerproduct` pattern to work around this issue. 264 struct LegalizeMaskedVectorOuterProductOpsByDecomposition 265 : public OpConversionPattern<vector::MaskOp> { 266 using OpConversionPattern::OpConversionPattern; 267 268 LogicalResult 269 matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor, 270 ConversionPatternRewriter &rewriter) const override { 271 if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>( 272 maskOp.getMaskableOp())) { 273 LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(), 274 getContext()); 275 return static_cast<RewritePattern &>(pattern).matchAndRewrite( 276 outerProductOp, rewriter); 277 } 278 return failure(); 279 } 280 }; 281 282 /// Legalize `vector.transfer_read` operations to fit within SME tiles by 283 /// decomposing them into tile-sized operations. 284 struct LegalizeTransferReadOpsByDecomposition 285 : public OpConversionPattern<vector::TransferReadOp> { 286 using OpConversionPattern::OpConversionPattern; 287 288 LogicalResult 289 matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor, 290 ConversionPatternRewriter &rewriter) const override { 291 auto vectorType = readOp.getVectorType(); 292 if (!isMultipleOfSMETileVectorType(vectorType)) 293 return rewriter.notifyMatchFailure(readOp, 294 kMatchFailureNotSMETileTypeMultiple); 295 296 auto mask = readOp.getMask(); 297 if (!isSupportedMaskOp(mask)) 298 return rewriter.notifyMatchFailure(readOp, 299 kMatchFailureUnsupportedMaskOp); 300 301 auto permutationMap = readOp.getPermutationMap(); 302 if (!permutationMap.isPermutation()) 303 return rewriter.notifyMatchFailure(readOp, 304 kMatchFailureNonPermutationMap); 305 306 // Note: For 2D vector types the only non-identity permutation is a simple 307 // transpose [1, 0]. 308 bool transposed = !permutationMap.isIdentity(); 309 310 auto loc = readOp.getLoc(); 311 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); 312 313 SmallVector<Value> resultSMETiles; 314 for (SMESubTile smeTile : 315 decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) { 316 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); 317 auto smeRead = rewriter.create<vector::TransferReadOp>( 318 loc, smeTileType, readOp.getSource(), 319 getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile), 320 readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask, 321 readOp.getInBoundsAttr()); 322 resultSMETiles.push_back(smeRead); 323 } 324 325 rewriter.replaceOpWithMultiple(readOp, {resultSMETiles}); 326 return success(); 327 } 328 }; 329 330 /// Legalize `vector.transfer_write` operations to fit within SME tiles by 331 /// decomposing them into tile-sized operations. 332 struct LegalizeTransferWriteOpsByDecomposition 333 : public OpConversionPattern<vector::TransferWriteOp> { 334 using OpConversionPattern::OpConversionPattern; 335 336 LogicalResult 337 matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, 338 ConversionPatternRewriter &rewriter) const override { 339 auto vectorType = writeOp.getVectorType(); 340 if (!isMultipleOfSMETileVectorType(vectorType)) 341 return rewriter.notifyMatchFailure(writeOp, 342 kMatchFailureNotSMETileTypeMultiple); 343 344 auto mask = writeOp.getMask(); 345 if (!isSupportedMaskOp(mask)) 346 return rewriter.notifyMatchFailure(writeOp, 347 kMatchFailureUnsupportedMaskOp); 348 349 auto permutationMap = writeOp.getPermutationMap(); 350 if (!permutationMap.isPermutation()) 351 return rewriter.notifyMatchFailure(writeOp, 352 kMatchFailureNonPermutationMap); 353 354 // Note: For 2D vector types the only non-identity permutation is a simple 355 // transpose [1, 0]. 356 bool transposed = !permutationMap.isIdentity(); 357 358 auto loc = writeOp.getLoc(); 359 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); 360 auto inputSMETiles = adaptor.getVector(); 361 362 Value destTensorOrMemref = writeOp.getSource(); 363 for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles( 364 rewriter, vectorType, smeTileType, transposed))) { 365 auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); 366 auto smeWrite = rewriter.create<vector::TransferWriteOp>( 367 loc, inputSMETiles[index], destTensorOrMemref, 368 getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile), 369 writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr()); 370 if (writeOp.hasPureTensorSemantics()) 371 destTensorOrMemref = smeWrite.getResult(); 372 } 373 374 if (writeOp.hasPureTensorSemantics()) 375 rewriter.replaceOp(writeOp, destTensorOrMemref); 376 else 377 rewriter.eraseOp(writeOp); 378 379 return success(); 380 } 381 }; 382 383 /// Legalize a multi-tile transfer_write as a single store loop. This is done as 384 /// part of type decomposition as at this level we know each tile write is 385 /// disjoint, but that information is lost after decomposition (without analysis 386 /// to reconstruct it). 387 /// 388 /// Example (pseudo-MLIR): 389 /// 390 /// ``` 391 /// vector.transfer_write %vector, %dest[%y, %x], %mask 392 /// : vector<[16]x[8]xi16>, memref<?x?xi16> 393 /// ``` 394 /// Is rewritten to: 395 /// ``` 396 /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 { 397 /// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐ 398 /// : vector<[8]xi1> from vector<[16]x[8]xi1> | 399 /// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile 400 /// : vector<[8]xi16> from vector<[8]x[8]xi16> | 401 /// vector.transfer_write %upper_slice, | 402 /// %dest[%slice_idx + %y, %x], %upper_slice_mask | 403 /// : vector<[8]xi16>, memref<?x?xi16> ┘ 404 /// %lower_slice_idx = %slice_idx + %c8_vscale ─┐ 405 /// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] | 406 /// : vector<[8]xi1> from vector<[16]x[8]xi1> | 407 /// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower 408 /// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile 409 /// vector.transfer_write %lower_slice, | 410 /// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask | 411 /// : vector<[8]xi16>, memref<?x?xi16> ┘ 412 /// } 413 /// ``` 414 struct LegalizeMultiTileTransferWriteAsStoreLoop 415 : public OpConversionPattern<vector::TransferWriteOp> { 416 using OpConversionPattern::OpConversionPattern; 417 418 LogicalResult 419 matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, 420 ConversionPatternRewriter &rewriter) const override { 421 if (writeOp.hasPureTensorSemantics()) 422 return rewriter.notifyMatchFailure( 423 writeOp, "TODO: tensor semantics are unsupported"); 424 425 auto permutationMap = writeOp.getPermutationMap(); 426 if (!permutationMap.isPermutation()) 427 return rewriter.notifyMatchFailure(writeOp, 428 kMatchFailureNonPermutationMap); 429 430 bool transposed = !permutationMap.isIdentity(); 431 if (transposed) 432 return rewriter.notifyMatchFailure(writeOp, 433 "TODO: transpose unsupported"); 434 435 auto vectorType = writeOp.getVectorType(); 436 if (!isMultipleOfSMETileVectorType(vectorType)) 437 return rewriter.notifyMatchFailure(writeOp, 438 kMatchFailureNotSMETileTypeMultiple); 439 440 // Note: We also disallow masks where any dimension is > 16 because that 441 // prevents the masking from being lowered to use arm_sve.psel. 442 auto mask = writeOp.getMask(); 443 if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 || 444 vectorType.getDimSize(1) > 16))) 445 return rewriter.notifyMatchFailure(writeOp, 446 kMatchFailureUnsupportedMaskOp); 447 448 auto loc = writeOp.getLoc(); 449 auto createVscaleMultiple = 450 vector::makeVscaleConstantBuilder(rewriter, loc); 451 452 // Get SME tile and slice types. 453 auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); 454 auto minTileSlices = smeTileType.getDimSize(0); 455 VectorType sliceMaskType = 456 VectorType::get(minTileSlices, rewriter.getI1Type(), true); 457 458 // Create loop over all tile slices. 459 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); 460 auto upperBound = createVscaleMultiple(minTileSlices); 461 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 462 auto storeLoop = 463 rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); 464 rewriter.setInsertionPointToStart(storeLoop.getBody()); 465 466 // For each sub-tile of the multi-tile `vectorType`. 467 auto inputSMETiles = adaptor.getVector(); 468 auto tileSliceIndex = storeLoop.getInductionVar(); 469 for (auto [index, smeTile] : llvm::enumerate( 470 decomposeToSMETiles(rewriter, vectorType, smeTileType))) { 471 // The coordinates of the tile within `vectorType`. 472 auto tileRow = createVscaleMultiple(smeTile.row); 473 auto tileCol = createVscaleMultiple(smeTile.col); 474 475 // The current slice of `vectorType` we are processing. 476 auto sliceIndex = 477 rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex); 478 479 // Where in the destination memref the current slice will be stored. 480 auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex, 481 writeOp.getIndices()[0]); 482 auto storeCol = 483 rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]); 484 485 // Extract the mask for the current slice. 486 Value sliceMask = nullptr; 487 if (mask) { 488 sliceMask = rewriter.create<vector::ExtractOp>( 489 loc, mask, OpFoldResult(sliceIndex)); 490 if (sliceMaskType != sliceMask.getType()) 491 sliceMask = rewriter.create<vector::ScalableExtractOp>( 492 loc, sliceMaskType, sliceMask, smeTile.col); 493 } 494 495 // Extract and store the current slice. 496 Value tile = inputSMETiles[index]; 497 auto slice = 498 rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex); 499 rewriter.create<vector::TransferWriteOp>( 500 loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol}, 501 AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)), 502 sliceMask, 503 rewriter.getBoolArrayAttr( 504 ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front())); 505 } 506 507 rewriter.eraseOp(writeOp); 508 return success(); 509 } 510 }; 511 512 //===----------------------------------------------------------------------===// 513 // ArmSME-specific fixup canonicalizations/folds 514 //===----------------------------------------------------------------------===// 515 516 /// Folds an extract from a 3D `vector.create_mask` (which is a vector of 517 /// SME-like masks), into a compare and a 2D `vector.create_mask`. This is 518 /// necessary for the mask to be lowered to ArmSME. 519 /// 520 /// Example: 521 /// 522 /// BEFORE: 523 /// ```mlir 524 /// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> 525 /// %subMask = vector.extract %mask[2] 526 /// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> 527 /// ``` 528 /// 529 /// AFTER: 530 /// ```mlir 531 /// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index 532 /// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index 533 /// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> 534 /// ``` 535 struct FoldExtractFromVectorOfSMELikeCreateMasks 536 : public OpRewritePattern<vector::ExtractOp> { 537 using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; 538 539 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 540 PatternRewriter &rewriter) const override { 541 auto loc = extractOp.getLoc(); 542 auto createMaskOp = 543 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); 544 if (!createMaskOp) 545 return rewriter.notifyMatchFailure( 546 extractOp, "extract not from vector.create_mask op"); 547 548 VectorType extractedMaskType = 549 llvm::dyn_cast<VectorType>(extractOp.getResult().getType()); 550 if (!extractedMaskType) 551 return rewriter.notifyMatchFailure(extractOp, 552 "extracted type is not a vector type"); 553 554 auto numScalable = extractedMaskType.getNumScalableDims(); 555 if (numScalable != 2) 556 return rewriter.notifyMatchFailure( 557 extractOp, "expected extracted type to be an SME-like mask"); 558 559 // TODO: Support multiple extraction indices. 560 if (extractOp.getStaticPosition().size() != 1) 561 return rewriter.notifyMatchFailure( 562 extractOp, "only a single extraction index is supported"); 563 564 auto frontMaskDim = createMaskOp.getOperand(0); 565 if (frontMaskDim.getDefiningOp<arith::ConstantOp>()) 566 return rewriter.notifyMatchFailure( 567 extractOp, 568 "constant vector.create_masks dims should be folded elsewhere"); 569 570 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 571 auto extractionIndex = getValueOrCreateConstantIndexOp( 572 rewriter, loc, extractOp.getMixedPosition()[0]); 573 auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>( 574 loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex, 575 frontMaskDim); 576 auto newMaskFrontDim = rewriter.create<arith::SelectOp>( 577 loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero); 578 579 rewriter.replaceOpWithNewOp<vector::CreateMaskOp>( 580 extractOp, extractedMaskType, 581 ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)}); 582 return success(); 583 } 584 }; 585 586 /// A vector type where no fixed dimension comes after a scalable dimension. 587 bool isLegalVectorType(VectorType vType) { 588 bool seenFixedDim = false; 589 for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) { 590 seenFixedDim |= !scalableFlag; 591 if (seenFixedDim && scalableFlag) 592 return false; 593 } 594 return true; 595 } 596 597 /// Lifts an illegal vector.transpose and vector.transfer_read to a 598 /// memref.subview + memref.transpose, followed by a legal read. 599 /// 600 /// 'Illegal' here means a leading scalable dimension and a fixed trailing 601 /// dimension, which has no valid lowering. 602 /// 603 /// The memref.transpose is metadata-only transpose that produces a strided 604 /// memref, which eventually becomes a loop reading individual elements. 605 /// 606 /// Example: 607 /// 608 /// BEFORE: 609 /// ```mlir 610 /// %illegalRead = vector.transfer_read %memref[%a, %b] 611 /// : memref<?x?xf32>, vector<[8]x4xf32> 612 /// %legalType = vector.transpose %illegalRead, [1, 0] 613 /// : vector<[8]x4xf32> to vector<4x[8]xf32> 614 /// ``` 615 /// 616 /// AFTER: 617 /// ```mlir 618 /// %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1] 619 /// : memref<?x?xf32> to memref<?x?xf32> 620 /// %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0) 621 /// : memref<?x?xf32> to memref<?x?xf32> 622 /// %legalType = vector.transfer_read %transpose[%c0, %c0] 623 /// : memref<?x?xf32>, vector<4x[8]xf32> 624 /// ``` 625 struct LiftIllegalVectorTransposeToMemory 626 : public OpRewritePattern<vector::TransposeOp> { 627 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 628 629 static Value getExtensionSource(Operation *op) { 630 if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op)) 631 return op->getOperand(0); 632 return {}; 633 } 634 635 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 636 PatternRewriter &rewriter) const override { 637 auto sourceType = transposeOp.getSourceVectorType(); 638 auto resultType = transposeOp.getResultVectorType(); 639 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) 640 return rewriter.notifyMatchFailure(transposeOp, 641 kMatchFailureNotIllegalToLegal); 642 643 // Look through extend for transfer_read. 644 Value maybeRead = transposeOp.getVector(); 645 auto *transposeSourceOp = maybeRead.getDefiningOp(); 646 Operation *extendOp = nullptr; 647 if (Value extendSource = getExtensionSource(transposeSourceOp)) { 648 maybeRead = extendSource; 649 extendOp = transposeSourceOp; 650 } 651 652 auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>(); 653 if (!illegalRead) 654 return rewriter.notifyMatchFailure( 655 transposeOp, 656 "expected source to be (possibly extended) transfer_read"); 657 658 if (!illegalRead.getPermutationMap().isIdentity()) 659 return rewriter.notifyMatchFailure( 660 illegalRead, "expected read to have identity permutation map"); 661 662 auto loc = transposeOp.getLoc(); 663 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 664 auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); 665 666 // Create a subview that matches the size of the illegal read vector type. 667 auto readType = illegalRead.getVectorType(); 668 auto readSizes = llvm::map_to_vector( 669 llvm::zip_equal(readType.getShape(), readType.getScalableDims()), 670 [&](auto dim) -> Value { 671 auto [size, isScalable] = dim; 672 auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size); 673 if (!isScalable) 674 return dimSize; 675 auto vscale = rewriter.create<vector::VectorScaleOp>(loc); 676 return rewriter.create<arith::MulIOp>(loc, vscale, dimSize); 677 }); 678 SmallVector<Value> strides(readType.getRank(), Value(one)); 679 auto readSubview = rewriter.create<memref::SubViewOp>( 680 loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes, 681 strides); 682 683 // Apply the transpose to all values/attributes of the transfer_read: 684 // - The mask 685 Value mask = illegalRead.getMask(); 686 if (mask) { 687 // Note: The transpose for the mask should fold into the 688 // vector.create_mask/constant_mask op, which will then become legal. 689 mask = rewriter.create<vector::TransposeOp>(loc, mask, 690 transposeOp.getPermutation()); 691 } 692 // - The source memref 693 mlir::AffineMap transposeMap = AffineMap::getPermutationMap( 694 transposeOp.getPermutation(), getContext()); 695 auto transposedSubview = rewriter.create<memref::TransposeOp>( 696 loc, readSubview, AffineMapAttr::get(transposeMap)); 697 ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr(); 698 // - The `in_bounds` attribute 699 if (inBoundsAttr) { 700 SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(), 701 inBoundsAttr.end()); 702 applyPermutationToVector(inBoundsValues, transposeOp.getPermutation()); 703 inBoundsAttr = rewriter.getArrayAttr(inBoundsValues); 704 } 705 706 VectorType legalReadType = resultType.clone(readType.getElementType()); 707 // Note: The indices are all zero as the subview is already offset. 708 SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero); 709 auto legalRead = rewriter.create<vector::TransferReadOp>( 710 loc, legalReadType, transposedSubview, readIndices, 711 illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask, 712 inBoundsAttr); 713 714 // Replace the transpose with the new read, extending the result if 715 // necessary. 716 rewriter.replaceOp(transposeOp, [&]() -> Operation * { 717 if (extendOp) 718 return rewriter.create(loc, extendOp->getName().getIdentifier(), 719 Value(legalRead), resultType); 720 return legalRead; 721 }()); 722 723 return success(); 724 } 725 }; 726 727 /// A rewrite to turn unit dim transpose-like vector.shape_casts into 728 /// vector.transposes. The shape_cast has to be from an illegal vector type to a 729 /// legal one (as defined by isLegalVectorType). 730 /// 731 /// The reasoning for this is if we've got to this pass and we still have 732 /// shape_casts of illegal types, then they likely will not cancel out. Turning 733 /// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to 734 /// eliminate them. 735 /// 736 /// Example: 737 /// 738 /// BEFORE: 739 /// ```mlir 740 /// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32> 741 /// ``` 742 /// 743 /// AFTER: 744 /// ```mlir 745 /// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> 746 /// ``` 747 struct ConvertIllegalShapeCastOpsToTransposes 748 : public OpRewritePattern<vector::ShapeCastOp> { 749 using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern; 750 751 LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, 752 PatternRewriter &rewriter) const override { 753 auto sourceType = shapeCastOp.getSourceVectorType(); 754 auto resultType = shapeCastOp.getResultVectorType(); 755 if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType)) 756 return rewriter.notifyMatchFailure(shapeCastOp, 757 kMatchFailureNotIllegalToLegal); 758 759 // Note: If we know that `sourceType` is an illegal vector type (and 2D) 760 // then dim 0 is scalable and dim 1 is fixed. 761 if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1) 762 return rewriter.notifyMatchFailure( 763 shapeCastOp, "expected source to be a 2D scalable vector with a " 764 "trailing unit dim"); 765 766 auto loc = shapeCastOp.getLoc(); 767 auto transpose = rewriter.create<vector::TransposeOp>( 768 loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0}); 769 770 if (resultType.getRank() == 1) 771 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType, 772 transpose); 773 else 774 rewriter.replaceOp(shapeCastOp, transpose); 775 776 return success(); 777 } 778 }; 779 780 /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use 781 /// the ZA state. This workaround rewrite to support these transposes when ZA is 782 /// available. 783 /// 784 /// Example: 785 /// 786 /// BEFORE: 787 /// ```mlir 788 /// %transpose = vector.transpose %vec, [1, 0] 789 /// : vector<2x[4]xf32> to vector<[4]x2xf32> 790 /// vector.transfer_write %transpose, %dest[%y, %x] 791 /// : vector<[4]x2xf32>, memref<?x?xf32> 792 /// ``` 793 /// 794 /// AFTER: 795 /// ```mlir 796 /// %0 = arm_sme.get_tile : vector<[4]x[4]xf32> 797 /// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32> 798 /// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32> 799 /// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32> 800 /// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32> 801 /// %c4_vscale = arith.muli %vscale, %c4 : index 802 /// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1> 803 /// vector.transfer_write %4, %dest[%y, %x], %mask 804 /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} 805 /// : vector<[4]x[4]xf32>, memref<?x?xf32> 806 /// ``` 807 /// 808 /// Values larger than a single tile are supported via decomposition. 809 struct LowerIllegalTransposeStoreViaZA 810 : public OpRewritePattern<vector::TransferWriteOp> { 811 using OpRewritePattern::OpRewritePattern; 812 813 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 814 PatternRewriter &rewriter) const override { 815 if (!isSupportedMaskOp(writeOp.getMask())) 816 return rewriter.notifyMatchFailure(writeOp, 817 kMatchFailureUnsupportedMaskOp); 818 819 auto permutationMap = writeOp.getPermutationMap(); 820 if (!permutationMap.isIdentity()) 821 return rewriter.notifyMatchFailure(writeOp, 822 kMatchFailureNonPermutationMap); 823 824 auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>(); 825 if (!transposeOp) 826 return failure(); 827 828 auto sourceType = transposeOp.getSourceVectorType(); 829 auto resultType = transposeOp.getResultVectorType(); 830 831 if (resultType.getRank() != 2) 832 return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2"); 833 834 if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType)) 835 return rewriter.notifyMatchFailure( 836 transposeOp, "not illegal/unsupported SVE transpose"); 837 838 auto smeTileType = getSMETileTypeForElement(resultType.getElementType()); 839 VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0); 840 841 if (sourceType.getDimSize(0) <= 1 || 842 sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0) 843 return rewriter.notifyMatchFailure(writeOp, "unsupported source shape"); 844 845 auto loc = writeOp.getLoc(); 846 auto createVscaleMultiple = 847 vector::makeVscaleConstantBuilder(rewriter, loc); 848 849 auto transposeMap = AffineMapAttr::get( 850 AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext())); 851 852 // Note: We need to use `get_tile` as there's no vector-level `undef`. 853 Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType); 854 Value destTensorOrMemref = writeOp.getSource(); 855 auto numSlicesPerTile = 856 std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0)); 857 auto numSlices = 858 rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile); 859 for (auto [index, smeTile] : llvm::enumerate( 860 decomposeToSMETiles(rewriter, sourceType, smeTileType))) { 861 // 1. _Deliberately_ drop a scalable dimension and insert a fixed number 862 // of slices from the source type into the SME tile. Without checking 863 // vscale (and emitting multiple implementations) we can't make use of the 864 // rows of the tile after 1*vscale rows. 865 Value tile = undefTile; 866 for (int d = 0; d < numSlicesPerTile; ++d) { 867 Value vector = rewriter.create<vector::ExtractOp>( 868 loc, transposeOp.getVector(), 869 rewriter.getIndexAttr(d + smeTile.row)); 870 if (vector.getType() != smeSliceType) { 871 vector = rewriter.create<vector::ScalableExtractOp>( 872 loc, smeSliceType, vector, smeTile.col); 873 } 874 tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d); 875 } 876 877 // 2. Transpose the tile position. 878 auto transposedRow = createVscaleMultiple(smeTile.col); 879 auto transposedCol = 880 rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row); 881 882 // 3. Compute mask for tile store. 883 Value maskRows; 884 Value maskCols; 885 if (auto mask = writeOp.getMask()) { 886 auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); 887 maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0), 888 transposedRow); 889 maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1), 890 transposedCol); 891 maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices); 892 } else { 893 maskRows = createVscaleMultiple(smeTileType.getDimSize(0)); 894 maskCols = numSlices; 895 } 896 auto subMask = rewriter.create<vector::CreateMaskOp>( 897 loc, smeTileType.clone(rewriter.getI1Type()), 898 ValueRange{maskRows, maskCols}); 899 900 // 4. Emit a transposed tile write. 901 auto writeIndices = writeOp.getIndices(); 902 Value destRow = 903 rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]); 904 Value destCol = 905 rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]); 906 auto smeWrite = rewriter.create<vector::TransferWriteOp>( 907 loc, tile, destTensorOrMemref, ValueRange{destRow, destCol}, 908 transposeMap, subMask, writeOp.getInBounds()); 909 910 if (writeOp.hasPureTensorSemantics()) 911 destTensorOrMemref = smeWrite.getResult(); 912 } 913 914 if (writeOp.hasPureTensorSemantics()) 915 rewriter.replaceOp(writeOp, destTensorOrMemref); 916 else 917 rewriter.eraseOp(writeOp); 918 919 return success(); 920 } 921 }; 922 923 struct VectorLegalizationPass 924 : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> { 925 void runOnOperation() override { 926 auto *context = &getContext(); 927 TypeConverter converter; 928 RewritePatternSet patterns(context); 929 converter.addConversion([](Type type) { return type; }); 930 converter.addConversion( 931 [](VectorType vectorType, 932 SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> { 933 if (!isMultipleOfSMETileVectorType(vectorType)) 934 return std::nullopt; 935 auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType); 936 auto smeTileType = 937 getSMETileTypeForElement(vectorType.getElementType()); 938 types = SmallVector<Type>(smeTileCount, smeTileType); 939 return success(); 940 }); 941 942 // Apply preprocessing patterns. 943 RewritePatternSet rewritePatterns(context); 944 rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks, 945 LiftIllegalVectorTransposeToMemory, 946 ConvertIllegalShapeCastOpsToTransposes, 947 LowerIllegalTransposeStoreViaZA>(context); 948 if (failed( 949 applyPatternsGreedily(getOperation(), std::move(rewritePatterns)))) 950 return signalPassFailure(); 951 952 // Note: These two patterns are added with a high benefit to ensure: 953 // - Masked outer products are handled before unmasked ones 954 // - Multi-tile writes are lowered as a store loop (if possible) 955 patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition, 956 LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context, 957 /*benefit=*/1024); 958 patterns.add<LegalizeArithConstantOpsByDecomposition, 959 LegalizeVectorOuterProductOpsByDecomposition, 960 LegalizeTransferReadOpsByDecomposition, 961 LegalizeTransferWriteOpsByDecomposition>(converter, context); 962 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, 963 converter); 964 populateCallOpTypeConversionPattern(patterns, converter); 965 populateReturnOpTypeConversionPattern(patterns, converter); 966 scf::populateSCFStructuralTypeConversions(converter, patterns); 967 968 ConversionTarget target(getContext()); 969 target.markUnknownOpDynamicallyLegal( 970 [&](Operation *op) { return converter.isLegal(op); }); 971 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { 972 return converter.isSignatureLegal(op.getFunctionType()); 973 }); 974 if (failed(applyPartialConversion(getOperation(), target, 975 std::move(patterns)))) 976 return signalPassFailure(); 977 } 978 }; 979 980 } // namespace 981 982 std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() { 983 return std::make_unique<VectorLegalizationPass>(); 984 } 985