xref: /llvm-project/mlir/test/Integration/Dialect/Vector/CPU/outerproduct-i64.mlir (revision eb206e9ea84eff0a0596fed2de8316d924f946d1)
1// RUN: mlir-opt %s -test-lower-to-llvm  | \
2// RUN: mlir-runner -e entry -entry-point-result=void  \
3// RUN:   -shared-libs=%mlir_c_runner_utils | \
4// RUN: FileCheck %s
5
6!vector_type_A = vector<8xi64>
7!vector_type_B = vector<8xi64>
8!vector_type_C = vector<8x8xi64>
9
10!vector_type_X = vector<2xi64>
11!vector_type_Y = vector<3xi64>
12!vector_type_Z = vector<2x3xi64>
13
14!vector_type_R = vector<7xi64>
15
16func.func @vector_outerproduct_splat_8x8(%ia: i64, %ib: i64, %ic: i64) -> !vector_type_C {
17  %a = vector.splat %ia: !vector_type_A
18  %b = vector.splat %ib: !vector_type_B
19  %c = vector.splat %ic: !vector_type_C
20  %d = vector.outerproduct %a, %b, %c : !vector_type_A, !vector_type_B
21  return %d: !vector_type_C
22}
23
24func.func @vector_outerproduct_vec_2x3(%x : !vector_type_X,
25                                  %y : !vector_type_Y) -> !vector_type_Z {
26  %o = vector.outerproduct %x, %y : !vector_type_X, !vector_type_Y
27  return %o: !vector_type_Z
28}
29
30func.func @vector_outerproduct_vec_2x3_acc(%x : !vector_type_X,
31                                      %y : !vector_type_Y,
32                                      %z : !vector_type_Z) -> !vector_type_Z {
33  %o = vector.outerproduct %x, %y, %z : !vector_type_X, !vector_type_Y
34  return %o: !vector_type_Z
35}
36
37func.func @entry() {
38  %i0 = arith.constant 0: i64
39  %i1 = arith.constant 1: i64
40  %i2 = arith.constant 2: i64
41  %i3 = arith.constant 3: i64
42  %i4 = arith.constant 4: i64
43  %i5 = arith.constant 5: i64
44  %i10 = arith.constant 10: i64
45
46  // Simple case, splat scalars into vectors, then take outer product.
47  %v = call @vector_outerproduct_splat_8x8(%i1, %i2, %i10)
48      : (i64, i64, i64) -> (!vector_type_C)
49  vector.print %v : !vector_type_C
50  //
51  // outer product 8x8:
52  //
53  // CHECK-COUNT-8: ( 12, 12, 12, 12, 12, 12, 12, 12 )
54
55  // Direct outerproduct on vectors with different size.
56  %0 = vector.broadcast %i1 : i64 to !vector_type_X
57  %x = vector.insert %i2, %0[1] : i64 into !vector_type_X
58  %1 = vector.broadcast %i3 : i64 to !vector_type_Y
59  %2 = vector.insert %i4, %1[1] : i64 into !vector_type_Y
60  %y = vector.insert %i5, %2[2] : i64 into !vector_type_Y
61
62  %p = call @vector_outerproduct_vec_2x3(%x, %y)
63      : (!vector_type_X, !vector_type_Y) -> (!vector_type_Z)
64  vector.print %p : !vector_type_Z
65  //
66  // outer product 2x3:
67  //
68  // CHECK: ( ( 3, 4, 5 ), ( 6, 8, 10 ) )
69
70  %q = call @vector_outerproduct_vec_2x3_acc(%x, %y, %p)
71      : (!vector_type_X, !vector_type_Y, !vector_type_Z) -> (!vector_type_Z)
72  vector.print %q : !vector_type_Z
73  //
74  // outer product 2x3:
75  //
76  // CHECK: ( ( 6, 8, 10 ), ( 12, 16, 20 ) )
77
78  %3 = vector.broadcast %i0 : i64 to !vector_type_R
79  %4 = vector.insert %i1,  %3[1] : i64 into !vector_type_R
80  %5 = vector.insert %i2,  %4[2] : i64 into !vector_type_R
81  %6 = vector.insert %i3,  %5[3] : i64 into !vector_type_R
82  %7 = vector.insert %i4,  %6[4] : i64 into !vector_type_R
83  %8 = vector.insert %i5,  %7[5] : i64 into !vector_type_R
84  %9 = vector.insert %i10, %8[6] : i64 into !vector_type_R
85
86  %o = vector.broadcast %i1 : i64 to !vector_type_R
87
88  %axpy1 = vector.outerproduct %9, %i2     : !vector_type_R, i64
89  %axpy2 = vector.outerproduct %9, %i2, %o : !vector_type_R, i64
90
91  vector.print %axpy1 : !vector_type_R
92  vector.print %axpy2 : !vector_type_R
93  //
94  // axpy operations:
95  //
96  // CHECK: ( 0, 2, 4, 6, 8, 10, 20 )
97  // CHECK: ( 1, 3, 5, 7, 9, 11, 21 )
98
99  return
100}
101