xref: /llvm-project/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
10 #include "TransformsDetail.h"
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
13 #include "mlir/IR/BuiltinTypeInterfaces.h"
14 #include "mlir/IR/ImplicitLocOpBuilder.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/SymbolTable.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include <numeric>
20 #include <utility>
21 
22 namespace mlir {
23 namespace mesh {
24 
populateSimplificationPatterns(RewritePatternSet & patterns,SymbolTableCollection & symbolTableCollection)25 void populateSimplificationPatterns(
26     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
27   populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
28       patterns, ReductionKind::Sum);
29   populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
30       patterns, ReductionKind::Sum);
31 
32   populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
33       patterns, ReductionKind::Min);
34   populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
35       patterns, ReductionKind::Min);
36   populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
37       patterns, ReductionKind::Min);
38 
39   populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
40       patterns, ReductionKind::Max);
41   populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
42       patterns, ReductionKind::Max);
43   populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
44       patterns, ReductionKind::Max);
45 
46   // TODO: add simplifications for all-gather and other collectives.
47 
48   populateFoldingPatterns(patterns, symbolTableCollection);
49 }
50 
51 namespace {
52 
53 // This folding can not be done with an operation's fold method or
54 // DialectFoldInterface, because it needs a SymbolTableCollection to cache the
55 // symbol tables.
56 // We can't use DialectFoldInterface since the cache may be invalidated by some
57 // pass changing the referenced MeshOp ops.
58 struct MeshShapeFolder
59     : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
60   using OpRewritePatternWithSymbolTableCollection::
61       OpRewritePatternWithSymbolTableCollection;
matchAndRewritemlir::mesh::__anon2a0f5f890111::MeshShapeFolder62   LogicalResult matchAndRewrite(MeshShapeOp op,
63                                 PatternRewriter &rewriter) const override {
64     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
65     MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
66         op.getOperation(), op.getMeshAttr());
67     if (!mesh) {
68       return failure();
69     }
70     ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
71     SmallVector<MeshAxis> opAxesIota;
72     if (opMeshAxes.empty()) {
73       opAxesIota.resize(mesh.getRank());
74       std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
75       opMeshAxes = opAxesIota;
76     }
77     if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
78           return ShapedType::isDynamic(mesh.getShape()[axis]);
79         })) {
80       // All mesh dimensions are dynamic. Nothing to fold.
81       return failure();
82     }
83 
84     SmallVector<Value> newResults(op->getResults().size());
85     SmallVector<MeshAxis> newShapeOpMeshAxes;
86     SmallVector<size_t> newToOldResultsIndexMap;
87 
88     for (size_t i = 0; i < opMeshAxes.size(); ++i) {
89       auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
90       if (ShapedType::isDynamic(meshAxisSize)) {
91         newToOldResultsIndexMap.push_back(i);
92         newShapeOpMeshAxes.push_back(opMeshAxes[i]);
93       } else {
94         // Fold static mesh axes.
95         newResults[i] = builder.create<arith::ConstantOp>(
96             builder.getIndexAttr(meshAxisSize));
97       }
98     }
99 
100     // Leave only the dynamic mesh axes to be queried.
101     if (!newShapeOpMeshAxes.empty()) {
102       MeshShapeOp newShapeOp =
103           builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
104       for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
105         newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
106       }
107     }
108     rewriter.replaceOp(op, newResults);
109 
110     return success();
111   }
112 };
113 
114 } // namespace
115 
populateFoldingPatterns(RewritePatternSet & patterns,SymbolTableCollection & symbolTableCollection)116 void populateFoldingPatterns(RewritePatternSet &patterns,
117                              SymbolTableCollection &symbolTableCollection) {
118   patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext());
119 }
120 
121 } // namespace mesh
122 } // namespace mlir
123