xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (revision f09db6a3af971ab7d9bbc7ba574a8dc0c10b2940)
1 //===- TosaValidation.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Validate if TOSA dialect input matchs with the specification for given
10 // requirements.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
15 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
16 
17 #include <string>
18 
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/TypeUtilities.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 
28 namespace mlir {
29 namespace tosa {
30 #define GEN_PASS_DEF_TOSAVALIDATION
31 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
32 } // namespace tosa
33 } // namespace mlir
34 
35 using namespace mlir;
36 using namespace mlir::tosa;
37 
38 namespace {
39 
40 static LogicalResult checkConstantOperandPad(Operation *op) {
41   if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
42     DenseElementsAttr paddings;
43     if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
44       return op->emitOpError("padding of pad is not constant");
45 
46     DenseElementsAttr padConst;
47     // Assume this op is zero-padding if padConst is not presented.
48     if (padOp.getPadConst() &&
49         !matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
50       return op->emitOpError("pad_const of pad is not constant");
51   }
52   return success();
53 }
54 
55 static LogicalResult checkConstantOperandTranspose(Operation *op) {
56   if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
57     DenseElementsAttr perms;
58     if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms)))
59       return op->emitOpError("perms of transpose is not constant");
60   }
61   return success();
62 }
63 
64 static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
65   if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
66     DenseElementsAttr weight;
67     if (!matchPattern(fcOp.getWeight(), m_Constant(&weight)))
68       return op->emitOpError("weight of fully_connected is not constant");
69 
70     DenseElementsAttr bias;
71     if (!matchPattern(fcOp.getBias(), m_Constant(&bias)))
72       return op->emitOpError("bias of fully_connected is not constant");
73   }
74   return success();
75 }
76 
77 struct TosaLevel {
78   int32_t MAX_RANK = 0;
79   int32_t MAX_KERNEL = 0;
80   int32_t MAX_STRIDE = 0;
81   int32_t MAX_SCALE = 0;
82 
83   // @todo: MAX_LOG2_SIZE value and checks
84 
85   bool operator==(const TosaLevel &rhs) {
86     return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
87            MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
88   }
89 };
90 
91 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
92 static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
93 
94 //===----------------------------------------------------------------------===//
95 // TOSA Validation Pass.
96 //===----------------------------------------------------------------------===//
97 
98 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
99 public:
100   explicit TosaValidation() { populateConstantOperandChecks(); }
101   explicit TosaValidation(const TosaValidationOptions &options)
102       : TosaValidation() {
103     this->profile = options.profile;
104     this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
105     this->level = options.level;
106   }
107   void runOnOperation() final;
108 
109   LogicalResult applyConstantOperandCheck(Operation *op) {
110     for (auto &checker : constCheckers) {
111       if (failed(checker(op)))
112         return failure();
113     }
114     return success();
115   }
116 
117   LogicalResult applyLevelCheck(Operation *op);
118 
119   // check variable read/write data types against variable declarations
120   LogicalResult applyVariableCheck(Operation *op);
121 
122 private:
123   void populateConstantOperandChecks() {
124     constCheckers.emplace_back(checkConstantOperandPad);
125     constCheckers.emplace_back(checkConstantOperandTranspose);
126     constCheckers.emplace_back(checkConstantOperandFullyConnected);
127   }
128 
129   bool levelCheckKernel(Operation *op, int32_t v,
130                         const std::string &checkDesc) {
131     if (v > tosaLevel.MAX_KERNEL) {
132       op->emitOpError() << "failed level check: " << checkDesc;
133       return false;
134     }
135     return true;
136   }
137 
138   bool levelCheckStride(Operation *op, int32_t v,
139                         const std::string &checkDesc) {
140     if (v > tosaLevel.MAX_STRIDE) {
141       op->emitOpError() << "failed level check: " << checkDesc;
142       return false;
143     }
144     return true;
145   }
146 
147   bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
148     if (v > tosaLevel.MAX_SCALE) {
149       op->emitOpError() << "failed level check: " << checkDesc;
150       return false;
151     }
152     return true;
153   }
154 
155   bool levelCheckRank(Operation *op, const Value &v,
156                       const std::string &checkDesc) {
157     if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
158       if (!type.hasRank()) {
159         op->emitOpError() << "failed level check: unranked tensor";
160         return false;
161       }
162       if (type.getRank() > tosaLevel.MAX_RANK) {
163         op->emitOpError() << "failed level check: " << checkDesc;
164         return false;
165       }
166     }
167     return true;
168   }
169 
170   template <typename T>
171   bool levelCheckRanksFor(Operation *op) {
172     if (dyn_cast<T>(op)) {
173       // level check ranks of all operands and results
174       for (auto v : op->getOperands()) {
175         if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
176           return false;
177       }
178       for (auto v : op->getResults()) {
179         if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
180           return false;
181       }
182     }
183     return true;
184   }
185 
186   bool levelCheckRanks(Operation *op) {
187 #define CHECK_RANKS_FOR(tosaOp)                                                \
188   if (!levelCheckRanksFor<tosaOp##Op>(op))                                     \
189     return false;
190 
191     // tensor operators:
192     CHECK_RANKS_FOR(ArgMax);
193     // all activation functions:
194     CHECK_RANKS_FOR(Clamp);
195     CHECK_RANKS_FOR(Sigmoid);
196     CHECK_RANKS_FOR(Tanh);
197     // all elementwise binary operators:
198     CHECK_RANKS_FOR(Add);
199     CHECK_RANKS_FOR(ArithmeticRightShift);
200     CHECK_RANKS_FOR(BitwiseAnd);
201     CHECK_RANKS_FOR(BitwiseOr);
202     CHECK_RANKS_FOR(BitwiseXor);
203     CHECK_RANKS_FOR(IntDiv);
204     CHECK_RANKS_FOR(LogicalAnd);
205     CHECK_RANKS_FOR(LogicalLeftShift);
206     CHECK_RANKS_FOR(LogicalRightShift);
207     CHECK_RANKS_FOR(LogicalOr);
208     CHECK_RANKS_FOR(LogicalXor);
209     CHECK_RANKS_FOR(Maximum);
210     CHECK_RANKS_FOR(Minimum);
211     CHECK_RANKS_FOR(Mul);
212     CHECK_RANKS_FOR(Pow);
213     CHECK_RANKS_FOR(Sub);
214     CHECK_RANKS_FOR(Table);
215     // all elementwise unary operators:
216     CHECK_RANKS_FOR(Abs);
217     CHECK_RANKS_FOR(BitwiseNot);
218     CHECK_RANKS_FOR(Ceil);
219     CHECK_RANKS_FOR(Clz);
220     CHECK_RANKS_FOR(Exp);
221     CHECK_RANKS_FOR(Floor);
222     CHECK_RANKS_FOR(Log);
223     CHECK_RANKS_FOR(LogicalNot);
224     CHECK_RANKS_FOR(Negate);
225     CHECK_RANKS_FOR(Reciprocal);
226     CHECK_RANKS_FOR(Rsqrt);
227     // all elementwise ternary operators:
228     CHECK_RANKS_FOR(Select);
229     // all comparison operators:
230     CHECK_RANKS_FOR(Equal);
231     CHECK_RANKS_FOR(Greater);
232     CHECK_RANKS_FOR(GreaterEqual);
233     // all reduction operators:
234     CHECK_RANKS_FOR(ReduceAll);
235     CHECK_RANKS_FOR(ReduceAny);
236     CHECK_RANKS_FOR(ReduceMax);
237     CHECK_RANKS_FOR(ReduceMin);
238     CHECK_RANKS_FOR(ReduceProd);
239     CHECK_RANKS_FOR(ReduceSum);
240     // all data layout operators:
241     CHECK_RANKS_FOR(Concat);
242     CHECK_RANKS_FOR(Pad);
243     CHECK_RANKS_FOR(Reshape);
244     CHECK_RANKS_FOR(Reverse);
245     CHECK_RANKS_FOR(Slice);
246     CHECK_RANKS_FOR(Tile);
247     CHECK_RANKS_FOR(Transpose);
248     // all type conversion operators:
249     CHECK_RANKS_FOR(Cast);
250     CHECK_RANKS_FOR(Rescale);
251     // all data nodes operators:
252     CHECK_RANKS_FOR(Const);
253     CHECK_RANKS_FOR(Identity);
254 
255 #undef CHECK_RANKS_FOR
256     return true;
257   }
258 
259   // Pool Op: level check kernel/stride/pad values
260   template <typename T>
261   bool levelCheckPool(Operation *op) {
262     if (auto poolOp = dyn_cast<T>(op)) {
263       for (auto k : poolOp.getKernel()) {
264         if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
265           return false;
266         }
267       }
268       for (auto s : poolOp.getStride()) {
269         if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
270           return false;
271         }
272       }
273       for (auto p : poolOp.getPad()) {
274         if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
275           return false;
276         }
277       }
278     }
279     return true;
280   }
281 
282   // Conv Op: level check dilation/stride/pad values
283   template <typename T>
284   bool levelCheckConv(Operation *op) {
285     if (auto convOp = dyn_cast<T>(op)) {
286 
287       for (auto k : convOp.getDilation()) {
288         if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
289           return false;
290         }
291       }
292       for (auto p : convOp.getPad()) {
293         if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
294           return false;
295         }
296       }
297       for (auto s : convOp.getStride()) {
298         if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
299           return false;
300         }
301       }
302       auto dilation = convOp.getDilation();
303       if (ShapedType weightType =
304               dyn_cast<ShapedType>(op->getOperand(1).getType())) {
305         auto shape = weightType.getShape();
306         if (isa<tosa::Conv2DOp>(op)) {
307           assert(shape.size() == 4);
308           assert(dilation.size() == 2);
309           if (!levelCheckKernel(op, dilation[0] * shape[1],
310                                 "dilation_y * KH <= MAX_KERNEL)") ||
311               !levelCheckKernel(op, dilation[1] * shape[2],
312                                 "dilation_x * KW <= MAX_KERNEL)"))
313             return false;
314         } else if (isa<tosa::Conv3DOp>(op)) {
315           assert(shape.size() == 5);
316           assert(dilation.size() == 3);
317           if (!levelCheckKernel(op, dilation[0] * shape[1],
318                                 "dilation_d * KD <= MAX_KERNEL)") ||
319               !levelCheckKernel(op, dilation[1] * shape[2],
320                                 "dilation_y * KH <= MAX_KERNEL)") ||
321               !levelCheckKernel(op, dilation[2] * shape[3],
322                                 "dilation_x * KW <= MAX_KERNEL)"))
323             return false;
324         } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
325           assert(shape.size() == 4);
326           assert(dilation.size() == 2);
327           if (!levelCheckKernel(op, dilation[0] * shape[0],
328                                 "dilation_y * KH <= MAX_KERNEL)") ||
329               !levelCheckKernel(op, dilation[1] * shape[1],
330                                 "dilation_x * KW <= MAX_KERNEL)"))
331             return false;
332         }
333       }
334     }
335     return true;
336   }
337 
338   // FFT op: level check H, W in input shape [N,H,W]
339   template <typename T>
340   bool levelCheckFFT(Operation *op) {
341     if (isa<T>(op)) {
342       for (auto v : op->getOperands()) {
343         if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
344           auto shape = type.getShape();
345           assert(shape.size() == 3);
346           if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
347               !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
348             return false;
349           }
350         }
351       }
352     }
353     return true;
354   }
355 
356   // TransposeConv2d op: level check kH/kW, outpad, and stride
357   bool levelCheckTransposeConv2d(Operation *op) {
358     if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
359       if (ShapedType filterType =
360               dyn_cast<ShapedType>(transpose.getFilter().getType())) {
361         auto shape = filterType.getShape();
362         assert(shape.size() == 4);
363         // level check kernel sizes for kH and KW
364         if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
365             !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
366           return false;
367         }
368       }
369       for (auto p : transpose.getOutPad()) {
370         if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
371           return false;
372         }
373       }
374       for (auto s : transpose.getStride()) {
375         if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
376           return false;
377         }
378       }
379     }
380     return true;
381   }
382 
383   // Resize op: level check max scales
384   bool levelCheckResize(Operation *op) {
385     if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
386       auto scale = resize.getScale();
387       int16_t scaleYN = scale[0];
388       int16_t scaleYD = scale[1];
389       int16_t scaleXN = scale[2];
390       int16_t scaleXD = scale[3];
391       if (!levelCheckScale(op, scaleYN / scaleYD,
392                            "scale_y_n/scale_y_d <= MAX_SCALE") ||
393           !levelCheckScale(op, scaleXN / scaleXD,
394                            "scale_x_n/scale_x_d <= MAX_SCALE")) {
395         return false;
396       }
397     }
398     return true;
399   }
400 
401   // configure profile and level values from pass options profileName and
402   // levelName
403   void configLevelAndProfile() {
404     tosaLevel = TOSA_LEVEL_NONE;
405     if (level == TosaLevelEnum::EightK) {
406       tosaLevel = TOSA_LEVEL_EIGHTK;
407     }
408 
409     if (!profile.empty()) {
410       for (std::string &prof : profile) {
411         auto profSymbol = symbolizeTosaProfileEnum(prof);
412         if (profSymbol) {
413           enabled_profiles.push_back(profSymbol.value());
414         }
415       }
416     }
417   }
418 
419   bool CheckVariable(Operation *op);
420   bool CheckVariableReadOrWrite(Operation *op);
421 
422   bool isValidElementType(Type type);
423   bool isEnabledProfile(TosaProfileEnum prof) {
424     return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) !=
425            std::end(enabled_profiles);
426   }
427 
428   SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
429   SmallVector<TosaProfileEnum, 3> enabled_profiles;
430   TosaLevel tosaLevel;
431   DenseMap<StringAttr, mlir::Type> variablesMap;
432 };
433 
434 LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
435   if (tosaLevel == TOSA_LEVEL_NONE) {
436     // no need to do level checks
437     return success();
438   }
439 
440   if (!levelCheckRanks(op)) {
441     return failure();
442   }
443 
444   // additional level checks from spec 0.70
445   if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
446       !levelCheckConv<tosa::Conv2DOp>(op) ||
447       !levelCheckConv<tosa::Conv3DOp>(op) ||
448       !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
449       !levelCheckFFT<tosa::FFT2dOp>(op) ||
450       !levelCheckPool<tosa::MaxPool2dOp>(op) ||
451       !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
452       !levelCheckResize(op)) {
453     return failure();
454   }
455 
456   return success();
457 }
458 
459 inline bool CompatibleTypes(const mlir::Type &type,
460                             const mlir::Type &declaredType) {
461   // for now, simply use type equality comparison
462   return type == declaredType;
463 }
464 
465 bool TosaValidation::CheckVariable(Operation *op) {
466   if (isa<mlir::tosa::VariableOp>(op)) {
467     auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
468 
469     if (variablesMap.count(nameAttr)) {
470       op->emitOpError() << "name has already been declared";
471       return false;
472     }
473 
474     auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
475     mlir::Type type = typeAttr.getValue();
476 
477     variablesMap[nameAttr] = type;
478   }
479 
480   return true;
481 }
482 
483 bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
484   if (isa<mlir::tosa::VariableReadOp>(op) ||
485       isa<mlir::tosa::VariableWriteOp>(op)) {
486     auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
487 
488     if (!variablesMap.count(nameAttr)) {
489       op->emitOpError() << "name has not been declared";
490       return false;
491     }
492 
493     auto varType = variablesMap[nameAttr];
494 
495     for (auto v : op->getOperands()) {
496       auto type = v.getType();
497       if (!CompatibleTypes(type, varType)) {
498         op->emitOpError() << "operand type does not equal variable type";
499         return false;
500       }
501     }
502 
503     for (auto v : op->getResults()) {
504       auto type = v.getType();
505       if (!CompatibleTypes(type, varType)) {
506         op->emitOpError() << "result type does not equal variable type";
507         return false;
508       }
509     }
510   }
511 
512   return true;
513 }
514 
515 LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
516   if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
517     return failure();
518   }
519   return success();
520 }
521 
522 bool TosaValidation::isValidElementType(Type type) {
523   if (isa<FloatType>(type)) {
524     if (!isEnabledProfile(TosaProfileEnum::MainInference))
525       return false;
526     return type.isF32() || type.isF16() || type.isBF16();
527   } else if (auto intTy = dyn_cast<IntegerType>(type)) {
528     if (intTy.isSignless()) {
529       switch (intTy.getWidth()) {
530       case 1:
531       case 4:
532       case 8:
533       case 16:
534       case 32:
535       case 48:
536         return true;
537       }
538     }
539   } else if (mlir::isa<tosa::shapeType>(type)) {
540     return true;
541   }
542   return false;
543 }
544 
545 void TosaValidation::runOnOperation() {
546   configLevelAndProfile();
547 
548   TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
549   if (!tosaDialect)
550     return;
551 
552   getOperation().walk([&](Operation *op) {
553     if (op->getDialect() != tosaDialect)
554       return;
555 
556     for (Value operand : op->getOperands()) {
557       auto elementTy = getElementTypeOrSelf(operand);
558       if (!isValidElementType(elementTy)) {
559         op->emitOpError() << "is not profile-aligned: element type "
560                           << elementTy << " is not legal";
561         return signalPassFailure();
562       }
563     }
564     for (Type resultTy : op->getResultTypes()) {
565       auto elementTy = getElementTypeOrSelf(resultTy);
566       if (!isValidElementType(elementTy)) {
567         op->emitOpError() << "is not profile-aligned: element type "
568                           << elementTy << " is not legal";
569         return signalPassFailure();
570       }
571     }
572 
573     // Some uses of TOSA rely on the constant operands of particular
574     // operations.
575     if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
576       signalPassFailure();
577 
578     // do level checks
579     if (failed(applyLevelCheck(op)))
580       signalPassFailure();
581 
582     // do variable type checks
583     if (failed(applyVariableCheck(op)))
584       signalPassFailure();
585   });
586 }
587 } // namespace
588