1 //===- Transforms.cpp ---------------------------------------------- C++ --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Mesh/Transforms/Transforms.h" 10 #include "TransformsDetail.h" 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Affine/Utils.h" 13 #include "mlir/Dialect/Arith/IR/Arith.h" 14 #include "mlir/Dialect/Arith/Utils/Utils.h" 15 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" 16 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 17 #include "mlir/Dialect/Mesh/IR/MeshDialect.h" 18 #include "mlir/Dialect/Mesh/IR/MeshOps.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Utils/StaticValueUtils.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/DialectRegistry.h" 23 #include "mlir/IR/ImplicitLocOpBuilder.h" 24 #include "mlir/IR/OpDefinition.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/IR/Value.h" 27 #include "llvm/ADT/STLExtras.h" 28 #include "llvm/ADT/SmallVector.h" 29 #include <iterator> 30 #include <numeric> 31 32 namespace mlir::mesh { 33 34 namespace { 35 36 /// Lower `mesh.process_multi_index` into expression using 37 /// `mesh.process_linear_index` and `mesh.mesh_shape`. 38 struct ProcessMultiIndexOpLowering 39 : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> { 40 using OpRewritePatternWithSymbolTableCollection:: 41 OpRewritePatternWithSymbolTableCollection; 42 43 LogicalResult matchAndRewrite(ProcessMultiIndexOp op, 44 PatternRewriter &rewriter) const override { 45 MeshOp mesh = getMesh(op, symbolTableCollection); 46 if (!mesh) { 47 return failure(); 48 } 49 50 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 51 builder.setInsertionPointAfter(op.getOperation()); 52 Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh); 53 ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults(); 54 SmallVector<Value> completeMultiIndex = 55 builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape) 56 .getMultiIndex(); 57 SmallVector<Value> multiIndex; 58 ArrayRef<MeshAxis> opMeshAxes = op.getAxes(); 59 SmallVector<MeshAxis> opAxesIota; 60 if (opMeshAxes.empty()) { 61 opAxesIota.resize(mesh.getRank()); 62 std::iota(opAxesIota.begin(), opAxesIota.end(), 0); 63 opMeshAxes = opAxesIota; 64 } 65 llvm::transform(opMeshAxes, std::back_inserter(multiIndex), 66 [&completeMultiIndex](MeshAxis meshAxis) { 67 return completeMultiIndex[meshAxis]; 68 }); 69 rewriter.replaceAllUsesWith(op.getResults(), multiIndex); 70 return success(); 71 } 72 }; 73 74 struct AllSliceOpLowering 75 : OpRewritePatternWithSymbolTableCollection<AllSliceOp> { 76 using OpRewritePatternWithSymbolTableCollection:: 77 OpRewritePatternWithSymbolTableCollection; 78 79 LogicalResult matchAndRewrite(AllSliceOp op, 80 PatternRewriter &rewriter) const override { 81 // 1. Compute the process linear index inside the process group from its 82 // multi-index. 83 // 84 // 2. Extract a slice from the input tensor. 85 // All axes except the slicing axis are not interesting and take the full 86 // axis. 87 // The slice axis is split into equisized parts with count 88 // the number of processes in the collective process group induced by 89 // the mesh axes. 90 // The part for each process is determined by the corresponding 91 // linear-index in the process group. 92 // 93 // There are no collectives that require communication. 94 // Each process operates on its local tensor. 95 96 MeshOp mesh = getMesh(op, symbolTableCollection); 97 if (!mesh) { 98 return failure(); 99 } 100 101 ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 102 builder.setInsertionPointAfter(op.getOperation()); 103 104 Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0)); 105 106 Operation::result_range processInGroupMultiIndex = 107 builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes()) 108 .getResults(); 109 110 Operation::result_range processGroupShape = 111 builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes()) 112 .getResult(); 113 Value processGroupSize = 114 createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder); 115 116 int64_t sliceAxis = op.getSliceAxis().getSExtValue(); 117 Value operandSliceAxisSize = 118 builder.create<tensor::DimOp>(op.getOperand(), sliceAxis); 119 Value operandSliceAxisSizeModProcessGroupSize = 120 builder.create<arith::RemUIOp>(operandSliceAxisSize, processGroupSize); 121 Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>( 122 arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize, 123 zero); 124 builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible, 125 "Slicing a tensor with axis size that is " 126 "not exactly divisible by the " 127 "mesh process group size is not supported."); 128 Value resultSliceAxisSize = 129 builder.create<arith::DivUIOp>(operandSliceAxisSize, processGroupSize); 130 OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( 131 llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex), 132 llvm::to_vector_of<OpFoldResult>(processGroupShape), builder); 133 134 // insert tensor.extract_slice 135 RankedTensorType operandType = 136 cast<RankedTensorType>(op.getOperand().getType()); 137 SmallVector<OpFoldResult> sizes; 138 for (int64_t i = 0; i < operandType.getRank(); ++i) { 139 if (i == sliceAxis) { 140 sizes.emplace_back(resultSliceAxisSize); 141 } else { 142 Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i); 143 sizes.emplace_back(dimSize); 144 } 145 } 146 SmallVector<OpFoldResult> offsets( 147 operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0)); 148 offsets[sliceAxis] = 149 ArithBuilder(builder, builder.getLoc()) 150 .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(), 151 processInGroupLinearIndex), 152 resultSliceAxisSize); 153 SmallVector<OpFoldResult> strides( 154 operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1)); 155 Value slice = builder.create<tensor::ExtractSliceOp>( 156 op.getOperand(), offsets, sizes, strides); 157 Value newResult = 158 builder.create<tensor::CastOp>(op.getResult().getType(), slice); 159 rewriter.replaceAllUsesWith(op.getResult(), newResult); 160 161 return success(); 162 } 163 }; 164 165 } // namespace 166 167 void populateProcessMultiIndexOpLoweringPatterns( 168 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { 169 patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection, 170 patterns.getContext()); 171 } 172 173 void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry) { 174 registry.insert<affine::AffineDialect, mesh::MeshDialect>(); 175 } 176 177 void populateAllSliceOpLoweringPatterns( 178 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { 179 patterns.add<AllSliceOpLowering>(symbolTableCollection, 180 patterns.getContext()); 181 } 182 183 void registerAllSliceOpLoweringDialects(DialectRegistry ®istry) { 184 registry.insert<affine::AffineDialect, arith::ArithDialect, 185 cf::ControlFlowDialect, mesh::MeshDialect, 186 tensor::TensorDialect>(); 187 } 188 189 void populateAllOpLoweringPatterns( 190 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { 191 populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection); 192 populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); 193 } 194 195 void registerAllOpLoweringDialects(DialectRegistry ®istry) { 196 registerProcessMultiIndexOpLoweringDialects(registry); 197 registerAllSliceOpLoweringDialects(registry); 198 } 199 200 TypedValue<IndexType> 201 createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes, 202 ImplicitLocOpBuilder &builder) { 203 Operation::result_range meshShape = 204 builder.create<mesh::MeshShapeOp>(mesh, axes).getResults(); 205 return cast<TypedValue<IndexType>>(arith::createProduct( 206 builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape), 207 builder.getIndexType())); 208 } 209 210 TypedValue<IndexType> createProcessLinearIndex(StringRef mesh, 211 ArrayRef<MeshAxis> meshAxes, 212 ImplicitLocOpBuilder &builder) { 213 ResultRange processInGroupMultiIndex = 214 builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults(); 215 Operation::result_range processGroupShape = 216 builder.create<MeshShapeOp>(mesh, meshAxes).getResult(); 217 OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( 218 llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex), 219 llvm::to_vector_of<OpFoldResult>(processGroupShape), builder); 220 return cast<TypedValue<IndexType>>(cast<Value>(processInGroupLinearIndex)); 221 } 222 223 } // namespace mlir::mesh 224