1 //===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
10 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
11 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
12 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
13 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/DialectRegistry.h"
16 #include "llvm/Support/Debug.h"
17
18 #define DEBUG_TYPE "tosa-sharding-impl"
19 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
20
21 using namespace mlir;
22 using namespace mlir::tosa;
23 using namespace mlir::mesh;
24
25 namespace {
26
27 // loop types: [parallel, parallel, parallel, reduction_sum]
28 // indexing maps:
29 // (d0, d1, d2, d3) -> (d0, d1, d3)
30 // (d0, d1, d2, d3) -> (d0, d3, d2)
31 // (d0, d1, d2, d3) -> (d0, d1, d2)
32 struct MatMulOpSharding
33 : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
getLoopIteratorTypes__anondf8635240111::MatMulOpSharding34 SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
35 auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
36 if (!tensorType)
37 return {};
38
39 SmallVector<utils::IteratorType> types(tensorType.getRank() + 1,
40 utils::IteratorType::parallel);
41 types[tensorType.getRank()] = utils::IteratorType::reduction;
42 return types;
43 }
44
45 SmallVector<ReductionKind>
getReductionLoopIteratorKinds__anondf8635240111::MatMulOpSharding46 getReductionLoopIteratorKinds(Operation *op) const {
47 return SmallVector<ReductionKind>(1, ReductionKind::Sum);
48 }
49
getIndexingMaps__anondf8635240111::MatMulOpSharding50 SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
51 auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
52 if (!tensorType)
53 return {};
54 MLIRContext *ctx = op->getContext();
55 SmallVector<AffineMap> maps;
56 maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
57 maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
58 maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
59 return maps;
60 }
61 };
62
63 template <typename OpType>
registerElemwiseOne(MLIRContext * ctx)64 static void registerElemwiseOne(MLIRContext *ctx) {
65 OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
66 }
67
68 /// Variadic helper function.
69 template <typename... OpTypes>
registerElemwiseAll(MLIRContext * ctx)70 static void registerElemwiseAll(MLIRContext *ctx) {
71 (registerElemwiseOne<OpTypes>(ctx), ...);
72 }
73
74 } // namespace
75
registerShardingInterfaceExternalModels(DialectRegistry & registry)76 void mlir::tosa::registerShardingInterfaceExternalModels(
77 DialectRegistry ®istry) {
78
79 registry.addExtension(+[](MLIRContext *ctx, TosaDialect *dialect) {
80 registerElemwiseAll<
81 ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
82 BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
83 LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
84 MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
85 LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
86 GreaterOp, GreaterEqualOp>(ctx);
87
88 MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
89 });
90 }
91