1 //===- Operation.cpp - MLIR Operation Class -------------------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 18 #include "mlir/IR/Operation.h" 19 #include "AttributeListStorage.h" 20 #include "mlir/IR/CFGFunction.h" 21 #include "mlir/IR/Dialect.h" 22 #include "mlir/IR/Instructions.h" 23 #include "mlir/IR/MLFunction.h" 24 #include "mlir/IR/MLIRContext.h" 25 #include "mlir/IR/OpDefinition.h" 26 #include "mlir/IR/OpImplementation.h" 27 #include "mlir/IR/Statements.h" 28 29 using namespace mlir; 30 31 /// Form the OperationName for an op with the specified string. This either is 32 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier 33 /// if not. 34 OperationName::OperationName(StringRef name, MLIRContext *context) { 35 if (auto *op = AbstractOperation::lookup(name, context)) 36 representation = op; 37 else 38 representation = Identifier::get(name, context); 39 } 40 41 /// Return the name of this operation. This always succeeds. 42 StringRef OperationName::getStringRef() const { 43 if (auto *op = representation.dyn_cast<const AbstractOperation *>()) 44 return op->name; 45 return representation.get<Identifier>().strref(); 46 } 47 48 const AbstractOperation *OperationName::getAbstractOperation() const { 49 return representation.dyn_cast<const AbstractOperation *>(); 50 } 51 52 OperationName OperationName::getFromOpaquePointer(void *pointer) { 53 return OperationName(RepresentationUnion::getFromOpaqueValue(pointer)); 54 } 55 56 OpAsmParser::~OpAsmParser() {} 57 58 //===----------------------------------------------------------------------===// 59 // Operation class 60 //===----------------------------------------------------------------------===// 61 62 Operation::Operation(bool isInstruction, OperationName name, 63 ArrayRef<NamedAttribute> attrs, MLIRContext *context) 64 : nameAndIsInstruction(name, isInstruction) { 65 this->attrs = AttributeListStorage::get(attrs, context); 66 67 #ifndef NDEBUG 68 for (auto elt : attrs) 69 assert(elt.second != nullptr && "Attributes cannot have null entries"); 70 #endif 71 } 72 73 Operation::~Operation() {} 74 75 /// Return the context this operation is associated with. 76 MLIRContext *Operation::getContext() const { 77 if (auto *inst = llvm::dyn_cast<Instruction>(this)) 78 return inst->getContext(); 79 return llvm::cast<OperationStmt>(this)->getContext(); 80 } 81 82 /// The source location the operation was defined or derived from. Note that 83 /// it is possible for this pointer to be null. 84 Location Operation::getLoc() const { 85 if (auto *inst = llvm::dyn_cast<Instruction>(this)) 86 return inst->getLoc(); 87 return llvm::cast<OperationStmt>(this)->getLoc(); 88 } 89 90 /// Set the source location the operation was defined or derived from. 91 void Operation::setLoc(Location loc) { 92 if (auto *inst = llvm::dyn_cast<Instruction>(this)) 93 inst->setLoc(loc); 94 else 95 llvm::cast<OperationStmt>(this)->setLoc(loc); 96 } 97 98 /// Return the function this operation is defined in. 99 Function *Operation::getOperationFunction() { 100 if (auto *inst = llvm::dyn_cast<Instruction>(this)) 101 return inst->getFunction(); 102 return llvm::cast<OperationStmt>(this)->findFunction(); 103 } 104 105 /// Return the number of operands this operation has. 106 unsigned Operation::getNumOperands() const { 107 if (auto *inst = llvm::dyn_cast<Instruction>(this)) 108 return inst->getNumOperands(); 109 110 return llvm::cast<OperationStmt>(this)->getNumOperands(); 111 } 112 113 SSAValue *Operation::getOperand(unsigned idx) { 114 if (auto *inst = llvm::dyn_cast<Instruction>(this)) 115 return inst->getOperand(idx); 116 117 return llvm::cast<OperationStmt>(this)->getOperand(idx); 118 } 119 120 void Operation::setOperand(unsigned idx, SSAValue *value) { 121 if (auto *inst = llvm::dyn_cast<Instruction>(this)) { 122 inst->setOperand(idx, llvm::cast<CFGValue>(value)); 123 } else { 124 auto *stmt = llvm::cast<OperationStmt>(this); 125 stmt->setOperand(idx, llvm::cast<MLValue>(value)); 126 } 127 } 128 129 /// Return the number of results this operation has. 130 unsigned Operation::getNumResults() const { 131 if (auto *inst = llvm::dyn_cast<Instruction>(this)) 132 return inst->getNumResults(); 133 134 return llvm::cast<OperationStmt>(this)->getNumResults(); 135 } 136 137 /// Return the indicated result. 138 SSAValue *Operation::getResult(unsigned idx) { 139 if (auto *inst = llvm::dyn_cast<Instruction>(this)) 140 return inst->getResult(idx); 141 142 return llvm::cast<OperationStmt>(this)->getResult(idx); 143 } 144 145 unsigned Operation::getNumSuccessors() const { 146 assert(isTerminator() && "Only terminators have successors."); 147 if (llvm::isa<Instruction>(this)) 148 return llvm::cast<Instruction>(this)->getNumSuccessors(); 149 150 // OperationStmt currently only has a return terminator. 151 assert(llvm::cast<OperationStmt>(this)->isReturn() && 152 "Unhandled OperationStmt terminator."); 153 return 0; 154 } 155 156 unsigned Operation::getNumSuccessorOperands(unsigned index) const { 157 assert(isTerminator() && "Only terminators have successors."); 158 assert(llvm::isa<Instruction>(this) && "Only instructions have successors."); 159 return llvm::cast<Instruction>(this)->getNumSuccessorOperands(index); 160 } 161 BasicBlock *Operation::getSuccessor(unsigned index) { 162 assert(isTerminator() && "Only terminators have successors."); 163 assert(llvm::isa<Instruction>(this) && 164 "Only instructions have basic block successors."); 165 return llvm::cast<Instruction>(this)->getSuccessor(index); 166 } 167 void Operation::setSuccessor(BasicBlock *block, unsigned index) { 168 assert(isTerminator() && "Only terminators have successors."); 169 assert(llvm::isa<Instruction>(this) && 170 "Only instructions have basic block successors."); 171 llvm::cast<Instruction>(this)->setSuccessor(block, index); 172 } 173 void Operation::addSuccessorOperand(unsigned index, SSAValue *value) { 174 assert(isTerminator() && "Only terminators have successors."); 175 assert(llvm::isa<Instruction>(this) && "Only instructions have successors."); 176 return llvm::cast<Instruction>(this)->addSuccessorOperand( 177 index, llvm::cast<CFGValue>(value)); 178 } 179 void Operation::eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) { 180 assert(isTerminator() && "Only terminators have successors."); 181 assert(llvm::isa<Instruction>(this) && "Only instructions have successors."); 182 return llvm::cast<Instruction>(this)->eraseSuccessorOperand(succIndex, 183 opIndex); 184 } 185 auto Operation::getSuccessorOperands(unsigned index) const 186 -> llvm::iterator_range<const_operand_iterator> { 187 assert(isTerminator() && "Only terminators have successors."); 188 assert(llvm::isa<Instruction>(this) && "Only instructions have successors."); 189 unsigned succOperandIndex = 190 llvm::cast<Instruction>(this)->getSuccessorOperandIndex(index); 191 return {const_operand_iterator(this, succOperandIndex), 192 const_operand_iterator(this, succOperandIndex + 193 getNumSuccessorOperands(index))}; 194 } 195 auto Operation::getSuccessorOperands(unsigned index) 196 -> llvm::iterator_range<operand_iterator> { 197 assert(isTerminator() && "Only terminators have successors."); 198 assert(llvm::isa<Instruction>(this) && "Only instructions have successors."); 199 unsigned succOperandIndex = 200 llvm::cast<Instruction>(this)->getSuccessorOperandIndex(index); 201 return {operand_iterator(this, succOperandIndex), 202 operand_iterator(this, 203 succOperandIndex + getNumSuccessorOperands(index))}; 204 } 205 206 /// Return true if there are no users of any results of this operation. 207 bool Operation::use_empty() const { 208 for (auto *result : getResults()) 209 if (!result->use_empty()) 210 return false; 211 return true; 212 } 213 214 void Operation::moveBefore(Operation *existingOp) { 215 if (auto *inst = llvm::dyn_cast<Instruction>(this)) 216 return inst->moveBefore(llvm::cast<Instruction>(existingOp)); 217 return llvm::cast<OperationStmt>(this)->moveBefore( 218 llvm::cast<OperationStmt>(existingOp)); 219 } 220 221 ArrayRef<NamedAttribute> Operation::getAttrs() const { 222 if (!attrs) 223 return {}; 224 return attrs->getElements(); 225 } 226 227 /// If an attribute exists with the specified name, change it to the new 228 /// value. Otherwise, add a new attribute with the specified name/value. 229 void Operation::setAttr(Identifier name, Attribute value) { 230 assert(value && "attributes may never be null"); 231 auto origAttrs = getAttrs(); 232 233 SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end()); 234 auto *context = getContext(); 235 236 // If we already have this attribute, replace it. 237 for (auto &elt : newAttrs) 238 if (elt.first == name) { 239 elt.second = value; 240 attrs = AttributeListStorage::get(newAttrs, context); 241 return; 242 } 243 244 // Otherwise, add it. 245 newAttrs.push_back({name, value}); 246 attrs = AttributeListStorage::get(newAttrs, context); 247 } 248 249 /// Remove the attribute with the specified name if it exists. The return 250 /// value indicates whether the attribute was present or not. 251 auto Operation::removeAttr(Identifier name) -> RemoveResult { 252 auto origAttrs = getAttrs(); 253 for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) { 254 if (origAttrs[i].first == name) { 255 SmallVector<NamedAttribute, 8> newAttrs; 256 newAttrs.reserve(origAttrs.size() - 1); 257 newAttrs.append(origAttrs.begin(), origAttrs.begin() + i); 258 newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end()); 259 attrs = AttributeListStorage::get(newAttrs, getContext()); 260 return RemoveResult::Removed; 261 } 262 } 263 return RemoveResult::NotFound; 264 } 265 266 /// Emit a note about this operation, reporting up to any diagnostic 267 /// handlers that may be listening. 268 void Operation::emitNote(const Twine &message) const { 269 getContext()->emitDiagnostic(getLoc(), message, 270 MLIRContext::DiagnosticKind::Note); 271 } 272 273 /// Emit a warning about this operation, reporting up to any diagnostic 274 /// handlers that may be listening. 275 void Operation::emitWarning(const Twine &message) const { 276 getContext()->emitDiagnostic(getLoc(), message, 277 MLIRContext::DiagnosticKind::Warning); 278 } 279 280 /// Emit an error about fatal conditions with this operation, reporting up to 281 /// any diagnostic handlers that may be listening. NOTE: This may terminate 282 /// the containing application, only use when the IR is in an inconsistent 283 /// state. 284 void Operation::emitError(const Twine &message) const { 285 getContext()->emitDiagnostic(getLoc(), message, 286 MLIRContext::DiagnosticKind::Error); 287 } 288 289 /// Emit an error with the op name prefixed, like "'dim' op " which is 290 /// convenient for verifiers. 291 bool Operation::emitOpError(const Twine &message) const { 292 emitError(Twine('\'') + getName().getStringRef() + "' op " + message); 293 return true; 294 } 295 296 /// Remove this operation from its parent block and delete it. 297 void Operation::erase() { 298 if (auto *inst = llvm::dyn_cast<Instruction>(this)) 299 return inst->erase(); 300 return llvm::cast<OperationStmt>(this)->erase(); 301 } 302 303 /// Attempt to constant fold this operation with the specified constant 304 /// operand values. If successful, this returns false and fills in the 305 /// results vector. If not, this returns true and results is unspecified. 306 bool Operation::constantFold(ArrayRef<Attribute> operands, 307 SmallVectorImpl<Attribute> &results) const { 308 if (auto *abstractOp = getAbstractOperation()) { 309 // If we have a registered operation definition matching this one, use it to 310 // try to constant fold the operation. 311 if (!abstractOp->constantFoldHook(this, operands, results)) 312 return false; 313 314 // Otherwise, fall back on the dialect hook to handle it. 315 return abstractOp->dialect.constantFoldHook(this, operands, results); 316 } 317 318 // If this operation hasn't been registered or doesn't have abstract 319 // operation, fall back to a dialect which matches the prefix. 320 auto opName = getName().getStringRef(); 321 if (auto *dialect = getContext()->getRegisteredDialect(opName)) { 322 return dialect->constantFoldHook(this, operands, results); 323 } 324 325 return true; 326 } 327 328 void Operation::print(raw_ostream &os) const { 329 if (auto *inst = llvm::dyn_cast<Instruction>(this)) 330 return inst->print(os); 331 return llvm::cast<OperationStmt>(this)->print(os); 332 } 333 334 void Operation::dump() const { 335 if (auto *inst = llvm::dyn_cast<Instruction>(this)) 336 return inst->dump(); 337 return llvm::cast<OperationStmt>(this)->dump(); 338 } 339 340 /// Methods for support type inquiry through isa, cast, and dyn_cast. 341 bool Operation::classof(const Statement *stmt) { 342 return stmt->getKind() == Statement::Kind::Operation; 343 } 344 bool Operation::classof(const IROperandOwner *ptr) { 345 return ptr->getKind() == IROperandOwner::Kind::Instruction || 346 ptr->getKind() == IROperandOwner::Kind::OperationStmt; 347 } 348 349 /// We need to teach the LLVM cast/dyn_cast etc logic how to cast from an 350 /// IROperandOwner* to Operation*. This can't be done with a simple pointer to 351 /// pointer cast because the pointer adjustment depends on whether the Owner is 352 /// dynamically an Instruction or Statement, because of multiple inheritance. 353 Operation * 354 llvm::cast_convert_val<mlir::Operation, mlir::IROperandOwner *, 355 mlir::IROperandOwner *>::doit(const mlir::IROperandOwner 356 *value) { 357 const Operation *op; 358 if (auto *ptr = dyn_cast<OperationStmt>(value)) 359 op = ptr; 360 else 361 op = cast<Instruction>(value); 362 return const_cast<Operation *>(op); 363 } 364 365 //===----------------------------------------------------------------------===// 366 // OpState trait class. 367 //===----------------------------------------------------------------------===// 368 369 // The fallback for the parser is to reject the short form. 370 bool OpState::parse(OpAsmParser *parser, OperationState *result) { 371 return parser->emitError(parser->getNameLoc(), "has no concise form"); 372 } 373 374 // The fallback for the printer is to print it the longhand form. 375 void OpState::print(OpAsmPrinter *p) const { 376 p->printDefaultOp(getOperation()); 377 } 378 379 /// Emit an error about fatal conditions with this operation, reporting up to 380 /// any diagnostic handlers that may be listening. NOTE: This may terminate 381 /// the containing application, only use when the IR is in an inconsistent 382 /// state. 383 void OpState::emitError(const Twine &message) const { 384 getOperation()->emitError(message); 385 } 386 387 /// Emit an error with the op name prefixed, like "'dim' op " which is 388 /// convenient for verifiers. 389 bool OpState::emitOpError(const Twine &message) const { 390 return getOperation()->emitOpError(message); 391 } 392 393 /// Emit a warning about this operation, reporting up to any diagnostic 394 /// handlers that may be listening. 395 void OpState::emitWarning(const Twine &message) const { 396 getOperation()->emitWarning(message); 397 } 398 399 /// Emit a note about this operation, reporting up to any diagnostic 400 /// handlers that may be listening. 401 void OpState::emitNote(const Twine &message) const { 402 getOperation()->emitNote(message); 403 } 404 405 //===----------------------------------------------------------------------===// 406 // Op Trait implementations 407 //===----------------------------------------------------------------------===// 408 409 bool OpTrait::impl::verifyZeroOperands(const Operation *op) { 410 if (op->getNumOperands() != 0) 411 return op->emitOpError("requires zero operands"); 412 return false; 413 } 414 415 bool OpTrait::impl::verifyOneOperand(const Operation *op) { 416 if (op->getNumOperands() != 1) 417 return op->emitOpError("requires a single operand"); 418 return false; 419 } 420 421 bool OpTrait::impl::verifyNOperands(const Operation *op, unsigned numOperands) { 422 if (op->getNumOperands() != numOperands) { 423 return op->emitOpError("expected " + Twine(numOperands) + 424 " operands, but found " + 425 Twine(op->getNumOperands())); 426 } 427 return false; 428 } 429 430 bool OpTrait::impl::verifyAtLeastNOperands(const Operation *op, 431 unsigned numOperands) { 432 if (op->getNumOperands() < numOperands) 433 return op->emitOpError("expected " + Twine(numOperands) + 434 " or more operands"); 435 return false; 436 } 437 438 /// If this is a vector type, or a tensor type, return the scalar element type 439 /// that it is built around, otherwise return the type unmodified. 440 static Type getTensorOrVectorElementType(Type type) { 441 if (auto vec = type.dyn_cast<VectorType>()) 442 return vec.getElementType(); 443 444 // Look through tensor<vector<...>> to find the underlying element type. 445 if (auto tensor = type.dyn_cast<TensorType>()) 446 return getTensorOrVectorElementType(tensor.getElementType()); 447 return type; 448 } 449 450 // Checks if the given type is an integer or an index type. Following LLVM's 451 // convention, returns true if the check fails and false otherwise. 452 static inline bool checkIntegerLikeType(Type type) { 453 return !(type.isa<IntegerType>() || type.isa<IndexType>()); 454 } 455 456 bool OpTrait::impl::verifyOperandsAreIntegerLike(const Operation *op) { 457 for (auto *operand : op->getOperands()) { 458 auto type = getTensorOrVectorElementType(operand->getType()); 459 if (checkIntegerLikeType(type)) 460 return op->emitOpError("requires an integer or index type"); 461 } 462 return false; 463 } 464 465 bool OpTrait::impl::verifySameTypeOperands(const Operation *op) { 466 // Zero or one operand always have the "same" type. 467 unsigned nOperands = op->getNumOperands(); 468 if (nOperands < 2) 469 return false; 470 471 auto type = op->getOperand(0)->getType(); 472 for (unsigned i = 1; i < nOperands; ++i) { 473 if (op->getOperand(i)->getType() != type) 474 return op->emitOpError("requires all operands to have the same type"); 475 } 476 return false; 477 } 478 479 bool OpTrait::impl::verifyZeroResult(const Operation *op) { 480 if (op->getNumResults() != 0) 481 return op->emitOpError("requires zero results"); 482 return false; 483 } 484 485 bool OpTrait::impl::verifyOneResult(const Operation *op) { 486 if (op->getNumResults() != 1) 487 return op->emitOpError("requires one result"); 488 return false; 489 } 490 491 bool OpTrait::impl::verifyNResults(const Operation *op, unsigned numOperands) { 492 if (op->getNumResults() != numOperands) 493 return op->emitOpError("expected " + Twine(numOperands) + " results"); 494 return false; 495 } 496 497 bool OpTrait::impl::verifyAtLeastNResults(const Operation *op, 498 unsigned numOperands) { 499 if (op->getNumResults() < numOperands) 500 return op->emitOpError("expected " + Twine(numOperands) + 501 " or more results"); 502 return false; 503 } 504 505 /// Returns false if the given two types have the same shape. That is, 506 /// they are both scalars, or they are both vectors / ranked tensors with 507 /// the same dimension specifications. The element type does not matter. 508 static bool verifyShapeMatch(Type type1, Type type2) { 509 // Check scalar cases 510 if (type1.isa<IntegerType>() || type1.isa<FloatType>() || 511 type1.isa<IndexType>()) 512 return !(type2.isa<IntegerType>() || type2.isa<FloatType>() || 513 type2.isa<IndexType>()); 514 515 // Check unranked tensor cases 516 if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) 517 return true; 518 519 // Check normal vector/tensor cases 520 if (auto vtType1 = type1.dyn_cast<VectorOrTensorType>()) { 521 auto vtType2 = type2.dyn_cast<VectorOrTensorType>(); 522 return !(vtType2 && vtType1.getShape() == vtType2.getShape()); 523 } 524 525 return false; 526 } 527 528 bool OpTrait::impl::verifySameOperandsAndResultShape(const Operation *op) { 529 if (op->getNumOperands() == 0 || op->getNumResults() == 0) 530 return true; 531 532 auto type = op->getOperand(0)->getType(); 533 for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) { 534 if (verifyShapeMatch(op->getResult(i)->getType(), type)) 535 return op->emitOpError( 536 "requires the same shape for all operands and results"); 537 } 538 for (unsigned i = 1, e = op->getNumOperands(); i < e; ++i) { 539 if (verifyShapeMatch(op->getOperand(i)->getType(), type)) 540 return op->emitOpError( 541 "requires the same shape for all operands and results"); 542 } 543 return false; 544 } 545 546 bool OpTrait::impl::verifySameOperandsAndResultType(const Operation *op) { 547 if (op->getNumOperands() == 0 || op->getNumResults() == 0) 548 return true; 549 550 auto type = op->getResult(0)->getType(); 551 for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) { 552 if (op->getResult(i)->getType() != type) 553 return op->emitOpError( 554 "requires the same type for all operands and results"); 555 } 556 for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) { 557 if (op->getOperand(i)->getType() != type) 558 return op->emitOpError( 559 "requires the same type for all operands and results"); 560 } 561 return false; 562 } 563 564 static bool verifyBBArguments( 565 llvm::iterator_range<Operation::const_operand_iterator> operands, 566 const BasicBlock *destBB, const Operation *op) { 567 unsigned operandCount = std::distance(operands.begin(), operands.end()); 568 if (operandCount != destBB->getNumArguments()) { 569 op->emitError("branch has " + Twine(operandCount) + 570 " operands, but target block has " + 571 Twine(destBB->getNumArguments())); 572 return true; 573 } 574 575 auto operandIt = operands.begin(); 576 for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) { 577 if ((*operandIt)->getType() != destBB->getArgument(i)->getType()) { 578 op->emitError("type mismatch in bb argument #" + Twine(i)); 579 return true; 580 } 581 } 582 583 return false; 584 } 585 586 static bool verifyTerminatorSuccessors(const Operation *op) { 587 // Verify that the operands lines up with the BB arguments in the successor. 588 const Function *fn = op->getOperationFunction(); 589 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { 590 auto *succ = op->getSuccessor(i); 591 if (succ->getFunction() != fn) { 592 op->emitError("reference to block defined in another function"); 593 return true; 594 } 595 if (verifyBBArguments(op->getSuccessorOperands(i), succ, op)) 596 return true; 597 } 598 return false; 599 } 600 601 bool OpTrait::impl::verifyIsTerminator(const Operation *op) { 602 // Verify that the operation is at the end of the respective parent block. 603 if (auto *stmt = dyn_cast<OperationStmt>(op)) { 604 StmtBlock *block = stmt->getBlock(); 605 if (!block || !isa<MLFunction>(block) || &block->back() != stmt) 606 return op->emitOpError("must be the last statement in the ML function"); 607 } else { 608 const Instruction *inst = cast<Instruction>(op); 609 const BasicBlock *block = inst->getBlock(); 610 if (!block || &block->back() != inst) 611 return op->emitOpError( 612 "must be the last instruction in the parent basic block."); 613 } 614 615 // Verify the state of the successor blocks. 616 if (op->getNumSuccessors() != 0 && verifyTerminatorSuccessors(op)) 617 return true; 618 return false; 619 } 620 621 bool OpTrait::impl::verifyResultsAreBoolLike(const Operation *op) { 622 for (auto *result : op->getResults()) { 623 auto elementType = getTensorOrVectorElementType(result->getType()); 624 auto intType = elementType.dyn_cast<IntegerType>(); 625 bool isBoolType = intType && intType.getWidth() == 1; 626 if (!isBoolType) 627 return op->emitOpError("requires a bool result type"); 628 } 629 630 return false; 631 } 632 633 bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) { 634 for (auto *result : op->getResults()) { 635 if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>()) 636 return op->emitOpError("requires a floating point type"); 637 } 638 639 return false; 640 } 641 642 bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) { 643 for (auto *result : op->getResults()) { 644 auto type = getTensorOrVectorElementType(result->getType()); 645 if (checkIntegerLikeType(type)) 646 return op->emitOpError("requires an integer or index type"); 647 } 648 return false; 649 } 650 651 //===----------------------------------------------------------------------===// 652 // BinaryOp implementation 653 //===----------------------------------------------------------------------===// 654 655 // These functions are out-of-line implementations of the methods in BinaryOp, 656 // which avoids them being template instantiated/duplicated. 657 658 void impl::buildBinaryOp(Builder *builder, OperationState *result, 659 SSAValue *lhs, SSAValue *rhs) { 660 assert(lhs->getType() == rhs->getType()); 661 result->addOperands({lhs, rhs}); 662 result->types.push_back(lhs->getType()); 663 } 664 665 bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) { 666 SmallVector<OpAsmParser::OperandType, 2> ops; 667 Type type; 668 return parser->parseOperandList(ops, 2) || 669 parser->parseOptionalAttributeDict(result->attributes) || 670 parser->parseColonType(type) || 671 parser->resolveOperands(ops, type, result->operands) || 672 parser->addTypeToList(type, result->types); 673 } 674 675 void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) { 676 *p << op->getName() << ' ' << *op->getOperand(0) << ", " 677 << *op->getOperand(1); 678 p->printOptionalAttrDict(op->getAttrs()); 679 *p << " : " << op->getResult(0)->getType(); 680 } 681 682 //===----------------------------------------------------------------------===// 683 // CastOp implementation 684 //===----------------------------------------------------------------------===// 685 686 void impl::buildCastOp(Builder *builder, OperationState *result, 687 SSAValue *source, Type destType) { 688 result->addOperands(source); 689 result->addTypes(destType); 690 } 691 692 bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) { 693 OpAsmParser::OperandType srcInfo; 694 Type srcType, dstType; 695 return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) || 696 parser->resolveOperand(srcInfo, srcType, result->operands) || 697 parser->parseKeywordType("to", dstType) || 698 parser->addTypeToList(dstType, result->types); 699 } 700 701 void impl::printCastOp(const Operation *op, OpAsmPrinter *p) { 702 *p << op->getName() << ' ' << *op->getOperand(0) << " : " 703 << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType(); 704 } 705