xref: /llvm-project/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
179aa7762SBoian Petkantchin //===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
24b344677SBoian Petkantchin //
34b344677SBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44b344677SBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information.
54b344677SBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64b344677SBoian Petkantchin //
74b344677SBoian Petkantchin //===----------------------------------------------------------------------===//
84b344677SBoian Petkantchin 
94b344677SBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
10dc3258c6SBoian Petkantchin #include "TransformsDetail.h"
114b344677SBoian Petkantchin #include "mlir/Dialect/Arith/IR/Arith.h"
12ab590377SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshOps.h"
13ab590377SBoian Petkantchin #include "mlir/IR/BuiltinTypeInterfaces.h"
14ab590377SBoian Petkantchin #include "mlir/IR/ImplicitLocOpBuilder.h"
15ab590377SBoian Petkantchin #include "mlir/IR/PatternMatch.h"
16ab590377SBoian Petkantchin #include "mlir/IR/SymbolTable.h"
17ab590377SBoian Petkantchin #include "llvm/ADT/STLExtras.h"
18ab590377SBoian Petkantchin #include "llvm/ADT/SmallVector.h"
19ab590377SBoian Petkantchin #include <numeric>
20ab590377SBoian Petkantchin #include <utility>
214b344677SBoian Petkantchin 
224b344677SBoian Petkantchin namespace mlir {
234b344677SBoian Petkantchin namespace mesh {
244b344677SBoian Petkantchin 
populateSimplificationPatterns(RewritePatternSet & patterns,SymbolTableCollection & symbolTableCollection)25ab590377SBoian Petkantchin void populateSimplificationPatterns(
26ab590377SBoian Petkantchin     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
274b344677SBoian Petkantchin   populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
28*ff2720d1SBoian Petkantchin       patterns, ReductionKind::Sum);
294b344677SBoian Petkantchin   populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
30*ff2720d1SBoian Petkantchin       patterns, ReductionKind::Sum);
314b344677SBoian Petkantchin 
324b344677SBoian Petkantchin   populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
33*ff2720d1SBoian Petkantchin       patterns, ReductionKind::Min);
344b344677SBoian Petkantchin   populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
35*ff2720d1SBoian Petkantchin       patterns, ReductionKind::Min);
364b344677SBoian Petkantchin   populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
37*ff2720d1SBoian Petkantchin       patterns, ReductionKind::Min);
384b344677SBoian Petkantchin 
394b344677SBoian Petkantchin   populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
40*ff2720d1SBoian Petkantchin       patterns, ReductionKind::Max);
414b344677SBoian Petkantchin   populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
42*ff2720d1SBoian Petkantchin       patterns, ReductionKind::Max);
434b344677SBoian Petkantchin   populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
44*ff2720d1SBoian Petkantchin       patterns, ReductionKind::Max);
454b344677SBoian Petkantchin 
464b344677SBoian Petkantchin   // TODO: add simplifications for all-gather and other collectives.
47ab590377SBoian Petkantchin 
48ab590377SBoian Petkantchin   populateFoldingPatterns(patterns, symbolTableCollection);
49ab590377SBoian Petkantchin }
50ab590377SBoian Petkantchin 
51ab590377SBoian Petkantchin namespace {
52ab590377SBoian Petkantchin 
53ab590377SBoian Petkantchin // This folding can not be done with an operation's fold method or
54ab590377SBoian Petkantchin // DialectFoldInterface, because it needs a SymbolTableCollection to cache the
55ab590377SBoian Petkantchin // symbol tables.
56ab590377SBoian Petkantchin // We can't use DialectFoldInterface since the cache may be invalidated by some
579a8437f5SBoian Petkantchin // pass changing the referenced MeshOp ops.
58dc3258c6SBoian Petkantchin struct MeshShapeFolder
59dc3258c6SBoian Petkantchin     : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
60dc3258c6SBoian Petkantchin   using OpRewritePatternWithSymbolTableCollection::
61dc3258c6SBoian Petkantchin       OpRewritePatternWithSymbolTableCollection;
matchAndRewritemlir::mesh::__anon2a0f5f890111::MeshShapeFolder629a8437f5SBoian Petkantchin   LogicalResult matchAndRewrite(MeshShapeOp op,
63ab590377SBoian Petkantchin                                 PatternRewriter &rewriter) const override {
64ab590377SBoian Petkantchin     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
659a8437f5SBoian Petkantchin     MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
66ab590377SBoian Petkantchin         op.getOperation(), op.getMeshAttr());
67ab590377SBoian Petkantchin     if (!mesh) {
68ab590377SBoian Petkantchin       return failure();
69ab590377SBoian Petkantchin     }
70ab590377SBoian Petkantchin     ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
71ab590377SBoian Petkantchin     SmallVector<MeshAxis> opAxesIota;
72ab590377SBoian Petkantchin     if (opMeshAxes.empty()) {
73ab590377SBoian Petkantchin       opAxesIota.resize(mesh.getRank());
74ab590377SBoian Petkantchin       std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
75ab590377SBoian Petkantchin       opMeshAxes = opAxesIota;
76ab590377SBoian Petkantchin     }
77ab590377SBoian Petkantchin     if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) {
785df2c00aSBoian Petkantchin           return ShapedType::isDynamic(mesh.getShape()[axis]);
79ab590377SBoian Petkantchin         })) {
80ab590377SBoian Petkantchin       // All mesh dimensions are dynamic. Nothing to fold.
81ab590377SBoian Petkantchin       return failure();
82ab590377SBoian Petkantchin     }
83ab590377SBoian Petkantchin 
84ab590377SBoian Petkantchin     SmallVector<Value> newResults(op->getResults().size());
85ab590377SBoian Petkantchin     SmallVector<MeshAxis> newShapeOpMeshAxes;
86ab590377SBoian Petkantchin     SmallVector<size_t> newToOldResultsIndexMap;
87ab590377SBoian Petkantchin 
88ab590377SBoian Petkantchin     for (size_t i = 0; i < opMeshAxes.size(); ++i) {
895df2c00aSBoian Petkantchin       auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
90ab590377SBoian Petkantchin       if (ShapedType::isDynamic(meshAxisSize)) {
91ab590377SBoian Petkantchin         newToOldResultsIndexMap.push_back(i);
92ab590377SBoian Petkantchin         newShapeOpMeshAxes.push_back(opMeshAxes[i]);
93ab590377SBoian Petkantchin       } else {
94ab590377SBoian Petkantchin         // Fold static mesh axes.
95ab590377SBoian Petkantchin         newResults[i] = builder.create<arith::ConstantOp>(
96ab590377SBoian Petkantchin             builder.getIndexAttr(meshAxisSize));
97ab590377SBoian Petkantchin       }
98ab590377SBoian Petkantchin     }
99ab590377SBoian Petkantchin 
100ab590377SBoian Petkantchin     // Leave only the dynamic mesh axes to be queried.
1010cb024b3SMatthias Springer     if (!newShapeOpMeshAxes.empty()) {
1029a8437f5SBoian Petkantchin       MeshShapeOp newShapeOp =
1039a8437f5SBoian Petkantchin           builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
104ab590377SBoian Petkantchin       for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
105ab590377SBoian Petkantchin         newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
106ab590377SBoian Petkantchin       }
1070cb024b3SMatthias Springer     }
1080cb024b3SMatthias Springer     rewriter.replaceOp(op, newResults);
109ab590377SBoian Petkantchin 
110ab590377SBoian Petkantchin     return success();
111ab590377SBoian Petkantchin   }
112ab590377SBoian Petkantchin };
113ab590377SBoian Petkantchin 
114ab590377SBoian Petkantchin } // namespace
115ab590377SBoian Petkantchin 
populateFoldingPatterns(RewritePatternSet & patterns,SymbolTableCollection & symbolTableCollection)116ab590377SBoian Petkantchin void populateFoldingPatterns(RewritePatternSet &patterns,
117ab590377SBoian Petkantchin                              SymbolTableCollection &symbolTableCollection) {
1189a8437f5SBoian Petkantchin   patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext());
1194b344677SBoian Petkantchin }
1204b344677SBoian Petkantchin 
1214b344677SBoian Petkantchin } // namespace mesh
1224b344677SBoian Petkantchin } // namespace mlir
123