xref: /llvm-project/mlir/lib/Interfaces/CastInterfaces.cpp (revision e66f2beba8b38b148d3a892326a7133c388ffbfb)
16ccf2d62SRiver Riddle //===- CastInterfaces.cpp -------------------------------------------------===//
26ccf2d62SRiver Riddle //
36ccf2d62SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
46ccf2d62SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
56ccf2d62SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66ccf2d62SRiver Riddle //
76ccf2d62SRiver Riddle //===----------------------------------------------------------------------===//
86ccf2d62SRiver Riddle 
96ccf2d62SRiver Riddle #include "mlir/Interfaces/CastInterfaces.h"
106ccf2d62SRiver Riddle 
11*e66f2bebSMatthias Springer #include "mlir/IR/BuiltinDialect.h"
12*e66f2bebSMatthias Springer #include "mlir/IR/BuiltinOps.h"
13*e66f2bebSMatthias Springer 
146ccf2d62SRiver Riddle using namespace mlir;
156ccf2d62SRiver Riddle 
166ccf2d62SRiver Riddle //===----------------------------------------------------------------------===//
17*e66f2bebSMatthias Springer // Helper functions for CastOpInterface
18*e66f2bebSMatthias Springer //===----------------------------------------------------------------------===//
19*e66f2bebSMatthias Springer 
20*e66f2bebSMatthias Springer /// Attempt to fold the given cast operation.
21*e66f2bebSMatthias Springer LogicalResult
foldCastInterfaceOp(Operation * op,ArrayRef<Attribute> attrOperands,SmallVectorImpl<OpFoldResult> & foldResults)22*e66f2bebSMatthias Springer impl::foldCastInterfaceOp(Operation *op, ArrayRef<Attribute> attrOperands,
23*e66f2bebSMatthias Springer                           SmallVectorImpl<OpFoldResult> &foldResults) {
24*e66f2bebSMatthias Springer   OperandRange operands = op->getOperands();
25*e66f2bebSMatthias Springer   if (operands.empty())
26*e66f2bebSMatthias Springer     return failure();
27*e66f2bebSMatthias Springer   ResultRange results = op->getResults();
28*e66f2bebSMatthias Springer 
29*e66f2bebSMatthias Springer   // Check for the case where the input and output types match 1-1.
30*e66f2bebSMatthias Springer   if (operands.getTypes() == results.getTypes()) {
31*e66f2bebSMatthias Springer     foldResults.append(operands.begin(), operands.end());
32*e66f2bebSMatthias Springer     return success();
33*e66f2bebSMatthias Springer   }
34*e66f2bebSMatthias Springer 
35*e66f2bebSMatthias Springer   return failure();
36*e66f2bebSMatthias Springer }
37*e66f2bebSMatthias Springer 
38*e66f2bebSMatthias Springer /// Attempt to verify the given cast operation.
verifyCastInterfaceOp(Operation * op)39*e66f2bebSMatthias Springer LogicalResult impl::verifyCastInterfaceOp(Operation *op) {
40*e66f2bebSMatthias Springer   auto resultTypes = op->getResultTypes();
41*e66f2bebSMatthias Springer   if (resultTypes.empty())
42*e66f2bebSMatthias Springer     return op->emitOpError()
43*e66f2bebSMatthias Springer            << "expected at least one result for cast operation";
44*e66f2bebSMatthias Springer 
45*e66f2bebSMatthias Springer   auto operandTypes = op->getOperandTypes();
46*e66f2bebSMatthias Springer   if (!cast<CastOpInterface>(op).areCastCompatible(operandTypes, resultTypes)) {
47*e66f2bebSMatthias Springer     InFlightDiagnostic diag = op->emitOpError("operand type");
48*e66f2bebSMatthias Springer     if (operandTypes.empty())
49*e66f2bebSMatthias Springer       diag << "s []";
50*e66f2bebSMatthias Springer     else if (llvm::size(operandTypes) == 1)
51*e66f2bebSMatthias Springer       diag << " " << *operandTypes.begin();
52*e66f2bebSMatthias Springer     else
53*e66f2bebSMatthias Springer       diag << "s " << operandTypes;
54*e66f2bebSMatthias Springer     return diag << " and result type" << (resultTypes.size() == 1 ? " " : "s ")
55*e66f2bebSMatthias Springer                 << resultTypes << " are cast incompatible";
56*e66f2bebSMatthias Springer   }
57*e66f2bebSMatthias Springer 
58*e66f2bebSMatthias Springer   return success();
59*e66f2bebSMatthias Springer }
60*e66f2bebSMatthias Springer 
61*e66f2bebSMatthias Springer //===----------------------------------------------------------------------===//
62*e66f2bebSMatthias Springer // External model for BuiltinDialect ops
63*e66f2bebSMatthias Springer //===----------------------------------------------------------------------===//
64*e66f2bebSMatthias Springer 
65*e66f2bebSMatthias Springer namespace mlir {
66*e66f2bebSMatthias Springer namespace {
67*e66f2bebSMatthias Springer // This interface cannot be implemented directly on the op because the IR build
68*e66f2bebSMatthias Springer // unit cannot depend on the Interfaces build unit.
69*e66f2bebSMatthias Springer struct UnrealizedConversionCastOpInterface
70*e66f2bebSMatthias Springer     : CastOpInterface::ExternalModel<UnrealizedConversionCastOpInterface,
71*e66f2bebSMatthias Springer                                      UnrealizedConversionCastOp> {
areCastCompatiblemlir::__anond689e21e0111::UnrealizedConversionCastOpInterface72*e66f2bebSMatthias Springer   static bool areCastCompatible(TypeRange inputs, TypeRange outputs) {
73*e66f2bebSMatthias Springer     // `UnrealizedConversionCastOp` is agnostic of the input/output types.
74*e66f2bebSMatthias Springer     return true;
75*e66f2bebSMatthias Springer   }
76*e66f2bebSMatthias Springer };
77*e66f2bebSMatthias Springer } // namespace
78*e66f2bebSMatthias Springer } // namespace mlir
79*e66f2bebSMatthias Springer 
registerCastOpInterfaceExternalModels(DialectRegistry & registry)80*e66f2bebSMatthias Springer void mlir::builtin::registerCastOpInterfaceExternalModels(
81*e66f2bebSMatthias Springer     DialectRegistry &registry) {
82*e66f2bebSMatthias Springer   registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) {
83*e66f2bebSMatthias Springer     UnrealizedConversionCastOp::attachInterface<
84*e66f2bebSMatthias Springer         UnrealizedConversionCastOpInterface>(*ctx);
85*e66f2bebSMatthias Springer   });
86*e66f2bebSMatthias Springer }
87*e66f2bebSMatthias Springer 
88*e66f2bebSMatthias Springer //===----------------------------------------------------------------------===//
896ccf2d62SRiver Riddle // Table-generated class definitions
906ccf2d62SRiver Riddle //===----------------------------------------------------------------------===//
916ccf2d62SRiver Riddle 
926ccf2d62SRiver Riddle #include "mlir/Interfaces/CastInterfaces.cpp.inc"
93