/llvm-project/mlir/lib/Dialect/Mesh/IR/ |
H A D | MeshOps.cpp | 130 mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable); in verifyMeshAxes() argument 103 mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable); getMeshAndVerify() local 163 shardShapedType(ShapedType shape,MeshOp mesh,MeshShardingAttr sharding) shardShapedType() argument 172 shardType(Type type,MeshOp mesh,MeshShardingAttr sharding) shardType() argument 289 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); verifySymbolUses() local 308 build(OpBuilder & odsBuilder,OperationState & odsState,MeshOp mesh) build() argument 313 build(OpBuilder & odsBuilder,OperationState & odsState,MeshOp mesh,ArrayRef<MeshAxis> axes) build() argument 321 build(OpBuilder & odsBuilder,OperationState & odsState,StringRef mesh,ArrayRef<MeshAxis> axes) build() argument 420 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); verifySymbolUses() local 439 build(OpBuilder & odsBuilder,OperationState & odsState,MeshOp mesh) build() argument 446 build(OpBuilder & odsBuilder,OperationState & odsState,StringRef mesh,ArrayRef<MeshAxis> axes) build() argument 463 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); verifySymbolUses() local 471 build(OpBuilder & odsBuilder,OperationState & odsState,MeshOp mesh) build() argument 536 auto mesh = getMeshAndVerifyAxes() local 686 sliceResultType(Type operandType,MeshOp mesh,ArrayRef<MeshAxis> meshAxes,int64_t sliceAxis) sliceResultType() argument 708 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); verifySymbolUses() local 743 build(OpBuilder & odsBuilder,OperationState & odsState,Value input,StringRef mesh,ArrayRef<MeshAxis> meshAxes,ReductionKind reduction) build() argument 759 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); verifySymbolUses() local 774 build(OpBuilder & odsBuilder,OperationState & odsState,Value input,MeshOp mesh,ArrayRef<MeshAxis> meshAxes,int64_t sliceAxis) build() argument 782 build(OpBuilder & odsBuilder,OperationState & odsState,Type resultType,Value input,StringRef mesh,ArrayRef<MeshAxis> meshAxes,int64_t sliceAxis) build() argument 798 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); verifySymbolUses() local 824 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); verifySymbolUses() local 852 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); verifySymbolUses() local 883 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); verifySymbolUses() local 910 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); verifySymbolUses() local 939 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); verifySymbolUses() local 964 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); verifySymbolUses() local 995 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); verifySymbolUses() local 1021 auto mesh = getMeshAndVerifyAxes(*this, symbolTable); verifySymbolUses() local [all...] |
/llvm-project/mlir/lib/Dialect/Mesh/Transforms/ |
H A D | Transforms.cpp | 32 namespace mlir::mesh { namespace 45 MeshOp mesh = getMesh(op, symbolTableCollection); in matchAndRewrite() local 96 MeshOp mesh = getMesh(op, symbolTableCollection); matchAndRewrite() local 201 createCollectiveProcessGroupSize(MeshOp mesh,ArrayRef<MeshAxis> axes,ImplicitLocOpBuilder & builder) createCollectiveProcessGroupSize() argument 210 createProcessLinearIndex(StringRef mesh,ArrayRef<MeshAxis> meshAxes,ImplicitLocOpBuilder & builder) createProcessLinearIndex() argument [all...] |
H A D | Simplifications.cpp | 23 namespace mesh { namespace 65 MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>( in matchAndRewrite() local
|
H A D | Spmdization.cpp | 41 namespace mlir::mesh { namespace 134 TypedValue<ShapedType> sourceShard, MeshOp mesh, in splitLastAxisInResharding() argument 185 trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, in trySplitLastAxisInResharding() argument 262 unsplitLastAxisInResharding(ImplicitLocOpBuilder & builder,MeshShardingAttr sourceSharding,ShapedType sourceUnshardedShape,TypedValue<ShapedType> sourceShard,MeshOp mesh,int64_t splitTensorAxis,MeshAxis splitMeshAxis) unsplitLastAxisInResharding() argument 284 tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder & builder,MeshOp mesh,MeshShardingAttr sourceSharding,MeshShardingAttr targetSharding,ShapedType sourceUnshardedShape,TypedValue<ShapedType> sourceShard) tryUnsplitLastAxisInResharding() argument 390 moveLastSplitAxisInResharding(ImplicitLocOpBuilder & builder,MeshOp mesh,MeshShardingAttr sourceSharding,ShapedType sourceUnshardedShape,TypedValue<ShapedType> sourceShard,int64_t sourceTensorAxis,int64_t targetTensorAxis,MeshAxis meshAxis) moveLastSplitAxisInResharding() argument 417 tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder & builder,MeshOp mesh,MeshShardingAttr sourceSharding,MeshShardingAttr targetSharding,ShapedType sourceUnshardedShape,TypedValue<ShapedType> sourceShard) tryMoveLastSplitAxisInResharding() argument 437 reshardOn1DMesh(ImplicitLocOpBuilder & builder,MeshOp mesh,MeshShardingAttr sourceSharding,MeshShardingAttr targetSharding,TypedValue<ShapedType> sourceUnshardedValue,TypedValue<ShapedType> sourceShard) reshardOn1DMesh() argument 480 reshard(ImplicitLocOpBuilder & builder,MeshOp mesh,MeshShardingAttr sourceSharding,MeshShardingAttr targetSharding,TypedValue<ShapedType> sourceUnshardedValue,TypedValue<ShapedType> sourceShard) reshard() argument 492 reshard(OpBuilder & builder,MeshOp mesh,ShardOp source,ShardOp target,TypedValue<ShapedType> sourceShardValue) reshard() argument 539 MeshOp mesh = getMesh(shardOp, symbolTableCollection); shardedBlockArgumentTypes() local [all...] |
H A D | ShardingPropagation.cpp | 27 namespace mesh { namespace [all...] |
/llvm-project/mlir/include/mlir/Dialect/Mesh/Interfaces/ |
H A D | ShardingInterface.h | 32 FlatSymbolRefAttr mesh = nullptr; member
|
/llvm-project/mlir/test/lib/Dialect/Mesh/ |
H A D | TestReshardingSpmdization.cpp | 39 mesh::MeshOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::MeshOp>( in matchAndRewrite() local [all...] |
/llvm-project/mlir/include/mlir/Dialect/Mesh/IR/ |
H A D | MeshOps.h | 106 collectiveProcessGroupSize(MeshAxesRange && meshAxes,MeshOp mesh) collectiveProcessGroupSize() argument [all...] |
/llvm-project/mlir/lib/Dialect/Linalg/Transforms/ |
H A D | MeshShardingInterfaceImpl.cpp | 228 MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable); spmdizeLinalgOpWithShardedReduction() local [all...] |
/llvm-project/mlir/lib/Dialect/Mesh/Interfaces/ |
H A D | ShardingInterface.cpp | 221 fillShardingOption(Operation * op,ShardingOption & shardingOption,FlatSymbolRefAttr mesh,ArrayRef<MeshAxis> meshAxes,unsigned loopIdx) fillShardingOption() argument [all...] |