Lines Matching full:op
40 static LogicalResult checkConstantOperandPad(Operation *op) {
41 if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
44 return op->emitOpError("padding of pad is not constant");
47 // Assume this op is zero-padding if padConst is not presented.
50 return op->emitOpError("pad_const of pad is not constant");
55 static LogicalResult checkConstantOperandTranspose(Operation *op) {
56 if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
59 return op->emitOpError("perms of transpose is not constant");
64 static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
65 if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
68 return op->emitOpError("weight of fully_connected is not constant");
72 return op->emitOpError("bias of fully_connected is not constant");
109 LogicalResult applyConstantOperandCheck(Operation *op) {
111 if (failed(checker(op)))
117 LogicalResult applyLevelCheck(Operation *op);
120 LogicalResult applyVariableCheck(Operation *op);
129 bool levelCheckKernel(Operation *op, int32_t v,
132 op->emitOpError() << "failed level check: " << checkDesc;
138 bool levelCheckStride(Operation *op, int32_t v,
141 op->emitOpError() << "failed level check: " << checkDesc;
147 bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
149 op->emitOpError() << "failed level check: " << checkDesc;
155 bool levelCheckRank(Operation *op, const Value &v,
159 op->emitOpError() << "failed level check: unranked tensor";
163 op->emitOpError() << "failed level check: " << checkDesc;
171 bool levelCheckRanksFor(Operation *op) {
172 if (dyn_cast<T>(op)) {
174 for (auto v : op->getOperands()) {
175 if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
178 for (auto v : op->getResults()) {
179 if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
186 bool levelCheckRanks(Operation *op) {
188 if (!levelCheckRanksFor<tosaOp##Op>(op)) \
259 // Pool Op: level check kernel/stride/pad values
261 bool levelCheckPool(Operation *op) {
262 if (auto poolOp = dyn_cast<T>(op)) {
264 if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
269 if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
274 if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
282 // Conv Op: level check dilation/stride/pad values
284 bool levelCheckConv(Operation *op) {
285 if (auto convOp = dyn_cast<T>(op)) {
288 if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
293 if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
298 if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
304 dyn_cast<ShapedType>(op->getOperand(1).getType())) {
306 if (isa<tosa::Conv2DOp>(op)) {
309 if (!levelCheckKernel(op, dilation[0] * shape[1],
311 !levelCheckKernel(op, dilation[1] * shape[2],
314 } else if (isa<tosa::Conv3DOp>(op)) {
317 if (!levelCheckKernel(op, dilation[0] * shape[1],
319 !levelCheckKernel(op, dilation[1] * shape[2],
321 !levelCheckKernel(op, dilation[2] * shape[3],
324 } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
327 if (!levelCheckKernel(op, dilation[0] * shape[0],
329 !levelCheckKernel(op, dilation[1] * shape[1],
338 // FFT op: level check H, W in input shape [N,H,W]
340 bool levelCheckFFT(Operation *op) {
341 if (isa<T>(op)) {
342 for (auto v : op->getOperands()) {
346 if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
347 !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
356 // TransposeConv2d op: level check kH/kW, outpad, and stride
357 bool levelCheckTransposeConv2d(Operation *op) {
358 if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
364 if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
365 !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
370 if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
375 if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
383 // Resize op: level check max scales
384 bool levelCheckResize(Operation *op) {
385 if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
391 if (!levelCheckScale(op, scaleYN / scaleYD,
393 !levelCheckScale(op, scaleXN / scaleXD,
419 bool CheckVariable(Operation *op);
420 bool CheckVariableReadOrWrite(Operation *op);
434 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
440 if (!levelCheckRanks(op)) {
445 if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
446 !levelCheckConv<tosa::Conv2DOp>(op) ||
447 !levelCheckConv<tosa::Conv3DOp>(op) ||
448 !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
449 !levelCheckFFT<tosa::FFT2dOp>(op) ||
450 !levelCheckPool<tosa::MaxPool2dOp>(op) ||
451 !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
452 !levelCheckResize(op)) {
465 bool TosaValidation::CheckVariable(Operation *op) {
466 if (isa<mlir::tosa::VariableOp>(op)) {
467 auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
470 op->emitOpError() << "name has already been declared";
474 auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
483 bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
484 if (isa<mlir::tosa::VariableReadOp>(op) ||
485 isa<mlir::tosa::VariableWriteOp>(op)) {
486 auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
489 op->emitOpError() << "name has not been declared";
495 for (auto v : op->getOperands()) {
498 op->emitOpError() << "operand type does not equal variable type";
503 for (auto v : op->getResults()) {
506 op->emitOpError() << "result type does not equal variable type";
515 LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
516 if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
552 getOperation().walk([&](Operation *op) {
553 if (op->getDialect() != tosaDialect)
556 for (Value operand : op->getOperands()) {
559 op->emitOpError() << "is not profile-aligned: element type "
564 for (Type resultTy : op->getResultTypes()) {
567 op->emitOpError() << "is not profile-aligned: element type "
575 if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
579 if (failed(applyLevelCheck(op)))
583 if (failed(applyVariableCheck(op)))