Lines Matching +full:non +full:- +full:batch
1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
40 //===----------------------------------------------------------------------===//
42 //===----------------------------------------------------------------------===//
47 for (auto &opOperand : linalgOp->getOpOperands()) {
61 //===----------------------------------------------------------------------===//
63 //===----------------------------------------------------------------------===//
76 return llvm::hasSingleElement(op.getBlock()->getOperations());
79 //===----------------------------------------------------------------------===//
81 //===----------------------------------------------------------------------===//
96 return value->get();
99 //===----------------------------------------------------------------------===//
101 //===----------------------------------------------------------------------===//
109 auto srcTy = op.getDpsInputOperand(0)->get().getType();
110 auto dstTy = op.getDpsInitOperand(0)->get().getType();
134 if (i > 0 && pos <= position[i - 1])
149 //===----------------------------------------------------------------------===//
151 //===----------------------------------------------------------------------===//
181 //===----------------------------------------------------------------------===//
182 // Elementwise Single Unary/Binary-OpInterface implementation
183 //===----------------------------------------------------------------------===//
190 // Check there are arity-inputs, 1-output and all are identity-maps.
201 // as resulting from producer-consumer fusion. Here, we restrict to two ops in
205 if (body->getOperations().size() != 2)
208 Operation *oper = &body->front();
209 if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
212 auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
214 yieldOp->getOperand(0).getDefiningOp() != oper)
243 //===----------------------------------------------------------------------===//
245 //===----------------------------------------------------------------------===//
247 /// If the value is defined by a chain of unary side effect-free, go up the
248 /// use-def chain until the first value that isn't defined by such an op.
249 // TODO: relax to multi-operands with constants, which are technically unary ops
253 while (op && op->getNumOperands() == 1) {
257 value = op->getOperand(0);
277 if (terminator->getNumOperands() != 1) {
282 Value yielded = getSourceSkipUnary(terminator->getOperand(0));
284 if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
289 Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
290 Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
302 if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
303 elementwiseOp->getNumOperands() != 2) {
313 Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
314 Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
354 /// - It is a single AffineDimExpr.
355 /// - It is the only result involving this AffineDimExpr.
396 /// 1. The m dimension is involved in an outer-product along LHS
398 /// 2. The n dimension is involved in an outer-product along RHS
402 /// 5. Optional batch dimensions that appear in all operands are captured.
415 // A & C - B are the iterators involved in an outer-product along A (the LHS).
419 // B & C - A are the iterators involved in an outer-product along B (the RHS).
423 // A & B & C are the "batch" dimensions.
441 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
492 // clang-format off
501 // clang-format on
546 /// operations that may change the type (e.g. for mixed-precision).
555 return op->emitError(getMatchContractionMessage(res));
559 //===----------------------------------------------------------------------===//
561 //===----------------------------------------------------------------------===//
573 /// - AffineDimExpr
574 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
603 int64_t pairedDim = it->second;
628 // In pre-order visit, top level op has to be an add op.
692 assert(constantExpr && "Found non-constant stride/dilation");
718 // unConvolvedDims & outputDims - filterDims are the batch iterators.
719 llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
720 llvm::set_intersect(batch, outputDims);
721 llvm::set_subtract(batch, filterDims);
727 // filterDims & outputDims - unConvolvedDims are the output channel iterators.
754 SmallVector<unsigned, 2>(batch.begin(), batch.end()),
762 llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
770 auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
780 linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
798 /// 1. Optional batch dimensions that appear in the input and filter.
799 /// 2. The output_image dimension is involved in a cross-correlation along LHS
802 /// 3. Optional output_channel dimension is involved in an outer-product along
808 /// represents the shape of the kernel cross-correlated along a
880 // - Batch loop : present in output, as non-convolved in input, not present in
882 // - Output image dimension : present in output, convolved dims in input, not
884 // - Output channel dimension : present in output, not present in input,
886 // - Filter loop dimension : present in filter, convolved in input, not
888 // - Input channel dimension : unconvolved in input, not present in output,
890 // - Depth multiplier : unconvolved in input, present in output, present in
897 // Batch dimension.
1000 return "expected convolved dim to be non-empty";
1017 return op->emitError(getMatchConvolutionMessage(res));
1021 //===----------------------------------------------------------------------===//
1023 //===----------------------------------------------------------------------===//
1049 return op->emitError("expected a LinalgOp");
1051 return op->emitError("expected op with 1 input and 1 output");
1053 return op->emitError("expected op with scalar input");
1058 //===----------------------------------------------------------------------===//
1060 //===----------------------------------------------------------------------===//
1065 for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1075 for (OpOperand &opOperand : getOperation()->getOpOperands())
1151 // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1152 // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1153 // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1155 // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1166 resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1175 Location loc = getOperation()->getLoc();
1207 auto operandNumber = opOperand->getOperandNumber();
1208 auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1215 return cast<DestinationStyleOpInterface>(*this->getOperation())
1217 operandNumber - start;
1224 !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1225 return op->emitOpError("expected to have pure tensor or buffer semantics");
1235 linalgOp->getNumOperands())
1236 return op->emitOpError("expected the number of indexing_map (")
1239 << linalgOp->getNumOperands() << ")";
1243 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1248 return op->emitOpError("unexpected symbols in indexing_map #")
1254 return op->emitOpError("expected indexing_map #")
1261 return op->emitOpError("expected operand rank (")
1270 return op->emitOpError("expected the shape-to-loops map to be non-null");
1279 range -= 1;
1280 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1301 // -> (d1 - d0)
1310 return op->emitOpError(
1316 return op->emitOpError("inferred input/output operand #")
1323 return op->emitOpError("inferred input/output operand #")
1334 if (linalgOp->getNumRegions() != 1 ||
1335 !llvm::hasSingleElement(linalgOp->getRegion(0)))
1336 return op->emitOpError("expects to have 1 region with 1 block");
1338 // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1344 Block &block = linalgOp->getRegion(0).front();
1347 return op->emitOpError("expected as many non-induction variable region "
1351 Type elementType = opOperand->get().getType();
1353 elementType = getElementTypeOrSelf(opOperand->get().getType());
1354 Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1356 return op->emitOpError("expected type of bb argument #")
1357 << opOperand->getOperandNumber() << " (" << argType << ")"