xref: /llvm-project/mlir/lib/CAPI/Transforms/Rewrite.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Rewrite.h"
10 
11 #include "mlir-c/Transforms.h"
12 #include "mlir/CAPI/IR.h"
13 #include "mlir/CAPI/Rewrite.h"
14 #include "mlir/CAPI/Support.h"
15 #include "mlir/CAPI/Wrap.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 /// RewriterBase API inherited from OpBuilder
24 //===----------------------------------------------------------------------===//
25 
26 MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) {
27   return wrap(unwrap(rewriter)->getContext());
28 }
29 
30 //===----------------------------------------------------------------------===//
31 /// Insertion points methods
32 
33 void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) {
34   unwrap(rewriter)->clearInsertionPoint();
35 }
36 
37 void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter,
38                                              MlirOperation op) {
39   unwrap(rewriter)->setInsertionPoint(unwrap(op));
40 }
41 
42 void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter,
43                                             MlirOperation op) {
44   unwrap(rewriter)->setInsertionPointAfter(unwrap(op));
45 }
46 
47 void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter,
48                                                  MlirValue value) {
49   unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value));
50 }
51 
52 void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter,
53                                               MlirBlock block) {
54   unwrap(rewriter)->setInsertionPointToStart(unwrap(block));
55 }
56 
57 void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter,
58                                             MlirBlock block) {
59   unwrap(rewriter)->setInsertionPointToEnd(unwrap(block));
60 }
61 
62 MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) {
63   return wrap(unwrap(rewriter)->getInsertionBlock());
64 }
65 
66 MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) {
67   return wrap(unwrap(rewriter)->getBlock());
68 }
69 
70 //===----------------------------------------------------------------------===//
71 /// Block and operation creation/insertion/cloning
72 
73 MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter,
74                                             MlirBlock insertBefore,
75                                             intptr_t nArgTypes,
76                                             MlirType const *argTypes,
77                                             MlirLocation const *locations) {
78   SmallVector<Type, 4> args;
79   ArrayRef<Type> unwrappedArgs = unwrapList(nArgTypes, argTypes, args);
80   SmallVector<Location, 4> locs;
81   ArrayRef<Location> unwrappedLocs = unwrapList(nArgTypes, locations, locs);
82   return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs,
83                                             unwrappedLocs));
84 }
85 
86 MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter,
87                                      MlirOperation op) {
88   return wrap(unwrap(rewriter)->insert(unwrap(op)));
89 }
90 
91 // Other methods of OpBuilder
92 
93 MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter,
94                                     MlirOperation op) {
95   return wrap(unwrap(rewriter)->clone(*unwrap(op)));
96 }
97 
98 MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter,
99                                                   MlirOperation op) {
100   return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op)));
101 }
102 
103 void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter,
104                                        MlirRegion region, MlirBlock before) {
105 
106   unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before));
107 }
108 
109 //===----------------------------------------------------------------------===//
110 /// RewriterBase API
111 //===----------------------------------------------------------------------===//
112 
113 void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter,
114                                         MlirRegion region, MlirBlock before) {
115   unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before));
116 }
117 
118 void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter,
119                                          MlirOperation op, intptr_t nValues,
120                                          MlirValue const *values) {
121   SmallVector<Value, 4> vals;
122   ArrayRef<Value> unwrappedVals = unwrapList(nValues, values, vals);
123   unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals);
124 }
125 
126 void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter,
127                                             MlirOperation op,
128                                             MlirOperation newOp) {
129   unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp));
130 }
131 
132 void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) {
133   unwrap(rewriter)->eraseOp(unwrap(op));
134 }
135 
136 void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) {
137   unwrap(rewriter)->eraseBlock(unwrap(block));
138 }
139 
140 void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter,
141                                        MlirBlock source, MlirOperation op,
142                                        intptr_t nArgValues,
143                                        MlirValue const *argValues) {
144   SmallVector<Value, 4> vals;
145   ArrayRef<Value> unwrappedVals = unwrapList(nArgValues, argValues, vals);
146 
147   unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op),
148                                       unwrappedVals);
149 }
150 
151 void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source,
152                                  MlirBlock dest, intptr_t nArgValues,
153                                  MlirValue const *argValues) {
154   SmallVector<Value, 4> args;
155   ArrayRef<Value> unwrappedArgs = unwrapList(nArgValues, argValues, args);
156   unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs);
157 }
158 
159 void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op,
160                                   MlirOperation existingOp) {
161   unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp));
162 }
163 
164 void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op,
165                                  MlirOperation existingOp) {
166   unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp));
167 }
168 
169 void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block,
170                                      MlirBlock existingBlock) {
171   unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock));
172 }
173 
174 void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter,
175                                          MlirOperation op) {
176   unwrap(rewriter)->startOpModification(unwrap(op));
177 }
178 
179 void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter,
180                                             MlirOperation op) {
181   unwrap(rewriter)->finalizeOpModification(unwrap(op));
182 }
183 
184 void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter,
185                                           MlirOperation op) {
186   unwrap(rewriter)->cancelOpModification(unwrap(op));
187 }
188 
189 void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter,
190                                         MlirValue from, MlirValue to) {
191   unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to));
192 }
193 
194 void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter,
195                                                   intptr_t nValues,
196                                                   MlirValue const *from,
197                                                   MlirValue const *to) {
198   SmallVector<Value, 4> fromVals;
199   ArrayRef<Value> unwrappedFromVals = unwrapList(nValues, from, fromVals);
200   SmallVector<Value, 4> toVals;
201   ArrayRef<Value> unwrappedToVals = unwrapList(nValues, to, toVals);
202   unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals);
203 }
204 
205 void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter,
206                                                     MlirOperation from,
207                                                     intptr_t nTo,
208                                                     MlirValue const *to) {
209   SmallVector<Value, 4> toVals;
210   ArrayRef<Value> unwrappedToVals = unwrapList(nTo, to, toVals);
211   unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals);
212 }
213 
214 void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter,
215                                                    MlirOperation from,
216                                                    MlirOperation to) {
217   unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to));
218 }
219 
220 void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter,
221                                               MlirOperation op,
222                                               intptr_t nNewValues,
223                                               MlirValue const *newValues,
224                                               MlirBlock block) {
225   SmallVector<Value, 4> vals;
226   ArrayRef<Value> unwrappedVals = unwrapList(nNewValues, newValues, vals);
227   unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals,
228                                              unwrap(block));
229 }
230 
231 void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter,
232                                           MlirValue from, MlirValue to,
233                                           MlirOperation exceptedUser) {
234   unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to),
235                                          unwrap(exceptedUser));
236 }
237 
238 //===----------------------------------------------------------------------===//
239 /// IRRewriter API
240 //===----------------------------------------------------------------------===//
241 
242 MlirRewriterBase mlirIRRewriterCreate(MlirContext context) {
243   return wrap(new IRRewriter(unwrap(context)));
244 }
245 
246 MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) {
247   return wrap(new IRRewriter(unwrap(op)));
248 }
249 
250 void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
251   delete static_cast<IRRewriter *>(unwrap(rewriter));
252 }
253 
254 //===----------------------------------------------------------------------===//
255 /// RewritePatternSet and FrozenRewritePatternSet API
256 //===----------------------------------------------------------------------===//
257 
258 inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
259   assert(module.ptr && "unexpected null module");
260   return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
261 }
262 
263 inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
264   return {module};
265 }
266 
267 inline mlir::FrozenRewritePatternSet *
268 unwrap(MlirFrozenRewritePatternSet module) {
269   assert(module.ptr && "unexpected null module");
270   return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
271 }
272 
273 inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
274   return {module};
275 }
276 
277 MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) {
278   auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op)));
279   op.ptr = nullptr;
280   return wrap(m);
281 }
282 
283 void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
284   delete unwrap(op);
285   op.ptr = nullptr;
286 }
287 
288 MlirLogicalResult
289 mlirApplyPatternsAndFoldGreedily(MlirModule op,
290                                  MlirFrozenRewritePatternSet patterns,
291                                  MlirGreedyRewriteDriverConfig) {
292   return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
293 }
294 
295 //===----------------------------------------------------------------------===//
296 /// PDLPatternModule API
297 //===----------------------------------------------------------------------===//
298 
299 #if MLIR_ENABLE_PDL_IN_PATTERNMATCH
300 inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
301   assert(module.ptr && "unexpected null module");
302   return static_cast<mlir::PDLPatternModule *>(module.ptr);
303 }
304 
305 inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
306   return {module};
307 }
308 
309 MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) {
310   return wrap(new mlir::PDLPatternModule(
311       mlir::OwningOpRef<mlir::ModuleOp>(unwrap(op))));
312 }
313 
314 void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) {
315   delete unwrap(op);
316   op.ptr = nullptr;
317 }
318 
319 MlirRewritePatternSet
320 mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) {
321   auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op)));
322   op.ptr = nullptr;
323   return wrap(m);
324 }
325 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
326