xref: /llvm-project/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===//
2 //
3 // Licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Transforms/OneToNTypeConversion.h"
10 
11 #include "mlir/Interfaces/FunctionInterfaces.h"
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13 #include "llvm/ADT/SmallSet.h"
14 
15 #include <unordered_map>
16 
17 using namespace llvm;
18 using namespace mlir;
19 
20 TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
21   TypeRange convertedTypes = getConvertedTypes();
22   if (auto mapping = getInputMapping(originalTypeNo))
23     return convertedTypes.slice(mapping->inputNo, mapping->size);
24   return {};
25 }
26 
27 ValueRange
28 OneToNTypeMapping::getConvertedValues(ValueRange convertedValues,
29                                       unsigned originalValueNo) const {
30   if (auto mapping = getInputMapping(originalValueNo))
31     return convertedValues.slice(mapping->inputNo, mapping->size);
32   return {};
33 }
34 
35 void OneToNTypeMapping::convertLocation(
36     Value originalValue, unsigned originalValueNo,
37     llvm::SmallVectorImpl<Location> &result) const {
38   if (auto mapping = getInputMapping(originalValueNo))
39     result.append(mapping->size, originalValue.getLoc());
40 }
41 
42 void OneToNTypeMapping::convertLocations(
43     ValueRange originalValues, llvm::SmallVectorImpl<Location> &result) const {
44   assert(originalValues.size() == getOriginalTypes().size());
45   for (auto [i, value] : llvm::enumerate(originalValues))
46     convertLocation(value, i, result);
47 }
48 
49 static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) {
50   return convertedTypes.size() == 1 && convertedTypes[0] == originalType;
51 }
52 
53 bool OneToNTypeMapping::hasNonIdentityConversion() const {
54   // XXX: I think that the original types and the converted types are the same
55   //      iff there was no non-identity type conversion. If that is true, the
56   //      patterns could actually test whether there is anything useful to do
57   //      without having access to the signature conversion.
58   for (auto [i, originalType] : llvm::enumerate(originalTypes)) {
59     TypeRange types = getConvertedTypes(i);
60     if (!isIdentityConversion(originalType, types)) {
61       assert(TypeRange(originalTypes) != getConvertedTypes());
62       return true;
63     }
64   }
65   assert(TypeRange(originalTypes) == getConvertedTypes());
66   return false;
67 }
68 
69 namespace {
70 enum class CastKind {
71   // Casts block arguments in the target type back to the source type. (If
72   // necessary, this cast becomes an argument materialization.)
73   Argument,
74 
75   // Casts other values in the target type back to the source type. (If
76   // necessary, this cast becomes a source materialization.)
77   Source,
78 
79   // Casts values in the source type to the target type. (If necessary, this
80   // cast becomes a target materialization.)
81   Target
82 };
83 } // namespace
84 
85 /// Mapping of enum values to string values.
86 StringRef getCastKindName(CastKind kind) {
87   static const std::unordered_map<CastKind, StringRef> castKindNames = {
88       {CastKind::Argument, "argument"},
89       {CastKind::Source, "source"},
90       {CastKind::Target, "target"}};
91   return castKindNames.at(kind);
92 }
93 
94 /// Attribute name that is used to annotate inserted unrealized casts with their
95 /// kind (source, argument, or target).
96 static const char *const castKindAttrName =
97     "__one-to-n-type-conversion_cast-kind__";
98 
99 /// Builds an `UnrealizedConversionCastOp` from the given inputs to the given
100 /// result types. Returns the result values of the cast.
101 static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes,
102                                       ValueRange inputs, CastKind kind) {
103   // Special case: 1-to-N conversion with N = 0. No need to build an
104   // UnrealizedConversionCastOp because the op will always be dead.
105   if (resultTypes.empty())
106     return ValueRange();
107 
108   // Create cast.
109   Location loc = builder.getUnknownLoc();
110   if (!inputs.empty())
111     loc = inputs.front().getLoc();
112   auto castOp =
113       builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
114 
115   // Store cast kind as attribute.
116   auto kindAttr = StringAttr::get(builder.getContext(), getCastKindName(kind));
117   castOp->setAttr(castKindAttrName, kindAttr);
118 
119   return castOp->getResults();
120 }
121 
122 /// Builds one `UnrealizedConversionCastOp` for each of the given original
123 /// values using the respective target types given in the provided conversion
124 /// mapping and returns the results of these casts. If the conversion mapping of
125 /// a value maps a type to itself (i.e., is an identity conversion), then no
126 /// cast is inserted and the original value is returned instead.
127 /// Note that these unrealized casts are different from target materializations
128 /// in that they are *always* inserted, even if they immediately fold away, such
129 /// that patterns always see valid intermediate IR, whereas materializations are
130 /// only used in the places where the unrealized casts *don't* fold away.
131 static SmallVector<Value>
132 buildUnrealizedForwardCasts(ValueRange originalValues,
133                             OneToNTypeMapping &conversion,
134                             RewriterBase &rewriter, CastKind kind) {
135 
136   // Convert each operand one by one.
137   SmallVector<Value> convertedValues;
138   convertedValues.reserve(conversion.getConvertedTypes().size());
139   for (auto [idx, originalValue] : llvm::enumerate(originalValues)) {
140     TypeRange convertedTypes = conversion.getConvertedTypes(idx);
141 
142     // Identity conversion: keep operand as is.
143     if (isIdentityConversion(originalValue.getType(), convertedTypes)) {
144       convertedValues.push_back(originalValue);
145       continue;
146     }
147 
148     // Non-identity conversion: materialize target types.
149     ValueRange castResult =
150         buildUnrealizedCast(rewriter, convertedTypes, originalValue, kind);
151     convertedValues.append(castResult.begin(), castResult.end());
152   }
153 
154   return convertedValues;
155 }
156 
157 /// Builds one `UnrealizedConversionCastOp` for each sequence of the given
158 /// original values to one value of the type they originated from, i.e., a
159 /// "reverse" conversion from N converted values back to one value of the
160 /// original type, using the given (forward) type conversion. If a given value
161 /// was mapped to a value of the same type (i.e., the conversion in the mapping
162 /// is an identity conversion), then the "converted" value is returned without
163 /// cast.
164 /// Note that these unrealized casts are different from source materializations
165 /// in that they are *always* inserted, even if they immediately fold away, such
166 /// that patterns always see valid intermediate IR, whereas materializations are
167 /// only used in the places where the unrealized casts *don't* fold away.
168 static SmallVector<Value>
169 buildUnrealizedBackwardsCasts(ValueRange convertedValues,
170                               const OneToNTypeMapping &typeConversion,
171                               RewriterBase &rewriter) {
172   assert(typeConversion.getConvertedTypes() == convertedValues.getTypes());
173 
174   // Create unrealized cast op for each converted result of the op.
175   SmallVector<Value> recastValues;
176   TypeRange originalTypes = typeConversion.getOriginalTypes();
177   recastValues.reserve(originalTypes.size());
178   auto convertedValueIt = convertedValues.begin();
179   for (auto [idx, originalType] : llvm::enumerate(originalTypes)) {
180     TypeRange convertedTypes = typeConversion.getConvertedTypes(idx);
181     size_t numConvertedValues = convertedTypes.size();
182     if (isIdentityConversion(originalType, convertedTypes)) {
183       // Identity conversion: take result as is.
184       recastValues.push_back(*convertedValueIt);
185     } else {
186       // Non-identity conversion: cast back to source type.
187       ValueRange recastValue = buildUnrealizedCast(
188           rewriter, originalType,
189           ValueRange{convertedValueIt, convertedValueIt + numConvertedValues},
190           CastKind::Source);
191       assert(recastValue.size() == 1);
192       recastValues.push_back(recastValue.front());
193     }
194     convertedValueIt += numConvertedValues;
195   }
196 
197   return recastValues;
198 }
199 
200 void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues,
201                                       const OneToNTypeMapping &resultMapping) {
202   // Create a cast back to the original types and replace the results of the
203   // original op with those.
204   assert(newValues.size() == resultMapping.getConvertedTypes().size());
205   assert(op->getResultTypes() == resultMapping.getOriginalTypes());
206   PatternRewriter::InsertionGuard g(*this);
207   setInsertionPointAfter(op);
208   SmallVector<Value> castResults =
209       buildUnrealizedBackwardsCasts(newValues, resultMapping, *this);
210   replaceOp(op, castResults);
211 }
212 
213 Block *OneToNPatternRewriter::applySignatureConversion(
214     Block *block, OneToNTypeMapping &argumentConversion) {
215   PatternRewriter::InsertionGuard g(*this);
216 
217   // Split the block at the beginning to get a new block to use for the
218   // updated signature.
219   SmallVector<Location> locs;
220   argumentConversion.convertLocations(block->getArguments(), locs);
221   Block *newBlock =
222       createBlock(block, argumentConversion.getConvertedTypes(), locs);
223   replaceAllUsesWith(block, newBlock);
224 
225   // Create necessary casts in new block.
226   SmallVector<Value> castResults;
227   for (auto [i, arg] : llvm::enumerate(block->getArguments())) {
228     TypeRange convertedTypes = argumentConversion.getConvertedTypes(i);
229     ValueRange newArgs =
230         argumentConversion.getConvertedValues(newBlock->getArguments(), i);
231     if (isIdentityConversion(arg.getType(), convertedTypes)) {
232       // Identity conversion: take argument as is.
233       assert(newArgs.size() == 1);
234       castResults.push_back(newArgs.front());
235     } else {
236       // Non-identity conversion: cast the converted arguments to the original
237       // type.
238       PatternRewriter::InsertionGuard g(*this);
239       setInsertionPointToStart(newBlock);
240       ValueRange castResult = buildUnrealizedCast(*this, arg.getType(), newArgs,
241                                                   CastKind::Argument);
242       assert(castResult.size() == 1);
243       castResults.push_back(castResult.front());
244     }
245   }
246 
247   // Merge old block into new block such that we only have the latter with the
248   // new signature.
249   mergeBlocks(block, newBlock, castResults);
250 
251   return newBlock;
252 }
253 
254 LogicalResult
255 OneToNConversionPattern::matchAndRewrite(Operation *op,
256                                          PatternRewriter &rewriter) const {
257   auto *typeConverter = getTypeConverter();
258 
259   // Construct conversion mapping for results.
260   Operation::result_type_range originalResultTypes = op->getResultTypes();
261   OneToNTypeMapping resultMapping(originalResultTypes);
262   if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
263                                                  resultMapping)))
264     return failure();
265 
266   // Construct conversion mapping for operands.
267   Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
268   OneToNTypeMapping operandMapping(originalOperandTypes);
269   if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
270                                                  operandMapping)))
271     return failure();
272 
273   // Cast operands to target types.
274   SmallVector<Value> convertedOperands = buildUnrealizedForwardCasts(
275       op->getOperands(), operandMapping, rewriter, CastKind::Target);
276 
277   // Create a `OneToNPatternRewriter` for the pattern, which provides additional
278   // functionality.
279   // TODO(ingomueller): I guess it would be better to use only one rewriter
280   //                    throughout the whole pass, but that would require to
281   //                    drive the pattern application ourselves, which is a lot
282   //                    of additional boilerplate code. This seems to work fine,
283   //                    so I leave it like this for the time being.
284   OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext(),
285                                               rewriter.getListener());
286   oneToNPatternRewriter.restoreInsertionPoint(rewriter.saveInsertionPoint());
287 
288   // Apply actual pattern.
289   if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping,
290                              resultMapping, convertedOperands)))
291     return failure();
292 
293   return success();
294 }
295 
296 namespace mlir {
297 
298 // This function applies the provided patterns using
299 // `applyPatternsGreedily` and then replaces all newly inserted
300 // `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts
301 // from target to source types inserted by a `OneToNConversionPattern` normally
302 // fold away with the "forward" casts from source to target types inserted by
303 // the next pattern.) To understand which casts are "newly inserted", all casts
304 // inserted by this pass are annotated with a string attribute that also
305 // documents which kind of the cast (source, argument, or target).
306 LogicalResult
307 applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
308                              const FrozenRewritePatternSet &patterns) {
309 #ifndef NDEBUG
310   // Remember existing unrealized casts. This data structure is only used in
311   // asserts; building it only for that purpose may be an overkill.
312   SmallSet<UnrealizedConversionCastOp, 4> existingCasts;
313   op->walk([&](UnrealizedConversionCastOp castOp) {
314     assert(!castOp->hasAttr(castKindAttrName));
315     existingCasts.insert(castOp);
316   });
317 #endif // NDEBUG
318 
319   // Apply provided conversion patterns.
320   if (failed(applyPatternsGreedily(op, patterns))) {
321     emitError(op->getLoc()) << "failed to apply conversion patterns";
322     return failure();
323   }
324 
325   // Find all unrealized casts inserted by the pass that haven't folded away.
326   SmallVector<UnrealizedConversionCastOp> worklist;
327   op->walk([&](UnrealizedConversionCastOp castOp) {
328     if (castOp->hasAttr(castKindAttrName)) {
329       assert(!existingCasts.contains(castOp));
330       worklist.push_back(castOp);
331     }
332   });
333 
334   // Replace new casts with user materializations.
335   IRRewriter rewriter(op->getContext());
336   for (UnrealizedConversionCastOp castOp : worklist) {
337     TypeRange resultTypes = castOp->getResultTypes();
338     ValueRange operands = castOp->getOperands();
339     StringRef castKind =
340         castOp->getAttrOfType<StringAttr>(castKindAttrName).getValue();
341     rewriter.setInsertionPoint(castOp);
342 
343 #ifndef NDEBUG
344     // Determine whether operands or results are already legal to test some
345     // assumptions for the different kind of materializations. These properties
346     // are only used it asserts and it may be overkill to compute them.
347     bool areOperandTypesLegal = llvm::all_of(
348         operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); });
349     bool areResultsTypesLegal = llvm::all_of(
350         resultTypes, [&](Type t) { return typeConverter.isLegal(t); });
351 #endif // NDEBUG
352 
353     // Add materialization and remember materialized results.
354     SmallVector<Value> materializedResults;
355     if (castKind == getCastKindName(CastKind::Target)) {
356       // Target materialization.
357       assert(!areOperandTypesLegal && areResultsTypesLegal &&
358              operands.size() == 1 && "found unexpected target cast");
359       materializedResults = typeConverter.materializeTargetConversion(
360           rewriter, castOp->getLoc(), resultTypes, operands.front());
361       if (materializedResults.empty()) {
362         emitError(castOp->getLoc())
363             << "failed to create target materialization";
364         return failure();
365       }
366     } else {
367       // Source and argument materializations.
368       assert(areOperandTypesLegal && !areResultsTypesLegal &&
369              resultTypes.size() == 1 && "found unexpected cast");
370       std::optional<Value> maybeResult;
371       if (castKind == getCastKindName(CastKind::Source)) {
372         // Source materialization.
373         maybeResult = typeConverter.materializeSourceConversion(
374             rewriter, castOp->getLoc(), resultTypes.front(),
375             castOp.getOperands());
376       } else {
377         // Argument materialization.
378         assert(castKind == getCastKindName(CastKind::Argument) &&
379                "unexpected value of cast kind attribute");
380         assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>));
381         maybeResult = typeConverter.materializeArgumentConversion(
382             rewriter, castOp->getLoc(), resultTypes.front(),
383             castOp.getOperands());
384       }
385       if (!maybeResult.has_value() || !maybeResult.value()) {
386         emitError(castOp->getLoc())
387             << "failed to create " << castKind << " materialization";
388         return failure();
389       }
390       materializedResults = {maybeResult.value()};
391     }
392 
393     // Replace the cast with the result of the materialization.
394     rewriter.replaceOp(castOp, materializedResults);
395   }
396 
397   return success();
398 }
399 
400 namespace {
401 class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
402 public:
403   FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
404                                          MLIRContext *ctx,
405                                          const TypeConverter &converter)
406       : OneToNConversionPattern(converter, functionLikeOpName, /*benefit=*/1,
407                                 ctx) {}
408 
409   LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
410                                 const OneToNTypeMapping &operandMapping,
411                                 const OneToNTypeMapping &resultMapping,
412                                 ValueRange convertedOperands) const override {
413     auto funcOp = cast<FunctionOpInterface>(op);
414     auto *typeConverter = getTypeConverter();
415 
416     // Construct mapping for function arguments.
417     OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
418     if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
419                                                    argumentMapping)))
420       return failure();
421 
422     // Construct mapping for function results.
423     OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
424     if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
425                                                    funcResultMapping)))
426       return failure();
427 
428     // Nothing to do if the op doesn't have any non-identity conversions for its
429     // operands or results.
430     if (!argumentMapping.hasNonIdentityConversion() &&
431         !funcResultMapping.hasNonIdentityConversion())
432       return failure();
433 
434     // Update the function signature in-place.
435     auto newType = FunctionType::get(rewriter.getContext(),
436                                      argumentMapping.getConvertedTypes(),
437                                      funcResultMapping.getConvertedTypes());
438     rewriter.modifyOpInPlace(op, [&] { funcOp.setType(newType); });
439 
440     // Update block signatures.
441     if (!funcOp.isExternal()) {
442       Region *region = &funcOp.getFunctionBody();
443       Block *block = &region->front();
444       rewriter.applySignatureConversion(block, argumentMapping);
445     }
446 
447     return success();
448   }
449 };
450 } // namespace
451 
452 void populateOneToNFunctionOpInterfaceTypeConversionPattern(
453     StringRef functionLikeOpName, const TypeConverter &converter,
454     RewritePatternSet &patterns) {
455   patterns.add<FunctionOpInterfaceSignatureConversion>(
456       functionLikeOpName, patterns.getContext(), converter);
457 }
458 } // namespace mlir
459