xref: /llvm-project/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (revision 66953c9582b0ad960ccbb8463158273cdc40cd8f)
1//===-- Passes.td - ArmSME pass definition file ------------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
10#define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
11
12include "mlir/Pass/PassBase.td"
13include "mlir/IR/EnumAttr.td"
14
15def ArmStreamingMode : I32EnumAttr<"ArmStreamingMode", "Armv9 Streaming SVE mode",
16    [
17      I32EnumAttrCase<"Disabled", 0, "disabled">,
18      // Streaming: Streaming-mode is part of the function interface (ABI).
19      I32EnumAttrCase<"Streaming", 1, "arm_streaming">,
20      // StreamingLocally: PSTATE.SM is kept internal and the callee manages it
21      // on entry/exit.
22      I32EnumAttrCase<"StreamingLocally", 2, "arm_locally_streaming">,
23      // StreamingCompatible: the function may be entered in either
24      // non-streaming mode (PSTATE.SM=0) or in streaming mode (PSTATE.SM=1)
25      I32EnumAttrCase<"StreamingCompatible", 3, "arm_streaming_compatible">,
26    ]>{
27  let cppNamespace = "mlir::arm_sme";
28  let genSpecializedAttr = 0;
29}
30
31// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za
32// See also the LLVM definitions: https://llvm.org/docs/AArch64SME.html
33//
34// Various frontends (e.g. Flang) that build on top of this may restrict or
35// enforce how these attributes are used, both individually and in terms of
36// combinations that are allowed.
37//
38// The MLIR interface here does not make any attempt to perform any checking,
39// it is up to the higher level to ensure that these attributes are used in a
40// way that both makes sense and is legal according to the Arm architecture.
41def ArmZaMode : I32EnumAttr<"ArmZaMode", "Armv9 ZA storage mode",
42    [
43      I32EnumAttrCase<"Disabled", 0, "disabled">,
44      // A function's ZA state is created on entry and destroyed on exit.
45      I32EnumAttrCase<"NewZA", 1, "arm_new_za">,
46      // A function with a Shared-ZA interfaces that takes ZA as input.
47      I32EnumAttrCase<"InZA", 2, "arm_in_za">,
48      // A function with a Shared-ZA interfaces that returns ZA as output.
49      I32EnumAttrCase<"OutZA", 3, "arm_out_za">,
50      // A function with a Shared-ZA interfaces that takes ZA as input and
51      // returns ZA as output.
52      I32EnumAttrCase<"InOutZA", 4, "arm_inout_za">,
53      // A function with a Shared-ZA interface that does not read ZA and
54      // returns with ZA unchanged.
55      I32EnumAttrCase<"PreservesZA", 5, "arm_preserves_za">,
56    ]>{
57  let cppNamespace = "mlir::arm_sme";
58  let genSpecializedAttr = 0;
59}
60
61def EnableArmStreaming
62    : Pass<"enable-arm-streaming", "mlir::func::FuncOp"> {
63  let summary = "Enable Armv9 Streaming SVE mode";
64  let description = [{
65    Enables the Armv9 Streaming SVE mode [1] for func.func ops by annotating
66    them with attributes. See options for more details.
67
68    [1] https://developer.arm.com/documentation/ddi0616/aa
69  }];
70  let constructor = "mlir::arm_sme::createEnableArmStreamingPass()";
71  let options = [
72    Option<"streamingMode", "streaming-mode", "mlir::arm_sme::ArmStreamingMode",
73          /*default=*/"mlir::arm_sme::ArmStreamingMode::Streaming",
74          "Select how streaming-mode is managed at the function-level.",
75          [{::llvm::cl::values(
76                clEnumValN(mlir::arm_sme::ArmStreamingMode::Disabled,
77                           "disabled", "Streaming mode is disabled."),
78                clEnumValN(mlir::arm_sme::ArmStreamingMode::Streaming,
79                           "streaming",
80                           "Streaming mode is part of the function interface "
81                           "(ABI), caller manages PSTATE.SM on entry/exit."),
82                clEnumValN(mlir::arm_sme::ArmStreamingMode::StreamingLocally,
83                           "streaming-locally",
84                           "Streaming mode is internal to the function, callee "
85                           "manages PSTATE.SM on entry/exit."),
86                clEnumValN(mlir::arm_sme::ArmStreamingMode::StreamingCompatible,
87                           "streaming-compatible",
88                           "Function supports both streaming and non-streaming "
89                           "modes.")
90          )}]>,
91    Option<"zaMode", "za-mode", "mlir::arm_sme::ArmZaMode",
92           /*default=*/"mlir::arm_sme::ArmZaMode::Disabled",
93           "Select how ZA-storage is managed at the function-level.",
94           [{::llvm::cl::values(
95                 clEnumValN(mlir::arm_sme::ArmZaMode::Disabled,
96                            "disabled", "ZA storage is disabled."),
97                 clEnumValN(mlir::arm_sme::ArmZaMode::NewZA,
98                            "new-za",
99                            "The function has ZA state. The ZA state is "
100                            "created on entry and destroyed on exit."),
101                 clEnumValN(mlir::arm_sme::ArmZaMode::InZA,
102                            "in-za",
103                            "The function uses ZA state. The ZA state may "
104                            "be used for input."),
105                 clEnumValN(mlir::arm_sme::ArmZaMode::OutZA,
106                            "out-za",
107                            "The function uses ZA state. The ZA state may "
108                            "be used for output."),
109                 clEnumValN(mlir::arm_sme::ArmZaMode::InOutZA,
110                            "inout-za",
111                            "The function uses ZA state. The ZA state may "
112                            "be used for input and/or output."),
113                 clEnumValN(mlir::arm_sme::ArmZaMode::PreservesZA,
114                            "preserves-za",
115                            "The function shares ZA state. The ZA state may "
116                            "not be used for input and/or output and the "
117                            "function must return with ZA unchanged")
118           )}]>,
119    Option<"ifRequiredByOps", "if-required-by-ops", "bool",
120           /*default=*/"false",
121           "Only apply the selected streaming/ZA modes if the function contains"
122           " ops that implement the ArmSMETileOpInterface.">,
123    Option<"ifScalableAndSupported", "if-scalable-and-supported",
124           "bool", /*default=*/"false",
125           "Only apply the selected streaming/ZA modes if the function contains"
126           " supported scalable vector operations.">
127  ];
128  let dependentDialects = ["func::FuncDialect"];
129}
130
131def TestTileAllocation
132    : Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> {
133  let summary = "Tests SME 'virtual tile' allocation";
134  let description = [{
135    This pass does tile allocation for SME "virtual tiles". It is run at the
136    'func.func' op level, and assigns tile IDs (via an attribute) to all ops
137    that implement the `ArmSMETileOpInterface`. Note: This pass is only intended
138    to be used for testing, tile allocation is done as part of the ArmSME to
139    LLVM conversion (`convert-arm-sme-to-llvm`).
140  }];
141  let options = [
142    Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
143           "bool", /*default=*/"false",
144           "Dump the live ranges of SME tiles (for debugging)">,
145    Option<"preprocessOnly", "preprocess-only", "bool", /*default=*/"false",
146           "Only preprocess IR so it is ready for tile allocation "
147           "(but do not allocate any tiles)">
148  ];
149  let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
150}
151
152def OuterProductFusion
153    : Pass<"arm-sme-outer-product-fusion", "mlir::func::FuncOp"> {
154  let summary = "Fuse 'arm_sme.outerproduct' operations into 2-way or 4-way widening variants";
155  let description = [{
156    This pass fuses 'arm_sme.outerproduct' operations that are chained via the
157    accumulator into 2-way or 4-way ArmSME outer product operations.
158
159    For example:
160    ```mlir
161    %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
162    %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
163    %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
164    %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
165
166    %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
167    %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
168    ```
169
170    Becomes:
171
172    ```mlir
173    %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16>
174    %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16>
175    %0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
176    ```
177
178    For further information on the 2-way or 4-way widening ops see:
179    https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smefmopa_2way-arm_smefmopa_2wayop
180    https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_4way-arm_smesmopa_4wayop
181  }];
182  let constructor = "mlir::arm_sme::createOuterProductFusionPass()";
183  let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
184}
185
186def VectorLegalization
187  : Pass<"arm-sme-vector-legalization", "mlir::ModuleOp"> {
188  let summary = "Legalize vectors for ArmSME";
189  let description = [{
190    This pass legalizes vector operations so that they can be lowered to ArmSME.
191    This includes decomposing operations that operate on vector types larger
192    than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME
193    tile-sized operations, as well as rewrites needed to get operations into
194    forms compatible with SME lowerings.
195
196    Note: Decomposition is currently limited to vector types that are an exact
197    multiple of SME tiles. That is scalable in two dimensions, with both the
198    rows and columns divisible by the SVE vector length for the element type.
199  }];
200  let constructor = "mlir::arm_sme::createVectorLegalizationPass()";
201  let dependentDialects = [
202    "func::FuncDialect",
203    "arm_sme::ArmSMEDialect",
204    "vector::VectorDialect",
205    "arith::ArithDialect",
206    "index::IndexDialect"
207  ];
208}
209
210#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
211