xref: /llvm-project/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp (revision 830b9b072d8458ee89c48f00d4de59456c9f467f)
1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13 #include "mlir/Pass/Pass.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // Printing op availability pass
19 //===----------------------------------------------------------------------===//
20 
21 namespace {
22 /// A pass for testing SPIR-V op availability.
23 struct PrintOpAvailability
24     : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> {
25   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability)
26 
27   void runOnOperation() override;
getArgument__anon03031cee0111::PrintOpAvailability28   StringRef getArgument() const final { return "test-spirv-op-availability"; }
getDescription__anon03031cee0111::PrintOpAvailability29   StringRef getDescription() const final {
30     return "Test SPIR-V op availability";
31   }
32 };
33 } // namespace
34 
runOnOperation()35 void PrintOpAvailability::runOnOperation() {
36   auto f = getOperation();
37   llvm::outs() << f.getName() << "\n";
38 
39   Dialect *spirvDialect = getContext().getLoadedDialect("spirv");
40 
41   f->walk([&](Operation *op) {
42     if (op->getDialect() != spirvDialect)
43       return WalkResult::advance();
44 
45     auto opName = op->getName();
46     auto &os = llvm::outs();
47 
48     if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
49       std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
50       os << opName << " min version: ";
51       if (minVersion)
52         os << spirv::stringifyVersion(*minVersion) << "\n";
53       else
54         os << "None\n";
55     }
56 
57     if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
58       std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
59       os << opName << " max version: ";
60       if (maxVersion)
61         os << spirv::stringifyVersion(*maxVersion) << "\n";
62       else
63         os << "None\n";
64     }
65 
66     if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
67       os << opName << " extensions: [";
68       for (const auto &exts : extension.getExtensions()) {
69         os << " [";
70         llvm::interleaveComma(exts, os, [&](spirv::Extension ext) {
71           os << spirv::stringifyExtension(ext);
72         });
73         os << "]";
74       }
75       os << " ]\n";
76     }
77 
78     if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
79       os << opName << " capabilities: [";
80       for (const auto &caps : capability.getCapabilities()) {
81         os << " [";
82         llvm::interleaveComma(caps, os, [&](spirv::Capability cap) {
83           os << spirv::stringifyCapability(cap);
84         });
85         os << "]";
86       }
87       os << " ]\n";
88     }
89     os.flush();
90 
91     return WalkResult::advance();
92   });
93 }
94 
95 namespace mlir {
registerPrintSpirvAvailabilityPass()96 void registerPrintSpirvAvailabilityPass() {
97   PassRegistration<PrintOpAvailability>();
98 }
99 } // namespace mlir
100 
101 //===----------------------------------------------------------------------===//
102 // Converting target environment pass
103 //===----------------------------------------------------------------------===//
104 
105 namespace {
106 /// A pass for testing SPIR-V op availability.
107 struct ConvertToTargetEnv
108     : public PassWrapper<ConvertToTargetEnv, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon03031cee0511::ConvertToTargetEnv109   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToTargetEnv)
110 
111   StringRef getArgument() const override { return "test-spirv-target-env"; }
getDescription__anon03031cee0511::ConvertToTargetEnv112   StringRef getDescription() const override {
113     return "Test SPIR-V target environment";
114   }
115   void runOnOperation() override;
116 };
117 
118 struct ConvertToAtomCmpExchangeWeak : RewritePattern {
ConvertToAtomCmpExchangeWeak__anon03031cee0511::ConvertToAtomCmpExchangeWeak119   ConvertToAtomCmpExchangeWeak(MLIRContext *context)
120       : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1,
121                        context, {"spirv.AtomicCompareExchangeWeak"}) {}
122 
matchAndRewrite__anon03031cee0511::ConvertToAtomCmpExchangeWeak123   LogicalResult matchAndRewrite(Operation *op,
124                                 PatternRewriter &rewriter) const override {
125     Value ptr = op->getOperand(0);
126     Value value = op->getOperand(1);
127     Value comparator = op->getOperand(2);
128 
129     // Create a spirv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits
130     // in memory semantics to additionally require AtomicStorage capability.
131     rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
132         op, value.getType(), ptr, spirv::Scope::Workgroup,
133         spirv::MemorySemantics::AcquireRelease |
134             spirv::MemorySemantics::AtomicCounterMemory,
135         spirv::MemorySemantics::Acquire, value, comparator);
136     return success();
137   }
138 };
139 
140 struct ConvertToBitReverse : RewritePattern {
ConvertToBitReverse__anon03031cee0511::ConvertToBitReverse141   ConvertToBitReverse(MLIRContext *context)
142       : RewritePattern("test.convert_to_bit_reverse_op", 1, context,
143                        {"spirv.BitReverse"}) {}
144 
matchAndRewrite__anon03031cee0511::ConvertToBitReverse145   LogicalResult matchAndRewrite(Operation *op,
146                                 PatternRewriter &rewriter) const override {
147     Value predicate = op->getOperand(0);
148     rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
149         op, op->getResult(0).getType(), predicate);
150     return success();
151   }
152 };
153 
154 struct ConvertToGroupNonUniformBallot : RewritePattern {
ConvertToGroupNonUniformBallot__anon03031cee0511::ConvertToGroupNonUniformBallot155   ConvertToGroupNonUniformBallot(MLIRContext *context)
156       : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1,
157                        context, {"spirv.GroupNonUniformBallot"}) {}
158 
matchAndRewrite__anon03031cee0511::ConvertToGroupNonUniformBallot159   LogicalResult matchAndRewrite(Operation *op,
160                                 PatternRewriter &rewriter) const override {
161     Value predicate = op->getOperand(0);
162     rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
163         op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
164     return success();
165   }
166 };
167 
168 struct ConvertToModule : RewritePattern {
ConvertToModule__anon03031cee0511::ConvertToModule169   ConvertToModule(MLIRContext *context)
170       : RewritePattern("test.convert_to_module_op", 1, context,
171                        {"spirv.module"}) {}
172 
matchAndRewrite__anon03031cee0511::ConvertToModule173   LogicalResult matchAndRewrite(Operation *op,
174                                 PatternRewriter &rewriter) const override {
175     rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
176         op, spirv::AddressingModel::PhysicalStorageBuffer64,
177         spirv::MemoryModel::Vulkan);
178     return success();
179   }
180 };
181 
182 struct ConvertToSubgroupBallot : RewritePattern {
ConvertToSubgroupBallot__anon03031cee0511::ConvertToSubgroupBallot183   ConvertToSubgroupBallot(MLIRContext *context)
184       : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context,
185                        {"spirv.KHR.SubgroupBallot"}) {}
186 
matchAndRewrite__anon03031cee0511::ConvertToSubgroupBallot187   LogicalResult matchAndRewrite(Operation *op,
188                                 PatternRewriter &rewriter) const override {
189     Value predicate = op->getOperand(0);
190     rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
191         op, op->getResult(0).getType(), predicate);
192     return success();
193   }
194 };
195 
196 template <const char *TestOpName, typename SPIRVOp>
197 struct ConvertToIntegerDotProd : RewritePattern {
ConvertToIntegerDotProd__anon03031cee0511::ConvertToIntegerDotProd198   ConvertToIntegerDotProd(MLIRContext *context)
199       : RewritePattern(TestOpName, 1, context, {SPIRVOp::getOperationName()}) {}
200 
matchAndRewrite__anon03031cee0511::ConvertToIntegerDotProd201   LogicalResult matchAndRewrite(Operation *op,
202                                 PatternRewriter &rewriter) const override {
203     rewriter.replaceOpWithNewOp<SPIRVOp>(op, op->getResultTypes(),
204                                          op->getOperands(), op->getAttrs());
205     return success();
206   }
207 };
208 } // namespace
209 
runOnOperation()210 void ConvertToTargetEnv::runOnOperation() {
211   MLIRContext *context = &getContext();
212   func::FuncOp fn = getOperation();
213 
214   auto targetEnv = dyn_cast_or_null<spirv::TargetEnvAttr>(
215       fn.getOperation()->getDiscardableAttr(spirv::getTargetEnvAttrName()));
216   if (!targetEnv) {
217     fn.emitError("missing 'spirv.target_env' attribute");
218     return signalPassFailure();
219   }
220 
221   auto target = SPIRVConversionTarget::get(targetEnv);
222 
223   static constexpr char sDotTestOpName[] = "test.convert_to_sdot_op";
224   static constexpr char suDotTestOpName[] = "test.convert_to_sudot_op";
225   static constexpr char uDotTestOpName[] = "test.convert_to_udot_op";
226   static constexpr char sDotAccSatTestOpName[] =
227       "test.convert_to_sdot_acc_sat_op";
228   static constexpr char suDotAccSatTestOpName[] =
229       "test.convert_to_sudot_acc_sat_op";
230   static constexpr char uDotAccSatTestOpName[] =
231       "test.convert_to_udot_acc_sat_op";
232 
233   RewritePatternSet patterns(context);
234   patterns.add<
235       ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
236       ConvertToGroupNonUniformBallot, ConvertToModule, ConvertToSubgroupBallot,
237       ConvertToIntegerDotProd<sDotTestOpName, spirv::SDotOp>,
238       ConvertToIntegerDotProd<suDotTestOpName, spirv::SUDotOp>,
239       ConvertToIntegerDotProd<uDotTestOpName, spirv::UDotOp>,
240       ConvertToIntegerDotProd<sDotAccSatTestOpName, spirv::SDotAccSatOp>,
241       ConvertToIntegerDotProd<suDotAccSatTestOpName, spirv::SUDotAccSatOp>,
242       ConvertToIntegerDotProd<uDotAccSatTestOpName, spirv::UDotAccSatOp>>(
243       context);
244 
245   if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
246     return signalPassFailure();
247 }
248 
249 namespace mlir {
registerConvertToTargetEnvPass()250 void registerConvertToTargetEnvPass() {
251   PassRegistration<ConvertToTargetEnv>();
252 }
253 } // namespace mlir
254