Lines Matching defs:targetShape
251 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
252 targetShape[splitTensorAxis] =
253 gatherDimension(targetShape[splitTensorAxis], splitCount);
254 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
275 ShapedType targetShape =
278 builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult());
380 SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
381 targetShape[sourceTensorAxis] =
382 gatherDimension(targetShape[sourceTensorAxis], splitCount);
383 targetShape[targetTensorAxis] =
384 shardDimension(targetShape[targetTensorAxis], splitCount);
385 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
408 ShapedType targetShape =
411 builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult());