Lines Matching defs:sourceSharding

53 // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
59 MeshSharding sourceSharding,
62 if (sourceSharding.getPartialAxes().empty() &&
64 return {sourceShard, sourceSharding};
67 (!sourceSharding.getPartialAxes().empty() &&
68 sourceSharding.getPartialType() == targetSharding.getPartialType()));
69 using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
71 AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
72 sourceSharding.getPartialAxes().end());
84 return {sourceShard, sourceSharding};
91 sourceSharding.getMeshAttr().getLeafReference(),
93 sourceSharding.getPartialType())
103 sourceSharding.getMeshAttr(), sourceSharding.getSplitAxes(),
104 remainingPartialAxes, sourceSharding.getPartialType());
109 MeshSharding sourceSharding,
113 llvm::to_vector(sourceSharding.getSplitAxes());
124 sourceSharding.getMeshAttr(), targetShardingSplitAxes,
125 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
133 MeshSharding sourceSharding,
143 builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
153 detectSplitLastAxisInResharding(MeshSharding sourceSharding,
157 if (sourceSharding.getSplitAxes().size() > tensorAxis) {
158 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
163 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
186 MeshSharding sourceSharding,
190 detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
192 return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
203 detectUnsplitLastAxisInResharding(MeshSharding sourceSharding,
205 for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
208 if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
213 sourceSharding.getSplitAxes()[tensorAxis]
216 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
221 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
226 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
232 MeshSharding sourceSharding,
235 llvm::to_vector(sourceSharding.getSplitAxes());
245 sourceSharding.getMeshAttr(), targetShardingSplitAxes,
246 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
259 MeshSharding sourceSharding,
267 targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
284 MeshSharding sourceSharding,
289 detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
291 return unsplitLastAxisInResharding(builder, sourceSharding,
305 detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding,
308 sourceTensorAxis < sourceSharding.getSplitAxes().size();
315 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
317 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
323 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
326 sourceSharding.getSplitAxes()[sourceTensorAxis]
340 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
347 MeshSharding sourceSharding,
351 llvm::to_vector(sourceSharding.getSplitAxes());
372 sourceSharding.getMeshAttr(), targetShardingSplitAxes,
373 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
390 MeshSharding sourceSharding,
399 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
417 MeshSharding sourceSharding,
422 detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
425 builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
438 MeshSharding sourceSharding,
444 if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) ||
445 !sourceSharding.getPartialAxes().empty() ||
447 !sourceSharding.getStaticShardedDimsOffsets().empty() ||
449 sourceSharding.equalHaloSizes(targetSharding)) {
453 auto srcHaloSizes = sourceSharding.getStaticHaloSizes();
461 auto splitAxes = sourceSharding.getSplitAxes();
501 sourceSharding.getSplitAxes()),
514 MeshSharding sourceSharding, MeshSharding targetSharding,
518 shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
525 handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
559 MeshSharding sourceSharding,
564 if (sourceSharding == targetSharding) {
571 builder, mesh, sourceSharding, targetSharding,
579 return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
587 auto sourceSharding = source.getSharding();
590 return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,