Lines Matching +full:batch +full:- +full:reduce
1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
13 //===----------------------------------------------------------------------===//
41 //===----------------------------------------------------------------------===//
43 //===----------------------------------------------------------------------===//
50 //===----------------------------------------------------------------------===//
52 //===----------------------------------------------------------------------===//
56 //===--------------------------------------------------------------------===//
58 //===--------------------------------------------------------------------===//
69 return (isa<tosa::IfOp>(dest->getParentOp()) ||
70 isa<tosa::WhileOp>(dest->getParentOp()));
79 //===--------------------------------------------------------------------===//
91 //===--------------------------------------------------------------------===//
122 //===----------------------------------------------------------------------===//
124 //===----------------------------------------------------------------------===//
129 //===----------------------------------------------------------------------===//
131 //===----------------------------------------------------------------------===//
171 //===----------------------------------------------------------------------===//
173 //===----------------------------------------------------------------------===//
214 //===----------------------------------------------------------------------===//
216 //===----------------------------------------------------------------------===//
326 // Ensure output is of 32-bit integer
411 //===----------------------------------------------------------------------===//
413 //===----------------------------------------------------------------------===//
536 /// This builder is called on single-parameter unary operators that have scale
551 /// correctly. No pad_const is interpreted as zero-padding.
576 //===----------------------------------------------------------------------===//
578 //===----------------------------------------------------------------------===//
597 auto rankDiff = outShape.size() - shape.getRank();
632 outShape.reserve(inputShape.getRank() - 1);
659 // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
698 // Copy shapes until the dim is non-dynamic.
707 " on the non-axis dimension ",
856 // if either padding for dim i is -1, output dim is unknown
886 return dim == -1 ? ShapedType::kDynamic : dim;
907 // if size[i] is -1, all remaining elements in dimension i are included
914 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
917 // size[i] is not 0 and not < -1, and start[i] is in valid range
919 // input shape has unknown dim[i] - only valid if size[i] > 0
925 if (size[i] == -1) {
926 outputShape[i] = inputShape.getDimSize(i) - start[i];
991 // such as shift of mul op, so this is the only difference with the built-in
1007 llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1142 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
1144 "expect element of 'multiples' to be positive integer or -1.");
1203 if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
1207 if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
1234 int missingDims = llvm::count(getNewShape(), -1);
1236 return emitOpError() << "expected at most one target dimension to be -1";
1261 // We cannot infer anything from a rank-0 "permutation" tensor.
1280 // Rank-0 means no permutations matter.
1372 constantPerms, [](int32_t v) -> int64_t { return v; }))))
1474 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
1479 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
1570 // All TOSA reduce Ops have input, output and axis.
1576 op.emitOpError("reduce axis must not be negative");
1585 << inputRank << ") to be larger than reduce axis (" << reduceAxis
1599 << outputRank << ") to be larger than reduce axis (" << reduceAxis
1701 // Batch and number of channels are identical for pooling layer.
1709 int64_t padded = height + pad[0] + pad[1] - kernel[0];
1714 int64_t padded = width + pad[2] + pad[3] - kernel[1];
1733 // Input shape describes input width/height and batch.
1765 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1766 int64_t unstridedResult = inputSize - filterSize + 1;
1767 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1773 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1774 int64_t unstridedResult = inputSize - filterSize + 1;
1775 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1802 // Input shape describes input width/height and batch.
1833 int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
1834 int32_t unstridedResult = inputSize - filterSize + 1;
1835 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1841 int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
1842 int32_t unstridedResult = inputSize - filterSize + 1;
1843 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1849 int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
1850 int32_t unstridedResult = inputSize - filterSize + 1;
1851 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1898 // Input shape describes input width/height and batch.
1940 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1941 int64_t unstridedResult = inputSize - filterSize + 1;
1942 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1948 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1949 int64_t unstridedResult = inputSize - filterSize + 1;
1950 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1976 // Input shape describes input width/height and batch.
2010 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
2018 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
2164 p << " -> (" << getResultTypes() << ")";
2182 p.printOptionalAttrDict((*this)->getAttrs());
2191 return emitOpError("expected non-negative reverse axis");
2280 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
2283 //===----------------------------------------------------------------------===//
2285 //===----------------------------------------------------------------------===//
2300 for (auto v : op->getOperands()) {
2303 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
2304 return op->emitOpError("shape operand is not compile time resolvable");
2312 for (auto type : op->getOperandTypes()) {
2314 return op->emitOpError("must have operands with tosa shape type");
2317 for (auto type : op->getResultTypes()) {
2319 return op->emitOpError("must have result with tosa shape type");
2335 auto operandTypes = op->getOperandTypes();
2336 auto resultTypes = op->getResultTypes();
2338 auto rank = getRank(*op->getOperandTypes().begin());
2341 return op->emitOpError("operands don't have matching ranks");
2346 return op->emitOpError("result shape has different rank than operands");
2352 //===----------------------------------------------------------------------===//
2354 //===----------------------------------------------------------------------===//
2368 //===----------------------------------------------------------------------===//
2370 //===----------------------------------------------------------------------===//
2375 //===----------------------------------------------------------------------===//
2377 //===----------------------------------------------------------------------===//
2381 //===----------------------------------------------------------------------===//
2383 //===----------------------------------------------------------------------===//