Lines Matching defs:mesh
41 #define DEBUG_TYPE "mesh-ops"
45 using namespace mlir::mesh;
130 mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable);
131 if (!mesh) {
132 return op->emitError() << "Undefined required mesh symbol \""
136 return mesh;
157 MeshOp mesh) {
164 MeshAxis rank = mesh.getRank();
168 << "0-based mesh axis index " << axis
169 << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
180 auto mesh =
182 if (failed(mesh)) {
185 if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) {
188 return mesh;
254 ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh,
258 shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(),
264 Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) {
267 return shardShapedType(rankedTensorType, mesh, sharding);
272 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
306 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
314 void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding,
358 // mesh.mesh op
365 return emitOpError("rank of mesh is expected to be a positive integer");
369 return emitOpError("dimension size of a mesh is expected to be "
377 // mesh.mesh_shape op
382 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
383 if (failed(mesh)) {
386 if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
391 getAxes().empty() ? mesh->getRank() : getAxes().size();
401 MeshOp mesh) {
402 build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>());
406 MeshOp mesh, ArrayRef<MeshAxis> axes) {
408 SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(),
410 mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes));
414 StringRef mesh, ArrayRef<MeshAxis> axes) {
417 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
427 // mesh.sharding
431 FlatSymbolRefAttr mesh,
434 mesh::ReductionKind partial_type,
438 b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes),
440 ::mlir::mesh::ReductionKindAttr::get(b.getContext(), partial_type),
448 FlatSymbolRefAttr mesh,
451 b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
452 ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
458 FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes,
466 b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), {},
467 ::mlir::mesh::ReductionKindAttr::get(b.getContext(), ReductionKind::Sum),
473 mlir::mesh::MeshSharding from) {
480 ::mlir::mesh::ReductionKindAttr::get(b.getContext(),
498 return emitError() << "mesh axis is expected to be non-negative";
500 return emitError() << "mesh axis duplicated";
539 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
540 if (failed(mesh)) {
543 if (mlir::ShapedType::isDynamicShape(mesh->getShape()) &&
551 auto meshShape = mesh.value().getShape();
731 res.mesh = mesh_;
754 // mesh.shard_shape
766 // mesh.shard op
775 // mesh.process_multi_index op
780 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
781 if (failed(mesh)) {
784 if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
789 getAxes().empty() ? mesh->getRank() : getAxes().size();
799 MeshOp mesh) {
801 SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
802 mesh.getSymName(), ArrayRef<MeshAxis>());
806 StringRef mesh, ArrayRef<MeshAxis> axes) {
808 SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
818 // mesh.process_linear_index op
823 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
824 if (failed(mesh)) {
831 OperationState &odsState, MeshOp mesh) {
832 build(odsBuilder, odsState, mesh.getSymName());
841 // mesh.neighbors_linear_indices op
846 auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
847 if (failed(mesh)) {
1051 static RankedTensorType sliceResultType(Type operandType, MeshOp mesh,
1063 DimensionSize(collectiveProcessGroupSize(meshAxes, mesh));
1068 // mesh.all_gather op
1073 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1074 if (failed(mesh)) {
1080 mesh.value().getShape());
1094 // mesh.all_reduce op
1108 Value input, StringRef mesh,
1110 build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
1120 // mesh.all_slice op
1124 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1125 if (failed(mesh)) {
1130 mesh.value().getShape());
1139 Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes,
1141 Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis);
1142 build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes,
1147 Type resultType, Value input, StringRef mesh,
1149 build(odsBuilder, odsState, resultType, mesh, meshAxes, input,
1159 // mesh.all_to_all op
1163 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1164 if (failed(mesh)) {
1170 getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape());
1184 // mesh.broadcast op
1189 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1190 if (failed(mesh)) {
1195 mesh.value().getShape()))) {
1213 // mesh.gather op
1217 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1218 if (failed(mesh)) {
1223 mesh.value().getShape()))) {
1230 mesh.value().getShape());
1244 // mesh.recv op
1248 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1249 if (failed(mesh)) {
1255 getMeshAxes(), mesh.value().getShape()))) {
1271 // mesh.reduce op
1275 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1276 if (failed(mesh)) {
1281 mesh.value().getShape()))) {
1299 // mesh.reduce_scatter op
1304 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1305 if (failed(mesh)) {
1311 mesh.value().getShape());
1325 // mesh.scatter op
1329 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1330 if (failed(mesh)) {
1335 mesh.value().getShape()))) {
1342 mesh.value().getShape());
1356 // mesh.send op
1360 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1361 if (failed(mesh)) {
1366 getMeshAxes(), mesh.value().getShape()))) {
1382 // mesh.shift op
1386 auto mesh = getMeshAndVerifyAxes(*this, symbolTable);
1387 if (failed(mesh)) {
1395 << ". It must be one of the grouping mesh axes.";
1413 // mesh.update_halo op
1418 auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable);
1419 if (failed(mesh)) {