Lines Matching defs:odsBuilder

400 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
402 build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
405 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
407 build(odsBuilder, odsState,
409 odsBuilder.getIndexType()),
410 mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
413 void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
416 build(odsBuilder, odsState,
417 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
418 MeshAxesAttr::get(odsBuilder.getContext(), axes));
757 void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder,
761 SmallVector<mlir::Type> resType(shape.size(), odsBuilder.getIndexType());
762 build(odsBuilder, odsState, resType, shape, sharding, device);
798 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
800 build(odsBuilder, odsState,
801 SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
805 void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
807 build(odsBuilder, odsState,
808 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
809 MeshAxesAttr::get(odsBuilder.getContext(), axes));
830 void ProcessLinearIndexOp::build(OpBuilder &odsBuilder,
832 build(odsBuilder, odsState, mesh.getSymName());
1107 void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1110 build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
1138 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1142 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1146 void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
1149 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,