112648492SCullen Rhodes //===- EnableArmStreaming.cpp - Enable Armv9 Streaming SVE mode -----------===//
212648492SCullen Rhodes //
312648492SCullen Rhodes // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
412648492SCullen Rhodes // See https://llvm.org/LICENSE.txt for license information.
512648492SCullen Rhodes // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
612648492SCullen Rhodes //
712648492SCullen Rhodes //===----------------------------------------------------------------------===//
812648492SCullen Rhodes //
912648492SCullen Rhodes // This pass enables the Armv9 Scalable Matrix Extension (SME) Streaming SVE
1012648492SCullen Rhodes // (SSVE) mode [1][2] by adding either of the following attributes to
1112648492SCullen Rhodes // 'func.func' ops:
1212648492SCullen Rhodes //
1312648492SCullen Rhodes // * 'arm_streaming' (default)
1412648492SCullen Rhodes // * 'arm_locally_streaming'
1512648492SCullen Rhodes //
16e947e760SCullen Rhodes // It can also optionally enable the ZA storage array.
17e947e760SCullen Rhodes //
1812648492SCullen Rhodes // Streaming-mode is part of the interface (ABI) for functions with the
1912648492SCullen Rhodes // first attribute and it's the responsibility of the caller to manage
2012648492SCullen Rhodes // PSTATE.SM on entry/exit to functions with this attribute [3]. The LLVM
2112648492SCullen Rhodes // backend will emit 'smstart sm' / 'smstop sm' [4] around calls to
2212648492SCullen Rhodes // streaming functions.
2312648492SCullen Rhodes //
2412648492SCullen Rhodes // In locally streaming functions PSTATE.SM is kept internal and managed by
2512648492SCullen Rhodes // the callee on entry/exit. The LLVM backend will emit 'smstart sm' /
2612648492SCullen Rhodes // 'smstop sm' in the prologue / epilogue for functions with this
2712648492SCullen Rhodes // attribute.
2812648492SCullen Rhodes //
2912648492SCullen Rhodes // [1] https://developer.arm.com/documentation/ddi0616/aa
3012648492SCullen Rhodes // [2] https://llvm.org/docs/AArch64SME.html
3112648492SCullen Rhodes // [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#671pstatesm-interfaces
3212648492SCullen Rhodes // [4] https://developer.arm.com/documentation/ddi0602/2023-03/Base-Instructions/SMSTART--Enables-access-to-Streaming-SVE-mode-and-SME-architectural-state--an-alias-of-MSR--immediate--
3312648492SCullen Rhodes //
3412648492SCullen Rhodes //===----------------------------------------------------------------------===//
3512648492SCullen Rhodes
36f7d91faaSBenjamin Maxwell #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
3712648492SCullen Rhodes #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
38783ac3b6SBenjamin Maxwell #include "mlir/Dialect/ArmSME/Transforms/PassesEnums.cpp.inc"
3912648492SCullen Rhodes
4012648492SCullen Rhodes #include "mlir/Dialect/Func/IR/FuncOps.h"
4112648492SCullen Rhodes
4212648492SCullen Rhodes #define DEBUG_TYPE "enable-arm-streaming"
4312648492SCullen Rhodes
4412648492SCullen Rhodes namespace mlir {
4512648492SCullen Rhodes namespace arm_sme {
4612648492SCullen Rhodes #define GEN_PASS_DEF_ENABLEARMSTREAMING
4712648492SCullen Rhodes #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
4812648492SCullen Rhodes } // namespace arm_sme
4912648492SCullen Rhodes } // namespace mlir
5012648492SCullen Rhodes
5112648492SCullen Rhodes using namespace mlir;
5212648492SCullen Rhodes using namespace mlir::arm_sme;
5312648492SCullen Rhodes namespace {
54783ac3b6SBenjamin Maxwell
55783ac3b6SBenjamin Maxwell constexpr StringLiteral
56783ac3b6SBenjamin Maxwell kEnableArmStreamingIgnoreAttr("enable_arm_streaming_ignore");
57783ac3b6SBenjamin Maxwell
58*1b64ed0eSBenjamin Maxwell template <typename... Ops>
opList()59*1b64ed0eSBenjamin Maxwell constexpr auto opList() {
60*1b64ed0eSBenjamin Maxwell return std::array{TypeID::get<Ops>()...};
61*1b64ed0eSBenjamin Maxwell }
62*1b64ed0eSBenjamin Maxwell
isScalableVector(Type type)63*1b64ed0eSBenjamin Maxwell bool isScalableVector(Type type) {
64*1b64ed0eSBenjamin Maxwell if (auto vectorType = dyn_cast<VectorType>(type))
65*1b64ed0eSBenjamin Maxwell return vectorType.isScalable();
66*1b64ed0eSBenjamin Maxwell return false;
67*1b64ed0eSBenjamin Maxwell }
68*1b64ed0eSBenjamin Maxwell
6912648492SCullen Rhodes struct EnableArmStreamingPass
7012648492SCullen Rhodes : public arm_sme::impl::EnableArmStreamingBase<EnableArmStreamingPass> {
EnableArmStreamingPass__anon1e0b46640111::EnableArmStreamingPass71f7d91faaSBenjamin Maxwell EnableArmStreamingPass(ArmStreamingMode streamingMode, ArmZaMode zaMode,
72*1b64ed0eSBenjamin Maxwell bool ifRequiredByOps, bool ifScalableAndSupported) {
73783ac3b6SBenjamin Maxwell this->streamingMode = streamingMode;
74783ac3b6SBenjamin Maxwell this->zaMode = zaMode;
75d319fc41SBenjamin Maxwell this->ifRequiredByOps = ifRequiredByOps;
76*1b64ed0eSBenjamin Maxwell this->ifScalableAndSupported = ifScalableAndSupported;
77e947e760SCullen Rhodes }
runOnOperation__anon1e0b46640111::EnableArmStreamingPass7812648492SCullen Rhodes void runOnOperation() override {
79d319fc41SBenjamin Maxwell auto function = getOperation();
80f7d91faaSBenjamin Maxwell
81*1b64ed0eSBenjamin Maxwell if (ifRequiredByOps && ifScalableAndSupported) {
82d319fc41SBenjamin Maxwell function->emitOpError(
83d319fc41SBenjamin Maxwell "enable-arm-streaming: `if-required-by-ops` and "
84*1b64ed0eSBenjamin Maxwell "`if-scalable-and-supported` are mutually exclusive");
85d319fc41SBenjamin Maxwell return signalPassFailure();
86d319fc41SBenjamin Maxwell }
87d319fc41SBenjamin Maxwell
88d319fc41SBenjamin Maxwell if (ifRequiredByOps) {
89f7d91faaSBenjamin Maxwell bool foundTileOp = false;
90d319fc41SBenjamin Maxwell function.walk([&](Operation *op) {
91f7d91faaSBenjamin Maxwell if (llvm::isa<ArmSMETileOpInterface>(op)) {
92f7d91faaSBenjamin Maxwell foundTileOp = true;
93f7d91faaSBenjamin Maxwell return WalkResult::interrupt();
94f7d91faaSBenjamin Maxwell }
95f7d91faaSBenjamin Maxwell return WalkResult::advance();
96f7d91faaSBenjamin Maxwell });
97f7d91faaSBenjamin Maxwell if (!foundTileOp)
98f7d91faaSBenjamin Maxwell return;
99f7d91faaSBenjamin Maxwell }
100f7d91faaSBenjamin Maxwell
101*1b64ed0eSBenjamin Maxwell if (ifScalableAndSupported) {
102*1b64ed0eSBenjamin Maxwell // FIXME: This should be based on target information (i.e., the presence
103*1b64ed0eSBenjamin Maxwell // of FEAT_SME_FA64). This currently errs on the side of caution. If
104*1b64ed0eSBenjamin Maxwell // possible gathers/scatters should be lowered regular vector loads/stores
105*1b64ed0eSBenjamin Maxwell // before invoking this pass.
106*1b64ed0eSBenjamin Maxwell auto disallowedOperations = opList<vector::GatherOp, vector::ScatterOp>();
107*1b64ed0eSBenjamin Maxwell bool isCompatibleScalableFunction = false;
108d319fc41SBenjamin Maxwell function.walk([&](Operation *op) {
109*1b64ed0eSBenjamin Maxwell if (llvm::is_contained(disallowedOperations,
110*1b64ed0eSBenjamin Maxwell op->getName().getTypeID())) {
111*1b64ed0eSBenjamin Maxwell isCompatibleScalableFunction = false;
112d319fc41SBenjamin Maxwell return WalkResult::interrupt();
113d319fc41SBenjamin Maxwell }
114*1b64ed0eSBenjamin Maxwell if (!isCompatibleScalableFunction &&
115*1b64ed0eSBenjamin Maxwell (llvm::any_of(op->getOperandTypes(), isScalableVector) ||
116*1b64ed0eSBenjamin Maxwell llvm::any_of(op->getResultTypes(), isScalableVector))) {
117*1b64ed0eSBenjamin Maxwell isCompatibleScalableFunction = true;
118*1b64ed0eSBenjamin Maxwell }
119d319fc41SBenjamin Maxwell return WalkResult::advance();
120d319fc41SBenjamin Maxwell });
121*1b64ed0eSBenjamin Maxwell if (!isCompatibleScalableFunction)
122d319fc41SBenjamin Maxwell return;
123d319fc41SBenjamin Maxwell }
124d319fc41SBenjamin Maxwell
125d319fc41SBenjamin Maxwell if (function->getAttr(kEnableArmStreamingIgnoreAttr) ||
126783ac3b6SBenjamin Maxwell streamingMode == ArmStreamingMode::Disabled)
12706f22c9aSBenjamin Maxwell return;
128783ac3b6SBenjamin Maxwell
129783ac3b6SBenjamin Maxwell auto unitAttr = UnitAttr::get(&getContext());
130783ac3b6SBenjamin Maxwell
131d319fc41SBenjamin Maxwell function->setAttr(stringifyArmStreamingMode(streamingMode), unitAttr);
132e947e760SCullen Rhodes
133e947e760SCullen Rhodes // The pass currently only supports enabling ZA when in streaming-mode, but
134e947e760SCullen Rhodes // ZA can be accessed by the SME LDR, STR and ZERO instructions when not in
135e947e760SCullen Rhodes // streaming-mode (see section B1.1.1, IDGNQM of spec [1]). It may be worth
136e947e760SCullen Rhodes // supporting this later.
137783ac3b6SBenjamin Maxwell if (zaMode != ArmZaMode::Disabled)
138d319fc41SBenjamin Maxwell function->setAttr(stringifyArmZaMode(zaMode), unitAttr);
13912648492SCullen Rhodes }
14012648492SCullen Rhodes };
14112648492SCullen Rhodes } // namespace
14212648492SCullen Rhodes
createEnableArmStreamingPass(const ArmStreamingMode streamingMode,const ArmZaMode zaMode,bool ifRequiredByOps,bool ifScalableAndSupported)143783ac3b6SBenjamin Maxwell std::unique_ptr<Pass> mlir::arm_sme::createEnableArmStreamingPass(
144f7d91faaSBenjamin Maxwell const ArmStreamingMode streamingMode, const ArmZaMode zaMode,
145*1b64ed0eSBenjamin Maxwell bool ifRequiredByOps, bool ifScalableAndSupported) {
146d319fc41SBenjamin Maxwell return std::make_unique<EnableArmStreamingPass>(
147*1b64ed0eSBenjamin Maxwell streamingMode, zaMode, ifRequiredByOps, ifScalableAndSupported);
14812648492SCullen Rhodes }
149