1940d3e08STatWai Chong //===- TosaValidation.cpp ------------------------------------------------===// 2940d3e08STatWai Chong // 3940d3e08STatWai Chong // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4940d3e08STatWai Chong // See https://llvm.org/LICENSE.txt for license information. 5940d3e08STatWai Chong // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6940d3e08STatWai Chong // 7940d3e08STatWai Chong //===----------------------------------------------------------------------===// 8940d3e08STatWai Chong // 9940d3e08STatWai Chong // Validate if TOSA dialect input matchs with the specification for given 10940d3e08STatWai Chong // requirements. 11940d3e08STatWai Chong // 12940d3e08STatWai Chong //===----------------------------------------------------------------------===// 13940d3e08STatWai Chong 14940d3e08STatWai Chong #include "mlir/Dialect/Tosa/Transforms/Passes.h" 15940d3e08STatWai Chong #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" 16940d3e08STatWai Chong 17af972f01STai Ly #include <string> 18af972f01STai Ly 19940d3e08STatWai Chong #include "mlir/Dialect/Func/IR/FuncOps.h" 20940d3e08STatWai Chong #include "mlir/Dialect/Tosa/IR/TosaOps.h" 21940d3e08STatWai Chong #include "mlir/IR/Builders.h" 22940d3e08STatWai Chong #include "mlir/IR/BuiltinOps.h" 23940d3e08STatWai Chong #include "mlir/IR/Matchers.h" 24940d3e08STatWai Chong #include "mlir/IR/TypeUtilities.h" 25940d3e08STatWai Chong #include "mlir/Pass/Pass.h" 26940d3e08STatWai Chong #include "mlir/Transforms/DialectConversion.h" 27940d3e08STatWai Chong 28940d3e08STatWai Chong namespace mlir { 29940d3e08STatWai Chong namespace tosa { 30940d3e08STatWai Chong #define GEN_PASS_DEF_TOSAVALIDATION 31940d3e08STatWai Chong #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 32940d3e08STatWai Chong } // namespace tosa 33940d3e08STatWai Chong } // namespace mlir 34940d3e08STatWai Chong 35940d3e08STatWai Chong using namespace mlir; 36940d3e08STatWai Chong using namespace mlir::tosa; 37940d3e08STatWai Chong 38940d3e08STatWai Chong namespace { 39940d3e08STatWai Chong 4008b0977aSTatWai Chong static LogicalResult checkConstantOperandPad(Operation *op) { 413745e708STai Ly if (auto padOp = dyn_cast<tosa::PadOp>(op)) { 4208b0977aSTatWai Chong DenseElementsAttr paddings; 433745e708STai Ly if (!matchPattern(padOp.getPadding(), m_Constant(&paddings))) 4408b0977aSTatWai Chong return op->emitOpError("padding of pad is not constant"); 4508b0977aSTatWai Chong 463745e708STai Ly DenseElementsAttr padConst; 473745e708STai Ly // Assume this op is zero-padding if padConst is not presented. 483745e708STai Ly if (padOp.getPadConst() && 493745e708STai Ly !matchPattern(padOp.getPadConst(), m_Constant(&padConst))) 5008b0977aSTatWai Chong return op->emitOpError("pad_const of pad is not constant"); 5108b0977aSTatWai Chong } 5208b0977aSTatWai Chong return success(); 5308b0977aSTatWai Chong } 5408b0977aSTatWai Chong 5508b0977aSTatWai Chong static LogicalResult checkConstantOperandTranspose(Operation *op) { 563745e708STai Ly if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) { 5708b0977aSTatWai Chong DenseElementsAttr perms; 583745e708STai Ly if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms))) 5908b0977aSTatWai Chong return op->emitOpError("perms of transpose is not constant"); 6008b0977aSTatWai Chong } 6108b0977aSTatWai Chong return success(); 6208b0977aSTatWai Chong } 6308b0977aSTatWai Chong 6408b0977aSTatWai Chong static LogicalResult checkConstantOperandFullyConnected(Operation *op) { 653745e708STai Ly if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) { 6608b0977aSTatWai Chong DenseElementsAttr weight; 673745e708STai Ly if (!matchPattern(fcOp.getWeight(), m_Constant(&weight))) 6808b0977aSTatWai Chong return op->emitOpError("weight of fully_connected is not constant"); 6908b0977aSTatWai Chong 7008b0977aSTatWai Chong DenseElementsAttr bias; 713745e708STai Ly if (!matchPattern(fcOp.getBias(), m_Constant(&bias))) 7208b0977aSTatWai Chong return op->emitOpError("bias of fully_connected is not constant"); 7308b0977aSTatWai Chong } 7408b0977aSTatWai Chong return success(); 7508b0977aSTatWai Chong } 7608b0977aSTatWai Chong 773745e708STai Ly struct TosaLevel { 78d713a002STai Ly int32_t MAX_RANK = 0; 79d713a002STai Ly int32_t MAX_KERNEL = 0; 80d713a002STai Ly int32_t MAX_STRIDE = 0; 81d713a002STai Ly int32_t MAX_SCALE = 0; 82d713a002STai Ly 83d713a002STai Ly // @todo: MAX_LOG2_SIZE value and checks 84d713a002STai Ly 853745e708STai Ly bool operator==(const TosaLevel &rhs) { 86d713a002STai Ly return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && 87d713a002STai Ly MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE; 88d713a002STai Ly } 89d713a002STai Ly }; 90d713a002STai Ly 913745e708STai Ly static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256}; 923745e708STai Ly static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0}; 93d713a002STai Ly 94940d3e08STatWai Chong //===----------------------------------------------------------------------===// 95940d3e08STatWai Chong // TOSA Validation Pass. 96940d3e08STatWai Chong //===----------------------------------------------------------------------===// 97940d3e08STatWai Chong 98940d3e08STatWai Chong struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> { 99940d3e08STatWai Chong public: 10008b0977aSTatWai Chong explicit TosaValidation() { populateConstantOperandChecks(); } 101af972f01STai Ly explicit TosaValidation(const TosaValidationOptions &options) 102af972f01STai Ly : TosaValidation() { 10332b7c1ffSBenjamin Maxwell this->profile = options.profile; 104af972f01STai Ly this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment; 10532b7c1ffSBenjamin Maxwell this->level = options.level; 10632b7c1ffSBenjamin Maxwell } 107af972f01STai Ly void runOnOperation() final; 108940d3e08STatWai Chong 10908b0977aSTatWai Chong LogicalResult applyConstantOperandCheck(Operation *op) { 1103745e708STai Ly for (auto &checker : constCheckers) { 11108b0977aSTatWai Chong if (failed(checker(op))) 11208b0977aSTatWai Chong return failure(); 11308b0977aSTatWai Chong } 11408b0977aSTatWai Chong return success(); 11508b0977aSTatWai Chong } 11608b0977aSTatWai Chong 117d713a002STai Ly LogicalResult applyLevelCheck(Operation *op); 118d713a002STai Ly 119af972f01STai Ly // check variable read/write data types against variable declarations 120af972f01STai Ly LogicalResult applyVariableCheck(Operation *op); 121af972f01STai Ly 12208b0977aSTatWai Chong private: 12308b0977aSTatWai Chong void populateConstantOperandChecks() { 1243745e708STai Ly constCheckers.emplace_back(checkConstantOperandPad); 1253745e708STai Ly constCheckers.emplace_back(checkConstantOperandTranspose); 1263745e708STai Ly constCheckers.emplace_back(checkConstantOperandFullyConnected); 12708b0977aSTatWai Chong } 12808b0977aSTatWai Chong 129d713a002STai Ly bool levelCheckKernel(Operation *op, int32_t v, 1303745e708STai Ly const std::string &checkDesc) { 1313745e708STai Ly if (v > tosaLevel.MAX_KERNEL) { 1323745e708STai Ly op->emitOpError() << "failed level check: " << checkDesc; 133d713a002STai Ly return false; 134d713a002STai Ly } 135d713a002STai Ly return true; 136d713a002STai Ly } 137940d3e08STatWai Chong 138d713a002STai Ly bool levelCheckStride(Operation *op, int32_t v, 1393745e708STai Ly const std::string &checkDesc) { 1403745e708STai Ly if (v > tosaLevel.MAX_STRIDE) { 1413745e708STai Ly op->emitOpError() << "failed level check: " << checkDesc; 142d713a002STai Ly return false; 143d713a002STai Ly } 144d713a002STai Ly return true; 145d713a002STai Ly } 146d713a002STai Ly 1473745e708STai Ly bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) { 1483745e708STai Ly if (v > tosaLevel.MAX_SCALE) { 1493745e708STai Ly op->emitOpError() << "failed level check: " << checkDesc; 150d713a002STai Ly return false; 151d713a002STai Ly } 152d713a002STai Ly return true; 153d713a002STai Ly } 154d713a002STai Ly 155d713a002STai Ly bool levelCheckRank(Operation *op, const Value &v, 1563745e708STai Ly const std::string &checkDesc) { 157d713a002STai Ly if (ShapedType type = dyn_cast<ShapedType>(v.getType())) { 1583651f377SSarthak Gupta if (!type.hasRank()) { 1593651f377SSarthak Gupta op->emitOpError() << "failed level check: unranked tensor"; 1603651f377SSarthak Gupta return false; 1613651f377SSarthak Gupta } 1623745e708STai Ly if (type.getRank() > tosaLevel.MAX_RANK) { 1633745e708STai Ly op->emitOpError() << "failed level check: " << checkDesc; 164d713a002STai Ly return false; 165d713a002STai Ly } 166d713a002STai Ly } 167d713a002STai Ly return true; 168d713a002STai Ly } 169d713a002STai Ly 170d713a002STai Ly template <typename T> 171d713a002STai Ly bool levelCheckRanksFor(Operation *op) { 172d713a002STai Ly if (dyn_cast<T>(op)) { 173d713a002STai Ly // level check ranks of all operands and results 174d713a002STai Ly for (auto v : op->getOperands()) { 175d713a002STai Ly if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK")) 176d713a002STai Ly return false; 177d713a002STai Ly } 178d713a002STai Ly for (auto v : op->getResults()) { 179d713a002STai Ly if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK")) 180d713a002STai Ly return false; 181d713a002STai Ly } 182d713a002STai Ly } 183d713a002STai Ly return true; 184d713a002STai Ly } 185d713a002STai Ly 186d713a002STai Ly bool levelCheckRanks(Operation *op) { 1873745e708STai Ly #define CHECK_RANKS_FOR(tosaOp) \ 1883745e708STai Ly if (!levelCheckRanksFor<tosaOp##Op>(op)) \ 189d713a002STai Ly return false; 190d713a002STai Ly 191d713a002STai Ly // tensor operators: 192d713a002STai Ly CHECK_RANKS_FOR(ArgMax); 193d713a002STai Ly // all activation functions: 194d713a002STai Ly CHECK_RANKS_FOR(Clamp); 195d713a002STai Ly CHECK_RANKS_FOR(Sigmoid); 196d713a002STai Ly CHECK_RANKS_FOR(Tanh); 197d713a002STai Ly // all elementwise binary operators: 198d713a002STai Ly CHECK_RANKS_FOR(Add); 199d713a002STai Ly CHECK_RANKS_FOR(ArithmeticRightShift); 200d713a002STai Ly CHECK_RANKS_FOR(BitwiseAnd); 201d713a002STai Ly CHECK_RANKS_FOR(BitwiseOr); 202d713a002STai Ly CHECK_RANKS_FOR(BitwiseXor); 20382383d5fSTai Ly CHECK_RANKS_FOR(IntDiv); 204d713a002STai Ly CHECK_RANKS_FOR(LogicalAnd); 205d713a002STai Ly CHECK_RANKS_FOR(LogicalLeftShift); 206d713a002STai Ly CHECK_RANKS_FOR(LogicalRightShift); 207d713a002STai Ly CHECK_RANKS_FOR(LogicalOr); 208d713a002STai Ly CHECK_RANKS_FOR(LogicalXor); 209d713a002STai Ly CHECK_RANKS_FOR(Maximum); 210d713a002STai Ly CHECK_RANKS_FOR(Minimum); 211d713a002STai Ly CHECK_RANKS_FOR(Mul); 212d713a002STai Ly CHECK_RANKS_FOR(Pow); 213d713a002STai Ly CHECK_RANKS_FOR(Sub); 214d713a002STai Ly CHECK_RANKS_FOR(Table); 215d713a002STai Ly // all elementwise unary operators: 216d713a002STai Ly CHECK_RANKS_FOR(Abs); 217d713a002STai Ly CHECK_RANKS_FOR(BitwiseNot); 218d713a002STai Ly CHECK_RANKS_FOR(Ceil); 219d713a002STai Ly CHECK_RANKS_FOR(Clz); 220d713a002STai Ly CHECK_RANKS_FOR(Exp); 221d713a002STai Ly CHECK_RANKS_FOR(Floor); 222d713a002STai Ly CHECK_RANKS_FOR(Log); 223d713a002STai Ly CHECK_RANKS_FOR(LogicalNot); 224d713a002STai Ly CHECK_RANKS_FOR(Negate); 225d713a002STai Ly CHECK_RANKS_FOR(Reciprocal); 226d713a002STai Ly CHECK_RANKS_FOR(Rsqrt); 227d713a002STai Ly // all elementwise ternary operators: 228d713a002STai Ly CHECK_RANKS_FOR(Select); 229d713a002STai Ly // all comparison operators: 230d713a002STai Ly CHECK_RANKS_FOR(Equal); 231d713a002STai Ly CHECK_RANKS_FOR(Greater); 232d713a002STai Ly CHECK_RANKS_FOR(GreaterEqual); 233d713a002STai Ly // all reduction operators: 234d713a002STai Ly CHECK_RANKS_FOR(ReduceAll); 235d713a002STai Ly CHECK_RANKS_FOR(ReduceAny); 236d713a002STai Ly CHECK_RANKS_FOR(ReduceMax); 237d713a002STai Ly CHECK_RANKS_FOR(ReduceMin); 238d713a002STai Ly CHECK_RANKS_FOR(ReduceProd); 239d713a002STai Ly CHECK_RANKS_FOR(ReduceSum); 240d713a002STai Ly // all data layout operators: 241d713a002STai Ly CHECK_RANKS_FOR(Concat); 242d713a002STai Ly CHECK_RANKS_FOR(Pad); 243d713a002STai Ly CHECK_RANKS_FOR(Reshape); 244d713a002STai Ly CHECK_RANKS_FOR(Reverse); 245d713a002STai Ly CHECK_RANKS_FOR(Slice); 246d713a002STai Ly CHECK_RANKS_FOR(Tile); 247d713a002STai Ly CHECK_RANKS_FOR(Transpose); 248d713a002STai Ly // all type conversion operators: 249d713a002STai Ly CHECK_RANKS_FOR(Cast); 250d713a002STai Ly CHECK_RANKS_FOR(Rescale); 251d713a002STai Ly // all data nodes operators: 252d713a002STai Ly CHECK_RANKS_FOR(Const); 253d713a002STai Ly CHECK_RANKS_FOR(Identity); 254d713a002STai Ly 255d713a002STai Ly #undef CHECK_RANKS_FOR 256d713a002STai Ly return true; 257d713a002STai Ly } 258d713a002STai Ly 259d713a002STai Ly // Pool Op: level check kernel/stride/pad values 260d713a002STai Ly template <typename T> 261d713a002STai Ly bool levelCheckPool(Operation *op) { 2623745e708STai Ly if (auto poolOp = dyn_cast<T>(op)) { 2633745e708STai Ly for (auto k : poolOp.getKernel()) { 264d713a002STai Ly if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) { 265d713a002STai Ly return false; 266d713a002STai Ly } 267d713a002STai Ly } 2683745e708STai Ly for (auto s : poolOp.getStride()) { 269d713a002STai Ly if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { 270d713a002STai Ly return false; 271d713a002STai Ly } 272d713a002STai Ly } 2733745e708STai Ly for (auto p : poolOp.getPad()) { 274d713a002STai Ly if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { 275d713a002STai Ly return false; 276d713a002STai Ly } 277d713a002STai Ly } 278d713a002STai Ly } 279d713a002STai Ly return true; 280d713a002STai Ly } 281d713a002STai Ly 282d713a002STai Ly // Conv Op: level check dilation/stride/pad values 283d713a002STai Ly template <typename T> 284d713a002STai Ly bool levelCheckConv(Operation *op) { 2853745e708STai Ly if (auto convOp = dyn_cast<T>(op)) { 286d713a002STai Ly 2873745e708STai Ly for (auto k : convOp.getDilation()) { 288d713a002STai Ly if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) { 289d713a002STai Ly return false; 290d713a002STai Ly } 291d713a002STai Ly } 2923745e708STai Ly for (auto p : convOp.getPad()) { 293d713a002STai Ly if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { 294d713a002STai Ly return false; 295d713a002STai Ly } 296d713a002STai Ly } 2973745e708STai Ly for (auto s : convOp.getStride()) { 298d713a002STai Ly if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { 299d713a002STai Ly return false; 300d713a002STai Ly } 301d713a002STai Ly } 3023745e708STai Ly auto dilation = convOp.getDilation(); 3033745e708STai Ly if (ShapedType weightType = 304d713a002STai Ly dyn_cast<ShapedType>(op->getOperand(1).getType())) { 3053745e708STai Ly auto shape = weightType.getShape(); 306d713a002STai Ly if (isa<tosa::Conv2DOp>(op)) { 307d713a002STai Ly assert(shape.size() == 4); 308d713a002STai Ly assert(dilation.size() == 2); 309d713a002STai Ly if (!levelCheckKernel(op, dilation[0] * shape[1], 310d713a002STai Ly "dilation_y * KH <= MAX_KERNEL)") || 311d713a002STai Ly !levelCheckKernel(op, dilation[1] * shape[2], 312d713a002STai Ly "dilation_x * KW <= MAX_KERNEL)")) 313d713a002STai Ly return false; 314d713a002STai Ly } else if (isa<tosa::Conv3DOp>(op)) { 315d713a002STai Ly assert(shape.size() == 5); 316d713a002STai Ly assert(dilation.size() == 3); 317d713a002STai Ly if (!levelCheckKernel(op, dilation[0] * shape[1], 318d713a002STai Ly "dilation_d * KD <= MAX_KERNEL)") || 319d713a002STai Ly !levelCheckKernel(op, dilation[1] * shape[2], 320d713a002STai Ly "dilation_y * KH <= MAX_KERNEL)") || 321d713a002STai Ly !levelCheckKernel(op, dilation[2] * shape[3], 322d713a002STai Ly "dilation_x * KW <= MAX_KERNEL)")) 323d713a002STai Ly return false; 324d713a002STai Ly } else if (isa<tosa::DepthwiseConv2DOp>(op)) { 325d713a002STai Ly assert(shape.size() == 4); 326d713a002STai Ly assert(dilation.size() == 2); 327d713a002STai Ly if (!levelCheckKernel(op, dilation[0] * shape[0], 328d713a002STai Ly "dilation_y * KH <= MAX_KERNEL)") || 329d713a002STai Ly !levelCheckKernel(op, dilation[1] * shape[1], 330d713a002STai Ly "dilation_x * KW <= MAX_KERNEL)")) 331d713a002STai Ly return false; 332d713a002STai Ly } 333d713a002STai Ly } 334d713a002STai Ly } 335d713a002STai Ly return true; 336d713a002STai Ly } 337d713a002STai Ly 338d713a002STai Ly // FFT op: level check H, W in input shape [N,H,W] 339d713a002STai Ly template <typename T> 340d713a002STai Ly bool levelCheckFFT(Operation *op) { 341d713a002STai Ly if (isa<T>(op)) { 342d713a002STai Ly for (auto v : op->getOperands()) { 343d713a002STai Ly if (ShapedType type = dyn_cast<ShapedType>(v.getType())) { 344d713a002STai Ly auto shape = type.getShape(); 345d713a002STai Ly assert(shape.size() == 3); 346d713a002STai Ly if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") || 347d713a002STai Ly !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) { 348d713a002STai Ly return false; 349d713a002STai Ly } 350d713a002STai Ly } 351d713a002STai Ly } 352d713a002STai Ly } 353d713a002STai Ly return true; 354d713a002STai Ly } 355d713a002STai Ly 356d713a002STai Ly // TransposeConv2d op: level check kH/kW, outpad, and stride 357d713a002STai Ly bool levelCheckTransposeConv2d(Operation *op) { 358d713a002STai Ly if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) { 3593745e708STai Ly if (ShapedType filterType = 360a5757c5bSChristian Sigg dyn_cast<ShapedType>(transpose.getFilter().getType())) { 3613745e708STai Ly auto shape = filterType.getShape(); 362d713a002STai Ly assert(shape.size() == 4); 363d713a002STai Ly // level check kernel sizes for kH and KW 364d713a002STai Ly if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") || 365d713a002STai Ly !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) { 366d713a002STai Ly return false; 367d713a002STai Ly } 368d713a002STai Ly } 369d713a002STai Ly for (auto p : transpose.getOutPad()) { 370d713a002STai Ly if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { 371d713a002STai Ly return false; 372d713a002STai Ly } 373d713a002STai Ly } 374d713a002STai Ly for (auto s : transpose.getStride()) { 375d713a002STai Ly if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { 376d713a002STai Ly return false; 377d713a002STai Ly } 378d713a002STai Ly } 379d713a002STai Ly } 380d713a002STai Ly return true; 381d713a002STai Ly } 382d713a002STai Ly 383d713a002STai Ly // Resize op: level check max scales 384d713a002STai Ly bool levelCheckResize(Operation *op) { 385d713a002STai Ly if (auto resize = dyn_cast<tosa::ResizeOp>(op)) { 386d713a002STai Ly auto scale = resize.getScale(); 3873745e708STai Ly int16_t scaleYN = scale[0]; 3883745e708STai Ly int16_t scaleYD = scale[1]; 3893745e708STai Ly int16_t scaleXN = scale[2]; 3903745e708STai Ly int16_t scaleXD = scale[3]; 3913745e708STai Ly if (!levelCheckScale(op, scaleYN / scaleYD, 392d713a002STai Ly "scale_y_n/scale_y_d <= MAX_SCALE") || 3933745e708STai Ly !levelCheckScale(op, scaleXN / scaleXD, 394d713a002STai Ly "scale_x_n/scale_x_d <= MAX_SCALE")) { 395d713a002STai Ly return false; 396d713a002STai Ly } 397d713a002STai Ly } 398d713a002STai Ly return true; 399d713a002STai Ly } 400d713a002STai Ly 401d713a002STai Ly // configure profile and level values from pass options profileName and 402d713a002STai Ly // levelName 403d713a002STai Ly void configLevelAndProfile() { 4043745e708STai Ly tosaLevel = TOSA_LEVEL_NONE; 40532b7c1ffSBenjamin Maxwell if (level == TosaLevelEnum::EightK) { 4063745e708STai Ly tosaLevel = TOSA_LEVEL_EIGHTK; 407d713a002STai Ly } 408cc9e7cb9STatWai Chong 409cc9e7cb9STatWai Chong if (!profile.empty()) { 410cc9e7cb9STatWai Chong for (std::string &prof : profile) { 411cc9e7cb9STatWai Chong auto profSymbol = symbolizeTosaProfileEnum(prof); 412cc9e7cb9STatWai Chong if (profSymbol) { 413cc9e7cb9STatWai Chong enabled_profiles.push_back(profSymbol.value()); 414cc9e7cb9STatWai Chong } 415cc9e7cb9STatWai Chong } 416cc9e7cb9STatWai Chong } 417d713a002STai Ly } 418d713a002STai Ly 419af972f01STai Ly bool CheckVariable(Operation *op); 420af972f01STai Ly bool CheckVariableReadOrWrite(Operation *op); 421af972f01STai Ly 422c6d419c1SMatthias Gehre bool isValidElementType(Type type); 423cc9e7cb9STatWai Chong bool isEnabledProfile(TosaProfileEnum prof) { 424cc9e7cb9STatWai Chong return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) != 425cc9e7cb9STatWai Chong std::end(enabled_profiles); 426cc9e7cb9STatWai Chong } 427c6d419c1SMatthias Gehre 4283745e708STai Ly SmallVector<std::function<LogicalResult(Operation *)>> constCheckers; 429cc9e7cb9STatWai Chong SmallVector<TosaProfileEnum, 3> enabled_profiles; 4303745e708STai Ly TosaLevel tosaLevel; 4313745e708STai Ly DenseMap<StringAttr, mlir::Type> variablesMap; 432d713a002STai Ly }; 433d713a002STai Ly 434d713a002STai Ly LogicalResult TosaValidation::applyLevelCheck(Operation *op) { 4353745e708STai Ly if (tosaLevel == TOSA_LEVEL_NONE) { 436d713a002STai Ly // no need to do level checks 437d713a002STai Ly return success(); 438d713a002STai Ly } 439d713a002STai Ly 440d713a002STai Ly if (!levelCheckRanks(op)) { 441d713a002STai Ly return failure(); 442d713a002STai Ly } 443d713a002STai Ly 444d713a002STai Ly // additional level checks from spec 0.70 445d713a002STai Ly if (!levelCheckPool<tosa::AvgPool2dOp>(op) || 446d713a002STai Ly !levelCheckConv<tosa::Conv2DOp>(op) || 447d713a002STai Ly !levelCheckConv<tosa::Conv3DOp>(op) || 448d713a002STai Ly !levelCheckConv<tosa::DepthwiseConv2DOp>(op) || 449d713a002STai Ly !levelCheckFFT<tosa::FFT2dOp>(op) || 450d713a002STai Ly !levelCheckPool<tosa::MaxPool2dOp>(op) || 451d713a002STai Ly !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) || 452d713a002STai Ly !levelCheckResize(op)) { 453d713a002STai Ly return failure(); 454d713a002STai Ly } 455d713a002STai Ly 456d713a002STai Ly return success(); 457d713a002STai Ly } 458d713a002STai Ly 459af972f01STai Ly inline bool CompatibleTypes(const mlir::Type &type, 4603745e708STai Ly const mlir::Type &declaredType) { 461af972f01STai Ly // for now, simply use type equality comparison 4623745e708STai Ly return type == declaredType; 463af972f01STai Ly } 464af972f01STai Ly 465af972f01STai Ly bool TosaValidation::CheckVariable(Operation *op) { 466af972f01STai Ly if (isa<mlir::tosa::VariableOp>(op)) { 4673745e708STai Ly auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name")); 468af972f01STai Ly 4693745e708STai Ly if (variablesMap.count(nameAttr)) { 470af972f01STai Ly op->emitOpError() << "name has already been declared"; 471af972f01STai Ly return false; 472af972f01STai Ly } 473af972f01STai Ly 4743745e708STai Ly auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type")); 4753745e708STai Ly mlir::Type type = typeAttr.getValue(); 476af972f01STai Ly 4773745e708STai Ly variablesMap[nameAttr] = type; 478af972f01STai Ly } 479af972f01STai Ly 480af972f01STai Ly return true; 481af972f01STai Ly } 482af972f01STai Ly 483af972f01STai Ly bool TosaValidation::CheckVariableReadOrWrite(Operation *op) { 484af972f01STai Ly if (isa<mlir::tosa::VariableReadOp>(op) || 485af972f01STai Ly isa<mlir::tosa::VariableWriteOp>(op)) { 4863745e708STai Ly auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name")); 487af972f01STai Ly 4883745e708STai Ly if (!variablesMap.count(nameAttr)) { 489af972f01STai Ly op->emitOpError() << "name has not been declared"; 490af972f01STai Ly return false; 491af972f01STai Ly } 492af972f01STai Ly 4933745e708STai Ly auto varType = variablesMap[nameAttr]; 494af972f01STai Ly 495af972f01STai Ly for (auto v : op->getOperands()) { 496af972f01STai Ly auto type = v.getType(); 4973745e708STai Ly if (!CompatibleTypes(type, varType)) { 498af972f01STai Ly op->emitOpError() << "operand type does not equal variable type"; 499af972f01STai Ly return false; 500af972f01STai Ly } 501af972f01STai Ly } 502af972f01STai Ly 503af972f01STai Ly for (auto v : op->getResults()) { 504af972f01STai Ly auto type = v.getType(); 5053745e708STai Ly if (!CompatibleTypes(type, varType)) { 506af972f01STai Ly op->emitOpError() << "result type does not equal variable type"; 507af972f01STai Ly return false; 508af972f01STai Ly } 509af972f01STai Ly } 510af972f01STai Ly } 511af972f01STai Ly 512af972f01STai Ly return true; 513af972f01STai Ly } 514af972f01STai Ly 515af972f01STai Ly LogicalResult TosaValidation::applyVariableCheck(Operation *op) { 516af972f01STai Ly if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) { 517af972f01STai Ly return failure(); 518af972f01STai Ly } 519af972f01STai Ly return success(); 520af972f01STai Ly } 521af972f01STai Ly 522c6d419c1SMatthias Gehre bool TosaValidation::isValidElementType(Type type) { 523ea238974SMatthias Gehre if (isa<FloatType>(type)) { 524cc9e7cb9STatWai Chong if (!isEnabledProfile(TosaProfileEnum::MainInference)) 525c6d419c1SMatthias Gehre return false; 526ea238974SMatthias Gehre return type.isF32() || type.isF16() || type.isBF16(); 5279472c5fcSLuke Hutton } else if (auto intTy = dyn_cast<IntegerType>(type)) { 5289472c5fcSLuke Hutton if (intTy.isSignless()) { 529c6d419c1SMatthias Gehre switch (intTy.getWidth()) { 530c6d419c1SMatthias Gehre case 1: 531c6d419c1SMatthias Gehre case 4: 532c6d419c1SMatthias Gehre case 8: 533c6d419c1SMatthias Gehre case 16: 534c6d419c1SMatthias Gehre case 32: 535c6d419c1SMatthias Gehre case 48: 536c6d419c1SMatthias Gehre return true; 5379472c5fcSLuke Hutton } 538c6d419c1SMatthias Gehre } 539*f09db6a3SJerry-Ge } else if (mlir::isa<tosa::shapeType>(type)) { 540*f09db6a3SJerry-Ge return true; 541c6d419c1SMatthias Gehre } 542c6d419c1SMatthias Gehre return false; 543c6d419c1SMatthias Gehre } 544c6d419c1SMatthias Gehre 545d713a002STai Ly void TosaValidation::runOnOperation() { 546d713a002STai Ly configLevelAndProfile(); 54739e93eeeSLuke Hutton 54839e93eeeSLuke Hutton TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); 54939e93eeeSLuke Hutton if (!tosaDialect) 55039e93eeeSLuke Hutton return; 55139e93eeeSLuke Hutton 552940d3e08STatWai Chong getOperation().walk([&](Operation *op) { 55339e93eeeSLuke Hutton if (op->getDialect() != tosaDialect) 554e4351f27SLuke Hutton return; 555e4351f27SLuke Hutton 556940d3e08STatWai Chong for (Value operand : op->getOperands()) { 557c6d419c1SMatthias Gehre auto elementTy = getElementTypeOrSelf(operand); 558c6d419c1SMatthias Gehre if (!isValidElementType(elementTy)) { 559c6d419c1SMatthias Gehre op->emitOpError() << "is not profile-aligned: element type " 560c6d419c1SMatthias Gehre << elementTy << " is not legal"; 561940d3e08STatWai Chong return signalPassFailure(); 562940d3e08STatWai Chong } 563c6d419c1SMatthias Gehre } 564c6d419c1SMatthias Gehre for (Type resultTy : op->getResultTypes()) { 565c6d419c1SMatthias Gehre auto elementTy = getElementTypeOrSelf(resultTy); 566c6d419c1SMatthias Gehre if (!isValidElementType(elementTy)) { 567c6d419c1SMatthias Gehre op->emitOpError() << "is not profile-aligned: element type " 568c6d419c1SMatthias Gehre << elementTy << " is not legal"; 569a2dcd994SAmosLewis return signalPassFailure(); 570a2dcd994SAmosLewis } 571940d3e08STatWai Chong } 57208b0977aSTatWai Chong 573af972f01STai Ly // Some uses of TOSA rely on the constant operands of particular 574af972f01STai Ly // operations. 57508b0977aSTatWai Chong if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op))) 57608b0977aSTatWai Chong signalPassFailure(); 577d713a002STai Ly 578d713a002STai Ly // do level checks 579d713a002STai Ly if (failed(applyLevelCheck(op))) 580d713a002STai Ly signalPassFailure(); 581af972f01STai Ly 582af972f01STai Ly // do variable type checks 583af972f01STai Ly if (failed(applyVariableCheck(op))) 584af972f01STai Ly signalPassFailure(); 585940d3e08STatWai Chong }); 586940d3e08STatWai Chong } 587940d3e08STatWai Chong } // namespace 588