1181254a7Smrg /* Implementation of the MATMUL intrinsic
2*b1e83836Smrg Copyright (C) 2002-2022 Free Software Foundation, Inc.
3181254a7Smrg Contributed by Paul Brook <paul@nowt.org>
4181254a7Smrg
5181254a7Smrg This file is part of the GNU Fortran runtime library (libgfortran).
6181254a7Smrg
7181254a7Smrg Libgfortran is free software; you can redistribute it and/or
8181254a7Smrg modify it under the terms of the GNU General Public
9181254a7Smrg License as published by the Free Software Foundation; either
10181254a7Smrg version 3 of the License, or (at your option) any later version.
11181254a7Smrg
12181254a7Smrg Libgfortran is distributed in the hope that it will be useful,
13181254a7Smrg but WITHOUT ANY WARRANTY; without even the implied warranty of
14181254a7Smrg MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15181254a7Smrg GNU General Public License for more details.
16181254a7Smrg
17181254a7Smrg Under Section 7 of GPL version 3, you are granted additional
18181254a7Smrg permissions described in the GCC Runtime Library Exception, version
19181254a7Smrg 3.1, as published by the Free Software Foundation.
20181254a7Smrg
21181254a7Smrg You should have received a copy of the GNU General Public License and
22181254a7Smrg a copy of the GCC Runtime Library Exception along with this program;
23181254a7Smrg see the files COPYING3 and COPYING.RUNTIME respectively. If not, see
24181254a7Smrg <http://www.gnu.org/licenses/>. */
25181254a7Smrg
26181254a7Smrg #include "libgfortran.h"
27181254a7Smrg #include <assert.h>
28181254a7Smrg
29181254a7Smrg
30181254a7Smrg #if defined (HAVE_GFC_LOGICAL_4)
31181254a7Smrg
32181254a7Smrg /* Dimensions: retarray(x,y) a(x, count) b(count,y).
33181254a7Smrg Either a or b can be rank 1. In this case x or y is 1. */
34181254a7Smrg
35181254a7Smrg extern void matmul_l4 (gfc_array_l4 * const restrict,
36181254a7Smrg gfc_array_l1 * const restrict, gfc_array_l1 * const restrict);
37181254a7Smrg export_proto(matmul_l4);
38181254a7Smrg
39181254a7Smrg void
matmul_l4(gfc_array_l4 * const restrict retarray,gfc_array_l1 * const restrict a,gfc_array_l1 * const restrict b)40181254a7Smrg matmul_l4 (gfc_array_l4 * const restrict retarray,
41181254a7Smrg gfc_array_l1 * const restrict a, gfc_array_l1 * const restrict b)
42181254a7Smrg {
43181254a7Smrg const GFC_LOGICAL_1 * restrict abase;
44181254a7Smrg const GFC_LOGICAL_1 * restrict bbase;
45181254a7Smrg GFC_LOGICAL_4 * restrict dest;
46181254a7Smrg index_type rxstride;
47181254a7Smrg index_type rystride;
48181254a7Smrg index_type xcount;
49181254a7Smrg index_type ycount;
50181254a7Smrg index_type xstride;
51181254a7Smrg index_type ystride;
52181254a7Smrg index_type x;
53181254a7Smrg index_type y;
54181254a7Smrg int a_kind;
55181254a7Smrg int b_kind;
56181254a7Smrg
57181254a7Smrg const GFC_LOGICAL_1 * restrict pa;
58181254a7Smrg const GFC_LOGICAL_1 * restrict pb;
59181254a7Smrg index_type astride;
60181254a7Smrg index_type bstride;
61181254a7Smrg index_type count;
62181254a7Smrg index_type n;
63181254a7Smrg
64181254a7Smrg assert (GFC_DESCRIPTOR_RANK (a) == 2
65181254a7Smrg || GFC_DESCRIPTOR_RANK (b) == 2);
66181254a7Smrg
67181254a7Smrg if (retarray->base_addr == NULL)
68181254a7Smrg {
69181254a7Smrg if (GFC_DESCRIPTOR_RANK (a) == 1)
70181254a7Smrg {
71181254a7Smrg GFC_DIMENSION_SET(retarray->dim[0], 0,
72181254a7Smrg GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
73181254a7Smrg }
74181254a7Smrg else if (GFC_DESCRIPTOR_RANK (b) == 1)
75181254a7Smrg {
76181254a7Smrg GFC_DIMENSION_SET(retarray->dim[0], 0,
77181254a7Smrg GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
78181254a7Smrg }
79181254a7Smrg else
80181254a7Smrg {
81181254a7Smrg GFC_DIMENSION_SET(retarray->dim[0], 0,
82181254a7Smrg GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
83181254a7Smrg
84181254a7Smrg GFC_DIMENSION_SET(retarray->dim[1], 0,
85181254a7Smrg GFC_DESCRIPTOR_EXTENT(b,1) - 1,
86181254a7Smrg GFC_DESCRIPTOR_EXTENT(retarray,0));
87181254a7Smrg }
88181254a7Smrg
89181254a7Smrg retarray->base_addr
90181254a7Smrg = xmallocarray (size0 ((array_t *) retarray), sizeof (GFC_LOGICAL_4));
91181254a7Smrg retarray->offset = 0;
92181254a7Smrg }
93181254a7Smrg else if (unlikely (compile_options.bounds_check))
94181254a7Smrg {
95181254a7Smrg index_type ret_extent, arg_extent;
96181254a7Smrg
97181254a7Smrg if (GFC_DESCRIPTOR_RANK (a) == 1)
98181254a7Smrg {
99181254a7Smrg arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
100181254a7Smrg ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
101181254a7Smrg if (arg_extent != ret_extent)
102181254a7Smrg runtime_error ("Incorrect extent in return array in"
103181254a7Smrg " MATMUL intrinsic: is %ld, should be %ld",
104181254a7Smrg (long int) ret_extent, (long int) arg_extent);
105181254a7Smrg }
106181254a7Smrg else if (GFC_DESCRIPTOR_RANK (b) == 1)
107181254a7Smrg {
108181254a7Smrg arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
109181254a7Smrg ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
110181254a7Smrg if (arg_extent != ret_extent)
111181254a7Smrg runtime_error ("Incorrect extent in return array in"
112181254a7Smrg " MATMUL intrinsic: is %ld, should be %ld",
113181254a7Smrg (long int) ret_extent, (long int) arg_extent);
114181254a7Smrg }
115181254a7Smrg else
116181254a7Smrg {
117181254a7Smrg arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
118181254a7Smrg ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
119181254a7Smrg if (arg_extent != ret_extent)
120181254a7Smrg runtime_error ("Incorrect extent in return array in"
121181254a7Smrg " MATMUL intrinsic for dimension 1:"
122181254a7Smrg " is %ld, should be %ld",
123181254a7Smrg (long int) ret_extent, (long int) arg_extent);
124181254a7Smrg
125181254a7Smrg arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
126181254a7Smrg ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
127181254a7Smrg if (arg_extent != ret_extent)
128181254a7Smrg runtime_error ("Incorrect extent in return array in"
129181254a7Smrg " MATMUL intrinsic for dimension 2:"
130181254a7Smrg " is %ld, should be %ld",
131181254a7Smrg (long int) ret_extent, (long int) arg_extent);
132181254a7Smrg }
133181254a7Smrg }
134181254a7Smrg
135181254a7Smrg abase = a->base_addr;
136181254a7Smrg a_kind = GFC_DESCRIPTOR_SIZE (a);
137181254a7Smrg
138181254a7Smrg if (a_kind == 1 || a_kind == 2 || a_kind == 4 || a_kind == 8
139181254a7Smrg #ifdef HAVE_GFC_LOGICAL_16
140181254a7Smrg || a_kind == 16
141181254a7Smrg #endif
142181254a7Smrg )
143181254a7Smrg abase = GFOR_POINTER_TO_L1 (abase, a_kind);
144181254a7Smrg else
145181254a7Smrg internal_error (NULL, "Funny sized logical array");
146181254a7Smrg
147181254a7Smrg bbase = b->base_addr;
148181254a7Smrg b_kind = GFC_DESCRIPTOR_SIZE (b);
149181254a7Smrg
150181254a7Smrg if (b_kind == 1 || b_kind == 2 || b_kind == 4 || b_kind == 8
151181254a7Smrg #ifdef HAVE_GFC_LOGICAL_16
152181254a7Smrg || b_kind == 16
153181254a7Smrg #endif
154181254a7Smrg )
155181254a7Smrg bbase = GFOR_POINTER_TO_L1 (bbase, b_kind);
156181254a7Smrg else
157181254a7Smrg internal_error (NULL, "Funny sized logical array");
158181254a7Smrg
159181254a7Smrg dest = retarray->base_addr;
160181254a7Smrg
161181254a7Smrg
162181254a7Smrg if (GFC_DESCRIPTOR_RANK (retarray) == 1)
163181254a7Smrg {
164181254a7Smrg rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
165181254a7Smrg rystride = rxstride;
166181254a7Smrg }
167181254a7Smrg else
168181254a7Smrg {
169181254a7Smrg rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
170181254a7Smrg rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
171181254a7Smrg }
172181254a7Smrg
173181254a7Smrg /* If we have rank 1 parameters, zero the absent stride, and set the size to
174181254a7Smrg one. */
175181254a7Smrg if (GFC_DESCRIPTOR_RANK (a) == 1)
176181254a7Smrg {
177181254a7Smrg astride = GFC_DESCRIPTOR_STRIDE_BYTES(a,0);
178181254a7Smrg count = GFC_DESCRIPTOR_EXTENT(a,0);
179181254a7Smrg xstride = 0;
180181254a7Smrg rxstride = 0;
181181254a7Smrg xcount = 1;
182181254a7Smrg }
183181254a7Smrg else
184181254a7Smrg {
185181254a7Smrg astride = GFC_DESCRIPTOR_STRIDE_BYTES(a,1);
186181254a7Smrg count = GFC_DESCRIPTOR_EXTENT(a,1);
187181254a7Smrg xstride = GFC_DESCRIPTOR_STRIDE_BYTES(a,0);
188181254a7Smrg xcount = GFC_DESCRIPTOR_EXTENT(a,0);
189181254a7Smrg }
190181254a7Smrg if (GFC_DESCRIPTOR_RANK (b) == 1)
191181254a7Smrg {
192181254a7Smrg bstride = GFC_DESCRIPTOR_STRIDE_BYTES(b,0);
193181254a7Smrg assert(count == GFC_DESCRIPTOR_EXTENT(b,0));
194181254a7Smrg ystride = 0;
195181254a7Smrg rystride = 0;
196181254a7Smrg ycount = 1;
197181254a7Smrg }
198181254a7Smrg else
199181254a7Smrg {
200181254a7Smrg bstride = GFC_DESCRIPTOR_STRIDE_BYTES(b,0);
201181254a7Smrg assert(count == GFC_DESCRIPTOR_EXTENT(b,0));
202181254a7Smrg ystride = GFC_DESCRIPTOR_STRIDE_BYTES(b,1);
203181254a7Smrg ycount = GFC_DESCRIPTOR_EXTENT(b,1);
204181254a7Smrg }
205181254a7Smrg
206181254a7Smrg for (y = 0; y < ycount; y++)
207181254a7Smrg {
208181254a7Smrg for (x = 0; x < xcount; x++)
209181254a7Smrg {
210181254a7Smrg /* Do the summation for this element. For real and integer types
211181254a7Smrg this is the same as DOT_PRODUCT. For complex types we use do
212181254a7Smrg a*b, not conjg(a)*b. */
213181254a7Smrg pa = abase;
214181254a7Smrg pb = bbase;
215181254a7Smrg *dest = 0;
216181254a7Smrg
217181254a7Smrg for (n = 0; n < count; n++)
218181254a7Smrg {
219181254a7Smrg if (*pa && *pb)
220181254a7Smrg {
221181254a7Smrg *dest = 1;
222181254a7Smrg break;
223181254a7Smrg }
224181254a7Smrg pa += astride;
225181254a7Smrg pb += bstride;
226181254a7Smrg }
227181254a7Smrg
228181254a7Smrg dest += rxstride;
229181254a7Smrg abase += xstride;
230181254a7Smrg }
231181254a7Smrg abase -= xstride * xcount;
232181254a7Smrg bbase += ystride;
233181254a7Smrg dest += rystride - (rxstride * xcount);
234181254a7Smrg }
235181254a7Smrg }
236181254a7Smrg
237181254a7Smrg #endif
238181254a7Smrg
239