Lines Matching defs:sourceShard
61 TypedValue<ShapedType> sourceShard) {
64 return {sourceShard, sourceSharding};
84 return {sourceShard, sourceSharding};
87 builder.setInsertionPointAfterValue(sourceShard);
90 .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
92 allReduceMeshAxes, sourceShard,
134 TypedValue<ShapedType> sourceShard, MeshOp mesh,
138 .create<AllSliceOp>(sourceShard, mesh,
188 TypedValue<ShapedType> sourceShard) {
192 return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
261 TypedValue<ShapedType> sourceShard, MeshOp mesh,
264 builder.setInsertionPointAfterValue(sourceShard);
269 sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
273 mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
287 TypedValue<ShapedType> sourceShard) {
292 sourceUnshardedShape, sourceShard, mesh,
392 TypedValue<ShapedType> sourceShard,
396 builder.setInsertionPointAfterValue(sourceShard);
401 sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
406 mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
420 TypedValue<ShapedType> sourceShard) {
425 builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
441 TypedValue<ShapedType> sourceShard) {
458 sourceShard.getType().hasStaticShape()) &&
460 auto rank = sourceShard.getType().getRank();
463 strides(rank, 1), outShape(sourceShard.getType().getShape()),
464 coreShape(sourceShard.getType().getShape());
483 sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
485 sourceShard.getLoc(),
486 RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
487 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
489 sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
496 sourceShard.getLoc(),
498 sourceShard.getType().getElementType()),
516 TypedValue<ShapedType> sourceShard) {
517 assert(sourceShard.getType() ==
521 assert(sourceShard.getType().getRank() == targetShardType.getRank());
526 sourceShard);
562 TypedValue<ShapedType> sourceShard) {
565 return sourceShard;
572 sourceUnshardedValue.getType(), sourceShard)) {
580 sourceUnshardedValue, sourceShard);