1 //===- EnableArmStreaming.cpp - Enable Armv9 Streaming SVE mode -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass enables the Armv9 Scalable Matrix Extension (SME) Streaming SVE
10 // (SSVE) mode [1][2] by adding either of the following attributes to
11 // 'func.func' ops:
12 //
13 // * 'arm_streaming' (default)
14 // * 'arm_locally_streaming'
15 //
16 // It can also optionally enable the ZA storage array.
17 //
18 // Streaming-mode is part of the interface (ABI) for functions with the
19 // first attribute and it's the responsibility of the caller to manage
20 // PSTATE.SM on entry/exit to functions with this attribute [3]. The LLVM
21 // backend will emit 'smstart sm' / 'smstop sm' [4] around calls to
22 // streaming functions.
23 //
24 // In locally streaming functions PSTATE.SM is kept internal and managed by
25 // the callee on entry/exit. The LLVM backend will emit 'smstart sm' /
26 // 'smstop sm' in the prologue / epilogue for functions with this
27 // attribute.
28 //
29 // [1] https://developer.arm.com/documentation/ddi0616/aa
30 // [2] https://llvm.org/docs/AArch64SME.html
31 // [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#671pstatesm-interfaces
32 // [4] https://developer.arm.com/documentation/ddi0602/2023-03/Base-Instructions/SMSTART--Enables-access-to-Streaming-SVE-mode-and-SME-architectural-state--an-alias-of-MSR--immediate--
33 //
34 //===----------------------------------------------------------------------===//
35
36 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
37 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
38 #include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
39
40 #include "mlir/Dialect/Func/IR/FuncOps.h"
41
42 #define DEBUG_TYPE "enable-arm-streaming"
43
44 namespace mlir {
45 namespace arm_sme {
46 #define GEN_PASS_DEF_ENABLEARMSTREAMING
47 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
48 } // namespace arm_sme
49 } // namespace mlir
50
51 using namespace mlir;
52 using namespace mlir::arm_sme;
53 namespace {
54
55 constexpr StringLiteral
56 kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
57
58 template <typename... Ops>
opList()59 constexpr auto opList() {
60 return std::array{TypeID::get<Ops>()...};
61 }
62
isScalableVector(Type type)63 bool isScalableVector(Type type) {
64 if (auto vectorType = dyn_cast<VectorType>(type))
65 return vectorType.isScalable();
66 return false;
67 }
68
69 struct EnableArmStreamingPass
70 : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass__anon1e0b46640111::EnableArmStreamingPass71 EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
72 bool ifRequiredByOps, bool ifScalableAndSupported) {
73 this->streamingMode = streamingMode;
74 this->zaMode = zaMode;
75 this->ifRequiredByOps = ifRequiredByOps;
76 this->ifScalableAndSupported = ifScalableAndSupported;
77 }
runOnOperation__anon1e0b46640111::EnableArmStreamingPass78 void runOnOperation() override {
79 auto function = getOperation();
80
81 if (ifRequiredByOps && ifScalableAndSupported) {
82 function->emitOpError(
83 "enable-arm-streaming: `if-required-by-ops` and "
84 "`if-scalable-and-supported` are mutually exclusive");
85 return signalPassFailure();
86 }
87
88 if (ifRequiredByOps) {
89 bool foundTileOp = false;
90 function.walk([&](Operation *op) {
91 if (llvm::isa<ArmSMETileOpInterface>(op)) {
92 foundTileOp = true;
93 return WalkResult::interrupt();
94 }
95 return WalkResult::advance();
96 });
97 if (!foundTileOp)
98 return;
99 }
100
101 if (ifScalableAndSupported) {
102 // FIXME: This should be based on target information (i.e., the presence
103 // of FEAT_SME_FA64). This currently errs on the side of caution. If
104 // possible gathers/scatters should be lowered regular vector loads/stores
105 // before invoking this pass.
106 auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
107 bool isCompatibleScalableFunction = false;
108 function.walk([&](Operation *op) {
109 if (llvm::is_contained(disallowedOperations,
110 op->getName().getTypeID())) {
111 isCompatibleScalableFunction = false;
112 return WalkResult::interrupt();
113 }
114 if (!isCompatibleScalableFunction &&
115 (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
116 llvm::any_of(op->getResultTypes(), isScalableVector))) {
117 isCompatibleScalableFunction = true;
118 }
119 return WalkResult::advance();
120 });
121 if (!isCompatibleScalableFunction)
122 return;
123 }
124
125 if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
126 streamingMode == ArmStreamingMode::Disabled)
127 return;
128
129 auto unitAttr = UnitAttr::get(&getContext());
130
131 function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
132
133 // The pass currently only supports enabling ZA when in streaming-mode, but
134 // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
135 // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
136 // supporting this later.
137 if (zaMode != ArmZaMode::Disabled)
138 function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
139 }
140 };
141 } // namespace
142
createEnableArmStreamingPass(const ArmStreamingMode streamingMode,const ArmZaMode zaMode,bool ifRequiredByOps,bool ifScalableAndSupported)143 std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
144 const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
145 bool ifRequiredByOps, bool ifScalableAndSupported) {
146 return std::make_unique<EnableArmStreamingPass>(
147 streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
148 }
149