Lines Matching defs:genericOp
85 LogicalResult matchAndRewrite(GenericOp genericOp,
87 if (!genericOp.hasPureTensorSemantics())
89 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
92 auto outputOperands = genericOp.getDpsInitsMutable();
95 if (genericOp.getMatchingBlockArgument(&op).use_empty())
104 int64_t origNumInput = genericOp.getNumDpsInputs();
105 SmallVector<Value> newInputOperands = genericOp.getDpsInputs();
106 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
112 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
117 Location loc = genericOp.getLoc();
119 llvm::to_vector(genericOp.getDpsInits());
127 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
132 loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
133 newIndexingMaps, genericOp.getIteratorTypesArray(),
134 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
140 for (auto bbarg : genericOp.getRegionInputArgs())
144 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
149 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
156 for (auto &op : genericOp.getBody()->getOperations()) {
159 rewriter.replaceOp(genericOp, newOp.getResults());
229 replaceUnitDimIndexOps(GenericOp genericOp,
233 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
337 MLIRContext *context, GenericOp genericOp, OpOperand *opOperand,
343 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
344 ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand);
386 linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
388 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
398 return rewriter.notifyMatchFailure(genericOp,
401 SmallVector<int64_t> dims = genericOp.getStaticShape();
404 SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp);
407 genericOp, "control function returns no allowed unit dims to prune");
427 llvm::enumerate(genericOp.getIteratorTypesArray())) {
465 for (OpOperand &opOperand : genericOp->getOpOperands()) {
466 auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand);
467 ArrayRef<int64_t> shape = genericOp.getShape(&opOperand);
478 rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap,
494 Location loc = genericOp.getLoc();
500 for (OpOperand &opOperand : genericOp->getOpOperands()) {
514 ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
516 ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
518 resultTypes.reserve(genericOp.getNumResults());
519 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
524 rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
535 Value origDest = genericOp.getDpsInitOperand(index)->get();
555 LogicalResult matchAndRewrite(GenericOp genericOp,
558 dropUnitDims(rewriter, genericOp, options);
562 rewriter.replaceOp(genericOp, result->replacements);