//===- Transforms.cpp ---------------------------------------------- C++ --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Mesh/Transforms/Transforms.h" #include "TransformsDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Mesh/IR/MeshDialect.h" #include "mlir/Dialect/Mesh/IR/MeshOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include #include namespace mlir::mesh { namespace { /// Lower `mesh.process_multi_index` into expression using /// `mesh.process_linear_index` and `mesh.mesh_shape`. struct ProcessMultiIndexOpLowering : OpRewritePatternWithSymbolTableCollection { using OpRewritePatternWithSymbolTableCollection:: OpRewritePatternWithSymbolTableCollection; LogicalResult matchAndRewrite(ProcessMultiIndexOp op, PatternRewriter &rewriter) const override { MeshOp mesh = getMesh(op, symbolTableCollection); if (!mesh) { return failure(); } ImplicitLocOpBuilder builder(op->getLoc(), rewriter); builder.setInsertionPointAfter(op.getOperation()); Value linearIndex = builder.create(mesh); ValueRange meshShape = builder.create(mesh).getResults(); SmallVector completeMultiIndex = builder.create(linearIndex, meshShape) .getMultiIndex(); SmallVector multiIndex; ArrayRef opMeshAxes = op.getAxes(); SmallVector opAxesIota; if (opMeshAxes.empty()) { opAxesIota.resize(mesh.getRank()); std::iota(opAxesIota.begin(), opAxesIota.end(), 0); opMeshAxes = opAxesIota; } llvm::transform(opMeshAxes, std::back_inserter(multiIndex), [&completeMultiIndex](MeshAxis meshAxis) { return completeMultiIndex[meshAxis]; }); rewriter.replaceAllUsesWith(op.getResults(), multiIndex); return success(); } }; struct AllSliceOpLowering : OpRewritePatternWithSymbolTableCollection { using OpRewritePatternWithSymbolTableCollection:: OpRewritePatternWithSymbolTableCollection; LogicalResult matchAndRewrite(AllSliceOp op, PatternRewriter &rewriter) const override { // 1. Compute the process linear index inside the process group from its // multi-index. // // 2. Extract a slice from the input tensor. // All axes except the slicing axis are not interesting and take the full // axis. // The slice axis is split into equisized parts with count // the number of processes in the collective process group induced by // the mesh axes. // The part for each process is determined by the corresponding // linear-index in the process group. // // There are no collectives that require communication. // Each process operates on its local tensor. MeshOp mesh = getMesh(op, symbolTableCollection); if (!mesh) { return failure(); } ImplicitLocOpBuilder builder(op->getLoc(), rewriter); builder.setInsertionPointAfter(op.getOperation()); Value zero = builder.create(builder.getIndexAttr(0)); Operation::result_range processInGroupMultiIndex = builder.create(mesh.getSymName(), op.getMeshAxes()) .getResults(); Operation::result_range processGroupShape = builder.create(mesh.getSymName(), op.getMeshAxes()) .getResult(); Value processGroupSize = createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder); int64_t sliceAxis = op.getSliceAxis().getSExtValue(); Value operandSliceAxisSize = builder.create(op.getOperand(), sliceAxis); Value operandSliceAxisSizeModProcessGroupSize = builder.create(operandSliceAxisSize, processGroupSize); Value isTargetShapeExactlyDivisible = builder.create( arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize, zero); builder.create(isTargetShapeExactlyDivisible, "Slicing a tensor with axis size that is " "not exactly divisible by the " "mesh process group size is not supported."); Value resultSliceAxisSize = builder.create(operandSliceAxisSize, processGroupSize); OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( llvm::to_vector_of(processInGroupMultiIndex), llvm::to_vector_of(processGroupShape), builder); // insert tensor.extract_slice RankedTensorType operandType = cast(op.getOperand().getType()); SmallVector sizes; for (int64_t i = 0; i < operandType.getRank(); ++i) { if (i == sliceAxis) { sizes.emplace_back(resultSliceAxisSize); } else { Value dimSize = builder.create(op.getOperand(), i); sizes.emplace_back(dimSize); } } SmallVector offsets( operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0)); offsets[sliceAxis] = ArithBuilder(builder, builder.getLoc()) .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(), processInGroupLinearIndex), resultSliceAxisSize); SmallVector strides( operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1)); Value slice = builder.create( op.getOperand(), offsets, sizes, strides); Value newResult = builder.create(op.getResult().getType(), slice); rewriter.replaceAllUsesWith(op.getResult(), newResult); return success(); } }; } // namespace void populateProcessMultiIndexOpLoweringPatterns( RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { patterns.add(symbolTableCollection, patterns.getContext()); } void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry) { registry.insert(); } void populateAllSliceOpLoweringPatterns( RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { patterns.add(symbolTableCollection, patterns.getContext()); } void registerAllSliceOpLoweringDialects(DialectRegistry ®istry) { registry.insert(); } void populateAllOpLoweringPatterns( RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection); populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); } void registerAllOpLoweringDialects(DialectRegistry ®istry) { registerProcessMultiIndexOpLoweringDialects(registry); registerAllSliceOpLoweringDialects(registry); } TypedValue createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef axes, ImplicitLocOpBuilder &builder) { Operation::result_range meshShape = builder.create(mesh, axes).getResults(); return cast>(arith::createProduct( builder, builder.getLoc(), llvm::to_vector_of(meshShape), builder.getIndexType())); } TypedValue createProcessLinearIndex(StringRef mesh, ArrayRef meshAxes, ImplicitLocOpBuilder &builder) { ResultRange processInGroupMultiIndex = builder.create(mesh, meshAxes).getResults(); Operation::result_range processGroupShape = builder.create(mesh, meshAxes).getResult(); OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( llvm::to_vector_of(processInGroupMultiIndex), llvm::to_vector_of(processGroupShape), builder); return cast>(cast(processInGroupLinearIndex)); } } // namespace mlir::mesh