xref: /llvm-project/mlir/test/lib/Dialect/SCF/TestWhileOpBuilder.cpp (revision b716bf84eaba25e0f83d1778288f65a671e85f98)
1*b716bf84SIngo Müller //===- TestWhileOpBuilder.cpp - Pass to test WhileOp::build ---------------===//
2*b716bf84SIngo Müller //
3*b716bf84SIngo Müller // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*b716bf84SIngo Müller // See https://llvm.org/LICENSE.txt for license information.
5*b716bf84SIngo Müller // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*b716bf84SIngo Müller //
7*b716bf84SIngo Müller //===----------------------------------------------------------------------===//
8*b716bf84SIngo Müller //
9*b716bf84SIngo Müller // This file implements a pass to test some builder functions of WhileOp. It
10*b716bf84SIngo Müller // tests the regression explained in https://reviews.llvm.org/D142952, where
11*b716bf84SIngo Müller // a WhileOp::build overload crashed when fed with operands of different types
12*b716bf84SIngo Müller // than the result types.
13*b716bf84SIngo Müller //
14*b716bf84SIngo Müller // To test the build function, the pass copies each WhileOp found in the body
15*b716bf84SIngo Müller // of a FuncOp and adds an additional WhileOp with the same operands and result
16*b716bf84SIngo Müller // types (but dummy computations) using the builder in question.
17*b716bf84SIngo Müller //
18*b716bf84SIngo Müller //===----------------------------------------------------------------------===//
19*b716bf84SIngo Müller 
20*b716bf84SIngo Müller #include "mlir/Dialect/Arith/IR/Arith.h"
21*b716bf84SIngo Müller #include "mlir/Dialect/Func/IR/FuncOps.h"
22*b716bf84SIngo Müller #include "mlir/Dialect/SCF/IR/SCF.h"
23*b716bf84SIngo Müller #include "mlir/IR/BuiltinOps.h"
24*b716bf84SIngo Müller #include "mlir/IR/ImplicitLocOpBuilder.h"
25*b716bf84SIngo Müller #include "mlir/Pass/Pass.h"
26*b716bf84SIngo Müller 
27*b716bf84SIngo Müller using namespace mlir;
28*b716bf84SIngo Müller using namespace mlir::arith;
29*b716bf84SIngo Müller using namespace mlir::scf;
30*b716bf84SIngo Müller 
31*b716bf84SIngo Müller namespace {
32*b716bf84SIngo Müller struct TestSCFWhileOpBuilderPass
33*b716bf84SIngo Müller     : public PassWrapper<TestSCFWhileOpBuilderPass,
34*b716bf84SIngo Müller                          OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon80dcf31a0111::TestSCFWhileOpBuilderPass35*b716bf84SIngo Müller   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFWhileOpBuilderPass)
36*b716bf84SIngo Müller 
37*b716bf84SIngo Müller   StringRef getArgument() const final { return "test-scf-while-op-builder"; }
getDescription__anon80dcf31a0111::TestSCFWhileOpBuilderPass38*b716bf84SIngo Müller   StringRef getDescription() const final {
39*b716bf84SIngo Müller     return "test build functions of scf.while";
40*b716bf84SIngo Müller   }
41*b716bf84SIngo Müller   explicit TestSCFWhileOpBuilderPass() = default;
42*b716bf84SIngo Müller   TestSCFWhileOpBuilderPass(const TestSCFWhileOpBuilderPass &pass) = default;
43*b716bf84SIngo Müller 
runOnOperation__anon80dcf31a0111::TestSCFWhileOpBuilderPass44*b716bf84SIngo Müller   void runOnOperation() override {
45*b716bf84SIngo Müller     func::FuncOp func = getOperation();
46*b716bf84SIngo Müller     func.walk([&](WhileOp whileOp) {
47*b716bf84SIngo Müller       Location loc = whileOp->getLoc();
48*b716bf84SIngo Müller       ImplicitLocOpBuilder builder(loc, whileOp);
49*b716bf84SIngo Müller 
50*b716bf84SIngo Müller       // Create a WhileOp with the same operands and result types.
51*b716bf84SIngo Müller       TypeRange resultTypes = whileOp->getResultTypes();
52*b716bf84SIngo Müller       ValueRange operands = whileOp->getOperands();
53*b716bf84SIngo Müller       builder.create<WhileOp>(
54*b716bf84SIngo Müller           loc, resultTypes, operands, /*beforeBuilder=*/
55*b716bf84SIngo Müller           [&](OpBuilder &b, Location loc, ValueRange args) {
56*b716bf84SIngo Müller             // Just cast the before args into the right types for condition.
57*b716bf84SIngo Müller             ImplicitLocOpBuilder builder(loc, b);
58*b716bf84SIngo Müller             auto castOp =
59*b716bf84SIngo Müller                 builder.create<UnrealizedConversionCastOp>(resultTypes, args);
60*b716bf84SIngo Müller             auto cmp = builder.create<ConstantIntOp>(/*value=*/1, /*width=*/1);
61*b716bf84SIngo Müller             builder.create<ConditionOp>(cmp, castOp->getResults());
62*b716bf84SIngo Müller           },
63*b716bf84SIngo Müller           /*afterBuilder=*/
64*b716bf84SIngo Müller           [&](OpBuilder &b, Location loc, ValueRange args) {
65*b716bf84SIngo Müller             // Just cast the after args into the right types for yield.
66*b716bf84SIngo Müller             ImplicitLocOpBuilder builder(loc, b);
67*b716bf84SIngo Müller             auto castOp = builder.create<UnrealizedConversionCastOp>(
68*b716bf84SIngo Müller                 operands.getTypes(), args);
69*b716bf84SIngo Müller             builder.create<YieldOp>(castOp->getResults());
70*b716bf84SIngo Müller           });
71*b716bf84SIngo Müller     });
72*b716bf84SIngo Müller   }
73*b716bf84SIngo Müller };
74*b716bf84SIngo Müller } // namespace
75*b716bf84SIngo Müller 
76*b716bf84SIngo Müller namespace mlir {
77*b716bf84SIngo Müller namespace test {
registerTestSCFWhileOpBuilderPass()78*b716bf84SIngo Müller void registerTestSCFWhileOpBuilderPass() {
79*b716bf84SIngo Müller   PassRegistration<TestSCFWhileOpBuilderPass>();
80*b716bf84SIngo Müller }
81*b716bf84SIngo Müller } // namespace test
82*b716bf84SIngo Müller } // namespace mlir
83