Lines Matching defs:splitTensorAxis
110 int64_t splitTensorAxis,
115 splitTensorAxis) {
119 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
121 targetShardingSplitAxes[splitTensorAxis] =
135 int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
140 splitTensorAxis)
143 builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
233 int64_t splitTensorAxis) {
237 splitTensorAxis);
239 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
242 targetShardingSplitAxes[splitTensorAxis] =
250 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
252 targetShape[splitTensorAxis] =
253 gatherDimension(targetShape[splitTensorAxis], splitCount);
262 int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
267 targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis);
269 sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
274 APInt(64, splitTensorAxis));