//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir-c/Rewrite.h" #include "mlir-c/Transforms.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Rewrite.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; //===----------------------------------------------------------------------===// /// RewriterBase API inherited from OpBuilder //===----------------------------------------------------------------------===// MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) { return wrap(unwrap(rewriter)->getContext()); } //===----------------------------------------------------------------------===// /// Insertion points methods void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) { unwrap(rewriter)->clearInsertionPoint(); } void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, MlirOperation op) { unwrap(rewriter)->setInsertionPoint(unwrap(op)); } void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, MlirOperation op) { unwrap(rewriter)->setInsertionPointAfter(unwrap(op)); } void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, MlirValue value) { unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value)); } void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, MlirBlock block) { unwrap(rewriter)->setInsertionPointToStart(unwrap(block)); } void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, MlirBlock block) { unwrap(rewriter)->setInsertionPointToEnd(unwrap(block)); } MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) { return wrap(unwrap(rewriter)->getInsertionBlock()); } MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) { return wrap(unwrap(rewriter)->getBlock()); } //===----------------------------------------------------------------------===// /// Block and operation creation/insertion/cloning MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter, MlirBlock insertBefore, intptr_t nArgTypes, MlirType const *argTypes, MlirLocation const *locations) { SmallVector args; ArrayRef unwrappedArgs = unwrapList(nArgTypes, argTypes, args); SmallVector locs; ArrayRef unwrappedLocs = unwrapList(nArgTypes, locations, locs); return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs, unwrappedLocs)); } MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, MlirOperation op) { return wrap(unwrap(rewriter)->insert(unwrap(op))); } // Other methods of OpBuilder MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op) { return wrap(unwrap(rewriter)->clone(*unwrap(op))); } MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter, MlirOperation op) { return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op))); } void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region, MlirBlock before) { unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before)); } //===----------------------------------------------------------------------===// /// RewriterBase API //===----------------------------------------------------------------------===// void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, MlirRegion region, MlirBlock before) { unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before)); } void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op, intptr_t nValues, MlirValue const *values) { SmallVector vals; ArrayRef unwrappedVals = unwrapList(nValues, values, vals); unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals); } void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, MlirOperation op, MlirOperation newOp) { unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp)); } void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) { unwrap(rewriter)->eraseOp(unwrap(op)); } void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) { unwrap(rewriter)->eraseBlock(unwrap(block)); } void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, MlirBlock source, MlirOperation op, intptr_t nArgValues, MlirValue const *argValues) { SmallVector vals; ArrayRef unwrappedVals = unwrapList(nArgValues, argValues, vals); unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op), unwrappedVals); } void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source, MlirBlock dest, intptr_t nArgValues, MlirValue const *argValues) { SmallVector args; ArrayRef unwrappedArgs = unwrapList(nArgValues, argValues, args); unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs); } void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op, MlirOperation existingOp) { unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp)); } void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op, MlirOperation existingOp) { unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp)); } void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, MlirBlock existingBlock) { unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock)); } void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, MlirOperation op) { unwrap(rewriter)->startOpModification(unwrap(op)); } void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, MlirOperation op) { unwrap(rewriter)->finalizeOpModification(unwrap(op)); } void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, MlirOperation op) { unwrap(rewriter)->cancelOpModification(unwrap(op)); } void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, MlirValue from, MlirValue to) { unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to)); } void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter, intptr_t nValues, MlirValue const *from, MlirValue const *to) { SmallVector fromVals; ArrayRef unwrappedFromVals = unwrapList(nValues, from, fromVals); SmallVector toVals; ArrayRef unwrappedToVals = unwrapList(nValues, to, toVals); unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals); } void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, MlirOperation from, intptr_t nTo, MlirValue const *to) { SmallVector toVals; ArrayRef unwrappedToVals = unwrapList(nTo, to, toVals); unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals); } void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter, MlirOperation from, MlirOperation to) { unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to)); } void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter, MlirOperation op, intptr_t nNewValues, MlirValue const *newValues, MlirBlock block) { SmallVector vals; ArrayRef unwrappedVals = unwrapList(nNewValues, newValues, vals); unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals, unwrap(block)); } void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, MlirValue from, MlirValue to, MlirOperation exceptedUser) { unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to), unwrap(exceptedUser)); } //===----------------------------------------------------------------------===// /// IRRewriter API //===----------------------------------------------------------------------===// MlirRewriterBase mlirIRRewriterCreate(MlirContext context) { return wrap(new IRRewriter(unwrap(context))); } MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) { return wrap(new IRRewriter(unwrap(op))); } void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { delete static_cast(unwrap(rewriter)); } //===----------------------------------------------------------------------===// /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { assert(module.ptr && "unexpected null module"); return *(static_cast(module.ptr)); } inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { return {module}; } inline mlir::FrozenRewritePatternSet * unwrap(MlirFrozenRewritePatternSet module) { assert(module.ptr && "unexpected null module"); return static_cast(module.ptr); } inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) { return {module}; } MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op))); op.ptr = nullptr; return wrap(m); } void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) { delete unwrap(op); op.ptr = nullptr; } MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig) { return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); } //===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// #if MLIR_ENABLE_PDL_IN_PATTERNMATCH inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { assert(module.ptr && "unexpected null module"); return static_cast(module.ptr); } inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { return {module}; } MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { return wrap(new mlir::PDLPatternModule( mlir::OwningOpRef(unwrap(op)))); } void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) { delete unwrap(op); op.ptr = nullptr; } MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op))); op.ptr = nullptr; return wrap(m); } #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH