xref: /llvm-project/mlir/lib/IR/Operation.cpp (revision 1f5330ac9028b577ff5496b57a96757512ca8dda)
1 //===- Operation.cpp - MLIR Operation Class -------------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/IR/Operation.h"
19 #include "AttributeListStorage.h"
20 #include "mlir/IR/CFGFunction.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/Instructions.h"
23 #include "mlir/IR/MLFunction.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/OpDefinition.h"
26 #include "mlir/IR/OpImplementation.h"
27 #include "mlir/IR/Statements.h"
28 
29 using namespace mlir;
30 
31 /// Form the OperationName for an op with the specified string.  This either is
32 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier
33 /// if not.
34 OperationName::OperationName(StringRef name, MLIRContext *context) {
35   if (auto *op = AbstractOperation::lookup(name, context))
36     representation = op;
37   else
38     representation = Identifier::get(name, context);
39 }
40 
41 /// Return the name of this operation.  This always succeeds.
42 StringRef OperationName::getStringRef() const {
43   if (auto *op = representation.dyn_cast<const AbstractOperation *>())
44     return op->name;
45   return representation.get<Identifier>().strref();
46 }
47 
48 const AbstractOperation *OperationName::getAbstractOperation() const {
49   return representation.dyn_cast<const AbstractOperation *>();
50 }
51 
52 OperationName OperationName::getFromOpaquePointer(void *pointer) {
53   return OperationName(RepresentationUnion::getFromOpaqueValue(pointer));
54 }
55 
56 OpAsmParser::~OpAsmParser() {}
57 
58 //===----------------------------------------------------------------------===//
59 // Operation class
60 //===----------------------------------------------------------------------===//
61 
62 Operation::Operation(bool isInstruction, OperationName name,
63                      ArrayRef<NamedAttribute> attrs, MLIRContext *context)
64     : nameAndIsInstruction(name, isInstruction) {
65   this->attrs = AttributeListStorage::get(attrs, context);
66 
67 #ifndef NDEBUG
68   for (auto elt : attrs)
69     assert(elt.second != nullptr && "Attributes cannot have null entries");
70 #endif
71 }
72 
73 Operation::~Operation() {}
74 
75 /// Return the context this operation is associated with.
76 MLIRContext *Operation::getContext() const {
77   if (auto *inst = llvm::dyn_cast<Instruction>(this))
78     return inst->getContext();
79   return llvm::cast<OperationStmt>(this)->getContext();
80 }
81 
82 /// The source location the operation was defined or derived from.  Note that
83 /// it is possible for this pointer to be null.
84 Location Operation::getLoc() const {
85   if (auto *inst = llvm::dyn_cast<Instruction>(this))
86     return inst->getLoc();
87   return llvm::cast<OperationStmt>(this)->getLoc();
88 }
89 
90 /// Set the source location the operation was defined or derived from.
91 void Operation::setLoc(Location loc) {
92   if (auto *inst = llvm::dyn_cast<Instruction>(this))
93     inst->setLoc(loc);
94   else
95     llvm::cast<OperationStmt>(this)->setLoc(loc);
96 }
97 
98 /// Return the function this operation is defined in.
99 Function *Operation::getOperationFunction() {
100   if (auto *inst = llvm::dyn_cast<Instruction>(this))
101     return inst->getFunction();
102   return llvm::cast<OperationStmt>(this)->findFunction();
103 }
104 
105 /// Return the number of operands this operation has.
106 unsigned Operation::getNumOperands() const {
107   if (auto *inst = llvm::dyn_cast<Instruction>(this))
108     return inst->getNumOperands();
109 
110   return llvm::cast<OperationStmt>(this)->getNumOperands();
111 }
112 
113 SSAValue *Operation::getOperand(unsigned idx) {
114   if (auto *inst = llvm::dyn_cast<Instruction>(this))
115     return inst->getOperand(idx);
116 
117   return llvm::cast<OperationStmt>(this)->getOperand(idx);
118 }
119 
120 void Operation::setOperand(unsigned idx, SSAValue *value) {
121   if (auto *inst = llvm::dyn_cast<Instruction>(this)) {
122     inst->setOperand(idx, llvm::cast<CFGValue>(value));
123   } else {
124     auto *stmt = llvm::cast<OperationStmt>(this);
125     stmt->setOperand(idx, llvm::cast<MLValue>(value));
126   }
127 }
128 
129 /// Return the number of results this operation has.
130 unsigned Operation::getNumResults() const {
131   if (auto *inst = llvm::dyn_cast<Instruction>(this))
132     return inst->getNumResults();
133 
134   return llvm::cast<OperationStmt>(this)->getNumResults();
135 }
136 
137 /// Return the indicated result.
138 SSAValue *Operation::getResult(unsigned idx) {
139   if (auto *inst = llvm::dyn_cast<Instruction>(this))
140     return inst->getResult(idx);
141 
142   return llvm::cast<OperationStmt>(this)->getResult(idx);
143 }
144 
145 unsigned Operation::getNumSuccessors() const {
146   assert(isTerminator() && "Only terminators have successors.");
147   if (llvm::isa<Instruction>(this))
148     return llvm::cast<Instruction>(this)->getNumSuccessors();
149 
150   // OperationStmt currently only has a return terminator.
151   assert(llvm::cast<OperationStmt>(this)->isReturn() &&
152          "Unhandled OperationStmt terminator.");
153   return 0;
154 }
155 
156 unsigned Operation::getNumSuccessorOperands(unsigned index) const {
157   assert(isTerminator() && "Only terminators have successors.");
158   assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
159   return llvm::cast<Instruction>(this)->getNumSuccessorOperands(index);
160 }
161 BasicBlock *Operation::getSuccessor(unsigned index) {
162   assert(isTerminator() && "Only terminators have successors.");
163   assert(llvm::isa<Instruction>(this) &&
164          "Only instructions have basic block successors.");
165   return llvm::cast<Instruction>(this)->getSuccessor(index);
166 }
167 void Operation::setSuccessor(BasicBlock *block, unsigned index) {
168   assert(isTerminator() && "Only terminators have successors.");
169   assert(llvm::isa<Instruction>(this) &&
170          "Only instructions have basic block successors.");
171   llvm::cast<Instruction>(this)->setSuccessor(block, index);
172 }
173 void Operation::addSuccessorOperand(unsigned index, SSAValue *value) {
174   assert(isTerminator() && "Only terminators have successors.");
175   assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
176   return llvm::cast<Instruction>(this)->addSuccessorOperand(
177       index, llvm::cast<CFGValue>(value));
178 }
179 void Operation::eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) {
180   assert(isTerminator() && "Only terminators have successors.");
181   assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
182   return llvm::cast<Instruction>(this)->eraseSuccessorOperand(succIndex,
183                                                               opIndex);
184 }
185 auto Operation::getSuccessorOperands(unsigned index) const
186     -> llvm::iterator_range<const_operand_iterator> {
187   assert(isTerminator() && "Only terminators have successors.");
188   assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
189   unsigned succOperandIndex =
190       llvm::cast<Instruction>(this)->getSuccessorOperandIndex(index);
191   return {const_operand_iterator(this, succOperandIndex),
192           const_operand_iterator(this, succOperandIndex +
193                                            getNumSuccessorOperands(index))};
194 }
195 auto Operation::getSuccessorOperands(unsigned index)
196     -> llvm::iterator_range<operand_iterator> {
197   assert(isTerminator() && "Only terminators have successors.");
198   assert(llvm::isa<Instruction>(this) && "Only instructions have successors.");
199   unsigned succOperandIndex =
200       llvm::cast<Instruction>(this)->getSuccessorOperandIndex(index);
201   return {operand_iterator(this, succOperandIndex),
202           operand_iterator(this,
203                            succOperandIndex + getNumSuccessorOperands(index))};
204 }
205 
206 /// Return true if there are no users of any results of this operation.
207 bool Operation::use_empty() const {
208   for (auto *result : getResults())
209     if (!result->use_empty())
210       return false;
211   return true;
212 }
213 
214 void Operation::moveBefore(Operation *existingOp) {
215   if (auto *inst = llvm::dyn_cast<Instruction>(this))
216     return inst->moveBefore(llvm::cast<Instruction>(existingOp));
217   return llvm::cast<OperationStmt>(this)->moveBefore(
218       llvm::cast<OperationStmt>(existingOp));
219 }
220 
221 ArrayRef<NamedAttribute> Operation::getAttrs() const {
222   if (!attrs)
223     return {};
224   return attrs->getElements();
225 }
226 
227 /// If an attribute exists with the specified name, change it to the new
228 /// value.  Otherwise, add a new attribute with the specified name/value.
229 void Operation::setAttr(Identifier name, Attribute value) {
230   assert(value && "attributes may never be null");
231   auto origAttrs = getAttrs();
232 
233   SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
234   auto *context = getContext();
235 
236   // If we already have this attribute, replace it.
237   for (auto &elt : newAttrs)
238     if (elt.first == name) {
239       elt.second = value;
240       attrs = AttributeListStorage::get(newAttrs, context);
241       return;
242     }
243 
244   // Otherwise, add it.
245   newAttrs.push_back({name, value});
246   attrs = AttributeListStorage::get(newAttrs, context);
247 }
248 
249 /// Remove the attribute with the specified name if it exists.  The return
250 /// value indicates whether the attribute was present or not.
251 auto Operation::removeAttr(Identifier name) -> RemoveResult {
252   auto origAttrs = getAttrs();
253   for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
254     if (origAttrs[i].first == name) {
255       SmallVector<NamedAttribute, 8> newAttrs;
256       newAttrs.reserve(origAttrs.size() - 1);
257       newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
258       newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
259       attrs = AttributeListStorage::get(newAttrs, getContext());
260       return RemoveResult::Removed;
261     }
262   }
263   return RemoveResult::NotFound;
264 }
265 
266 /// Emit a note about this operation, reporting up to any diagnostic
267 /// handlers that may be listening.
268 void Operation::emitNote(const Twine &message) const {
269   getContext()->emitDiagnostic(getLoc(), message,
270                                MLIRContext::DiagnosticKind::Note);
271 }
272 
273 /// Emit a warning about this operation, reporting up to any diagnostic
274 /// handlers that may be listening.
275 void Operation::emitWarning(const Twine &message) const {
276   getContext()->emitDiagnostic(getLoc(), message,
277                                MLIRContext::DiagnosticKind::Warning);
278 }
279 
280 /// Emit an error about fatal conditions with this operation, reporting up to
281 /// any diagnostic handlers that may be listening.  NOTE: This may terminate
282 /// the containing application, only use when the IR is in an inconsistent
283 /// state.
284 void Operation::emitError(const Twine &message) const {
285   getContext()->emitDiagnostic(getLoc(), message,
286                                MLIRContext::DiagnosticKind::Error);
287 }
288 
289 /// Emit an error with the op name prefixed, like "'dim' op " which is
290 /// convenient for verifiers.
291 bool Operation::emitOpError(const Twine &message) const {
292   emitError(Twine('\'') + getName().getStringRef() + "' op " + message);
293   return true;
294 }
295 
296 /// Remove this operation from its parent block and delete it.
297 void Operation::erase() {
298   if (auto *inst = llvm::dyn_cast<Instruction>(this))
299     return inst->erase();
300   return llvm::cast<OperationStmt>(this)->erase();
301 }
302 
303 /// Attempt to constant fold this operation with the specified constant
304 /// operand values.  If successful, this returns false and fills in the
305 /// results vector.  If not, this returns true and results is unspecified.
306 bool Operation::constantFold(ArrayRef<Attribute> operands,
307                              SmallVectorImpl<Attribute> &results) const {
308   if (auto *abstractOp = getAbstractOperation()) {
309     // If we have a registered operation definition matching this one, use it to
310     // try to constant fold the operation.
311     if (!abstractOp->constantFoldHook(this, operands, results))
312       return false;
313 
314     // Otherwise, fall back on the dialect hook to handle it.
315     return abstractOp->dialect.constantFoldHook(this, operands, results);
316   }
317 
318   // If this operation hasn't been registered or doesn't have abstract
319   // operation, fall back to a dialect which matches the prefix.
320   auto opName = getName().getStringRef();
321   if (auto *dialect = getContext()->getRegisteredDialect(opName)) {
322     return dialect->constantFoldHook(this, operands, results);
323   }
324 
325   return true;
326 }
327 
328 void Operation::print(raw_ostream &os) const {
329   if (auto *inst = llvm::dyn_cast<Instruction>(this))
330     return inst->print(os);
331   return llvm::cast<OperationStmt>(this)->print(os);
332 }
333 
334 void Operation::dump() const {
335   if (auto *inst = llvm::dyn_cast<Instruction>(this))
336     return inst->dump();
337   return llvm::cast<OperationStmt>(this)->dump();
338 }
339 
340 /// Methods for support type inquiry through isa, cast, and dyn_cast.
341 bool Operation::classof(const Statement *stmt) {
342   return stmt->getKind() == Statement::Kind::Operation;
343 }
344 bool Operation::classof(const IROperandOwner *ptr) {
345   return ptr->getKind() == IROperandOwner::Kind::Instruction ||
346          ptr->getKind() == IROperandOwner::Kind::OperationStmt;
347 }
348 
349 /// We need to teach the LLVM cast/dyn_cast etc logic how to cast from an
350 /// IROperandOwner* to Operation*.  This can't be done with a simple pointer to
351 /// pointer cast because the pointer adjustment depends on whether the Owner is
352 /// dynamically an Instruction or Statement, because of multiple inheritance.
353 Operation *
354 llvm::cast_convert_val<mlir::Operation, mlir::IROperandOwner *,
355                        mlir::IROperandOwner *>::doit(const mlir::IROperandOwner
356                                                          *value) {
357   const Operation *op;
358   if (auto *ptr = dyn_cast<OperationStmt>(value))
359     op = ptr;
360   else
361     op = cast<Instruction>(value);
362   return const_cast<Operation *>(op);
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // OpState trait class.
367 //===----------------------------------------------------------------------===//
368 
369 // The fallback for the parser is to reject the short form.
370 bool OpState::parse(OpAsmParser *parser, OperationState *result) {
371   return parser->emitError(parser->getNameLoc(), "has no concise form");
372 }
373 
374 // The fallback for the printer is to print it the longhand form.
375 void OpState::print(OpAsmPrinter *p) const {
376   p->printDefaultOp(getOperation());
377 }
378 
379 /// Emit an error about fatal conditions with this operation, reporting up to
380 /// any diagnostic handlers that may be listening.  NOTE: This may terminate
381 /// the containing application, only use when the IR is in an inconsistent
382 /// state.
383 void OpState::emitError(const Twine &message) const {
384   getOperation()->emitError(message);
385 }
386 
387 /// Emit an error with the op name prefixed, like "'dim' op " which is
388 /// convenient for verifiers.
389 bool OpState::emitOpError(const Twine &message) const {
390   return getOperation()->emitOpError(message);
391 }
392 
393 /// Emit a warning about this operation, reporting up to any diagnostic
394 /// handlers that may be listening.
395 void OpState::emitWarning(const Twine &message) const {
396   getOperation()->emitWarning(message);
397 }
398 
399 /// Emit a note about this operation, reporting up to any diagnostic
400 /// handlers that may be listening.
401 void OpState::emitNote(const Twine &message) const {
402   getOperation()->emitNote(message);
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // Op Trait implementations
407 //===----------------------------------------------------------------------===//
408 
409 bool OpTrait::impl::verifyZeroOperands(const Operation *op) {
410   if (op->getNumOperands() != 0)
411     return op->emitOpError("requires zero operands");
412   return false;
413 }
414 
415 bool OpTrait::impl::verifyOneOperand(const Operation *op) {
416   if (op->getNumOperands() != 1)
417     return op->emitOpError("requires a single operand");
418   return false;
419 }
420 
421 bool OpTrait::impl::verifyNOperands(const Operation *op, unsigned numOperands) {
422   if (op->getNumOperands() != numOperands) {
423     return op->emitOpError("expected " + Twine(numOperands) +
424                            " operands, but found " +
425                            Twine(op->getNumOperands()));
426   }
427   return false;
428 }
429 
430 bool OpTrait::impl::verifyAtLeastNOperands(const Operation *op,
431                                            unsigned numOperands) {
432   if (op->getNumOperands() < numOperands)
433     return op->emitOpError("expected " + Twine(numOperands) +
434                            " or more operands");
435   return false;
436 }
437 
438 /// If this is a vector type, or a tensor type, return the scalar element type
439 /// that it is built around, otherwise return the type unmodified.
440 static Type getTensorOrVectorElementType(Type type) {
441   if (auto vec = type.dyn_cast<VectorType>())
442     return vec.getElementType();
443 
444   // Look through tensor<vector<...>> to find the underlying element type.
445   if (auto tensor = type.dyn_cast<TensorType>())
446     return getTensorOrVectorElementType(tensor.getElementType());
447   return type;
448 }
449 
450 // Checks if the given type is an integer or an index type.  Following LLVM's
451 // convention, returns true if the check fails and false otherwise.
452 static inline bool checkIntegerLikeType(Type type) {
453   return !(type.isa<IntegerType>() || type.isa<IndexType>());
454 }
455 
456 bool OpTrait::impl::verifyOperandsAreIntegerLike(const Operation *op) {
457   for (auto *operand : op->getOperands()) {
458     auto type = getTensorOrVectorElementType(operand->getType());
459     if (checkIntegerLikeType(type))
460       return op->emitOpError("requires an integer or index type");
461   }
462   return false;
463 }
464 
465 bool OpTrait::impl::verifySameTypeOperands(const Operation *op) {
466   // Zero or one operand always have the "same" type.
467   unsigned nOperands = op->getNumOperands();
468   if (nOperands < 2)
469     return false;
470 
471   auto type = op->getOperand(0)->getType();
472   for (unsigned i = 1; i < nOperands; ++i) {
473     if (op->getOperand(i)->getType() != type)
474       return op->emitOpError("requires all operands to have the same type");
475   }
476   return false;
477 }
478 
479 bool OpTrait::impl::verifyZeroResult(const Operation *op) {
480   if (op->getNumResults() != 0)
481     return op->emitOpError("requires zero results");
482   return false;
483 }
484 
485 bool OpTrait::impl::verifyOneResult(const Operation *op) {
486   if (op->getNumResults() != 1)
487     return op->emitOpError("requires one result");
488   return false;
489 }
490 
491 bool OpTrait::impl::verifyNResults(const Operation *op, unsigned numOperands) {
492   if (op->getNumResults() != numOperands)
493     return op->emitOpError("expected " + Twine(numOperands) + " results");
494   return false;
495 }
496 
497 bool OpTrait::impl::verifyAtLeastNResults(const Operation *op,
498                                           unsigned numOperands) {
499   if (op->getNumResults() < numOperands)
500     return op->emitOpError("expected " + Twine(numOperands) +
501                            " or more results");
502   return false;
503 }
504 
505 /// Returns false if the given two types have the same shape. That is,
506 /// they are both scalars, or they are both vectors / ranked tensors with
507 /// the same dimension specifications. The element type does not matter.
508 static bool verifyShapeMatch(Type type1, Type type2) {
509   // Check scalar cases
510   if (type1.isa<IntegerType>() || type1.isa<FloatType>() ||
511       type1.isa<IndexType>())
512     return !(type2.isa<IntegerType>() || type2.isa<FloatType>() ||
513              type2.isa<IndexType>());
514 
515   // Check unranked tensor cases
516   if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>())
517     return true;
518 
519   // Check normal vector/tensor cases
520   if (auto vtType1 = type1.dyn_cast<VectorOrTensorType>()) {
521     auto vtType2 = type2.dyn_cast<VectorOrTensorType>();
522     return !(vtType2 && vtType1.getShape() == vtType2.getShape());
523   }
524 
525   return false;
526 }
527 
528 bool OpTrait::impl::verifySameOperandsAndResultShape(const Operation *op) {
529   if (op->getNumOperands() == 0 || op->getNumResults() == 0)
530     return true;
531 
532   auto type = op->getOperand(0)->getType();
533   for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) {
534     if (verifyShapeMatch(op->getResult(i)->getType(), type))
535       return op->emitOpError(
536           "requires the same shape for all operands and results");
537   }
538   for (unsigned i = 1, e = op->getNumOperands(); i < e; ++i) {
539     if (verifyShapeMatch(op->getOperand(i)->getType(), type))
540       return op->emitOpError(
541           "requires the same shape for all operands and results");
542   }
543   return false;
544 }
545 
546 bool OpTrait::impl::verifySameOperandsAndResultType(const Operation *op) {
547   if (op->getNumOperands() == 0 || op->getNumResults() == 0)
548     return true;
549 
550   auto type = op->getResult(0)->getType();
551   for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
552     if (op->getResult(i)->getType() != type)
553       return op->emitOpError(
554           "requires the same type for all operands and results");
555   }
556   for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) {
557     if (op->getOperand(i)->getType() != type)
558       return op->emitOpError(
559           "requires the same type for all operands and results");
560   }
561   return false;
562 }
563 
564 static bool verifyBBArguments(
565     llvm::iterator_range<Operation::const_operand_iterator> operands,
566     const BasicBlock *destBB, const Operation *op) {
567   unsigned operandCount = std::distance(operands.begin(), operands.end());
568   if (operandCount != destBB->getNumArguments()) {
569     op->emitError("branch has " + Twine(operandCount) +
570                   " operands, but target block has " +
571                   Twine(destBB->getNumArguments()));
572     return true;
573   }
574 
575   auto operandIt = operands.begin();
576   for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) {
577     if ((*operandIt)->getType() != destBB->getArgument(i)->getType()) {
578       op->emitError("type mismatch in bb argument #" + Twine(i));
579       return true;
580     }
581   }
582 
583   return false;
584 }
585 
586 static bool verifyTerminatorSuccessors(const Operation *op) {
587   // Verify that the operands lines up with the BB arguments in the successor.
588   const Function *fn = op->getOperationFunction();
589   for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
590     auto *succ = op->getSuccessor(i);
591     if (succ->getFunction() != fn) {
592       op->emitError("reference to block defined in another function");
593       return true;
594     }
595     if (verifyBBArguments(op->getSuccessorOperands(i), succ, op))
596       return true;
597   }
598   return false;
599 }
600 
601 bool OpTrait::impl::verifyIsTerminator(const Operation *op) {
602   // Verify that the operation is at the end of the respective parent block.
603   if (auto *stmt = dyn_cast<OperationStmt>(op)) {
604     StmtBlock *block = stmt->getBlock();
605     if (!block || !isa<MLFunction>(block) || &block->back() != stmt)
606       return op->emitOpError("must be the last statement in the ML function");
607   } else {
608     const Instruction *inst = cast<Instruction>(op);
609     const BasicBlock *block = inst->getBlock();
610     if (!block || &block->back() != inst)
611       return op->emitOpError(
612           "must be the last instruction in the parent basic block.");
613   }
614 
615   // Verify the state of the successor blocks.
616   if (op->getNumSuccessors() != 0 && verifyTerminatorSuccessors(op))
617     return true;
618   return false;
619 }
620 
621 bool OpTrait::impl::verifyResultsAreBoolLike(const Operation *op) {
622   for (auto *result : op->getResults()) {
623     auto elementType = getTensorOrVectorElementType(result->getType());
624     auto intType = elementType.dyn_cast<IntegerType>();
625     bool isBoolType = intType && intType.getWidth() == 1;
626     if (!isBoolType)
627       return op->emitOpError("requires a bool result type");
628   }
629 
630   return false;
631 }
632 
633 bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
634   for (auto *result : op->getResults()) {
635     if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
636       return op->emitOpError("requires a floating point type");
637   }
638 
639   return false;
640 }
641 
642 bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
643   for (auto *result : op->getResults()) {
644     auto type = getTensorOrVectorElementType(result->getType());
645     if (checkIntegerLikeType(type))
646       return op->emitOpError("requires an integer or index type");
647   }
648   return false;
649 }
650 
651 //===----------------------------------------------------------------------===//
652 // BinaryOp implementation
653 //===----------------------------------------------------------------------===//
654 
655 // These functions are out-of-line implementations of the methods in BinaryOp,
656 // which avoids them being template instantiated/duplicated.
657 
658 void impl::buildBinaryOp(Builder *builder, OperationState *result,
659                          SSAValue *lhs, SSAValue *rhs) {
660   assert(lhs->getType() == rhs->getType());
661   result->addOperands({lhs, rhs});
662   result->types.push_back(lhs->getType());
663 }
664 
665 bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
666   SmallVector<OpAsmParser::OperandType, 2> ops;
667   Type type;
668   return parser->parseOperandList(ops, 2) ||
669          parser->parseOptionalAttributeDict(result->attributes) ||
670          parser->parseColonType(type) ||
671          parser->resolveOperands(ops, type, result->operands) ||
672          parser->addTypeToList(type, result->types);
673 }
674 
675 void impl::printBinaryOp(const Operation *op, OpAsmPrinter *p) {
676   *p << op->getName() << ' ' << *op->getOperand(0) << ", "
677      << *op->getOperand(1);
678   p->printOptionalAttrDict(op->getAttrs());
679   *p << " : " << op->getResult(0)->getType();
680 }
681 
682 //===----------------------------------------------------------------------===//
683 // CastOp implementation
684 //===----------------------------------------------------------------------===//
685 
686 void impl::buildCastOp(Builder *builder, OperationState *result,
687                        SSAValue *source, Type destType) {
688   result->addOperands(source);
689   result->addTypes(destType);
690 }
691 
692 bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
693   OpAsmParser::OperandType srcInfo;
694   Type srcType, dstType;
695   return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
696          parser->resolveOperand(srcInfo, srcType, result->operands) ||
697          parser->parseKeywordType("to", dstType) ||
698          parser->addTypeToList(dstType, result->types);
699 }
700 
701 void impl::printCastOp(const Operation *op, OpAsmPrinter *p) {
702   *p << op->getName() << ' ' << *op->getOperand(0) << " : "
703      << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType();
704 }
705