Lines Matching defs:spmdizationMap
643 IRMapping &spmdizationMap,
650 ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap,
657 resultShardings, spmdizationMap,
661 resultShardings, spmdizationMap,
667 assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) {
668 return spmdizationMap.contains(result);
716 spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap,
726 targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc());
730 cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp));
735 assert(!spmdizationMap.contains(shardOp.getResult()));
736 spmdizationMap.map(shardOp.getResult(), targetSpmdValue);
741 spmdizeOperation(Operation &op, IRMapping &spmdizationMap,
750 return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection,
756 [&spmdizationMap](Value operand) {
757 assert(spmdizationMap.contains(operand));
758 return spmdizationMap.lookup(operand);
761 getResultShardings(op), spmdizationMap,
765 static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
776 spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg);
782 if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection,
792 spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
803 if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
838 IRMapping spmdizationMap;
840 if (failed(spmdizeFuncOp(getOperation(), spmdizationMap,