1 //===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/BuiltinTypes.h" 10 #include "mlir-c/AffineMap.h" 11 #include "mlir-c/IR.h" 12 #include "mlir-c/Support.h" 13 #include "mlir/CAPI/AffineMap.h" 14 #include "mlir/CAPI/IR.h" 15 #include "mlir/CAPI/Support.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/BuiltinTypes.h" 18 #include "mlir/IR/Types.h" 19 20 #include <algorithm> 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // Integer types. 26 //===----------------------------------------------------------------------===// 27 28 MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); } 29 30 bool mlirTypeIsAInteger(MlirType type) { 31 return llvm::isa<IntegerType>(unwrap(type)); 32 } 33 34 MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) { 35 return wrap(IntegerType::get(unwrap(ctx), bitwidth)); 36 } 37 38 MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) { 39 return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed)); 40 } 41 42 MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) { 43 return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned)); 44 } 45 46 unsigned mlirIntegerTypeGetWidth(MlirType type) { 47 return llvm::cast<IntegerType>(unwrap(type)).getWidth(); 48 } 49 50 bool mlirIntegerTypeIsSignless(MlirType type) { 51 return llvm::cast<IntegerType>(unwrap(type)).isSignless(); 52 } 53 54 bool mlirIntegerTypeIsSigned(MlirType type) { 55 return llvm::cast<IntegerType>(unwrap(type)).isSigned(); 56 } 57 58 bool mlirIntegerTypeIsUnsigned(MlirType type) { 59 return llvm::cast<IntegerType>(unwrap(type)).isUnsigned(); 60 } 61 62 //===----------------------------------------------------------------------===// 63 // Index type. 64 //===----------------------------------------------------------------------===// 65 66 MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); } 67 68 bool mlirTypeIsAIndex(MlirType type) { 69 return llvm::isa<IndexType>(unwrap(type)); 70 } 71 72 MlirType mlirIndexTypeGet(MlirContext ctx) { 73 return wrap(IndexType::get(unwrap(ctx))); 74 } 75 76 //===----------------------------------------------------------------------===// 77 // Floating-point types. 78 //===----------------------------------------------------------------------===// 79 80 bool mlirTypeIsAFloat(MlirType type) { 81 return llvm::isa<FloatType>(unwrap(type)); 82 } 83 84 unsigned mlirFloatTypeGetWidth(MlirType type) { 85 return llvm::cast<FloatType>(unwrap(type)).getWidth(); 86 } 87 88 MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() { 89 return wrap(Float4E2M1FNType::getTypeID()); 90 } 91 92 bool mlirTypeIsAFloat4E2M1FN(MlirType type) { 93 return llvm::isa<Float4E2M1FNType>(unwrap(type)); 94 } 95 96 MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) { 97 return wrap(Float4E2M1FNType::get(unwrap(ctx))); 98 } 99 100 MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() { 101 return wrap(Float6E2M3FNType::getTypeID()); 102 } 103 104 bool mlirTypeIsAFloat6E2M3FN(MlirType type) { 105 return llvm::isa<Float6E2M3FNType>(unwrap(type)); 106 } 107 108 MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) { 109 return wrap(Float6E2M3FNType::get(unwrap(ctx))); 110 } 111 112 MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() { 113 return wrap(Float6E3M2FNType::getTypeID()); 114 } 115 116 bool mlirTypeIsAFloat6E3M2FN(MlirType type) { 117 return llvm::isa<Float6E3M2FNType>(unwrap(type)); 118 } 119 120 MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) { 121 return wrap(Float6E3M2FNType::get(unwrap(ctx))); 122 } 123 124 MlirTypeID mlirFloat8E5M2TypeGetTypeID() { 125 return wrap(Float8E5M2Type::getTypeID()); 126 } 127 128 bool mlirTypeIsAFloat8E5M2(MlirType type) { 129 return llvm::isa<Float8E5M2Type>(unwrap(type)); 130 } 131 132 MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { 133 return wrap(Float8E5M2Type::get(unwrap(ctx))); 134 } 135 136 MlirTypeID mlirFloat8E4M3TypeGetTypeID() { 137 return wrap(Float8E4M3Type::getTypeID()); 138 } 139 140 bool mlirTypeIsAFloat8E4M3(MlirType type) { 141 return llvm::isa<Float8E4M3Type>(unwrap(type)); 142 } 143 144 MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) { 145 return wrap(Float8E4M3Type::get(unwrap(ctx))); 146 } 147 148 MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() { 149 return wrap(Float8E4M3FNType::getTypeID()); 150 } 151 152 bool mlirTypeIsAFloat8E4M3FN(MlirType type) { 153 return llvm::isa<Float8E4M3FNType>(unwrap(type)); 154 } 155 156 MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) { 157 return wrap(Float8E4M3FNType::get(unwrap(ctx))); 158 } 159 160 MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() { 161 return wrap(Float8E5M2FNUZType::getTypeID()); 162 } 163 164 bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { 165 return llvm::isa<Float8E5M2FNUZType>(unwrap(type)); 166 } 167 168 MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) { 169 return wrap(Float8E5M2FNUZType::get(unwrap(ctx))); 170 } 171 172 MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() { 173 return wrap(Float8E4M3FNUZType::getTypeID()); 174 } 175 176 bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { 177 return llvm::isa<Float8E4M3FNUZType>(unwrap(type)); 178 } 179 180 MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) { 181 return wrap(Float8E4M3FNUZType::get(unwrap(ctx))); 182 } 183 184 MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() { 185 return wrap(Float8E4M3B11FNUZType::getTypeID()); 186 } 187 188 bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { 189 return llvm::isa<Float8E4M3B11FNUZType>(unwrap(type)); 190 } 191 192 MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) { 193 return wrap(Float8E4M3B11FNUZType::get(unwrap(ctx))); 194 } 195 196 MlirTypeID mlirFloat8E3M4TypeGetTypeID() { 197 return wrap(Float8E3M4Type::getTypeID()); 198 } 199 200 bool mlirTypeIsAFloat8E3M4(MlirType type) { 201 return llvm::isa<Float8E3M4Type>(unwrap(type)); 202 } 203 204 MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) { 205 return wrap(Float8E3M4Type::get(unwrap(ctx))); 206 } 207 208 MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() { 209 return wrap(Float8E8M0FNUType::getTypeID()); 210 } 211 212 bool mlirTypeIsAFloat8E8M0FNU(MlirType type) { 213 return llvm::isa<Float8E8M0FNUType>(unwrap(type)); 214 } 215 216 MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) { 217 return wrap(Float8E8M0FNUType::get(unwrap(ctx))); 218 } 219 220 MlirTypeID mlirBFloat16TypeGetTypeID() { 221 return wrap(BFloat16Type::getTypeID()); 222 } 223 224 bool mlirTypeIsABF16(MlirType type) { 225 return llvm::isa<BFloat16Type>(unwrap(type)); 226 } 227 228 MlirType mlirBF16TypeGet(MlirContext ctx) { 229 return wrap(BFloat16Type::get(unwrap(ctx))); 230 } 231 232 MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); } 233 234 bool mlirTypeIsAF16(MlirType type) { 235 return llvm::isa<Float16Type>(unwrap(type)); 236 } 237 238 MlirType mlirF16TypeGet(MlirContext ctx) { 239 return wrap(Float16Type::get(unwrap(ctx))); 240 } 241 242 MlirTypeID mlirFloatTF32TypeGetTypeID() { 243 return wrap(FloatTF32Type::getTypeID()); 244 } 245 246 bool mlirTypeIsATF32(MlirType type) { 247 return llvm::isa<FloatTF32Type>(unwrap(type)); 248 } 249 250 MlirType mlirTF32TypeGet(MlirContext ctx) { 251 return wrap(FloatTF32Type::get(unwrap(ctx))); 252 } 253 254 MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } 255 256 bool mlirTypeIsAF32(MlirType type) { 257 return llvm::isa<Float32Type>(unwrap(type)); 258 } 259 260 MlirType mlirF32TypeGet(MlirContext ctx) { 261 return wrap(Float32Type::get(unwrap(ctx))); 262 } 263 264 MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); } 265 266 bool mlirTypeIsAF64(MlirType type) { 267 return llvm::isa<Float64Type>(unwrap(type)); 268 } 269 270 MlirType mlirF64TypeGet(MlirContext ctx) { 271 return wrap(Float64Type::get(unwrap(ctx))); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // None type. 276 //===----------------------------------------------------------------------===// 277 278 MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); } 279 280 bool mlirTypeIsANone(MlirType type) { 281 return llvm::isa<NoneType>(unwrap(type)); 282 } 283 284 MlirType mlirNoneTypeGet(MlirContext ctx) { 285 return wrap(NoneType::get(unwrap(ctx))); 286 } 287 288 //===----------------------------------------------------------------------===// 289 // Complex type. 290 //===----------------------------------------------------------------------===// 291 292 MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); } 293 294 bool mlirTypeIsAComplex(MlirType type) { 295 return llvm::isa<ComplexType>(unwrap(type)); 296 } 297 298 MlirType mlirComplexTypeGet(MlirType elementType) { 299 return wrap(ComplexType::get(unwrap(elementType))); 300 } 301 302 MlirType mlirComplexTypeGetElementType(MlirType type) { 303 return wrap(llvm::cast<ComplexType>(unwrap(type)).getElementType()); 304 } 305 306 //===----------------------------------------------------------------------===// 307 // Shaped type. 308 //===----------------------------------------------------------------------===// 309 310 bool mlirTypeIsAShaped(MlirType type) { 311 return llvm::isa<ShapedType>(unwrap(type)); 312 } 313 314 MlirType mlirShapedTypeGetElementType(MlirType type) { 315 return wrap(llvm::cast<ShapedType>(unwrap(type)).getElementType()); 316 } 317 318 bool mlirShapedTypeHasRank(MlirType type) { 319 return llvm::cast<ShapedType>(unwrap(type)).hasRank(); 320 } 321 322 int64_t mlirShapedTypeGetRank(MlirType type) { 323 return llvm::cast<ShapedType>(unwrap(type)).getRank(); 324 } 325 326 bool mlirShapedTypeHasStaticShape(MlirType type) { 327 return llvm::cast<ShapedType>(unwrap(type)).hasStaticShape(); 328 } 329 330 bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) { 331 return llvm::cast<ShapedType>(unwrap(type)) 332 .isDynamicDim(static_cast<unsigned>(dim)); 333 } 334 335 int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { 336 return llvm::cast<ShapedType>(unwrap(type)) 337 .getDimSize(static_cast<unsigned>(dim)); 338 } 339 340 int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; } 341 342 bool mlirShapedTypeIsDynamicSize(int64_t size) { 343 return ShapedType::isDynamic(size); 344 } 345 346 bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) { 347 return ShapedType::isDynamic(val); 348 } 349 350 int64_t mlirShapedTypeGetDynamicStrideOrOffset() { 351 return ShapedType::kDynamic; 352 } 353 354 //===----------------------------------------------------------------------===// 355 // Vector type. 356 //===----------------------------------------------------------------------===// 357 358 MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); } 359 360 bool mlirTypeIsAVector(MlirType type) { 361 return llvm::isa<VectorType>(unwrap(type)); 362 } 363 364 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, 365 MlirType elementType) { 366 return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)), 367 unwrap(elementType))); 368 } 369 370 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, 371 const int64_t *shape, MlirType elementType) { 372 return wrap(VectorType::getChecked( 373 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)), 374 unwrap(elementType))); 375 } 376 377 MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape, 378 const bool *scalable, MlirType elementType) { 379 return wrap(VectorType::get( 380 llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType), 381 llvm::ArrayRef(scalable, static_cast<size_t>(rank)))); 382 } 383 384 MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, 385 const int64_t *shape, 386 const bool *scalable, 387 MlirType elementType) { 388 return wrap(VectorType::getChecked( 389 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)), 390 unwrap(elementType), 391 llvm::ArrayRef(scalable, static_cast<size_t>(rank)))); 392 } 393 394 bool mlirVectorTypeIsScalable(MlirType type) { 395 return cast<VectorType>(unwrap(type)).isScalable(); 396 } 397 398 bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) { 399 return cast<VectorType>(unwrap(type)).getScalableDims()[dim]; 400 } 401 402 //===----------------------------------------------------------------------===// 403 // Ranked / Unranked tensor type. 404 //===----------------------------------------------------------------------===// 405 406 bool mlirTypeIsATensor(MlirType type) { 407 return llvm::isa<TensorType>(unwrap(type)); 408 } 409 410 MlirTypeID mlirRankedTensorTypeGetTypeID() { 411 return wrap(RankedTensorType::getTypeID()); 412 } 413 414 bool mlirTypeIsARankedTensor(MlirType type) { 415 return llvm::isa<RankedTensorType>(unwrap(type)); 416 } 417 418 MlirTypeID mlirUnrankedTensorTypeGetTypeID() { 419 return wrap(UnrankedTensorType::getTypeID()); 420 } 421 422 bool mlirTypeIsAUnrankedTensor(MlirType type) { 423 return llvm::isa<UnrankedTensorType>(unwrap(type)); 424 } 425 426 MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, 427 MlirType elementType, MlirAttribute encoding) { 428 return wrap( 429 RankedTensorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)), 430 unwrap(elementType), unwrap(encoding))); 431 } 432 433 MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, 434 const int64_t *shape, 435 MlirType elementType, 436 MlirAttribute encoding) { 437 return wrap(RankedTensorType::getChecked( 438 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)), 439 unwrap(elementType), unwrap(encoding))); 440 } 441 442 MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) { 443 return wrap(llvm::cast<RankedTensorType>(unwrap(type)).getEncoding()); 444 } 445 446 MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { 447 return wrap(UnrankedTensorType::get(unwrap(elementType))); 448 } 449 450 MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, 451 MlirType elementType) { 452 return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType))); 453 } 454 455 MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) { 456 return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType()); 457 } 458 459 //===----------------------------------------------------------------------===// 460 // Ranked / Unranked MemRef type. 461 //===----------------------------------------------------------------------===// 462 463 MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); } 464 465 bool mlirTypeIsAMemRef(MlirType type) { 466 return llvm::isa<MemRefType>(unwrap(type)); 467 } 468 469 MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, 470 const int64_t *shape, MlirAttribute layout, 471 MlirAttribute memorySpace) { 472 return wrap(MemRefType::get( 473 llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType), 474 mlirAttributeIsNull(layout) 475 ? MemRefLayoutAttrInterface() 476 : llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)), 477 unwrap(memorySpace))); 478 } 479 480 MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, 481 intptr_t rank, const int64_t *shape, 482 MlirAttribute layout, 483 MlirAttribute memorySpace) { 484 return wrap(MemRefType::getChecked( 485 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)), 486 unwrap(elementType), 487 mlirAttributeIsNull(layout) 488 ? MemRefLayoutAttrInterface() 489 : llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)), 490 unwrap(memorySpace))); 491 } 492 493 MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, 494 const int64_t *shape, 495 MlirAttribute memorySpace) { 496 return wrap(MemRefType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)), 497 unwrap(elementType), MemRefLayoutAttrInterface(), 498 unwrap(memorySpace))); 499 } 500 501 MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, 502 MlirType elementType, intptr_t rank, 503 const int64_t *shape, 504 MlirAttribute memorySpace) { 505 return wrap(MemRefType::getChecked( 506 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)), 507 unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace))); 508 } 509 510 MlirAttribute mlirMemRefTypeGetLayout(MlirType type) { 511 return wrap(llvm::cast<MemRefType>(unwrap(type)).getLayout()); 512 } 513 514 MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) { 515 return wrap(llvm::cast<MemRefType>(unwrap(type)).getLayout().getAffineMap()); 516 } 517 518 MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { 519 return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace()); 520 } 521 522 MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, 523 int64_t *strides, 524 int64_t *offset) { 525 MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type)); 526 SmallVector<int64_t> strides_; 527 if (failed(memrefType.getStridesAndOffset(strides_, *offset))) 528 return mlirLogicalResultFailure(); 529 530 (void)std::copy(strides_.begin(), strides_.end(), strides); 531 return mlirLogicalResultSuccess(); 532 } 533 534 MlirTypeID mlirUnrankedMemRefTypeGetTypeID() { 535 return wrap(UnrankedMemRefType::getTypeID()); 536 } 537 538 bool mlirTypeIsAUnrankedMemRef(MlirType type) { 539 return llvm::isa<UnrankedMemRefType>(unwrap(type)); 540 } 541 542 MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, 543 MlirAttribute memorySpace) { 544 return wrap( 545 UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace))); 546 } 547 548 MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, 549 MlirType elementType, 550 MlirAttribute memorySpace) { 551 return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType), 552 unwrap(memorySpace))); 553 } 554 555 MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) { 556 return wrap(llvm::cast<UnrankedMemRefType>(unwrap(type)).getMemorySpace()); 557 } 558 559 //===----------------------------------------------------------------------===// 560 // Tuple type. 561 //===----------------------------------------------------------------------===// 562 563 MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); } 564 565 bool mlirTypeIsATuple(MlirType type) { 566 return llvm::isa<TupleType>(unwrap(type)); 567 } 568 569 MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, 570 MlirType const *elements) { 571 SmallVector<Type, 4> types; 572 ArrayRef<Type> typeRef = unwrapList(numElements, elements, types); 573 return wrap(TupleType::get(unwrap(ctx), typeRef)); 574 } 575 576 intptr_t mlirTupleTypeGetNumTypes(MlirType type) { 577 return llvm::cast<TupleType>(unwrap(type)).size(); 578 } 579 580 MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) { 581 return wrap( 582 llvm::cast<TupleType>(unwrap(type)).getType(static_cast<size_t>(pos))); 583 } 584 585 //===----------------------------------------------------------------------===// 586 // Function type. 587 //===----------------------------------------------------------------------===// 588 589 MlirTypeID mlirFunctionTypeGetTypeID() { 590 return wrap(FunctionType::getTypeID()); 591 } 592 593 bool mlirTypeIsAFunction(MlirType type) { 594 return llvm::isa<FunctionType>(unwrap(type)); 595 } 596 597 MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, 598 MlirType const *inputs, intptr_t numResults, 599 MlirType const *results) { 600 SmallVector<Type, 4> inputsList; 601 SmallVector<Type, 4> resultsList; 602 (void)unwrapList(numInputs, inputs, inputsList); 603 (void)unwrapList(numResults, results, resultsList); 604 return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList)); 605 } 606 607 intptr_t mlirFunctionTypeGetNumInputs(MlirType type) { 608 return llvm::cast<FunctionType>(unwrap(type)).getNumInputs(); 609 } 610 611 intptr_t mlirFunctionTypeGetNumResults(MlirType type) { 612 return llvm::cast<FunctionType>(unwrap(type)).getNumResults(); 613 } 614 615 MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) { 616 assert(pos >= 0 && "pos in array must be positive"); 617 return wrap(llvm::cast<FunctionType>(unwrap(type)) 618 .getInput(static_cast<unsigned>(pos))); 619 } 620 621 MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) { 622 assert(pos >= 0 && "pos in array must be positive"); 623 return wrap(llvm::cast<FunctionType>(unwrap(type)) 624 .getResult(static_cast<unsigned>(pos))); 625 } 626 627 //===----------------------------------------------------------------------===// 628 // Opaque type. 629 //===----------------------------------------------------------------------===// 630 631 MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); } 632 633 bool mlirTypeIsAOpaque(MlirType type) { 634 return llvm::isa<OpaqueType>(unwrap(type)); 635 } 636 637 MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, 638 MlirStringRef typeData) { 639 return wrap( 640 OpaqueType::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)), 641 unwrap(typeData))); 642 } 643 644 MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) { 645 return wrap( 646 llvm::cast<OpaqueType>(unwrap(type)).getDialectNamespace().strref()); 647 } 648 649 MlirStringRef mlirOpaqueTypeGetData(MlirType type) { 650 return wrap(llvm::cast<OpaqueType>(unwrap(type)).getTypeData()); 651 } 652