18358ddbeSLei Zhang //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
28358ddbeSLei Zhang //
38358ddbeSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48358ddbeSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
58358ddbeSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68358ddbeSLei Zhang //
78358ddbeSLei Zhang //===----------------------------------------------------------------------===//
88358ddbeSLei Zhang
936550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
1016672dbaSJakub Kuderski #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
1101178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1201178654SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
138358ddbeSLei Zhang #include "mlir/Pass/Pass.h"
148358ddbeSLei Zhang
158358ddbeSLei Zhang using namespace mlir;
168358ddbeSLei Zhang
178358ddbeSLei Zhang //===----------------------------------------------------------------------===//
188358ddbeSLei Zhang // Printing op availability pass
198358ddbeSLei Zhang //===----------------------------------------------------------------------===//
208358ddbeSLei Zhang
218358ddbeSLei Zhang namespace {
228358ddbeSLei Zhang /// A pass for testing SPIR-V op availability.
2380aca1eaSRiver Riddle struct PrintOpAvailability
2458ceae95SRiver Riddle : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> {
255e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability)
265e50dd04SRiver Riddle
2741574554SRiver Riddle void runOnOperation() override;
getArgument__anon03031cee0111::PrintOpAvailability28b5e22e6dSMehdi Amini StringRef getArgument() const final { return "test-spirv-op-availability"; }
getDescription__anon03031cee0111::PrintOpAvailability29b5e22e6dSMehdi Amini StringRef getDescription() const final {
30b5e22e6dSMehdi Amini return "Test SPIR-V op availability";
31b5e22e6dSMehdi Amini }
328358ddbeSLei Zhang };
33be0a7e9fSMehdi Amini } // namespace
348358ddbeSLei Zhang
runOnOperation()3541574554SRiver Riddle void PrintOpAvailability::runOnOperation() {
3641574554SRiver Riddle auto f = getOperation();
378358ddbeSLei Zhang llvm::outs() << f.getName() << "\n";
388358ddbeSLei Zhang
395ab6ef75SJakub Kuderski Dialect *spirvDialect = getContext().getLoadedDialect("spirv");
408358ddbeSLei Zhang
41c4a04059SChristian Sigg f->walk([&](Operation *op) {
425ab6ef75SJakub Kuderski if (op->getDialect() != spirvDialect)
438358ddbeSLei Zhang return WalkResult::advance();
448358ddbeSLei Zhang
458358ddbeSLei Zhang auto opName = op->getName();
468358ddbeSLei Zhang auto &os = llvm::outs();
478358ddbeSLei Zhang
48cb395f66SLei Zhang if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
49e8bcc37fSRamkumar Ramachandra std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
50cb395f66SLei Zhang os << opName << " min version: ";
51cb395f66SLei Zhang if (minVersion)
52cb395f66SLei Zhang os << spirv::stringifyVersion(*minVersion) << "\n";
53cb395f66SLei Zhang else
54cb395f66SLei Zhang os << "None\n";
55cb395f66SLei Zhang }
568358ddbeSLei Zhang
57cb395f66SLei Zhang if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
58e8bcc37fSRamkumar Ramachandra std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
59cb395f66SLei Zhang os << opName << " max version: ";
60cb395f66SLei Zhang if (maxVersion)
61cb395f66SLei Zhang os << spirv::stringifyVersion(*maxVersion) << "\n";
62cb395f66SLei Zhang else
63cb395f66SLei Zhang os << "None\n";
64cb395f66SLei Zhang }
658358ddbeSLei Zhang
668358ddbeSLei Zhang if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
678358ddbeSLei Zhang os << opName << " extensions: [";
688358ddbeSLei Zhang for (const auto &exts : extension.getExtensions()) {
698358ddbeSLei Zhang os << " [";
702f21a579SRiver Riddle llvm::interleaveComma(exts, os, [&](spirv::Extension ext) {
718358ddbeSLei Zhang os << spirv::stringifyExtension(ext);
728358ddbeSLei Zhang });
738358ddbeSLei Zhang os << "]";
748358ddbeSLei Zhang }
758358ddbeSLei Zhang os << " ]\n";
768358ddbeSLei Zhang }
778358ddbeSLei Zhang
788358ddbeSLei Zhang if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
798358ddbeSLei Zhang os << opName << " capabilities: [";
808358ddbeSLei Zhang for (const auto &caps : capability.getCapabilities()) {
818358ddbeSLei Zhang os << " [";
822f21a579SRiver Riddle llvm::interleaveComma(caps, os, [&](spirv::Capability cap) {
838358ddbeSLei Zhang os << spirv::stringifyCapability(cap);
848358ddbeSLei Zhang });
858358ddbeSLei Zhang os << "]";
868358ddbeSLei Zhang }
878358ddbeSLei Zhang os << " ]\n";
888358ddbeSLei Zhang }
898358ddbeSLei Zhang os.flush();
908358ddbeSLei Zhang
918358ddbeSLei Zhang return WalkResult::advance();
928358ddbeSLei Zhang });
938358ddbeSLei Zhang }
948358ddbeSLei Zhang
958358ddbeSLei Zhang namespace mlir {
registerPrintSpirvAvailabilityPass()96cb395f66SLei Zhang void registerPrintSpirvAvailabilityPass() {
97b5e22e6dSMehdi Amini PassRegistration<PrintOpAvailability>();
988358ddbeSLei Zhang }
998358ddbeSLei Zhang } // namespace mlir
1008358ddbeSLei Zhang
1018358ddbeSLei Zhang //===----------------------------------------------------------------------===//
1028358ddbeSLei Zhang // Converting target environment pass
1038358ddbeSLei Zhang //===----------------------------------------------------------------------===//
1048358ddbeSLei Zhang
1058358ddbeSLei Zhang namespace {
1068358ddbeSLei Zhang /// A pass for testing SPIR-V op availability.
10780aca1eaSRiver Riddle struct ConvertToTargetEnv
10858ceae95SRiver Riddle : public PassWrapper<ConvertToTargetEnv, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon03031cee0511::ConvertToTargetEnv1095e50dd04SRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToTargetEnv)
1105e50dd04SRiver Riddle
111b5e22e6dSMehdi Amini StringRef getArgument() const override { return "test-spirv-target-env"; }
getDescription__anon03031cee0511::ConvertToTargetEnv112b5e22e6dSMehdi Amini StringRef getDescription() const override {
113b5e22e6dSMehdi Amini return "Test SPIR-V target environment";
114b5e22e6dSMehdi Amini }
11541574554SRiver Riddle void runOnOperation() override;
1168358ddbeSLei Zhang };
1178358ddbeSLei Zhang
118b2536281SJakub Kuderski struct ConvertToAtomCmpExchangeWeak : RewritePattern {
ConvertToAtomCmpExchangeWeak__anon03031cee0511::ConvertToAtomCmpExchangeWeak119b2536281SJakub Kuderski ConvertToAtomCmpExchangeWeak(MLIRContext *context)
120b2536281SJakub Kuderski : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1,
121b2536281SJakub Kuderski context, {"spirv.AtomicCompareExchangeWeak"}) {}
122b2536281SJakub Kuderski
matchAndRewrite__anon03031cee0511::ConvertToAtomCmpExchangeWeak1233145427dSRiver Riddle LogicalResult matchAndRewrite(Operation *op,
124b2536281SJakub Kuderski PatternRewriter &rewriter) const override {
125b2536281SJakub Kuderski Value ptr = op->getOperand(0);
126b2536281SJakub Kuderski Value value = op->getOperand(1);
127b2536281SJakub Kuderski Value comparator = op->getOperand(2);
128b2536281SJakub Kuderski
129b2536281SJakub Kuderski // Create a spirv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits
130b2536281SJakub Kuderski // in memory semantics to additionally require AtomicStorage capability.
131b2536281SJakub Kuderski rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
132b2536281SJakub Kuderski op, value.getType(), ptr, spirv::Scope::Workgroup,
133b2536281SJakub Kuderski spirv::MemorySemantics::AcquireRelease |
134b2536281SJakub Kuderski spirv::MemorySemantics::AtomicCounterMemory,
135b2536281SJakub Kuderski spirv::MemorySemantics::Acquire, value, comparator);
136b2536281SJakub Kuderski return success();
137b2536281SJakub Kuderski }
1388358ddbeSLei Zhang };
1398358ddbeSLei Zhang
140b2536281SJakub Kuderski struct ConvertToBitReverse : RewritePattern {
ConvertToBitReverse__anon03031cee0511::ConvertToBitReverse141b2536281SJakub Kuderski ConvertToBitReverse(MLIRContext *context)
142b2536281SJakub Kuderski : RewritePattern("test.convert_to_bit_reverse_op", 1, context,
143b2536281SJakub Kuderski {"spirv.BitReverse"}) {}
144b2536281SJakub Kuderski
matchAndRewrite__anon03031cee0511::ConvertToBitReverse1453145427dSRiver Riddle LogicalResult matchAndRewrite(Operation *op,
146b2536281SJakub Kuderski PatternRewriter &rewriter) const override {
147b2536281SJakub Kuderski Value predicate = op->getOperand(0);
148b2536281SJakub Kuderski rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
149b2536281SJakub Kuderski op, op->getResult(0).getType(), predicate);
150b2536281SJakub Kuderski return success();
151b2536281SJakub Kuderski }
1528358ddbeSLei Zhang };
1538358ddbeSLei Zhang
154b2536281SJakub Kuderski struct ConvertToGroupNonUniformBallot : RewritePattern {
ConvertToGroupNonUniformBallot__anon03031cee0511::ConvertToGroupNonUniformBallot155b2536281SJakub Kuderski ConvertToGroupNonUniformBallot(MLIRContext *context)
156b2536281SJakub Kuderski : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1,
157b2536281SJakub Kuderski context, {"spirv.GroupNonUniformBallot"}) {}
158b2536281SJakub Kuderski
matchAndRewrite__anon03031cee0511::ConvertToGroupNonUniformBallot1593145427dSRiver Riddle LogicalResult matchAndRewrite(Operation *op,
160b2536281SJakub Kuderski PatternRewriter &rewriter) const override {
161b2536281SJakub Kuderski Value predicate = op->getOperand(0);
162b2536281SJakub Kuderski rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
163b2536281SJakub Kuderski op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
164b2536281SJakub Kuderski return success();
165b2536281SJakub Kuderski }
1668358ddbeSLei Zhang };
1678358ddbeSLei Zhang
168b2536281SJakub Kuderski struct ConvertToModule : RewritePattern {
ConvertToModule__anon03031cee0511::ConvertToModule169b2536281SJakub Kuderski ConvertToModule(MLIRContext *context)
170b2536281SJakub Kuderski : RewritePattern("test.convert_to_module_op", 1, context,
171b2536281SJakub Kuderski {"spirv.module"}) {}
172b2536281SJakub Kuderski
matchAndRewrite__anon03031cee0511::ConvertToModule1733145427dSRiver Riddle LogicalResult matchAndRewrite(Operation *op,
174b2536281SJakub Kuderski PatternRewriter &rewriter) const override {
175b2536281SJakub Kuderski rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
176b2536281SJakub Kuderski op, spirv::AddressingModel::PhysicalStorageBuffer64,
177b2536281SJakub Kuderski spirv::MemoryModel::Vulkan);
178b2536281SJakub Kuderski return success();
179b2536281SJakub Kuderski }
1808358ddbeSLei Zhang };
1818358ddbeSLei Zhang
182b2536281SJakub Kuderski struct ConvertToSubgroupBallot : RewritePattern {
ConvertToSubgroupBallot__anon03031cee0511::ConvertToSubgroupBallot183b2536281SJakub Kuderski ConvertToSubgroupBallot(MLIRContext *context)
184b2536281SJakub Kuderski : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context,
185b2536281SJakub Kuderski {"spirv.KHR.SubgroupBallot"}) {}
186b2536281SJakub Kuderski
matchAndRewrite__anon03031cee0511::ConvertToSubgroupBallot1873145427dSRiver Riddle LogicalResult matchAndRewrite(Operation *op,
188b2536281SJakub Kuderski PatternRewriter &rewriter) const override {
189b2536281SJakub Kuderski Value predicate = op->getOperand(0);
190b2536281SJakub Kuderski rewriter.replaceOpWithNewOp<spirv::KHRSubgroupBallotOp>(
191b2536281SJakub Kuderski op, op->getResult(0).getType(), predicate);
192b2536281SJakub Kuderski return success();
193b2536281SJakub Kuderski }
1948358ddbeSLei Zhang };
19503e6bf5fSJakub Kuderski
19603e6bf5fSJakub Kuderski template <const char *TestOpName, typename SPIRVOp>
19703e6bf5fSJakub Kuderski struct ConvertToIntegerDotProd : RewritePattern {
ConvertToIntegerDotProd__anon03031cee0511::ConvertToIntegerDotProd19803e6bf5fSJakub Kuderski ConvertToIntegerDotProd(MLIRContext *context)
19903e6bf5fSJakub Kuderski : RewritePattern(TestOpName, 1, context, {SPIRVOp::getOperationName()}) {}
20003e6bf5fSJakub Kuderski
matchAndRewrite__anon03031cee0511::ConvertToIntegerDotProd20103e6bf5fSJakub Kuderski LogicalResult matchAndRewrite(Operation *op,
20203e6bf5fSJakub Kuderski PatternRewriter &rewriter) const override {
20303e6bf5fSJakub Kuderski rewriter.replaceOpWithNewOp<SPIRVOp>(op, op->getResultTypes(),
20403e6bf5fSJakub Kuderski op->getOperands(), op->getAttrs());
20503e6bf5fSJakub Kuderski return success();
20603e6bf5fSJakub Kuderski }
20703e6bf5fSJakub Kuderski };
208be0a7e9fSMehdi Amini } // namespace
2098358ddbeSLei Zhang
runOnOperation()21041574554SRiver Riddle void ConvertToTargetEnv::runOnOperation() {
2118358ddbeSLei Zhang MLIRContext *context = &getContext();
21258ceae95SRiver Riddle func::FuncOp fn = getOperation();
2138358ddbeSLei Zhang
21416672dbaSJakub Kuderski auto targetEnv = dyn_cast_or_null<spirv::TargetEnvAttr>(
215*830b9b07SMehdi Amini fn.getOperation()->getDiscardableAttr(spirv::getTargetEnvAttrName()));
21658df5e6dSLei Zhang if (!targetEnv) {
2175ab6ef75SJakub Kuderski fn.emitError("missing 'spirv.target_env' attribute");
21858df5e6dSLei Zhang return signalPassFailure();
21958df5e6dSLei Zhang }
22058df5e6dSLei Zhang
2216dd07fa5SLei Zhang auto target = SPIRVConversionTarget::get(targetEnv);
2228358ddbeSLei Zhang
22303e6bf5fSJakub Kuderski static constexpr char sDotTestOpName[] = "test.convert_to_sdot_op";
22403e6bf5fSJakub Kuderski static constexpr char suDotTestOpName[] = "test.convert_to_sudot_op";
22503e6bf5fSJakub Kuderski static constexpr char uDotTestOpName[] = "test.convert_to_udot_op";
226f7f4dd67SJakub Kuderski static constexpr char sDotAccSatTestOpName[] =
227f7f4dd67SJakub Kuderski "test.convert_to_sdot_acc_sat_op";
228f7f4dd67SJakub Kuderski static constexpr char suDotAccSatTestOpName[] =
229f7f4dd67SJakub Kuderski "test.convert_to_sudot_acc_sat_op";
230f7f4dd67SJakub Kuderski static constexpr char uDotAccSatTestOpName[] =
231f7f4dd67SJakub Kuderski "test.convert_to_udot_acc_sat_op";
23203e6bf5fSJakub Kuderski
233dc4e913bSChris Lattner RewritePatternSet patterns(context);
234f7f4dd67SJakub Kuderski patterns.add<
235f7f4dd67SJakub Kuderski ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
236f7f4dd67SJakub Kuderski ConvertToGroupNonUniformBallot, ConvertToModule, ConvertToSubgroupBallot,
23703e6bf5fSJakub Kuderski ConvertToIntegerDotProd<sDotTestOpName, spirv::SDotOp>,
23803e6bf5fSJakub Kuderski ConvertToIntegerDotProd<suDotTestOpName, spirv::SUDotOp>,
239f7f4dd67SJakub Kuderski ConvertToIntegerDotProd<uDotTestOpName, spirv::UDotOp>,
240f7f4dd67SJakub Kuderski ConvertToIntegerDotProd<sDotAccSatTestOpName, spirv::SDotAccSatOp>,
241f7f4dd67SJakub Kuderski ConvertToIntegerDotProd<suDotAccSatTestOpName, spirv::SUDotAccSatOp>,
242f7f4dd67SJakub Kuderski ConvertToIntegerDotProd<uDotAccSatTestOpName, spirv::UDotAccSatOp>>(
243f7f4dd67SJakub Kuderski context);
2448358ddbeSLei Zhang
2453fffffa8SRiver Riddle if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
2468358ddbeSLei Zhang return signalPassFailure();
2478358ddbeSLei Zhang }
2488358ddbeSLei Zhang
2498358ddbeSLei Zhang namespace mlir {
registerConvertToTargetEnvPass()2508358ddbeSLei Zhang void registerConvertToTargetEnvPass() {
251b5e22e6dSMehdi Amini PassRegistration<ConvertToTargetEnv>();
2528358ddbeSLei Zhang }
2538358ddbeSLei Zhang } // namespace mlir
254