Lines Matching defs:mesh
44 using MeshAxis = mesh::MeshAxis;
45 using ReductionKind = mesh::ReductionKind;
46 using MeshSharding = mesh::MeshSharding;
47 using ShardingArray = mesh::ShardingArray;
48 using MeshOp = mesh::MeshOp;
50 // Returns the corresponding mesh reduction kind for the given arith op.
110 return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
116 return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
125 // mesh axes.
135 Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
208 Value reducedValue = builder.create<mesh::AllReduceOp>(
235 MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
236 SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
239 createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
269 : public mesh::ShardingInterface::ExternalModel<
298 mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp);
326 if (mesh::isAtLeastOneReductionIteratorSharded(