xref: /llvm-project/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (revision f09db6a3af971ab7d9bbc7ba574a8dc0c10b2940)
1940d3e08STatWai Chong //===- TosaValidation.cpp ------------------------------------------------===//
2940d3e08STatWai Chong //
3940d3e08STatWai Chong // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4940d3e08STatWai Chong // See https://llvm.org/LICENSE.txt for license information.
5940d3e08STatWai Chong // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6940d3e08STatWai Chong //
7940d3e08STatWai Chong //===----------------------------------------------------------------------===//
8940d3e08STatWai Chong //
9940d3e08STatWai Chong // Validate if TOSA dialect input matchs with the specification for given
10940d3e08STatWai Chong // requirements.
11940d3e08STatWai Chong //
12940d3e08STatWai Chong //===----------------------------------------------------------------------===//
13940d3e08STatWai Chong 
14940d3e08STatWai Chong #include "mlir/Dialect/Tosa/Transforms/Passes.h"
15940d3e08STatWai Chong #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
16940d3e08STatWai Chong 
17af972f01STai Ly #include <string>
18af972f01STai Ly 
19940d3e08STatWai Chong #include "mlir/Dialect/Func/IR/FuncOps.h"
20940d3e08STatWai Chong #include "mlir/Dialect/Tosa/IR/TosaOps.h"
21940d3e08STatWai Chong #include "mlir/IR/Builders.h"
22940d3e08STatWai Chong #include "mlir/IR/BuiltinOps.h"
23940d3e08STatWai Chong #include "mlir/IR/Matchers.h"
24940d3e08STatWai Chong #include "mlir/IR/TypeUtilities.h"
25940d3e08STatWai Chong #include "mlir/Pass/Pass.h"
26940d3e08STatWai Chong #include "mlir/Transforms/DialectConversion.h"
27940d3e08STatWai Chong 
28940d3e08STatWai Chong namespace mlir {
29940d3e08STatWai Chong namespace tosa {
30940d3e08STatWai Chong #define GEN_PASS_DEF_TOSAVALIDATION
31940d3e08STatWai Chong #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
32940d3e08STatWai Chong } // namespace tosa
33940d3e08STatWai Chong } // namespace mlir
34940d3e08STatWai Chong 
35940d3e08STatWai Chong using namespace mlir;
36940d3e08STatWai Chong using namespace mlir::tosa;
37940d3e08STatWai Chong 
38940d3e08STatWai Chong namespace {
39940d3e08STatWai Chong 
4008b0977aSTatWai Chong static LogicalResult checkConstantOperandPad(Operation *op) {
413745e708STai Ly   if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
4208b0977aSTatWai Chong     DenseElementsAttr paddings;
433745e708STai Ly     if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
4408b0977aSTatWai Chong       return op->emitOpError("padding of pad is not constant");
4508b0977aSTatWai Chong 
463745e708STai Ly     DenseElementsAttr padConst;
473745e708STai Ly     // Assume this op is zero-padding if padConst is not presented.
483745e708STai Ly     if (padOp.getPadConst() &&
493745e708STai Ly         !matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
5008b0977aSTatWai Chong       return op->emitOpError("pad_const of pad is not constant");
5108b0977aSTatWai Chong   }
5208b0977aSTatWai Chong   return success();
5308b0977aSTatWai Chong }
5408b0977aSTatWai Chong 
5508b0977aSTatWai Chong static LogicalResult checkConstantOperandTranspose(Operation *op) {
563745e708STai Ly   if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
5708b0977aSTatWai Chong     DenseElementsAttr perms;
583745e708STai Ly     if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms)))
5908b0977aSTatWai Chong       return op->emitOpError("perms of transpose is not constant");
6008b0977aSTatWai Chong   }
6108b0977aSTatWai Chong   return success();
6208b0977aSTatWai Chong }
6308b0977aSTatWai Chong 
6408b0977aSTatWai Chong static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
653745e708STai Ly   if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
6608b0977aSTatWai Chong     DenseElementsAttr weight;
673745e708STai Ly     if (!matchPattern(fcOp.getWeight(), m_Constant(&weight)))
6808b0977aSTatWai Chong       return op->emitOpError("weight of fully_connected is not constant");
6908b0977aSTatWai Chong 
7008b0977aSTatWai Chong     DenseElementsAttr bias;
713745e708STai Ly     if (!matchPattern(fcOp.getBias(), m_Constant(&bias)))
7208b0977aSTatWai Chong       return op->emitOpError("bias of fully_connected is not constant");
7308b0977aSTatWai Chong   }
7408b0977aSTatWai Chong   return success();
7508b0977aSTatWai Chong }
7608b0977aSTatWai Chong 
773745e708STai Ly struct TosaLevel {
78d713a002STai Ly   int32_t MAX_RANK = 0;
79d713a002STai Ly   int32_t MAX_KERNEL = 0;
80d713a002STai Ly   int32_t MAX_STRIDE = 0;
81d713a002STai Ly   int32_t MAX_SCALE = 0;
82d713a002STai Ly 
83d713a002STai Ly   // @todo: MAX_LOG2_SIZE value and checks
84d713a002STai Ly 
853745e708STai Ly   bool operator==(const TosaLevel &rhs) {
86d713a002STai Ly     return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
87d713a002STai Ly            MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
88d713a002STai Ly   }
89d713a002STai Ly };
90d713a002STai Ly 
913745e708STai Ly static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
923745e708STai Ly static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
93d713a002STai Ly 
94940d3e08STatWai Chong //===----------------------------------------------------------------------===//
95940d3e08STatWai Chong // TOSA Validation Pass.
96940d3e08STatWai Chong //===----------------------------------------------------------------------===//
97940d3e08STatWai Chong 
98940d3e08STatWai Chong struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
99940d3e08STatWai Chong public:
10008b0977aSTatWai Chong   explicit TosaValidation() { populateConstantOperandChecks(); }
101af972f01STai Ly   explicit TosaValidation(const TosaValidationOptions &options)
102af972f01STai Ly       : TosaValidation() {
10332b7c1ffSBenjamin Maxwell     this->profile = options.profile;
104af972f01STai Ly     this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
10532b7c1ffSBenjamin Maxwell     this->level = options.level;
10632b7c1ffSBenjamin Maxwell   }
107af972f01STai Ly   void runOnOperation() final;
108940d3e08STatWai Chong 
10908b0977aSTatWai Chong   LogicalResult applyConstantOperandCheck(Operation *op) {
1103745e708STai Ly     for (auto &checker : constCheckers) {
11108b0977aSTatWai Chong       if (failed(checker(op)))
11208b0977aSTatWai Chong         return failure();
11308b0977aSTatWai Chong     }
11408b0977aSTatWai Chong     return success();
11508b0977aSTatWai Chong   }
11608b0977aSTatWai Chong 
117d713a002STai Ly   LogicalResult applyLevelCheck(Operation *op);
118d713a002STai Ly 
119af972f01STai Ly   // check variable read/write data types against variable declarations
120af972f01STai Ly   LogicalResult applyVariableCheck(Operation *op);
121af972f01STai Ly 
12208b0977aSTatWai Chong private:
12308b0977aSTatWai Chong   void populateConstantOperandChecks() {
1243745e708STai Ly     constCheckers.emplace_back(checkConstantOperandPad);
1253745e708STai Ly     constCheckers.emplace_back(checkConstantOperandTranspose);
1263745e708STai Ly     constCheckers.emplace_back(checkConstantOperandFullyConnected);
12708b0977aSTatWai Chong   }
12808b0977aSTatWai Chong 
129d713a002STai Ly   bool levelCheckKernel(Operation *op, int32_t v,
1303745e708STai Ly                         const std::string &checkDesc) {
1313745e708STai Ly     if (v > tosaLevel.MAX_KERNEL) {
1323745e708STai Ly       op->emitOpError() << "failed level check: " << checkDesc;
133d713a002STai Ly       return false;
134d713a002STai Ly     }
135d713a002STai Ly     return true;
136d713a002STai Ly   }
137940d3e08STatWai Chong 
138d713a002STai Ly   bool levelCheckStride(Operation *op, int32_t v,
1393745e708STai Ly                         const std::string &checkDesc) {
1403745e708STai Ly     if (v > tosaLevel.MAX_STRIDE) {
1413745e708STai Ly       op->emitOpError() << "failed level check: " << checkDesc;
142d713a002STai Ly       return false;
143d713a002STai Ly     }
144d713a002STai Ly     return true;
145d713a002STai Ly   }
146d713a002STai Ly 
1473745e708STai Ly   bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
1483745e708STai Ly     if (v > tosaLevel.MAX_SCALE) {
1493745e708STai Ly       op->emitOpError() << "failed level check: " << checkDesc;
150d713a002STai Ly       return false;
151d713a002STai Ly     }
152d713a002STai Ly     return true;
153d713a002STai Ly   }
154d713a002STai Ly 
155d713a002STai Ly   bool levelCheckRank(Operation *op, const Value &v,
1563745e708STai Ly                       const std::string &checkDesc) {
157d713a002STai Ly     if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
1583651f377SSarthak Gupta       if (!type.hasRank()) {
1593651f377SSarthak Gupta         op->emitOpError() << "failed level check: unranked tensor";
1603651f377SSarthak Gupta         return false;
1613651f377SSarthak Gupta       }
1623745e708STai Ly       if (type.getRank() > tosaLevel.MAX_RANK) {
1633745e708STai Ly         op->emitOpError() << "failed level check: " << checkDesc;
164d713a002STai Ly         return false;
165d713a002STai Ly       }
166d713a002STai Ly     }
167d713a002STai Ly     return true;
168d713a002STai Ly   }
169d713a002STai Ly 
170d713a002STai Ly   template <typename T>
171d713a002STai Ly   bool levelCheckRanksFor(Operation *op) {
172d713a002STai Ly     if (dyn_cast<T>(op)) {
173d713a002STai Ly       // level check ranks of all operands and results
174d713a002STai Ly       for (auto v : op->getOperands()) {
175d713a002STai Ly         if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
176d713a002STai Ly           return false;
177d713a002STai Ly       }
178d713a002STai Ly       for (auto v : op->getResults()) {
179d713a002STai Ly         if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
180d713a002STai Ly           return false;
181d713a002STai Ly       }
182d713a002STai Ly     }
183d713a002STai Ly     return true;
184d713a002STai Ly   }
185d713a002STai Ly 
186d713a002STai Ly   bool levelCheckRanks(Operation *op) {
1873745e708STai Ly #define CHECK_RANKS_FOR(tosaOp)                                                \
1883745e708STai Ly   if (!levelCheckRanksFor<tosaOp##Op>(op))                                     \
189d713a002STai Ly     return false;
190d713a002STai Ly 
191d713a002STai Ly     // tensor operators:
192d713a002STai Ly     CHECK_RANKS_FOR(ArgMax);
193d713a002STai Ly     // all activation functions:
194d713a002STai Ly     CHECK_RANKS_FOR(Clamp);
195d713a002STai Ly     CHECK_RANKS_FOR(Sigmoid);
196d713a002STai Ly     CHECK_RANKS_FOR(Tanh);
197d713a002STai Ly     // all elementwise binary operators:
198d713a002STai Ly     CHECK_RANKS_FOR(Add);
199d713a002STai Ly     CHECK_RANKS_FOR(ArithmeticRightShift);
200d713a002STai Ly     CHECK_RANKS_FOR(BitwiseAnd);
201d713a002STai Ly     CHECK_RANKS_FOR(BitwiseOr);
202d713a002STai Ly     CHECK_RANKS_FOR(BitwiseXor);
20382383d5fSTai Ly     CHECK_RANKS_FOR(IntDiv);
204d713a002STai Ly     CHECK_RANKS_FOR(LogicalAnd);
205d713a002STai Ly     CHECK_RANKS_FOR(LogicalLeftShift);
206d713a002STai Ly     CHECK_RANKS_FOR(LogicalRightShift);
207d713a002STai Ly     CHECK_RANKS_FOR(LogicalOr);
208d713a002STai Ly     CHECK_RANKS_FOR(LogicalXor);
209d713a002STai Ly     CHECK_RANKS_FOR(Maximum);
210d713a002STai Ly     CHECK_RANKS_FOR(Minimum);
211d713a002STai Ly     CHECK_RANKS_FOR(Mul);
212d713a002STai Ly     CHECK_RANKS_FOR(Pow);
213d713a002STai Ly     CHECK_RANKS_FOR(Sub);
214d713a002STai Ly     CHECK_RANKS_FOR(Table);
215d713a002STai Ly     // all elementwise unary operators:
216d713a002STai Ly     CHECK_RANKS_FOR(Abs);
217d713a002STai Ly     CHECK_RANKS_FOR(BitwiseNot);
218d713a002STai Ly     CHECK_RANKS_FOR(Ceil);
219d713a002STai Ly     CHECK_RANKS_FOR(Clz);
220d713a002STai Ly     CHECK_RANKS_FOR(Exp);
221d713a002STai Ly     CHECK_RANKS_FOR(Floor);
222d713a002STai Ly     CHECK_RANKS_FOR(Log);
223d713a002STai Ly     CHECK_RANKS_FOR(LogicalNot);
224d713a002STai Ly     CHECK_RANKS_FOR(Negate);
225d713a002STai Ly     CHECK_RANKS_FOR(Reciprocal);
226d713a002STai Ly     CHECK_RANKS_FOR(Rsqrt);
227d713a002STai Ly     // all elementwise ternary operators:
228d713a002STai Ly     CHECK_RANKS_FOR(Select);
229d713a002STai Ly     // all comparison operators:
230d713a002STai Ly     CHECK_RANKS_FOR(Equal);
231d713a002STai Ly     CHECK_RANKS_FOR(Greater);
232d713a002STai Ly     CHECK_RANKS_FOR(GreaterEqual);
233d713a002STai Ly     // all reduction operators:
234d713a002STai Ly     CHECK_RANKS_FOR(ReduceAll);
235d713a002STai Ly     CHECK_RANKS_FOR(ReduceAny);
236d713a002STai Ly     CHECK_RANKS_FOR(ReduceMax);
237d713a002STai Ly     CHECK_RANKS_FOR(ReduceMin);
238d713a002STai Ly     CHECK_RANKS_FOR(ReduceProd);
239d713a002STai Ly     CHECK_RANKS_FOR(ReduceSum);
240d713a002STai Ly     // all data layout operators:
241d713a002STai Ly     CHECK_RANKS_FOR(Concat);
242d713a002STai Ly     CHECK_RANKS_FOR(Pad);
243d713a002STai Ly     CHECK_RANKS_FOR(Reshape);
244d713a002STai Ly     CHECK_RANKS_FOR(Reverse);
245d713a002STai Ly     CHECK_RANKS_FOR(Slice);
246d713a002STai Ly     CHECK_RANKS_FOR(Tile);
247d713a002STai Ly     CHECK_RANKS_FOR(Transpose);
248d713a002STai Ly     // all type conversion operators:
249d713a002STai Ly     CHECK_RANKS_FOR(Cast);
250d713a002STai Ly     CHECK_RANKS_FOR(Rescale);
251d713a002STai Ly     // all data nodes operators:
252d713a002STai Ly     CHECK_RANKS_FOR(Const);
253d713a002STai Ly     CHECK_RANKS_FOR(Identity);
254d713a002STai Ly 
255d713a002STai Ly #undef CHECK_RANKS_FOR
256d713a002STai Ly     return true;
257d713a002STai Ly   }
258d713a002STai Ly 
259d713a002STai Ly   // Pool Op: level check kernel/stride/pad values
260d713a002STai Ly   template <typename T>
261d713a002STai Ly   bool levelCheckPool(Operation *op) {
2623745e708STai Ly     if (auto poolOp = dyn_cast<T>(op)) {
2633745e708STai Ly       for (auto k : poolOp.getKernel()) {
264d713a002STai Ly         if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
265d713a002STai Ly           return false;
266d713a002STai Ly         }
267d713a002STai Ly       }
2683745e708STai Ly       for (auto s : poolOp.getStride()) {
269d713a002STai Ly         if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
270d713a002STai Ly           return false;
271d713a002STai Ly         }
272d713a002STai Ly       }
2733745e708STai Ly       for (auto p : poolOp.getPad()) {
274d713a002STai Ly         if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
275d713a002STai Ly           return false;
276d713a002STai Ly         }
277d713a002STai Ly       }
278d713a002STai Ly     }
279d713a002STai Ly     return true;
280d713a002STai Ly   }
281d713a002STai Ly 
282d713a002STai Ly   // Conv Op: level check dilation/stride/pad values
283d713a002STai Ly   template <typename T>
284d713a002STai Ly   bool levelCheckConv(Operation *op) {
2853745e708STai Ly     if (auto convOp = dyn_cast<T>(op)) {
286d713a002STai Ly 
2873745e708STai Ly       for (auto k : convOp.getDilation()) {
288d713a002STai Ly         if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
289d713a002STai Ly           return false;
290d713a002STai Ly         }
291d713a002STai Ly       }
2923745e708STai Ly       for (auto p : convOp.getPad()) {
293d713a002STai Ly         if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
294d713a002STai Ly           return false;
295d713a002STai Ly         }
296d713a002STai Ly       }
2973745e708STai Ly       for (auto s : convOp.getStride()) {
298d713a002STai Ly         if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
299d713a002STai Ly           return false;
300d713a002STai Ly         }
301d713a002STai Ly       }
3023745e708STai Ly       auto dilation = convOp.getDilation();
3033745e708STai Ly       if (ShapedType weightType =
304d713a002STai Ly               dyn_cast<ShapedType>(op->getOperand(1).getType())) {
3053745e708STai Ly         auto shape = weightType.getShape();
306d713a002STai Ly         if (isa<tosa::Conv2DOp>(op)) {
307d713a002STai Ly           assert(shape.size() == 4);
308d713a002STai Ly           assert(dilation.size() == 2);
309d713a002STai Ly           if (!levelCheckKernel(op, dilation[0] * shape[1],
310d713a002STai Ly                                 "dilation_y * KH <= MAX_KERNEL)") ||
311d713a002STai Ly               !levelCheckKernel(op, dilation[1] * shape[2],
312d713a002STai Ly                                 "dilation_x * KW <= MAX_KERNEL)"))
313d713a002STai Ly             return false;
314d713a002STai Ly         } else if (isa<tosa::Conv3DOp>(op)) {
315d713a002STai Ly           assert(shape.size() == 5);
316d713a002STai Ly           assert(dilation.size() == 3);
317d713a002STai Ly           if (!levelCheckKernel(op, dilation[0] * shape[1],
318d713a002STai Ly                                 "dilation_d * KD <= MAX_KERNEL)") ||
319d713a002STai Ly               !levelCheckKernel(op, dilation[1] * shape[2],
320d713a002STai Ly                                 "dilation_y * KH <= MAX_KERNEL)") ||
321d713a002STai Ly               !levelCheckKernel(op, dilation[2] * shape[3],
322d713a002STai Ly                                 "dilation_x * KW <= MAX_KERNEL)"))
323d713a002STai Ly             return false;
324d713a002STai Ly         } else if (isa<tosa::DepthwiseConv2DOp>(op)) {
325d713a002STai Ly           assert(shape.size() == 4);
326d713a002STai Ly           assert(dilation.size() == 2);
327d713a002STai Ly           if (!levelCheckKernel(op, dilation[0] * shape[0],
328d713a002STai Ly                                 "dilation_y * KH <= MAX_KERNEL)") ||
329d713a002STai Ly               !levelCheckKernel(op, dilation[1] * shape[1],
330d713a002STai Ly                                 "dilation_x * KW <= MAX_KERNEL)"))
331d713a002STai Ly             return false;
332d713a002STai Ly         }
333d713a002STai Ly       }
334d713a002STai Ly     }
335d713a002STai Ly     return true;
336d713a002STai Ly   }
337d713a002STai Ly 
338d713a002STai Ly   // FFT op: level check H, W in input shape [N,H,W]
339d713a002STai Ly   template <typename T>
340d713a002STai Ly   bool levelCheckFFT(Operation *op) {
341d713a002STai Ly     if (isa<T>(op)) {
342d713a002STai Ly       for (auto v : op->getOperands()) {
343d713a002STai Ly         if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
344d713a002STai Ly           auto shape = type.getShape();
345d713a002STai Ly           assert(shape.size() == 3);
346d713a002STai Ly           if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
347d713a002STai Ly               !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
348d713a002STai Ly             return false;
349d713a002STai Ly           }
350d713a002STai Ly         }
351d713a002STai Ly       }
352d713a002STai Ly     }
353d713a002STai Ly     return true;
354d713a002STai Ly   }
355d713a002STai Ly 
356d713a002STai Ly   // TransposeConv2d op: level check kH/kW, outpad, and stride
357d713a002STai Ly   bool levelCheckTransposeConv2d(Operation *op) {
358d713a002STai Ly     if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
3593745e708STai Ly       if (ShapedType filterType =
360a5757c5bSChristian Sigg               dyn_cast<ShapedType>(transpose.getFilter().getType())) {
3613745e708STai Ly         auto shape = filterType.getShape();
362d713a002STai Ly         assert(shape.size() == 4);
363d713a002STai Ly         // level check kernel sizes for kH and KW
364d713a002STai Ly         if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
365d713a002STai Ly             !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
366d713a002STai Ly           return false;
367d713a002STai Ly         }
368d713a002STai Ly       }
369d713a002STai Ly       for (auto p : transpose.getOutPad()) {
370d713a002STai Ly         if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
371d713a002STai Ly           return false;
372d713a002STai Ly         }
373d713a002STai Ly       }
374d713a002STai Ly       for (auto s : transpose.getStride()) {
375d713a002STai Ly         if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
376d713a002STai Ly           return false;
377d713a002STai Ly         }
378d713a002STai Ly       }
379d713a002STai Ly     }
380d713a002STai Ly     return true;
381d713a002STai Ly   }
382d713a002STai Ly 
383d713a002STai Ly   // Resize op: level check max scales
384d713a002STai Ly   bool levelCheckResize(Operation *op) {
385d713a002STai Ly     if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
386d713a002STai Ly       auto scale = resize.getScale();
3873745e708STai Ly       int16_t scaleYN = scale[0];
3883745e708STai Ly       int16_t scaleYD = scale[1];
3893745e708STai Ly       int16_t scaleXN = scale[2];
3903745e708STai Ly       int16_t scaleXD = scale[3];
3913745e708STai Ly       if (!levelCheckScale(op, scaleYN / scaleYD,
392d713a002STai Ly                            "scale_y_n/scale_y_d <= MAX_SCALE") ||
3933745e708STai Ly           !levelCheckScale(op, scaleXN / scaleXD,
394d713a002STai Ly                            "scale_x_n/scale_x_d <= MAX_SCALE")) {
395d713a002STai Ly         return false;
396d713a002STai Ly       }
397d713a002STai Ly     }
398d713a002STai Ly     return true;
399d713a002STai Ly   }
400d713a002STai Ly 
401d713a002STai Ly   // configure profile and level values from pass options profileName and
402d713a002STai Ly   // levelName
403d713a002STai Ly   void configLevelAndProfile() {
4043745e708STai Ly     tosaLevel = TOSA_LEVEL_NONE;
40532b7c1ffSBenjamin Maxwell     if (level == TosaLevelEnum::EightK) {
4063745e708STai Ly       tosaLevel = TOSA_LEVEL_EIGHTK;
407d713a002STai Ly     }
408cc9e7cb9STatWai Chong 
409cc9e7cb9STatWai Chong     if (!profile.empty()) {
410cc9e7cb9STatWai Chong       for (std::string &prof : profile) {
411cc9e7cb9STatWai Chong         auto profSymbol = symbolizeTosaProfileEnum(prof);
412cc9e7cb9STatWai Chong         if (profSymbol) {
413cc9e7cb9STatWai Chong           enabled_profiles.push_back(profSymbol.value());
414cc9e7cb9STatWai Chong         }
415cc9e7cb9STatWai Chong       }
416cc9e7cb9STatWai Chong     }
417d713a002STai Ly   }
418d713a002STai Ly 
419af972f01STai Ly   bool CheckVariable(Operation *op);
420af972f01STai Ly   bool CheckVariableReadOrWrite(Operation *op);
421af972f01STai Ly 
422c6d419c1SMatthias Gehre   bool isValidElementType(Type type);
423cc9e7cb9STatWai Chong   bool isEnabledProfile(TosaProfileEnum prof) {
424cc9e7cb9STatWai Chong     return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) !=
425cc9e7cb9STatWai Chong            std::end(enabled_profiles);
426cc9e7cb9STatWai Chong   }
427c6d419c1SMatthias Gehre 
4283745e708STai Ly   SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
429cc9e7cb9STatWai Chong   SmallVector<TosaProfileEnum, 3> enabled_profiles;
4303745e708STai Ly   TosaLevel tosaLevel;
4313745e708STai Ly   DenseMap<StringAttr, mlir::Type> variablesMap;
432d713a002STai Ly };
433d713a002STai Ly 
434d713a002STai Ly LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
4353745e708STai Ly   if (tosaLevel == TOSA_LEVEL_NONE) {
436d713a002STai Ly     // no need to do level checks
437d713a002STai Ly     return success();
438d713a002STai Ly   }
439d713a002STai Ly 
440d713a002STai Ly   if (!levelCheckRanks(op)) {
441d713a002STai Ly     return failure();
442d713a002STai Ly   }
443d713a002STai Ly 
444d713a002STai Ly   // additional level checks from spec 0.70
445d713a002STai Ly   if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
446d713a002STai Ly       !levelCheckConv<tosa::Conv2DOp>(op) ||
447d713a002STai Ly       !levelCheckConv<tosa::Conv3DOp>(op) ||
448d713a002STai Ly       !levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
449d713a002STai Ly       !levelCheckFFT<tosa::FFT2dOp>(op) ||
450d713a002STai Ly       !levelCheckPool<tosa::MaxPool2dOp>(op) ||
451d713a002STai Ly       !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
452d713a002STai Ly       !levelCheckResize(op)) {
453d713a002STai Ly     return failure();
454d713a002STai Ly   }
455d713a002STai Ly 
456d713a002STai Ly   return success();
457d713a002STai Ly }
458d713a002STai Ly 
459af972f01STai Ly inline bool CompatibleTypes(const mlir::Type &type,
4603745e708STai Ly                             const mlir::Type &declaredType) {
461af972f01STai Ly   // for now, simply use type equality comparison
4623745e708STai Ly   return type == declaredType;
463af972f01STai Ly }
464af972f01STai Ly 
465af972f01STai Ly bool TosaValidation::CheckVariable(Operation *op) {
466af972f01STai Ly   if (isa<mlir::tosa::VariableOp>(op)) {
4673745e708STai Ly     auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
468af972f01STai Ly 
4693745e708STai Ly     if (variablesMap.count(nameAttr)) {
470af972f01STai Ly       op->emitOpError() << "name has already been declared";
471af972f01STai Ly       return false;
472af972f01STai Ly     }
473af972f01STai Ly 
4743745e708STai Ly     auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
4753745e708STai Ly     mlir::Type type = typeAttr.getValue();
476af972f01STai Ly 
4773745e708STai Ly     variablesMap[nameAttr] = type;
478af972f01STai Ly   }
479af972f01STai Ly 
480af972f01STai Ly   return true;
481af972f01STai Ly }
482af972f01STai Ly 
483af972f01STai Ly bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
484af972f01STai Ly   if (isa<mlir::tosa::VariableReadOp>(op) ||
485af972f01STai Ly       isa<mlir::tosa::VariableWriteOp>(op)) {
4863745e708STai Ly     auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
487af972f01STai Ly 
4883745e708STai Ly     if (!variablesMap.count(nameAttr)) {
489af972f01STai Ly       op->emitOpError() << "name has not been declared";
490af972f01STai Ly       return false;
491af972f01STai Ly     }
492af972f01STai Ly 
4933745e708STai Ly     auto varType = variablesMap[nameAttr];
494af972f01STai Ly 
495af972f01STai Ly     for (auto v : op->getOperands()) {
496af972f01STai Ly       auto type = v.getType();
4973745e708STai Ly       if (!CompatibleTypes(type, varType)) {
498af972f01STai Ly         op->emitOpError() << "operand type does not equal variable type";
499af972f01STai Ly         return false;
500af972f01STai Ly       }
501af972f01STai Ly     }
502af972f01STai Ly 
503af972f01STai Ly     for (auto v : op->getResults()) {
504af972f01STai Ly       auto type = v.getType();
5053745e708STai Ly       if (!CompatibleTypes(type, varType)) {
506af972f01STai Ly         op->emitOpError() << "result type does not equal variable type";
507af972f01STai Ly         return false;
508af972f01STai Ly       }
509af972f01STai Ly     }
510af972f01STai Ly   }
511af972f01STai Ly 
512af972f01STai Ly   return true;
513af972f01STai Ly }
514af972f01STai Ly 
515af972f01STai Ly LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
516af972f01STai Ly   if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
517af972f01STai Ly     return failure();
518af972f01STai Ly   }
519af972f01STai Ly   return success();
520af972f01STai Ly }
521af972f01STai Ly 
522c6d419c1SMatthias Gehre bool TosaValidation::isValidElementType(Type type) {
523ea238974SMatthias Gehre   if (isa<FloatType>(type)) {
524cc9e7cb9STatWai Chong     if (!isEnabledProfile(TosaProfileEnum::MainInference))
525c6d419c1SMatthias Gehre       return false;
526ea238974SMatthias Gehre     return type.isF32() || type.isF16() || type.isBF16();
5279472c5fcSLuke Hutton   } else if (auto intTy = dyn_cast<IntegerType>(type)) {
5289472c5fcSLuke Hutton     if (intTy.isSignless()) {
529c6d419c1SMatthias Gehre       switch (intTy.getWidth()) {
530c6d419c1SMatthias Gehre       case 1:
531c6d419c1SMatthias Gehre       case 4:
532c6d419c1SMatthias Gehre       case 8:
533c6d419c1SMatthias Gehre       case 16:
534c6d419c1SMatthias Gehre       case 32:
535c6d419c1SMatthias Gehre       case 48:
536c6d419c1SMatthias Gehre         return true;
5379472c5fcSLuke Hutton       }
538c6d419c1SMatthias Gehre     }
539*f09db6a3SJerry-Ge   } else if (mlir::isa<tosa::shapeType>(type)) {
540*f09db6a3SJerry-Ge     return true;
541c6d419c1SMatthias Gehre   }
542c6d419c1SMatthias Gehre   return false;
543c6d419c1SMatthias Gehre }
544c6d419c1SMatthias Gehre 
545d713a002STai Ly void TosaValidation::runOnOperation() {
546d713a002STai Ly   configLevelAndProfile();
54739e93eeeSLuke Hutton 
54839e93eeeSLuke Hutton   TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>();
54939e93eeeSLuke Hutton   if (!tosaDialect)
55039e93eeeSLuke Hutton     return;
55139e93eeeSLuke Hutton 
552940d3e08STatWai Chong   getOperation().walk([&](Operation *op) {
55339e93eeeSLuke Hutton     if (op->getDialect() != tosaDialect)
554e4351f27SLuke Hutton       return;
555e4351f27SLuke Hutton 
556940d3e08STatWai Chong     for (Value operand : op->getOperands()) {
557c6d419c1SMatthias Gehre       auto elementTy = getElementTypeOrSelf(operand);
558c6d419c1SMatthias Gehre       if (!isValidElementType(elementTy)) {
559c6d419c1SMatthias Gehre         op->emitOpError() << "is not profile-aligned: element type "
560c6d419c1SMatthias Gehre                           << elementTy << " is not legal";
561940d3e08STatWai Chong         return signalPassFailure();
562940d3e08STatWai Chong       }
563c6d419c1SMatthias Gehre     }
564c6d419c1SMatthias Gehre     for (Type resultTy : op->getResultTypes()) {
565c6d419c1SMatthias Gehre       auto elementTy = getElementTypeOrSelf(resultTy);
566c6d419c1SMatthias Gehre       if (!isValidElementType(elementTy)) {
567c6d419c1SMatthias Gehre         op->emitOpError() << "is not profile-aligned: element type "
568c6d419c1SMatthias Gehre                           << elementTy << " is not legal";
569a2dcd994SAmosLewis         return signalPassFailure();
570a2dcd994SAmosLewis       }
571940d3e08STatWai Chong     }
57208b0977aSTatWai Chong 
573af972f01STai Ly     // Some uses of TOSA rely on the constant operands of particular
574af972f01STai Ly     // operations.
57508b0977aSTatWai Chong     if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
57608b0977aSTatWai Chong       signalPassFailure();
577d713a002STai Ly 
578d713a002STai Ly     // do level checks
579d713a002STai Ly     if (failed(applyLevelCheck(op)))
580d713a002STai Ly       signalPassFailure();
581af972f01STai Ly 
582af972f01STai Ly     // do variable type checks
583af972f01STai Ly     if (failed(applyVariableCheck(op)))
584af972f01STai Ly       signalPassFailure();
585940d3e08STatWai Chong   });
586940d3e08STatWai Chong }
587940d3e08STatWai Chong } // namespace
588