xref: /llvm-project/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp (revision 82383d5f3fa8289688dcd314f7a89ce5599bbdb2)
1 //===- ShardingInterfaceImpl.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
10 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
11 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
12 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
13 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/DialectRegistry.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "tosa-sharding-impl"
19 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
20 
21 using namespace mlir;
22 using namespace mlir::tosa;
23 using namespace mlir::mesh;
24 
25 namespace {
26 
27 // loop types: [parallel, parallel, parallel, reduction_sum]
28 // indexing maps:
29 // (d0, d1, d2, d3) -> (d0, d1, d3)
30 // (d0, d1, d2, d3) -> (d0, d3, d2)
31 // (d0, d1, d2, d3) -> (d0, d1, d2)
32 struct MatMulOpSharding
33     : public ShardingInterface::ExternalModel<MatMulOpSharding, MatMulOp> {
getLoopIteratorTypes__anondf8635240111::MatMulOpSharding34   SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
35     auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
36     if (!tensorType)
37       return {};
38 
39     SmallVector<utils::IteratorType> types(tensorType.getRank() + 1,
40                                            utils::IteratorType::parallel);
41     types[tensorType.getRank()] = utils::IteratorType::reduction;
42     return types;
43   }
44 
45   SmallVector<ReductionKind>
getReductionLoopIteratorKinds__anondf8635240111::MatMulOpSharding46   getReductionLoopIteratorKinds(Operation *op) const {
47     return SmallVector<ReductionKind>(1, ReductionKind::Sum);
48   }
49 
getIndexingMaps__anondf8635240111::MatMulOpSharding50   SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
51     auto tensorType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
52     if (!tensorType)
53       return {};
54     MLIRContext *ctx = op->getContext();
55     SmallVector<AffineMap> maps;
56     maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
57     maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
58     maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
59     return maps;
60   }
61 };
62 
63 template <typename OpType>
registerElemwiseOne(MLIRContext * ctx)64 static void registerElemwiseOne(MLIRContext *ctx) {
65   OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
66 }
67 
68 /// Variadic helper function.
69 template <typename... OpTypes>
registerElemwiseAll(MLIRContext * ctx)70 static void registerElemwiseAll(MLIRContext *ctx) {
71   (registerElemwiseOne<OpTypes>(ctx), ...);
72 }
73 
74 } // namespace
75 
registerShardingInterfaceExternalModels(DialectRegistry & registry)76 void mlir::tosa::registerShardingInterfaceExternalModels(
77     DialectRegistry &registry) {
78 
79   registry.addExtension(+[](MLIRContext *ctx, TosaDialect *dialect) {
80     registerElemwiseAll<
81         ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp,
82         BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
83         LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
84         MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
85         LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
86         GreaterOp, GreaterEqualOp>(ctx);
87 
88     MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
89   });
90 }
91