Lines Matching defs:mesh
41 namespace mlir::mesh {
128 // Split a replicated tensor along a mesh axis.
134 TypedValue<ShapedType> sourceShard, MeshOp mesh,
138 .create<AllSliceOp>(sourceShard, mesh,
149 // If detected, returns the corresponding tensor axis mesh axis pair.
185 trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
192 return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
201 // If detected, returns the corresponding tensor axis mesh axis pair.
261 TypedValue<ShapedType> sourceShard, MeshOp mesh,
269 sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
273 mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
276 shardShapedType(sourceUnshardedShape, mesh, targetSharding);
283 tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
292 sourceUnshardedShape, sourceShard, mesh,
389 moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
401 sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
406 mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
409 shardShapedType(sourceUnshardedShape, mesh, targetSharding);
416 tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
425 builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
437 tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh,
459 "dynamic shapes/halos are not supported yet for mesh-spmdization");
499 initOprnd, mesh.getSymName(),
509 // Handles only resharding on a 1D mesh.
511 // mesh axis size.
513 reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh,
518 shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
520 shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
522 assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
539 builder, mesh, reducedSourceSharding, targetSharding,
543 builder, mesh, reducedSourceSharding, targetSharding,
547 builder, mesh, reducedSourceSharding, targetSharding,
558 TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh,
569 // sizes are different. Supports arbitrary mesh dimensionality.
571 builder, mesh, sourceSharding, targetSharding,
579 return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
583 TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source,
590 return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
605 registry.insert<mesh::MeshDialect, tensor::TensorDialect>();
632 MeshOp mesh = getMesh(shardOp, symbolTableCollection);
633 return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh,
848 registry.insert<mesh::MeshDialect>();
854 } // namespace mlir::mesh