Lines Matching defs:targetSharding

54 // targetSharding = <@mesh_1d, [[]]>
60 MeshSharding targetSharding,
63 targetSharding.getPartialAxes().empty()) {
66 assert(targetSharding.getPartialAxes().empty() ||
68 sourceSharding.getPartialType() == targetSharding.getPartialType()));
73 AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
74 targetSharding.getPartialAxes().end());
142 MeshSharding targetSharding = targetShardingInSplitLastAxis(
144 return {targetShard, targetSharding};
154 MeshSharding targetSharding) {
155 for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
159 targetSharding.getSplitAxes()[tensorAxis].size()) {
165 targetSharding.getSplitAxes()[tensorAxis]
168 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
173 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
179 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
187 MeshSharding targetSharding,
190 detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
204 MeshSharding targetSharding) {
207 if (targetSharding.getSplitAxes().size() > tensorAxis) {
209 targetSharding.getSplitAxes()[tensorAxis].size() + 1)
218 targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
266 MeshSharding targetSharding =
276 shardShapedType(sourceUnshardedShape, mesh, targetSharding);
279 return {targetShard, targetSharding};
285 MeshSharding targetSharding,
289 detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
306 MeshSharding targetSharding) {
311 targetTensorAxis < targetSharding.getSplitAxes().size();
316 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
318 targetSharding.getSplitAxes()[targetTensorAxis]
330 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
333 targetSharding.getSplitAxes()[targetTensorAxis]
398 MeshSharding targetSharding = targetShardingInMoveLastAxis(
409 shardShapedType(sourceUnshardedShape, mesh, targetSharding);
412 return {targetShard, targetSharding};
418 MeshSharding targetSharding,
422 detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
439 MeshSharding targetSharding,
444 if (!sourceSharding.equalSplitAndPartialAxes(targetSharding) ||
446 !targetSharding.getPartialAxes().empty() ||
448 !targetSharding.getStaticShardedDimsOffsets().empty() ||
449 sourceSharding.equalHaloSizes(targetSharding)) {
454 auto tgtHaloSizes = targetSharding.getStaticHaloSizes();
502 targetSharding.getDynamicHaloSizes(),
503 targetSharding.getStaticHaloSizes())
506 targetSharding);
514 MeshSharding sourceSharding, MeshSharding targetSharding,
520 shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
525 handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
528 if (reducedSourceSharding == targetSharding) {
535 targetSharding.getStaticShardedDimsOffsets().empty() &&
537 targetSharding.getStaticHaloSizes().empty()) {
539 builder, mesh, reducedSourceSharding, targetSharding,
543 builder, mesh, reducedSourceSharding, targetSharding,
547 builder, mesh, reducedSourceSharding, targetSharding,
553 assert(actualTargetSharding == targetSharding);
560 MeshSharding targetSharding,
564 if (sourceSharding == targetSharding) {
571 builder, mesh, sourceSharding, targetSharding,
579 return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
588 auto targetSharding = target.getSharding();
590 return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,