xref: /llvm-project/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp (revision 1b64ed0e0c7fde1b65d55bfb7954beadc0f60e28)
1 //===- EnableArmStreaming.cpp - Enable Armv9 Streaming SVE mode -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass enables the Armv9 Scalable Matrix Extension (SME) Streaming SVE
10 // (SSVE) mode [1][2] by adding either of the following attributes to
11 // 'func.func' ops:
12 //
13 //   * 'arm_streaming' (default)
14 //   * 'arm_locally_streaming'
15 //
16 // It can also optionally enable the ZA storage array.
17 //
18 // Streaming-mode is part of the interface (ABI) for functions with the
19 // first attribute and it's the responsibility of the caller to manage
20 // PSTATE.SM on entry/exit to functions with this attribute [3]. The LLVM
21 // backend will emit 'smstart sm' / 'smstop sm' [4] around calls to
22 // streaming functions.
23 //
24 // In locally streaming functions PSTATE.SM is kept internal and managed by
25 // the callee on entry/exit. The LLVM backend will emit 'smstart sm' /
26 // 'smstop sm' in the prologue / epilogue for functions with this
27 // attribute.
28 //
29 // [1] https://developer.arm.com/documentation/ddi0616/aa
30 // [2] https://llvm.org/docs/AArch64SME.html
31 // [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#671pstatesm-interfaces
32 // [4] https://developer.arm.com/documentation/ddi0602/2023-03/Base-Instructions/SMSTART--Enables-access-to-Streaming-SVE-mode-and-SME-architectural-state--an-alias-of-MSR--immediate--
33 //
34 //===----------------------------------------------------------------------===//
35 
36 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
37 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
38 #include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
39 
40 #include "mlir/Dialect/Func/IR/FuncOps.h"
41 
42 #define DEBUG_TYPE "enable-arm-streaming"
43 
44 namespace mlir {
45 namespace arm_sme {
46 #define GEN_PASS_DEF_ENABLEARMSTREAMING
47 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
48 } // namespace arm_sme
49 } // namespace mlir
50 
51 using namespace mlir;
52 using namespace mlir::arm_sme;
53 namespace {
54 
55 constexpr StringLiteral
56     kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
57 
58 template <typename... Ops>
opList()59 constexpr auto opList() {
60   return std::array{TypeID::get<Ops>()...};
61 }
62 
isScalableVector(Type type)63 bool isScalableVector(Type type) {
64   if (auto vectorType = dyn_cast<VectorType>(type))
65     return vectorType.isScalable();
66   return false;
67 }
68 
69 struct EnableArmStreamingPass
70     : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass__anon1e0b46640111::EnableArmStreamingPass71   EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
72                          bool ifRequiredByOps, bool ifScalableAndSupported) {
73     this->streamingMode = streamingMode;
74     this->zaMode = zaMode;
75     this->ifRequiredByOps = ifRequiredByOps;
76     this->ifScalableAndSupported = ifScalableAndSupported;
77   }
runOnOperation__anon1e0b46640111::EnableArmStreamingPass78   void runOnOperation() override {
79     auto function = getOperation();
80 
81     if (ifRequiredByOps && ifScalableAndSupported) {
82       function->emitOpError(
83           "enable-arm-streaming: `if-required-by-ops` and "
84           "`if-scalable-and-supported` are mutually exclusive");
85       return signalPassFailure();
86     }
87 
88     if (ifRequiredByOps) {
89       bool foundTileOp = false;
90       function.walk([&](Operation *op) {
91         if (llvm::isa<ArmSMETileOpInterface>(op)) {
92           foundTileOp = true;
93           return WalkResult::interrupt();
94         }
95         return WalkResult::advance();
96       });
97       if (!foundTileOp)
98         return;
99     }
100 
101     if (ifScalableAndSupported) {
102       // FIXME: This should be based on target information (i.e., the presence
103       // of FEAT_SME_FA64). This currently errs on the side of caution. If
104       // possible gathers/scatters should be lowered regular vector loads/stores
105       // before invoking this pass.
106       auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
107       bool isCompatibleScalableFunction = false;
108       function.walk([&](Operation *op) {
109         if (llvm::is_contained(disallowedOperations,
110                                op->getName().getTypeID())) {
111           isCompatibleScalableFunction = false;
112           return WalkResult::interrupt();
113         }
114         if (!isCompatibleScalableFunction &&
115             (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
116              llvm::any_of(op->getResultTypes(), isScalableVector))) {
117           isCompatibleScalableFunction = true;
118         }
119         return WalkResult::advance();
120       });
121       if (!isCompatibleScalableFunction)
122         return;
123     }
124 
125     if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
126         streamingMode == ArmStreamingMode::Disabled)
127       return;
128 
129     auto unitAttr = UnitAttr::get(&getContext());
130 
131     function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
132 
133     // The pass currently only supports enabling ZA when in streaming-mode, but
134     // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
135     // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
136     // supporting this later.
137     if (zaMode != ArmZaMode::Disabled)
138       function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
139   }
140 };
141 } // namespace
142 
createEnableArmStreamingPass(const ArmStreamingMode streamingMode,const ArmZaMode zaMode,bool ifRequiredByOps,bool ifScalableAndSupported)143 std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
144     const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
145     bool ifRequiredByOps, bool ifScalableAndSupported) {
146   return std::make_unique<EnableArmStreamingPass>(
147       streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
148 }
149