1 //===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===// 2 // 3 // Licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Transforms/OneToNTypeConversion.h" 10 11 #include "mlir/Interfaces/FunctionInterfaces.h" 12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 13 #include "llvm/ADT/SmallSet.h" 14 15 #include <unordered_map> 16 17 using namespace llvm; 18 using namespace mlir; 19 20 TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const { 21 TypeRange convertedTypes = getConvertedTypes(); 22 if (auto mapping = getInputMapping(originalTypeNo)) 23 return convertedTypes.slice(mapping->inputNo, mapping->size); 24 return {}; 25 } 26 27 ValueRange 28 OneToNTypeMapping::getConvertedValues(ValueRange convertedValues, 29 unsigned originalValueNo) const { 30 if (auto mapping = getInputMapping(originalValueNo)) 31 return convertedValues.slice(mapping->inputNo, mapping->size); 32 return {}; 33 } 34 35 void OneToNTypeMapping::convertLocation( 36 Value originalValue, unsigned originalValueNo, 37 llvm::SmallVectorImpl<Location> &result) const { 38 if (auto mapping = getInputMapping(originalValueNo)) 39 result.append(mapping->size, originalValue.getLoc()); 40 } 41 42 void OneToNTypeMapping::convertLocations( 43 ValueRange originalValues, llvm::SmallVectorImpl<Location> &result) const { 44 assert(originalValues.size() == getOriginalTypes().size()); 45 for (auto [i, value] : llvm::enumerate(originalValues)) 46 convertLocation(value, i, result); 47 } 48 49 static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) { 50 return convertedTypes.size() == 1 && convertedTypes[0] == originalType; 51 } 52 53 bool OneToNTypeMapping::hasNonIdentityConversion() const { 54 // XXX: I think that the original types and the converted types are the same 55 // iff there was no non-identity type conversion. If that is true, the 56 // patterns could actually test whether there is anything useful to do 57 // without having access to the signature conversion. 58 for (auto [i, originalType] : llvm::enumerate(originalTypes)) { 59 TypeRange types = getConvertedTypes(i); 60 if (!isIdentityConversion(originalType, types)) { 61 assert(TypeRange(originalTypes) != getConvertedTypes()); 62 return true; 63 } 64 } 65 assert(TypeRange(originalTypes) == getConvertedTypes()); 66 return false; 67 } 68 69 namespace { 70 enum class CastKind { 71 // Casts block arguments in the target type back to the source type. (If 72 // necessary, this cast becomes an argument materialization.) 73 Argument, 74 75 // Casts other values in the target type back to the source type. (If 76 // necessary, this cast becomes a source materialization.) 77 Source, 78 79 // Casts values in the source type to the target type. (If necessary, this 80 // cast becomes a target materialization.) 81 Target 82 }; 83 } // namespace 84 85 /// Mapping of enum values to string values. 86 StringRef getCastKindName(CastKind kind) { 87 static const std::unordered_map<CastKind, StringRef> castKindNames = { 88 {CastKind::Argument, "argument"}, 89 {CastKind::Source, "source"}, 90 {CastKind::Target, "target"}}; 91 return castKindNames.at(kind); 92 } 93 94 /// Attribute name that is used to annotate inserted unrealized casts with their 95 /// kind (source, argument, or target). 96 static const char *const castKindAttrName = 97 "__one-to-n-type-conversion_cast-kind__"; 98 99 /// Builds an `UnrealizedConversionCastOp` from the given inputs to the given 100 /// result types. Returns the result values of the cast. 101 static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes, 102 ValueRange inputs, CastKind kind) { 103 // Special case: 1-to-N conversion with N = 0. No need to build an 104 // UnrealizedConversionCastOp because the op will always be dead. 105 if (resultTypes.empty()) 106 return ValueRange(); 107 108 // Create cast. 109 Location loc = builder.getUnknownLoc(); 110 if (!inputs.empty()) 111 loc = inputs.front().getLoc(); 112 auto castOp = 113 builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs); 114 115 // Store cast kind as attribute. 116 auto kindAttr = StringAttr::get(builder.getContext(), getCastKindName(kind)); 117 castOp->setAttr(castKindAttrName, kindAttr); 118 119 return castOp->getResults(); 120 } 121 122 /// Builds one `UnrealizedConversionCastOp` for each of the given original 123 /// values using the respective target types given in the provided conversion 124 /// mapping and returns the results of these casts. If the conversion mapping of 125 /// a value maps a type to itself (i.e., is an identity conversion), then no 126 /// cast is inserted and the original value is returned instead. 127 /// Note that these unrealized casts are different from target materializations 128 /// in that they are *always* inserted, even if they immediately fold away, such 129 /// that patterns always see valid intermediate IR, whereas materializations are 130 /// only used in the places where the unrealized casts *don't* fold away. 131 static SmallVector<Value> 132 buildUnrealizedForwardCasts(ValueRange originalValues, 133 OneToNTypeMapping &conversion, 134 RewriterBase &rewriter, CastKind kind) { 135 136 // Convert each operand one by one. 137 SmallVector<Value> convertedValues; 138 convertedValues.reserve(conversion.getConvertedTypes().size()); 139 for (auto [idx, originalValue] : llvm::enumerate(originalValues)) { 140 TypeRange convertedTypes = conversion.getConvertedTypes(idx); 141 142 // Identity conversion: keep operand as is. 143 if (isIdentityConversion(originalValue.getType(), convertedTypes)) { 144 convertedValues.push_back(originalValue); 145 continue; 146 } 147 148 // Non-identity conversion: materialize target types. 149 ValueRange castResult = 150 buildUnrealizedCast(rewriter, convertedTypes, originalValue, kind); 151 convertedValues.append(castResult.begin(), castResult.end()); 152 } 153 154 return convertedValues; 155 } 156 157 /// Builds one `UnrealizedConversionCastOp` for each sequence of the given 158 /// original values to one value of the type they originated from, i.e., a 159 /// "reverse" conversion from N converted values back to one value of the 160 /// original type, using the given (forward) type conversion. If a given value 161 /// was mapped to a value of the same type (i.e., the conversion in the mapping 162 /// is an identity conversion), then the "converted" value is returned without 163 /// cast. 164 /// Note that these unrealized casts are different from source materializations 165 /// in that they are *always* inserted, even if they immediately fold away, such 166 /// that patterns always see valid intermediate IR, whereas materializations are 167 /// only used in the places where the unrealized casts *don't* fold away. 168 static SmallVector<Value> 169 buildUnrealizedBackwardsCasts(ValueRange convertedValues, 170 const OneToNTypeMapping &typeConversion, 171 RewriterBase &rewriter) { 172 assert(typeConversion.getConvertedTypes() == convertedValues.getTypes()); 173 174 // Create unrealized cast op for each converted result of the op. 175 SmallVector<Value> recastValues; 176 TypeRange originalTypes = typeConversion.getOriginalTypes(); 177 recastValues.reserve(originalTypes.size()); 178 auto convertedValueIt = convertedValues.begin(); 179 for (auto [idx, originalType] : llvm::enumerate(originalTypes)) { 180 TypeRange convertedTypes = typeConversion.getConvertedTypes(idx); 181 size_t numConvertedValues = convertedTypes.size(); 182 if (isIdentityConversion(originalType, convertedTypes)) { 183 // Identity conversion: take result as is. 184 recastValues.push_back(*convertedValueIt); 185 } else { 186 // Non-identity conversion: cast back to source type. 187 ValueRange recastValue = buildUnrealizedCast( 188 rewriter, originalType, 189 ValueRange{convertedValueIt, convertedValueIt + numConvertedValues}, 190 CastKind::Source); 191 assert(recastValue.size() == 1); 192 recastValues.push_back(recastValue.front()); 193 } 194 convertedValueIt += numConvertedValues; 195 } 196 197 return recastValues; 198 } 199 200 void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues, 201 const OneToNTypeMapping &resultMapping) { 202 // Create a cast back to the original types and replace the results of the 203 // original op with those. 204 assert(newValues.size() == resultMapping.getConvertedTypes().size()); 205 assert(op->getResultTypes() == resultMapping.getOriginalTypes()); 206 PatternRewriter::InsertionGuard g(*this); 207 setInsertionPointAfter(op); 208 SmallVector<Value> castResults = 209 buildUnrealizedBackwardsCasts(newValues, resultMapping, *this); 210 replaceOp(op, castResults); 211 } 212 213 Block *OneToNPatternRewriter::applySignatureConversion( 214 Block *block, OneToNTypeMapping &argumentConversion) { 215 PatternRewriter::InsertionGuard g(*this); 216 217 // Split the block at the beginning to get a new block to use for the 218 // updated signature. 219 SmallVector<Location> locs; 220 argumentConversion.convertLocations(block->getArguments(), locs); 221 Block *newBlock = 222 createBlock(block, argumentConversion.getConvertedTypes(), locs); 223 replaceAllUsesWith(block, newBlock); 224 225 // Create necessary casts in new block. 226 SmallVector<Value> castResults; 227 for (auto [i, arg] : llvm::enumerate(block->getArguments())) { 228 TypeRange convertedTypes = argumentConversion.getConvertedTypes(i); 229 ValueRange newArgs = 230 argumentConversion.getConvertedValues(newBlock->getArguments(), i); 231 if (isIdentityConversion(arg.getType(), convertedTypes)) { 232 // Identity conversion: take argument as is. 233 assert(newArgs.size() == 1); 234 castResults.push_back(newArgs.front()); 235 } else { 236 // Non-identity conversion: cast the converted arguments to the original 237 // type. 238 PatternRewriter::InsertionGuard g(*this); 239 setInsertionPointToStart(newBlock); 240 ValueRange castResult = buildUnrealizedCast(*this, arg.getType(), newArgs, 241 CastKind::Argument); 242 assert(castResult.size() == 1); 243 castResults.push_back(castResult.front()); 244 } 245 } 246 247 // Merge old block into new block such that we only have the latter with the 248 // new signature. 249 mergeBlocks(block, newBlock, castResults); 250 251 return newBlock; 252 } 253 254 LogicalResult 255 OneToNConversionPattern::matchAndRewrite(Operation *op, 256 PatternRewriter &rewriter) const { 257 auto *typeConverter = getTypeConverter(); 258 259 // Construct conversion mapping for results. 260 Operation::result_type_range originalResultTypes = op->getResultTypes(); 261 OneToNTypeMapping resultMapping(originalResultTypes); 262 if (failed(typeConverter->convertSignatureArgs(originalResultTypes, 263 resultMapping))) 264 return failure(); 265 266 // Construct conversion mapping for operands. 267 Operation::operand_type_range originalOperandTypes = op->getOperandTypes(); 268 OneToNTypeMapping operandMapping(originalOperandTypes); 269 if (failed(typeConverter->convertSignatureArgs(originalOperandTypes, 270 operandMapping))) 271 return failure(); 272 273 // Cast operands to target types. 274 SmallVector<Value> convertedOperands = buildUnrealizedForwardCasts( 275 op->getOperands(), operandMapping, rewriter, CastKind::Target); 276 277 // Create a `OneToNPatternRewriter` for the pattern, which provides additional 278 // functionality. 279 // TODO(ingomueller): I guess it would be better to use only one rewriter 280 // throughout the whole pass, but that would require to 281 // drive the pattern application ourselves, which is a lot 282 // of additional boilerplate code. This seems to work fine, 283 // so I leave it like this for the time being. 284 OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext(), 285 rewriter.getListener()); 286 oneToNPatternRewriter.restoreInsertionPoint(rewriter.saveInsertionPoint()); 287 288 // Apply actual pattern. 289 if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping, 290 resultMapping, convertedOperands))) 291 return failure(); 292 293 return success(); 294 } 295 296 namespace mlir { 297 298 // This function applies the provided patterns using 299 // `applyPatternsGreedily` and then replaces all newly inserted 300 // `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts 301 // from target to source types inserted by a `OneToNConversionPattern` normally 302 // fold away with the "forward" casts from source to target types inserted by 303 // the next pattern.) To understand which casts are "newly inserted", all casts 304 // inserted by this pass are annotated with a string attribute that also 305 // documents which kind of the cast (source, argument, or target). 306 LogicalResult 307 applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, 308 const FrozenRewritePatternSet &patterns) { 309 #ifndef NDEBUG 310 // Remember existing unrealized casts. This data structure is only used in 311 // asserts; building it only for that purpose may be an overkill. 312 SmallSet<UnrealizedConversionCastOp, 4> existingCasts; 313 op->walk([&](UnrealizedConversionCastOp castOp) { 314 assert(!castOp->hasAttr(castKindAttrName)); 315 existingCasts.insert(castOp); 316 }); 317 #endif // NDEBUG 318 319 // Apply provided conversion patterns. 320 if (failed(applyPatternsGreedily(op, patterns))) { 321 emitError(op->getLoc()) << "failed to apply conversion patterns"; 322 return failure(); 323 } 324 325 // Find all unrealized casts inserted by the pass that haven't folded away. 326 SmallVector<UnrealizedConversionCastOp> worklist; 327 op->walk([&](UnrealizedConversionCastOp castOp) { 328 if (castOp->hasAttr(castKindAttrName)) { 329 assert(!existingCasts.contains(castOp)); 330 worklist.push_back(castOp); 331 } 332 }); 333 334 // Replace new casts with user materializations. 335 IRRewriter rewriter(op->getContext()); 336 for (UnrealizedConversionCastOp castOp : worklist) { 337 TypeRange resultTypes = castOp->getResultTypes(); 338 ValueRange operands = castOp->getOperands(); 339 StringRef castKind = 340 castOp->getAttrOfType<StringAttr>(castKindAttrName).getValue(); 341 rewriter.setInsertionPoint(castOp); 342 343 #ifndef NDEBUG 344 // Determine whether operands or results are already legal to test some 345 // assumptions for the different kind of materializations. These properties 346 // are only used it asserts and it may be overkill to compute them. 347 bool areOperandTypesLegal = llvm::all_of( 348 operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); }); 349 bool areResultsTypesLegal = llvm::all_of( 350 resultTypes, [&](Type t) { return typeConverter.isLegal(t); }); 351 #endif // NDEBUG 352 353 // Add materialization and remember materialized results. 354 SmallVector<Value> materializedResults; 355 if (castKind == getCastKindName(CastKind::Target)) { 356 // Target materialization. 357 assert(!areOperandTypesLegal && areResultsTypesLegal && 358 operands.size() == 1 && "found unexpected target cast"); 359 materializedResults = typeConverter.materializeTargetConversion( 360 rewriter, castOp->getLoc(), resultTypes, operands.front()); 361 if (materializedResults.empty()) { 362 emitError(castOp->getLoc()) 363 << "failed to create target materialization"; 364 return failure(); 365 } 366 } else { 367 // Source and argument materializations. 368 assert(areOperandTypesLegal && !areResultsTypesLegal && 369 resultTypes.size() == 1 && "found unexpected cast"); 370 std::optional<Value> maybeResult; 371 if (castKind == getCastKindName(CastKind::Source)) { 372 // Source materialization. 373 maybeResult = typeConverter.materializeSourceConversion( 374 rewriter, castOp->getLoc(), resultTypes.front(), 375 castOp.getOperands()); 376 } else { 377 // Argument materialization. 378 assert(castKind == getCastKindName(CastKind::Argument) && 379 "unexpected value of cast kind attribute"); 380 assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>)); 381 maybeResult = typeConverter.materializeArgumentConversion( 382 rewriter, castOp->getLoc(), resultTypes.front(), 383 castOp.getOperands()); 384 } 385 if (!maybeResult.has_value() || !maybeResult.value()) { 386 emitError(castOp->getLoc()) 387 << "failed to create " << castKind << " materialization"; 388 return failure(); 389 } 390 materializedResults = {maybeResult.value()}; 391 } 392 393 // Replace the cast with the result of the materialization. 394 rewriter.replaceOp(castOp, materializedResults); 395 } 396 397 return success(); 398 } 399 400 namespace { 401 class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern { 402 public: 403 FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName, 404 MLIRContext *ctx, 405 const TypeConverter &converter) 406 : OneToNConversionPattern(converter, functionLikeOpName, /*benefit=*/1, 407 ctx) {} 408 409 LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, 410 const OneToNTypeMapping &operandMapping, 411 const OneToNTypeMapping &resultMapping, 412 ValueRange convertedOperands) const override { 413 auto funcOp = cast<FunctionOpInterface>(op); 414 auto *typeConverter = getTypeConverter(); 415 416 // Construct mapping for function arguments. 417 OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes()); 418 if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(), 419 argumentMapping))) 420 return failure(); 421 422 // Construct mapping for function results. 423 OneToNTypeMapping funcResultMapping(funcOp.getResultTypes()); 424 if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(), 425 funcResultMapping))) 426 return failure(); 427 428 // Nothing to do if the op doesn't have any non-identity conversions for its 429 // operands or results. 430 if (!argumentMapping.hasNonIdentityConversion() && 431 !funcResultMapping.hasNonIdentityConversion()) 432 return failure(); 433 434 // Update the function signature in-place. 435 auto newType = FunctionType::get(rewriter.getContext(), 436 argumentMapping.getConvertedTypes(), 437 funcResultMapping.getConvertedTypes()); 438 rewriter.modifyOpInPlace(op, [&] { funcOp.setType(newType); }); 439 440 // Update block signatures. 441 if (!funcOp.isExternal()) { 442 Region *region = &funcOp.getFunctionBody(); 443 Block *block = ®ion->front(); 444 rewriter.applySignatureConversion(block, argumentMapping); 445 } 446 447 return success(); 448 } 449 }; 450 } // namespace 451 452 void populateOneToNFunctionOpInterfaceTypeConversionPattern( 453 StringRef functionLikeOpName, const TypeConverter &converter, 454 RewritePatternSet &patterns) { 455 patterns.add<FunctionOpInterfaceSignatureConversion>( 456 functionLikeOpName, patterns.getContext(), converter); 457 } 458 } // namespace mlir 459