1 //===- TosaValidation.cpp ------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Validate if TOSA dialect input matchs with the specification for given 10 // requirements. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 15 #include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" 16 17 #include <string> 18 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 21 #include "mlir/IR/Builders.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/Matchers.h" 24 #include "mlir/IR/TypeUtilities.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Transforms/DialectConversion.h" 27 28 namespace mlir { 29 namespace tosa { 30 #define GEN_PASS_DEF_TOSAVALIDATION 31 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 32 } // namespace tosa 33 } // namespace mlir 34 35 using namespace mlir; 36 using namespace mlir::tosa; 37 38 namespace { 39 40 static LogicalResult checkConstantOperandPad(Operation *op) { 41 if (auto padOp = dyn_cast<tosa::PadOp>(op)) { 42 DenseElementsAttr paddings; 43 if (!matchPattern(padOp.getPadding(), m_Constant(&paddings))) 44 return op->emitOpError("padding of pad is not constant"); 45 46 DenseElementsAttr padConst; 47 // Assume this op is zero-padding if padConst is not presented. 48 if (padOp.getPadConst() && 49 !matchPattern(padOp.getPadConst(), m_Constant(&padConst))) 50 return op->emitOpError("pad_const of pad is not constant"); 51 } 52 return success(); 53 } 54 55 static LogicalResult checkConstantOperandTranspose(Operation *op) { 56 if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) { 57 DenseElementsAttr perms; 58 if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms))) 59 return op->emitOpError("perms of transpose is not constant"); 60 } 61 return success(); 62 } 63 64 static LogicalResult checkConstantOperandFullyConnected(Operation *op) { 65 if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) { 66 DenseElementsAttr weight; 67 if (!matchPattern(fcOp.getWeight(), m_Constant(&weight))) 68 return op->emitOpError("weight of fully_connected is not constant"); 69 70 DenseElementsAttr bias; 71 if (!matchPattern(fcOp.getBias(), m_Constant(&bias))) 72 return op->emitOpError("bias of fully_connected is not constant"); 73 } 74 return success(); 75 } 76 77 struct TosaLevel { 78 int32_t MAX_RANK = 0; 79 int32_t MAX_KERNEL = 0; 80 int32_t MAX_STRIDE = 0; 81 int32_t MAX_SCALE = 0; 82 83 // @todo: MAX_LOG2_SIZE value and checks 84 85 bool operator==(const TosaLevel &rhs) { 86 return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && 87 MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE; 88 } 89 }; 90 91 static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256}; 92 static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0}; 93 94 //===----------------------------------------------------------------------===// 95 // TOSA Validation Pass. 96 //===----------------------------------------------------------------------===// 97 98 struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> { 99 public: 100 explicit TosaValidation() { populateConstantOperandChecks(); } 101 explicit TosaValidation(const TosaValidationOptions &options) 102 : TosaValidation() { 103 this->profile = options.profile; 104 this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment; 105 this->level = options.level; 106 } 107 void runOnOperation() final; 108 109 LogicalResult applyConstantOperandCheck(Operation *op) { 110 for (auto &checker : constCheckers) { 111 if (failed(checker(op))) 112 return failure(); 113 } 114 return success(); 115 } 116 117 LogicalResult applyLevelCheck(Operation *op); 118 119 // check variable read/write data types against variable declarations 120 LogicalResult applyVariableCheck(Operation *op); 121 122 private: 123 void populateConstantOperandChecks() { 124 constCheckers.emplace_back(checkConstantOperandPad); 125 constCheckers.emplace_back(checkConstantOperandTranspose); 126 constCheckers.emplace_back(checkConstantOperandFullyConnected); 127 } 128 129 bool levelCheckKernel(Operation *op, int32_t v, 130 const std::string &checkDesc) { 131 if (v > tosaLevel.MAX_KERNEL) { 132 op->emitOpError() << "failed level check: " << checkDesc; 133 return false; 134 } 135 return true; 136 } 137 138 bool levelCheckStride(Operation *op, int32_t v, 139 const std::string &checkDesc) { 140 if (v > tosaLevel.MAX_STRIDE) { 141 op->emitOpError() << "failed level check: " << checkDesc; 142 return false; 143 } 144 return true; 145 } 146 147 bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) { 148 if (v > tosaLevel.MAX_SCALE) { 149 op->emitOpError() << "failed level check: " << checkDesc; 150 return false; 151 } 152 return true; 153 } 154 155 bool levelCheckRank(Operation *op, const Value &v, 156 const std::string &checkDesc) { 157 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) { 158 if (!type.hasRank()) { 159 op->emitOpError() << "failed level check: unranked tensor"; 160 return false; 161 } 162 if (type.getRank() > tosaLevel.MAX_RANK) { 163 op->emitOpError() << "failed level check: " << checkDesc; 164 return false; 165 } 166 } 167 return true; 168 } 169 170 template <typename T> 171 bool levelCheckRanksFor(Operation *op) { 172 if (dyn_cast<T>(op)) { 173 // level check ranks of all operands and results 174 for (auto v : op->getOperands()) { 175 if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK")) 176 return false; 177 } 178 for (auto v : op->getResults()) { 179 if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK")) 180 return false; 181 } 182 } 183 return true; 184 } 185 186 bool levelCheckRanks(Operation *op) { 187 #define CHECK_RANKS_FOR(tosaOp) \ 188 if (!levelCheckRanksFor<tosaOp##Op>(op)) \ 189 return false; 190 191 // tensor operators: 192 CHECK_RANKS_FOR(ArgMax); 193 // all activation functions: 194 CHECK_RANKS_FOR(Clamp); 195 CHECK_RANKS_FOR(Sigmoid); 196 CHECK_RANKS_FOR(Tanh); 197 // all elementwise binary operators: 198 CHECK_RANKS_FOR(Add); 199 CHECK_RANKS_FOR(ArithmeticRightShift); 200 CHECK_RANKS_FOR(BitwiseAnd); 201 CHECK_RANKS_FOR(BitwiseOr); 202 CHECK_RANKS_FOR(BitwiseXor); 203 CHECK_RANKS_FOR(IntDiv); 204 CHECK_RANKS_FOR(LogicalAnd); 205 CHECK_RANKS_FOR(LogicalLeftShift); 206 CHECK_RANKS_FOR(LogicalRightShift); 207 CHECK_RANKS_FOR(LogicalOr); 208 CHECK_RANKS_FOR(LogicalXor); 209 CHECK_RANKS_FOR(Maximum); 210 CHECK_RANKS_FOR(Minimum); 211 CHECK_RANKS_FOR(Mul); 212 CHECK_RANKS_FOR(Pow); 213 CHECK_RANKS_FOR(Sub); 214 CHECK_RANKS_FOR(Table); 215 // all elementwise unary operators: 216 CHECK_RANKS_FOR(Abs); 217 CHECK_RANKS_FOR(BitwiseNot); 218 CHECK_RANKS_FOR(Ceil); 219 CHECK_RANKS_FOR(Clz); 220 CHECK_RANKS_FOR(Exp); 221 CHECK_RANKS_FOR(Floor); 222 CHECK_RANKS_FOR(Log); 223 CHECK_RANKS_FOR(LogicalNot); 224 CHECK_RANKS_FOR(Negate); 225 CHECK_RANKS_FOR(Reciprocal); 226 CHECK_RANKS_FOR(Rsqrt); 227 // all elementwise ternary operators: 228 CHECK_RANKS_FOR(Select); 229 // all comparison operators: 230 CHECK_RANKS_FOR(Equal); 231 CHECK_RANKS_FOR(Greater); 232 CHECK_RANKS_FOR(GreaterEqual); 233 // all reduction operators: 234 CHECK_RANKS_FOR(ReduceAll); 235 CHECK_RANKS_FOR(ReduceAny); 236 CHECK_RANKS_FOR(ReduceMax); 237 CHECK_RANKS_FOR(ReduceMin); 238 CHECK_RANKS_FOR(ReduceProd); 239 CHECK_RANKS_FOR(ReduceSum); 240 // all data layout operators: 241 CHECK_RANKS_FOR(Concat); 242 CHECK_RANKS_FOR(Pad); 243 CHECK_RANKS_FOR(Reshape); 244 CHECK_RANKS_FOR(Reverse); 245 CHECK_RANKS_FOR(Slice); 246 CHECK_RANKS_FOR(Tile); 247 CHECK_RANKS_FOR(Transpose); 248 // all type conversion operators: 249 CHECK_RANKS_FOR(Cast); 250 CHECK_RANKS_FOR(Rescale); 251 // all data nodes operators: 252 CHECK_RANKS_FOR(Const); 253 CHECK_RANKS_FOR(Identity); 254 255 #undef CHECK_RANKS_FOR 256 return true; 257 } 258 259 // Pool Op: level check kernel/stride/pad values 260 template <typename T> 261 bool levelCheckPool(Operation *op) { 262 if (auto poolOp = dyn_cast<T>(op)) { 263 for (auto k : poolOp.getKernel()) { 264 if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) { 265 return false; 266 } 267 } 268 for (auto s : poolOp.getStride()) { 269 if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { 270 return false; 271 } 272 } 273 for (auto p : poolOp.getPad()) { 274 if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { 275 return false; 276 } 277 } 278 } 279 return true; 280 } 281 282 // Conv Op: level check dilation/stride/pad values 283 template <typename T> 284 bool levelCheckConv(Operation *op) { 285 if (auto convOp = dyn_cast<T>(op)) { 286 287 for (auto k : convOp.getDilation()) { 288 if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) { 289 return false; 290 } 291 } 292 for (auto p : convOp.getPad()) { 293 if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { 294 return false; 295 } 296 } 297 for (auto s : convOp.getStride()) { 298 if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { 299 return false; 300 } 301 } 302 auto dilation = convOp.getDilation(); 303 if (ShapedType weightType = 304 dyn_cast<ShapedType>(op->getOperand(1).getType())) { 305 auto shape = weightType.getShape(); 306 if (isa<tosa::Conv2DOp>(op)) { 307 assert(shape.size() == 4); 308 assert(dilation.size() == 2); 309 if (!levelCheckKernel(op, dilation[0] * shape[1], 310 "dilation_y * KH <= MAX_KERNEL)") || 311 !levelCheckKernel(op, dilation[1] * shape[2], 312 "dilation_x * KW <= MAX_KERNEL)")) 313 return false; 314 } else if (isa<tosa::Conv3DOp>(op)) { 315 assert(shape.size() == 5); 316 assert(dilation.size() == 3); 317 if (!levelCheckKernel(op, dilation[0] * shape[1], 318 "dilation_d * KD <= MAX_KERNEL)") || 319 !levelCheckKernel(op, dilation[1] * shape[2], 320 "dilation_y * KH <= MAX_KERNEL)") || 321 !levelCheckKernel(op, dilation[2] * shape[3], 322 "dilation_x * KW <= MAX_KERNEL)")) 323 return false; 324 } else if (isa<tosa::DepthwiseConv2DOp>(op)) { 325 assert(shape.size() == 4); 326 assert(dilation.size() == 2); 327 if (!levelCheckKernel(op, dilation[0] * shape[0], 328 "dilation_y * KH <= MAX_KERNEL)") || 329 !levelCheckKernel(op, dilation[1] * shape[1], 330 "dilation_x * KW <= MAX_KERNEL)")) 331 return false; 332 } 333 } 334 } 335 return true; 336 } 337 338 // FFT op: level check H, W in input shape [N,H,W] 339 template <typename T> 340 bool levelCheckFFT(Operation *op) { 341 if (isa<T>(op)) { 342 for (auto v : op->getOperands()) { 343 if (ShapedType type = dyn_cast<ShapedType>(v.getType())) { 344 auto shape = type.getShape(); 345 assert(shape.size() == 3); 346 if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") || 347 !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) { 348 return false; 349 } 350 } 351 } 352 } 353 return true; 354 } 355 356 // TransposeConv2d op: level check kH/kW, outpad, and stride 357 bool levelCheckTransposeConv2d(Operation *op) { 358 if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) { 359 if (ShapedType filterType = 360 dyn_cast<ShapedType>(transpose.getFilter().getType())) { 361 auto shape = filterType.getShape(); 362 assert(shape.size() == 4); 363 // level check kernel sizes for kH and KW 364 if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") || 365 !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) { 366 return false; 367 } 368 } 369 for (auto p : transpose.getOutPad()) { 370 if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { 371 return false; 372 } 373 } 374 for (auto s : transpose.getStride()) { 375 if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { 376 return false; 377 } 378 } 379 } 380 return true; 381 } 382 383 // Resize op: level check max scales 384 bool levelCheckResize(Operation *op) { 385 if (auto resize = dyn_cast<tosa::ResizeOp>(op)) { 386 auto scale = resize.getScale(); 387 int16_t scaleYN = scale[0]; 388 int16_t scaleYD = scale[1]; 389 int16_t scaleXN = scale[2]; 390 int16_t scaleXD = scale[3]; 391 if (!levelCheckScale(op, scaleYN / scaleYD, 392 "scale_y_n/scale_y_d <= MAX_SCALE") || 393 !levelCheckScale(op, scaleXN / scaleXD, 394 "scale_x_n/scale_x_d <= MAX_SCALE")) { 395 return false; 396 } 397 } 398 return true; 399 } 400 401 // configure profile and level values from pass options profileName and 402 // levelName 403 void configLevelAndProfile() { 404 tosaLevel = TOSA_LEVEL_NONE; 405 if (level == TosaLevelEnum::EightK) { 406 tosaLevel = TOSA_LEVEL_EIGHTK; 407 } 408 409 if (!profile.empty()) { 410 for (std::string &prof : profile) { 411 auto profSymbol = symbolizeTosaProfileEnum(prof); 412 if (profSymbol) { 413 enabled_profiles.push_back(profSymbol.value()); 414 } 415 } 416 } 417 } 418 419 bool CheckVariable(Operation *op); 420 bool CheckVariableReadOrWrite(Operation *op); 421 422 bool isValidElementType(Type type); 423 bool isEnabledProfile(TosaProfileEnum prof) { 424 return std::find(enabled_profiles.begin(), enabled_profiles.end(), prof) != 425 std::end(enabled_profiles); 426 } 427 428 SmallVector<std::function<LogicalResult(Operation *)>> constCheckers; 429 SmallVector<TosaProfileEnum, 3> enabled_profiles; 430 TosaLevel tosaLevel; 431 DenseMap<StringAttr, mlir::Type> variablesMap; 432 }; 433 434 LogicalResult TosaValidation::applyLevelCheck(Operation *op) { 435 if (tosaLevel == TOSA_LEVEL_NONE) { 436 // no need to do level checks 437 return success(); 438 } 439 440 if (!levelCheckRanks(op)) { 441 return failure(); 442 } 443 444 // additional level checks from spec 0.70 445 if (!levelCheckPool<tosa::AvgPool2dOp>(op) || 446 !levelCheckConv<tosa::Conv2DOp>(op) || 447 !levelCheckConv<tosa::Conv3DOp>(op) || 448 !levelCheckConv<tosa::DepthwiseConv2DOp>(op) || 449 !levelCheckFFT<tosa::FFT2dOp>(op) || 450 !levelCheckPool<tosa::MaxPool2dOp>(op) || 451 !levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) || 452 !levelCheckResize(op)) { 453 return failure(); 454 } 455 456 return success(); 457 } 458 459 inline bool CompatibleTypes(const mlir::Type &type, 460 const mlir::Type &declaredType) { 461 // for now, simply use type equality comparison 462 return type == declaredType; 463 } 464 465 bool TosaValidation::CheckVariable(Operation *op) { 466 if (isa<mlir::tosa::VariableOp>(op)) { 467 auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name")); 468 469 if (variablesMap.count(nameAttr)) { 470 op->emitOpError() << "name has already been declared"; 471 return false; 472 } 473 474 auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type")); 475 mlir::Type type = typeAttr.getValue(); 476 477 variablesMap[nameAttr] = type; 478 } 479 480 return true; 481 } 482 483 bool TosaValidation::CheckVariableReadOrWrite(Operation *op) { 484 if (isa<mlir::tosa::VariableReadOp>(op) || 485 isa<mlir::tosa::VariableWriteOp>(op)) { 486 auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name")); 487 488 if (!variablesMap.count(nameAttr)) { 489 op->emitOpError() << "name has not been declared"; 490 return false; 491 } 492 493 auto varType = variablesMap[nameAttr]; 494 495 for (auto v : op->getOperands()) { 496 auto type = v.getType(); 497 if (!CompatibleTypes(type, varType)) { 498 op->emitOpError() << "operand type does not equal variable type"; 499 return false; 500 } 501 } 502 503 for (auto v : op->getResults()) { 504 auto type = v.getType(); 505 if (!CompatibleTypes(type, varType)) { 506 op->emitOpError() << "result type does not equal variable type"; 507 return false; 508 } 509 } 510 } 511 512 return true; 513 } 514 515 LogicalResult TosaValidation::applyVariableCheck(Operation *op) { 516 if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) { 517 return failure(); 518 } 519 return success(); 520 } 521 522 bool TosaValidation::isValidElementType(Type type) { 523 if (isa<FloatType>(type)) { 524 if (!isEnabledProfile(TosaProfileEnum::MainInference)) 525 return false; 526 return type.isF32() || type.isF16() || type.isBF16(); 527 } else if (auto intTy = dyn_cast<IntegerType>(type)) { 528 if (intTy.isSignless()) { 529 switch (intTy.getWidth()) { 530 case 1: 531 case 4: 532 case 8: 533 case 16: 534 case 32: 535 case 48: 536 return true; 537 } 538 } 539 } else if (mlir::isa<tosa::shapeType>(type)) { 540 return true; 541 } 542 return false; 543 } 544 545 void TosaValidation::runOnOperation() { 546 configLevelAndProfile(); 547 548 TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); 549 if (!tosaDialect) 550 return; 551 552 getOperation().walk([&](Operation *op) { 553 if (op->getDialect() != tosaDialect) 554 return; 555 556 for (Value operand : op->getOperands()) { 557 auto elementTy = getElementTypeOrSelf(operand); 558 if (!isValidElementType(elementTy)) { 559 op->emitOpError() << "is not profile-aligned: element type " 560 << elementTy << " is not legal"; 561 return signalPassFailure(); 562 } 563 } 564 for (Type resultTy : op->getResultTypes()) { 565 auto elementTy = getElementTypeOrSelf(resultTy); 566 if (!isValidElementType(elementTy)) { 567 op->emitOpError() << "is not profile-aligned: element type " 568 << elementTy << " is not legal"; 569 return signalPassFailure(); 570 } 571 } 572 573 // Some uses of TOSA rely on the constant operands of particular 574 // operations. 575 if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op))) 576 signalPassFailure(); 577 578 // do level checks 579 if (failed(applyLevelCheck(op))) 580 signalPassFailure(); 581 582 // do variable type checks 583 if (failed(applyVariableCheck(op))) 584 signalPassFailure(); 585 }); 586 } 587 } // namespace 588