1 //===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements lowering of ArmSME operations to LLVM intrinsics. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h" 14 15 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 16 #include "mlir/Conversion/LLVMCommon/Pattern.h" 17 #include "mlir/Dialect/Arith/IR/Arith.h" 18 #include "mlir/Dialect/ArmSME/IR/ArmSME.h" 19 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" 20 #include "mlir/Dialect/ArmSME/Utils/Utils.h" 21 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 22 #include "mlir/Dialect/Func/IR/FuncOps.h" 23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 24 #include "mlir/Dialect/MemRef/IR/MemRef.h" 25 #include "mlir/Dialect/Vector/IR/VectorOps.h" 26 #include "mlir/Pass/Pass.h" 27 #include "mlir/Transforms/DialectConversion.h" 28 #include "llvm/ADT/ScopeExit.h" 29 30 namespace mlir { 31 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM 32 #include "mlir/Conversion/Passes.h.inc" 33 } // namespace mlir 34 35 using namespace mlir; 36 37 namespace { 38 39 static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id"); 40 41 /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic. 42 static Operation *createLoadTileSliceIntrinsic( 43 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type, 44 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr, 45 IntegerAttr tileId, Value tileSliceI32) { 46 if (layout == arm_sme::TileSliceLayout::Horizontal) { 47 switch (type) { 48 case arm_sme::ArmSMETileType::ZAB: 49 return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>( 50 loc, maskOp, ptr, tileId, tileSliceI32); 51 case arm_sme::ArmSMETileType::ZAH: 52 return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>( 53 loc, maskOp, ptr, tileId, tileSliceI32); 54 case arm_sme::ArmSMETileType::ZAS: 55 return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>( 56 loc, maskOp, ptr, tileId, tileSliceI32); 57 case arm_sme::ArmSMETileType::ZAD: 58 return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>( 59 loc, maskOp, ptr, tileId, tileSliceI32); 60 case arm_sme::ArmSMETileType::ZAQ: 61 return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>( 62 loc, maskOp, ptr, tileId, tileSliceI32); 63 } 64 } else { 65 switch (type) { 66 case arm_sme::ArmSMETileType::ZAB: 67 return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>( 68 loc, maskOp, ptr, tileId, tileSliceI32); 69 case arm_sme::ArmSMETileType::ZAH: 70 return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>( 71 loc, maskOp, ptr, tileId, tileSliceI32); 72 case arm_sme::ArmSMETileType::ZAS: 73 return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>( 74 loc, maskOp, ptr, tileId, tileSliceI32); 75 case arm_sme::ArmSMETileType::ZAD: 76 return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>( 77 loc, maskOp, ptr, tileId, tileSliceI32); 78 case arm_sme::ArmSMETileType::ZAQ: 79 return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>( 80 loc, maskOp, ptr, tileId, tileSliceI32); 81 break; 82 } 83 } 84 llvm_unreachable("unknown type in createLoadTileSliceIntrinsic"); 85 } 86 87 /// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic. 88 static Operation *createStoreTileSliceIntrinsic( 89 RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type, 90 arm_sme::TileSliceLayout layout, Value maskOp, Value ptr, 91 IntegerAttr tileId, Value tileSliceI32) { 92 if (layout == arm_sme::TileSliceLayout::Horizontal) { 93 switch (type) { 94 case arm_sme::ArmSMETileType::ZAB: 95 return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>( 96 loc, maskOp, ptr, tileId, tileSliceI32); 97 case arm_sme::ArmSMETileType::ZAH: 98 return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>( 99 loc, maskOp, ptr, tileId, tileSliceI32); 100 case arm_sme::ArmSMETileType::ZAS: 101 return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>( 102 loc, maskOp, ptr, tileId, tileSliceI32); 103 case arm_sme::ArmSMETileType::ZAD: 104 return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>( 105 loc, maskOp, ptr, tileId, tileSliceI32); 106 case arm_sme::ArmSMETileType::ZAQ: 107 return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>( 108 loc, maskOp, ptr, tileId, tileSliceI32); 109 } 110 } else { 111 switch (type) { 112 case arm_sme::ArmSMETileType::ZAB: 113 return rewriter.create<arm_sme::aarch64_sme_st1b_vert>( 114 loc, maskOp, ptr, tileId, tileSliceI32); 115 case arm_sme::ArmSMETileType::ZAH: 116 return rewriter.create<arm_sme::aarch64_sme_st1h_vert>( 117 loc, maskOp, ptr, tileId, tileSliceI32); 118 case arm_sme::ArmSMETileType::ZAS: 119 return rewriter.create<arm_sme::aarch64_sme_st1w_vert>( 120 loc, maskOp, ptr, tileId, tileSliceI32); 121 case arm_sme::ArmSMETileType::ZAD: 122 return rewriter.create<arm_sme::aarch64_sme_st1d_vert>( 123 loc, maskOp, ptr, tileId, tileSliceI32); 124 case arm_sme::ArmSMETileType::ZAQ: 125 return rewriter.create<arm_sme::aarch64_sme_st1q_vert>( 126 loc, maskOp, ptr, tileId, tileSliceI32); 127 } 128 } 129 llvm_unreachable("unknown type in createStoreTileSliceIntrinsic"); 130 } 131 132 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) { 133 auto tileId = op.getTileId(); 134 if (!tileId) 135 op.emitOpError( 136 "expected tile ID to be allocated before conversion to LLVM"); 137 return tileId; 138 } 139 140 /// Creates an alloca matching the size of tile used by `tileOp`. The alloca is 141 /// placed in the first block of the function. 142 static memref::AllocaOp 143 createAllocaForTile(RewriterBase &rewriter, Location loc, 144 FunctionOpInterface func, 145 arm_sme::ArmSMETileOpInterface tileOp) { 146 RewriterBase::InsertionGuard g(rewriter); 147 // Move to the first operation in the function. 148 rewriter.setInsertionPointToStart(&func.getBlocks().front()); 149 // Create an alloca matching the tile size of the `tileOp`. 150 auto vscale = rewriter.create<vector::VectorScaleOp>(loc); 151 auto tileElementType = tileOp.getTileType().getElementType(); 152 auto memrefType = MemRefType::get( 153 {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType); 154 unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType); 155 auto minElementsOp = 156 rewriter.create<arith::ConstantIndexOp>(loc, minElements); 157 auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp); 158 auto alloca = rewriter.create<memref::AllocaOp>( 159 loc, memrefType, ValueRange{vectorLen, vectorLen}); 160 return alloca; 161 } 162 163 /// Finds or creates an alloca for a spill of a tile. 164 static memref::AllocaOp getOrCreateAllocaForTile( 165 RewriterBase &rewriter, Location loc, FunctionOpInterface func, 166 arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) { 167 // Find an alloca at the top of the function tagged with a 168 // 'arm_sme.in_memory_tile_id' that matches `tileId`. 169 for (auto &op : func.getBlocks().front()) { 170 auto alloca = llvm::dyn_cast<memref::AllocaOp>(op); 171 if (!alloca) 172 continue; 173 auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>( 174 alloca->getDiscardableAttr(kInMemoryTileIdAttr)); 175 if (!inMemoryTileId) 176 continue; 177 if (inMemoryTileId.getInt() == tileId) 178 return alloca; 179 } 180 // Otherwise, create a new alloca: 181 auto alloca = createAllocaForTile(rewriter, loc, func, tileOp); 182 alloca->setDiscardableAttr(kInMemoryTileIdAttr, 183 rewriter.getI32IntegerAttr(tileId)); 184 return alloca; 185 } 186 187 /// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a 188 /// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning 189 /// the op to tile 0, then emitting a full tile swap between ZA and memory 190 /// before + after the tile op. 191 /// 192 /// Example: 193 /// 194 /// // Note: <IN MEMORY TILE> = tile ID >= 16. 195 /// arm_sme.tile_op { tile_id = <IN MEMORY TILE> } 196 /// 197 /// is converted to: 198 /// // At function entry: 199 /// %spill = memref.alloca ... : memref<?x?xty> 200 /// 201 /// // Around op: 202 /// scf.for %slice_idx { 203 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> 204 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> 205 /// vector.store %slice_to_save, %spill[%slice_idx, %c0] 206 /// } 207 /// arm_sme.tile_op { tile_id = 0 } 208 /// scf.for %slice_idx { 209 /// %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}> 210 /// "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx) <{tile_id = 0 : i32}> 211 /// vector.store %slice_to_save, %spill[%slice_idx, %c0] 212 /// } 213 /// 214 /// Note that these spills/fills are not inserted earlier as concept of a 215 /// register, and the need to swap the contents, can't really be represented 216 /// correctly at a high level in MLIR. 217 /// 218 /// TODO: Reduce the spills/reloads to single slices where possible (and omit 219 /// redundant reloads). This could be done via a method on the 220 /// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.: 221 /// 222 /// `tileOp.getZaUsage()` could return: 223 /// 224 /// struct ArmSMEOpZAUsage { 225 /// enum class Kind { 226 /// TileRead, // Omit store after tile operation. 227 /// TileWrite, // Omit load before tile operation. 228 /// TileReadWrite, // Needs both tile load and store. 229 /// SliceRead, // Spill single slice and omit store after operation. 230 /// SliceWrite, // Spill single slice and omit load before operation. 231 /// SliceReadWrite // Spill single slice. 232 /// }; 233 /// Value sliceIndex {}; 234 /// TileSliceLayout sliceLayout { TileSliceLayout::Horizontal }; 235 /// }; 236 /// 237 struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { 238 239 ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName, 240 const LLVMTypeConverter &typeConverter, 241 PatternBenefit benefit) 242 : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(), 243 typeConverter, benefit) {} 244 245 LogicalResult 246 matchAndRewrite(Operation *op, ArrayRef<Value> operands, 247 ConversionPatternRewriter &rewriter) const override { 248 auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op); 249 // Tile has a real (hardware) tile. No spills/reloads required. 250 if (!tileOp.isInMemoryTile()) 251 return failure(); 252 253 tileOp->emitWarning( 254 "failed to allocate SME virtual tile to operation, tile value will go " 255 "through memory, expect degraded performance"); 256 257 // Step 1. Create an alloca for the tile at the top of the function (if one 258 // does not already exist). 259 auto loc = tileOp.getLoc(); 260 auto func = tileOp->getParentOfType<FunctionOpInterface>(); 261 auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp, 262 tileOp.getTileId().getInt()); 263 264 // Step 2. Assign the op a real tile ID. 265 // For simplicity, we always use tile 0 (which always exists). 266 auto zeroTileId = rewriter.getI32IntegerAttr(0); 267 rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); }); 268 269 VectorType tileVectorType = tileOp.getTileType(); 270 auto sliceType = VectorType::Builder(tileVectorType).dropDim(0); 271 auto swapInMemoryTileWithSMETileZero = [&] { 272 emitFullTileSwap(rewriter, loc, tileAlloca, 273 *arm_sme::getSMETileType(tileVectorType), sliceType, 274 zeroTileId); 275 }; 276 277 // Step 3. Emit tile swaps before and after the op. 278 // TODO: Reduce the amount spilled to the amount of data the `tileOp` 279 // touches (i.e. a single tile slice). 280 { 281 rewriter.setInsertionPoint(op); 282 // Swap the contents of ZA and the in-memory tile before the op. 283 swapInMemoryTileWithSMETileZero(); 284 rewriter.setInsertionPointAfter(op); 285 // Swap the tile back out to memory again after the op. 286 swapInMemoryTileWithSMETileZero(); 287 } 288 289 return success(); 290 } 291 292 /// Extracts a pointer to a slice of an in-memory tile. 293 Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc, 294 Value tileMemory, Value sliceIndex) const { 295 auto llvmType = getTypeConverter()->convertType(tileMemory.getType()); 296 auto descriptor = 297 rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory); 298 auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64); 299 auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>( 300 loc, rewriter.getI64Type(), sliceIndex); 301 return getStridedElementPtr( 302 loc, llvm::cast<MemRefType>(tileMemory.getType()), 303 descriptor.getResult(0), {sliceIndexI64, zero}, 304 static_cast<ConversionPatternRewriter &>(rewriter)); 305 } 306 307 /// Emits an in-place swap of a slice of a tile in ZA and a slice of a 308 /// tile-sized memref (`tileAlloca`). 309 void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca, 310 arm_sme::ArmSMETileType tileType, VectorType sliceType, 311 IntegerAttr tileId, Value sliceIndex) const { 312 // Cast the slice index to an i32. 313 auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>( 314 loc, rewriter.getI32Type(), sliceIndex); 315 // Create an all-true predicate for the slice. 316 auto predicateType = sliceType.clone(rewriter.getI1Type()); 317 auto allTruePredicate = rewriter.create<arith::ConstantOp>( 318 loc, DenseElementsAttr::get(predicateType, true)); 319 // Create padding vector (never used due to all-true predicate). 320 auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType); 321 // Get a pointer to the current slice. 322 auto slicePtr = 323 getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex); 324 // Read the value of the current slice from ZA. 325 auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>( 326 loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32); 327 // Load the new tile slice back from memory into ZA. 328 createLoadTileSliceIntrinsic( 329 rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal, 330 allTruePredicate, slicePtr, tileId, sliceIndexI32); 331 // Store the current tile slice to memory. 332 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); 333 rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca, 334 ValueRange{sliceIndex, zero}); 335 } 336 337 /// Emits a full in-place swap of the contents of a tile in ZA and a 338 /// tile-sized memref (`tileAlloca`). 339 void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca, 340 arm_sme::ArmSMETileType tileType, VectorType sliceType, 341 IntegerAttr tileId) const { 342 RewriterBase::InsertionGuard guard(rewriter); 343 // Create an scf.for over all tile slices. 344 auto minNumElts = 345 rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0)); 346 auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); 347 auto upperBound = rewriter.create<arith::MulIOp>( 348 loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc)); 349 auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 350 auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); 351 // Emit a swap for each tile slice. 352 rewriter.setInsertionPointToStart(forOp.getBody()); 353 auto sliceIndex = forOp.getInductionVar(); 354 emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId, 355 sliceIndex); 356 } 357 }; 358 359 enum class RequiresSpillsAndFills { Yes, No }; 360 361 /// Base class for ArmSME to LLVM conversion patterns. By default, this adds 362 /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be 363 /// disabled by setting the `requiresSpillsAndFills` template parameter to 364 /// `RequiresSpillsAndFills::No`. 365 template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills = 366 RequiresSpillsAndFills::Yes> 367 struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> { 368 using ArmSMEOp = SourceOp; 369 using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern; 370 371 static constexpr bool requiresSpillsAndFillsConversion() { 372 return requiresSpillsAndFills == RequiresSpillsAndFills::Yes; 373 } 374 }; 375 376 template <typename Pattern> 377 static void addArmSMEConversionPattern(RewritePatternSet &patterns, 378 LLVMTypeConverter const &typeConverter) { 379 // Register spills/fills for ops that implement the 380 // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to 381 // `RequiresSpillsAndFills::Yes`. 382 if constexpr (Pattern::requiresSpillsAndFillsConversion() && 383 std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait< 384 typename Pattern::ArmSMEOp>, 385 typename Pattern::ArmSMEOp>) { 386 // Add spill/fill conversions with a very high benefit to ensure 387 // they are lowered first. 388 patterns.add<ConvertArmSMESpillsAndFillsToLLVM>( 389 Pattern::ArmSMEOp::getOperationName(), typeConverter, 390 /*benefit=*/1337); 391 } 392 patterns.add<Pattern>(typeConverter); 393 } 394 395 /// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns. 396 template <typename... Patterns> 397 static void 398 addArmSMEConversionPatterns(RewritePatternSet &patterns, 399 LLVMTypeConverter const &typeConverter) { 400 (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...); 401 } 402 403 /// Lower 'arm_sme.zero' to SME intrinsics. 404 /// 405 /// BEFORE: 406 /// ```mlir 407 /// %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32> 408 /// ``` 409 /// 410 /// AFTER: 411 /// ```mlir 412 /// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> () 413 /// %v = arm_sme.get_tile : vector<[4]x[4]xi32> 414 /// ``` 415 /// 416 /// The 'arm_sme.get_tile' (which models the return) will fold away once all 417 /// ArmSME ops have been converted to LLVM intrinsics. 418 struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> { 419 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; 420 421 LogicalResult 422 matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor, 423 ConversionPatternRewriter &rewriter) const override { 424 auto loc = zero.getLoc(); 425 426 auto tileId = getTileIdOrError(zero); 427 if (!tileId) 428 return failure(); 429 430 // Get the base mask for tile based on the element size. 431 // The base mask is just the mask to zero the first tile (of a size). 432 // These masks are derived from: 433 // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles- 434 arm_sme::ArmSMETileType tileType = 435 *arm_sme::getSMETileType(zero.getTileType()); 436 auto baseMaskForSize = [&] { 437 switch (tileType) { 438 case arm_sme::ArmSMETileType::ZAB: 439 // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight 440 // 64-bit element tiles named ZA0.D to ZA7.D. 441 return 0b1111'1111; 442 case arm_sme::ArmSMETileType::ZAH: 443 // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit 444 // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left 445 // once for ZA1.H. 446 return 0b0101'0101; 447 case arm_sme::ArmSMETileType::ZAS: 448 // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit 449 // element tiles named ZA0.D and ZA4.D. 450 // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S. 451 return 0b0001'0001; 452 case arm_sme::ArmSMETileType::ZAD: 453 // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires 454 // setting the bit for that tile. 455 return 0b0000'0001; 456 default: 457 llvm_unreachable("bad element size"); 458 } 459 }(); 460 461 // The actual mask is just the base mask shifted by the tile ID. 462 // This will be folded to a constant after tile allocation. 463 // 464 // The shift is just derived from the layout of the tiles, and that the tile 465 // ID is the index of the tile. For example, looking at the 32-bit ZAx.S 466 // tiles: 467 // 468 // ZA0.S = ZA0.D and ZA4.D 469 // * Tile ID -> 0 470 // * Mask -> 00010001 = (00010001 << 0) 471 // ZA1.S = ZA1.D and ZA5.D 472 // * Tile ID -> 1 473 // * Mask -> 00100010 = (00010001 << 1) 474 // ZA2.S = ZA2.D and ZA6.D 475 // * Tile ID -> 2 476 // * Mask -> 01000100 = (00010001 << 2) 477 // ZA3.S = ZA3.D and ZA7.D 478 // * Tile ID -> 3 479 // * Mask -> 10001000 = (00010001 << 3) 480 // 481 // This holds for all tile sizes. 482 int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt()); 483 rewriter.create<arm_sme::aarch64_sme_zero>( 484 loc, rewriter.getI32IntegerAttr(zeroMask)); 485 486 // Create a placeholder op to preserve dataflow. 487 // Note: Place the `get_tile` op at the start of the block. This ensures 488 // that if there are multiple `zero` ops the intrinsics will be consecutive. 489 rewriter.setInsertionPointToStart(zero->getBlock()); 490 rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType()); 491 492 return success(); 493 } 494 }; 495 496 /// Lower `arm_sme.load_tile_slice` to SME intrinsics. 497 struct LoadTileSliceConversion 498 : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> { 499 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; 500 501 LogicalResult 502 matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp, 503 arm_sme::LoadTileSliceOp::Adaptor adaptor, 504 ConversionPatternRewriter &rewriter) const override { 505 auto loc = loadTileSliceOp.getLoc(); 506 auto tileId = getTileIdOrError(loadTileSliceOp); 507 if (!tileId) 508 return failure(); 509 510 Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(), 511 adaptor.getBase(), 512 adaptor.getIndices(), rewriter); 513 514 auto tileSlice = loadTileSliceOp.getTileSliceIndex(); 515 516 // Cast tile slice to i32 for intrinsic. 517 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( 518 loc, rewriter.getI32Type(), tileSlice); 519 520 // Create all active predicate mask. 521 auto maskOp = loadTileSliceOp.getMask(); 522 523 auto tileVectorType = loadTileSliceOp.getVectorType(); 524 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType); 525 arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout(); 526 527 // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice. 528 createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr, 529 tileId, tileSliceI32); 530 531 // The load intrinsics have no result, replace 'arm_sme.tile_load' with 532 // the input tile to preserve dataflow. 533 rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile()); 534 535 return success(); 536 } 537 }; 538 539 /// Lower for `arm_sme.store_tile_slice` to SME intrinsics. 540 struct StoreTileSliceConversion 541 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> { 542 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; 543 544 LogicalResult 545 matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp, 546 arm_sme::StoreTileSliceOp::Adaptor adaptor, 547 ConversionPatternRewriter &rewriter) const override { 548 auto loc = storeTileSliceOp.getLoc(); 549 auto tileVectorType = storeTileSliceOp.getVectorType(); 550 551 auto tileId = getTileIdOrError(storeTileSliceOp); 552 if (!tileId) 553 return failure(); 554 555 // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice. 556 Value ptr = this->getStridedElementPtr( 557 loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(), 558 adaptor.getIndices(), rewriter); 559 560 auto tileSlice = storeTileSliceOp.getTileSliceIndex(); 561 562 // Cast tile slice to i32 for intrinsic. 563 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( 564 loc, rewriter.getI32Type(), tileSlice); 565 566 auto maskOp = storeTileSliceOp.getMask(); 567 568 arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout(); 569 arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType); 570 571 rewriter.replaceOp(storeTileSliceOp, 572 createStoreTileSliceIntrinsic(rewriter, loc, tileType, 573 layout, maskOp, ptr, 574 tileId, tileSliceI32)); 575 576 return success(); 577 } 578 }; 579 580 /// Lower `arm_sme.insert_tile_slice` to SME intrinsics. 581 struct InsertTileSliceConversion 582 : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> { 583 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; 584 585 LogicalResult 586 matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp, 587 arm_sme::InsertTileSliceOp::Adaptor adaptor, 588 ConversionPatternRewriter &rewriter) const override { 589 auto loc = insertTileSliceOp.getLoc(); 590 auto tileType = insertTileSliceOp.getTileType(); 591 592 auto tileId = getTileIdOrError(insertTileSliceOp); 593 if (!tileId) 594 return failure(); 595 596 auto tileSlice = insertTileSliceOp.getTileSliceIndex(); 597 598 // Cast tile slice from index to i32 for intrinsic. 599 auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>( 600 loc, rewriter.getI32Type(), tileSlice); 601 602 // Create all active predicate mask. 603 auto one = rewriter.create<arith::ConstantOp>( 604 loc, rewriter.getI1Type(), 605 rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); 606 auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), 607 /*scalableDims=*/{true}); 608 auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one); 609 610 // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice. 611 switch (insertTileSliceOp.getLayout()) { 612 case arm_sme::TileSliceLayout::Horizontal: 613 rewriter.create<arm_sme::aarch64_sme_write_horiz>( 614 loc, tileId, tileSliceI32, allActiveMask, 615 insertTileSliceOp.getVector()); 616 break; 617 case arm_sme::TileSliceLayout::Vertical: 618 rewriter.create<arm_sme::aarch64_sme_write_vert>( 619 loc, tileId, tileSliceI32, allActiveMask, 620 insertTileSliceOp.getVector()); 621 break; 622 } 623 624 // Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with 625 // the input tile to preserve dataflow. 626 rewriter.replaceOp(insertTileSliceOp, insertTileSliceOp.getTile()); 627 628 return success(); 629 } 630 }; 631 632 /// Lower `arm_sme.extract_tile_slice` to SME intrinsics. 633 struct ExtractTileSliceConversion 634 : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> { 635 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; 636 637 LogicalResult 638 matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor, 639 ConversionPatternRewriter &rewriter) const override { 640 auto loc = extractTileSlice.getLoc(); 641 auto sliceType = extractTileSlice.getSliceType(); 642 auto sliceIndex = extractTileSlice.getTileSliceIndex(); 643 644 auto tileId = getTileIdOrError(extractTileSlice); 645 if (!tileId) 646 return failure(); 647 648 // Create an 'all true' predicate for the tile slice. 649 auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type()); 650 auto allTruePredicate = rewriter.create<arith::ConstantOp>( 651 loc, DenseElementsAttr::get(predicateType, true)); 652 653 // Zero destination/fallback for tile slice extraction. 654 auto zeroVector = rewriter.create<arith::ConstantOp>( 655 loc, sliceType, rewriter.getZeroAttr(sliceType)); 656 657 // Cast tile slice from index to i32 for intrinsic. 658 auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>( 659 loc, rewriter.getI32Type(), sliceIndex); 660 661 // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice. 662 switch (extractTileSlice.getLayout()) { 663 case arm_sme::TileSliceLayout::Horizontal: 664 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>( 665 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId, 666 sliceIndexI32); 667 break; 668 case arm_sme::TileSliceLayout::Vertical: 669 rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>( 670 extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId, 671 sliceIndexI32); 672 break; 673 } 674 675 return success(); 676 } 677 }; 678 679 /// Lower `arm_sme.outerproduct` to SME MOPA intrinsics. 680 /// 681 /// Example: 682 /// 683 /// %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc) 684 /// : vector<[4]xf32>, vector<[4]xf32> 685 /// 686 /// is converted to: 687 /// 688 /// "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}> 689 /// : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, 690 /// vector<[4]xf32>) -> () 691 /// 692 /// Currently only supports FMOPA and BFMOPA (non-widening). 693 struct OuterProductOpConversion 694 : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> { 695 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; 696 697 LogicalResult 698 matchAndRewrite(arm_sme::OuterProductOp outerProductOp, 699 arm_sme::OuterProductOp::Adaptor adaptor, 700 ConversionPatternRewriter &rewriter) const override { 701 auto tileId = getTileIdOrError(outerProductOp); 702 if (!tileId) 703 return failure(); 704 705 auto isSupportedType = [](VectorType vectorType) { 706 // TODO: the FP outer product instruction variants are predicated on 707 // different features [1]: 708 // 709 // * FMOPA (non-widening) 710 // * half-precision - +sme2p1,+sme-f16f16 711 // * single-precision - +sme 712 // * double-precision - +sme-f64f64 713 // * BFMOPA 714 // * half-precision - +sme2p1,+b16b16 715 // 716 // It should be possible to control lowering based on target features. 717 // [1] 718 // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile 719 if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable()) 720 return false; 721 722 auto elementType = vectorType.getElementType(); 723 724 if (!elementType.isF16() && !elementType.isBF16() && 725 !elementType.isF32() && !elementType.isF64()) 726 return false; 727 728 unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits / 729 vectorType.getElementTypeBitWidth(); 730 return vectorType.getShape() == 731 ArrayRef<int64_t>({minNumElts, minNumElts}); 732 }; 733 734 // TODO: Support CombiningKind::Sub for outer products. 735 if (outerProductOp.getKind() != arm_sme::CombiningKind::Add) 736 return outerProductOp.emitError("unsupported kind"); 737 738 auto resultVectorType = outerProductOp.getResultType(); 739 if (!isSupportedType(resultVectorType)) 740 return outerProductOp.emitError("unsupported type"); 741 742 auto loc = outerProductOp.getLoc(); 743 744 Value acc = outerProductOp.getAcc(); 745 if (!acc) { 746 // Initalize accumulator with zero. 747 auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType); 748 zero.setTileId(tileId); 749 acc = zero; 750 } 751 752 Value lhsMask = outerProductOp.getLhsMask(); 753 Value rhsMask = outerProductOp.getRhsMask(); 754 755 if (!lhsMask || !rhsMask) { 756 auto predTy = 757 outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type()); 758 Value allActiveMask = rewriter.create<arith::ConstantOp>( 759 loc, DenseElementsAttr::get(predTy, true)); 760 lhsMask = allActiveMask; 761 rhsMask = allActiveMask; 762 } 763 764 // Create 'arm_sme.intr.mopa' outer product intrinsic. 765 rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask, 766 outerProductOp.getLhs(), 767 outerProductOp.getRhs()); 768 769 // The outerproduct intrinsics have no result, replace 770 // 'arm_sme.outerproduct' with the input tile to preserve dataflow. 771 rewriter.replaceOp(outerProductOp, acc); 772 773 return success(); 774 } 775 }; 776 777 /// Lower 2-way and 4-way widening outer products to intrinsics. 778 template <class OuterProductWideningOp, class OuterProductWideningIntrOp> 779 struct OuterProductWideningOpConversion 780 : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> { 781 using ConvertArmSMEOpToLLVMPattern< 782 OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern; 783 784 LogicalResult 785 matchAndRewrite(OuterProductWideningOp op, 786 typename OuterProductWideningOp::Adaptor adaptor, 787 ConversionPatternRewriter &rewriter) const override { 788 auto tileId = getTileIdOrError(op); 789 if (!tileId) 790 return failure(); 791 792 auto loc = op.getLoc(); 793 Value acc = op.getAcc(); 794 if (!acc) { 795 // Initalize accumulator with zero. 796 auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType()); 797 zero.setTileId(tileId); 798 acc = zero; 799 } 800 801 Value lhsMask = op.getLhsMask(); 802 Value rhsMask = op.getRhsMask(); 803 if (!lhsMask || !rhsMask) { 804 auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type()); 805 Value allActiveMask = rewriter.create<arith::ConstantOp>( 806 loc, DenseElementsAttr::get(predTy, true)); 807 lhsMask = allActiveMask; 808 rhsMask = allActiveMask; 809 } 810 811 rewriter.create<OuterProductWideningIntrOp>( 812 loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs()); 813 814 // The outerproduct intrinsics have no result, replace 815 // 'arm_sme.outerproduct' with the input tile to preserve dataflow. 816 rewriter.replaceOp(op, acc); 817 818 return success(); 819 } 820 }; 821 822 /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics. 823 /// 824 /// Example: 825 /// 826 /// %0 = arm_sme.streaming_vl <half> 827 /// 828 /// is converted to: 829 /// 830 /// %cnt = "arm_sme.intr.cntsh"() : () -> i64 831 /// %0 = arith.index_cast %cnt : i64 to index 832 /// 833 struct StreamingVLOpConversion 834 : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp, 835 RequiresSpillsAndFills::No> { 836 using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; 837 838 LogicalResult 839 matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp, 840 arm_sme::StreamingVLOp::Adaptor adaptor, 841 ConversionPatternRewriter &rewriter) const override { 842 auto loc = streamingVlOp.getLoc(); 843 auto i64Type = rewriter.getI64Type(); 844 auto *intrOp = [&]() -> Operation * { 845 switch (streamingVlOp.getTypeSize()) { 846 case arm_sme::TypeSize::Byte: 847 return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type); 848 case arm_sme::TypeSize::Half: 849 return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type); 850 case arm_sme::TypeSize::Word: 851 return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type); 852 case arm_sme::TypeSize::Double: 853 return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type); 854 } 855 llvm_unreachable("unknown type size in StreamingVLOpConversion"); 856 }(); 857 rewriter.replaceOpWithNewOp<arith::IndexCastOp>( 858 streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0)); 859 return success(); 860 } 861 }; 862 863 /// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise 864 /// or-ing the zero masks. Note: In future the backend _should_ handle this. 865 static void mergeConsecutiveTileZerosInBlock(Block *block) { 866 uint32_t mergedZeroMask = 0; 867 SmallVector<arm_sme::aarch64_sme_zero, 16> zeroOpsToMerge; 868 auto replaceMergedZeroOps = [&] { 869 auto cleanup = llvm::make_scope_exit([&] { 870 mergedZeroMask = 0; 871 zeroOpsToMerge.clear(); 872 }); 873 if (zeroOpsToMerge.size() <= 1) 874 return; 875 IRRewriter rewriter(zeroOpsToMerge.front()); 876 rewriter.create<arm_sme::aarch64_sme_zero>( 877 zeroOpsToMerge.front().getLoc(), 878 rewriter.getI32IntegerAttr(mergedZeroMask)); 879 for (auto zeroOp : zeroOpsToMerge) 880 rewriter.eraseOp(zeroOp); 881 }; 882 for (Operation &op : *block) { 883 if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) { 884 mergedZeroMask |= zeroOp.getTileMask(); 885 zeroOpsToMerge.push_back(zeroOp); 886 } else { 887 replaceMergedZeroOps(); 888 } 889 } 890 replaceMergedZeroOps(); 891 } 892 893 } // namespace 894 895 namespace { 896 897 struct ConvertArmSMEToLLVMPass 898 : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> { 899 ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) { 900 this->dumpTileLiveRanges = dumpTileLiveRanges; 901 } 902 void runOnOperation() override { 903 auto function = getOperation(); 904 905 if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges))) 906 return signalPassFailure(); 907 908 LLVMConversionTarget target(getContext()); 909 RewritePatternSet patterns(&getContext()); 910 LLVMTypeConverter converter(&getContext()); 911 configureArmSMEToLLVMConversionLegality(target); 912 populateArmSMEToLLVMConversionPatterns(converter, patterns); 913 914 if (failed(applyPartialConversion(function, target, std::move(patterns)))) 915 signalPassFailure(); 916 917 function->walk(mergeConsecutiveTileZerosInBlock); 918 919 // Walk the function and fail if there are unexpected operations on SME 920 // tile types after conversion. 921 function->walk([&](Operation *op) { 922 // These ops are legal post conversion, skip these. 923 if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) || 924 !op->isRegistered()) 925 return; 926 auto isSMETileType = [](Type type) { 927 return arm_sme::isValidSMETileVectorType(type); 928 }; 929 if (llvm::any_of(op->getResultTypes(), isSMETileType) || 930 llvm::any_of(op->getOperandTypes(), isSMETileType)) { 931 op->emitOpError("unexpected operation with SME tile type after " 932 "conversion to LLVM"); 933 signalPassFailure(); 934 } 935 }); 936 } 937 }; 938 939 } // namespace 940 941 void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { 942 target.addIllegalDialect<arm_sme::ArmSMEDialect>(); 943 target.addLegalOp< 944 arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str, 945 arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz, 946 arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz, 947 arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz, 948 arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz, 949 arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz, 950 arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert, 951 arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert, 952 arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert, 953 arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert, 954 arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert, 955 arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert, 956 arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert, 957 arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide, 958 arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide, 959 arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide, 960 arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32, 961 arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32, 962 arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide, 963 arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide, 964 arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb, 965 arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw, 966 arm_sme::aarch64_sme_cntsd>(); 967 target.addLegalDialect<arith::ArithDialect, 968 /* The following are used to lower tile spills/fills */ 969 vector::VectorDialect, scf::SCFDialect, 970 memref::MemRefDialect>(); 971 // Pseudo operations. These cannot be code-generated but may exist in the 972 // input IR, or be generated during the conversion. They need to be eliminated 973 // before the final conversion to LLVM IR (and likely will be due to DCE). 974 target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp, 975 UnrealizedConversionCastOp>(); 976 } 977 978 void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, 979 RewritePatternSet &patterns) { 980 converter.addConversion([&](VectorType type) -> std::optional<Type> { 981 // There's no LLVM type for SME tiles, but after lowering to intrinsics all 982 // SME vector types should be eliminated. 983 if (arm_sme::isValidSMETileVectorType(type)) 984 return type; 985 return std::nullopt; 986 }); 987 988 addArmSMEConversionPatterns< 989 LoadTileSliceConversion, ExtractTileSliceConversion, 990 InsertTileSliceConversion, StoreTileSliceConversion, 991 StreamingVLOpConversion, OuterProductOpConversion, 992 OuterProductWideningOpConversion<arm_sme::FMopa2WayOp, 993 arm_sme::aarch64_sme_mopa_wide>, 994 OuterProductWideningOpConversion<arm_sme::FMops2WayOp, 995 arm_sme::aarch64_sme_mops_wide>, 996 OuterProductWideningOpConversion<arm_sme::SMopa2WayOp, 997 arm_sme::aarch64_sme_smopa_za32>, 998 OuterProductWideningOpConversion<arm_sme::SMops2WayOp, 999 arm_sme::aarch64_sme_smops_za32>, 1000 OuterProductWideningOpConversion<arm_sme::UMopa2WayOp, 1001 arm_sme::aarch64_sme_umopa_za32>, 1002 OuterProductWideningOpConversion<arm_sme::UMops2WayOp, 1003 arm_sme::aarch64_sme_umops_za32>, 1004 OuterProductWideningOpConversion<arm_sme::SMopa4WayOp, 1005 arm_sme::aarch64_sme_smopa_wide>, 1006 OuterProductWideningOpConversion<arm_sme::SMops4WayOp, 1007 arm_sme::aarch64_sme_smops_wide>, 1008 OuterProductWideningOpConversion<arm_sme::UMopa4WayOp, 1009 arm_sme::aarch64_sme_umopa_wide>, 1010 OuterProductWideningOpConversion<arm_sme::UMops4WayOp, 1011 arm_sme::aarch64_sme_umops_wide>, 1012 OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp, 1013 arm_sme::aarch64_sme_sumopa_wide>, 1014 OuterProductWideningOpConversion<arm_sme::SuMops4WayOp, 1015 arm_sme::aarch64_sme_sumops_wide>, 1016 OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp, 1017 arm_sme::aarch64_sme_usmopa_wide>, 1018 OuterProductWideningOpConversion<arm_sme::UsMops4WayOp, 1019 arm_sme::aarch64_sme_usmops_wide>, 1020 ZeroOpConversion>(patterns, converter); 1021 } 1022 1023 std::unique_ptr<Pass> 1024 mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) { 1025 return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges); 1026 } 1027