Lines Matching defs:targetTensorAxis
310 for (size_t targetTensorAxis = 0;
311 targetTensorAxis < targetSharding.getSplitAxes().size();
312 ++targetTensorAxis) {
313 if (sourceTensorAxis == targetTensorAxis)
316 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
318 targetSharding.getSplitAxes()[targetTensorAxis]
330 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
333 targetSharding.getSplitAxes()[targetTensorAxis]
339 sourceTensorAxis, targetTensorAxis,
349 int64_t targetTensorAxis) {
353 targetTensorAxis) {
366 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
368 targetShardingSplitAxes[targetTensorAxis] =
379 int64_t targetTensorAxis) {
383 targetShape[targetTensorAxis] =
384 shardDimension(targetShape[targetTensorAxis], splitCount);
394 int64_t targetTensorAxis, MeshAxis meshAxis) {
399 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
402 targetTensorAxis);
407 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
423 auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
426 sourceTensorAxis, targetTensorAxis, meshAxis);