xref: /llvm-project/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (revision 041baf2f60ac3e107399641aea04c77019e7eab8)
1 //===- Utils.h - General ArmSME transformation utilities --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines prototypes for various utilities for the ArmSME
10 // dialect. These are not passes by themselves but are used either by passes,
11 // optimization sequences, or in turn by other transformation utilities.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #ifndef MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
16 #define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
17 
18 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
19 #include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
20 #include "mlir/Dialect/SCF/IR/SCF.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/Interfaces/FunctionInterfaces.h"
23 #include <optional>
24 
25 namespace mlir {
26 class Location;
27 class PatternRewriter;
28 class Value;
29 } // namespace mlir
30 
31 namespace mlir::arm_sme {
32 
33 constexpr unsigned MinStreamingVectorLengthInBits = 128;
34 
35 /// Return minimum number of elements for the given element `type` in
36 /// a vector of SVL bits.
37 unsigned getSMETileSliceMinNumElts(Type type);
38 
39 /// Returns true if `type` is a valid element type for an SME tile or false
40 /// otherwise.
41 bool isValidSMETileElementType(Type type);
42 
43 /// Returns true if `vType` is a valid vector type for an SME tile or false
44 /// otherwise.
45 bool isValidSMETileVectorType(VectorType vType);
46 
isValidSMETileVectorType(Type type)47 inline bool isValidSMETileVectorType(Type type) {
48   auto vType = dyn_cast<VectorType>(type);
49   return vType && isValidSMETileVectorType(vType);
50 }
51 
52 /// Returns the type of SME tile this vector type corresponds to, or none if the
53 /// vector type does not fit within an SME tile.
54 std::optional<ArmSMETileType> getSMETileType(VectorType);
55 
56 /// Verifies the tile ID (if set) on this tile operation is valid.
57 LogicalResult verifyOperationHasValidTileId(Operation *);
58 
59 /// Generates a for loop over ZA tile slices where the induction variable is
60 /// the tile slice index and each iteration yields a new tile. Loop body is
61 /// built via `makeLoopBody`, which returns the next tile value.
62 scf::ForOp createLoopOverTileSlices(
63     PatternRewriter &rewriter, Location loc, Value initTile,
64     std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
65 
66 /// Returns true if `vType` is a multiple of an SME tile size. Returns false if
67 /// the `vType` exactly matches the size of an SME tile.
68 bool isMultipleOfSMETileVectorType(VectorType vType);
69 
70 /// Creates a vector type for the SME tile of `elementType`.
71 VectorType getSMETileTypeForElement(Type elementType);
72 
73 /// Erase trivially dead tile ops from a function.
74 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
75                                FunctionOpInterface function);
76 
77 /// Returns true if `tileOp` is trivially cloneable. A tile operation is
78 /// trivially cloneable if:
79 ///
80 ///  1. It has no operands (and only a single tile result)
81 ///  2. It is 'Pure'
82 ///
83 /// This ensures that the cloned operation will not share any dependencies with
84 /// the original operation (which could also need to be considered), and that
85 /// inserting the cloned operation at a different point in the program won't
86 /// change the semantics of the program (as it has no side effects).
87 bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp);
88 
89 /// Returns true if `tileOp` produces a tile result.
90 bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp);
91 
92 /// Returns the tile `OpOperand` for this `tileOp` (or null).
93 OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp);
94 
95 /// Returns true `typeA` is >= (in terms of bytes) than `typeB`.
96 bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB);
97 
98 } // namespace mlir::arm_sme
99 
100 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
101