1 //===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
10 #include "TransformsDetail.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
13 #include "mlir/IR/BuiltinTypeInterfaces.h"
14 #include "mlir/IR/ImplicitLocOpBuilder.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include <numeric>
20 #include <utility>
21
22 namespace mlir {
23 namespace mesh {
24
populateSimplificationPatterns(RewritePatternSet & patterns,SymbolTableCollection & symbolTableCollection)25 void populateSimplificationPatterns(
26 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
27 populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
28 patterns, ReductionKind::Sum);
29 populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
30 patterns, ReductionKind::Sum);
31
32 populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
33 patterns, ReductionKind::Min);
34 populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
35 patterns, ReductionKind::Min);
36 populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
37 patterns, ReductionKind::Min);
38
39 populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
40 patterns, ReductionKind::Max);
41 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
42 patterns, ReductionKind::Max);
43 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
44 patterns, ReductionKind::Max);
45
46 // TODO: add simplifications for all-gather and other collectives.
47
48 populateFoldingPatterns(patterns, symbolTableCollection);
49 }
50
51 namespace {
52
53 // This folding can not be done with an operation's fold method or
54 // DialectFoldInterface, because it needs a SymbolTableCollection to cache the
55 // symbol tables.
56 // We can't use DialectFoldInterface since the cache may be invalidated by some
57 // pass changing the referenced MeshOp ops.
58 struct MeshShapeFolder
59 : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
60 using OpRewritePatternWithSymbolTableCollection::
61 OpRewritePatternWithSymbolTableCollection;
matchAndRewritemlir::mesh::__anon2a0f5f890111::MeshShapeFolder62 LogicalResult matchAndRewrite(MeshShapeOp op,
63 PatternRewriter &rewriter) const override {
64 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
65 MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
66 op.getOperation(), op.getMeshAttr());
67 if (!mesh) {
68 return failure();
69 }
70 ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
71 SmallVector<MeshAxis> opAxesIota;
72 if (opMeshAxes.empty()) {
73 opAxesIota.resize(mesh.getRank());
74 std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
75 opMeshAxes = opAxesIota;
76 }
77 if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
78 return ShapedType::isDynamic(mesh.getShape()[axis]);
79 })) {
80 // All mesh dimensions are dynamic. Nothing to fold.
81 return failure();
82 }
83
84 SmallVector<Value> newResults(op->getResults().size());
85 SmallVector<MeshAxis> newShapeOpMeshAxes;
86 SmallVector<size_t> newToOldResultsIndexMap;
87
88 for (size_t i = 0; i < opMeshAxes.size(); ++i) {
89 auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
90 if (ShapedType::isDynamic(meshAxisSize)) {
91 newToOldResultsIndexMap.push_back(i);
92 newShapeOpMeshAxes.push_back(opMeshAxes[i]);
93 } else {
94 // Fold static mesh axes.
95 newResults[i] = builder.create<arith::ConstantOp>(
96 builder.getIndexAttr(meshAxisSize));
97 }
98 }
99
100 // Leave only the dynamic mesh axes to be queried.
101 if (!newShapeOpMeshAxes.empty()) {
102 MeshShapeOp newShapeOp =
103 builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
104 for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
105 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
106 }
107 }
108 rewriter.replaceOp(op, newResults);
109
110 return success();
111 }
112 };
113
114 } // namespace
115
populateFoldingPatterns(RewritePatternSet & patterns,SymbolTableCollection & symbolTableCollection)116 void populateFoldingPatterns(RewritePatternSet &patterns,
117 SymbolTableCollection &symbolTableCollection) {
118 patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext());
119 }
120
121 } // namespace mesh
122 } // namespace mlir
123