Lines Matching defs:sourceTensorAxis
307 for (size_t sourceTensorAxis = 0;
308 sourceTensorAxis < sourceSharding.getSplitAxes().size();
309 ++sourceTensorAxis) {
313 if (sourceTensorAxis == targetTensorAxis)
315 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
317 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
323 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
326 sourceSharding.getSplitAxes()[sourceTensorAxis]
339 sourceTensorAxis, targetTensorAxis,
340 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
348 int64_t sourceTensorAxis,
358 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
362 targetShardingSplitAxes[sourceTensorAxis] =
378 int64_t sourceTensorAxis,
381 targetShape[sourceTensorAxis] =
382 gatherDimension(targetShape[sourceTensorAxis], splitCount);
393 int64_t sourceTensorAxis,
399 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
401 sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
407 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
423 auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
426 sourceTensorAxis, targetTensorAxis, meshAxis);