1 //===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Rewrite.h" 10 11 #include "mlir-c/Transforms.h" 12 #include "mlir/CAPI/IR.h" 13 #include "mlir/CAPI/Rewrite.h" 14 #include "mlir/CAPI/Support.h" 15 #include "mlir/CAPI/Wrap.h" 16 #include "mlir/IR/PatternMatch.h" 17 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 20 using namespace mlir; 21 22 //===----------------------------------------------------------------------===// 23 /// RewriterBase API inherited from OpBuilder 24 //===----------------------------------------------------------------------===// 25 26 MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) { 27 return wrap(unwrap(rewriter)->getContext()); 28 } 29 30 //===----------------------------------------------------------------------===// 31 /// Insertion points methods 32 33 void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) { 34 unwrap(rewriter)->clearInsertionPoint(); 35 } 36 37 void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, 38 MlirOperation op) { 39 unwrap(rewriter)->setInsertionPoint(unwrap(op)); 40 } 41 42 void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, 43 MlirOperation op) { 44 unwrap(rewriter)->setInsertionPointAfter(unwrap(op)); 45 } 46 47 void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, 48 MlirValue value) { 49 unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value)); 50 } 51 52 void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, 53 MlirBlock block) { 54 unwrap(rewriter)->setInsertionPointToStart(unwrap(block)); 55 } 56 57 void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, 58 MlirBlock block) { 59 unwrap(rewriter)->setInsertionPointToEnd(unwrap(block)); 60 } 61 62 MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) { 63 return wrap(unwrap(rewriter)->getInsertionBlock()); 64 } 65 66 MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) { 67 return wrap(unwrap(rewriter)->getBlock()); 68 } 69 70 //===----------------------------------------------------------------------===// 71 /// Block and operation creation/insertion/cloning 72 73 MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter, 74 MlirBlock insertBefore, 75 intptr_t nArgTypes, 76 MlirType const *argTypes, 77 MlirLocation const *locations) { 78 SmallVector<Type, 4> args; 79 ArrayRef<Type> unwrappedArgs = unwrapList(nArgTypes, argTypes, args); 80 SmallVector<Location, 4> locs; 81 ArrayRef<Location> unwrappedLocs = unwrapList(nArgTypes, locations, locs); 82 return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs, 83 unwrappedLocs)); 84 } 85 86 MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, 87 MlirOperation op) { 88 return wrap(unwrap(rewriter)->insert(unwrap(op))); 89 } 90 91 // Other methods of OpBuilder 92 93 MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, 94 MlirOperation op) { 95 return wrap(unwrap(rewriter)->clone(*unwrap(op))); 96 } 97 98 MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter, 99 MlirOperation op) { 100 return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op))); 101 } 102 103 void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, 104 MlirRegion region, MlirBlock before) { 105 106 unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before)); 107 } 108 109 //===----------------------------------------------------------------------===// 110 /// RewriterBase API 111 //===----------------------------------------------------------------------===// 112 113 void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, 114 MlirRegion region, MlirBlock before) { 115 unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before)); 116 } 117 118 void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, 119 MlirOperation op, intptr_t nValues, 120 MlirValue const *values) { 121 SmallVector<Value, 4> vals; 122 ArrayRef<Value> unwrappedVals = unwrapList(nValues, values, vals); 123 unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals); 124 } 125 126 void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, 127 MlirOperation op, 128 MlirOperation newOp) { 129 unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp)); 130 } 131 132 void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) { 133 unwrap(rewriter)->eraseOp(unwrap(op)); 134 } 135 136 void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) { 137 unwrap(rewriter)->eraseBlock(unwrap(block)); 138 } 139 140 void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, 141 MlirBlock source, MlirOperation op, 142 intptr_t nArgValues, 143 MlirValue const *argValues) { 144 SmallVector<Value, 4> vals; 145 ArrayRef<Value> unwrappedVals = unwrapList(nArgValues, argValues, vals); 146 147 unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op), 148 unwrappedVals); 149 } 150 151 void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source, 152 MlirBlock dest, intptr_t nArgValues, 153 MlirValue const *argValues) { 154 SmallVector<Value, 4> args; 155 ArrayRef<Value> unwrappedArgs = unwrapList(nArgValues, argValues, args); 156 unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs); 157 } 158 159 void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op, 160 MlirOperation existingOp) { 161 unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp)); 162 } 163 164 void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op, 165 MlirOperation existingOp) { 166 unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp)); 167 } 168 169 void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, 170 MlirBlock existingBlock) { 171 unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock)); 172 } 173 174 void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, 175 MlirOperation op) { 176 unwrap(rewriter)->startOpModification(unwrap(op)); 177 } 178 179 void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, 180 MlirOperation op) { 181 unwrap(rewriter)->finalizeOpModification(unwrap(op)); 182 } 183 184 void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, 185 MlirOperation op) { 186 unwrap(rewriter)->cancelOpModification(unwrap(op)); 187 } 188 189 void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, 190 MlirValue from, MlirValue to) { 191 unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to)); 192 } 193 194 void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter, 195 intptr_t nValues, 196 MlirValue const *from, 197 MlirValue const *to) { 198 SmallVector<Value, 4> fromVals; 199 ArrayRef<Value> unwrappedFromVals = unwrapList(nValues, from, fromVals); 200 SmallVector<Value, 4> toVals; 201 ArrayRef<Value> unwrappedToVals = unwrapList(nValues, to, toVals); 202 unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals); 203 } 204 205 void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, 206 MlirOperation from, 207 intptr_t nTo, 208 MlirValue const *to) { 209 SmallVector<Value, 4> toVals; 210 ArrayRef<Value> unwrappedToVals = unwrapList(nTo, to, toVals); 211 unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals); 212 } 213 214 void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter, 215 MlirOperation from, 216 MlirOperation to) { 217 unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to)); 218 } 219 220 void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter, 221 MlirOperation op, 222 intptr_t nNewValues, 223 MlirValue const *newValues, 224 MlirBlock block) { 225 SmallVector<Value, 4> vals; 226 ArrayRef<Value> unwrappedVals = unwrapList(nNewValues, newValues, vals); 227 unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals, 228 unwrap(block)); 229 } 230 231 void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, 232 MlirValue from, MlirValue to, 233 MlirOperation exceptedUser) { 234 unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to), 235 unwrap(exceptedUser)); 236 } 237 238 //===----------------------------------------------------------------------===// 239 /// IRRewriter API 240 //===----------------------------------------------------------------------===// 241 242 MlirRewriterBase mlirIRRewriterCreate(MlirContext context) { 243 return wrap(new IRRewriter(unwrap(context))); 244 } 245 246 MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) { 247 return wrap(new IRRewriter(unwrap(op))); 248 } 249 250 void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { 251 delete static_cast<IRRewriter *>(unwrap(rewriter)); 252 } 253 254 //===----------------------------------------------------------------------===// 255 /// RewritePatternSet and FrozenRewritePatternSet API 256 //===----------------------------------------------------------------------===// 257 258 inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { 259 assert(module.ptr && "unexpected null module"); 260 return *(static_cast<mlir::RewritePatternSet *>(module.ptr)); 261 } 262 263 inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { 264 return {module}; 265 } 266 267 inline mlir::FrozenRewritePatternSet * 268 unwrap(MlirFrozenRewritePatternSet module) { 269 assert(module.ptr && "unexpected null module"); 270 return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr); 271 } 272 273 inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) { 274 return {module}; 275 } 276 277 MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { 278 auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op))); 279 op.ptr = nullptr; 280 return wrap(m); 281 } 282 283 void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) { 284 delete unwrap(op); 285 op.ptr = nullptr; 286 } 287 288 MlirLogicalResult 289 mlirApplyPatternsAndFoldGreedily(MlirModule op, 290 MlirFrozenRewritePatternSet patterns, 291 MlirGreedyRewriteDriverConfig) { 292 return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); 293 } 294 295 //===----------------------------------------------------------------------===// 296 /// PDLPatternModule API 297 //===----------------------------------------------------------------------===// 298 299 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH 300 inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { 301 assert(module.ptr && "unexpected null module"); 302 return static_cast<mlir::PDLPatternModule *>(module.ptr); 303 } 304 305 inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { 306 return {module}; 307 } 308 309 MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { 310 return wrap(new mlir::PDLPatternModule( 311 mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op)))); 312 } 313 314 void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) { 315 delete unwrap(op); 316 op.ptr = nullptr; 317 } 318 319 MlirRewritePatternSet 320 mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { 321 auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op))); 322 op.ptr = nullptr; 323 return wrap(m); 324 } 325 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH 326