xref: /llvm-project/mlir/test/Examples/Toy/Ch7/shape_inference.mlir (revision ee2c6cd9069fe0d8e7386ce53300e7645e4db792)
16b4e30b7SRiver Riddle// RUN: toyc-ch7 %s -emit=mlir -opt 2>&1 | FileCheck %s
26b4e30b7SRiver Riddle
36b4e30b7SRiver Riddle// Check the result of inlining+shape inference on an input module.
46b4e30b7SRiver Riddle
5*ee2c6cd9SRiver Riddletoy.func private @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
60050e8f0SRiver Riddle  %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
70050e8f0SRiver Riddle  %1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
80050e8f0SRiver Riddle  %2 = toy.mul %0, %1 : tensor<*xf64>
90050e8f0SRiver Riddle  toy.return %2 : tensor<*xf64>
106b4e30b7SRiver Riddle}
11*ee2c6cd9SRiver Riddletoy.func @main() {
120050e8f0SRiver Riddle  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
130050e8f0SRiver Riddle  %1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
140050e8f0SRiver Riddle  %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
150050e8f0SRiver Riddle  %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<2x3xf64>
160050e8f0SRiver Riddle  %4 = toy.generic_call @multiply_transpose(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
170050e8f0SRiver Riddle  %5 = toy.generic_call @multiply_transpose(%3, %1) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
180050e8f0SRiver Riddle  toy.print %5 : tensor<*xf64>
190050e8f0SRiver Riddle  toy.return
206b4e30b7SRiver Riddle}
216b4e30b7SRiver Riddle
226b4e30b7SRiver Riddle// CHECK-NOT: func @multiply_transpose
236b4e30b7SRiver Riddle// CHECK-NOT: tensor<*xf64>
246b4e30b7SRiver Riddle
256b4e30b7SRiver Riddle// CHECK-LABEL: func @main()
260050e8f0SRiver Riddle// CHECK:         [[VAL_0:%.*]] = toy.constant dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
270050e8f0SRiver Riddle// CHECK:         [[VAL_1:%.*]] = toy.transpose([[VAL_0]] : tensor<2x3xf64>) to tensor<3x2xf64>
280050e8f0SRiver Riddle// CHECK:         [[VAL_2:%.*]] = toy.mul [[VAL_1]], [[VAL_1]] : tensor<3x2xf64>
290050e8f0SRiver Riddle// CHECK:         toy.print [[VAL_2]] : tensor<3x2xf64>
300050e8f0SRiver Riddle// CHECK:         toy.return
31