xref: /llvm-project/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp (revision 830b9b072d8458ee89c48f00d4de59456c9f467f)
18358ddbeSLei Zhang //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
28358ddbeSLei Zhang //
38358ddbeSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48358ddbeSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
58358ddbeSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68358ddbeSLei Zhang //
78358ddbeSLei Zhang //===----------------------------------------------------------------------===//
88358ddbeSLei Zhang 
936550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1016672dbaSJakub Kuderski #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
1101178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1201178654SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
138358ddbeSLei Zhang #include "mlir/Pass/Pass.h"
148358ddbeSLei Zhang 
158358ddbeSLei Zhang using namespace mlir;
168358ddbeSLei Zhang 
178358ddbeSLei Zhang //===----------------------------------------------------------------------===//
188358ddbeSLei Zhang // Printing op availability pass
198358ddbeSLei Zhang //===----------------------------------------------------------------------===//
208358ddbeSLei Zhang 
218358ddbeSLei Zhang namespace {
228358ddbeSLei Zhang /// A pass for testing SPIR-V op availability.
2380aca1eaSRiver Riddle struct PrintOpAvailability
2458ceae95SRiver Riddle     : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> {
255e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability)
265e50dd04SRiver Riddle 
2741574554SRiver Riddle   void runOnOperation() override;
getArgument__anon03031cee0111::PrintOpAvailability28b5e22e6dSMehdi Amini   StringRef getArgument() const final { return "test-spirv-op-availability"; }
getDescription__anon03031cee0111::PrintOpAvailability29b5e22e6dSMehdi Amini   StringRef getDescription() const final {
30b5e22e6dSMehdi Amini     return "Test SPIR-V op availability";
31b5e22e6dSMehdi Amini   }
328358ddbeSLei Zhang };
33be0a7e9fSMehdi Amini } // namespace
348358ddbeSLei Zhang 
runOnOperation()3541574554SRiver Riddle void PrintOpAvailability::runOnOperation() {
3641574554SRiver Riddle   auto f = getOperation();
378358ddbeSLei Zhang   llvm::outs() << f.getName() << "\n";
388358ddbeSLei Zhang 
395ab6ef75SJakub Kuderski   Dialect *spirvDialect = getContext().getLoadedDialect("spirv");
408358ddbeSLei Zhang 
41c4a04059SChristian Sigg   f->walk([&](Operation *op) {
425ab6ef75SJakub Kuderski     if (op->getDialect() != spirvDialect)
438358ddbeSLei Zhang       return WalkResult::advance();
448358ddbeSLei Zhang 
458358ddbeSLei Zhang     auto opName = op->getName();
468358ddbeSLei Zhang     auto &os = llvm::outs();
478358ddbeSLei Zhang 
48cb395f66SLei Zhang     if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
49e8bcc37fSRamkumar Ramachandra       std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
50cb395f66SLei Zhang       os << opName << " min version: ";
51cb395f66SLei Zhang       if (minVersion)
52cb395f66SLei Zhang         os << spirv::stringifyVersion(*minVersion) << "\n";
53cb395f66SLei Zhang       else
54cb395f66SLei Zhang         os << "None\n";
55cb395f66SLei Zhang     }
568358ddbeSLei Zhang 
57cb395f66SLei Zhang     if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
58e8bcc37fSRamkumar Ramachandra       std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
59cb395f66SLei Zhang       os << opName << " max version: ";
60cb395f66SLei Zhang       if (maxVersion)
61cb395f66SLei Zhang         os << spirv::stringifyVersion(*maxVersion) << "\n";
62cb395f66SLei Zhang       else
63cb395f66SLei Zhang         os << "None\n";
64cb395f66SLei Zhang     }
658358ddbeSLei Zhang 
668358ddbeSLei Zhang     if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
678358ddbeSLei Zhang       os << opName << " extensions: [";
688358ddbeSLei Zhang       for (const auto &exts : extension.getExtensions()) {
698358ddbeSLei Zhang         os << " [";
702f21a579SRiver Riddle         llvm::interleaveComma(exts, os, [&](spirv::Extension ext) {
718358ddbeSLei Zhang           os << spirv::stringifyExtension(ext);
728358ddbeSLei Zhang         });
738358ddbeSLei Zhang         os << "]";
748358ddbeSLei Zhang       }
758358ddbeSLei Zhang       os << " ]\n";
768358ddbeSLei Zhang     }
778358ddbeSLei Zhang 
788358ddbeSLei Zhang     if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
798358ddbeSLei Zhang       os << opName << " capabilities: [";
808358ddbeSLei Zhang       for (const auto &caps : capability.getCapabilities()) {
818358ddbeSLei Zhang         os << " [";
822f21a579SRiver Riddle         llvm::interleaveComma(caps, os, [&](spirv::Capability cap) {
838358ddbeSLei Zhang           os << spirv::stringifyCapability(cap);
848358ddbeSLei Zhang         });
858358ddbeSLei Zhang         os << "]";
868358ddbeSLei Zhang       }
878358ddbeSLei Zhang       os << " ]\n";
888358ddbeSLei Zhang     }
898358ddbeSLei Zhang     os.flush();
908358ddbeSLei Zhang 
918358ddbeSLei Zhang     return WalkResult::advance();
928358ddbeSLei Zhang   });
938358ddbeSLei Zhang }
948358ddbeSLei Zhang 
958358ddbeSLei Zhang namespace mlir {
registerPrintSpirvAvailabilityPass()96cb395f66SLei Zhang void registerPrintSpirvAvailabilityPass() {
97b5e22e6dSMehdi Amini   PassRegistration<PrintOpAvailability>();
988358ddbeSLei Zhang }
998358ddbeSLei Zhang } // namespace mlir
1008358ddbeSLei Zhang 
1018358ddbeSLei Zhang //===----------------------------------------------------------------------===//
1028358ddbeSLei Zhang // Converting target environment pass
1038358ddbeSLei Zhang //===----------------------------------------------------------------------===//
1048358ddbeSLei Zhang 
1058358ddbeSLei Zhang namespace {
1068358ddbeSLei Zhang /// A pass for testing SPIR-V op availability.
10780aca1eaSRiver Riddle struct ConvertToTargetEnv
10858ceae95SRiver Riddle     : public PassWrapper<ConvertToTargetEnv, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon03031cee0511::ConvertToTargetEnv1095e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToTargetEnv)
1105e50dd04SRiver Riddle 
111b5e22e6dSMehdi Amini   StringRef getArgument() const override { return "test-spirv-target-env"; }
getDescription__anon03031cee0511::ConvertToTargetEnv112b5e22e6dSMehdi Amini   StringRef getDescription() const override {
113b5e22e6dSMehdi Amini     return "Test SPIR-V target environment";
114b5e22e6dSMehdi Amini   }
11541574554SRiver Riddle   void runOnOperation() override;
1168358ddbeSLei Zhang };
1178358ddbeSLei Zhang 
118b2536281SJakub Kuderski struct ConvertToAtomCmpExchangeWeak : RewritePattern {
ConvertToAtomCmpExchangeWeak__anon03031cee0511::ConvertToAtomCmpExchangeWeak119b2536281SJakub Kuderski   ConvertToAtomCmpExchangeWeak(MLIRContext *context)
120b2536281SJakub Kuderski       : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1,
121b2536281SJakub Kuderski                        context, {"spirv.AtomicCompareExchangeWeak"}) {}
122b2536281SJakub Kuderski 
matchAndRewrite__anon03031cee0511::ConvertToAtomCmpExchangeWeak1233145427dSRiver Riddle   LogicalResult matchAndRewrite(Operation *op,
124b2536281SJakub Kuderski                                 PatternRewriter &rewriter) const override {
125b2536281SJakub Kuderski     Value ptr = op->getOperand(0);
126b2536281SJakub Kuderski     Value value = op->getOperand(1);
127b2536281SJakub Kuderski     Value comparator = op->getOperand(2);
128b2536281SJakub Kuderski 
129b2536281SJakub Kuderski     // Create a spirv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits
130b2536281SJakub Kuderski     // in memory semantics to additionally require AtomicStorage capability.
131b2536281SJakub Kuderski     rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
132b2536281SJakub Kuderski         op, value.getType(), ptr, spirv::Scope::Workgroup,
133b2536281SJakub Kuderski         spirv::MemorySemantics::AcquireRelease |
134b2536281SJakub Kuderski             spirv::MemorySemantics::AtomicCounterMemory,
135b2536281SJakub Kuderski         spirv::MemorySemantics::Acquire, value, comparator);
136b2536281SJakub Kuderski     return success();
137b2536281SJakub Kuderski   }
1388358ddbeSLei Zhang };
1398358ddbeSLei Zhang 
140b2536281SJakub Kuderski struct ConvertToBitReverse : RewritePattern {
ConvertToBitReverse__anon03031cee0511::ConvertToBitReverse141b2536281SJakub Kuderski   ConvertToBitReverse(MLIRContext *context)
142b2536281SJakub Kuderski       : RewritePattern("test.convert_to_bit_reverse_op", 1, context,
143b2536281SJakub Kuderski                        {"spirv.BitReverse"}) {}
144b2536281SJakub Kuderski 
matchAndRewrite__anon03031cee0511::ConvertToBitReverse1453145427dSRiver Riddle   LogicalResult matchAndRewrite(Operation *op,
146b2536281SJakub Kuderski                                 PatternRewriter &rewriter) const override {
147b2536281SJakub Kuderski     Value predicate = op->getOperand(0);
148b2536281SJakub Kuderski     rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
149b2536281SJakub Kuderski         op, op->getResult(0).getType(), predicate);
150b2536281SJakub Kuderski     return success();
151b2536281SJakub Kuderski   }
1528358ddbeSLei Zhang };
1538358ddbeSLei Zhang 
154b2536281SJakub Kuderski struct ConvertToGroupNonUniformBallot : RewritePattern {
ConvertToGroupNonUniformBallot__anon03031cee0511::ConvertToGroupNonUniformBallot155b2536281SJakub Kuderski   ConvertToGroupNonUniformBallot(MLIRContext *context)
156b2536281SJakub Kuderski       : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1,
157b2536281SJakub Kuderski                        context, {"spirv.GroupNonUniformBallot"}) {}
158b2536281SJakub Kuderski 
matchAndRewrite__anon03031cee0511::ConvertToGroupNonUniformBallot1593145427dSRiver Riddle   LogicalResult matchAndRewrite(Operation *op,
160b2536281SJakub Kuderski                                 PatternRewriter &rewriter) const override {
161b2536281SJakub Kuderski     Value predicate = op->getOperand(0);
162b2536281SJakub Kuderski     rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
163b2536281SJakub Kuderski         op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
164b2536281SJakub Kuderski     return success();
165b2536281SJakub Kuderski   }
1668358ddbeSLei Zhang };
1678358ddbeSLei Zhang 
168b2536281SJakub Kuderski struct ConvertToModule : RewritePattern {
ConvertToModule__anon03031cee0511::ConvertToModule169b2536281SJakub Kuderski   ConvertToModule(MLIRContext *context)
170b2536281SJakub Kuderski       : RewritePattern("test.convert_to_module_op", 1, context,
171b2536281SJakub Kuderski                        {"spirv.module"}) {}
172b2536281SJakub Kuderski 
matchAndRewrite__anon03031cee0511::ConvertToModule1733145427dSRiver Riddle   LogicalResult matchAndRewrite(Operation *op,
174b2536281SJakub Kuderski                                 PatternRewriter &rewriter) const override {
175b2536281SJakub Kuderski     rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
176b2536281SJakub Kuderski         op, spirv::AddressingModel::PhysicalStorageBuffer64,
177b2536281SJakub Kuderski         spirv::MemoryModel::Vulkan);
178b2536281SJakub Kuderski     return success();
179b2536281SJakub Kuderski   }
1808358ddbeSLei Zhang };
1818358ddbeSLei Zhang 
182b2536281SJakub Kuderski struct ConvertToSubgroupBallot : RewritePattern {
ConvertToSubgroupBallot__anon03031cee0511::ConvertToSubgroupBallot183b2536281SJakub Kuderski   ConvertToSubgroupBallot(MLIRContext *context)
184b2536281SJakub Kuderski       : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context,
185b2536281SJakub Kuderski                        {"spirv.KHR.SubgroupBallot"}) {}
186b2536281SJakub Kuderski 
matchAndRewrite__anon03031cee0511::ConvertToSubgroupBallot1873145427dSRiver Riddle   LogicalResult matchAndRewrite(Operation *op,
188b2536281SJakub Kuderski                                 PatternRewriter &rewriter) const override {
189b2536281SJakub Kuderski     Value predicate = op->getOperand(0);
190b2536281SJakub Kuderski     rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
191b2536281SJakub Kuderski         op, op->getResult(0).getType(), predicate);
192b2536281SJakub Kuderski     return success();
193b2536281SJakub Kuderski   }
1948358ddbeSLei Zhang };
19503e6bf5fSJakub Kuderski 
19603e6bf5fSJakub Kuderski template <const char *TestOpName, typename SPIRVOp>
19703e6bf5fSJakub Kuderski struct ConvertToIntegerDotProd : RewritePattern {
ConvertToIntegerDotProd__anon03031cee0511::ConvertToIntegerDotProd19803e6bf5fSJakub Kuderski   ConvertToIntegerDotProd(MLIRContext *context)
19903e6bf5fSJakub Kuderski       : RewritePattern(TestOpName, 1, context, {SPIRVOp::getOperationName()}) {}
20003e6bf5fSJakub Kuderski 
matchAndRewrite__anon03031cee0511::ConvertToIntegerDotProd20103e6bf5fSJakub Kuderski   LogicalResult matchAndRewrite(Operation *op,
20203e6bf5fSJakub Kuderski                                 PatternRewriter &rewriter) const override {
20303e6bf5fSJakub Kuderski     rewriter.replaceOpWithNewOp<SPIRVOp>(op, op->getResultTypes(),
20403e6bf5fSJakub Kuderski                                          op->getOperands(), op->getAttrs());
20503e6bf5fSJakub Kuderski     return success();
20603e6bf5fSJakub Kuderski   }
20703e6bf5fSJakub Kuderski };
208be0a7e9fSMehdi Amini } // namespace
2098358ddbeSLei Zhang 
runOnOperation()21041574554SRiver Riddle void ConvertToTargetEnv::runOnOperation() {
2118358ddbeSLei Zhang   MLIRContext *context = &getContext();
21258ceae95SRiver Riddle   func::FuncOp fn = getOperation();
2138358ddbeSLei Zhang 
21416672dbaSJakub Kuderski   auto targetEnv = dyn_cast_or_null<spirv::TargetEnvAttr>(
215*830b9b07SMehdi Amini       fn.getOperation()->getDiscardableAttr(spirv::getTargetEnvAttrName()));
21658df5e6dSLei Zhang   if (!targetEnv) {
2175ab6ef75SJakub Kuderski     fn.emitError("missing 'spirv.target_env' attribute");
21858df5e6dSLei Zhang     return signalPassFailure();
21958df5e6dSLei Zhang   }
22058df5e6dSLei Zhang 
2216dd07fa5SLei Zhang   auto target = SPIRVConversionTarget::get(targetEnv);
2228358ddbeSLei Zhang 
22303e6bf5fSJakub Kuderski   static constexpr char sDotTestOpName[] = "test.convert_to_sdot_op";
22403e6bf5fSJakub Kuderski   static constexpr char suDotTestOpName[] = "test.convert_to_sudot_op";
22503e6bf5fSJakub Kuderski   static constexpr char uDotTestOpName[] = "test.convert_to_udot_op";
226f7f4dd67SJakub Kuderski   static constexpr char sDotAccSatTestOpName[] =
227f7f4dd67SJakub Kuderski       "test.convert_to_sdot_acc_sat_op";
228f7f4dd67SJakub Kuderski   static constexpr char suDotAccSatTestOpName[] =
229f7f4dd67SJakub Kuderski       "test.convert_to_sudot_acc_sat_op";
230f7f4dd67SJakub Kuderski   static constexpr char uDotAccSatTestOpName[] =
231f7f4dd67SJakub Kuderski       "test.convert_to_udot_acc_sat_op";
23203e6bf5fSJakub Kuderski 
233dc4e913bSChris Lattner   RewritePatternSet patterns(context);
234f7f4dd67SJakub Kuderski   patterns.add<
235f7f4dd67SJakub Kuderski       ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
236f7f4dd67SJakub Kuderski       ConvertToGroupNonUniformBallot, ConvertToModule, ConvertToSubgroupBallot,
23703e6bf5fSJakub Kuderski       ConvertToIntegerDotProd<sDotTestOpName, spirv::SDotOp>,
23803e6bf5fSJakub Kuderski       ConvertToIntegerDotProd<suDotTestOpName, spirv::SUDotOp>,
239f7f4dd67SJakub Kuderski       ConvertToIntegerDotProd<uDotTestOpName, spirv::UDotOp>,
240f7f4dd67SJakub Kuderski       ConvertToIntegerDotProd<sDotAccSatTestOpName, spirv::SDotAccSatOp>,
241f7f4dd67SJakub Kuderski       ConvertToIntegerDotProd<suDotAccSatTestOpName, spirv::SUDotAccSatOp>,
242f7f4dd67SJakub Kuderski       ConvertToIntegerDotProd<uDotAccSatTestOpName, spirv::UDotAccSatOp>>(
243f7f4dd67SJakub Kuderski       context);
2448358ddbeSLei Zhang 
2453fffffa8SRiver Riddle   if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
2468358ddbeSLei Zhang     return signalPassFailure();
2478358ddbeSLei Zhang }
2488358ddbeSLei Zhang 
2498358ddbeSLei Zhang namespace mlir {
registerConvertToTargetEnvPass()2508358ddbeSLei Zhang void registerConvertToTargetEnvPass() {
251b5e22e6dSMehdi Amini   PassRegistration<ConvertToTargetEnv>();
2528358ddbeSLei Zhang }
2538358ddbeSLei Zhang } // namespace mlir
254