xref: /llvm-project/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp (revision b716bf84eaba25e0f83d1778288f65a671e85f98)
1 //===- TestWhileOpBuilder.cpp - Pass to test WhileOp::build ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test some builder functions of WhileOp. It
10 // tests the regression explained in https://reviews.llvm.org/D142952, where
11 // a WhileOp::build overload crashed when fed with operands of different types
12 // than the result types.
13 //
14 // To test the build function, the pass copies each WhileOp found in the body
15 // of a FuncOp and adds an additional WhileOp with the same operands and result
16 // types (but dummy computations) using the builder in question.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "mlir/Dialect/Arith/IR/Arith.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"
22 #include "mlir/Dialect/SCF/IR/SCF.h"
23 #include "mlir/IR/BuiltinOps.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/Pass/Pass.h"
26 
27 using namespace mlir;
28 using namespace mlir::arith;
29 using namespace mlir::scf;
30 
31 namespace {
32 struct TestSCFWhileOpBuilderPass
33     : public PassWrapper<TestSCFWhileOpBuilderPass,
34                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon80dcf31a0111::TestSCFWhileOpBuilderPass35   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFWhileOpBuilderPass)
36 
37   StringRef getArgument() const final { return "test-scf-while-op-builder"; }
getDescription__anon80dcf31a0111::TestSCFWhileOpBuilderPass38   StringRef getDescription() const final {
39     return "test build functions of scf.while";
40   }
41   explicit TestSCFWhileOpBuilderPass() = default;
42   TestSCFWhileOpBuilderPass(const TestSCFWhileOpBuilderPass &pass) = default;
43 
runOnOperation__anon80dcf31a0111::TestSCFWhileOpBuilderPass44   void runOnOperation() override {
45     func::FuncOp func = getOperation();
46     func.walk([&](WhileOp whileOp) {
47       Location loc = whileOp->getLoc();
48       ImplicitLocOpBuilder builder(loc, whileOp);
49 
50       // Create a WhileOp with the same operands and result types.
51       TypeRange resultTypes = whileOp->getResultTypes();
52       ValueRange operands = whileOp->getOperands();
53       builder.create<WhileOp>(
54           loc, resultTypes, operands, /*beforeBuilder=*/
55           [&](OpBuilder &b, Location loc, ValueRange args) {
56             // Just cast the before args into the right types for condition.
57             ImplicitLocOpBuilder builder(loc, b);
58             auto castOp =
59                 builder.create<UnrealizedConversionCastOp>(resultTypes, args);
60             auto cmp = builder.create<ConstantIntOp>(/*value=*/1, /*width=*/1);
61             builder.create<ConditionOp>(cmp, castOp->getResults());
62           },
63           /*afterBuilder=*/
64           [&](OpBuilder &b, Location loc, ValueRange args) {
65             // Just cast the after args into the right types for yield.
66             ImplicitLocOpBuilder builder(loc, b);
67             auto castOp = builder.create<UnrealizedConversionCastOp>(
68                 operands.getTypes(), args);
69             builder.create<YieldOp>(castOp->getResults());
70           });
71     });
72   }
73 };
74 } // namespace
75 
76 namespace mlir {
77 namespace test {
registerTestSCFWhileOpBuilderPass()78 void registerTestSCFWhileOpBuilderPass() {
79   PassRegistration<TestSCFWhileOpBuilderPass>();
80 }
81 } // namespace test
82 } // namespace mlir
83