179aa7762SBoian Petkantchin //===- Transforms.cpp ---------------------------------------------- C++ --===// 279aa7762SBoian Petkantchin // 379aa7762SBoian Petkantchin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 479aa7762SBoian Petkantchin // See https://llvm.org/LICENSE.txt for license information. 579aa7762SBoian Petkantchin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 679aa7762SBoian Petkantchin // 779aa7762SBoian Petkantchin //===----------------------------------------------------------------------===// 879aa7762SBoian Petkantchin 979aa7762SBoian Petkantchin #include "mlir/Dialect/Mesh/Transforms/Transforms.h" 10dc3258c6SBoian Petkantchin #include "TransformsDetail.h" 1179aa7762SBoian Petkantchin #include "mlir/Dialect/Affine/IR/AffineOps.h" 12dc3258c6SBoian Petkantchin #include "mlir/Dialect/Affine/Utils.h" 13dc3258c6SBoian Petkantchin #include "mlir/Dialect/Arith/IR/Arith.h" 14dc3258c6SBoian Petkantchin #include "mlir/Dialect/Arith/Utils/Utils.h" 15dc3258c6SBoian Petkantchin #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" 16dc3258c6SBoian Petkantchin #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 1731fc0a12SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshDialect.h" 1879aa7762SBoian Petkantchin #include "mlir/Dialect/Mesh/IR/MeshOps.h" 19dc3258c6SBoian Petkantchin #include "mlir/Dialect/Tensor/IR/Tensor.h" 20dc3258c6SBoian Petkantchin #include "mlir/Dialect/Utils/StaticValueUtils.h" 2179aa7762SBoian Petkantchin #include "mlir/IR/BuiltinTypes.h" 2279aa7762SBoian Petkantchin #include "mlir/IR/DialectRegistry.h" 2379aa7762SBoian Petkantchin #include "mlir/IR/ImplicitLocOpBuilder.h" 24dc3258c6SBoian Petkantchin #include "mlir/IR/OpDefinition.h" 2579aa7762SBoian Petkantchin #include "mlir/IR/PatternMatch.h" 2679aa7762SBoian Petkantchin #include "mlir/IR/Value.h" 2779aa7762SBoian Petkantchin #include "llvm/ADT/STLExtras.h" 2879aa7762SBoian Petkantchin #include "llvm/ADT/SmallVector.h" 2979aa7762SBoian Petkantchin #include <iterator> 3079aa7762SBoian Petkantchin #include <numeric> 3179aa7762SBoian Petkantchin 3279aa7762SBoian Petkantchin namespace mlir::mesh { 3379aa7762SBoian Petkantchin 3479aa7762SBoian Petkantchin namespace { 3579aa7762SBoian Petkantchin 3679aa7762SBoian Petkantchin /// Lower `mesh.process_multi_index` into expression using 379a8437f5SBoian Petkantchin /// `mesh.process_linear_index` and `mesh.mesh_shape`. 38dc3258c6SBoian Petkantchin struct ProcessMultiIndexOpLowering 39dc3258c6SBoian Petkantchin : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> { 40dc3258c6SBoian Petkantchin using OpRewritePatternWithSymbolTableCollection:: 41dc3258c6SBoian Petkantchin OpRewritePatternWithSymbolTableCollection; 4279aa7762SBoian Petkantchin 4379aa7762SBoian Petkantchin LogicalResult matchAndRewrite(ProcessMultiIndexOp op, 4479aa7762SBoian Petkantchin PatternRewriter &rewriter) const override { 45dc3258c6SBoian Petkantchin MeshOp mesh = getMesh(op, symbolTableCollection); 4679aa7762SBoian Petkantchin if (!mesh) { 4779aa7762SBoian Petkantchin return failure(); 4879aa7762SBoian Petkantchin } 4979aa7762SBoian Petkantchin 5079aa7762SBoian Petkantchin ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 5179aa7762SBoian Petkantchin builder.setInsertionPointAfter(op.getOperation()); 5279aa7762SBoian Petkantchin Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh); 539a8437f5SBoian Petkantchin ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults(); 5479aa7762SBoian Petkantchin SmallVector<Value> completeMultiIndex = 5579aa7762SBoian Petkantchin builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape) 5679aa7762SBoian Petkantchin .getMultiIndex(); 5779aa7762SBoian Petkantchin SmallVector<Value> multiIndex; 5879aa7762SBoian Petkantchin ArrayRef<MeshAxis> opMeshAxes = op.getAxes(); 5979aa7762SBoian Petkantchin SmallVector<MeshAxis> opAxesIota; 6079aa7762SBoian Petkantchin if (opMeshAxes.empty()) { 6179aa7762SBoian Petkantchin opAxesIota.resize(mesh.getRank()); 6279aa7762SBoian Petkantchin std::iota(opAxesIota.begin(), opAxesIota.end(), 0); 6379aa7762SBoian Petkantchin opMeshAxes = opAxesIota; 6479aa7762SBoian Petkantchin } 6579aa7762SBoian Petkantchin llvm::transform(opMeshAxes, std::back_inserter(multiIndex), 6679aa7762SBoian Petkantchin [&completeMultiIndex](MeshAxis meshAxis) { 6779aa7762SBoian Petkantchin return completeMultiIndex[meshAxis]; 6879aa7762SBoian Petkantchin }); 6979aa7762SBoian Petkantchin rewriter.replaceAllUsesWith(op.getResults(), multiIndex); 7079aa7762SBoian Petkantchin return success(); 7179aa7762SBoian Petkantchin } 72dc3258c6SBoian Petkantchin }; 7379aa7762SBoian Petkantchin 74dc3258c6SBoian Petkantchin struct AllSliceOpLowering 75dc3258c6SBoian Petkantchin : OpRewritePatternWithSymbolTableCollection<AllSliceOp> { 76dc3258c6SBoian Petkantchin using OpRewritePatternWithSymbolTableCollection:: 77dc3258c6SBoian Petkantchin OpRewritePatternWithSymbolTableCollection; 78dc3258c6SBoian Petkantchin 79dc3258c6SBoian Petkantchin LogicalResult matchAndRewrite(AllSliceOp op, 80dc3258c6SBoian Petkantchin PatternRewriter &rewriter) const override { 81dc3258c6SBoian Petkantchin // 1. Compute the process linear index inside the process group from its 82dc3258c6SBoian Petkantchin // multi-index. 83dc3258c6SBoian Petkantchin // 84dc3258c6SBoian Petkantchin // 2. Extract a slice from the input tensor. 85dc3258c6SBoian Petkantchin // All axes except the slicing axis are not interesting and take the full 86dc3258c6SBoian Petkantchin // axis. 87dc3258c6SBoian Petkantchin // The slice axis is split into equisized parts with count 88dc3258c6SBoian Petkantchin // the number of processes in the collective process group induced by 89dc3258c6SBoian Petkantchin // the mesh axes. 90dc3258c6SBoian Petkantchin // The part for each process is determined by the corresponding 91dc3258c6SBoian Petkantchin // linear-index in the process group. 92dc3258c6SBoian Petkantchin // 93dc3258c6SBoian Petkantchin // There are no collectives that require communication. 94dc3258c6SBoian Petkantchin // Each process operates on its local tensor. 95dc3258c6SBoian Petkantchin 96dc3258c6SBoian Petkantchin MeshOp mesh = getMesh(op, symbolTableCollection); 97dc3258c6SBoian Petkantchin if (!mesh) { 98dc3258c6SBoian Petkantchin return failure(); 99dc3258c6SBoian Petkantchin } 100dc3258c6SBoian Petkantchin 101dc3258c6SBoian Petkantchin ImplicitLocOpBuilder builder(op->getLoc(), rewriter); 102dc3258c6SBoian Petkantchin builder.setInsertionPointAfter(op.getOperation()); 103dc3258c6SBoian Petkantchin 104dc3258c6SBoian Petkantchin Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0)); 105dc3258c6SBoian Petkantchin 106dc3258c6SBoian Petkantchin Operation::result_range processInGroupMultiIndex = 107dc3258c6SBoian Petkantchin builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes()) 108dc3258c6SBoian Petkantchin .getResults(); 109dc3258c6SBoian Petkantchin 110dc3258c6SBoian Petkantchin Operation::result_range processGroupShape = 111dc3258c6SBoian Petkantchin builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes()) 112dc3258c6SBoian Petkantchin .getResult(); 113dc3258c6SBoian Petkantchin Value processGroupSize = 114dc3258c6SBoian Petkantchin createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder); 115dc3258c6SBoian Petkantchin 116dc3258c6SBoian Petkantchin int64_t sliceAxis = op.getSliceAxis().getSExtValue(); 117dc3258c6SBoian Petkantchin Value operandSliceAxisSize = 118dc3258c6SBoian Petkantchin builder.create<tensor::DimOp>(op.getOperand(), sliceAxis); 119dc3258c6SBoian Petkantchin Value operandSliceAxisSizeModProcessGroupSize = 120dc3258c6SBoian Petkantchin builder.create<arith::RemUIOp>(operandSliceAxisSize, processGroupSize); 121dc3258c6SBoian Petkantchin Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>( 122dc3258c6SBoian Petkantchin arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize, 123dc3258c6SBoian Petkantchin zero); 124dc3258c6SBoian Petkantchin builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible, 125dc3258c6SBoian Petkantchin "Slicing a tensor with axis size that is " 126dc3258c6SBoian Petkantchin "not exactly divisible by the " 127dc3258c6SBoian Petkantchin "mesh process group size is not supported."); 128dc3258c6SBoian Petkantchin Value resultSliceAxisSize = 129dc3258c6SBoian Petkantchin builder.create<arith::DivUIOp>(operandSliceAxisSize, processGroupSize); 130dc3258c6SBoian Petkantchin OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( 131dc3258c6SBoian Petkantchin llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex), 132dc3258c6SBoian Petkantchin llvm::to_vector_of<OpFoldResult>(processGroupShape), builder); 133dc3258c6SBoian Petkantchin 134dc3258c6SBoian Petkantchin // insert tensor.extract_slice 135dc3258c6SBoian Petkantchin RankedTensorType operandType = 136a5757c5bSChristian Sigg cast<RankedTensorType>(op.getOperand().getType()); 137dc3258c6SBoian Petkantchin SmallVector<OpFoldResult> sizes; 138dc3258c6SBoian Petkantchin for (int64_t i = 0; i < operandType.getRank(); ++i) { 139dc3258c6SBoian Petkantchin if (i == sliceAxis) { 140dc3258c6SBoian Petkantchin sizes.emplace_back(resultSliceAxisSize); 141dc3258c6SBoian Petkantchin } else { 142dc3258c6SBoian Petkantchin Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i); 143dc3258c6SBoian Petkantchin sizes.emplace_back(dimSize); 144dc3258c6SBoian Petkantchin } 145dc3258c6SBoian Petkantchin } 146dc3258c6SBoian Petkantchin SmallVector<OpFoldResult> offsets( 147dc3258c6SBoian Petkantchin operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0)); 148dc3258c6SBoian Petkantchin offsets[sliceAxis] = 149dc3258c6SBoian Petkantchin ArithBuilder(builder, builder.getLoc()) 150dc3258c6SBoian Petkantchin .mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(), 151dc3258c6SBoian Petkantchin processInGroupLinearIndex), 152dc3258c6SBoian Petkantchin resultSliceAxisSize); 153dc3258c6SBoian Petkantchin SmallVector<OpFoldResult> strides( 154dc3258c6SBoian Petkantchin operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1)); 155dc3258c6SBoian Petkantchin Value slice = builder.create<tensor::ExtractSliceOp>( 156dc3258c6SBoian Petkantchin op.getOperand(), offsets, sizes, strides); 157dc3258c6SBoian Petkantchin Value newResult = 158dc3258c6SBoian Petkantchin builder.create<tensor::CastOp>(op.getResult().getType(), slice); 159dc3258c6SBoian Petkantchin rewriter.replaceAllUsesWith(op.getResult(), newResult); 160dc3258c6SBoian Petkantchin 161dc3258c6SBoian Petkantchin return success(); 162dc3258c6SBoian Petkantchin } 16379aa7762SBoian Petkantchin }; 16479aa7762SBoian Petkantchin 16579aa7762SBoian Petkantchin } // namespace 16679aa7762SBoian Petkantchin 167dc3258c6SBoian Petkantchin void populateProcessMultiIndexOpLoweringPatterns( 16879aa7762SBoian Petkantchin RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { 16979aa7762SBoian Petkantchin patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection, 17079aa7762SBoian Petkantchin patterns.getContext()); 17179aa7762SBoian Petkantchin } 17279aa7762SBoian Petkantchin 173dc3258c6SBoian Petkantchin void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry) { 17479aa7762SBoian Petkantchin registry.insert<affine::AffineDialect, mesh::MeshDialect>(); 17579aa7762SBoian Petkantchin } 17679aa7762SBoian Petkantchin 177dc3258c6SBoian Petkantchin void populateAllSliceOpLoweringPatterns( 178dc3258c6SBoian Petkantchin RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { 179dc3258c6SBoian Petkantchin patterns.add<AllSliceOpLowering>(symbolTableCollection, 180dc3258c6SBoian Petkantchin patterns.getContext()); 181dc3258c6SBoian Petkantchin } 182dc3258c6SBoian Petkantchin 183dc3258c6SBoian Petkantchin void registerAllSliceOpLoweringDialects(DialectRegistry ®istry) { 184dc3258c6SBoian Petkantchin registry.insert<affine::AffineDialect, arith::ArithDialect, 185dc3258c6SBoian Petkantchin cf::ControlFlowDialect, mesh::MeshDialect, 186dc3258c6SBoian Petkantchin tensor::TensorDialect>(); 187dc3258c6SBoian Petkantchin } 188dc3258c6SBoian Petkantchin 189dc3258c6SBoian Petkantchin void populateAllOpLoweringPatterns( 190dc3258c6SBoian Petkantchin RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { 191dc3258c6SBoian Petkantchin populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection); 192dc3258c6SBoian Petkantchin populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); 193dc3258c6SBoian Petkantchin } 194dc3258c6SBoian Petkantchin 195dc3258c6SBoian Petkantchin void registerAllOpLoweringDialects(DialectRegistry ®istry) { 196dc3258c6SBoian Petkantchin registerProcessMultiIndexOpLoweringDialects(registry); 197dc3258c6SBoian Petkantchin registerAllSliceOpLoweringDialects(registry); 198dc3258c6SBoian Petkantchin } 199dc3258c6SBoian Petkantchin 200dc3258c6SBoian Petkantchin TypedValue<IndexType> 201dc3258c6SBoian Petkantchin createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes, 202dc3258c6SBoian Petkantchin ImplicitLocOpBuilder &builder) { 203dc3258c6SBoian Petkantchin Operation::result_range meshShape = 204dc3258c6SBoian Petkantchin builder.create<mesh::MeshShapeOp>(mesh, axes).getResults(); 205a5757c5bSChristian Sigg return cast<TypedValue<IndexType>>(arith::createProduct( 206a5757c5bSChristian Sigg builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape), 207a5757c5bSChristian Sigg builder.getIndexType())); 208dc3258c6SBoian Petkantchin } 209dc3258c6SBoian Petkantchin 210fb582b6aSBoian Petkantchin TypedValue<IndexType> createProcessLinearIndex(StringRef mesh, 211fb582b6aSBoian Petkantchin ArrayRef<MeshAxis> meshAxes, 212fb582b6aSBoian Petkantchin ImplicitLocOpBuilder &builder) { 213fb582b6aSBoian Petkantchin ResultRange processInGroupMultiIndex = 214fb582b6aSBoian Petkantchin builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults(); 215fb582b6aSBoian Petkantchin Operation::result_range processGroupShape = 216fb582b6aSBoian Petkantchin builder.create<MeshShapeOp>(mesh, meshAxes).getResult(); 217fb582b6aSBoian Petkantchin OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( 218fb582b6aSBoian Petkantchin llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex), 219fb582b6aSBoian Petkantchin llvm::to_vector_of<OpFoldResult>(processGroupShape), builder); 220*129f1001SKazu Hirata return cast<TypedValue<IndexType>>(cast<Value>(processInGroupLinearIndex)); 221fb582b6aSBoian Petkantchin } 222fb582b6aSBoian Petkantchin 22379aa7762SBoian Petkantchin } // namespace mlir::mesh 224