1 //===- SCFToSPIRV.cpp - SCF to SPIR-V Patterns ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert SCF dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" 14 #include "mlir/Dialect/SCF/IR/SCF.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "llvm/Support/FormatVariadic.h" 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // Context 26 //===----------------------------------------------------------------------===// 27 28 namespace mlir { 29 struct ScfToSPIRVContextImpl { 30 // Map between the spirv region control flow operation (spirv.mlir.loop or 31 // spirv.mlir.selection) to the VariableOp created to store the region 32 // results. The order of the VariableOp matches the order of the results. 33 DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars; 34 }; 35 } // namespace mlir 36 37 /// We use ScfToSPIRVContext to store information about the lowering of the scf 38 /// region that need to be used later on. When we lower scf.for/scf.if we create 39 /// VariableOp to store the results. We need to keep track of the VariableOp 40 /// created as we need to insert stores into them when lowering Yield. Those 41 /// StoreOp cannot be created earlier as they may use a different type than 42 /// yield operands. 43 ScfToSPIRVContext::ScfToSPIRVContext() { 44 impl = std::make_unique<::ScfToSPIRVContextImpl>(); 45 } 46 47 ScfToSPIRVContext::~ScfToSPIRVContext() = default; 48 49 namespace { 50 51 //===----------------------------------------------------------------------===// 52 // Helper Functions 53 //===----------------------------------------------------------------------===// 54 55 /// Replaces SCF op outputs with SPIR-V variable loads. 56 /// We create VariableOp to handle the results value of the control flow region. 57 /// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right 58 /// after the loop we load the value from the allocation and use it as the SCF 59 /// op result. 60 template <typename ScfOp, typename OpTy> 61 void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, 62 ConversionPatternRewriter &rewriter, 63 ScfToSPIRVContextImpl *scfToSPIRVContext, 64 ArrayRef<Type> returnTypes) { 65 66 Location loc = scfOp.getLoc(); 67 auto &allocas = scfToSPIRVContext->outputVars[newOp]; 68 // Clearing the allocas is necessary in case a dialect conversion path failed 69 // previously, and this is the second attempt of this conversion. 70 allocas.clear(); 71 SmallVector<Value, 8> resultValue; 72 for (Type convertedType : returnTypes) { 73 auto pointerType = 74 spirv::PointerType::get(convertedType, spirv::StorageClass::Function); 75 rewriter.setInsertionPoint(newOp); 76 auto alloc = rewriter.create<spirv::VariableOp>( 77 loc, pointerType, spirv::StorageClass::Function, 78 /*initializer=*/nullptr); 79 allocas.push_back(alloc); 80 rewriter.setInsertionPointAfter(newOp); 81 Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc); 82 resultValue.push_back(loadResult); 83 } 84 rewriter.replaceOp(scfOp, resultValue); 85 } 86 87 Region::iterator getBlockIt(Region ®ion, unsigned index) { 88 return std::next(region.begin(), index); 89 } 90 91 //===----------------------------------------------------------------------===// 92 // Conversion Patterns 93 //===----------------------------------------------------------------------===// 94 95 /// Common class for all vector to GPU patterns. 96 template <typename OpTy> 97 class SCFToSPIRVPattern : public OpConversionPattern<OpTy> { 98 public: 99 SCFToSPIRVPattern(MLIRContext *context, const SPIRVTypeConverter &converter, 100 ScfToSPIRVContextImpl *scfToSPIRVContext) 101 : OpConversionPattern<OpTy>::OpConversionPattern(converter, context), 102 scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {} 103 104 protected: 105 ScfToSPIRVContextImpl *scfToSPIRVContext; 106 // FIXME: We explicitly keep a reference of the type converter here instead of 107 // passing it to OpConversionPattern during construction. This effectively 108 // bypasses the conversion framework's automation on type conversion. This is 109 // needed right now because the conversion framework will unconditionally 110 // legalize all types used by SCF ops upon discovering them, for example, the 111 // types of loop carried values. We use SPIR-V variables for those loop 112 // carried values. Depending on the available capabilities, the SPIR-V 113 // variable can be different, for example, cooperative matrix or normal 114 // variable. We'd like to detach the conversion of the loop carried values 115 // from the SCF ops (which is mainly a region). So we need to "mark" types 116 // used by SCF ops as legal, if to use the conversion framework for type 117 // conversion. There isn't a straightforward way to do that yet, as when 118 // converting types, ops aren't taken into consideration. Therefore, we just 119 // bypass the framework's type conversion for now. 120 const SPIRVTypeConverter &typeConverter; 121 }; 122 123 //===----------------------------------------------------------------------===// 124 // scf::ForOp 125 //===----------------------------------------------------------------------===// 126 127 /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. 128 struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> { 129 using SCFToSPIRVPattern::SCFToSPIRVPattern; 130 131 LogicalResult 132 matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor, 133 ConversionPatternRewriter &rewriter) const override { 134 // scf::ForOp can be lowered to the structured control flow represented by 135 // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop 136 // latch and the merge block the exit block. The resulting spirv::LoopOp has 137 // a single back edge from the continue to header block, and a single exit 138 // from header to merge. 139 auto loc = forOp.getLoc(); 140 auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None); 141 loopOp.addEntryAndMergeBlock(rewriter); 142 143 OpBuilder::InsertionGuard guard(rewriter); 144 // Create the block for the header. 145 Block *header = rewriter.createBlock(&loopOp.getBody(), 146 getBlockIt(loopOp.getBody(), 1)); 147 rewriter.setInsertionPointAfter(loopOp); 148 149 // Create the new induction variable to use. 150 Value adapLowerBound = adaptor.getLowerBound(); 151 BlockArgument newIndVar = 152 header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc()); 153 for (Value arg : adaptor.getInitArgs()) 154 header->addArgument(arg.getType(), arg.getLoc()); 155 Block *body = forOp.getBody(); 156 157 // Apply signature conversion to the body of the forOp. It has a single 158 // block, with argument which is the induction variable. That has to be 159 // replaced with the new induction variable. 160 TypeConverter::SignatureConversion signatureConverter( 161 body->getNumArguments()); 162 signatureConverter.remapInput(0, newIndVar); 163 for (unsigned i = 1, e = body->getNumArguments(); i < e; i++) 164 signatureConverter.remapInput(i, header->getArgument(i)); 165 body = rewriter.applySignatureConversion(&forOp.getRegion().front(), 166 signatureConverter); 167 168 // Move the blocks from the forOp into the loopOp. This is the body of the 169 // loopOp. 170 rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(), 171 getBlockIt(loopOp.getBody(), 2)); 172 173 SmallVector<Value, 8> args(1, adaptor.getLowerBound()); 174 args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); 175 // Branch into it from the entry. 176 rewriter.setInsertionPointToEnd(&(loopOp.getBody().front())); 177 rewriter.create<spirv::BranchOp>(loc, header, args); 178 179 // Generate the rest of the loop header. 180 rewriter.setInsertionPointToEnd(header); 181 auto *mergeBlock = loopOp.getMergeBlock(); 182 auto cmpOp = rewriter.create<spirv::SLessThanOp>( 183 loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound()); 184 185 rewriter.create<spirv::BranchConditionalOp>( 186 loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>()); 187 188 // Generate instructions to increment the step of the induction variable and 189 // branch to the header. 190 Block *continueBlock = loopOp.getContinueBlock(); 191 rewriter.setInsertionPointToEnd(continueBlock); 192 193 // Add the step to the induction variable and branch to the header. 194 Value updatedIndVar = rewriter.create<spirv::IAddOp>( 195 loc, newIndVar.getType(), newIndVar, adaptor.getStep()); 196 rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar); 197 198 // Infer the return types from the init operands. Vector type may get 199 // converted to CooperativeMatrix or to Vector type, to avoid having complex 200 // extra logic to figure out the right type we just infer it from the Init 201 // operands. 202 SmallVector<Type, 8> initTypes; 203 for (auto arg : adaptor.getInitArgs()) 204 initTypes.push_back(arg.getType()); 205 replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, 206 initTypes); 207 return success(); 208 } 209 }; 210 211 //===----------------------------------------------------------------------===// 212 // scf::IfOp 213 //===----------------------------------------------------------------------===// 214 215 /// Pattern to convert a scf::IfOp within kernel functions into 216 /// spirv::SelectionOp. 217 struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> { 218 using SCFToSPIRVPattern::SCFToSPIRVPattern; 219 220 LogicalResult 221 matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, 222 ConversionPatternRewriter &rewriter) const override { 223 // When lowering `scf::IfOp` we explicitly create a selection header block 224 // before the control flow diverges and a merge block where control flow 225 // subsequently converges. 226 auto loc = ifOp.getLoc(); 227 228 // Create `spirv.selection` operation, selection header block and merge 229 // block. 230 auto selectionOp = 231 rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None); 232 auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(), 233 selectionOp.getBody().end()); 234 rewriter.create<spirv::MergeOp>(loc); 235 236 OpBuilder::InsertionGuard guard(rewriter); 237 auto *selectionHeaderBlock = 238 rewriter.createBlock(&selectionOp.getBody().front()); 239 240 // Inline `then` region before the merge block and branch to it. 241 auto &thenRegion = ifOp.getThenRegion(); 242 auto *thenBlock = &thenRegion.front(); 243 rewriter.setInsertionPointToEnd(&thenRegion.back()); 244 rewriter.create<spirv::BranchOp>(loc, mergeBlock); 245 rewriter.inlineRegionBefore(thenRegion, mergeBlock); 246 247 auto *elseBlock = mergeBlock; 248 // If `else` region is not empty, inline that region before the merge block 249 // and branch to it. 250 if (!ifOp.getElseRegion().empty()) { 251 auto &elseRegion = ifOp.getElseRegion(); 252 elseBlock = &elseRegion.front(); 253 rewriter.setInsertionPointToEnd(&elseRegion.back()); 254 rewriter.create<spirv::BranchOp>(loc, mergeBlock); 255 rewriter.inlineRegionBefore(elseRegion, mergeBlock); 256 } 257 258 // Create a `spirv.BranchConditional` operation for selection header block. 259 rewriter.setInsertionPointToEnd(selectionHeaderBlock); 260 rewriter.create<spirv::BranchConditionalOp>(loc, adaptor.getCondition(), 261 thenBlock, ArrayRef<Value>(), 262 elseBlock, ArrayRef<Value>()); 263 264 SmallVector<Type, 8> returnTypes; 265 for (auto result : ifOp.getResults()) { 266 auto convertedType = typeConverter.convertType(result.getType()); 267 if (!convertedType) 268 return rewriter.notifyMatchFailure( 269 loc, 270 llvm::formatv("failed to convert type '{0}'", result.getType())); 271 272 returnTypes.push_back(convertedType); 273 } 274 replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, 275 returnTypes); 276 return success(); 277 } 278 }; 279 280 //===----------------------------------------------------------------------===// 281 // scf::YieldOp 282 //===----------------------------------------------------------------------===// 283 284 struct TerminatorOpConversion final : SCFToSPIRVPattern<scf::YieldOp> { 285 public: 286 using SCFToSPIRVPattern::SCFToSPIRVPattern; 287 288 LogicalResult 289 matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor, 290 ConversionPatternRewriter &rewriter) const override { 291 ValueRange operands = adaptor.getOperands(); 292 293 Operation *parent = terminatorOp->getParentOp(); 294 295 // TODO: Implement conversion for the remaining `scf` ops. 296 if (parent->getDialect()->getNamespace() == 297 scf::SCFDialect::getDialectNamespace() && 298 !isa<scf::IfOp, scf::ForOp, scf::WhileOp>(parent)) 299 return rewriter.notifyMatchFailure( 300 terminatorOp, 301 llvm::formatv("conversion not supported for parent op: '{0}'", 302 parent->getName())); 303 304 // If the region return values, store each value into the associated 305 // VariableOp created during lowering of the parent region. 306 if (!operands.empty()) { 307 auto &allocas = scfToSPIRVContext->outputVars[parent]; 308 if (allocas.size() != operands.size()) 309 return failure(); 310 311 auto loc = terminatorOp.getLoc(); 312 for (unsigned i = 0, e = operands.size(); i < e; i++) 313 rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]); 314 if (isa<spirv::LoopOp>(parent)) { 315 // For loops we also need to update the branch jumping back to the 316 // header. 317 auto br = cast<spirv::BranchOp>( 318 rewriter.getInsertionBlock()->getTerminator()); 319 SmallVector<Value, 8> args(br.getBlockArguments()); 320 args.append(operands.begin(), operands.end()); 321 rewriter.setInsertionPoint(br); 322 rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(), 323 args); 324 rewriter.eraseOp(br); 325 } 326 } 327 rewriter.eraseOp(terminatorOp); 328 return success(); 329 } 330 }; 331 332 //===----------------------------------------------------------------------===// 333 // scf::WhileOp 334 //===----------------------------------------------------------------------===// 335 336 struct WhileOpConversion final : SCFToSPIRVPattern<scf::WhileOp> { 337 using SCFToSPIRVPattern::SCFToSPIRVPattern; 338 339 LogicalResult 340 matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor, 341 ConversionPatternRewriter &rewriter) const override { 342 auto loc = whileOp.getLoc(); 343 auto loopOp = rewriter.create<spirv::LoopOp>(loc, spirv::LoopControl::None); 344 loopOp.addEntryAndMergeBlock(rewriter); 345 346 Region &beforeRegion = whileOp.getBefore(); 347 Region &afterRegion = whileOp.getAfter(); 348 349 if (failed(rewriter.convertRegionTypes(&beforeRegion, typeConverter)) || 350 failed(rewriter.convertRegionTypes(&afterRegion, typeConverter))) 351 return rewriter.notifyMatchFailure(whileOp, 352 "Failed to convert region types"); 353 354 OpBuilder::InsertionGuard guard(rewriter); 355 356 Block &entryBlock = *loopOp.getEntryBlock(); 357 Block &beforeBlock = beforeRegion.front(); 358 Block &afterBlock = afterRegion.front(); 359 Block &mergeBlock = *loopOp.getMergeBlock(); 360 361 auto cond = cast<scf::ConditionOp>(beforeBlock.getTerminator()); 362 SmallVector<Value> condArgs; 363 if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs))) 364 return failure(); 365 366 Value conditionVal = rewriter.getRemappedValue(cond.getCondition()); 367 if (!conditionVal) 368 return failure(); 369 370 auto yield = cast<scf::YieldOp>(afterBlock.getTerminator()); 371 SmallVector<Value> yieldArgs; 372 if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs))) 373 return failure(); 374 375 // Move the while before block as the initial loop header block. 376 rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(), 377 getBlockIt(loopOp.getBody(), 1)); 378 379 // Move the while after block as the initial loop body block. 380 rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(), 381 getBlockIt(loopOp.getBody(), 2)); 382 383 // Jump from the loop entry block to the loop header block. 384 rewriter.setInsertionPointToEnd(&entryBlock); 385 rewriter.create<spirv::BranchOp>(loc, &beforeBlock, adaptor.getInits()); 386 387 auto condLoc = cond.getLoc(); 388 389 SmallVector<Value> resultValues(condArgs.size()); 390 391 // For other SCF ops, the scf.yield op yields the value for the whole SCF 392 // op. So we use the scf.yield op as the anchor to create/load/store SPIR-V 393 // local variables. But for the scf.while op, the scf.yield op yields a 394 // value for the before region, which may not matching the whole op's 395 // result. Instead, the scf.condition op returns values matching the whole 396 // op's results. So we need to create/load/store variables according to 397 // that. 398 for (const auto &it : llvm::enumerate(condArgs)) { 399 auto res = it.value(); 400 auto i = it.index(); 401 auto pointerType = 402 spirv::PointerType::get(res.getType(), spirv::StorageClass::Function); 403 404 // Create local variables before the scf.while op. 405 rewriter.setInsertionPoint(loopOp); 406 auto alloc = rewriter.create<spirv::VariableOp>( 407 condLoc, pointerType, spirv::StorageClass::Function, 408 /*initializer=*/nullptr); 409 410 // Load the final result values after the scf.while op. 411 rewriter.setInsertionPointAfter(loopOp); 412 auto loadResult = rewriter.create<spirv::LoadOp>(condLoc, alloc); 413 resultValues[i] = loadResult; 414 415 // Store the current iteration's result value. 416 rewriter.setInsertionPointToEnd(&beforeBlock); 417 rewriter.create<spirv::StoreOp>(condLoc, alloc, res); 418 } 419 420 rewriter.setInsertionPointToEnd(&beforeBlock); 421 rewriter.replaceOpWithNewOp<spirv::BranchConditionalOp>( 422 cond, conditionVal, &afterBlock, condArgs, &mergeBlock, std::nullopt); 423 424 // Convert the scf.yield op to a branch back to the header block. 425 rewriter.setInsertionPointToEnd(&afterBlock); 426 rewriter.replaceOpWithNewOp<spirv::BranchOp>(yield, &beforeBlock, 427 yieldArgs); 428 429 rewriter.replaceOp(whileOp, resultValues); 430 return success(); 431 } 432 }; 433 } // namespace 434 435 //===----------------------------------------------------------------------===// 436 // Public API 437 //===----------------------------------------------------------------------===// 438 439 void mlir::populateSCFToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, 440 ScfToSPIRVContext &scfToSPIRVContext, 441 RewritePatternSet &patterns) { 442 patterns.add<ForOpConversion, IfOpConversion, TerminatorOpConversion, 443 WhileOpConversion>(patterns.getContext(), typeConverter, 444 scfToSPIRVContext.getImpl()); 445 } 446