1//===-- Passes.td - ArmSME pass definition file ------------*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD 10#define MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD 11 12include "mlir/Pass/PassBase.td" 13include "mlir/IR/EnumAttr.td" 14 15def ArmStreamingMode : I32EnumAttr<"ArmStreamingMode", "Armv9 Streaming SVE mode", 16 [ 17 I32EnumAttrCase<"Disabled", 0, "disabled">, 18 // Streaming: Streaming-mode is part of the function interface (ABI). 19 I32EnumAttrCase<"Streaming", 1, "arm_streaming">, 20 // StreamingLocally: PSTATE.SM is kept internal and the callee manages it 21 // on entry/exit. 22 I32EnumAttrCase<"StreamingLocally", 2, "arm_locally_streaming">, 23 // StreamingCompatible: the function may be entered in either 24 // non-streaming mode (PSTATE.SM=0) or in streaming mode (PSTATE.SM=1) 25 I32EnumAttrCase<"StreamingCompatible", 3, "arm_streaming_compatible">, 26 ]>{ 27 let cppNamespace = "mlir::arm_sme"; 28 let genSpecializedAttr = 0; 29} 30 31// https://arm-software.github.io/acle/main/acle.html#sme-attributes-relating-to-za 32// See also the LLVM definitions: https://llvm.org/docs/AArch64SME.html 33// 34// Various frontends (e.g. Flang) that build on top of this may restrict or 35// enforce how these attributes are used, both individually and in terms of 36// combinations that are allowed. 37// 38// The MLIR interface here does not make any attempt to perform any checking, 39// it is up to the higher level to ensure that these attributes are used in a 40// way that both makes sense and is legal according to the Arm architecture. 41def ArmZaMode : I32EnumAttr<"ArmZaMode", "Armv9 ZA storage mode", 42 [ 43 I32EnumAttrCase<"Disabled", 0, "disabled">, 44 // A function's ZA state is created on entry and destroyed on exit. 45 I32EnumAttrCase<"NewZA", 1, "arm_new_za">, 46 // A function with a Shared-ZA interfaces that takes ZA as input. 47 I32EnumAttrCase<"InZA", 2, "arm_in_za">, 48 // A function with a Shared-ZA interfaces that returns ZA as output. 49 I32EnumAttrCase<"OutZA", 3, "arm_out_za">, 50 // A function with a Shared-ZA interfaces that takes ZA as input and 51 // returns ZA as output. 52 I32EnumAttrCase<"InOutZA", 4, "arm_inout_za">, 53 // A function with a Shared-ZA interface that does not read ZA and 54 // returns with ZA unchanged. 55 I32EnumAttrCase<"PreservesZA", 5, "arm_preserves_za">, 56 ]>{ 57 let cppNamespace = "mlir::arm_sme"; 58 let genSpecializedAttr = 0; 59} 60 61def EnableArmStreaming 62 : Pass<"enable-arm-streaming", "mlir::func::FuncOp"> { 63 let summary = "Enable Armv9 Streaming SVE mode"; 64 let description = [{ 65 Enables the Armv9 Streaming SVE mode [1] for func.func ops by annotating 66 them with attributes. See options for more details. 67 68 [1] https://developer.arm.com/documentation/ddi0616/aa 69 }]; 70 let constructor = "mlir::arm_sme::createEnableArmStreamingPass()"; 71 let options = [ 72 Option<"streamingMode", "streaming-mode", "mlir::arm_sme::ArmStreamingMode", 73 /*default=*/"mlir::arm_sme::ArmStreamingMode::Streaming", 74 "Select how streaming-mode is managed at the function-level.", 75 [{::llvm::cl::values( 76 clEnumValN(mlir::arm_sme::ArmStreamingMode::Disabled, 77 "disabled", "Streaming mode is disabled."), 78 clEnumValN(mlir::arm_sme::ArmStreamingMode::Streaming, 79 "streaming", 80 "Streaming mode is part of the function interface " 81 "(ABI), caller manages PSTATE.SM on entry/exit."), 82 clEnumValN(mlir::arm_sme::ArmStreamingMode::StreamingLocally, 83 "streaming-locally", 84 "Streaming mode is internal to the function, callee " 85 "manages PSTATE.SM on entry/exit."), 86 clEnumValN(mlir::arm_sme::ArmStreamingMode::StreamingCompatible, 87 "streaming-compatible", 88 "Function supports both streaming and non-streaming " 89 "modes.") 90 )}]>, 91 Option<"zaMode", "za-mode", "mlir::arm_sme::ArmZaMode", 92 /*default=*/"mlir::arm_sme::ArmZaMode::Disabled", 93 "Select how ZA-storage is managed at the function-level.", 94 [{::llvm::cl::values( 95 clEnumValN(mlir::arm_sme::ArmZaMode::Disabled, 96 "disabled", "ZA storage is disabled."), 97 clEnumValN(mlir::arm_sme::ArmZaMode::NewZA, 98 "new-za", 99 "The function has ZA state. The ZA state is " 100 "created on entry and destroyed on exit."), 101 clEnumValN(mlir::arm_sme::ArmZaMode::InZA, 102 "in-za", 103 "The function uses ZA state. The ZA state may " 104 "be used for input."), 105 clEnumValN(mlir::arm_sme::ArmZaMode::OutZA, 106 "out-za", 107 "The function uses ZA state. The ZA state may " 108 "be used for output."), 109 clEnumValN(mlir::arm_sme::ArmZaMode::InOutZA, 110 "inout-za", 111 "The function uses ZA state. The ZA state may " 112 "be used for input and/or output."), 113 clEnumValN(mlir::arm_sme::ArmZaMode::PreservesZA, 114 "preserves-za", 115 "The function shares ZA state. The ZA state may " 116 "not be used for input and/or output and the " 117 "function must return with ZA unchanged") 118 )}]>, 119 Option<"ifRequiredByOps", "if-required-by-ops", "bool", 120 /*default=*/"false", 121 "Only apply the selected streaming/ZA modes if the function contains" 122 " ops that implement the ArmSMETileOpInterface.">, 123 Option<"ifScalableAndSupported", "if-scalable-and-supported", 124 "bool", /*default=*/"false", 125 "Only apply the selected streaming/ZA modes if the function contains" 126 " supported scalable vector operations."> 127 ]; 128 let dependentDialects = ["func::FuncDialect"]; 129} 130 131def TestTileAllocation 132 : Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> { 133 let summary = "Tests SME 'virtual tile' allocation"; 134 let description = [{ 135 This pass does tile allocation for SME "virtual tiles". It is run at the 136 'func.func' op level, and assigns tile IDs (via an attribute) to all ops 137 that implement the `ArmSMETileOpInterface`. Note: This pass is only intended 138 to be used for testing, tile allocation is done as part of the ArmSME to 139 LLVM conversion (`convert-arm-sme-to-llvm`). 140 }]; 141 let options = [ 142 Option<"dumpTileLiveRanges", "dump-tile-live-ranges", 143 "bool", /*default=*/"false", 144 "Dump the live ranges of SME tiles (for debugging)">, 145 Option<"preprocessOnly", "preprocess-only", "bool", /*default=*/"false", 146 "Only preprocess IR so it is ready for tile allocation " 147 "(but do not allocate any tiles)"> 148 ]; 149 let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"]; 150} 151 152def OuterProductFusion 153 : Pass<"arm-sme-outer-product-fusion", "mlir::func::FuncOp"> { 154 let summary = "Fuse 'arm_sme.outerproduct' operations into 2-way or 4-way widening variants"; 155 let description = [{ 156 This pass fuses 'arm_sme.outerproduct' operations that are chained via the 157 accumulator into 2-way or 4-way ArmSME outer product operations. 158 159 For example: 160 ```mlir 161 %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> 162 %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> 163 %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> 164 %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> 165 166 %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32> 167 %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32> 168 ``` 169 170 Becomes: 171 172 ```mlir 173 %a_packed = vector.interleave %a0, %a1 : vector<[4]xf16> -> vector<[8]xf16> 174 %b_packed = vector.interleave %b0, %b1 : vector<[4]xf16> -> vector<[8]xf16> 175 %0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> 176 ``` 177 178 For further information on the 2-way or 4-way widening ops see: 179 https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smefmopa_2way-arm_smefmopa_2wayop 180 https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_4way-arm_smesmopa_4wayop 181 }]; 182 let constructor = "mlir::arm_sme::createOuterProductFusionPass()"; 183 let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"]; 184} 185 186def VectorLegalization 187 : Pass<"arm-sme-vector-legalization", "mlir::ModuleOp"> { 188 let summary = "Legalize vectors for ArmSME"; 189 let description = [{ 190 This pass legalizes vector operations so that they can be lowered to ArmSME. 191 This includes decomposing operations that operate on vector types larger 192 than a single SME tile (e.g. `vector<[8]x[8]xf32>`) into multiple SME 193 tile-sized operations, as well as rewrites needed to get operations into 194 forms compatible with SME lowerings. 195 196 Note: Decomposition is currently limited to vector types that are an exact 197 multiple of SME tiles. That is scalable in two dimensions, with both the 198 rows and columns divisible by the SVE vector length for the element type. 199 }]; 200 let constructor = "mlir::arm_sme::createVectorLegalizationPass()"; 201 let dependentDialects = [ 202 "func::FuncDialect", 203 "arm_sme::ArmSMEDialect", 204 "vector::VectorDialect", 205 "arith::ArithDialect", 206 "index::IndexDialect" 207 ]; 208} 209 210#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD 211