179aa7762SBoian Petkantchin //===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
24b344677SBoian Petkantchin //
34b344677SBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44b344677SBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information.
54b344677SBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64b344677SBoian Petkantchin //
74b344677SBoian Petkantchin //===----------------------------------------------------------------------===//
84b344677SBoian Petkantchin
94b344677SBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
10dc3258c6SBoian Petkantchin #include "TransformsDetail.h"
114b344677SBoian Petkantchin #include "mlir/Dialect/Arith/IR/Arith.h"
12ab590377SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshOps.h"
13ab590377SBoian Petkantchin #include "mlir/IR/BuiltinTypeInterfaces.h"
14ab590377SBoian Petkantchin #include "mlir/IR/ImplicitLocOpBuilder.h"
15ab590377SBoian Petkantchin #include "mlir/IR/PatternMatch.h"
16ab590377SBoian Petkantchin #include "mlir/IR/SymbolTable.h"
17ab590377SBoian Petkantchin #include "llvm/ADT/STLExtras.h"
18ab590377SBoian Petkantchin #include "llvm/ADT/SmallVector.h"
19ab590377SBoian Petkantchin #include <numeric>
20ab590377SBoian Petkantchin #include <utility>
214b344677SBoian Petkantchin
224b344677SBoian Petkantchin namespace mlir {
234b344677SBoian Petkantchin namespace mesh {
244b344677SBoian Petkantchin
populateSimplificationPatterns(RewritePatternSet & patterns,SymbolTableCollection & symbolTableCollection)25ab590377SBoian Petkantchin void populateSimplificationPatterns(
26ab590377SBoian Petkantchin RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
274b344677SBoian Petkantchin populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
28*ff2720d1SBoian Petkantchin patterns, ReductionKind::Sum);
294b344677SBoian Petkantchin populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
30*ff2720d1SBoian Petkantchin patterns, ReductionKind::Sum);
314b344677SBoian Petkantchin
324b344677SBoian Petkantchin populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
33*ff2720d1SBoian Petkantchin patterns, ReductionKind::Min);
344b344677SBoian Petkantchin populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
35*ff2720d1SBoian Petkantchin patterns, ReductionKind::Min);
364b344677SBoian Petkantchin populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
37*ff2720d1SBoian Petkantchin patterns, ReductionKind::Min);
384b344677SBoian Petkantchin
394b344677SBoian Petkantchin populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
40*ff2720d1SBoian Petkantchin patterns, ReductionKind::Max);
414b344677SBoian Petkantchin populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
42*ff2720d1SBoian Petkantchin patterns, ReductionKind::Max);
434b344677SBoian Petkantchin populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
44*ff2720d1SBoian Petkantchin patterns, ReductionKind::Max);
454b344677SBoian Petkantchin
464b344677SBoian Petkantchin // TODO: add simplifications for all-gather and other collectives.
47ab590377SBoian Petkantchin
48ab590377SBoian Petkantchin populateFoldingPatterns(patterns, symbolTableCollection);
49ab590377SBoian Petkantchin }
50ab590377SBoian Petkantchin
51ab590377SBoian Petkantchin namespace {
52ab590377SBoian Petkantchin
53ab590377SBoian Petkantchin // This folding can not be done with an operation's fold method or
54ab590377SBoian Petkantchin // DialectFoldInterface, because it needs a SymbolTableCollection to cache the
55ab590377SBoian Petkantchin // symbol tables.
56ab590377SBoian Petkantchin // We can't use DialectFoldInterface since the cache may be invalidated by some
579a8437f5SBoian Petkantchin // pass changing the referenced MeshOp ops.
58dc3258c6SBoian Petkantchin struct MeshShapeFolder
59dc3258c6SBoian Petkantchin : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
60dc3258c6SBoian Petkantchin using OpRewritePatternWithSymbolTableCollection::
61dc3258c6SBoian Petkantchin OpRewritePatternWithSymbolTableCollection;
matchAndRewritemlir::mesh::__anon2a0f5f890111::MeshShapeFolder629a8437f5SBoian Petkantchin LogicalResult matchAndRewrite(MeshShapeOp op,
63ab590377SBoian Petkantchin PatternRewriter &rewriter) const override {
64ab590377SBoian Petkantchin ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
659a8437f5SBoian Petkantchin MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
66ab590377SBoian Petkantchin op.getOperation(), op.getMeshAttr());
67ab590377SBoian Petkantchin if (!mesh) {
68ab590377SBoian Petkantchin return failure();
69ab590377SBoian Petkantchin }
70ab590377SBoian Petkantchin ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
71ab590377SBoian Petkantchin SmallVector<MeshAxis> opAxesIota;
72ab590377SBoian Petkantchin if (opMeshAxes.empty()) {
73ab590377SBoian Petkantchin opAxesIota.resize(mesh.getRank());
74ab590377SBoian Petkantchin std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
75ab590377SBoian Petkantchin opMeshAxes = opAxesIota;
76ab590377SBoian Petkantchin }
77ab590377SBoian Petkantchin if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
785df2c00aSBoian Petkantchin return ShapedType::isDynamic(mesh.getShape()[axis]);
79ab590377SBoian Petkantchin })) {
80ab590377SBoian Petkantchin // All mesh dimensions are dynamic. Nothing to fold.
81ab590377SBoian Petkantchin return failure();
82ab590377SBoian Petkantchin }
83ab590377SBoian Petkantchin
84ab590377SBoian Petkantchin SmallVector<Value> newResults(op->getResults().size());
85ab590377SBoian Petkantchin SmallVector<MeshAxis> newShapeOpMeshAxes;
86ab590377SBoian Petkantchin SmallVector<size_t> newToOldResultsIndexMap;
87ab590377SBoian Petkantchin
88ab590377SBoian Petkantchin for (size_t i = 0; i < opMeshAxes.size(); ++i) {
895df2c00aSBoian Petkantchin auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
90ab590377SBoian Petkantchin if (ShapedType::isDynamic(meshAxisSize)) {
91ab590377SBoian Petkantchin newToOldResultsIndexMap.push_back(i);
92ab590377SBoian Petkantchin newShapeOpMeshAxes.push_back(opMeshAxes[i]);
93ab590377SBoian Petkantchin } else {
94ab590377SBoian Petkantchin // Fold static mesh axes.
95ab590377SBoian Petkantchin newResults[i] = builder.create<arith::ConstantOp>(
96ab590377SBoian Petkantchin builder.getIndexAttr(meshAxisSize));
97ab590377SBoian Petkantchin }
98ab590377SBoian Petkantchin }
99ab590377SBoian Petkantchin
100ab590377SBoian Petkantchin // Leave only the dynamic mesh axes to be queried.
1010cb024b3SMatthias Springer if (!newShapeOpMeshAxes.empty()) {
1029a8437f5SBoian Petkantchin MeshShapeOp newShapeOp =
1039a8437f5SBoian Petkantchin builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
104ab590377SBoian Petkantchin for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
105ab590377SBoian Petkantchin newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
106ab590377SBoian Petkantchin }
1070cb024b3SMatthias Springer }
1080cb024b3SMatthias Springer rewriter.replaceOp(op, newResults);
109ab590377SBoian Petkantchin
110ab590377SBoian Petkantchin return success();
111ab590377SBoian Petkantchin }
112ab590377SBoian Petkantchin };
113ab590377SBoian Petkantchin
114ab590377SBoian Petkantchin } // namespace
115ab590377SBoian Petkantchin
populateFoldingPatterns(RewritePatternSet & patterns,SymbolTableCollection & symbolTableCollection)116ab590377SBoian Petkantchin void populateFoldingPatterns(RewritePatternSet &patterns,
117ab590377SBoian Petkantchin SymbolTableCollection &symbolTableCollection) {
1189a8437f5SBoian Petkantchin patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext());
1194b344677SBoian Petkantchin }
1204b344677SBoian Petkantchin
1214b344677SBoian Petkantchin } // namespace mesh
1224b344677SBoian Petkantchin } // namespace mlir
123