xref: /llvm-project/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (revision d5746d73cedcf7a593dc4b4f2ce2465e2d45750b)
1 //===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of ArmSME operations to LLVM intrinsics.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
14 
15 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16 #include "mlir/Conversion/LLVMCommon/Pattern.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
19 #include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
20 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
21 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"
23 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24 #include "mlir/Dialect/MemRef/IR/MemRef.h"
25 #include "mlir/Dialect/Vector/IR/VectorOps.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/DialectConversion.h"
28 #include "llvm/ADT/ScopeExit.h"
29 
30 namespace mlir {
31 #define GEN_PASS_DEF_CONVERTARMSMETOLLVM
32 #include "mlir/Conversion/Passes.h.inc"
33 } // namespace mlir
34 
35 using namespace mlir;
36 
37 namespace {
38 
39 static constexpr StringLiteral kInMemoryTileIdAttr("arm_sme.in_memory_tile_id");
40 
41 /// Helper to create an arm_sme.intr.ld1*.(horiz|vert)' intrinsic.
42 static Operation *createLoadTileSliceIntrinsic(
43     RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
44     arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
45     IntegerAttr tileId, Value tileSliceI32) {
46   if (layout == arm_sme::TileSliceLayout::Horizontal) {
47     switch (type) {
48     case arm_sme::ArmSMETileType::ZAB:
49       return rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
50           loc, maskOp, ptr, tileId, tileSliceI32);
51     case arm_sme::ArmSMETileType::ZAH:
52       return rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
53           loc, maskOp, ptr, tileId, tileSliceI32);
54     case arm_sme::ArmSMETileType::ZAS:
55       return rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
56           loc, maskOp, ptr, tileId, tileSliceI32);
57     case arm_sme::ArmSMETileType::ZAD:
58       return rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
59           loc, maskOp, ptr, tileId, tileSliceI32);
60     case arm_sme::ArmSMETileType::ZAQ:
61       return rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
62           loc, maskOp, ptr, tileId, tileSliceI32);
63     }
64   } else {
65     switch (type) {
66     case arm_sme::ArmSMETileType::ZAB:
67       return rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(
68           loc, maskOp, ptr, tileId, tileSliceI32);
69     case arm_sme::ArmSMETileType::ZAH:
70       return rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(
71           loc, maskOp, ptr, tileId, tileSliceI32);
72     case arm_sme::ArmSMETileType::ZAS:
73       return rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(
74           loc, maskOp, ptr, tileId, tileSliceI32);
75     case arm_sme::ArmSMETileType::ZAD:
76       return rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(
77           loc, maskOp, ptr, tileId, tileSliceI32);
78     case arm_sme::ArmSMETileType::ZAQ:
79       return rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(
80           loc, maskOp, ptr, tileId, tileSliceI32);
81       break;
82     }
83   }
84   llvm_unreachable("unknown type in createLoadTileSliceIntrinsic");
85 }
86 
87 /// Helper to create an arm_sme.intr.st1*.(horiz|vert)' intrinsic.
88 static Operation *createStoreTileSliceIntrinsic(
89     RewriterBase &rewriter, Location loc, arm_sme::ArmSMETileType type,
90     arm_sme::TileSliceLayout layout, Value maskOp, Value ptr,
91     IntegerAttr tileId, Value tileSliceI32) {
92   if (layout == arm_sme::TileSliceLayout::Horizontal) {
93     switch (type) {
94     case arm_sme::ArmSMETileType::ZAB:
95       return rewriter.create<arm_sme::aarch64_sme_st1b_horiz>(
96           loc, maskOp, ptr, tileId, tileSliceI32);
97     case arm_sme::ArmSMETileType::ZAH:
98       return rewriter.create<arm_sme::aarch64_sme_st1h_horiz>(
99           loc, maskOp, ptr, tileId, tileSliceI32);
100     case arm_sme::ArmSMETileType::ZAS:
101       return rewriter.create<arm_sme::aarch64_sme_st1w_horiz>(
102           loc, maskOp, ptr, tileId, tileSliceI32);
103     case arm_sme::ArmSMETileType::ZAD:
104       return rewriter.create<arm_sme::aarch64_sme_st1d_horiz>(
105           loc, maskOp, ptr, tileId, tileSliceI32);
106     case arm_sme::ArmSMETileType::ZAQ:
107       return rewriter.create<arm_sme::aarch64_sme_st1q_horiz>(
108           loc, maskOp, ptr, tileId, tileSliceI32);
109     }
110   } else {
111     switch (type) {
112     case arm_sme::ArmSMETileType::ZAB:
113       return rewriter.create<arm_sme::aarch64_sme_st1b_vert>(
114           loc, maskOp, ptr, tileId, tileSliceI32);
115     case arm_sme::ArmSMETileType::ZAH:
116       return rewriter.create<arm_sme::aarch64_sme_st1h_vert>(
117           loc, maskOp, ptr, tileId, tileSliceI32);
118     case arm_sme::ArmSMETileType::ZAS:
119       return rewriter.create<arm_sme::aarch64_sme_st1w_vert>(
120           loc, maskOp, ptr, tileId, tileSliceI32);
121     case arm_sme::ArmSMETileType::ZAD:
122       return rewriter.create<arm_sme::aarch64_sme_st1d_vert>(
123           loc, maskOp, ptr, tileId, tileSliceI32);
124     case arm_sme::ArmSMETileType::ZAQ:
125       return rewriter.create<arm_sme::aarch64_sme_st1q_vert>(
126           loc, maskOp, ptr, tileId, tileSliceI32);
127     }
128   }
129   llvm_unreachable("unknown type in createStoreTileSliceIntrinsic");
130 }
131 
132 IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
133   auto tileId = op.getTileId();
134   if (!tileId)
135     op.emitOpError(
136         "expected tile ID to be allocated before conversion to LLVM");
137   return tileId;
138 }
139 
140 /// Creates an alloca matching the size of tile used by `tileOp`. The alloca is
141 /// placed in the first block of the function.
142 static memref::AllocaOp
143 createAllocaForTile(RewriterBase &rewriter, Location loc,
144                     FunctionOpInterface func,
145                     arm_sme::ArmSMETileOpInterface tileOp) {
146   RewriterBase::InsertionGuard g(rewriter);
147   // Move to the first operation in the function.
148   rewriter.setInsertionPointToStart(&func.getBlocks().front());
149   // Create an alloca matching the tile size of the `tileOp`.
150   auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
151   auto tileElementType = tileOp.getTileType().getElementType();
152   auto memrefType = MemRefType::get(
153       {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType);
154   unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType);
155   auto minElementsOp =
156       rewriter.create<arith::ConstantIndexOp>(loc, minElements);
157   auto vectorLen = rewriter.create<arith::MulIOp>(loc, vscale, minElementsOp);
158   auto alloca = rewriter.create<memref::AllocaOp>(
159       loc, memrefType, ValueRange{vectorLen, vectorLen});
160   return alloca;
161 }
162 
163 /// Finds or creates an alloca for a spill of a tile.
164 static memref::AllocaOp getOrCreateAllocaForTile(
165     RewriterBase &rewriter, Location loc, FunctionOpInterface func,
166     arm_sme::ArmSMETileOpInterface tileOp, unsigned tileId) {
167   // Find an alloca at the top of the function tagged with a
168   // 'arm_sme.in_memory_tile_id' that matches `tileId`.
169   for (auto &op : func.getBlocks().front()) {
170     auto alloca = llvm::dyn_cast<memref::AllocaOp>(op);
171     if (!alloca)
172       continue;
173     auto inMemoryTileId = llvm::dyn_cast_or_null<IntegerAttr>(
174         alloca->getDiscardableAttr(kInMemoryTileIdAttr));
175     if (!inMemoryTileId)
176       continue;
177     if (inMemoryTileId.getInt() == tileId)
178       return alloca;
179   }
180   // Otherwise, create a new alloca:
181   auto alloca = createAllocaForTile(rewriter, loc, func, tileOp);
182   alloca->setDiscardableAttr(kInMemoryTileIdAttr,
183                              rewriter.getI32IntegerAttr(tileId));
184   return alloca;
185 }
186 
187 /// Very naive lowering of in-memory tiles (i.e. tiles that were not assigned a
188 /// hardware tile ID) to ArmSME intrinsics. Currently, this works by assigning
189 /// the op to tile 0, then emitting a full tile swap between ZA and memory
190 /// before + after the tile op.
191 ///
192 /// Example:
193 ///
194 ///    // Note: <IN MEMORY TILE> = tile ID >= 16.
195 ///    arm_sme.tile_op { tile_id = <IN MEMORY TILE> }
196 ///
197 /// is converted to:
198 ///     // At function entry:
199 ///     %spill = memref.alloca ... : memref<?x?xty>
200 ///
201 ///     // Around op:
202 ///     scf.for %slice_idx {
203 ///       %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
204 ///       "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx)  <{tile_id = 0 : i32}>
205 ///       vector.store %slice_to_save, %spill[%slice_idx, %c0]
206 ///     }
207 ///     arm_sme.tile_op { tile_id = 0 }
208 ///     scf.for %slice_idx {
209 ///       %slice_to_save = "arm_sme.intr.read.horiz" ... <{tile_id = 0 : i32}>
210 ///       "arm_sme.intr.ld1h.horiz"(%spill, %slice_idx)  <{tile_id = 0 : i32}>
211 ///       vector.store %slice_to_save, %spill[%slice_idx, %c0]
212 ///     }
213 ///
214 /// Note that these spills/fills are not inserted earlier as concept of a
215 /// register, and the need to swap the contents, can't really be represented
216 /// correctly at a high level in MLIR.
217 ///
218 /// TODO: Reduce the spills/reloads to single slices where possible (and omit
219 /// redundant reloads). This could be done via a method on the
220 /// `ArmSMETileOpInterface` which returns how the operation uses ZA. E.g.:
221 ///
222 /// `tileOp.getZaUsage()` could return:
223 ///
224 /// struct ArmSMEOpZAUsage {
225 ///   enum class Kind {
226 ///     TileRead,        // Omit store after tile operation.
227 ///     TileWrite,       // Omit load before tile operation.
228 ///     TileReadWrite,   // Needs both tile load and store.
229 ///     SliceRead,       // Spill single slice and omit store after operation.
230 ///     SliceWrite,      // Spill single slice and omit load before operation.
231 ///     SliceReadWrite   // Spill single slice.
232 ///   };
233 ///   Value sliceIndex {};
234 ///   TileSliceLayout sliceLayout { TileSliceLayout::Horizontal };
235 /// };
236 ///
237 struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
238 
239   ConvertArmSMESpillsAndFillsToLLVM(StringRef rootOpName,
240                                     const LLVMTypeConverter &typeConverter,
241                                     PatternBenefit benefit)
242       : ConvertToLLVMPattern(rootOpName, &typeConverter.getContext(),
243                              typeConverter, benefit) {}
244 
245   LogicalResult
246   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
247                   ConversionPatternRewriter &rewriter) const override {
248     auto tileOp = cast<arm_sme::ArmSMETileOpInterface>(op);
249     // Tile has a real (hardware) tile. No spills/reloads required.
250     if (!tileOp.isInMemoryTile())
251       return failure();
252 
253     tileOp->emitWarning(
254         "failed to allocate SME virtual tile to operation, tile value will go "
255         "through memory, expect degraded performance");
256 
257     // Step 1. Create an alloca for the tile at the top of the function (if one
258     // does not already exist).
259     auto loc = tileOp.getLoc();
260     auto func = tileOp->getParentOfType<FunctionOpInterface>();
261     auto tileAlloca = getOrCreateAllocaForTile(rewriter, loc, func, tileOp,
262                                                tileOp.getTileId().getInt());
263 
264     // Step 2. Assign the op a real tile ID.
265     // For simplicity, we always use tile 0 (which always exists).
266     auto zeroTileId = rewriter.getI32IntegerAttr(0);
267     rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(zeroTileId); });
268 
269     VectorType tileVectorType = tileOp.getTileType();
270     auto sliceType = VectorType::Builder(tileVectorType).dropDim(0);
271     auto swapInMemoryTileWithSMETileZero = [&] {
272       emitFullTileSwap(rewriter, loc, tileAlloca,
273                        *arm_sme::getSMETileType(tileVectorType), sliceType,
274                        zeroTileId);
275     };
276 
277     // Step 3. Emit tile swaps before and after the op.
278     // TODO: Reduce the amount spilled to the amount of data the `tileOp`
279     // touches (i.e. a single tile slice).
280     {
281       rewriter.setInsertionPoint(op);
282       // Swap the contents of ZA and the in-memory tile before the op.
283       swapInMemoryTileWithSMETileZero();
284       rewriter.setInsertionPointAfter(op);
285       // Swap the tile back out to memory again after the op.
286       swapInMemoryTileWithSMETileZero();
287     }
288 
289     return success();
290   }
291 
292   /// Extracts a pointer to a slice of an in-memory tile.
293   Value getInMemoryTileSlicePtr(RewriterBase &rewriter, Location loc,
294                                 Value tileMemory, Value sliceIndex) const {
295     auto llvmType = getTypeConverter()->convertType(tileMemory.getType());
296     auto descriptor =
297         rewriter.create<UnrealizedConversionCastOp>(loc, llvmType, tileMemory);
298     auto zero = rewriter.create<arith::ConstantIntOp>(loc, 0, /*width=*/64);
299     auto sliceIndexI64 = rewriter.create<arith::IndexCastOp>(
300         loc, rewriter.getI64Type(), sliceIndex);
301     return getStridedElementPtr(
302         loc, llvm::cast<MemRefType>(tileMemory.getType()),
303         descriptor.getResult(0), {sliceIndexI64, zero},
304         static_cast<ConversionPatternRewriter &>(rewriter));
305   }
306 
307   /// Emits an in-place swap of a slice of a tile in ZA and a slice of a
308   /// tile-sized memref (`tileAlloca`).
309   void emitSliceSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
310                      arm_sme::ArmSMETileType tileType, VectorType sliceType,
311                      IntegerAttr tileId, Value sliceIndex) const {
312     // Cast the slice index to an i32.
313     auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
314         loc, rewriter.getI32Type(), sliceIndex);
315     // Create an all-true predicate for the slice.
316     auto predicateType = sliceType.clone(rewriter.getI1Type());
317     auto allTruePredicate = rewriter.create<arith::ConstantOp>(
318         loc, DenseElementsAttr::get(predicateType, true));
319     // Create padding vector (never used due to all-true predicate).
320     auto padVector = rewriter.create<LLVM::UndefOp>(loc, sliceType);
321     // Get a pointer to the current slice.
322     auto slicePtr =
323         getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex);
324     // Read the value of the current slice from ZA.
325     auto currentTileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
326         loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32);
327     // Load the new tile slice back from memory into ZA.
328     createLoadTileSliceIntrinsic(
329         rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal,
330         allTruePredicate, slicePtr, tileId, sliceIndexI32);
331     // Store the current tile slice to memory.
332     auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
333     rewriter.create<vector::StoreOp>(loc, currentTileSlice, tileAlloca,
334                                      ValueRange{sliceIndex, zero});
335   }
336 
337   /// Emits a full in-place swap of the contents of a tile in ZA and a
338   /// tile-sized memref (`tileAlloca`).
339   void emitFullTileSwap(RewriterBase &rewriter, Location loc, Value tileAlloca,
340                         arm_sme::ArmSMETileType tileType, VectorType sliceType,
341                         IntegerAttr tileId) const {
342     RewriterBase::InsertionGuard guard(rewriter);
343     // Create an scf.for over all tile slices.
344     auto minNumElts =
345         rewriter.create<arith::ConstantIndexOp>(loc, sliceType.getDimSize(0));
346     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
347     auto upperBound = rewriter.create<arith::MulIOp>(
348         loc, minNumElts, rewriter.create<vector::VectorScaleOp>(loc));
349     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
350     auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
351     // Emit a swap for each tile slice.
352     rewriter.setInsertionPointToStart(forOp.getBody());
353     auto sliceIndex = forOp.getInductionVar();
354     emitSliceSwap(rewriter, loc, tileAlloca, tileType, sliceType, tileId,
355                   sliceIndex);
356   }
357 };
358 
359 enum class RequiresSpillsAndFills { Yes, No };
360 
361 /// Base class for ArmSME to LLVM conversion patterns. By default, this adds
362 /// spills and fills around ArmSME ops that use in-memory tile IDs. This can be
363 /// disabled by setting the `requiresSpillsAndFills` template parameter to
364 /// `RequiresSpillsAndFills::No`.
365 template <typename SourceOp, RequiresSpillsAndFills requiresSpillsAndFills =
366                                  RequiresSpillsAndFills::Yes>
367 struct ConvertArmSMEOpToLLVMPattern : ConvertOpToLLVMPattern<SourceOp> {
368   using ArmSMEOp = SourceOp;
369   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
370 
371   static constexpr bool requiresSpillsAndFillsConversion() {
372     return requiresSpillsAndFills == RequiresSpillsAndFills::Yes;
373   }
374 };
375 
376 template <typename Pattern>
377 static void addArmSMEConversionPattern(RewritePatternSet &patterns,
378                                        LLVMTypeConverter const &typeConverter) {
379   // Register spills/fills for ops that implement the
380   // `ArmSMETileOpInterface` and have `requiresSpillsAndFills` set to
381   // `RequiresSpillsAndFills::Yes`.
382   if constexpr (Pattern::requiresSpillsAndFillsConversion() &&
383                 std::is_base_of_v<arm_sme::ArmSMETileOpInterface::Trait<
384                                       typename Pattern::ArmSMEOp>,
385                                   typename Pattern::ArmSMEOp>) {
386     // Add spill/fill conversions with a very high benefit to ensure
387     // they are lowered first.
388     patterns.add<ConvertArmSMESpillsAndFillsToLLVM>(
389         Pattern::ArmSMEOp::getOperationName(), typeConverter,
390         /*benefit=*/1337);
391   }
392   patterns.add<Pattern>(typeConverter);
393 }
394 
395 /// Helper to register `ConvertArmSMEOpToLLVMPattern` patterns.
396 template <typename... Patterns>
397 static void
398 addArmSMEConversionPatterns(RewritePatternSet &patterns,
399                             LLVMTypeConverter const &typeConverter) {
400   (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
401 }
402 
403 /// Lower 'arm_sme.zero' to SME intrinsics.
404 ///
405 ///  BEFORE:
406 ///  ```mlir
407 ///     %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
408 ///  ```
409 ///
410 ///  AFTER:
411 ///  ```mlir
412 ///     "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
413 ///     %v = arm_sme.get_tile : vector<[4]x[4]xi32>
414 ///  ```
415 ///
416 ///  The 'arm_sme.get_tile' (which models the return) will fold away once all
417 ///  ArmSME ops have been converted to LLVM intrinsics.
418 struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
419   using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
420 
421   LogicalResult
422   matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor,
423                   ConversionPatternRewriter &rewriter) const override {
424     auto loc = zero.getLoc();
425 
426     auto tileId = getTileIdOrError(zero);
427     if (!tileId)
428       return failure();
429 
430     // Get the base mask for tile based on the element size.
431     // The base mask is just the mask to zero the first tile (of a size).
432     // These masks are derived from:
433     // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
434     arm_sme::ArmSMETileType tileType =
435         *arm_sme::getSMETileType(zero.getTileType());
436     auto baseMaskForSize = [&] {
437       switch (tileType) {
438       case arm_sme::ArmSMETileType::ZAB:
439         // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight
440         // 64-bit element tiles named ZA0.D to ZA7.D.
441         return 0b1111'1111;
442       case arm_sme::ArmSMETileType::ZAH:
443         // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit
444         // element tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. Shift this left
445         // once for ZA1.H.
446         return 0b0101'0101;
447       case arm_sme::ArmSMETileType::ZAS:
448         // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit
449         // element tiles named ZA0.D and ZA4.D.
450         // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S.
451         return 0b0001'0001;
452       case arm_sme::ArmSMETileType::ZAD:
453         // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires
454         // setting the bit for that tile.
455         return 0b0000'0001;
456       default:
457         llvm_unreachable("bad element size");
458       }
459     }();
460 
461     // The actual mask is just the base mask shifted by the tile ID.
462     // This will be folded to a constant after tile allocation.
463     //
464     // The shift is just derived from the layout of the tiles, and that the tile
465     // ID is the index of the tile. For example, looking at the 32-bit ZAx.S
466     // tiles:
467     //
468     // ZA0.S = ZA0.D and ZA4.D
469     //  * Tile ID -> 0
470     //  * Mask    -> 00010001 = (00010001 << 0)
471     // ZA1.S = ZA1.D and ZA5.D
472     //  * Tile ID -> 1
473     //  * Mask    -> 00100010 = (00010001 << 1)
474     // ZA2.S = ZA2.D and ZA6.D
475     //  * Tile ID -> 2
476     //  * Mask    -> 01000100 = (00010001 << 2)
477     // ZA3.S = ZA3.D and ZA7.D
478     //  * Tile ID -> 3
479     //  * Mask    -> 10001000 = (00010001 << 3)
480     //
481     // This holds for all tile sizes.
482     int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
483     rewriter.create<arm_sme::aarch64_sme_zero>(
484         loc, rewriter.getI32IntegerAttr(zeroMask));
485 
486     // Create a placeholder op to preserve dataflow.
487     // Note: Place the `get_tile` op at the start of the block. This ensures
488     // that if there are multiple `zero` ops the intrinsics will be consecutive.
489     rewriter.setInsertionPointToStart(zero->getBlock());
490     rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
491 
492     return success();
493   }
494 };
495 
496 /// Lower `arm_sme.load_tile_slice` to SME intrinsics.
497 struct LoadTileSliceConversion
498     : public ConvertArmSMEOpToLLVMPattern<arm_sme::LoadTileSliceOp> {
499   using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
500 
501   LogicalResult
502   matchAndRewrite(arm_sme::LoadTileSliceOp loadTileSliceOp,
503                   arm_sme::LoadTileSliceOp::Adaptor adaptor,
504                   ConversionPatternRewriter &rewriter) const override {
505     auto loc = loadTileSliceOp.getLoc();
506     auto tileId = getTileIdOrError(loadTileSliceOp);
507     if (!tileId)
508       return failure();
509 
510     Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
511                                            adaptor.getBase(),
512                                            adaptor.getIndices(), rewriter);
513 
514     auto tileSlice = loadTileSliceOp.getTileSliceIndex();
515 
516     // Cast tile slice to i32 for intrinsic.
517     auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
518         loc, rewriter.getI32Type(), tileSlice);
519 
520     // Create all active predicate mask.
521     auto maskOp = loadTileSliceOp.getMask();
522 
523     auto tileVectorType = loadTileSliceOp.getVectorType();
524     arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
525     arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
526 
527     // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
528     createLoadTileSliceIntrinsic(rewriter, loc, tileType, layout, maskOp, ptr,
529                                  tileId, tileSliceI32);
530 
531     // The load intrinsics have no result, replace 'arm_sme.tile_load' with
532     // the input tile to preserve dataflow.
533     rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
534 
535     return success();
536   }
537 };
538 
539 /// Lower for `arm_sme.store_tile_slice` to SME intrinsics.
540 struct StoreTileSliceConversion
541     : public ConvertArmSMEOpToLLVMPattern<arm_sme::StoreTileSliceOp> {
542   using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
543 
544   LogicalResult
545   matchAndRewrite(arm_sme::StoreTileSliceOp storeTileSliceOp,
546                   arm_sme::StoreTileSliceOp::Adaptor adaptor,
547                   ConversionPatternRewriter &rewriter) const override {
548     auto loc = storeTileSliceOp.getLoc();
549     auto tileVectorType = storeTileSliceOp.getVectorType();
550 
551     auto tileId = getTileIdOrError(storeTileSliceOp);
552     if (!tileId)
553       return failure();
554 
555     // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
556     Value ptr = this->getStridedElementPtr(
557         loc, storeTileSliceOp.getMemRefType(), adaptor.getBase(),
558         adaptor.getIndices(), rewriter);
559 
560     auto tileSlice = storeTileSliceOp.getTileSliceIndex();
561 
562     // Cast tile slice to i32 for intrinsic.
563     auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
564         loc, rewriter.getI32Type(), tileSlice);
565 
566     auto maskOp = storeTileSliceOp.getMask();
567 
568     arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
569     arm_sme::ArmSMETileType tileType = *arm_sme::getSMETileType(tileVectorType);
570 
571     rewriter.replaceOp(storeTileSliceOp,
572                        createStoreTileSliceIntrinsic(rewriter, loc, tileType,
573                                                      layout, maskOp, ptr,
574                                                      tileId, tileSliceI32));
575 
576     return success();
577   }
578 };
579 
580 /// Lower `arm_sme.insert_tile_slice` to SME intrinsics.
581 struct InsertTileSliceConversion
582     : public ConvertArmSMEOpToLLVMPattern<arm_sme::InsertTileSliceOp> {
583   using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
584 
585   LogicalResult
586   matchAndRewrite(arm_sme::InsertTileSliceOp insertTileSliceOp,
587                   arm_sme::InsertTileSliceOp::Adaptor adaptor,
588                   ConversionPatternRewriter &rewriter) const override {
589     auto loc = insertTileSliceOp.getLoc();
590     auto tileType = insertTileSliceOp.getTileType();
591 
592     auto tileId = getTileIdOrError(insertTileSliceOp);
593     if (!tileId)
594       return failure();
595 
596     auto tileSlice = insertTileSliceOp.getTileSliceIndex();
597 
598     // Cast tile slice from index to i32 for intrinsic.
599     auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
600         loc, rewriter.getI32Type(), tileSlice);
601 
602     // Create all active predicate mask.
603     auto one = rewriter.create<arith::ConstantOp>(
604         loc, rewriter.getI1Type(),
605         rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
606     auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
607                                   /*scalableDims=*/{true});
608     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
609 
610     // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
611     switch (insertTileSliceOp.getLayout()) {
612     case arm_sme::TileSliceLayout::Horizontal:
613       rewriter.create<arm_sme::aarch64_sme_write_horiz>(
614           loc, tileId, tileSliceI32, allActiveMask,
615           insertTileSliceOp.getVector());
616       break;
617     case arm_sme::TileSliceLayout::Vertical:
618       rewriter.create<arm_sme::aarch64_sme_write_vert>(
619           loc, tileId, tileSliceI32, allActiveMask,
620           insertTileSliceOp.getVector());
621       break;
622     }
623 
624     // Intrinsic has no result, replace 'arm_sme.insert_tile_slice' with
625     // the input tile to preserve dataflow.
626     rewriter.replaceOp(insertTileSliceOp, insertTileSliceOp.getTile());
627 
628     return success();
629   }
630 };
631 
632 /// Lower `arm_sme.extract_tile_slice` to SME intrinsics.
633 struct ExtractTileSliceConversion
634     : public ConvertArmSMEOpToLLVMPattern<arm_sme::ExtractTileSliceOp> {
635   using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
636 
637   LogicalResult
638   matchAndRewrite(arm_sme::ExtractTileSliceOp extractTileSlice, OpAdaptor,
639                   ConversionPatternRewriter &rewriter) const override {
640     auto loc = extractTileSlice.getLoc();
641     auto sliceType = extractTileSlice.getSliceType();
642     auto sliceIndex = extractTileSlice.getTileSliceIndex();
643 
644     auto tileId = getTileIdOrError(extractTileSlice);
645     if (!tileId)
646       return failure();
647 
648     // Create an 'all true' predicate for the tile slice.
649     auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
650     auto allTruePredicate = rewriter.create<arith::ConstantOp>(
651         loc, DenseElementsAttr::get(predicateType, true));
652 
653     // Zero destination/fallback for tile slice extraction.
654     auto zeroVector = rewriter.create<arith::ConstantOp>(
655         loc, sliceType, rewriter.getZeroAttr(sliceType));
656 
657     // Cast tile slice from index to i32 for intrinsic.
658     auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
659         loc, rewriter.getI32Type(), sliceIndex);
660 
661     // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
662     switch (extractTileSlice.getLayout()) {
663     case arm_sme::TileSliceLayout::Horizontal:
664       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
665           extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
666           sliceIndexI32);
667       break;
668     case arm_sme::TileSliceLayout::Vertical:
669       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
670           extractTileSlice, sliceType, zeroVector, allTruePredicate, tileId,
671           sliceIndexI32);
672       break;
673     }
674 
675     return success();
676   }
677 };
678 
679 /// Lower `arm_sme.outerproduct` to SME MOPA intrinsics.
680 ///
681 /// Example:
682 ///
683 ///   %0 = arm_sme.outerproduct %lhs, %rhs acc(%acc)
684 ///     : vector<[4]xf32>, vector<[4]xf32>
685 ///
686 /// is converted to:
687 ///
688 ///   "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
689 ///     : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
690 ///        vector<[4]xf32>) -> ()
691 ///
692 /// Currently only supports FMOPA and BFMOPA (non-widening).
693 struct OuterProductOpConversion
694     : public ConvertArmSMEOpToLLVMPattern<arm_sme::OuterProductOp> {
695   using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
696 
697   LogicalResult
698   matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
699                   arm_sme::OuterProductOp::Adaptor adaptor,
700                   ConversionPatternRewriter &rewriter) const override {
701     auto tileId = getTileIdOrError(outerProductOp);
702     if (!tileId)
703       return failure();
704 
705     auto isSupportedType = [](VectorType vectorType) {
706       // TODO: the FP outer product instruction variants are predicated on
707       // different features [1]:
708       //
709       // * FMOPA (non-widening)
710       //   * half-precision   - +sme2p1,+sme-f16f16
711       //   * single-precision - +sme
712       //   * double-precision - +sme-f64f64
713       // * BFMOPA
714       //   * half-precision   - +sme2p1,+b16b16
715       //
716       // It should be possible to control lowering based on target features.
717       // [1]
718       // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
719       if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
720         return false;
721 
722       auto elementType = vectorType.getElementType();
723 
724       if (!elementType.isF16() && !elementType.isBF16() &&
725           !elementType.isF32() && !elementType.isF64())
726         return false;
727 
728       unsigned minNumElts = arm_sme::MinStreamingVectorLengthInBits /
729                             vectorType.getElementTypeBitWidth();
730       return vectorType.getShape() ==
731              ArrayRef<int64_t>({minNumElts, minNumElts});
732     };
733 
734     // TODO: Support CombiningKind::Sub for outer products.
735     if (outerProductOp.getKind() != arm_sme::CombiningKind::Add)
736       return outerProductOp.emitError("unsupported kind");
737 
738     auto resultVectorType = outerProductOp.getResultType();
739     if (!isSupportedType(resultVectorType))
740       return outerProductOp.emitError("unsupported type");
741 
742     auto loc = outerProductOp.getLoc();
743 
744     Value acc = outerProductOp.getAcc();
745     if (!acc) {
746       // Initalize accumulator with zero.
747       auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
748       zero.setTileId(tileId);
749       acc = zero;
750     }
751 
752     Value lhsMask = outerProductOp.getLhsMask();
753     Value rhsMask = outerProductOp.getRhsMask();
754 
755     if (!lhsMask || !rhsMask) {
756       auto predTy =
757           outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type());
758       Value allActiveMask = rewriter.create<arith::ConstantOp>(
759           loc, DenseElementsAttr::get(predTy, true));
760       lhsMask = allActiveMask;
761       rhsMask = allActiveMask;
762     }
763 
764     // Create 'arm_sme.intr.mopa' outer product intrinsic.
765     rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
766                                                outerProductOp.getLhs(),
767                                                outerProductOp.getRhs());
768 
769     // The outerproduct intrinsics have no result, replace
770     // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
771     rewriter.replaceOp(outerProductOp, acc);
772 
773     return success();
774   }
775 };
776 
777 /// Lower 2-way and 4-way widening outer products to intrinsics.
778 template <class OuterProductWideningOp, class OuterProductWideningIntrOp>
779 struct OuterProductWideningOpConversion
780     : public ConvertArmSMEOpToLLVMPattern<OuterProductWideningOp> {
781   using ConvertArmSMEOpToLLVMPattern<
782       OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern;
783 
784   LogicalResult
785   matchAndRewrite(OuterProductWideningOp op,
786                   typename OuterProductWideningOp::Adaptor adaptor,
787                   ConversionPatternRewriter &rewriter) const override {
788     auto tileId = getTileIdOrError(op);
789     if (!tileId)
790       return failure();
791 
792     auto loc = op.getLoc();
793     Value acc = op.getAcc();
794     if (!acc) {
795       // Initalize accumulator with zero.
796       auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
797       zero.setTileId(tileId);
798       acc = zero;
799     }
800 
801     Value lhsMask = op.getLhsMask();
802     Value rhsMask = op.getRhsMask();
803     if (!lhsMask || !rhsMask) {
804       auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
805       Value allActiveMask = rewriter.create<arith::ConstantOp>(
806           loc, DenseElementsAttr::get(predTy, true));
807       lhsMask = allActiveMask;
808       rhsMask = allActiveMask;
809     }
810 
811     rewriter.create<OuterProductWideningIntrOp>(
812         loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
813 
814     // The outerproduct intrinsics have no result, replace
815     // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
816     rewriter.replaceOp(op, acc);
817 
818     return success();
819   }
820 };
821 
822 /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics.
823 ///
824 /// Example:
825 ///
826 ///   %0 = arm_sme.streaming_vl <half>
827 ///
828 /// is converted to:
829 ///
830 ///   %cnt = "arm_sme.intr.cntsh"() : () -> i64
831 ///   %0 = arith.index_cast %cnt : i64 to index
832 ///
833 struct StreamingVLOpConversion
834     : public ConvertArmSMEOpToLLVMPattern<arm_sme::StreamingVLOp,
835                                           RequiresSpillsAndFills::No> {
836   using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
837 
838   LogicalResult
839   matchAndRewrite(arm_sme::StreamingVLOp streamingVlOp,
840                   arm_sme::StreamingVLOp::Adaptor adaptor,
841                   ConversionPatternRewriter &rewriter) const override {
842     auto loc = streamingVlOp.getLoc();
843     auto i64Type = rewriter.getI64Type();
844     auto *intrOp = [&]() -> Operation * {
845       switch (streamingVlOp.getTypeSize()) {
846       case arm_sme::TypeSize::Byte:
847         return rewriter.create<arm_sme::aarch64_sme_cntsb>(loc, i64Type);
848       case arm_sme::TypeSize::Half:
849         return rewriter.create<arm_sme::aarch64_sme_cntsh>(loc, i64Type);
850       case arm_sme::TypeSize::Word:
851         return rewriter.create<arm_sme::aarch64_sme_cntsw>(loc, i64Type);
852       case arm_sme::TypeSize::Double:
853         return rewriter.create<arm_sme::aarch64_sme_cntsd>(loc, i64Type);
854       }
855       llvm_unreachable("unknown type size in StreamingVLOpConversion");
856     }();
857     rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
858         streamingVlOp, rewriter.getIndexType(), intrOp->getResult(0));
859     return success();
860   }
861 };
862 
863 /// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
864 /// or-ing the zero masks. Note: In future the backend _should_ handle this.
865 static void mergeConsecutiveTileZerosInBlock(Block *block) {
866   uint32_t mergedZeroMask = 0;
867   SmallVector<arm_sme::aarch64_sme_zero, 16> zeroOpsToMerge;
868   auto replaceMergedZeroOps = [&] {
869     auto cleanup = llvm::make_scope_exit([&] {
870       mergedZeroMask = 0;
871       zeroOpsToMerge.clear();
872     });
873     if (zeroOpsToMerge.size() <= 1)
874       return;
875     IRRewriter rewriter(zeroOpsToMerge.front());
876     rewriter.create<arm_sme::aarch64_sme_zero>(
877         zeroOpsToMerge.front().getLoc(),
878         rewriter.getI32IntegerAttr(mergedZeroMask));
879     for (auto zeroOp : zeroOpsToMerge)
880       rewriter.eraseOp(zeroOp);
881   };
882   for (Operation &op : *block) {
883     if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
884       mergedZeroMask |= zeroOp.getTileMask();
885       zeroOpsToMerge.push_back(zeroOp);
886     } else {
887       replaceMergedZeroOps();
888     }
889   }
890   replaceMergedZeroOps();
891 }
892 
893 } // namespace
894 
895 namespace {
896 
897 struct ConvertArmSMEToLLVMPass
898     : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
899   ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
900     this->dumpTileLiveRanges = dumpTileLiveRanges;
901   }
902   void runOnOperation() override {
903     auto function = getOperation();
904 
905     if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
906       return signalPassFailure();
907 
908     LLVMConversionTarget target(getContext());
909     RewritePatternSet patterns(&getContext());
910     LLVMTypeConverter converter(&getContext());
911     configureArmSMEToLLVMConversionLegality(target);
912     populateArmSMEToLLVMConversionPatterns(converter, patterns);
913 
914     if (failed(applyPartialConversion(function, target, std::move(patterns))))
915       signalPassFailure();
916 
917     function->walk(mergeConsecutiveTileZerosInBlock);
918 
919     // Walk the function and fail if there are unexpected operations on SME
920     // tile types after conversion.
921     function->walk([&](Operation *op) {
922       // These ops are legal post conversion, skip these.
923       if (isa<arm_sme::CopyTileOp, arm_sme::GetTileOp, cf::BranchOp>(op) ||
924           !op->isRegistered())
925         return;
926       auto isSMETileType = [](Type type) {
927         return arm_sme::isValidSMETileVectorType(type);
928       };
929       if (llvm::any_of(op->getResultTypes(), isSMETileType) ||
930           llvm::any_of(op->getOperandTypes(), isSMETileType)) {
931         op->emitOpError("unexpected operation with SME tile type after "
932                         "conversion to LLVM");
933         signalPassFailure();
934       }
935     });
936   }
937 };
938 
939 } // namespace
940 
941 void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
942   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
943   target.addLegalOp<
944       arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
945       arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
946       arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
947       arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
948       arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
949       arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
950       arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
951       arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
952       arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
953       arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
954       arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
955       arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
956       arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
957       arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
958       arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
959       arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
960       arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
961       arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
962       arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
963       arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
964       arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
965       arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
966       arm_sme::aarch64_sme_cntsd>();
967   target.addLegalDialect<arith::ArithDialect,
968                          /* The following are used to lower tile spills/fills */
969                          vector::VectorDialect, scf::SCFDialect,
970                          memref::MemRefDialect>();
971   // Pseudo operations. These cannot be code-generated but may exist in the
972   // input IR, or be generated during the conversion. They need to be eliminated
973   // before the final conversion to LLVM IR (and likely will be due to DCE).
974   target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
975                     UnrealizedConversionCastOp>();
976 }
977 
978 void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
979                                                   RewritePatternSet &patterns) {
980   converter.addConversion([&](VectorType type) -> std::optional<Type> {
981     // There's no LLVM type for SME tiles, but after lowering to intrinsics all
982     // SME vector types should be eliminated.
983     if (arm_sme::isValidSMETileVectorType(type))
984       return type;
985     return std::nullopt;
986   });
987 
988   addArmSMEConversionPatterns<
989       LoadTileSliceConversion, ExtractTileSliceConversion,
990       InsertTileSliceConversion, StoreTileSliceConversion,
991       StreamingVLOpConversion, OuterProductOpConversion,
992       OuterProductWideningOpConversion<arm_sme::FMopa2WayOp,
993                                        arm_sme::aarch64_sme_mopa_wide>,
994       OuterProductWideningOpConversion<arm_sme::FMops2WayOp,
995                                        arm_sme::aarch64_sme_mops_wide>,
996       OuterProductWideningOpConversion<arm_sme::SMopa2WayOp,
997                                        arm_sme::aarch64_sme_smopa_za32>,
998       OuterProductWideningOpConversion<arm_sme::SMops2WayOp,
999                                        arm_sme::aarch64_sme_smops_za32>,
1000       OuterProductWideningOpConversion<arm_sme::UMopa2WayOp,
1001                                        arm_sme::aarch64_sme_umopa_za32>,
1002       OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
1003                                        arm_sme::aarch64_sme_umops_za32>,
1004       OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
1005                                        arm_sme::aarch64_sme_smopa_wide>,
1006       OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
1007                                        arm_sme::aarch64_sme_smops_wide>,
1008       OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
1009                                        arm_sme::aarch64_sme_umopa_wide>,
1010       OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
1011                                        arm_sme::aarch64_sme_umops_wide>,
1012       OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
1013                                        arm_sme::aarch64_sme_sumopa_wide>,
1014       OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
1015                                        arm_sme::aarch64_sme_sumops_wide>,
1016       OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
1017                                        arm_sme::aarch64_sme_usmopa_wide>,
1018       OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
1019                                        arm_sme::aarch64_sme_usmops_wide>,
1020       ZeroOpConversion>(patterns, converter);
1021 }
1022 
1023 std::unique_ptr<Pass>
1024 mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
1025   return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
1026 }
1027