Lines Matching defs:meshShape
193 static void shardShape(const InShape &inShape, const MeshShape &meshShape,
201 auto isDynShape = ShapedType::isDynamicShape(meshShape);
213 numShards += meshShape[i];
231 collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape));
551 auto meshShape = mesh.value().getShape();
552 assert(!ShapedType::isDynamicShape(meshShape));
558 numShards += meshShape[i];
890 ArrayRef<int64_t> meshShape) {
900 !ShapedType::isDynamic(meshShape[meshAxes[i]]) &&
901 meshShape[meshAxes[i]] <= device[i]) {
906 << (meshShape[meshAxes[i]] - 1) << "].";
943 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
954 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
970 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
988 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));
1015 ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) {
1029 DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape));