xref: /llvm-project/mlir/test/Integration/Dialect/Vector/CPU/broadcast.mlir (revision eb206e9ea84eff0a0596fed2de8316d924f946d1)
1// RUN: mlir-opt %s -test-lower-to-llvm  | \
2// RUN: mlir-runner -e entry -entry-point-result=void  \
3// RUN:   -shared-libs=%mlir_c_runner_utils | \
4// RUN: FileCheck %s
5
6func.func @entry() {
7  %i = arith.constant 2147483647: i32
8  %l = arith.constant 9223372036854775807 : i64
9
10  %f0 = arith.constant 0.0: f32
11  %f1 = arith.constant 1.0: f32
12  %f2 = arith.constant 2.0: f32
13  %f3 = arith.constant 3.0: f32
14  %f4 = arith.constant 4.0: f32
15  %f5 = arith.constant 5.0: f32
16
17  // Test simple broadcasts.
18  %vi = vector.broadcast %i : i32 to vector<2xi32>
19  %vl = vector.broadcast %l : i64 to vector<2xi64>
20  %vf = vector.broadcast %f1 : f32 to vector<2x2x2xf32>
21  vector.print %vi : vector<2xi32>
22  vector.print %vl : vector<2xi64>
23  vector.print %vf : vector<2x2x2xf32>
24  // CHECK: ( 2147483647, 2147483647 )
25  // CHECK: ( 9223372036854775807, 9223372036854775807 )
26  // CHECK: ( ( ( 1, 1 ), ( 1, 1 ) ), ( ( 1, 1 ), ( 1, 1 ) ) )
27
28  // Test "duplication" in leading dimensions.
29  %v0 = vector.broadcast %f1 : f32 to vector<4xf32>
30  %v1 = vector.insert %f2, %v0[1] : f32 into vector<4xf32>
31  %v2 = vector.insert %f3, %v1[2] : f32 into vector<4xf32>
32  %v3 = vector.insert %f4, %v2[3] : f32 into vector<4xf32>
33  %v4 = vector.broadcast %v3 : vector<4xf32> to vector<3x4xf32>
34  %v5 = vector.broadcast %v3 : vector<4xf32> to vector<2x2x4xf32>
35  vector.print %v3 : vector<4xf32>
36  vector.print %v4 : vector<3x4xf32>
37  vector.print %v5 : vector<2x2x4xf32>
38  // CHECK: ( 1, 2, 3, 4 )
39  // CHECK: ( ( 1, 2, 3, 4 ), ( 1, 2, 3, 4 ), ( 1, 2, 3, 4 ) )
40  // CHECK: ( ( ( 1, 2, 3, 4 ), ( 1, 2, 3, 4 ) ), ( ( 1, 2, 3, 4 ), ( 1, 2, 3, 4 ) ) )
41
42  // Test straightforward "stretch" on a 1-D "scalar".
43  %x = vector.broadcast %f5 : f32 to vector<1xf32>
44  %y = vector.broadcast %x  : vector<1xf32> to vector<8xf32>
45  vector.print %y : vector<8xf32>
46  // CHECK: ( 5, 5, 5, 5, 5, 5, 5, 5 )
47
48  // Test "stretch" in leading dimension.
49  %s = vector.broadcast %v3 : vector<4xf32> to vector<1x4xf32>
50  %t = vector.broadcast %s  : vector<1x4xf32> to vector<3x4xf32>
51  vector.print %s : vector<1x4xf32>
52  vector.print %t : vector<3x4xf32>
53  // CHECK: ( ( 1, 2, 3, 4 ) )
54  // CHECK: ( ( 1, 2, 3, 4 ), ( 1, 2, 3, 4 ), ( 1, 2, 3, 4 ) )
55
56  // Test "stretch" in trailing dimension.
57  %a0 = vector.broadcast %f1 : f32 to vector<3x1xf32>
58  %a1 = vector.insert %f2, %a0[1, 0] : f32 into vector<3x1xf32>
59  %a2 = vector.insert %f3, %a1[2, 0] : f32 into vector<3x1xf32>
60  %a3 = vector.broadcast %a2 : vector<3x1xf32> to vector<3x4xf32>
61  vector.print %a2 : vector<3x1xf32>
62  vector.print %a3 : vector<3x4xf32>
63  // CHECK: ( ( 1 ), ( 2 ), ( 3 ) )
64  // CHECK: ( ( 1, 1, 1, 1 ), ( 2, 2, 2, 2 ), ( 3, 3, 3, 3 ) )
65
66  // Test "stretch" in middle dimension.
67  %m0 = vector.broadcast %f0 : f32 to vector<3x1x2xf32>
68  %m1 = vector.insert %f1, %m0[0, 0, 1] : f32 into vector<3x1x2xf32>
69  %m2 = vector.insert %f2, %m1[1, 0, 0] : f32 into vector<3x1x2xf32>
70  %m3 = vector.insert %f3, %m2[1, 0, 1] : f32 into vector<3x1x2xf32>
71  %m4 = vector.insert %f4, %m3[2, 0, 0] : f32 into vector<3x1x2xf32>
72  %m5 = vector.insert %f5, %m4[2, 0, 1] : f32 into vector<3x1x2xf32>
73  %m6 = vector.broadcast %m5 : vector<3x1x2xf32> to vector<3x4x2xf32>
74  vector.print %m5 : vector<3x1x2xf32>
75  vector.print %m6 : vector<3x4x2xf32>
76  // CHECK: ( ( ( 0, 1 ) ), ( ( 2, 3 ) ), ( ( 4, 5 ) ) )
77  // CHECK: ( ( ( 0, 1 ), ( 0, 1 ), ( 0, 1 ), ( 0, 1 ) ),
78  // CHECK-SAME: ( ( 2, 3 ), ( 2, 3 ), ( 2, 3 ), ( 2, 3 ) ),
79  // CHECK-SAME: ( ( 4, 5 ), ( 4, 5 ), ( 4, 5 ), ( 4, 5 ) ) )
80
81  return
82}
83