xref: /llvm-project/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (revision 129f1001c3b1b5200de43917d53c0efbdf08f11f)
179aa7762SBoian Petkantchin //===- Transforms.cpp ---------------------------------------------- C++ --===//
279aa7762SBoian Petkantchin //
379aa7762SBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
479aa7762SBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information.
579aa7762SBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
679aa7762SBoian Petkantchin //
779aa7762SBoian Petkantchin //===----------------------------------------------------------------------===//
879aa7762SBoian Petkantchin 
979aa7762SBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
10dc3258c6SBoian Petkantchin #include "TransformsDetail.h"
1179aa7762SBoian Petkantchin #include "mlir/Dialect/Affine/IR/AffineOps.h"
12dc3258c6SBoian Petkantchin #include "mlir/Dialect/Affine/Utils.h"
13dc3258c6SBoian Petkantchin #include "mlir/Dialect/Arith/IR/Arith.h"
14dc3258c6SBoian Petkantchin #include "mlir/Dialect/Arith/Utils/Utils.h"
15dc3258c6SBoian Petkantchin #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
16dc3258c6SBoian Petkantchin #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1731fc0a12SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
1879aa7762SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshOps.h"
19dc3258c6SBoian Petkantchin #include "mlir/Dialect/Tensor/IR/Tensor.h"
20dc3258c6SBoian Petkantchin #include "mlir/Dialect/Utils/StaticValueUtils.h"
2179aa7762SBoian Petkantchin #include "mlir/IR/BuiltinTypes.h"
2279aa7762SBoian Petkantchin #include "mlir/IR/DialectRegistry.h"
2379aa7762SBoian Petkantchin #include "mlir/IR/ImplicitLocOpBuilder.h"
24dc3258c6SBoian Petkantchin #include "mlir/IR/OpDefinition.h"
2579aa7762SBoian Petkantchin #include "mlir/IR/PatternMatch.h"
2679aa7762SBoian Petkantchin #include "mlir/IR/Value.h"
2779aa7762SBoian Petkantchin #include "llvm/ADT/STLExtras.h"
2879aa7762SBoian Petkantchin #include "llvm/ADT/SmallVector.h"
2979aa7762SBoian Petkantchin #include <iterator>
3079aa7762SBoian Petkantchin #include <numeric>
3179aa7762SBoian Petkantchin 
3279aa7762SBoian Petkantchin namespace mlir::mesh {
3379aa7762SBoian Petkantchin 
3479aa7762SBoian Petkantchin namespace {
3579aa7762SBoian Petkantchin 
3679aa7762SBoian Petkantchin /// Lower `mesh.process_multi_index` into expression using
379a8437f5SBoian Petkantchin /// `mesh.process_linear_index` and `mesh.mesh_shape`.
38dc3258c6SBoian Petkantchin struct ProcessMultiIndexOpLowering
39dc3258c6SBoian Petkantchin     : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
40dc3258c6SBoian Petkantchin   using OpRewritePatternWithSymbolTableCollection::
41dc3258c6SBoian Petkantchin       OpRewritePatternWithSymbolTableCollection;
4279aa7762SBoian Petkantchin 
4379aa7762SBoian Petkantchin   LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
4479aa7762SBoian Petkantchin                                 PatternRewriter &rewriter) const override {
45dc3258c6SBoian Petkantchin     MeshOp mesh = getMesh(op, symbolTableCollection);
4679aa7762SBoian Petkantchin     if (!mesh) {
4779aa7762SBoian Petkantchin       return failure();
4879aa7762SBoian Petkantchin     }
4979aa7762SBoian Petkantchin 
5079aa7762SBoian Petkantchin     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
5179aa7762SBoian Petkantchin     builder.setInsertionPointAfter(op.getOperation());
5279aa7762SBoian Petkantchin     Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
539a8437f5SBoian Petkantchin     ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
5479aa7762SBoian Petkantchin     SmallVector<Value> completeMultiIndex =
5579aa7762SBoian Petkantchin         builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
5679aa7762SBoian Petkantchin             .getMultiIndex();
5779aa7762SBoian Petkantchin     SmallVector<Value> multiIndex;
5879aa7762SBoian Petkantchin     ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
5979aa7762SBoian Petkantchin     SmallVector<MeshAxis> opAxesIota;
6079aa7762SBoian Petkantchin     if (opMeshAxes.empty()) {
6179aa7762SBoian Petkantchin       opAxesIota.resize(mesh.getRank());
6279aa7762SBoian Petkantchin       std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
6379aa7762SBoian Petkantchin       opMeshAxes = opAxesIota;
6479aa7762SBoian Petkantchin     }
6579aa7762SBoian Petkantchin     llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
6679aa7762SBoian Petkantchin                     [&completeMultiIndex](MeshAxis meshAxis) {
6779aa7762SBoian Petkantchin                       return completeMultiIndex[meshAxis];
6879aa7762SBoian Petkantchin                     });
6979aa7762SBoian Petkantchin     rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
7079aa7762SBoian Petkantchin     return success();
7179aa7762SBoian Petkantchin   }
72dc3258c6SBoian Petkantchin };
7379aa7762SBoian Petkantchin 
74dc3258c6SBoian Petkantchin struct AllSliceOpLowering
75dc3258c6SBoian Petkantchin     : OpRewritePatternWithSymbolTableCollection<AllSliceOp> {
76dc3258c6SBoian Petkantchin   using OpRewritePatternWithSymbolTableCollection::
77dc3258c6SBoian Petkantchin       OpRewritePatternWithSymbolTableCollection;
78dc3258c6SBoian Petkantchin 
79dc3258c6SBoian Petkantchin   LogicalResult matchAndRewrite(AllSliceOp op,
80dc3258c6SBoian Petkantchin                                 PatternRewriter &rewriter) const override {
81dc3258c6SBoian Petkantchin     // 1. Compute the process linear index inside the process group from its
82dc3258c6SBoian Petkantchin     // multi-index.
83dc3258c6SBoian Petkantchin     //
84dc3258c6SBoian Petkantchin     // 2. Extract a slice from the input tensor.
85dc3258c6SBoian Petkantchin     // All axes except the slicing axis are not interesting and take the full
86dc3258c6SBoian Petkantchin     // axis.
87dc3258c6SBoian Petkantchin     // The slice axis is split into equisized parts with count
88dc3258c6SBoian Petkantchin     // the number of processes in the collective process group induced by
89dc3258c6SBoian Petkantchin     // the mesh axes.
90dc3258c6SBoian Petkantchin     // The part for each process is determined by the corresponding
91dc3258c6SBoian Petkantchin     // linear-index in the process group.
92dc3258c6SBoian Petkantchin     //
93dc3258c6SBoian Petkantchin     // There are no collectives that require communication.
94dc3258c6SBoian Petkantchin     // Each process operates on its local tensor.
95dc3258c6SBoian Petkantchin 
96dc3258c6SBoian Petkantchin     MeshOp mesh = getMesh(op, symbolTableCollection);
97dc3258c6SBoian Petkantchin     if (!mesh) {
98dc3258c6SBoian Petkantchin       return failure();
99dc3258c6SBoian Petkantchin     }
100dc3258c6SBoian Petkantchin 
101dc3258c6SBoian Petkantchin     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
102dc3258c6SBoian Petkantchin     builder.setInsertionPointAfter(op.getOperation());
103dc3258c6SBoian Petkantchin 
104dc3258c6SBoian Petkantchin     Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
105dc3258c6SBoian Petkantchin 
106dc3258c6SBoian Petkantchin     Operation::result_range processInGroupMultiIndex =
107dc3258c6SBoian Petkantchin         builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
108dc3258c6SBoian Petkantchin             .getResults();
109dc3258c6SBoian Petkantchin 
110dc3258c6SBoian Petkantchin     Operation::result_range processGroupShape =
111dc3258c6SBoian Petkantchin         builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
112dc3258c6SBoian Petkantchin             .getResult();
113dc3258c6SBoian Petkantchin     Value processGroupSize =
114dc3258c6SBoian Petkantchin         createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
115dc3258c6SBoian Petkantchin 
116dc3258c6SBoian Petkantchin     int64_t sliceAxis = op.getSliceAxis().getSExtValue();
117dc3258c6SBoian Petkantchin     Value operandSliceAxisSize =
118dc3258c6SBoian Petkantchin         builder.create<tensor::DimOp>(op.getOperand(), sliceAxis);
119dc3258c6SBoian Petkantchin     Value operandSliceAxisSizeModProcessGroupSize =
120dc3258c6SBoian Petkantchin         builder.create<arith::RemUIOp>(operandSliceAxisSize, processGroupSize);
121dc3258c6SBoian Petkantchin     Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
122dc3258c6SBoian Petkantchin         arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize,
123dc3258c6SBoian Petkantchin         zero);
124dc3258c6SBoian Petkantchin     builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
125dc3258c6SBoian Petkantchin                                  "Slicing a tensor with axis size that is "
126dc3258c6SBoian Petkantchin                                  "not exactly divisible by the "
127dc3258c6SBoian Petkantchin                                  "mesh process group size is not supported.");
128dc3258c6SBoian Petkantchin     Value resultSliceAxisSize =
129dc3258c6SBoian Petkantchin         builder.create<arith::DivUIOp>(operandSliceAxisSize, processGroupSize);
130dc3258c6SBoian Petkantchin     OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
131dc3258c6SBoian Petkantchin         llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
132dc3258c6SBoian Petkantchin         llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
133dc3258c6SBoian Petkantchin 
134dc3258c6SBoian Petkantchin     // insert tensor.extract_slice
135dc3258c6SBoian Petkantchin     RankedTensorType operandType =
136a5757c5bSChristian Sigg         cast<RankedTensorType>(op.getOperand().getType());
137dc3258c6SBoian Petkantchin     SmallVector<OpFoldResult> sizes;
138dc3258c6SBoian Petkantchin     for (int64_t i = 0; i < operandType.getRank(); ++i) {
139dc3258c6SBoian Petkantchin       if (i == sliceAxis) {
140dc3258c6SBoian Petkantchin         sizes.emplace_back(resultSliceAxisSize);
141dc3258c6SBoian Petkantchin       } else {
142dc3258c6SBoian Petkantchin         Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
143dc3258c6SBoian Petkantchin         sizes.emplace_back(dimSize);
144dc3258c6SBoian Petkantchin       }
145dc3258c6SBoian Petkantchin     }
146dc3258c6SBoian Petkantchin     SmallVector<OpFoldResult> offsets(
147dc3258c6SBoian Petkantchin         operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
148dc3258c6SBoian Petkantchin     offsets[sliceAxis] =
149dc3258c6SBoian Petkantchin         ArithBuilder(builder, builder.getLoc())
150dc3258c6SBoian Petkantchin             .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
151dc3258c6SBoian Petkantchin                                                  processInGroupLinearIndex),
152dc3258c6SBoian Petkantchin                  resultSliceAxisSize);
153dc3258c6SBoian Petkantchin     SmallVector<OpFoldResult> strides(
154dc3258c6SBoian Petkantchin         operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
155dc3258c6SBoian Petkantchin     Value slice = builder.create<tensor::ExtractSliceOp>(
156dc3258c6SBoian Petkantchin         op.getOperand(), offsets, sizes, strides);
157dc3258c6SBoian Petkantchin     Value newResult =
158dc3258c6SBoian Petkantchin         builder.create<tensor::CastOp>(op.getResult().getType(), slice);
159dc3258c6SBoian Petkantchin     rewriter.replaceAllUsesWith(op.getResult(), newResult);
160dc3258c6SBoian Petkantchin 
161dc3258c6SBoian Petkantchin     return success();
162dc3258c6SBoian Petkantchin   }
16379aa7762SBoian Petkantchin };
16479aa7762SBoian Petkantchin 
16579aa7762SBoian Petkantchin } // namespace
16679aa7762SBoian Petkantchin 
167dc3258c6SBoian Petkantchin void populateProcessMultiIndexOpLoweringPatterns(
16879aa7762SBoian Petkantchin     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
16979aa7762SBoian Petkantchin   patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
17079aa7762SBoian Petkantchin                                             patterns.getContext());
17179aa7762SBoian Petkantchin }
17279aa7762SBoian Petkantchin 
173dc3258c6SBoian Petkantchin void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry) {
17479aa7762SBoian Petkantchin   registry.insert<affine::AffineDialect, mesh::MeshDialect>();
17579aa7762SBoian Petkantchin }
17679aa7762SBoian Petkantchin 
177dc3258c6SBoian Petkantchin void populateAllSliceOpLoweringPatterns(
178dc3258c6SBoian Petkantchin     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
179dc3258c6SBoian Petkantchin   patterns.add<AllSliceOpLowering>(symbolTableCollection,
180dc3258c6SBoian Petkantchin                                    patterns.getContext());
181dc3258c6SBoian Petkantchin }
182dc3258c6SBoian Petkantchin 
183dc3258c6SBoian Petkantchin void registerAllSliceOpLoweringDialects(DialectRegistry &registry) {
184dc3258c6SBoian Petkantchin   registry.insert<affine::AffineDialect, arith::ArithDialect,
185dc3258c6SBoian Petkantchin                   cf::ControlFlowDialect, mesh::MeshDialect,
186dc3258c6SBoian Petkantchin                   tensor::TensorDialect>();
187dc3258c6SBoian Petkantchin }
188dc3258c6SBoian Petkantchin 
189dc3258c6SBoian Petkantchin void populateAllOpLoweringPatterns(
190dc3258c6SBoian Petkantchin     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
191dc3258c6SBoian Petkantchin   populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection);
192dc3258c6SBoian Petkantchin   populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
193dc3258c6SBoian Petkantchin }
194dc3258c6SBoian Petkantchin 
195dc3258c6SBoian Petkantchin void registerAllOpLoweringDialects(DialectRegistry &registry) {
196dc3258c6SBoian Petkantchin   registerProcessMultiIndexOpLoweringDialects(registry);
197dc3258c6SBoian Petkantchin   registerAllSliceOpLoweringDialects(registry);
198dc3258c6SBoian Petkantchin }
199dc3258c6SBoian Petkantchin 
200dc3258c6SBoian Petkantchin TypedValue<IndexType>
201dc3258c6SBoian Petkantchin createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
202dc3258c6SBoian Petkantchin                                  ImplicitLocOpBuilder &builder) {
203dc3258c6SBoian Petkantchin   Operation::result_range meshShape =
204dc3258c6SBoian Petkantchin       builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
205a5757c5bSChristian Sigg   return cast<TypedValue<IndexType>>(arith::createProduct(
206a5757c5bSChristian Sigg       builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
207a5757c5bSChristian Sigg       builder.getIndexType()));
208dc3258c6SBoian Petkantchin }
209dc3258c6SBoian Petkantchin 
210fb582b6aSBoian Petkantchin TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
211fb582b6aSBoian Petkantchin                                                ArrayRef<MeshAxis> meshAxes,
212fb582b6aSBoian Petkantchin                                                ImplicitLocOpBuilder &builder) {
213fb582b6aSBoian Petkantchin   ResultRange processInGroupMultiIndex =
214fb582b6aSBoian Petkantchin       builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
215fb582b6aSBoian Petkantchin   Operation::result_range processGroupShape =
216fb582b6aSBoian Petkantchin       builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
217fb582b6aSBoian Petkantchin   OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
218fb582b6aSBoian Petkantchin       llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
219fb582b6aSBoian Petkantchin       llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
220*129f1001SKazu Hirata   return cast<TypedValue<IndexType>>(cast<Value>(processInGroupLinearIndex));
221fb582b6aSBoian Petkantchin }
222fb582b6aSBoian Petkantchin 
22379aa7762SBoian Petkantchin } // namespace mlir::mesh
224