1 //===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" 10 11 #include "mlir/Dialect/ArmSME/IR/ArmSME.h" 12 #include "mlir/Dialect/ArmSME/Utils/Utils.h" 13 #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "llvm/Support/Casting.h" 17 18 using namespace mlir; 19 20 namespace { 21 22 /// Conversion pattern for vector.transfer_read. 23 /// 24 /// --- 25 /// 26 /// Example 1: op with identity permutation map to horizontal 27 /// arm_sme.tile_load: 28 /// 29 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1) 30 /// 31 /// is converted to: 32 /// 33 /// arm_sme.tile_load ... 34 /// 35 /// --- 36 /// 37 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load 38 /// (in-flight transpose): 39 /// 40 /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0) 41 /// 42 /// is converted to: 43 /// 44 /// arm_sme.tile_load ... layout<vertical> 45 struct TransferReadToArmSMELowering 46 : public OpRewritePattern<vector::TransferReadOp> { 47 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 48 49 LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, 50 PatternRewriter &rewriter) const final { 51 // The permutation map must have two results. 52 if (transferReadOp.getTransferRank() != 2) 53 return rewriter.notifyMatchFailure(transferReadOp, 54 "not a 2 result permutation map"); 55 56 auto vectorType = transferReadOp.getVectorType(); 57 if (!arm_sme::isValidSMETileVectorType(vectorType)) 58 return rewriter.notifyMatchFailure(transferReadOp, 59 "not a valid vector type for SME"); 60 61 if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType())) 62 return rewriter.notifyMatchFailure(transferReadOp, "not a memref source"); 63 64 // Out-of-bounds dims are not supported. 65 if (transferReadOp.hasOutOfBoundsDim()) 66 return rewriter.notifyMatchFailure(transferReadOp, 67 "not inbounds transfer read"); 68 69 AffineMap map = transferReadOp.getPermutationMap(); 70 if (!map.isPermutation()) 71 return rewriter.notifyMatchFailure(transferReadOp, 72 "unsupported permutation map"); 73 74 // Note: For 2D vector types the only non-identity permutation is a simple 75 // transpose [1, 0]. 76 bool transposed = !map.isIdentity(); 77 arm_sme::TileSliceLayout layout = 78 transposed ? arm_sme::TileSliceLayout::Vertical 79 : arm_sme::TileSliceLayout::Horizontal; 80 81 // Padding isn't optional for transfer_read, but is only used in the case 82 // of out-of-bounds accesses (not supported here) and/or masking. Mask is 83 // optional, if it's not present don't pass padding. 84 auto mask = transferReadOp.getMask(); 85 auto padding = mask ? transferReadOp.getPadding() : nullptr; 86 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( 87 transferReadOp, vectorType, transferReadOp.getSource(), 88 transferReadOp.getIndices(), padding, mask, layout); 89 90 return success(); 91 } 92 }; 93 94 /// Conversion pattern for vector.transfer_write. 95 /// 96 /// --- 97 /// 98 /// Example 1: op with identity permutation map to horizontal 99 /// arm_sme.tile_store: 100 /// 101 /// vector.transfer_write %vector, %source[%c0, %c0] 102 /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> 103 /// 104 /// is converted to: 105 /// 106 /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>, 107 /// vector<[16]x[16]xi8> 108 /// --- 109 /// 110 /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store 111 /// (in-flight transpose): 112 /// 113 /// vector.transfer_write %vector, %source[%c0, %c0] 114 /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, 115 /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8> 116 /// 117 /// is converted to: 118 /// 119 /// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical> 120 /// : memref<?x?xi8>, vector<[16]x[16]xi8> 121 struct TransferWriteToArmSMELowering 122 : public OpRewritePattern<vector::TransferWriteOp> { 123 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 124 125 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 126 PatternRewriter &rewriter) const final { 127 auto vType = writeOp.getVectorType(); 128 if (!arm_sme::isValidSMETileVectorType(vType)) 129 return failure(); 130 131 if (!llvm::isa<MemRefType>(writeOp.getSource().getType())) 132 return failure(); 133 134 // Out-of-bounds dims are not supported. 135 if (writeOp.hasOutOfBoundsDim()) 136 return rewriter.notifyMatchFailure(writeOp, 137 "not inbounds transfer write"); 138 139 AffineMap map = writeOp.getPermutationMap(); 140 if (!map.isPermutation()) 141 return rewriter.notifyMatchFailure(writeOp, 142 "unsupported permutation map"); 143 144 // Note: For 2D vector types the only non-identity permutation is a simple 145 // transpose [1, 0]. 146 bool transposed = !map.isIdentity(); 147 arm_sme::TileSliceLayout layout = 148 transposed ? arm_sme::TileSliceLayout::Vertical 149 : arm_sme::TileSliceLayout::Horizontal; 150 151 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>( 152 writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(), 153 writeOp.getMask(), layout); 154 return success(); 155 } 156 }; 157 158 /// Conversion pattern for vector.load. 159 struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> { 160 using OpRewritePattern<vector::LoadOp>::OpRewritePattern; 161 162 LogicalResult matchAndRewrite(vector::LoadOp load, 163 PatternRewriter &rewriter) const override { 164 if (!arm_sme::isValidSMETileVectorType(load.getVectorType())) 165 return failure(); 166 167 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( 168 load, load.getVectorType(), load.getBase(), load.getIndices()); 169 170 return success(); 171 } 172 }; 173 174 /// Conversion pattern for vector.store. 175 struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> { 176 using OpRewritePattern<vector::StoreOp>::OpRewritePattern; 177 178 LogicalResult matchAndRewrite(vector::StoreOp store, 179 PatternRewriter &rewriter) const override { 180 if (!arm_sme::isValidSMETileVectorType(store.getVectorType())) 181 return failure(); 182 183 rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>( 184 store, store.getValueToStore(), store.getBase(), store.getIndices()); 185 186 return success(); 187 } 188 }; 189 190 /// Conversion pattern for vector.broadcast. 191 /// 192 /// Example: 193 /// 194 /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32> 195 /// 196 /// is converted to: 197 /// 198 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> 199 /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices 200 /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) 201 /// { 202 /// %tile_update = arm_sme.insert_tile_slice 203 /// %broadcast_to_1d, %iter_tile[%tile_slice_index] : 204 /// vector<[4]xi32> into vector<[4]x[4]xi32> 205 /// scf.yield %tile_update : vector<[4]x[4]xi32> 206 /// } 207 /// 208 /// Supports scalar, 0-d vector, and 1-d vector broadcasts. 209 struct BroadcastOpToArmSMELowering 210 : public OpRewritePattern<vector::BroadcastOp> { 211 using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern; 212 213 LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, 214 PatternRewriter &rewriter) const final { 215 auto tileType = broadcastOp.getResultVectorType(); 216 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) 217 return failure(); 218 219 auto loc = broadcastOp.getLoc(); 220 221 auto srcType = broadcastOp.getSourceType(); 222 auto srcVectorType = dyn_cast<VectorType>(srcType); 223 224 Value broadcastOp1D; 225 if (srcType.isIntOrFloat() || 226 (srcVectorType && (srcVectorType.getRank() == 0))) { 227 // Broadcast scalar or 0-d vector to 1-d vector. 228 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); 229 broadcastOp1D = rewriter.create<vector::BroadcastOp>( 230 loc, tileSliceType, broadcastOp.getSource()); 231 } else if (srcVectorType && (srcVectorType.getRank() == 1)) 232 // Value to broadcast is already a 1-d vector, nothing to do. 233 broadcastOp1D = broadcastOp.getSource(); 234 else 235 return failure(); 236 237 auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); 238 239 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, 240 Value currentTile) { 241 // Create 'arm_sme.insert_tile_slice' to broadcast the value 242 // to each tile slice. 243 auto nextTile = b.create<arm_sme::InsertTileSliceOp>( 244 loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); 245 return nextTile.getResult(); 246 }; 247 248 // Create a loop over ZA tile slices. 249 auto forOp = 250 createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); 251 252 rewriter.replaceOp(broadcastOp, forOp.getResult(0)); 253 254 return success(); 255 } 256 }; 257 258 /// Conversion pattern for vector.splat. 259 /// 260 /// Example: 261 /// 262 /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32> 263 /// 264 /// is converted to: 265 /// 266 /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> 267 /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices 268 /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) 269 /// { 270 /// %tile_update = arm_sme.insert_tile_slice 271 /// %broadcast_to_1d, %iter_tile[%tile_slice_index] : 272 /// vector<[4]xi32> into vector<[4]x[4]xi32> 273 /// scf.yield %tile_update : vector<[4]x[4]xi32> 274 /// } 275 /// 276 /// This is identical to vector.broadcast of a scalar. 277 struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> { 278 using OpRewritePattern<vector::SplatOp>::OpRewritePattern; 279 280 LogicalResult matchAndRewrite(vector::SplatOp splatOp, 281 PatternRewriter &rewriter) const final { 282 auto tileType = splatOp.getResult().getType(); 283 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) 284 return failure(); 285 286 auto loc = splatOp.getLoc(); 287 auto srcType = splatOp.getOperand().getType(); 288 289 assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat"); 290 // Avoid unused-variable warning when building without assertions. 291 (void)srcType; 292 293 // First, broadcast the scalar to a 1-d vector. 294 VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); 295 Value broadcastOp1D = rewriter.create<vector::BroadcastOp>( 296 loc, tileSliceType, splatOp.getInput()); 297 298 auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType); 299 300 auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, 301 Value currentTile) { 302 auto nextTile = b.create<arm_sme::InsertTileSliceOp>( 303 loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); 304 return nextTile.getResult(); 305 }; 306 307 // Next, create a loop over ZA tile slices and "move" the generated 1-d 308 // vector to each slice. 309 auto forOp = 310 createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); 311 312 rewriter.replaceOp(splatOp, forOp.getResult(0)); 313 314 return success(); 315 } 316 }; 317 318 /// Conversion pattern for vector.transpose. 319 /// 320 /// Stores the input tile to memory and reloads vertically. 321 /// 322 /// Example: 323 /// 324 /// %transposed_src = vector.transpose %src, [1, 0] 325 /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> 326 /// 327 /// is converted to: 328 /// 329 /// %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32> 330 /// %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0] 331 /// : memref<?x?xi32>, vector<[4]x[4]xi32> 332 /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0] 333 /// layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32> 334 /// 335 /// NOTE: Transposing via memory is obviously expensive, the current intention 336 /// is to avoid the transpose if possible, this is therefore intended as a 337 /// fallback and to provide base support for Vector ops. If it turns out 338 /// transposes can't be avoided then this should be replaced with a more optimal 339 /// implementation, perhaps with tile <-> vector (MOVA) ops. 340 struct TransposeOpToArmSMELowering 341 : public OpRewritePattern<vector::TransposeOp> { 342 using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; 343 344 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, 345 PatternRewriter &rewriter) const final { 346 auto tileType = transposeOp.getResultVectorType(); 347 if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) 348 return failure(); 349 350 // Bail unless this is a true 2-D matrix transpose. 351 ArrayRef<int64_t> permutation = transposeOp.getPermutation(); 352 if (permutation[0] != 1 || permutation[1] != 0) 353 return failure(); 354 355 auto loc = transposeOp.getLoc(); 356 Value input = transposeOp.getVector(); 357 358 if (auto xferOp = input.getDefiningOp<vector::TransferReadOp>(); 359 xferOp && xferOp->hasOneUse()) { 360 // Fold transpose into transfer_read to enable in-flight transpose when 361 // converting to arm_sme.tile_load. 362 rewriter.modifyOpInPlace(xferOp, [&]() { 363 xferOp->setAttr(xferOp.getPermutationMapAttrName(), 364 AffineMapAttr::get(AffineMap::getPermutationMap( 365 permutation, transposeOp.getContext()))); 366 }); 367 rewriter.replaceOp(transposeOp, xferOp); 368 return success(); 369 } 370 371 // Allocate buffer to store input tile to. 372 Value vscale = 373 rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); 374 Value minTileSlices = rewriter.create<arith::ConstantOp>( 375 loc, rewriter.getIndexAttr(tileType.getDimSize(0))); 376 Value c0 = 377 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); 378 Value numTileSlices = 379 rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices); 380 auto bufferType = 381 MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic}, 382 tileType.getElementType()); 383 auto buffer = rewriter.create<memref::AllocaOp>( 384 loc, bufferType, ValueRange{numTileSlices, numTileSlices}); 385 386 // Store input tile. 387 auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>( 388 loc, input, buffer, ValueRange{c0, c0}); 389 390 // Reload input tile vertically. 391 rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>( 392 transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(), 393 arm_sme::TileSliceLayout::Vertical); 394 395 return success(); 396 } 397 }; 398 399 /// Conversion pattern for vector.outerproduct. 400 /// 401 /// If the vector.outerproduct is masked (and the mask is from a 402 /// vector.create_mask), then the mask is decomposed into two 1-D masks for the 403 /// operands. 404 /// 405 /// Example: 406 /// 407 /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1> 408 /// %result = vector.mask %mask { 409 /// vector.outerproduct %vecA, %vecB 410 /// : vector<[4]xf32>, vector<[4]xf32> 411 /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> 412 /// 413 /// is converted to: 414 /// 415 /// %maskA = vector.create_mask %dimA : vector<[4]xi1> 416 /// %maskB = vector.create_mask %dimB : vector<[4]xi1> 417 /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) 418 /// : vector<[4]xf32>, vector<[4]xf32> 419 /// 420 /// Unmasked outerproducts can be directly replaced with the arm_sme op. 421 /// 422 /// Example: 423 /// 424 /// %result = vector.outerproduct %vecA, %vecB 425 /// : vector<[4]xf32>, vector<[4]xf32> 426 /// 427 /// is converted to: 428 /// 429 /// %result = arm_sme.outerproduct %vecA, %vecB 430 /// : vector<[4]xf32>, vector<[4]xf32> 431 /// 432 struct VectorOuterProductToArmSMELowering 433 : public OpRewritePattern<vector::OuterProductOp> { 434 435 using OpRewritePattern<vector::OuterProductOp>::OpRewritePattern; 436 437 LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp, 438 PatternRewriter &rewriter) const override { 439 440 // We don't yet support lowering AXPY operations to SME. These could be 441 // lowered by masking out all but the first element of the LHS. 442 if (!isa<VectorType>(outerProductOp.getOperandTypeRHS())) 443 return rewriter.notifyMatchFailure(outerProductOp, 444 "AXPY operations not supported"); 445 446 if (!arm_sme::isValidSMETileVectorType( 447 outerProductOp.getResultVectorType())) 448 return rewriter.notifyMatchFailure( 449 outerProductOp, "outer product does not fit into SME tile"); 450 451 auto kind = outerProductOp.getKind(); 452 if (kind != vector::CombiningKind::ADD) 453 return rewriter.notifyMatchFailure( 454 outerProductOp, 455 "unsupported kind (lowering to SME only supports ADD at the moment)"); 456 457 Value lhsMask = {}; 458 Value rhsMask = {}; 459 Operation *rootOp = outerProductOp; 460 auto loc = outerProductOp.getLoc(); 461 if (outerProductOp.isMasked()) { 462 auto maskOp = outerProductOp.getMaskingOp(); 463 rewriter.setInsertionPoint(maskOp); 464 rootOp = maskOp; 465 auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter); 466 if (failed(operandMasks)) 467 return failure(); 468 std::tie(lhsMask, rhsMask) = *operandMasks; 469 } 470 471 rewriter.replaceOpWithNewOp<arm_sme::OuterProductOp>( 472 rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(), 473 outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc()); 474 475 return success(); 476 } 477 478 static FailureOr<std::pair<Value, Value>> 479 decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) { 480 // Attempt to extract masks from vector.create_mask. 481 // TODO: Add support for other mask sources. 482 auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>(); 483 if (!createMaskOp) 484 return failure(); 485 486 auto maskType = createMaskOp.getVectorType(); 487 Value lhsMaskDim = createMaskOp.getOperand(0); 488 Value rhsMaskDim = createMaskOp.getOperand(1); 489 490 VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0); 491 Value lhsMask = 492 rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, lhsMaskDim); 493 Value rhsMask = 494 rewriter.create<vector::CreateMaskOp>(loc, operandMaskType, rhsMaskDim); 495 496 return std::make_pair(lhsMask, rhsMask); 497 } 498 }; 499 500 /// Lower `vector.extract` using `arm_sme.extract_tile_slice`. 501 /// 502 /// Example: 503 /// ``` 504 /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32> 505 /// ``` 506 /// Becomes: 507 /// ``` 508 /// %slice = arm_sme.extract_tile_slice %tile[%row] 509 /// : vector<[4]xi32> from vector<[4]x[4]xi32> 510 /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32> 511 /// ``` 512 struct VectorExtractToArmSMELowering 513 : public OpRewritePattern<vector::ExtractOp> { 514 using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; 515 516 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 517 PatternRewriter &rewriter) const override { 518 VectorType sourceType = extractOp.getSourceVectorType(); 519 if (!arm_sme::isValidSMETileVectorType(sourceType)) 520 return failure(); 521 522 auto loc = extractOp.getLoc(); 523 auto position = extractOp.getMixedPosition(); 524 525 Value sourceVector = extractOp.getVector(); 526 527 // Extract entire vector. Should be handled by folder, but just to be safe. 528 if (position.empty()) { 529 rewriter.replaceOp(extractOp, sourceVector); 530 return success(); 531 } 532 533 Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front(); 534 auto extractTileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( 535 loc, sourceVector, sliceIndex); 536 537 if (position.size() == 1) { 538 // Single index case: Extracts a 1D slice. 539 rewriter.replaceOp(extractOp, extractTileSlice); 540 return success(); 541 } 542 543 // Two indices case: Extracts a single element. 544 assert(position.size() == 2); 545 rewriter.replaceOpWithNewOp<vector::ExtractOp>(extractOp, extractTileSlice, 546 position[1]); 547 548 return success(); 549 } 550 }; 551 552 /// Lower `vector.insert` using `arm_sme.insert_tile_slice` and 553 /// `arm_sme.extract_tile_slice`. 554 /// 555 /// Example: 556 /// ``` 557 /// %new_tile = vector.insert %el, %tile[%row, %col] 558 /// : i32 into vector<[4]x[4]xi32> 559 /// ``` 560 /// Becomes: 561 /// ``` 562 /// %slice = arm_sme.extract_tile_slice %tile[%row] 563 /// : vector<[4]xi32> from vector<[4]x[4]xi32> 564 /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32> 565 /// %new_tile = arm_sme.insert_tile_slice %new_slice, %tile[%row] 566 /// : vector<[4]xi32> into vector<[4]x[4]xi32> 567 /// ``` 568 struct VectorInsertToArmSMELowering 569 : public OpRewritePattern<vector::InsertOp> { 570 using OpRewritePattern<vector::InsertOp>::OpRewritePattern; 571 572 LogicalResult matchAndRewrite(vector::InsertOp insertOp, 573 PatternRewriter &rewriter) const override { 574 VectorType resultType = insertOp.getResult().getType(); 575 576 if (!arm_sme::isValidSMETileVectorType(resultType)) 577 return failure(); 578 579 auto loc = insertOp.getLoc(); 580 auto position = insertOp.getMixedPosition(); 581 582 Value source = insertOp.getSource(); 583 584 // Overwrite entire vector with value. Should be handled by folder, but 585 // just to be safe. 586 if (position.empty()) { 587 rewriter.replaceOp(insertOp, source); 588 return success(); 589 } 590 591 Value tileSlice = source; 592 Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front(); 593 if (position.size() == 2) { 594 // Two indices case: Insert single element into tile. 595 // We need to first extract the existing slice and update the element. 596 tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( 597 loc, insertOp.getDest(), sliceIndex); 598 tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice, 599 position[1]); 600 } 601 602 // Insert the slice into the destination tile. 603 rewriter.replaceOpWithNewOp<arm_sme::InsertTileSliceOp>( 604 insertOp, tileSlice, insertOp.getDest(), sliceIndex); 605 return success(); 606 } 607 }; 608 609 /// Lowers `vector.print` of a tile into a loop over the rows of the tile, 610 /// extracting them via `arm_sme.extract_tile_slice`, then printing with 611 /// a 1D `vector.print`. 612 /// 613 /// BEFORE: 614 /// ```mlir 615 /// vector.print %tile : vector<[4]x[4]xf32> 616 /// ``` 617 /// AFTER: 618 /// ```mlir 619 /// %c0 = arith.constant 0 : index 620 /// %c1 = arith.constant 1 : index 621 /// %c4 = arith.constant 4 : index 622 /// %vscale = vector.vscale 623 /// %svl_s = arith.muli %c4, %vscale : index 624 /// scf.for %i = %c0 to %svl_s step %c1 { 625 /// %tile_slice = arm_sme.extract_tile_slice %tile[%i] 626 /// : vector<[4]xf32> from vector<[4]x[4]xf32> 627 /// vector.print %tile_slice : vector<[4]xf32> 628 /// } 629 /// ``` 630 struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> { 631 using OpRewritePattern<vector::PrintOp>::OpRewritePattern; 632 633 LogicalResult matchAndRewrite(vector::PrintOp printOp, 634 PatternRewriter &rewriter) const override { 635 if (!printOp.getSource()) 636 return failure(); 637 638 VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType()); 639 if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType)) 640 return failure(); 641 642 auto loc = printOp.getLoc(); 643 644 // Create a loop over the rows of the tile. 645 auto vscale = rewriter.create<vector::VectorScaleOp>(loc); 646 auto minTileRows = 647 rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0)); 648 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); 649 auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale); 650 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 651 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); 652 { 653 // Loop body. 654 rewriter.setInsertionPointToStart(forOp.getBody()); 655 // Extract the current row from the tile. 656 Value rowIndex = forOp.getInductionVar(); 657 auto tileSlice = rewriter.create<arm_sme::ExtractTileSliceOp>( 658 loc, printOp.getSource(), rowIndex); 659 // Print the row with a 1D vector.print. 660 rewriter.create<vector::PrintOp>(loc, tileSlice, 661 printOp.getPunctuation()); 662 } 663 664 rewriter.eraseOp(printOp); 665 return success(); 666 } 667 }; 668 669 /// Folds a ExtractTileSliceOp + TransferWriteOp to a StoreTileSliceOp. 670 /// 671 /// BEFORE: 672 /// ```mlir 673 /// %slice = arm_sme.extract_tile_slice %tile[%index] 674 /// : vector<[4]xf32> from vector<[4]x[4]xf32> 675 /// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]} 676 /// : vector<[4]xf32>, memref<?x?xf32> 677 /// ``` 678 /// AFTER: 679 /// ```mlir 680 /// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j] 681 /// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32> 682 /// ``` 683 struct FoldTransferWriteOfExtractTileSlice 684 : public OpRewritePattern<vector::TransferWriteOp> { 685 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 686 687 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 688 PatternRewriter &rewriter) const final { 689 if (!isa<MemRefType>(writeOp.getSource().getType())) 690 return rewriter.notifyMatchFailure(writeOp, "destination not a memref"); 691 692 if (writeOp.hasOutOfBoundsDim()) 693 return rewriter.notifyMatchFailure(writeOp, 694 "not inbounds transfer write"); 695 696 auto extractTileSlice = 697 writeOp.getVector().getDefiningOp<arm_sme::ExtractTileSliceOp>(); 698 if (!extractTileSlice) 699 return rewriter.notifyMatchFailure( 700 writeOp, "vector to store not from ExtractTileSliceOp"); 701 702 AffineMap map = writeOp.getPermutationMap(); 703 if (!map.isMinorIdentity()) 704 return rewriter.notifyMatchFailure(writeOp, 705 "unsupported permutation map"); 706 707 Value mask = writeOp.getMask(); 708 if (!mask) { 709 auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type()); 710 mask = rewriter.create<arith::ConstantOp>( 711 writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true)); 712 } 713 714 rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>( 715 writeOp, extractTileSlice.getTile(), 716 extractTileSlice.getTileSliceIndex(), mask, writeOp.getSource(), 717 writeOp.getIndices(), extractTileSlice.getLayout()); 718 return success(); 719 } 720 }; 721 722 /// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to 723 /// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or 724 /// SVE 2.1), so this is currently the most logical place for this lowering. 725 /// 726 /// Example: 727 /// ```mlir 728 /// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> 729 /// %slice = vector.extract %mask[%index] 730 /// : vector<[8]xi1> from vector<[4]x[8]xi1> 731 /// ``` 732 /// Becomes: 733 /// ``` 734 /// %mask_rows = vector.create_mask %a : vector<[4]xi1> 735 /// %mask_cols = vector.create_mask %b : vector<[8]xi1> 736 /// %slice = arm_sve.psel %mask_cols, %mask_rows[%index] 737 /// : vector<[8]xi1>, vector<[4]xi1> 738 /// ``` 739 struct ExtractFromCreateMaskToPselLowering 740 : public OpRewritePattern<vector::ExtractOp> { 741 using OpRewritePattern<vector::ExtractOp>::OpRewritePattern; 742 743 LogicalResult matchAndRewrite(vector::ExtractOp extractOp, 744 PatternRewriter &rewriter) const override { 745 if (extractOp.getNumIndices() != 1) 746 return rewriter.notifyMatchFailure(extractOp, "not single extract index"); 747 748 auto resultType = extractOp.getResult().getType(); 749 auto resultVectorType = dyn_cast<VectorType>(resultType); 750 if (!resultVectorType) 751 return rewriter.notifyMatchFailure(extractOp, "result not VectorType"); 752 753 auto createMaskOp = 754 extractOp.getVector().getDefiningOp<vector::CreateMaskOp>(); 755 if (!createMaskOp) 756 return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp"); 757 758 auto maskType = createMaskOp.getVectorType(); 759 if (maskType.getRank() != 2 || !maskType.allDimsScalable()) 760 return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask"); 761 762 auto isSVEPredicateSize = [](int64_t size) { 763 return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size)); 764 }; 765 766 auto rowsBaseSize = maskType.getDimSize(0); 767 auto colsBaseSize = maskType.getDimSize(1); 768 if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize)) 769 return rewriter.notifyMatchFailure( 770 createMaskOp, "mask dimensions not SVE predicate-sized"); 771 772 auto loc = extractOp.getLoc(); 773 VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1); 774 VectorType colMaskType = VectorType::Builder(maskType).dropDim(0); 775 776 // Create the two 1-D masks at the location of the 2-D create_mask (which is 777 // usually outside a loop). This prevents the need for later hoisting. 778 rewriter.setInsertionPoint(createMaskOp); 779 auto rowMask = rewriter.create<vector::CreateMaskOp>( 780 loc, rowMaskType, createMaskOp.getOperand(0)); 781 auto colMask = rewriter.create<vector::CreateMaskOp>( 782 loc, colMaskType, createMaskOp.getOperand(1)); 783 784 rewriter.setInsertionPoint(extractOp); 785 auto position = 786 vector::getAsValues(rewriter, loc, extractOp.getMixedPosition()); 787 rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask, 788 position[0]); 789 return success(); 790 } 791 }; 792 793 } // namespace 794 795 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, 796 MLIRContext &ctx) { 797 patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering, 798 TransferReadToArmSMELowering, TransferWriteToArmSMELowering, 799 TransposeOpToArmSMELowering, VectorLoadToArmSMELowering, 800 VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering, 801 VectorExtractToArmSMELowering, VectorInsertToArmSMELowering, 802 VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice, 803 ExtractFromCreateMaskToPselLowering>(&ctx); 804 } 805