xref: /llvm-project/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir (revision a01097faca35a9a8927c8b0c514bc35dcebec00f)
1// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
2
3spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
4  // CHECK-LABEL: @matrix_times_scalar_1
5  spirv.func @matrix_times_scalar_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" {
6    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32
7    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32
8    spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
9  }
10
11  // CHECK-LABEL: @matrix_times_scalar_2
12  spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA> "None" {
13    // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, f16
14    %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>, f16
15    spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixA>
16  }
17
18  // CHECK-LABEL: @matrix_transpose_1
19  spirv.func @matrix_transpose_1(%arg0 : !spirv.matrix<3 x vector<2xf32>>) -> !spirv.matrix<2 x vector<3xf32>> "None" {
20    // CHECK: {{%.*}} = spirv.Transpose {{%.*}} : !spirv.matrix<3 x vector<2xf32>> -> !spirv.matrix<2 x vector<3xf32>>
21    %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<2xf32>> -> !spirv.matrix<2 x vector<3xf32>>
22    spirv.ReturnValue %result : !spirv.matrix<2 x vector<3xf32>>
23  }
24
25  // CHECK-LABEL: @matrix_transpose_2
26  spirv.func @matrix_transpose_2(%arg0 : !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None" {
27    // CHECK: {{%.*}} = spirv.Transpose {{%.*}} : !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
28    %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
29    spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
30  }
31
32  // CHECK-LABEL: @matrix_times_vector_1
33  spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
34    // CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
35    %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
36    spirv.ReturnValue %result : vector<4xf32>
37  }
38
39  // CHECK-LABEL: @vector_times_matrix_1
40  spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" {
41    // CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
42    %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
43    spirv.ReturnValue %result : vector<4xf32>
44  }
45
46  // CHECK-LABEL: @matrix_times_matrix_1
47  spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
48    // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
49    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
50    spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
51  }
52
53  // CHECK-LABEL: @matrix_times_matrix_2
54  spirv.func @matrix_times_matrix_2(%arg0: !spirv.matrix<3 x vector<2xf32>>, %arg1: !spirv.matrix<2 x vector<3xf32>>) -> !spirv.matrix<2 x vector<2xf32>> "None"{
55    // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<2 x vector<2xf32>>
56    %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<2 x vector<2xf32>>
57    spirv.ReturnValue %result : !spirv.matrix<2 x vector<2xf32>>
58  }
59}
60
61// -----
62
63func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f16) {
64  // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
65  %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f16
66  return
67}
68
69// -----
70
71func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f64) {
72  // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
73  %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f64
74  return
75}
76
77// -----
78
79func.func @transpose_op_shape_mismatch_1(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
80   // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
81   %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<3 x vector<3xf32>>
82   return
83}
84
85// -----
86
87func.func @transpose_op_shape_mismatch_2(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
88   // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
89   %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<2 x vector<4xf32>>
90   return
91}
92
93// -----
94
95func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
96   // expected-error @+1 {{input and output matrices must have the same component type}}
97   %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<4 x vector<3xf16>>
98   return
99}
100
101// -----
102
103func.func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<3xf32>>){
104   // expected-error @+1 {{right and result matrices must have equal columns' count}}
105   %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<3 x vector<2xf32>>
106   return
107}
108
109// -----
110
111func.func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<3xf32>>){
112   // expected-error @+1 {{left and result matrices must have equal rows' count}}
113   %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<2 x vector<3xf32>>
114   return
115}
116
117// -----
118
119func.func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<2xf32>>){
120   // expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}}
121   %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<2xf32>> -> !spirv.matrix<2 x vector<2xf32>>
122   return
123}
124
125// -----
126
127func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
128   // expected-error @+1 {{right and result matrices' component type must be the same}}
129   %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf64>>
130   return
131}
132
133// -----
134
135func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
136   // expected-error @+1 {{left and result matrices' component type must be the same}}
137   %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
138   return
139}
140
141// -----
142
143func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf16>) {
144  // expected-error @+1 {{op failed to verify that all of {vector, result} have same element type}}
145  %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf16> -> vector<3xf32>
146  return
147}
148
149// -----
150
151func.func @matrix_times_vector_row_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf32>) {
152  // expected-error @+1 {{spirv.MatrixTimesVector' op result size (4) must match the matrix rows (3)}}
153  %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf32> -> vector<4xf32>
154  return
155}
156
157// -----
158
159func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<3xf32>) {
160  // expected-error @+1 {{spirv.MatrixTimesVector' op matrix columns (4) must match vector operand size (3)}}
161  %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
162  return
163}
164
165// -----
166
167func.func @vector_times_matrix_vector_matrix_mismatch(%arg0: vector<4xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
168  // expected-error @+1 {{number of components in vector must equal the number of components in each column in matrix}}
169  %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<4xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32>
170  return
171}
172
173// -----
174
175func.func @vector_times_matrix_result_matrix_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
176  // expected-error @+1 {{number of columns in matrix must equal the number of components in result}}
177  %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32>
178  return
179}
180
181// -----
182
183func.func @vector_times_matrix_vector_type_mismatch(%arg0: vector<3xf16>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
184  // expected-error @+1 {{op failed to verify that all of {vector, result} have same element type}}
185  %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf16>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
186  return
187}
188
189// -----
190
191func.func @vector_times_matrix_matrix_type_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf16>>) {
192  // expected-error @+1 {{matrix must be a matrix with the same component type as the component type in result}}
193  %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf16>> -> vector<4xf32>
194  return
195}
196