xref: /llvm-project/mlir/test/Dialect/SparseTensor/sparse_3d.mlir (revision ced2fc7819d5ddea616ec330f18e08ff284c1868)
1// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
2// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification | FileCheck %s
3
4#Td = #sparse_tensor.encoding<{ map = (d0) -> (d0 : dense) }>
5
6#Tddd = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : dense) }>
7#Tdds = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed) }>
8#Tdsd = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense) }>
9#Tdss = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : compressed) }>
10#Tsdd = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : dense, d2 : dense) }>
11#Tsds = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : dense, d2 : compressed) }>
12#Tssd = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : dense) }>
13#Tsss = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed, d1 : compressed, d2 : compressed) }>
14
15#trait3 = {
16  indexing_maps = [
17    affine_map<(i,j,k) -> (i,j,k)>,  // A
18    affine_map<(i,j,k) -> (i,j,k)>,  // B
19    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
20  ],
21  iterator_types = ["parallel", "parallel", "parallel"],
22  doc = "X(i,j,k) = A(i,j,k) OP B(i,j,k)"
23}
24
25// CHECK-LABEL:   func @add_ddd(
26// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
27// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
28// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
29// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
30// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
31// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
32// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
33// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
34// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
35// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
36// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
37// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
38// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>)
39// CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
40// CHECK:             %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index
41// CHECK:             scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
42// CHECK:               %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index
43// CHECK:               %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index
44// CHECK:               scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
45// CHECK:                 %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : index
46// CHECK:                 %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf32>
47// CHECK:                 %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
48// CHECK:                 %[[VAL_21:.*]] = arith.addf %[[VAL_19]], %[[VAL_20]] : f32
49// CHECK:                 memref.store %[[VAL_21]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
50// CHECK:               }
51// CHECK:             }
52// CHECK:           }
53// CHECK:           %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<32x16x8xf32>
54// CHECK:           return %[[VAL_22]] : tensor<32x16x8xf32>
55// CHECK:         }
56func.func @add_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
57  %0 = linalg.generic #trait3
58     ins(%arga, %argb: tensor<32x16x8xf32, #Tddd>, tensor<32x16x8xf32>)
59    outs(%argx: tensor<32x16x8xf32>) {
60      ^bb(%a: f32, %b: f32, %x: f32):
61        %0 = arith.addf %a, %b : f32
62        linalg.yield %0 : f32
63  } -> tensor<32x16x8xf32>
64  return %0 : tensor<32x16x8xf32>
65}
66
67// CHECK-LABEL:   func @mul_ddd(
68// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
69// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
70// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
71// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
72// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
73// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
74// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
75// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
76// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
77// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
78// CHECK-DAG:       %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
79// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
80// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_11]] : memref<32x16x8xf32>)
81// CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] {
82// CHECK:             %[[VAL_14:.*]] = arith.muli %[[VAL_12]], %[[VAL_4]] : index
83// CHECK:             scf.for %[[VAL_13:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
84// CHECK:               %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : index
85// CHECK:               %[[VAL_17:.*]] = arith.muli %[[VAL_15]], %[[VAL_5]] : index
86// CHECK:               scf.for %[[VAL_16:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
87// CHECK:                 %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : index
88// CHECK:                 %[[VAL_19:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xf32>
89// CHECK:                 %[[VAL_20:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
90// CHECK:                 %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f32
91// CHECK:                 memref.store %[[VAL_21]], %[[VAL_11]]{{\[}}%[[VAL_12]], %[[VAL_13]], %[[VAL_16]]] : memref<32x16x8xf32>
92// CHECK:               }
93// CHECK:             }
94// CHECK:           }
95// CHECK:           %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<32x16x8xf32>
96// CHECK:           return %[[VAL_22]] : tensor<32x16x8xf32>
97// CHECK:         }
98func.func @mul_ddd(%arga: tensor<32x16x8xf32, #Tddd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
99  %0 = linalg.generic #trait3
100     ins(%arga, %argb: tensor<32x16x8xf32, #Tddd>, tensor<32x16x8xf32>)
101    outs(%argx: tensor<32x16x8xf32>) {
102      ^bb(%a: f32, %b: f32, %x: f32):
103        %0 = arith.mulf %a, %b : f32
104        linalg.yield %0 : f32
105  } -> tensor<32x16x8xf32>
106  return %0 : tensor<32x16x8xf32>
107}
108
109// CHECK-LABEL:   func @add_dds(
110// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
111// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
112// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
113// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
114// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
115// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 16 : index
116// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 8 : index
117// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : index
118// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
119// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 1 : index
120// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
121// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
122// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
123// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
124// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
125// CHECK-DAG:       linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_15]] : memref<32x16x8xf32>)
126// CHECK:           scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_9]] {
127// CHECK:             %[[VAL_18:.*]] = arith.muli %[[VAL_16]], %[[VAL_5]] : index
128// CHECK:             scf.for %[[VAL_17:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_9]] {
129// CHECK:               %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_18]] : index
130// CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<?xindex>
131// CHECK:               %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_9]] : index
132// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xindex>
133// CHECK:               %[[VAL_23:.*]]:2 = scf.while (%[[VAL_24:.*]] = %[[VAL_20]], %[[VAL_25:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
134// CHECK:                 %[[VAL_26:.*]] = arith.cmpi ult, %[[VAL_24]], %[[VAL_22]] : index
135// CHECK:                 scf.condition(%[[VAL_26]]) %[[VAL_24]], %[[VAL_25]] : index, index
136// CHECK:               } do {
137// CHECK:               ^bb0(%[[VAL_27:.*]]: index, %[[VAL_28:.*]]: index):
138// CHECK:                 %[[VAL_29:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_27]]] : memref<?xindex>
139// CHECK:                 %[[VAL_30:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_28]] : index
140// CHECK:                 scf.if %[[VAL_30]] {
141// CHECK:                   %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_27]]] : memref<?xf32>
142// CHECK:                   %[[VAL_32:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
143// CHECK:                   %[[VAL_33:.*]] = arith.addf %[[VAL_31]], %[[VAL_32]] : f32
144// CHECK:                   memref.store %[[VAL_33]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
145// CHECK:                 } else {
146// CHECK:                   scf.if %[[VAL_8]] {
147// CHECK:                     %[[VAL_34:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
148// CHECK:                     memref.store %[[VAL_34]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_28]]] : memref<32x16x8xf32>
149// CHECK:                   } else {
150// CHECK:                   }
151// CHECK:                 }
152// CHECK:                 %[[VAL_35:.*]] = arith.cmpi eq, %[[VAL_29]], %[[VAL_28]] : index
153// CHECK:                 %[[VAL_36:.*]] = arith.addi %[[VAL_27]], %[[VAL_9]] : index
154// CHECK:                 %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_36]], %[[VAL_27]] : index
155// CHECK:                 %[[VAL_38:.*]] = arith.addi %[[VAL_28]], %[[VAL_9]] : index
156// CHECK:                 scf.yield %[[VAL_37]], %[[VAL_38]] : index, index
157// CHECK:               }
158// CHECK:               scf.for %[[VAL_39:.*]] = %[[VAL_40:.*]]#1 to %[[VAL_6]] step %[[VAL_9]] {
159// CHECK:                 %[[VAL_41:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_39]]] : memref<32x16x8xf32>
160// CHECK:                 memref.store %[[VAL_41]], %[[VAL_15]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_39]]] : memref<32x16x8xf32>
161// CHECK:               }
162// CHECK:             }
163// CHECK:           }
164// CHECK:           %[[VAL_42:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<32x16x8xf32>
165// CHECK:           return %[[VAL_42]] : tensor<32x16x8xf32>
166// CHECK:         }
167func.func @add_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
168  %0 = linalg.generic #trait3
169     ins(%arga, %argb: tensor<32x16x8xf32, #Tdds>, tensor<32x16x8xf32>)
170    outs(%argx: tensor<32x16x8xf32>) {
171      ^bb(%a: f32, %b: f32, %x: f32):
172        %0 = arith.addf %a, %b : f32
173        linalg.yield %0 : f32
174  } -> tensor<32x16x8xf32>
175  return %0 : tensor<32x16x8xf32>
176}
177
178// CHECK-LABEL:   func @mul_dds(
179// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
180// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
181// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
182// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
183// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
184// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 16 : index
185// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 0 : index
186// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 1 : index
187// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
188// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
189// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
190// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
191// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
192// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<32x16x8xf32>)
193// CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] {
194// CHECK:             %[[VAL_16:.*]] = arith.muli %[[VAL_14]], %[[VAL_5]] : index
195// CHECK:             scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] {
196// CHECK:               %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : index
197// CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
198// CHECK:               %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_7]] : index
199// CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
200// CHECK:               scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_7]] {
201// CHECK:                 %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xindex>
202// CHECK:                 %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xf32>
203// CHECK:                 %[[VAL_24:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_22]]] : memref<32x16x8xf32>
204// CHECK:                 %[[VAL_25:.*]] = arith.mulf %[[VAL_23]], %[[VAL_24]] : f32
205// CHECK:                 memref.store %[[VAL_25]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_15]], %[[VAL_22]]] : memref<32x16x8xf32>
206// CHECK:               }
207// CHECK:             }
208// CHECK:           }
209// CHECK:           %[[VAL_26:.*]] = bufferization.to_tensor %[[VAL_13]] : memref<32x16x8xf32>
210// CHECK:           return %[[VAL_26]] : tensor<32x16x8xf32>
211// CHECK:         }
212func.func @mul_dds(%arga: tensor<32x16x8xf32, #Tdds>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
213  %0 = linalg.generic #trait3
214     ins(%arga, %argb: tensor<32x16x8xf32, #Tdds>, tensor<32x16x8xf32>)
215    outs(%argx: tensor<32x16x8xf32>) {
216      ^bb(%a: f32, %b: f32, %x: f32):
217        %0 = arith.mulf %a, %b : f32
218        linalg.yield %0 : f32
219  } -> tensor<32x16x8xf32>
220  return %0 : tensor<32x16x8xf32>
221}
222
223// CHECK-LABEL:   func @add_dsd(
224// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
225// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
226// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
227// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
228// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
229// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
230// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
231// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant true
232// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : index
233// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 1 : index
234// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
235// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
236// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
237// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
238// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
239// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
240// CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_8]] {
241// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_15]]] : memref<?xindex>
242// CHECK:             %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_8]] : index
243// CHECK:             %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_17]]] : memref<?xindex>
244// CHECK:             %[[VAL_19:.*]]:2 = scf.while (%[[VAL_20:.*]] = %[[VAL_16]], %[[VAL_21:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
245// CHECK:               %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_18]] : index
246// CHECK:               scf.condition(%[[VAL_22]]) %[[VAL_20]], %[[VAL_21]] : index, index
247// CHECK:             } do {
248// CHECK:             ^bb0(%[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index):
249// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref<?xindex>
250// CHECK:               %[[VAL_26:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_24]] : index
251// CHECK:               scf.if %[[VAL_26]] {
252// CHECK:                 %[[VAL_28:.*]] = arith.muli %[[VAL_23]], %[[VAL_5]] : index
253// CHECK:                 scf.for %[[VAL_27:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
254// CHECK:                   %[[VAL_29:.*]] = arith.addi %[[VAL_27]], %[[VAL_28]] : index
255// CHECK:                   %[[VAL_30:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_29]]] : memref<?xf32>
256// CHECK:                   %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_27]]] : memref<32x16x8xf32>
257// CHECK:                   %[[VAL_32:.*]] = arith.addf %[[VAL_30]], %[[VAL_31]] : f32
258// CHECK:                   memref.store %[[VAL_32]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_27]]] : memref<32x16x8xf32>
259// CHECK:                 }
260// CHECK:               } else {
261// CHECK:                 scf.if %[[VAL_6]] {
262// CHECK:                   scf.for %[[VAL_33:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
263// CHECK:                     %[[VAL_34:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_33]]] : memref<32x16x8xf32>
264// CHECK:                     memref.store %[[VAL_34]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_24]], %[[VAL_33]]] : memref<32x16x8xf32>
265// CHECK:                   }
266// CHECK:                 } else {
267// CHECK:                 }
268// CHECK:               }
269// CHECK:               %[[VAL_35:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_24]] : index
270// CHECK:               %[[VAL_36:.*]] = arith.addi %[[VAL_23]], %[[VAL_8]] : index
271// CHECK:               %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_36]], %[[VAL_23]] : index
272// CHECK:               %[[VAL_38:.*]] = arith.addi %[[VAL_24]], %[[VAL_8]] : index
273// CHECK:               scf.yield %[[VAL_37]], %[[VAL_38]] : index, index
274// CHECK:             }
275// CHECK:             scf.for %[[VAL_39:.*]] = %[[VAL_40:.*]]#1 to %[[VAL_4]] step %[[VAL_8]] {
276// CHECK:               scf.for %[[VAL_41:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
277// CHECK:                 %[[VAL_42:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_39]], %[[VAL_41]]] : memref<32x16x8xf32>
278// CHECK:                 memref.store %[[VAL_42]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_39]], %[[VAL_41]]] : memref<32x16x8xf32>
279// CHECK:               }
280// CHECK:             }
281// CHECK:           }
282// CHECK:           %[[VAL_43:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<32x16x8xf32>
283// CHECK:           return %[[VAL_43]] : tensor<32x16x8xf32>
284// CHECK:         }
285func.func @add_dsd(%arga: tensor<32x16x8xf32, #Tdsd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
286  %0 = linalg.generic #trait3
287     ins(%arga, %argb: tensor<32x16x8xf32, #Tdsd>, tensor<32x16x8xf32>)
288    outs(%argx: tensor<32x16x8xf32>) {
289      ^bb(%a: f32, %b: f32, %x: f32):
290        %0 = arith.addf %a, %b : f32
291        linalg.yield %0 : f32
292  } -> tensor<32x16x8xf32>
293  return %0 : tensor<32x16x8xf32>
294}
295
296// CHECK-LABEL:   func @mul_dsd(
297// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
298// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
299// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
300// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
301// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
302// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 8 : index
303// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
304// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
305// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
306// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
307// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
308// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
309// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
310// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_12]] : memref<32x16x8xf32>)
311// CHECK:           scf.for %[[VAL_13:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
312// CHECK:             %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_13]]] : memref<?xindex>
313// CHECK:             %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : index
314// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
315// CHECK:             scf.for %[[VAL_17:.*]] = %[[VAL_14]] to %[[VAL_16]] step %[[VAL_6]] {
316// CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
317// CHECK:               %[[VAL_20:.*]] = arith.muli %[[VAL_17]], %[[VAL_4]] : index
318// CHECK:               scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
319// CHECK:                 %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_20]] : index
320// CHECK:                 %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xf32>
321// CHECK:                 %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_13]], %[[VAL_18]], %[[VAL_19]]] : memref<32x16x8xf32>
322// CHECK:                 %[[VAL_24:.*]] = arith.mulf %[[VAL_22]], %[[VAL_23]] : f32
323// CHECK:                 memref.store %[[VAL_24]], %[[VAL_12]]{{\[}}%[[VAL_13]], %[[VAL_18]], %[[VAL_19]]] : memref<32x16x8xf32>
324// CHECK:               }
325// CHECK:             }
326// CHECK:           }
327// CHECK:           %[[VAL_25:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<32x16x8xf32>
328// CHECK:           return %[[VAL_25]] : tensor<32x16x8xf32>
329// CHECK:         }
330func.func @mul_dsd(%arga: tensor<32x16x8xf32, #Tdsd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
331  %0 = linalg.generic #trait3
332     ins(%arga, %argb: tensor<32x16x8xf32, #Tdsd>, tensor<32x16x8xf32>)
333    outs(%argx: tensor<32x16x8xf32>) {
334      ^bb(%a: f32, %b: f32, %x: f32):
335        %0 = arith.mulf %a, %b : f32
336        linalg.yield %0 : f32
337  } -> tensor<32x16x8xf32>
338  return %0 : tensor<32x16x8xf32>
339}
340
341// CHECK-LABEL:   func @add_dss(
342// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
343// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
344// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
345// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
346// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
347// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 16 : index
348// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 8 : index
349// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant true
350// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : index
351// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 1 : index
352// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
353// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
354// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
355// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
356// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
357// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
358// CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
359// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_17]] : memref<32x16x8xf32>)
360// CHECK:           scf.for %[[VAL_18:.*]] = %[[VAL_8]] to %[[VAL_4]] step %[[VAL_9]] {
361// CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_18]]] : memref<?xindex>
362// CHECK:             %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_9]] : index
363// CHECK:             %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_20]]] : memref<?xindex>
364// CHECK:             %[[VAL_22:.*]]:2 = scf.while (%[[VAL_23:.*]] = %[[VAL_19]], %[[VAL_24:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
365// CHECK:               %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_23]], %[[VAL_21]] : index
366// CHECK:               scf.condition(%[[VAL_25]]) %[[VAL_23]], %[[VAL_24]] : index, index
367// CHECK:             } do {
368// CHECK:             ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index):
369// CHECK:               %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_26]]] : memref<?xindex>
370// CHECK:               %[[VAL_29:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_27]] : index
371// CHECK:               scf.if %[[VAL_29]] {
372// CHECK:                 %[[VAL_30:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<?xindex>
373// CHECK:                 %[[VAL_31:.*]] = arith.addi %[[VAL_26]], %[[VAL_9]] : index
374// CHECK:                 %[[VAL_32:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_31]]] : memref<?xindex>
375// CHECK:                 %[[VAL_33:.*]]:2 = scf.while (%[[VAL_34:.*]] = %[[VAL_30]], %[[VAL_35:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
376// CHECK:                   %[[VAL_36:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_32]] : index
377// CHECK:                   scf.condition(%[[VAL_36]]) %[[VAL_34]], %[[VAL_35]] : index, index
378// CHECK:                 } do {
379// CHECK:                 ^bb0(%[[VAL_37:.*]]: index, %[[VAL_38:.*]]: index):
380// CHECK:                   %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_37]]] : memref<?xindex>
381// CHECK:                   %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_38]] : index
382// CHECK:                   scf.if %[[VAL_40]] {
383// CHECK:                     %[[VAL_41:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_37]]] : memref<?xf32>
384// CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
385// CHECK:                     %[[VAL_43:.*]] = arith.addf %[[VAL_41]], %[[VAL_42]] : f32
386// CHECK:                     memref.store %[[VAL_43]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
387// CHECK:                   } else {
388// CHECK:                     scf.if %[[VAL_7]] {
389// CHECK:                       %[[VAL_44:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
390// CHECK:                       memref.store %[[VAL_44]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_38]]] : memref<32x16x8xf32>
391// CHECK:                     } else {
392// CHECK:                     }
393// CHECK:                   }
394// CHECK:                   %[[VAL_45:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_38]] : index
395// CHECK:                   %[[VAL_46:.*]] = arith.addi %[[VAL_37]], %[[VAL_9]] : index
396// CHECK:                   %[[VAL_47:.*]] = arith.select %[[VAL_45]], %[[VAL_46]], %[[VAL_37]] : index
397// CHECK:                   %[[VAL_48:.*]] = arith.addi %[[VAL_38]], %[[VAL_9]] : index
398// CHECK:                   scf.yield %[[VAL_47]], %[[VAL_48]] : index, index
399// CHECK:                 }
400// CHECK:                 scf.for %[[VAL_49:.*]] = %[[VAL_50:.*]]#1 to %[[VAL_6]] step %[[VAL_9]] {
401// CHECK:                   %[[VAL_51:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_49]]] : memref<32x16x8xf32>
402// CHECK:                   memref.store %[[VAL_51]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_49]]] : memref<32x16x8xf32>
403// CHECK:                 }
404// CHECK:               } else {
405// CHECK:                 scf.if %[[VAL_7]] {
406// CHECK:                   scf.for %[[VAL_52:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
407// CHECK:                     %[[VAL_53:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_52]]] : memref<32x16x8xf32>
408// CHECK:                     memref.store %[[VAL_53]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_27]], %[[VAL_52]]] : memref<32x16x8xf32>
409// CHECK:                   }
410// CHECK:                 } else {
411// CHECK:                 }
412// CHECK:               }
413// CHECK:               %[[VAL_54:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_27]] : index
414// CHECK:               %[[VAL_55:.*]] = arith.addi %[[VAL_26]], %[[VAL_9]] : index
415// CHECK:               %[[VAL_56:.*]] = arith.select %[[VAL_54]], %[[VAL_55]], %[[VAL_26]] : index
416// CHECK:               %[[VAL_57:.*]] = arith.addi %[[VAL_27]], %[[VAL_9]] : index
417// CHECK:               scf.yield %[[VAL_56]], %[[VAL_57]] : index, index
418// CHECK:             }
419// CHECK:             scf.for %[[VAL_58:.*]] = %[[VAL_59:.*]]#1 to %[[VAL_5]] step %[[VAL_9]] {
420// CHECK:               scf.for %[[VAL_60:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
421// CHECK:                 %[[VAL_61:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_18]], %[[VAL_58]], %[[VAL_60]]] : memref<32x16x8xf32>
422// CHECK:                 memref.store %[[VAL_61]], %[[VAL_17]]{{\[}}%[[VAL_18]], %[[VAL_58]], %[[VAL_60]]] : memref<32x16x8xf32>
423// CHECK:               }
424// CHECK:             }
425// CHECK:           }
426// CHECK:           %[[VAL_62:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<32x16x8xf32>
427// CHECK:           return %[[VAL_62]] : tensor<32x16x8xf32>
428// CHECK:         }
429func.func @add_dss(%arga: tensor<32x16x8xf32, #Tdss>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
430  %0 = linalg.generic #trait3
431     ins(%arga, %argb: tensor<32x16x8xf32, #Tdss>, tensor<32x16x8xf32>)
432    outs(%argx: tensor<32x16x8xf32>) {
433      ^bb(%a: f32, %b: f32, %x: f32):
434        %0 = arith.addf %a, %b : f32
435        linalg.yield %0 : f32
436  } -> tensor<32x16x8xf32>
437  return %0 : tensor<32x16x8xf32>
438}
439
440// CHECK-LABEL:   func @mul_dss(
441// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
442// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
443// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
444// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
445// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
446// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
447// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
448// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
449// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
450// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
451// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
452// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
453// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
454// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
455// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
456// CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
457// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_15]]] : memref<?xindex>
458// CHECK:             %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_6]] : index
459// CHECK:             %[[VAL_18:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_17]]] : memref<?xindex>
460// CHECK:             scf.for %[[VAL_19:.*]] = %[[VAL_16]] to %[[VAL_18]] step %[[VAL_6]] {
461// CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
462// CHECK:               %[[VAL_21:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_19]]] : memref<?xindex>
463// CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_6]] : index
464// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref<?xindex>
465// CHECK:               scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_6]] {
466// CHECK:                 %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_24]]] : memref<?xindex>
467// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_24]]] : memref<?xf32>
468// CHECK:                 %[[VAL_27:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_15]], %[[VAL_20]], %[[VAL_25]]] : memref<32x16x8xf32>
469// CHECK:                 %[[VAL_28:.*]] = arith.mulf %[[VAL_26]], %[[VAL_27]] : f32
470// CHECK:                 memref.store %[[VAL_28]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_20]], %[[VAL_25]]] : memref<32x16x8xf32>
471// CHECK:               }
472// CHECK:             }
473// CHECK:           }
474// CHECK:           %[[VAL_29:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<32x16x8xf32>
475// CHECK:           return %[[VAL_29]] : tensor<32x16x8xf32>
476// CHECK:         }
477func.func @mul_dss(%arga: tensor<32x16x8xf32, #Tdss>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
478  %0 = linalg.generic #trait3
479     ins(%arga, %argb: tensor<32x16x8xf32, #Tdss>, tensor<32x16x8xf32>)
480    outs(%argx: tensor<32x16x8xf32>) {
481      ^bb(%a: f32, %b: f32, %x: f32):
482        %0 = arith.mulf %a, %b : f32
483        linalg.yield %0 : f32
484  } -> tensor<32x16x8xf32>
485  return %0 : tensor<32x16x8xf32>
486}
487
488// CHECK-LABEL:   func @add_sdd(
489// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
490// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
491// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
492// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
493// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
494// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
495// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
496// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant true
497// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : index
498// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 1 : index
499// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
500// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
501// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
502// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
503// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
504// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
505// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
506// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref<?xindex>
507// CHECK:           %[[VAL_17:.*]]:2 = scf.while (%[[VAL_18:.*]] = %[[VAL_15]], %[[VAL_19:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
508// CHECK:             %[[VAL_20:.*]] = arith.cmpi ult, %[[VAL_18]], %[[VAL_16]] : index
509// CHECK:             scf.condition(%[[VAL_20]]) %[[VAL_18]], %[[VAL_19]] : index, index
510// CHECK:           } do {
511// CHECK:           ^bb0(%[[VAL_21:.*]]: index, %[[VAL_22:.*]]: index):
512// CHECK:             %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_21]]] : memref<?xindex>
513// CHECK:             %[[VAL_24:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_22]] : index
514// CHECK:             scf.if %[[VAL_24]] {
515// CHECK:               %[[VAL_26:.*]] = arith.muli %[[VAL_21]], %[[VAL_4]] : index
516// CHECK:               scf.for %[[VAL_25:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
517// CHECK:                 %[[VAL_27:.*]] = arith.addi %[[VAL_25]], %[[VAL_26]] : index
518// CHECK:                  %[[VAL_29:.*]] = arith.muli %[[VAL_27]], %[[VAL_5]] : index
519// CHECK:                 scf.for %[[VAL_28:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
520// CHECK:                   %[[VAL_30:.*]] = arith.addi %[[VAL_28]], %[[VAL_29]] : index
521// CHECK:                   %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<?xf32>
522// CHECK:                   %[[VAL_32:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]] : memref<32x16x8xf32>
523// CHECK:                   %[[VAL_33:.*]] = arith.addf %[[VAL_31]], %[[VAL_32]] : f32
524// CHECK:                   memref.store %[[VAL_33]], %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_25]], %[[VAL_28]]] : memref<32x16x8xf32>
525// CHECK:                 }
526// CHECK:               }
527// CHECK:             } else {
528// CHECK:               scf.if %[[VAL_6]] {
529// CHECK:                 scf.for %[[VAL_34:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
530// CHECK:                   scf.for %[[VAL_35:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
531// CHECK:                     %[[VAL_36:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_22]], %[[VAL_34]], %[[VAL_35]]] : memref<32x16x8xf32>
532// CHECK:                     memref.store %[[VAL_36]], %[[VAL_14]]{{\[}}%[[VAL_22]], %[[VAL_34]], %[[VAL_35]]] : memref<32x16x8xf32>
533// CHECK:                   }
534// CHECK:                 }
535// CHECK:               } else {
536// CHECK:               }
537// CHECK:             }
538// CHECK:             %[[VAL_37:.*]] = arith.cmpi eq, %[[VAL_23]], %[[VAL_22]] : index
539// CHECK:             %[[VAL_38:.*]] = arith.addi %[[VAL_21]], %[[VAL_8]] : index
540// CHECK:             %[[VAL_39:.*]] = arith.select %[[VAL_37]], %[[VAL_38]], %[[VAL_21]] : index
541// CHECK:             %[[VAL_40:.*]] = arith.addi %[[VAL_22]], %[[VAL_8]] : index
542// CHECK:             scf.yield %[[VAL_39]], %[[VAL_40]] : index, index
543// CHECK:           }
544// CHECK:           scf.for %[[VAL_41:.*]] = %[[VAL_42:.*]]#1 to %[[VAL_3]] step %[[VAL_8]] {
545// CHECK:             scf.for %[[VAL_43:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
546// CHECK:               scf.for %[[VAL_44:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
547// CHECK:                 %[[VAL_45:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_41]], %[[VAL_43]], %[[VAL_44]]] : memref<32x16x8xf32>
548// CHECK:                 memref.store %[[VAL_45]], %[[VAL_14]]{{\[}}%[[VAL_41]], %[[VAL_43]], %[[VAL_44]]] : memref<32x16x8xf32>
549// CHECK:               }
550// CHECK:             }
551// CHECK:           }
552// CHECK:           %[[VAL_46:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<32x16x8xf32>
553// CHECK:           return %[[VAL_46]] : tensor<32x16x8xf32>
554// CHECK:         }
555func.func @add_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
556  %0 = linalg.generic #trait3
557     ins(%arga, %argb: tensor<32x16x8xf32, #Tsdd>, tensor<32x16x8xf32>)
558    outs(%argx: tensor<32x16x8xf32>) {
559      ^bb(%a: f32, %b: f32, %x: f32):
560        %0 = arith.addf %a, %b : f32
561        linalg.yield %0 : f32
562  } -> tensor<32x16x8xf32>
563  return %0 : tensor<32x16x8xf32>
564}
565
566// CHECK-LABEL:   func @mul_sdd(
567// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
568// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
569// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
570// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
571// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 16 : index
572// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 8 : index
573// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
574// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
575// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
576// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
577// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
578// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
579// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
580// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_12]] : memref<32x16x8xf32>)
581// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
582// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
583// CHECK:           scf.for %[[VAL_15:.*]] = %[[VAL_13]] to %[[VAL_14]] step %[[VAL_6]] {
584// CHECK:             %[[VAL_16:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]]] : memref<?xindex>
585// CHECK:             %[[VAL_18:.*]] = arith.muli %[[VAL_15]], %[[VAL_3]] : index
586// CHECK:             scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_3]] step %[[VAL_6]] {
587// CHECK:               %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_18]] : index
588// CHECK:               %[[VAL_21:.*]] = arith.muli %[[VAL_19]], %[[VAL_4]] : index
589// CHECK:               scf.for %[[VAL_20:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
590// CHECK:                 %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_21]] : index
591// CHECK:                 %[[VAL_23:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_22]]] : memref<?xf32>
592// CHECK:                 %[[VAL_24:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_20]]] : memref<32x16x8xf32>
593// CHECK:                 %[[VAL_25:.*]] = arith.mulf %[[VAL_23]], %[[VAL_24]] : f32
594// CHECK:                 memref.store %[[VAL_25]], %[[VAL_12]]{{\[}}%[[VAL_16]], %[[VAL_17]], %[[VAL_20]]] : memref<32x16x8xf32>
595// CHECK:               }
596// CHECK:             }
597// CHECK:           }
598// CHECK:           %[[VAL_26:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<32x16x8xf32>
599// CHECK:           return %[[VAL_26]] : tensor<32x16x8xf32>
600// CHECK:         }
601func.func @mul_sdd(%arga: tensor<32x16x8xf32, #Tsdd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
602  %0 = linalg.generic #trait3
603     ins(%arga, %argb: tensor<32x16x8xf32, #Tsdd>, tensor<32x16x8xf32>)
604    outs(%argx: tensor<32x16x8xf32>) {
605      ^bb(%a: f32, %b: f32, %x: f32):
606        %0 = arith.mulf %a, %b : f32
607        linalg.yield %0 : f32
608  } -> tensor<32x16x8xf32>
609  return %0 : tensor<32x16x8xf32>
610}
611
612// CHECK-LABEL:   func @add_sds(
613// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
614// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
615// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
616// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
617// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
618// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 16 : index
619// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 8 : index
620// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant true
621// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : index
622// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 1 : index
623// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
624// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
625// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
626// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
627// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
628// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
629// CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
630// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_17]] : memref<32x16x8xf32>)
631// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_8]]] : memref<?xindex>
632// CHECK:           %[[VAL_19:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_9]]] : memref<?xindex>
633// CHECK:           %[[VAL_20:.*]]:2 = scf.while (%[[VAL_21:.*]] = %[[VAL_18]], %[[VAL_22:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
634// CHECK:             %[[VAL_23:.*]] = arith.cmpi ult, %[[VAL_21]], %[[VAL_19]] : index
635// CHECK:             scf.condition(%[[VAL_23]]) %[[VAL_21]], %[[VAL_22]] : index, index
636// CHECK:           } do {
637// CHECK:           ^bb0(%[[VAL_24:.*]]: index, %[[VAL_25:.*]]: index):
638// CHECK:             %[[VAL_26:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_24]]] : memref<?xindex>
639// CHECK:             %[[VAL_27:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_25]] : index
640// CHECK:             scf.if %[[VAL_27]] {
641// CHECK:               %[[VAL_29:.*]] = arith.muli %[[VAL_24]], %[[VAL_5]] : index
642// CHECK:               scf.for %[[VAL_28:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
643// CHECK:                 %[[VAL_30:.*]] = arith.addi %[[VAL_28]], %[[VAL_29]] : index
644// CHECK:                 %[[VAL_31:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_30]]] : memref<?xindex>
645// CHECK:                 %[[VAL_32:.*]] = arith.addi %[[VAL_30]], %[[VAL_9]] : index
646// CHECK:                 %[[VAL_33:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_32]]] : memref<?xindex>
647// CHECK:                 %[[VAL_34:.*]]:2 = scf.while (%[[VAL_35:.*]] = %[[VAL_31]], %[[VAL_36:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
648// CHECK:                   %[[VAL_37:.*]] = arith.cmpi ult, %[[VAL_35]], %[[VAL_33]] : index
649// CHECK:                   scf.condition(%[[VAL_37]]) %[[VAL_35]], %[[VAL_36]] : index, index
650// CHECK:                 } do {
651// CHECK:                 ^bb0(%[[VAL_38:.*]]: index, %[[VAL_39:.*]]: index):
652// CHECK:                   %[[VAL_40:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_38]]] : memref<?xindex>
653// CHECK:                   %[[VAL_41:.*]] = arith.cmpi eq, %[[VAL_40]], %[[VAL_39]] : index
654// CHECK:                   scf.if %[[VAL_41]] {
655// CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_38]]] : memref<?xf32>
656// CHECK:                     %[[VAL_43:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
657// CHECK:                     %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32
658// CHECK:                     memref.store %[[VAL_44]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
659// CHECK:                   } else {
660// CHECK:                     scf.if %[[VAL_7]] {
661// CHECK:                       %[[VAL_45:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
662// CHECK:                       memref.store %[[VAL_45]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_39]]] : memref<32x16x8xf32>
663// CHECK:                     } else {
664// CHECK:                     }
665// CHECK:                   }
666// CHECK:                   %[[VAL_46:.*]] = arith.cmpi eq, %[[VAL_40]], %[[VAL_39]] : index
667// CHECK:                   %[[VAL_47:.*]] = arith.addi %[[VAL_38]], %[[VAL_9]] : index
668// CHECK:                   %[[VAL_48:.*]] = arith.select %[[VAL_46]], %[[VAL_47]], %[[VAL_38]] : index
669// CHECK:                   %[[VAL_49:.*]] = arith.addi %[[VAL_39]], %[[VAL_9]] : index
670// CHECK:                   scf.yield %[[VAL_48]], %[[VAL_49]] : index, index
671// CHECK:                 }
672// CHECK:                 scf.for %[[VAL_50:.*]] = %[[VAL_51:.*]]#1 to %[[VAL_6]] step %[[VAL_9]] {
673// CHECK:                   %[[VAL_52:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_50]]] : memref<32x16x8xf32>
674// CHECK:                   memref.store %[[VAL_52]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_28]], %[[VAL_50]]] : memref<32x16x8xf32>
675// CHECK:                 }
676// CHECK:               }
677// CHECK:             } else {
678// CHECK:               scf.if %[[VAL_7]] {
679// CHECK:                 scf.for %[[VAL_53:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
680// CHECK:                   scf.for %[[VAL_54:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
681// CHECK:                     %[[VAL_55:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]], %[[VAL_53]], %[[VAL_54]]] : memref<32x16x8xf32>
682// CHECK:                     memref.store %[[VAL_55]], %[[VAL_17]]{{\[}}%[[VAL_25]], %[[VAL_53]], %[[VAL_54]]] : memref<32x16x8xf32>
683// CHECK:                   }
684// CHECK:                 }
685// CHECK:               } else {
686// CHECK:               }
687// CHECK:             }
688// CHECK:             %[[VAL_56:.*]] = arith.cmpi eq, %[[VAL_26]], %[[VAL_25]] : index
689// CHECK:             %[[VAL_57:.*]] = arith.addi %[[VAL_24]], %[[VAL_9]] : index
690// CHECK:             %[[VAL_58:.*]] = arith.select %[[VAL_56]], %[[VAL_57]], %[[VAL_24]] : index
691// CHECK:             %[[VAL_59:.*]] = arith.addi %[[VAL_25]], %[[VAL_9]] : index
692// CHECK:             scf.yield %[[VAL_58]], %[[VAL_59]] : index, index
693// CHECK:           }
694// CHECK:           scf.for %[[VAL_60:.*]] = %[[VAL_61:.*]]#1 to %[[VAL_4]] step %[[VAL_9]] {
695// CHECK:             scf.for %[[VAL_62:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
696// CHECK:               scf.for %[[VAL_63:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
697// CHECK:                 %[[VAL_64:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_60]], %[[VAL_62]], %[[VAL_63]]] : memref<32x16x8xf32>
698// CHECK:                 memref.store %[[VAL_64]], %[[VAL_17]]{{\[}}%[[VAL_60]], %[[VAL_62]], %[[VAL_63]]] : memref<32x16x8xf32>
699// CHECK:               }
700// CHECK:             }
701// CHECK:           }
702// CHECK:           %[[VAL_65:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<32x16x8xf32>
703// CHECK:           return %[[VAL_65]] : tensor<32x16x8xf32>
704// CHECK:         }
705func.func @add_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
706  %0 = linalg.generic #trait3
707     ins(%arga, %argb: tensor<32x16x8xf32, #Tsds>, tensor<32x16x8xf32>)
708    outs(%argx: tensor<32x16x8xf32>) {
709      ^bb(%a: f32, %b: f32, %x: f32):
710        %0 = arith.addf %a, %b : f32
711        linalg.yield %0 : f32
712  } -> tensor<32x16x8xf32>
713  return %0 : tensor<32x16x8xf32>
714}
715
716// CHECK-LABEL:   func @mul_sds(
717// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
718// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
719// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
720// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
721// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
722// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
723// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
724// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
725// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
726// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
727// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
728// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
729// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
730// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
731// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_14]] : memref<32x16x8xf32>)
732// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_5]]] : memref<?xindex>
733// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref<?xindex>
734// CHECK:           scf.for %[[VAL_17:.*]] = %[[VAL_15]] to %[[VAL_16]] step %[[VAL_6]] {
735// CHECK:             %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<?xindex>
736// CHECK:             %[[VAL_20:.*]] = arith.muli %[[VAL_17]], %[[VAL_4]] : index
737// CHECK:             scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_6]] {
738// CHECK:               %[[VAL_21:.*]] = arith.addi %[[VAL_19]], %[[VAL_20]] : index
739// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xindex>
740// CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_21]], %[[VAL_6]] : index
741// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
742// CHECK:               scf.for %[[VAL_25:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_6]] {
743// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xindex>
744// CHECK:                 %[[VAL_27:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_25]]] : memref<?xf32>
745// CHECK:                 %[[VAL_28:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_18]], %[[VAL_19]], %[[VAL_26]]] : memref<32x16x8xf32>
746// CHECK:                 %[[VAL_29:.*]] = arith.mulf %[[VAL_27]], %[[VAL_28]] : f32
747// CHECK:                 memref.store %[[VAL_29]], %[[VAL_14]]{{\[}}%[[VAL_18]], %[[VAL_19]], %[[VAL_26]]] : memref<32x16x8xf32>
748// CHECK:               }
749// CHECK:             }
750// CHECK:           }
751// CHECK:           %[[VAL_30:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<32x16x8xf32>
752// CHECK:           return %[[VAL_30]] : tensor<32x16x8xf32>
753// CHECK:         }
754func.func @mul_sds(%arga: tensor<32x16x8xf32, #Tsds>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
755  %0 = linalg.generic #trait3
756     ins(%arga, %argb: tensor<32x16x8xf32, #Tsds>, tensor<32x16x8xf32>)
757    outs(%argx: tensor<32x16x8xf32>) {
758      ^bb(%a: f32, %b: f32, %x: f32):
759        %0 = arith.mulf %a, %b : f32
760        linalg.yield %0 : f32
761  } -> tensor<32x16x8xf32>
762  return %0 : tensor<32x16x8xf32>
763}
764
765// CHECK-LABEL:   func @add_ssd(
766// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
767// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
768// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
769// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
770// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 32 : index
771// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 16 : index
772// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 8 : index
773// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant true
774// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : index
775// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 1 : index
776// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
777// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
778// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
779// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
780// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
781// CHECK-DAG:       %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
782// CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
783// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_16]] : memref<32x16x8xf32>)
784// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_7]]] : memref<?xindex>
785// CHECK:           %[[VAL_18:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_8]]] : memref<?xindex>
786// CHECK:           %[[VAL_19:.*]]:2 = scf.while (%[[VAL_20:.*]] = %[[VAL_17]], %[[VAL_21:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
787// CHECK:             %[[VAL_22:.*]] = arith.cmpi ult, %[[VAL_20]], %[[VAL_18]] : index
788// CHECK:             scf.condition(%[[VAL_22]]) %[[VAL_20]], %[[VAL_21]] : index, index
789// CHECK:           } do {
790// CHECK:           ^bb0(%[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index):
791// CHECK:             %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref<?xindex>
792// CHECK:             %[[VAL_26:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_24]] : index
793// CHECK:             scf.if %[[VAL_26]] {
794// CHECK:               %[[VAL_27:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]]] : memref<?xindex>
795// CHECK:               %[[VAL_28:.*]] = arith.addi %[[VAL_23]], %[[VAL_8]] : index
796// CHECK:               %[[VAL_29:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_28]]] : memref<?xindex>
797// CHECK:               %[[VAL_30:.*]]:2 = scf.while (%[[VAL_31:.*]] = %[[VAL_27]], %[[VAL_32:.*]] = %[[VAL_7]]) : (index, index) -> (index, index) {
798// CHECK:                 %[[VAL_33:.*]] = arith.cmpi ult, %[[VAL_31]], %[[VAL_29]] : index
799// CHECK:                 scf.condition(%[[VAL_33]]) %[[VAL_31]], %[[VAL_32]] : index, index
800// CHECK:               } do {
801// CHECK:               ^bb0(%[[VAL_34:.*]]: index, %[[VAL_35:.*]]: index):
802// CHECK:                 %[[VAL_36:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_34]]] : memref<?xindex>
803// CHECK:                 %[[VAL_37:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_35]] : index
804// CHECK:                 scf.if %[[VAL_37]] {
805// CHECK:                   %[[VAL_39:.*]] = arith.muli %[[VAL_34]], %[[VAL_5]] : index
806// CHECK:                   scf.for %[[VAL_38:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
807// CHECK:                     %[[VAL_40:.*]] = arith.addi %[[VAL_38]], %[[VAL_39]] : index
808// CHECK:                     %[[VAL_41:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_40]]] : memref<?xf32>
809// CHECK:                     %[[VAL_42:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_38]]] : memref<32x16x8xf32>
810// CHECK:                     %[[VAL_43:.*]] = arith.addf %[[VAL_41]], %[[VAL_42]] : f32
811// CHECK:                     memref.store %[[VAL_43]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_38]]] : memref<32x16x8xf32>
812// CHECK:                   }
813// CHECK:                 } else {
814// CHECK:                   scf.if %[[VAL_6]] {
815// CHECK:                     scf.for %[[VAL_44:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
816// CHECK:                       %[[VAL_45:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_44]]] : memref<32x16x8xf32>
817// CHECK:                       memref.store %[[VAL_45]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_35]], %[[VAL_44]]] : memref<32x16x8xf32>
818// CHECK:                     }
819// CHECK:                   } else {
820// CHECK:                   }
821// CHECK:                 }
822// CHECK:                 %[[VAL_46:.*]] = arith.cmpi eq, %[[VAL_36]], %[[VAL_35]] : index
823// CHECK:                 %[[VAL_47:.*]] = arith.addi %[[VAL_34]], %[[VAL_8]] : index
824// CHECK:                 %[[VAL_48:.*]] = arith.select %[[VAL_46]], %[[VAL_47]], %[[VAL_34]] : index
825// CHECK:                 %[[VAL_49:.*]] = arith.addi %[[VAL_35]], %[[VAL_8]] : index
826// CHECK:                 scf.yield %[[VAL_48]], %[[VAL_49]] : index, index
827// CHECK:               }
828// CHECK:               scf.for %[[VAL_50:.*]] = %[[VAL_51:.*]]#1 to %[[VAL_4]] step %[[VAL_8]] {
829// CHECK:                 scf.for %[[VAL_52:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
830// CHECK:                   %[[VAL_53:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_50]], %[[VAL_52]]] : memref<32x16x8xf32>
831// CHECK:                   memref.store %[[VAL_53]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_50]], %[[VAL_52]]] : memref<32x16x8xf32>
832// CHECK:                 }
833// CHECK:               }
834// CHECK:             } else {
835// CHECK:               scf.if %[[VAL_6]] {
836// CHECK:                 scf.for %[[VAL_54:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
837// CHECK:                   scf.for %[[VAL_55:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
838// CHECK:                     %[[VAL_56:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]], %[[VAL_54]], %[[VAL_55]]] : memref<32x16x8xf32>
839// CHECK:                     memref.store %[[VAL_56]], %[[VAL_16]]{{\[}}%[[VAL_24]], %[[VAL_54]], %[[VAL_55]]] : memref<32x16x8xf32>
840// CHECK:                   }
841// CHECK:                 }
842// CHECK:               } else {
843// CHECK:               }
844// CHECK:             }
845// CHECK:             %[[VAL_57:.*]] = arith.cmpi eq, %[[VAL_25]], %[[VAL_24]] : index
846// CHECK:             %[[VAL_58:.*]] = arith.addi %[[VAL_23]], %[[VAL_8]] : index
847// CHECK:             %[[VAL_59:.*]] = arith.select %[[VAL_57]], %[[VAL_58]], %[[VAL_23]] : index
848// CHECK:             %[[VAL_60:.*]] = arith.addi %[[VAL_24]], %[[VAL_8]] : index
849// CHECK:             scf.yield %[[VAL_59]], %[[VAL_60]] : index, index
850// CHECK:           }
851// CHECK:           scf.for %[[VAL_61:.*]] = %[[VAL_62:.*]]#1 to %[[VAL_3]] step %[[VAL_8]] {
852// CHECK:             scf.for %[[VAL_63:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
853// CHECK:               scf.for %[[VAL_64:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
854// CHECK:                 %[[VAL_65:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_61]], %[[VAL_63]], %[[VAL_64]]] : memref<32x16x8xf32>
855// CHECK:                 memref.store %[[VAL_65]], %[[VAL_16]]{{\[}}%[[VAL_61]], %[[VAL_63]], %[[VAL_64]]] : memref<32x16x8xf32>
856// CHECK:               }
857// CHECK:             }
858// CHECK:           }
859// CHECK:           %[[VAL_66:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<32x16x8xf32>
860// CHECK:           return %[[VAL_66]] : tensor<32x16x8xf32>
861// CHECK:         }
862func.func @add_ssd(%arga: tensor<32x16x8xf32, #Tssd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
863  %0 = linalg.generic #trait3
864     ins(%arga, %argb: tensor<32x16x8xf32, #Tssd>, tensor<32x16x8xf32>)
865    outs(%argx: tensor<32x16x8xf32>) {
866      ^bb(%a: f32, %b: f32, %x: f32):
867        %0 = arith.addf %a, %b : f32
868        linalg.yield %0 : f32
869  } -> tensor<32x16x8xf32>
870  return %0 : tensor<32x16x8xf32>
871}
872
873// CHECK-LABEL:   func @mul_ssd(
874// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
875// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
876// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
877// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
878// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 8 : index
879// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
880// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
881// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
882// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
883// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
884// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
885// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
886// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
887// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
888// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<32x16x8xf32>)
889// CHECK:           %[[VAL_14:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
890// CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
891// CHECK:           scf.for %[[VAL_16:.*]] = %[[VAL_14]] to %[[VAL_15]] step %[[VAL_5]] {
892// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_16]]] : memref<?xindex>
893// CHECK:             %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_16]]] : memref<?xindex>
894// CHECK:             %[[VAL_19:.*]] = arith.addi %[[VAL_16]], %[[VAL_5]] : index
895// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_19]]] : memref<?xindex>
896// CHECK:             scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_20]] step %[[VAL_5]] {
897// CHECK:               %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_21]]] : memref<?xindex>
898// CHECK:               %[[VAL_24:.*]] = arith.muli %[[VAL_21]], %[[VAL_3]] : index
899// CHECK:               scf.for %[[VAL_23:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
900// CHECK:                 %[[VAL_25:.*]] = arith.addi %[[VAL_23]], %[[VAL_24]] : index
901// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_25]]] : memref<?xf32>
902// CHECK:                 %[[VAL_27:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_17]], %[[VAL_22]], %[[VAL_23]]] : memref<32x16x8xf32>
903// CHECK:                 %[[VAL_28:.*]] = arith.mulf %[[VAL_26]], %[[VAL_27]] : f32
904// CHECK:                 memref.store %[[VAL_28]], %[[VAL_13]]{{\[}}%[[VAL_17]], %[[VAL_22]], %[[VAL_23]]] : memref<32x16x8xf32>
905// CHECK:               }
906// CHECK:             }
907// CHECK:           }
908// CHECK:           %[[VAL_29:.*]] = bufferization.to_tensor %[[VAL_13]] : memref<32x16x8xf32>
909// CHECK:           return %[[VAL_29]] : tensor<32x16x8xf32>
910// CHECK:         }
911func.func @mul_ssd(%arga: tensor<32x16x8xf32, #Tssd>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
912  %0 = linalg.generic #trait3
913     ins(%arga, %argb: tensor<32x16x8xf32, #Tssd>, tensor<32x16x8xf32>)
914    outs(%argx: tensor<32x16x8xf32>) {
915      ^bb(%a: f32, %b: f32, %x: f32):
916        %0 = arith.mulf %a, %b : f32
917        linalg.yield %0 : f32
918  } -> tensor<32x16x8xf32>
919  return %0 : tensor<32x16x8xf32>
920}
921
922// CHECK-LABEL:   func @add_sss(
923// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
924// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
925// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
926// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
927// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 32 : index
928// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 16 : index
929// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 8 : index
930// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant true
931// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 0 : index
932// CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 1 : index
933// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
934// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
935// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
936// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
937// CHECK-DAG:       %[[VAL_14:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
938// CHECK-DAG:       %[[VAL_15:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
939// CHECK-DAG:       %[[VAL_16:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
940// CHECK-DAG:       %[[VAL_17:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
941// CHECK-DAG:       %[[VAL_19:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
942// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_19]] : memref<32x16x8xf32>)
943// CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_8]]] : memref<?xindex>
944// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_9]]] : memref<?xindex>
945// CHECK:           %[[VAL_22:.*]]:2 = scf.while (%[[VAL_23:.*]] = %[[VAL_20]], %[[VAL_24:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
946// CHECK:             %[[VAL_25:.*]] = arith.cmpi ult, %[[VAL_23]], %[[VAL_21]] : index
947// CHECK:             scf.condition(%[[VAL_25]]) %[[VAL_23]], %[[VAL_24]] : index, index
948// CHECK:           } do {
949// CHECK:           ^bb0(%[[VAL_26:.*]]: index, %[[VAL_27:.*]]: index):
950// CHECK:             %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_26]]] : memref<?xindex>
951// CHECK:             %[[VAL_29:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_27]] : index
952// CHECK:             scf.if %[[VAL_29]] {
953// CHECK:               %[[VAL_30:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_26]]] : memref<?xindex>
954// CHECK:               %[[VAL_31:.*]] = arith.addi %[[VAL_26]], %[[VAL_9]] : index
955// CHECK:               %[[VAL_32:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_31]]] : memref<?xindex>
956// CHECK:               %[[VAL_33:.*]]:2 = scf.while (%[[VAL_34:.*]] = %[[VAL_30]], %[[VAL_35:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
957// CHECK:                 %[[VAL_36:.*]] = arith.cmpi ult, %[[VAL_34]], %[[VAL_32]] : index
958// CHECK:                 scf.condition(%[[VAL_36]]) %[[VAL_34]], %[[VAL_35]] : index, index
959// CHECK:               } do {
960// CHECK:               ^bb0(%[[VAL_37:.*]]: index, %[[VAL_38:.*]]: index):
961// CHECK:                 %[[VAL_39:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_37]]] : memref<?xindex>
962// CHECK:                 %[[VAL_40:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_38]] : index
963// CHECK:                 scf.if %[[VAL_40]] {
964// CHECK:                   %[[VAL_41:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_37]]] : memref<?xindex>
965// CHECK:                   %[[VAL_42:.*]] = arith.addi %[[VAL_37]], %[[VAL_9]] : index
966// CHECK:                   %[[VAL_43:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_42]]] : memref<?xindex>
967// CHECK:                   %[[VAL_44:.*]]:2 = scf.while (%[[VAL_45:.*]] = %[[VAL_41]], %[[VAL_46:.*]] = %[[VAL_8]]) : (index, index) -> (index, index) {
968// CHECK:                     %[[VAL_47:.*]] = arith.cmpi ult, %[[VAL_45]], %[[VAL_43]] : index
969// CHECK:                     scf.condition(%[[VAL_47]]) %[[VAL_45]], %[[VAL_46]] : index, index
970// CHECK:                   } do {
971// CHECK:                   ^bb0(%[[VAL_48:.*]]: index, %[[VAL_49:.*]]: index):
972// CHECK:                     %[[VAL_50:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_48]]] : memref<?xindex>
973// CHECK:                     %[[VAL_51:.*]] = arith.cmpi eq, %[[VAL_50]], %[[VAL_49]] : index
974// CHECK:                     scf.if %[[VAL_51]] {
975// CHECK:                       %[[VAL_52:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_48]]] : memref<?xf32>
976// CHECK:                       %[[VAL_53:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
977// CHECK:                       %[[VAL_54:.*]] = arith.addf %[[VAL_52]], %[[VAL_53]] : f32
978// CHECK:                       memref.store %[[VAL_54]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
979// CHECK:                     } else {
980// CHECK:                       scf.if %[[VAL_7]] {
981// CHECK:                         %[[VAL_55:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
982// CHECK:                         memref.store %[[VAL_55]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_49]]] : memref<32x16x8xf32>
983// CHECK:                       } else {
984// CHECK:                       }
985// CHECK:                     }
986// CHECK:                     %[[VAL_56:.*]] = arith.cmpi eq, %[[VAL_50]], %[[VAL_49]] : index
987// CHECK:                     %[[VAL_57:.*]] = arith.addi %[[VAL_48]], %[[VAL_9]] : index
988// CHECK:                     %[[VAL_58:.*]] = arith.select %[[VAL_56]], %[[VAL_57]], %[[VAL_48]] : index
989// CHECK:                     %[[VAL_59:.*]] = arith.addi %[[VAL_49]], %[[VAL_9]] : index
990// CHECK:                     scf.yield %[[VAL_58]], %[[VAL_59]] : index, index
991// CHECK:                   }
992// CHECK:                   scf.for %[[VAL_60:.*]] = %[[VAL_61:.*]]#1 to %[[VAL_6]] step %[[VAL_9]] {
993// CHECK:                     %[[VAL_62:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_60]]] : memref<32x16x8xf32>
994// CHECK:                     memref.store %[[VAL_62]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_60]]] : memref<32x16x8xf32>
995// CHECK:                   }
996// CHECK:                 } else {
997// CHECK:                   scf.if %[[VAL_7]] {
998// CHECK:                     scf.for %[[VAL_63:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
999// CHECK:                       %[[VAL_64:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_63]]] : memref<32x16x8xf32>
1000// CHECK:                       memref.store %[[VAL_64]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_38]], %[[VAL_63]]] : memref<32x16x8xf32>
1001// CHECK:                     }
1002// CHECK:                   } else {
1003// CHECK:                   }
1004// CHECK:                 }
1005// CHECK:                 %[[VAL_65:.*]] = arith.cmpi eq, %[[VAL_39]], %[[VAL_38]] : index
1006// CHECK:                 %[[VAL_66:.*]] = arith.addi %[[VAL_37]], %[[VAL_9]] : index
1007// CHECK:                 %[[VAL_67:.*]] = arith.select %[[VAL_65]], %[[VAL_66]], %[[VAL_37]] : index
1008// CHECK:                 %[[VAL_68:.*]] = arith.addi %[[VAL_38]], %[[VAL_9]] : index
1009// CHECK:                 scf.yield %[[VAL_67]], %[[VAL_68]] : index, index
1010// CHECK:               }
1011// CHECK:               scf.for %[[VAL_69:.*]] = %[[VAL_70:.*]]#1 to %[[VAL_5]] step %[[VAL_9]] {
1012// CHECK:                 scf.for %[[VAL_71:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
1013// CHECK:                   %[[VAL_72:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_69]], %[[VAL_71]]] : memref<32x16x8xf32>
1014// CHECK:                   memref.store %[[VAL_72]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_69]], %[[VAL_71]]] : memref<32x16x8xf32>
1015// CHECK:                 }
1016// CHECK:               }
1017// CHECK:             } else {
1018// CHECK:               scf.if %[[VAL_7]] {
1019// CHECK:                 scf.for %[[VAL_73:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
1020// CHECK:                   scf.for %[[VAL_74:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
1021// CHECK:                     %[[VAL_75:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]], %[[VAL_73]], %[[VAL_74]]] : memref<32x16x8xf32>
1022// CHECK:                     memref.store %[[VAL_75]], %[[VAL_19]]{{\[}}%[[VAL_27]], %[[VAL_73]], %[[VAL_74]]] : memref<32x16x8xf32>
1023// CHECK:                   }
1024// CHECK:                 }
1025// CHECK:               } else {
1026// CHECK:               }
1027// CHECK:             }
1028// CHECK:             %[[VAL_76:.*]] = arith.cmpi eq, %[[VAL_28]], %[[VAL_27]] : index
1029// CHECK:             %[[VAL_77:.*]] = arith.addi %[[VAL_26]], %[[VAL_9]] : index
1030// CHECK:             %[[VAL_78:.*]] = arith.select %[[VAL_76]], %[[VAL_77]], %[[VAL_26]] : index
1031// CHECK:             %[[VAL_79:.*]] = arith.addi %[[VAL_27]], %[[VAL_9]] : index
1032// CHECK:             scf.yield %[[VAL_78]], %[[VAL_79]] : index, index
1033// CHECK:           }
1034// CHECK:           scf.for %[[VAL_80:.*]] = %[[VAL_81:.*]]#1 to %[[VAL_4]] step %[[VAL_9]] {
1035// CHECK:             scf.for %[[VAL_82:.*]] = %[[VAL_8]] to %[[VAL_5]] step %[[VAL_9]] {
1036// CHECK:               scf.for %[[VAL_83:.*]] = %[[VAL_8]] to %[[VAL_6]] step %[[VAL_9]] {
1037// CHECK:                 %[[VAL_84:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_80]], %[[VAL_82]], %[[VAL_83]]] : memref<32x16x8xf32>
1038// CHECK:                 memref.store %[[VAL_84]], %[[VAL_19]]{{\[}}%[[VAL_80]], %[[VAL_82]], %[[VAL_83]]] : memref<32x16x8xf32>
1039// CHECK:               }
1040// CHECK:             }
1041// CHECK:           }
1042// CHECK:           %[[VAL_85:.*]] = bufferization.to_tensor %[[VAL_19]] : memref<32x16x8xf32>
1043// CHECK:           return %[[VAL_85]] : tensor<32x16x8xf32>
1044// CHECK:         }
1045func.func @add_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
1046  %0 = linalg.generic #trait3
1047     ins(%arga, %argb: tensor<32x16x8xf32, #Tsss>, tensor<32x16x8xf32>)
1048    outs(%argx: tensor<32x16x8xf32>) {
1049      ^bb(%a: f32, %b: f32, %x: f32):
1050        %0 = arith.addf %a, %b : f32
1051        linalg.yield %0 : f32
1052  } -> tensor<32x16x8xf32>
1053  return %0 : tensor<32x16x8xf32>
1054}
1055
1056// CHECK-LABEL:   func @mul_sss(
1057// CHECK-SAME:      %[[VAL_0:.*]]: tensor<32x16x8xf32, #sparse{{[0-9]*}}>,
1058// CHECK-SAME:      %[[VAL_1:.*]]: tensor<32x16x8xf32>,
1059// CHECK-SAME:      %[[VAL_2:.*]]: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
1060// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
1061// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 0 : index
1062// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 1 : index
1063// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
1064// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
1065// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
1066// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
1067// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
1068// CHECK-DAG:       %[[VAL_11:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 2 : index} : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xindex>
1069// CHECK-DAG:       %[[VAL_12:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32x16x8xf32, #sparse{{[0-9]*}}> to memref<?xf32>
1070// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
1071// CHECK-DAG:       %[[VAL_15:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<32x16x8xf32> to memref<32x16x8xf32>
1072// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_15]] : memref<32x16x8xf32>)
1073// CHECK:           %[[VAL_16:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
1074// CHECK:           %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
1075// CHECK:           scf.for %[[VAL_18:.*]] = %[[VAL_16]] to %[[VAL_17]] step %[[VAL_5]] {
1076// CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_18]]] : memref<?xindex>
1077// CHECK:             %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref<?xindex>
1078// CHECK:             %[[VAL_21:.*]] = arith.addi %[[VAL_18]], %[[VAL_5]] : index
1079// CHECK:             %[[VAL_22:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_21]]] : memref<?xindex>
1080// CHECK:             scf.for %[[VAL_23:.*]] = %[[VAL_20]] to %[[VAL_22]] step %[[VAL_5]] {
1081// CHECK:               %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref<?xindex>
1082// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_23]]] : memref<?xindex>
1083// CHECK:               %[[VAL_26:.*]] = arith.addi %[[VAL_23]], %[[VAL_5]] : index
1084// CHECK:               %[[VAL_27:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref<?xindex>
1085// CHECK:               scf.for %[[VAL_28:.*]] = %[[VAL_25]] to %[[VAL_27]] step %[[VAL_5]] {
1086// CHECK:                 %[[VAL_29:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_28]]] : memref<?xindex>
1087// CHECK:                 %[[VAL_30:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_28]]] : memref<?xf32>
1088// CHECK:                 %[[VAL_31:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_19]], %[[VAL_24]], %[[VAL_29]]] : memref<32x16x8xf32>
1089// CHECK:                 %[[VAL_32:.*]] = arith.mulf %[[VAL_30]], %[[VAL_31]] : f32
1090// CHECK:                 memref.store %[[VAL_32]], %[[VAL_15]]{{\[}}%[[VAL_19]], %[[VAL_24]], %[[VAL_29]]] : memref<32x16x8xf32>
1091// CHECK:               }
1092// CHECK:             }
1093// CHECK:           }
1094// CHECK:           %[[VAL_33:.*]] = bufferization.to_tensor %[[VAL_15]] : memref<32x16x8xf32>
1095// CHECK:           return %[[VAL_33]] : tensor<32x16x8xf32>
1096// CHECK:         }
1097func.func @mul_sss(%arga: tensor<32x16x8xf32, #Tsss>, %argb: tensor<32x16x8xf32>, %argx: tensor<32x16x8xf32>) -> tensor<32x16x8xf32> {
1098  %0 = linalg.generic #trait3
1099     ins(%arga, %argb: tensor<32x16x8xf32, #Tsss>, tensor<32x16x8xf32>)
1100    outs(%argx: tensor<32x16x8xf32>) {
1101      ^bb(%a: f32, %b: f32, %x: f32):
1102        %0 = arith.mulf %a, %b : f32
1103        linalg.yield %0 : f32
1104  } -> tensor<32x16x8xf32>
1105  return %0 : tensor<32x16x8xf32>
1106}
1107
1108#trait_kernel_3d = {
1109  indexing_maps = [
1110    affine_map<(i,j,k,l) -> (i,k,l)>,  // B
1111    affine_map<(i,j,k,l) -> (k,j)>,    // C
1112    affine_map<(i,j,k,l) -> (l,j)>,    // D
1113    affine_map<(i,j,k,l) -> (i,j)>     // A (out)
1114  ],
1115  iterator_types = ["parallel", "parallel", "reduction", "reduction"],
1116  doc = "A(i,j) += SUM_k,l B(i,k,l) * C(k,j) * D(l,j)"
1117}
1118
1119// CHECK-LABEL:   func @kernel_3d(
1120// CHECK-SAME:      %[[VAL_0:.*0]]: tensor<?x?xf32>,
1121// CHECK-SAME:      %[[VAL_1:.*1]]: tensor<?x?x?xf32, #sparse{{[0-9]*}}>,
1122// CHECK-SAME:      %[[VAL_2:.*2]]: tensor<?x?xf32>,
1123// CHECK-SAME:      %[[VAL_3:.*3]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
1124// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
1125// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 1 : index
1126// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_1]] {level = 2 : index} : tensor<?x?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
1127// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.coordinates %[[VAL_1]] {level = 2 : index} : tensor<?x?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
1128// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}> to memref<?xf32>
1129// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.lvl %[[VAL_1]], %[[VAL_6]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
1130// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<?x?xf32> to memref<?x?xf32>
1131// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_3]] : tensor<?x?xf32> to memref<?x?xf32>
1132// CHECK-DAG:       %[[VAL_13:.*]] = sparse_tensor.lvl %[[VAL_1]], %[[VAL_5]] : tensor<?x?x?xf32, #sparse{{[0-9]*}}>
1133// CHECK-DAG:       %[[VAL_14:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
1134// CHECK-DAG:       %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_0]] : tensor<?x?xf32> to memref<?x?xf32>
1135// CHECK:           scf.for %[[VAL_17:.*]] = %[[VAL_5]] to %[[VAL_13]] step %[[VAL_6]] {
1136// CHECK:             %[[VAL_19:.*]] = arith.muli %[[VAL_17]], %[[VAL_10]] : index
1137// CHECK:             scf.for %[[VAL_18:.*]] = %[[VAL_5]] to %[[VAL_10]] step %[[VAL_6]] {
1138// CHECK:               %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : index
1139// CHECK:               %[[VAL_21:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_20]]] : memref<?xindex>
1140// CHECK:               %[[VAL_22:.*]] = arith.addi %[[VAL_20]], %[[VAL_6]] : index
1141// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_22]]] : memref<?xindex>
1142// CHECK:               scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_6]] {
1143// CHECK:                 %[[VAL_25:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_24]]] : memref<?xindex>
1144// CHECK:                 %[[VAL_26:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_24]]] : memref<?xf32>
1145// CHECK:                 scf.for %[[VAL_27:.*]] = %[[VAL_5]] to %[[VAL_14]] step %[[VAL_6]] {
1146// CHECK:                   %[[VAL_28:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_18]], %[[VAL_27]]] : memref<?x?xf32>
1147// CHECK:                   %[[VAL_29:.*]] = arith.mulf %[[VAL_26]], %[[VAL_28]] : f32
1148// CHECK:                   %[[VAL_30:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_25]], %[[VAL_27]]] : memref<?x?xf32>
1149// CHECK:                   %[[VAL_31:.*]] = arith.mulf %[[VAL_29]], %[[VAL_30]] : f32
1150// CHECK:                   %[[VAL_32:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_17]], %[[VAL_27]]] : memref<?x?xf32>
1151// CHECK:                   %[[VAL_33:.*]] = arith.addf %[[VAL_31]], %[[VAL_32]] : f32
1152// CHECK:                   memref.store %[[VAL_33]], %[[VAL_16]]{{\[}}%[[VAL_17]], %[[VAL_27]]] : memref<?x?xf32>
1153// CHECK:                 }
1154// CHECK:               }
1155// CHECK:             }
1156// CHECK:           }
1157// CHECK:           %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<?x?xf32>
1158// CHECK:           return %[[VAL_34]] : tensor<?x?xf32>
1159// CHECK:         }
1160func.func @kernel_3d(%arga: tensor<?x?xf32>,
1161                %argb: tensor<?x?x?xf32, #Tdds>,
1162                %argc: tensor<?x?xf32>,
1163	        %argd: tensor<?x?xf32>) -> tensor<?x?xf32> {
1164  %0 = linalg.generic #trait_kernel_3d
1165       ins(%argb, %argc, %argd: tensor<?x?x?xf32, #Tdds>, tensor<?x?xf32>, tensor<?x?xf32>)
1166      outs(%arga: tensor<?x?xf32>) {
1167    ^bb(%b: f32, %c: f32, %d: f32, %a: f32):
1168      %0 = arith.mulf %b, %c : f32
1169      %1 = arith.mulf %0, %d : f32
1170      %2 = arith.addf %1, %a : f32
1171      linalg.yield %2 : f32
1172  } -> tensor<?x?xf32>
1173  return %0 : tensor<?x?xf32>
1174}
1175
1176#trait_sum_reduction = {
1177  indexing_maps = [
1178    affine_map<(i,j,k) -> (i,j,k)>,  // A
1179    affine_map<(i,j,k) -> ()>        // x (scalar out)
1180  ],
1181  iterator_types = ["reduction", "reduction", "reduction"],
1182  doc = "x += SUM_ijk A(i,j,k)"
1183}
1184
1185// CHECK-LABEL:   func @sum_reduction(
1186// CHECK-SAME:      %[[VAL_0:.*]]: tensor<10x20x30xf32, #sparse{{[0-9]*}}>,
1187// CHECK-SAME:      %[[VAL_1:.*]]: tensor<f32>) -> tensor<f32> {
1188// CHECK-DAG:       %[[VAL_2:.*]] = arith.constant 0 : index
1189// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
1190// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<10x20x30xf32, #sparse{{[0-9]*}}>
1191// CHECK-DAG:       %[[VAL_6:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<10x20x30xf32, #sparse{{[0-9]*}}>
1192// CHECK-DAG:       %[[VAL_7:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 2 : index} : tensor<10x20x30xf32, #sparse{{[0-9]*}}>
1193// CHECK-DAG:       %[[VAL_8:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10x20x30xf32, #sparse{{[0-9]*}}>
1194// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<f32> to memref<f32>
1195// CHECK:           %[[VAL_11:.*]] = memref.load %[[VAL_10]][] : memref<f32>
1196// CHECK:           %[[VAL_12:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_2]]] : memref<?xindex>
1197// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref<?xindex>
1198// CHECK:           %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_3]] iter_args(%[[VAL_16:.*]] = %[[VAL_11]]) -> (f32) {
1199// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_15]]] : memref<?xindex>
1200// CHECK:             %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[VAL_3]] : index
1201// CHECK:             %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_18]]] : memref<?xindex>
1202// CHECK:             %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_17]] to %[[VAL_19]] step %[[VAL_3]] iter_args(%[[VAL_22:.*]] = %[[VAL_16]]) -> (f32) {
1203// CHECK:               %[[VAL_23:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_21]]] : memref<?xindex>
1204// CHECK:               %[[VAL_24:.*]] = arith.addi %[[VAL_21]], %[[VAL_3]] : index
1205// CHECK:               %[[VAL_25:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_24]]] : memref<?xindex>
1206// CHECK:               %[[VAL_26:.*]] = scf.for %[[VAL_27:.*]] = %[[VAL_23]] to %[[VAL_25]] step %[[VAL_3]] iter_args(%[[VAL_28:.*]] = %[[VAL_22]]) -> (f32) {
1207// CHECK:                 %[[VAL_29:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_27]]] : memref<?xf32>
1208// CHECK:                 %[[VAL_30:.*]] = arith.addf %[[VAL_28]], %[[VAL_29]] : f32
1209// CHECK:                 scf.yield %[[VAL_30]] : f32
1210// CHECK:               }
1211// CHECK:               scf.yield %[[VAL_26]] : f32
1212// CHECK:             }
1213// CHECK:             scf.yield %[[VAL_20]] : f32
1214// CHECK:           }
1215// CHECK:           memref.store %[[VAL_14]], %[[VAL_10]][] : memref<f32>
1216// CHECK:           %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_10]] : memref<f32>
1217// CHECK:           return %[[VAL_34]] : tensor<f32>
1218// CHECK:         }
1219func.func @sum_reduction(%arga: tensor<10x20x30xf32, #Tsss>, %argx: tensor<f32>) -> tensor<f32> {
1220  %0 = linalg.generic #trait_sum_reduction
1221     ins(%arga: tensor<10x20x30xf32, #Tsss>)
1222    outs(%argx: tensor<f32>) {
1223      ^bb(%a: f32, %x: f32):
1224        %0 = arith.addf %x, %a : f32
1225        linalg.yield %0 : f32
1226  } -> tensor<f32>
1227  return %0 : tensor<f32>
1228}
1229
1230#trait_sum_reduction_inv = {
1231  indexing_maps = [
1232    affine_map<(i,j,k) -> (i,j,k)>,  // A
1233    affine_map<(i,j,k) -> (i)>,      // b
1234    affine_map<(i,j,k) -> ()>        // x (scalar out)
1235  ],
1236  iterator_types = ["reduction", "reduction", "reduction"],
1237  doc = "x += SUM_i A(i,j,k) * b(i)"
1238}
1239
1240// CHECK-LABEL:   func @sum_reduction_inv(
1241// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?x?xf32>,
1242// CHECK-SAME:      %[[VAL_1:.*]]: tensor<?xf32, #sparse{{[0-9]*}}>
1243// CHECK-SAME:      %[[VAL_2:.*]]: tensor<f32>) -> tensor<f32> {
1244// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 1 : index
1245// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 2 : index
1246// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 0 : index
1247// CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32>
1248// CHECK-DAG:       %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32>
1249// CHECK-DAG:       %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : tensor<?x?x?xf32> to memref<?x?x?xf32>
1250// CHECK-DAG:       %[[VAL_9:.*]] = tensor.dim %[[VAL_0]], %[[VAL_5]] : tensor<?x?x?xf32>
1251// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse{{[0-9]*}}>
1252// CHECK-DAG:       %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<f32> to memref<f32>
1253// CHECK:           %[[VAL_13:.*]] = memref.load %[[VAL_12]][] : memref<f32>
1254// CHECK:           %[[VAL_14:.*]] = scf.for %[[VAL_15:.*]] = %[[VAL_5]] to %[[VAL_9]] step %[[VAL_3]] iter_args(%[[VAL_16:.*]] = %[[VAL_13]]) -> (f32) {
1255// CHECK:             %[[VAL_17:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_15]]] : memref<?xf32>
1256// CHECK:             %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_5]] to %[[VAL_6]] step %[[VAL_3]] iter_args(%[[VAL_20:.*]] = %[[VAL_16]]) -> (f32) {
1257// CHECK:               %[[VAL_21:.*]] = scf.for %[[VAL_22:.*]] = %[[VAL_5]] to %[[VAL_7]] step %[[VAL_3]] iter_args(%[[VAL_23:.*]] = %[[VAL_20]]) -> (f32) {
1258// CHECK:                 %[[VAL_24:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_15]], %[[VAL_19]], %[[VAL_22]]] : memref<?x?x?xf32>
1259// CHECK:                 %[[VAL_25:.*]] = arith.mulf %[[VAL_24]], %[[VAL_17]] : f32
1260// CHECK:                 %[[VAL_26:.*]] = arith.addf %[[VAL_23]], %[[VAL_25]] : f32
1261// CHECK:                 scf.yield %[[VAL_26]] : f32
1262// CHECK:               }
1263// CHECK:               scf.yield %[[VAL_21]] : f32
1264// CHECK:             }
1265// CHECK:             scf.yield %[[VAL_18]] : f32
1266// CHECK:           }
1267// CHECK:           memref.store %[[VAL_14]], %[[VAL_12]][] : memref<f32>
1268// CHECK:           %[[VAL_30:.*]] = bufferization.to_tensor %[[VAL_12]] : memref<f32>
1269// CHECK:           return %[[VAL_30]] : tensor<f32>
1270// CHECK:         }
1271func.func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
1272                        %argb: tensor<?xf32, #Td>,
1273		        %argx: tensor<f32>) -> tensor<f32> {
1274  %0 = linalg.generic #trait_sum_reduction_inv
1275    ins(%arga, %argb: tensor<?x?x?xf32>, tensor<?xf32, #Td>)
1276    outs(%argx: tensor<f32>) {
1277      ^bb(%a: f32, %b: f32, %x: f32):
1278        %0 = arith.mulf %a, %b : f32
1279        %1 = arith.addf %x, %0 : f32
1280        linalg.yield %1 : f32
1281  } -> tensor<f32>
1282  return %0 : tensor<f32>
1283}
1284
1285#trait_invariants = {
1286  indexing_maps = [
1287    affine_map<(i,j,k) -> (i)>,      // a
1288    affine_map<(i,j,k) -> (j)>,      // b
1289    affine_map<(i,j,k) -> (k)>,      // c
1290    affine_map<(i,j,k) -> (i,j,k)>   // X (out)
1291  ],
1292  iterator_types = ["parallel", "parallel", "parallel"],
1293  doc = "X(i,j,k) = a(i) * b(j) * c(k)"
1294}
1295
1296// CHECK-LABEL:   func @invariants(
1297// CHECK-SAME:      %[[VAL_0:.*]]: tensor<10xf32, #sparse{{[0-9]*}}>,
1298// CHECK-SAME:      %[[VAL_1:.*]]: tensor<20xf32>,
1299// CHECK-SAME:      %[[VAL_2:.*]]: tensor<30xf32>,
1300// CHECK-SAME:      %[[VAL_3:.*]]: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> {
1301// CHECK-DAG:       %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
1302// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 10 : index
1303// CHECK-DAG:       %[[VAL_5:.*]] = arith.constant 20 : index
1304// CHECK-DAG:       %[[VAL_6:.*]] = arith.constant 30 : index
1305// CHECK-DAG:       %[[VAL_7:.*]] = arith.constant 0 : index
1306// CHECK-DAG:       %[[VAL_8:.*]] = arith.constant 1 : index
1307// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xf32, #sparse{{[0-9]*}}> to memref<?xf32>
1308// CHECK-DAG:       %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_1]] : tensor<20xf32> to memref<20xf32>
1309// CHECK-DAG:       %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_2]] : tensor<30xf32> to memref<30xf32>
1310// CHECK-DAG:       %[[VAL_13:.*]] = bufferization.to_memref %[[VAL_3]] : tensor<10x20x30xf32> to memref<10x20x30xf32>
1311// CHECK-DAG:           linalg.fill ins(%[[ZERO]] : f32) outs(%[[VAL_13]] : memref<10x20x30xf32>)
1312// CHECK:           scf.for %[[VAL_14:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
1313// CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref<?xf32>
1314// CHECK:             scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
1315// CHECK:               %[[VAL_17:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<20xf32>
1316// CHECK:               scf.for %[[VAL_18:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] {
1317// CHECK:                 %[[VAL_19:.*]] = arith.mulf %[[VAL_15]], %[[VAL_17]] : f32
1318// CHECK:                 %[[VAL_20:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_18]]] : memref<30xf32>
1319// CHECK:                 %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f32
1320// CHECK:                 memref.store %[[VAL_21]], %[[VAL_13]]{{\[}}%[[VAL_14]], %[[VAL_16]], %[[VAL_18]]] : memref<10x20x30xf32>
1321// CHECK:               }
1322// CHECK:             }
1323// CHECK:           }
1324// CHECK:           %[[VAL_22:.*]] = bufferization.to_tensor %[[VAL_13]] : memref<10x20x30xf32>
1325// CHECK:           return %[[VAL_22]] : tensor<10x20x30xf32>
1326// CHECK:         }
1327func.func @invariants(%arga: tensor<10xf32, #Td>,
1328                 %argb: tensor<20xf32>,
1329                 %argc: tensor<30xf32>,
1330                 %argx: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> {
1331  %0 = linalg.generic #trait_invariants
1332     ins(%arga, %argb, %argc : tensor<10xf32, #Td>, tensor<20xf32>, tensor<30xf32>)
1333    outs(%argx: tensor<10x20x30xf32>) {
1334      ^bb(%a: f32, %b: f32, %c: f32, %x: f32):
1335        %0 = arith.mulf %a, %b : f32
1336        %1 = arith.mulf %0, %c : f32
1337        linalg.yield %1 : f32
1338  } -> tensor<10x20x30xf32>
1339  return %0 : tensor<10x20x30xf32>
1340}
1341