1 //===- IR.cpp - C Interface for Core MLIR APIs ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/IR.h" 10 #include "mlir-c/Support.h" 11 12 #include "mlir/AsmParser/AsmParser.h" 13 #include "mlir/Bytecode/BytecodeWriter.h" 14 #include "mlir/CAPI/IR.h" 15 #include "mlir/CAPI/Support.h" 16 #include "mlir/CAPI/Utils.h" 17 #include "mlir/IR/Attributes.h" 18 #include "mlir/IR/BuiltinAttributes.h" 19 #include "mlir/IR/BuiltinOps.h" 20 #include "mlir/IR/Diagnostics.h" 21 #include "mlir/IR/Dialect.h" 22 #include "mlir/IR/Location.h" 23 #include "mlir/IR/Operation.h" 24 #include "mlir/IR/OperationSupport.h" 25 #include "mlir/IR/Types.h" 26 #include "mlir/IR/Value.h" 27 #include "mlir/IR/Verifier.h" 28 #include "mlir/IR/Visitors.h" 29 #include "mlir/Interfaces/InferTypeOpInterface.h" 30 #include "mlir/Parser/Parser.h" 31 #include "llvm/ADT/SmallPtrSet.h" 32 #include "llvm/Support/ThreadPool.h" 33 34 #include <cstddef> 35 #include <memory> 36 #include <optional> 37 38 using namespace mlir; 39 40 //===----------------------------------------------------------------------===// 41 // Context API. 42 //===----------------------------------------------------------------------===// 43 44 MlirContext mlirContextCreate() { 45 auto *context = new MLIRContext; 46 return wrap(context); 47 } 48 49 static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) { 50 return threadingEnabled ? MLIRContext::Threading::ENABLED 51 : MLIRContext::Threading::DISABLED; 52 } 53 54 MlirContext mlirContextCreateWithThreading(bool threadingEnabled) { 55 auto *context = new MLIRContext(toThreadingEnum(threadingEnabled)); 56 return wrap(context); 57 } 58 59 MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry, 60 bool threadingEnabled) { 61 auto *context = 62 new MLIRContext(*unwrap(registry), toThreadingEnum(threadingEnabled)); 63 return wrap(context); 64 } 65 66 bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { 67 return unwrap(ctx1) == unwrap(ctx2); 68 } 69 70 void mlirContextDestroy(MlirContext context) { delete unwrap(context); } 71 72 void mlirContextSetAllowUnregisteredDialects(MlirContext context, bool allow) { 73 unwrap(context)->allowUnregisteredDialects(allow); 74 } 75 76 bool mlirContextGetAllowUnregisteredDialects(MlirContext context) { 77 return unwrap(context)->allowsUnregisteredDialects(); 78 } 79 intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { 80 return static_cast<intptr_t>(unwrap(context)->getAvailableDialects().size()); 81 } 82 83 void mlirContextAppendDialectRegistry(MlirContext ctx, 84 MlirDialectRegistry registry) { 85 unwrap(ctx)->appendDialectRegistry(*unwrap(registry)); 86 } 87 88 // TODO: expose a cheaper way than constructing + sorting a vector only to take 89 // its size. 90 intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { 91 return static_cast<intptr_t>(unwrap(context)->getLoadedDialects().size()); 92 } 93 94 MlirDialect mlirContextGetOrLoadDialect(MlirContext context, 95 MlirStringRef name) { 96 return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); 97 } 98 99 bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { 100 return unwrap(context)->isOperationRegistered(unwrap(name)); 101 } 102 103 void mlirContextEnableMultithreading(MlirContext context, bool enable) { 104 return unwrap(context)->enableMultithreading(enable); 105 } 106 107 void mlirContextLoadAllAvailableDialects(MlirContext context) { 108 unwrap(context)->loadAllAvailableDialects(); 109 } 110 111 void mlirContextSetThreadPool(MlirContext context, 112 MlirLlvmThreadPool threadPool) { 113 unwrap(context)->setThreadPool(*unwrap(threadPool)); 114 } 115 116 //===----------------------------------------------------------------------===// 117 // Dialect API. 118 //===----------------------------------------------------------------------===// 119 120 MlirContext mlirDialectGetContext(MlirDialect dialect) { 121 return wrap(unwrap(dialect)->getContext()); 122 } 123 124 bool mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { 125 return unwrap(dialect1) == unwrap(dialect2); 126 } 127 128 MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { 129 return wrap(unwrap(dialect)->getNamespace()); 130 } 131 132 //===----------------------------------------------------------------------===// 133 // DialectRegistry API. 134 //===----------------------------------------------------------------------===// 135 136 MlirDialectRegistry mlirDialectRegistryCreate() { 137 return wrap(new DialectRegistry()); 138 } 139 140 void mlirDialectRegistryDestroy(MlirDialectRegistry registry) { 141 delete unwrap(registry); 142 } 143 144 //===----------------------------------------------------------------------===// 145 // AsmState API. 146 //===----------------------------------------------------------------------===// 147 148 MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op, 149 MlirOpPrintingFlags flags) { 150 return wrap(new AsmState(unwrap(op), *unwrap(flags))); 151 } 152 153 static Operation *findParent(Operation *op, bool shouldUseLocalScope) { 154 do { 155 // If we are printing local scope, stop at the first operation that is 156 // isolated from above. 157 if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>()) 158 break; 159 160 // Otherwise, traverse up to the next parent. 161 Operation *parentOp = op->getParentOp(); 162 if (!parentOp) 163 break; 164 op = parentOp; 165 } while (true); 166 return op; 167 } 168 169 MlirAsmState mlirAsmStateCreateForValue(MlirValue value, 170 MlirOpPrintingFlags flags) { 171 Operation *op; 172 mlir::Value val = unwrap(value); 173 if (auto result = llvm::dyn_cast<OpResult>(val)) { 174 op = result.getOwner(); 175 } else { 176 op = llvm::cast<BlockArgument>(val).getOwner()->getParentOp(); 177 if (!op) { 178 emitError(val.getLoc()) << "<<UNKNOWN SSA VALUE>>"; 179 return {nullptr}; 180 } 181 } 182 op = findParent(op, unwrap(flags)->shouldUseLocalScope()); 183 return wrap(new AsmState(op, *unwrap(flags))); 184 } 185 186 /// Destroys printing flags created with mlirAsmStateCreate. 187 void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(state); } 188 189 //===----------------------------------------------------------------------===// 190 // Printing flags API. 191 //===----------------------------------------------------------------------===// 192 193 MlirOpPrintingFlags mlirOpPrintingFlagsCreate() { 194 return wrap(new OpPrintingFlags()); 195 } 196 197 void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) { 198 delete unwrap(flags); 199 } 200 201 void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, 202 intptr_t largeElementLimit) { 203 unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); 204 } 205 206 void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, 207 intptr_t largeResourceLimit) { 208 unwrap(flags)->elideLargeResourceString(largeResourceLimit); 209 } 210 211 void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, 212 bool prettyForm) { 213 unwrap(flags)->enableDebugInfo(enable, /*prettyForm=*/prettyForm); 214 } 215 216 void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { 217 unwrap(flags)->printGenericOpForm(); 218 } 219 220 void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { 221 unwrap(flags)->useLocalScope(); 222 } 223 224 void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) { 225 unwrap(flags)->assumeVerified(); 226 } 227 228 void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags) { 229 unwrap(flags)->skipRegions(); 230 } 231 //===----------------------------------------------------------------------===// 232 // Bytecode printing flags API. 233 //===----------------------------------------------------------------------===// 234 235 MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate() { 236 return wrap(new BytecodeWriterConfig()); 237 } 238 239 void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config) { 240 delete unwrap(config); 241 } 242 243 void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, 244 int64_t version) { 245 unwrap(flags)->setDesiredBytecodeVersion(version); 246 } 247 248 //===----------------------------------------------------------------------===// 249 // Location API. 250 //===----------------------------------------------------------------------===// 251 252 MlirAttribute mlirLocationGetAttribute(MlirLocation location) { 253 return wrap(LocationAttr(unwrap(location))); 254 } 255 256 MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) { 257 return wrap(Location(llvm::cast<LocationAttr>(unwrap(attribute)))); 258 } 259 260 MlirLocation mlirLocationFileLineColGet(MlirContext context, 261 MlirStringRef filename, unsigned line, 262 unsigned col) { 263 return wrap(Location( 264 FileLineColLoc::get(unwrap(context), unwrap(filename), line, col))); 265 } 266 267 MlirLocation 268 mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, 269 unsigned startLine, unsigned startCol, 270 unsigned endLine, unsigned endCol) { 271 return wrap( 272 Location(FileLineColRange::get(unwrap(context), unwrap(filename), 273 startLine, startCol, endLine, endCol))); 274 } 275 276 MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { 277 return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); 278 } 279 280 MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, 281 MlirLocation const *locations, 282 MlirAttribute metadata) { 283 SmallVector<Location, 4> locs; 284 ArrayRef<Location> unwrappedLocs = unwrapList(nLocations, locations, locs); 285 return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx))); 286 } 287 288 MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, 289 MlirLocation childLoc) { 290 if (mlirLocationIsNull(childLoc)) 291 return wrap( 292 Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name))))); 293 return wrap(Location(NameLoc::get( 294 StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc)))); 295 } 296 297 MlirLocation mlirLocationUnknownGet(MlirContext context) { 298 return wrap(Location(UnknownLoc::get(unwrap(context)))); 299 } 300 301 bool mlirLocationEqual(MlirLocation l1, MlirLocation l2) { 302 return unwrap(l1) == unwrap(l2); 303 } 304 305 MlirContext mlirLocationGetContext(MlirLocation location) { 306 return wrap(unwrap(location).getContext()); 307 } 308 309 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, 310 void *userData) { 311 detail::CallbackOstream stream(callback, userData); 312 unwrap(location).print(stream); 313 } 314 315 //===----------------------------------------------------------------------===// 316 // Module API. 317 //===----------------------------------------------------------------------===// 318 319 MlirModule mlirModuleCreateEmpty(MlirLocation location) { 320 return wrap(ModuleOp::create(unwrap(location))); 321 } 322 323 MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { 324 OwningOpRef<ModuleOp> owning = 325 parseSourceString<ModuleOp>(unwrap(module), unwrap(context)); 326 if (!owning) 327 return MlirModule{nullptr}; 328 return MlirModule{owning.release().getOperation()}; 329 } 330 331 MlirContext mlirModuleGetContext(MlirModule module) { 332 return wrap(unwrap(module).getContext()); 333 } 334 335 MlirBlock mlirModuleGetBody(MlirModule module) { 336 return wrap(unwrap(module).getBody()); 337 } 338 339 void mlirModuleDestroy(MlirModule module) { 340 // Transfer ownership to an OwningOpRef<ModuleOp> so that its destructor is 341 // called. 342 OwningOpRef<ModuleOp>(unwrap(module)); 343 } 344 345 MlirOperation mlirModuleGetOperation(MlirModule module) { 346 return wrap(unwrap(module).getOperation()); 347 } 348 349 MlirModule mlirModuleFromOperation(MlirOperation op) { 350 return wrap(dyn_cast<ModuleOp>(unwrap(op))); 351 } 352 353 //===----------------------------------------------------------------------===// 354 // Operation state API. 355 //===----------------------------------------------------------------------===// 356 357 MlirOperationState mlirOperationStateGet(MlirStringRef name, MlirLocation loc) { 358 MlirOperationState state; 359 state.name = name; 360 state.location = loc; 361 state.nResults = 0; 362 state.results = nullptr; 363 state.nOperands = 0; 364 state.operands = nullptr; 365 state.nRegions = 0; 366 state.regions = nullptr; 367 state.nSuccessors = 0; 368 state.successors = nullptr; 369 state.nAttributes = 0; 370 state.attributes = nullptr; 371 state.enableResultTypeInference = false; 372 return state; 373 } 374 375 #define APPEND_ELEMS(type, sizeName, elemName) \ 376 state->elemName = \ 377 (type *)realloc(state->elemName, (state->sizeName + n) * sizeof(type)); \ 378 memcpy(state->elemName + state->sizeName, elemName, n * sizeof(type)); \ 379 state->sizeName += n; 380 381 void mlirOperationStateAddResults(MlirOperationState *state, intptr_t n, 382 MlirType const *results) { 383 APPEND_ELEMS(MlirType, nResults, results); 384 } 385 386 void mlirOperationStateAddOperands(MlirOperationState *state, intptr_t n, 387 MlirValue const *operands) { 388 APPEND_ELEMS(MlirValue, nOperands, operands); 389 } 390 void mlirOperationStateAddOwnedRegions(MlirOperationState *state, intptr_t n, 391 MlirRegion const *regions) { 392 APPEND_ELEMS(MlirRegion, nRegions, regions); 393 } 394 void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n, 395 MlirBlock const *successors) { 396 APPEND_ELEMS(MlirBlock, nSuccessors, successors); 397 } 398 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, 399 MlirNamedAttribute const *attributes) { 400 APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); 401 } 402 403 void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) { 404 state->enableResultTypeInference = true; 405 } 406 407 //===----------------------------------------------------------------------===// 408 // Operation API. 409 //===----------------------------------------------------------------------===// 410 411 static LogicalResult inferOperationTypes(OperationState &state) { 412 MLIRContext *context = state.getContext(); 413 std::optional<RegisteredOperationName> info = state.name.getRegisteredInfo(); 414 if (!info) { 415 emitError(state.location) 416 << "type inference was requested for the operation " << state.name 417 << ", but the operation was not registered; ensure that the dialect " 418 "containing the operation is linked into MLIR and registered with " 419 "the context"; 420 return failure(); 421 } 422 423 auto *inferInterface = info->getInterface<InferTypeOpInterface>(); 424 if (!inferInterface) { 425 emitError(state.location) 426 << "type inference was requested for the operation " << state.name 427 << ", but the operation does not support type inference; result " 428 "types must be specified explicitly"; 429 return failure(); 430 } 431 432 DictionaryAttr attributes = state.attributes.getDictionary(context); 433 OpaqueProperties properties = state.getRawProperties(); 434 435 if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) { 436 auto prop = std::make_unique<char[]>(info->getOpPropertyByteSize()); 437 properties = OpaqueProperties(prop.get()); 438 if (properties) { 439 auto emitError = [&]() { 440 return mlir::emitError(state.location) 441 << " failed properties conversion while building " 442 << state.name.getStringRef() << " with `" << attributes << "`: "; 443 }; 444 if (failed(info->setOpPropertiesFromAttribute(state.name, properties, 445 attributes, emitError))) 446 return failure(); 447 } 448 if (succeeded(inferInterface->inferReturnTypes( 449 context, state.location, state.operands, attributes, properties, 450 state.regions, state.types))) { 451 return success(); 452 } 453 // Diagnostic emitted by interface. 454 return failure(); 455 } 456 457 if (succeeded(inferInterface->inferReturnTypes( 458 context, state.location, state.operands, attributes, properties, 459 state.regions, state.types))) 460 return success(); 461 462 // Diagnostic emitted by interface. 463 return failure(); 464 } 465 466 MlirOperation mlirOperationCreate(MlirOperationState *state) { 467 assert(state); 468 OperationState cppState(unwrap(state->location), unwrap(state->name)); 469 SmallVector<Type, 4> resultStorage; 470 SmallVector<Value, 8> operandStorage; 471 SmallVector<Block *, 2> successorStorage; 472 cppState.addTypes(unwrapList(state->nResults, state->results, resultStorage)); 473 cppState.addOperands( 474 unwrapList(state->nOperands, state->operands, operandStorage)); 475 cppState.addSuccessors( 476 unwrapList(state->nSuccessors, state->successors, successorStorage)); 477 478 cppState.attributes.reserve(state->nAttributes); 479 for (intptr_t i = 0; i < state->nAttributes; ++i) 480 cppState.addAttribute(unwrap(state->attributes[i].name), 481 unwrap(state->attributes[i].attribute)); 482 483 for (intptr_t i = 0; i < state->nRegions; ++i) 484 cppState.addRegion(std::unique_ptr<Region>(unwrap(state->regions[i]))); 485 486 free(state->results); 487 free(state->operands); 488 free(state->successors); 489 free(state->regions); 490 free(state->attributes); 491 492 // Infer result types. 493 if (state->enableResultTypeInference) { 494 assert(cppState.types.empty() && 495 "result type inference enabled and result types provided"); 496 if (failed(inferOperationTypes(cppState))) 497 return {nullptr}; 498 } 499 500 return wrap(Operation::create(cppState)); 501 } 502 503 MlirOperation mlirOperationCreateParse(MlirContext context, 504 MlirStringRef sourceStr, 505 MlirStringRef sourceName) { 506 507 return wrap( 508 parseSourceString(unwrap(sourceStr), unwrap(context), unwrap(sourceName)) 509 .release()); 510 } 511 512 MlirOperation mlirOperationClone(MlirOperation op) { 513 return wrap(unwrap(op)->clone()); 514 } 515 516 void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } 517 518 void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); } 519 520 bool mlirOperationEqual(MlirOperation op, MlirOperation other) { 521 return unwrap(op) == unwrap(other); 522 } 523 524 MlirContext mlirOperationGetContext(MlirOperation op) { 525 return wrap(unwrap(op)->getContext()); 526 } 527 528 MlirLocation mlirOperationGetLocation(MlirOperation op) { 529 return wrap(unwrap(op)->getLoc()); 530 } 531 532 MlirTypeID mlirOperationGetTypeID(MlirOperation op) { 533 if (auto info = unwrap(op)->getRegisteredInfo()) 534 return wrap(info->getTypeID()); 535 return {nullptr}; 536 } 537 538 MlirIdentifier mlirOperationGetName(MlirOperation op) { 539 return wrap(unwrap(op)->getName().getIdentifier()); 540 } 541 542 MlirBlock mlirOperationGetBlock(MlirOperation op) { 543 return wrap(unwrap(op)->getBlock()); 544 } 545 546 MlirOperation mlirOperationGetParentOperation(MlirOperation op) { 547 return wrap(unwrap(op)->getParentOp()); 548 } 549 550 intptr_t mlirOperationGetNumRegions(MlirOperation op) { 551 return static_cast<intptr_t>(unwrap(op)->getNumRegions()); 552 } 553 554 MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { 555 return wrap(&unwrap(op)->getRegion(static_cast<unsigned>(pos))); 556 } 557 558 MlirRegion mlirOperationGetFirstRegion(MlirOperation op) { 559 Operation *cppOp = unwrap(op); 560 if (cppOp->getNumRegions() == 0) 561 return wrap(static_cast<Region *>(nullptr)); 562 return wrap(&cppOp->getRegion(0)); 563 } 564 565 MlirRegion mlirRegionGetNextInOperation(MlirRegion region) { 566 Region *cppRegion = unwrap(region); 567 Operation *parent = cppRegion->getParentOp(); 568 intptr_t next = cppRegion->getRegionNumber() + 1; 569 if (parent->getNumRegions() > next) 570 return wrap(&parent->getRegion(next)); 571 return wrap(static_cast<Region *>(nullptr)); 572 } 573 574 MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { 575 return wrap(unwrap(op)->getNextNode()); 576 } 577 578 intptr_t mlirOperationGetNumOperands(MlirOperation op) { 579 return static_cast<intptr_t>(unwrap(op)->getNumOperands()); 580 } 581 582 MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { 583 return wrap(unwrap(op)->getOperand(static_cast<unsigned>(pos))); 584 } 585 586 void mlirOperationSetOperand(MlirOperation op, intptr_t pos, 587 MlirValue newValue) { 588 unwrap(op)->setOperand(static_cast<unsigned>(pos), unwrap(newValue)); 589 } 590 591 void mlirOperationSetOperands(MlirOperation op, intptr_t nOperands, 592 MlirValue const *operands) { 593 SmallVector<Value> ops; 594 unwrap(op)->setOperands(unwrapList(nOperands, operands, ops)); 595 } 596 597 intptr_t mlirOperationGetNumResults(MlirOperation op) { 598 return static_cast<intptr_t>(unwrap(op)->getNumResults()); 599 } 600 601 MlirValue mlirOperationGetResult(MlirOperation op, intptr_t pos) { 602 return wrap(unwrap(op)->getResult(static_cast<unsigned>(pos))); 603 } 604 605 intptr_t mlirOperationGetNumSuccessors(MlirOperation op) { 606 return static_cast<intptr_t>(unwrap(op)->getNumSuccessors()); 607 } 608 609 MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { 610 return wrap(unwrap(op)->getSuccessor(static_cast<unsigned>(pos))); 611 } 612 613 MLIR_CAPI_EXPORTED bool 614 mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) { 615 std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name)); 616 return attr.has_value(); 617 } 618 619 MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op, 620 MlirStringRef name) { 621 std::optional<Attribute> attr = unwrap(op)->getInherentAttr(unwrap(name)); 622 if (attr.has_value()) 623 return wrap(*attr); 624 return {}; 625 } 626 627 void mlirOperationSetInherentAttributeByName(MlirOperation op, 628 MlirStringRef name, 629 MlirAttribute attr) { 630 unwrap(op)->setInherentAttr( 631 StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr)); 632 } 633 634 intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) { 635 return static_cast<intptr_t>( 636 llvm::range_size(unwrap(op)->getDiscardableAttrs())); 637 } 638 639 MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op, 640 intptr_t pos) { 641 NamedAttribute attr = 642 *std::next(unwrap(op)->getDiscardableAttrs().begin(), pos); 643 return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; 644 } 645 646 MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op, 647 MlirStringRef name) { 648 return wrap(unwrap(op)->getDiscardableAttr(unwrap(name))); 649 } 650 651 void mlirOperationSetDiscardableAttributeByName(MlirOperation op, 652 MlirStringRef name, 653 MlirAttribute attr) { 654 unwrap(op)->setDiscardableAttr(unwrap(name), unwrap(attr)); 655 } 656 657 bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op, 658 MlirStringRef name) { 659 return !!unwrap(op)->removeDiscardableAttr(unwrap(name)); 660 } 661 662 void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, 663 MlirBlock block) { 664 unwrap(op)->setSuccessor(unwrap(block), static_cast<unsigned>(pos)); 665 } 666 667 intptr_t mlirOperationGetNumAttributes(MlirOperation op) { 668 return static_cast<intptr_t>(unwrap(op)->getAttrs().size()); 669 } 670 671 MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { 672 NamedAttribute attr = unwrap(op)->getAttrs()[pos]; 673 return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; 674 } 675 676 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, 677 MlirStringRef name) { 678 return wrap(unwrap(op)->getAttr(unwrap(name))); 679 } 680 681 void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, 682 MlirAttribute attr) { 683 unwrap(op)->setAttr(unwrap(name), unwrap(attr)); 684 } 685 686 bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name) { 687 return !!unwrap(op)->removeAttr(unwrap(name)); 688 } 689 690 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, 691 void *userData) { 692 detail::CallbackOstream stream(callback, userData); 693 unwrap(op)->print(stream); 694 } 695 696 void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, 697 MlirStringCallback callback, void *userData) { 698 detail::CallbackOstream stream(callback, userData); 699 unwrap(op)->print(stream, *unwrap(flags)); 700 } 701 702 void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, 703 MlirStringCallback callback, void *userData) { 704 detail::CallbackOstream stream(callback, userData); 705 if (state.ptr) 706 unwrap(op)->print(stream, *unwrap(state)); 707 unwrap(op)->print(stream); 708 } 709 710 void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, 711 void *userData) { 712 detail::CallbackOstream stream(callback, userData); 713 // As no desired version is set, no failure can occur. 714 (void)writeBytecodeToFile(unwrap(op), stream); 715 } 716 717 MlirLogicalResult mlirOperationWriteBytecodeWithConfig( 718 MlirOperation op, MlirBytecodeWriterConfig config, 719 MlirStringCallback callback, void *userData) { 720 detail::CallbackOstream stream(callback, userData); 721 return wrap(writeBytecodeToFile(unwrap(op), stream, *unwrap(config))); 722 } 723 724 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } 725 726 bool mlirOperationVerify(MlirOperation op) { 727 return succeeded(verify(unwrap(op))); 728 } 729 730 void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) { 731 return unwrap(op)->moveAfter(unwrap(other)); 732 } 733 734 void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { 735 return unwrap(op)->moveBefore(unwrap(other)); 736 } 737 738 static mlir::WalkResult unwrap(MlirWalkResult result) { 739 switch (result) { 740 case MlirWalkResultAdvance: 741 return mlir::WalkResult::advance(); 742 743 case MlirWalkResultInterrupt: 744 return mlir::WalkResult::interrupt(); 745 746 case MlirWalkResultSkip: 747 return mlir::WalkResult::skip(); 748 } 749 llvm_unreachable("unknown result in WalkResult::unwrap"); 750 } 751 752 void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, 753 void *userData, MlirWalkOrder walkOrder) { 754 switch (walkOrder) { 755 756 case MlirWalkPreOrder: 757 unwrap(op)->walk<mlir::WalkOrder::PreOrder>( 758 [callback, userData](Operation *op) { 759 return unwrap(callback(wrap(op), userData)); 760 }); 761 break; 762 case MlirWalkPostOrder: 763 unwrap(op)->walk<mlir::WalkOrder::PostOrder>( 764 [callback, userData](Operation *op) { 765 return unwrap(callback(wrap(op), userData)); 766 }); 767 } 768 } 769 770 //===----------------------------------------------------------------------===// 771 // Region API. 772 //===----------------------------------------------------------------------===// 773 774 MlirRegion mlirRegionCreate() { return wrap(new Region); } 775 776 bool mlirRegionEqual(MlirRegion region, MlirRegion other) { 777 return unwrap(region) == unwrap(other); 778 } 779 780 MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { 781 Region *cppRegion = unwrap(region); 782 if (cppRegion->empty()) 783 return wrap(static_cast<Block *>(nullptr)); 784 return wrap(&cppRegion->front()); 785 } 786 787 void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block) { 788 unwrap(region)->push_back(unwrap(block)); 789 } 790 791 void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, 792 MlirBlock block) { 793 auto &blockList = unwrap(region)->getBlocks(); 794 blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); 795 } 796 797 void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, 798 MlirBlock block) { 799 Region *cppRegion = unwrap(region); 800 if (mlirBlockIsNull(reference)) { 801 cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); 802 return; 803 } 804 805 assert(unwrap(reference)->getParent() == unwrap(region) && 806 "expected reference block to belong to the region"); 807 cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), 808 unwrap(block)); 809 } 810 811 void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, 812 MlirBlock block) { 813 if (mlirBlockIsNull(reference)) 814 return mlirRegionAppendOwnedBlock(region, block); 815 816 assert(unwrap(reference)->getParent() == unwrap(region) && 817 "expected reference block to belong to the region"); 818 unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), 819 unwrap(block)); 820 } 821 822 void mlirRegionDestroy(MlirRegion region) { 823 delete static_cast<Region *>(region.ptr); 824 } 825 826 void mlirRegionTakeBody(MlirRegion target, MlirRegion source) { 827 unwrap(target)->takeBody(*unwrap(source)); 828 } 829 830 //===----------------------------------------------------------------------===// 831 // Block API. 832 //===----------------------------------------------------------------------===// 833 834 MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, 835 MlirLocation const *locs) { 836 Block *b = new Block; 837 for (intptr_t i = 0; i < nArgs; ++i) 838 b->addArgument(unwrap(args[i]), unwrap(locs[i])); 839 return wrap(b); 840 } 841 842 bool mlirBlockEqual(MlirBlock block, MlirBlock other) { 843 return unwrap(block) == unwrap(other); 844 } 845 846 MlirOperation mlirBlockGetParentOperation(MlirBlock block) { 847 return wrap(unwrap(block)->getParentOp()); 848 } 849 850 MlirRegion mlirBlockGetParentRegion(MlirBlock block) { 851 return wrap(unwrap(block)->getParent()); 852 } 853 854 MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { 855 return wrap(unwrap(block)->getNextNode()); 856 } 857 858 MlirOperation mlirBlockGetFirstOperation(MlirBlock block) { 859 Block *cppBlock = unwrap(block); 860 if (cppBlock->empty()) 861 return wrap(static_cast<Operation *>(nullptr)); 862 return wrap(&cppBlock->front()); 863 } 864 865 MlirOperation mlirBlockGetTerminator(MlirBlock block) { 866 Block *cppBlock = unwrap(block); 867 if (cppBlock->empty()) 868 return wrap(static_cast<Operation *>(nullptr)); 869 Operation &back = cppBlock->back(); 870 if (!back.hasTrait<OpTrait::IsTerminator>()) 871 return wrap(static_cast<Operation *>(nullptr)); 872 return wrap(&back); 873 } 874 875 void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation) { 876 unwrap(block)->push_back(unwrap(operation)); 877 } 878 879 void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, 880 MlirOperation operation) { 881 auto &opList = unwrap(block)->getOperations(); 882 opList.insert(std::next(opList.begin(), pos), unwrap(operation)); 883 } 884 885 void mlirBlockInsertOwnedOperationAfter(MlirBlock block, 886 MlirOperation reference, 887 MlirOperation operation) { 888 Block *cppBlock = unwrap(block); 889 if (mlirOperationIsNull(reference)) { 890 cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); 891 return; 892 } 893 894 assert(unwrap(reference)->getBlock() == unwrap(block) && 895 "expected reference operation to belong to the block"); 896 cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), 897 unwrap(operation)); 898 } 899 900 void mlirBlockInsertOwnedOperationBefore(MlirBlock block, 901 MlirOperation reference, 902 MlirOperation operation) { 903 if (mlirOperationIsNull(reference)) 904 return mlirBlockAppendOwnedOperation(block, operation); 905 906 assert(unwrap(reference)->getBlock() == unwrap(block) && 907 "expected reference operation to belong to the block"); 908 unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), 909 unwrap(operation)); 910 } 911 912 void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } 913 914 void mlirBlockDetach(MlirBlock block) { 915 Block *b = unwrap(block); 916 b->getParent()->getBlocks().remove(b); 917 } 918 919 intptr_t mlirBlockGetNumArguments(MlirBlock block) { 920 return static_cast<intptr_t>(unwrap(block)->getNumArguments()); 921 } 922 923 MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, 924 MlirLocation loc) { 925 return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc))); 926 } 927 928 void mlirBlockEraseArgument(MlirBlock block, unsigned index) { 929 return unwrap(block)->eraseArgument(index); 930 } 931 932 MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type, 933 MlirLocation loc) { 934 return wrap(unwrap(block)->insertArgument(pos, unwrap(type), unwrap(loc))); 935 } 936 937 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { 938 return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos))); 939 } 940 941 void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, 942 void *userData) { 943 detail::CallbackOstream stream(callback, userData); 944 unwrap(block)->print(stream); 945 } 946 947 //===----------------------------------------------------------------------===// 948 // Value API. 949 //===----------------------------------------------------------------------===// 950 951 bool mlirValueEqual(MlirValue value1, MlirValue value2) { 952 return unwrap(value1) == unwrap(value2); 953 } 954 955 bool mlirValueIsABlockArgument(MlirValue value) { 956 return llvm::isa<BlockArgument>(unwrap(value)); 957 } 958 959 bool mlirValueIsAOpResult(MlirValue value) { 960 return llvm::isa<OpResult>(unwrap(value)); 961 } 962 963 MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { 964 return wrap(llvm::cast<BlockArgument>(unwrap(value)).getOwner()); 965 } 966 967 intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { 968 return static_cast<intptr_t>( 969 llvm::cast<BlockArgument>(unwrap(value)).getArgNumber()); 970 } 971 972 void mlirBlockArgumentSetType(MlirValue value, MlirType type) { 973 llvm::cast<BlockArgument>(unwrap(value)).setType(unwrap(type)); 974 } 975 976 MlirOperation mlirOpResultGetOwner(MlirValue value) { 977 return wrap(llvm::cast<OpResult>(unwrap(value)).getOwner()); 978 } 979 980 intptr_t mlirOpResultGetResultNumber(MlirValue value) { 981 return static_cast<intptr_t>( 982 llvm::cast<OpResult>(unwrap(value)).getResultNumber()); 983 } 984 985 MlirType mlirValueGetType(MlirValue value) { 986 return wrap(unwrap(value).getType()); 987 } 988 989 void mlirValueSetType(MlirValue value, MlirType type) { 990 unwrap(value).setType(unwrap(type)); 991 } 992 993 void mlirValueDump(MlirValue value) { unwrap(value).dump(); } 994 995 void mlirValuePrint(MlirValue value, MlirStringCallback callback, 996 void *userData) { 997 detail::CallbackOstream stream(callback, userData); 998 unwrap(value).print(stream); 999 } 1000 1001 void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, 1002 MlirStringCallback callback, void *userData) { 1003 detail::CallbackOstream stream(callback, userData); 1004 Value cppValue = unwrap(value); 1005 cppValue.printAsOperand(stream, *unwrap(state)); 1006 } 1007 1008 MlirOpOperand mlirValueGetFirstUse(MlirValue value) { 1009 Value cppValue = unwrap(value); 1010 if (cppValue.use_empty()) 1011 return {}; 1012 1013 OpOperand *opOperand = cppValue.use_begin().getOperand(); 1014 1015 return wrap(opOperand); 1016 } 1017 1018 void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) { 1019 unwrap(oldValue).replaceAllUsesWith(unwrap(newValue)); 1020 } 1021 1022 void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue, 1023 intptr_t numExceptions, 1024 MlirOperation *exceptions) { 1025 Value oldValueCpp = unwrap(oldValue); 1026 Value newValueCpp = unwrap(newValue); 1027 1028 llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet; 1029 for (intptr_t i = 0; i < numExceptions; ++i) { 1030 exceptionSet.insert(unwrap(exceptions[i])); 1031 } 1032 1033 oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet); 1034 } 1035 1036 //===----------------------------------------------------------------------===// 1037 // OpOperand API. 1038 //===----------------------------------------------------------------------===// 1039 1040 bool mlirOpOperandIsNull(MlirOpOperand opOperand) { return !opOperand.ptr; } 1041 1042 MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand) { 1043 return wrap(unwrap(opOperand)->getOwner()); 1044 } 1045 1046 MlirValue mlirOpOperandGetValue(MlirOpOperand opOperand) { 1047 return wrap(unwrap(opOperand)->get()); 1048 } 1049 1050 unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand) { 1051 return unwrap(opOperand)->getOperandNumber(); 1052 } 1053 1054 MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand) { 1055 if (mlirOpOperandIsNull(opOperand)) 1056 return {}; 1057 1058 OpOperand *nextOpOperand = static_cast<OpOperand *>( 1059 unwrap(opOperand)->getNextOperandUsingThisValue()); 1060 1061 if (!nextOpOperand) 1062 return {}; 1063 1064 return wrap(nextOpOperand); 1065 } 1066 1067 //===----------------------------------------------------------------------===// 1068 // Type API. 1069 //===----------------------------------------------------------------------===// 1070 1071 MlirType mlirTypeParseGet(MlirContext context, MlirStringRef type) { 1072 return wrap(mlir::parseType(unwrap(type), unwrap(context))); 1073 } 1074 1075 MlirContext mlirTypeGetContext(MlirType type) { 1076 return wrap(unwrap(type).getContext()); 1077 } 1078 1079 MlirTypeID mlirTypeGetTypeID(MlirType type) { 1080 return wrap(unwrap(type).getTypeID()); 1081 } 1082 1083 MlirDialect mlirTypeGetDialect(MlirType type) { 1084 return wrap(&unwrap(type).getDialect()); 1085 } 1086 1087 bool mlirTypeEqual(MlirType t1, MlirType t2) { 1088 return unwrap(t1) == unwrap(t2); 1089 } 1090 1091 void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { 1092 detail::CallbackOstream stream(callback, userData); 1093 unwrap(type).print(stream); 1094 } 1095 1096 void mlirTypeDump(MlirType type) { unwrap(type).dump(); } 1097 1098 //===----------------------------------------------------------------------===// 1099 // Attribute API. 1100 //===----------------------------------------------------------------------===// 1101 1102 MlirAttribute mlirAttributeParseGet(MlirContext context, MlirStringRef attr) { 1103 return wrap(mlir::parseAttribute(unwrap(attr), unwrap(context))); 1104 } 1105 1106 MlirContext mlirAttributeGetContext(MlirAttribute attribute) { 1107 return wrap(unwrap(attribute).getContext()); 1108 } 1109 1110 MlirType mlirAttributeGetType(MlirAttribute attribute) { 1111 Attribute attr = unwrap(attribute); 1112 if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) 1113 return wrap(typedAttr.getType()); 1114 return wrap(NoneType::get(attr.getContext())); 1115 } 1116 1117 MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { 1118 return wrap(unwrap(attr).getTypeID()); 1119 } 1120 1121 MlirDialect mlirAttributeGetDialect(MlirAttribute attr) { 1122 return wrap(&unwrap(attr).getDialect()); 1123 } 1124 1125 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { 1126 return unwrap(a1) == unwrap(a2); 1127 } 1128 1129 void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, 1130 void *userData) { 1131 detail::CallbackOstream stream(callback, userData); 1132 unwrap(attr).print(stream); 1133 } 1134 1135 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } 1136 1137 MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, 1138 MlirAttribute attr) { 1139 return MlirNamedAttribute{name, attr}; 1140 } 1141 1142 //===----------------------------------------------------------------------===// 1143 // Identifier API. 1144 //===----------------------------------------------------------------------===// 1145 1146 MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { 1147 return wrap(StringAttr::get(unwrap(context), unwrap(str))); 1148 } 1149 1150 MlirContext mlirIdentifierGetContext(MlirIdentifier ident) { 1151 return wrap(unwrap(ident).getContext()); 1152 } 1153 1154 bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { 1155 return unwrap(ident) == unwrap(other); 1156 } 1157 1158 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { 1159 return wrap(unwrap(ident).strref()); 1160 } 1161 1162 //===----------------------------------------------------------------------===// 1163 // Symbol and SymbolTable API. 1164 //===----------------------------------------------------------------------===// 1165 1166 MlirStringRef mlirSymbolTableGetSymbolAttributeName() { 1167 return wrap(SymbolTable::getSymbolAttrName()); 1168 } 1169 1170 MlirStringRef mlirSymbolTableGetVisibilityAttributeName() { 1171 return wrap(SymbolTable::getVisibilityAttrName()); 1172 } 1173 1174 MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { 1175 if (!unwrap(operation)->hasTrait<OpTrait::SymbolTable>()) 1176 return wrap(static_cast<SymbolTable *>(nullptr)); 1177 return wrap(new SymbolTable(unwrap(operation))); 1178 } 1179 1180 void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) { 1181 delete unwrap(symbolTable); 1182 } 1183 1184 MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, 1185 MlirStringRef name) { 1186 return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length))); 1187 } 1188 1189 MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, 1190 MlirOperation operation) { 1191 return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation))); 1192 } 1193 1194 void mlirSymbolTableErase(MlirSymbolTable symbolTable, 1195 MlirOperation operation) { 1196 unwrap(symbolTable)->erase(unwrap(operation)); 1197 } 1198 1199 MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, 1200 MlirStringRef newSymbol, 1201 MlirOperation from) { 1202 auto *cppFrom = unwrap(from); 1203 auto *context = cppFrom->getContext(); 1204 auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol)); 1205 auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol)); 1206 return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr, 1207 unwrap(from))); 1208 } 1209 1210 void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, 1211 void (*callback)(MlirOperation, bool, 1212 void *userData), 1213 void *userData) { 1214 SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible, 1215 [&](Operation *foundOpCpp, bool isVisible) { 1216 callback(wrap(foundOpCpp), isVisible, 1217 userData); 1218 }); 1219 } 1220