xref: /llvm-project/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp (revision 129f1001c3b1b5200de43917d53c0efbdf08f11f)
1 //===- Transforms.cpp ---------------------------------------------- C++ --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Mesh/Transforms/Transforms.h"
10 #include "TransformsDetail.h"
11 #include "mlir/Dialect/Affine/IR/AffineOps.h"
12 #include "mlir/Dialect/Affine/Utils.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Arith/Utils/Utils.h"
15 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
16 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17 #include "mlir/Dialect/Mesh/IR/MeshDialect.h"
18 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/Dialect/Utils/StaticValueUtils.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/DialectRegistry.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "mlir/IR/OpDefinition.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/Value.h"
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ADT/SmallVector.h"
29 #include <iterator>
30 #include <numeric>
31 
32 namespace mlir::mesh {
33 
34 namespace {
35 
36 /// Lower `mesh.process_multi_index` into expression using
37 /// `mesh.process_linear_index` and `mesh.mesh_shape`.
38 struct ProcessMultiIndexOpLowering
39     : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
40   using OpRewritePatternWithSymbolTableCollection::
41       OpRewritePatternWithSymbolTableCollection;
42 
43   LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
44                                 PatternRewriter &rewriter) const override {
45     MeshOp mesh = getMesh(op, symbolTableCollection);
46     if (!mesh) {
47       return failure();
48     }
49 
50     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
51     builder.setInsertionPointAfter(op.getOperation());
52     Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
53     ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
54     SmallVector<Value> completeMultiIndex =
55         builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
56             .getMultiIndex();
57     SmallVector<Value> multiIndex;
58     ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
59     SmallVector<MeshAxis> opAxesIota;
60     if (opMeshAxes.empty()) {
61       opAxesIota.resize(mesh.getRank());
62       std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
63       opMeshAxes = opAxesIota;
64     }
65     llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
66                     [&completeMultiIndex](MeshAxis meshAxis) {
67                       return completeMultiIndex[meshAxis];
68                     });
69     rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
70     return success();
71   }
72 };
73 
74 struct AllSliceOpLowering
75     : OpRewritePatternWithSymbolTableCollection<AllSliceOp> {
76   using OpRewritePatternWithSymbolTableCollection::
77       OpRewritePatternWithSymbolTableCollection;
78 
79   LogicalResult matchAndRewrite(AllSliceOp op,
80                                 PatternRewriter &rewriter) const override {
81     // 1. Compute the process linear index inside the process group from its
82     // multi-index.
83     //
84     // 2. Extract a slice from the input tensor.
85     // All axes except the slicing axis are not interesting and take the full
86     // axis.
87     // The slice axis is split into equisized parts with count
88     // the number of processes in the collective process group induced by
89     // the mesh axes.
90     // The part for each process is determined by the corresponding
91     // linear-index in the process group.
92     //
93     // There are no collectives that require communication.
94     // Each process operates on its local tensor.
95 
96     MeshOp mesh = getMesh(op, symbolTableCollection);
97     if (!mesh) {
98       return failure();
99     }
100 
101     ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
102     builder.setInsertionPointAfter(op.getOperation());
103 
104     Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
105 
106     Operation::result_range processInGroupMultiIndex =
107         builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
108             .getResults();
109 
110     Operation::result_range processGroupShape =
111         builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
112             .getResult();
113     Value processGroupSize =
114         createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
115 
116     int64_t sliceAxis = op.getSliceAxis().getSExtValue();
117     Value operandSliceAxisSize =
118         builder.create<tensor::DimOp>(op.getOperand(), sliceAxis);
119     Value operandSliceAxisSizeModProcessGroupSize =
120         builder.create<arith::RemUIOp>(operandSliceAxisSize, processGroupSize);
121     Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
122         arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize,
123         zero);
124     builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
125                                  "Slicing a tensor with axis size that is "
126                                  "not exactly divisible by the "
127                                  "mesh process group size is not supported.");
128     Value resultSliceAxisSize =
129         builder.create<arith::DivUIOp>(operandSliceAxisSize, processGroupSize);
130     OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
131         llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
132         llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
133 
134     // insert tensor.extract_slice
135     RankedTensorType operandType =
136         cast<RankedTensorType>(op.getOperand().getType());
137     SmallVector<OpFoldResult> sizes;
138     for (int64_t i = 0; i < operandType.getRank(); ++i) {
139       if (i == sliceAxis) {
140         sizes.emplace_back(resultSliceAxisSize);
141       } else {
142         Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
143         sizes.emplace_back(dimSize);
144       }
145     }
146     SmallVector<OpFoldResult> offsets(
147         operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
148     offsets[sliceAxis] =
149         ArithBuilder(builder, builder.getLoc())
150             .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
151                                                  processInGroupLinearIndex),
152                  resultSliceAxisSize);
153     SmallVector<OpFoldResult> strides(
154         operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
155     Value slice = builder.create<tensor::ExtractSliceOp>(
156         op.getOperand(), offsets, sizes, strides);
157     Value newResult =
158         builder.create<tensor::CastOp>(op.getResult().getType(), slice);
159     rewriter.replaceAllUsesWith(op.getResult(), newResult);
160 
161     return success();
162   }
163 };
164 
165 } // namespace
166 
167 void populateProcessMultiIndexOpLoweringPatterns(
168     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
169   patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
170                                             patterns.getContext());
171 }
172 
173 void registerProcessMultiIndexOpLoweringDialects(DialectRegistry &registry) {
174   registry.insert<affine::AffineDialect, mesh::MeshDialect>();
175 }
176 
177 void populateAllSliceOpLoweringPatterns(
178     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
179   patterns.add<AllSliceOpLowering>(symbolTableCollection,
180                                    patterns.getContext());
181 }
182 
183 void registerAllSliceOpLoweringDialects(DialectRegistry &registry) {
184   registry.insert<affine::AffineDialect, arith::ArithDialect,
185                   cf::ControlFlowDialect, mesh::MeshDialect,
186                   tensor::TensorDialect>();
187 }
188 
189 void populateAllOpLoweringPatterns(
190     RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
191   populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection);
192   populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
193 }
194 
195 void registerAllOpLoweringDialects(DialectRegistry &registry) {
196   registerProcessMultiIndexOpLoweringDialects(registry);
197   registerAllSliceOpLoweringDialects(registry);
198 }
199 
200 TypedValue<IndexType>
201 createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
202                                  ImplicitLocOpBuilder &builder) {
203   Operation::result_range meshShape =
204       builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
205   return cast<TypedValue<IndexType>>(arith::createProduct(
206       builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
207       builder.getIndexType()));
208 }
209 
210 TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
211                                                ArrayRef<MeshAxis> meshAxes,
212                                                ImplicitLocOpBuilder &builder) {
213   ResultRange processInGroupMultiIndex =
214       builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
215   Operation::result_range processGroupShape =
216       builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
217   OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
218       llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
219       llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
220   return cast<TypedValue<IndexType>>(cast<Value>(processInGroupLinearIndex));
221 }
222 
223 } // namespace mlir::mesh
224