Lines Matching +full:vector +full:- +full:matrix
1 // RUN: mlir-opt %s -test-lower-to-llvm | \
2 // RUN: mlir-runner -e entry -entry-point-result=void \
3 // RUN: -shared-libs=%mlir_c_runner_utils | \
7 affine_map<(i) -> (i)>,
8 affine_map<(i) -> (i)>,
9 affine_map<(i) -> ()>
17 affine_map<(i, j) -> (i, j)>,
18 affine_map<(i, j) -> (j)>,
19 affine_map<(i, j) -> (i)>
27 affine_map<(i, j) -> (j, i)>,
28 affine_map<(i, j) -> (j)>,
29 affine_map<(i, j) -> (i)>
37 affine_map<(i, j, k) -> (i, k)>,
38 affine_map<(i, j, k) -> (k, j)>,
39 affine_map<(i, j, k) -> (i, j)>
47 affine_map<(i, j, k) -> (k, i)>,
48 affine_map<(i, j, k) -> (k, j)>,
49 affine_map<(i, j, k) -> (i, j)>
57 affine_map<(i, j, k) -> (i, k)>,
58 affine_map<(i, j, k) -> (j, k)>,
59 affine_map<(i, j, k) -> (i, j)>
67 affine_map<(i, j, k) -> (k, i)>,
68 affine_map<(i, j, k) -> (j, k)>,
69 affine_map<(i, j, k) -> (i, j)>
77 affine_map<(i, j, k) -> (i, k)>,
78 affine_map<(i, j, k) -> (k, j)>,
79 affine_map<(i, j, k) -> (j, i)>
87 affine_map<(i, j) -> (i, j)>,
88 affine_map<(i, j) -> (i, j)>,
89 affine_map<(i, j) -> ()>
97 affine_map<(i, j) -> (j, i)>,
98 affine_map<(i, j) -> (j, i)>,
99 affine_map<(i, j) -> ()>
107 affine_map<(i, j) -> (i, j)>,
108 affine_map<(i, j) -> (j, i)>,
109 affine_map<(i, j) -> ()>
117 affine_map<(i, j) -> (j, i)>,
118 affine_map<(i, j) -> (i, j)>,
119 affine_map<(i, j) -> ()>
127 affine_map<(i, j, k) -> (k, j)>,
128 affine_map<(i, j, k) -> (i, k)>,
129 affine_map<(i, j, k) -> (j, i)>
148 %z1 = vector.broadcast %f0 : f32 to vector<2xf32>
149 %z2 = vector.broadcast %f0 : f32 to vector<2x2xf32>
150 %z3 = vector.broadcast %f0 : f32 to vector<3x4xf32>
153 %0 = vector.broadcast %f1 : f32 to vector<2xf32>
154 %a = vector.insert %f2, %0[1] : f32 into vector<2xf32>
155 %1 = vector.broadcast %f3 : f32 to vector<2xf32>
156 %b = vector.insert %f4, %1[1] : f32 into vector<2xf32>
157 %2 = vector.broadcast %f5 : f32 to vector<2xf32>
158 %c = vector.insert %f6, %2[1] : f32 into vector<2xf32>
159 %3 = vector.broadcast %f7 : f32 to vector<2xf32>
160 %d = vector.insert %f8, %3[1] : f32 into vector<2xf32>
162 vector.print %a : vector<2xf32>
163 vector.print %b : vector<2xf32>
164 vector.print %c : vector<2xf32>
165 vector.print %d : vector<2xf32>
175 %4 = vector.broadcast %f0 : f32 to vector<2x2xf32>
176 %5 = vector.insert %a, %4[0] : vector<2xf32> into vector<2x2xf32>
177 %A = vector.insert %b, %5[1] : vector<2xf32> into vector<2x2xf32>
178 %6 = vector.broadcast %f0 : f32 to vector<2x2xf32>
179 %7 = vector.insert %c, %6[0] : vector<2xf32> into vector<2x2xf32>
180 %B = vector.insert %d, %7[1] : vector<2xf32> into vector<2x2xf32>
181 %8 = vector.broadcast %f0 : f32 to vector<3x2xf32>
182 %9 = vector.insert %a, %8[0] : vector<2xf32> into vector<3x2xf32>
183 %10 = vector.insert %b, %9[1] : vector<2xf32> into vector<3x2xf32>
184 %C = vector.insert %c, %10[2] : vector<2xf32> into vector<3x2xf32>
185 %cst = arith.constant dense<0.000000e+00> : vector<2x4xf32>
186 %11 = vector.insert_strided_slice %A, %cst {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32>
187 %D = vector.insert_strided_slice %B, %11 {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<2x4xf32>
189 vector.print %A : vector<2x2xf32>
190 vector.print %B : vector<2x2xf32>
191 vector.print %C : vector<3x2xf32>
192 vector.print %D : vector<2x4xf32>
201 // Contraction: dot-product a x b
202 %dp1 = vector.contract #dotp_trait %a, %b, %f0
203 : vector<2xf32>, vector<2xf32> into f32
204 %dp2 = vector.contract #dotp_trait %a, %b, %f1
205 : vector<2xf32>, vector<2xf32> into f32
207 vector.print %dp1 : f32
208 vector.print %dp2 : f32
215 // Contraction: matrix-vector A x c
216 %mv1 = vector.contract #matvec_trait %A, %c, %z1
217 : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
218 %mv2 = vector.contract #matvec_trait %A, %c, %a
219 : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
221 vector.print %mv1 : vector<2xf32>
222 vector.print %mv2 : vector<2xf32>
224 // matrix x vector:
229 // Contraction: matrix-trans-vector A^T x c
230 %mv3 = vector.contract #mattransvec_trait %A, %c, %z1
231 : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
232 %mv4 = vector.contract #mattransvec_trait %A, %c, %a
233 : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
235 vector.print %mv3 : vector<2xf32>
236 vector.print %mv4 : vector<2xf32>
238 // matrix x vector:
243 // Contraction: matrix-matrix A x B
244 %mm1 = vector.contract #matmat_trait %A, %B, %z2
245 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
246 %mm2 = vector.contract #matmat_trait %A, %B, %A
247 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
249 vector.print %mm1 : vector<2x2xf32>
250 vector.print %mm2 : vector<2x2xf32>
252 // matrix x matrix:
257 // Contraction: matrix-matrix A x B where A, B, C have column-major layout.
263 vector.contract #column_major_matmat_trait %A, %B, %z2
264 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
266 vector.contract #column_major_matmat_trait %A, %B, %A
267 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
269 vector.print %llvm_matrix_column_major_mm0 : vector<2x2xf32>
270 vector.print %llvm_matrix_column_major_mm1 : vector<2x2xf32>
272 // matrix x matrix:
277 // Contraction: matrix-trans-matrix A^T x B
278 %mm3 = vector.contract #mattransmat_trait %A, %B, %z2
279 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
280 %mm4 = vector.contract #mattransmat_trait %A, %B, %A
281 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
283 vector.print %mm3 : vector<2x2xf32>
284 vector.print %mm4 : vector<2x2xf32>
286 // matrix x matrix:
291 // Contraction: matrix-matrix-trans A x B^T
292 %mm5 = vector.contract #matmattrans_trait %A, %B, %z2
293 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
294 %mm6 = vector.contract #matmattrans_trait %A, %B, %A
295 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
297 vector.print %mm5 : vector<2x2xf32>
298 vector.print %mm6 : vector<2x2xf32>
300 // matrix x matrix:
305 // Contraction: matrix-trans-matrix-trans A^T x B^T
306 %mm7 = vector.contract #mattransmattrans_trait %A, %B, %z2
307 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
308 %mm8 = vector.contract #mattransmattrans_trait %A, %B, %A
309 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
311 vector.print %mm7 : vector<2x2xf32>
312 vector.print %mm8 : vector<2x2xf32>
314 // matrix x matrix:
319 // Contraction: matrix-matrix-then-trans (A x B)^T
320 %mm9 = vector.contract #matmat_then_trans_trait %A, %B, %z2
321 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
322 %mm10 = vector.contract #matmat_then_trans_trait %A, %B, %A
323 : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
325 vector.print %mm9 : vector<2x2xf32>
326 vector.print %mm10 : vector<2x2xf32>
328 // matrix x matrix:
333 // Contraction: matrix-matrix C x D
334 %mm11 = vector.contract #matmat_trait %C, %D, %z3
335 : vector<3x2xf32>, vector<2x4xf32> into vector<3x4xf32>
336 %mm12 = vector.contract #matmat_trait %C, %D, %mm11
337 : vector<3x2xf32>, vector<2x4xf32> into vector<3x4xf32>
339 vector.print %mm11 : vector<3x4xf32>
340 vector.print %mm12 : vector<3x4xf32>
345 %c1 = vector.contract #contract2d_trait %A, %B, %f0
346 : vector<2x2xf32>, vector<2x2xf32> into f32
347 %c2 = vector.contract #contract2d_trait %A, %B, %f1
348 : vector<2x2xf32>, vector<2x2xf32> into f32
349 %c3 = vector.contract #contract2d_alt_trait %A, %B, %f0
350 : vector<2x2xf32>, vector<2x2xf32> into f32
351 %c4 = vector.contract #contract2d_alt_trait %A, %B, %f1
352 : vector<2x2xf32>, vector<2x2xf32> into f32
353 %c5 = vector.contract #contract2d_trans_trait %A, %B, %f0
354 : vector<2x2xf32>, vector<2x2xf32> into f32
355 %c6 = vector.contract #contract2d_trans_trait %A, %B, %f1
356 : vector<2x2xf32>, vector<2x2xf32> into f32
357 %c7 = vector.contract #contract2d_trans_alt_trait %A, %B, %f0
358 : vector<2x2xf32>, vector<2x2xf32> into f32
359 %c8 = vector.contract #contract2d_trans_alt_trait %A, %B, %f1
360 : vector<2x2xf32>, vector<2x2xf32> into f32
362 vector.print %c1 : f32
363 vector.print %c2 : f32
364 vector.print %c3 : f32
365 vector.print %c4 : f32
366 vector.print %c5 : f32
367 vector.print %c6 : f32
368 vector.print %c7 : f32
369 vector.print %c8 : f32