xref: /netbsd-src/external/gpl3/gcc.old/dist/libgfortran/m4/matmull.m4 (revision 4c3eb207d36f67d31994830c0a694161fc1ca39b)
1627f7eb2Smrg`/* Implementation of the MATMUL intrinsic
2*4c3eb207Smrg   Copyright (C) 2002-2020 Free Software Foundation, Inc.
3627f7eb2Smrg   Contributed by Paul Brook <paul@nowt.org>
4627f7eb2Smrg
5627f7eb2SmrgThis file is part of the GNU Fortran runtime library (libgfortran).
6627f7eb2Smrg
7627f7eb2SmrgLibgfortran is free software; you can redistribute it and/or
8627f7eb2Smrgmodify it under the terms of the GNU General Public
9627f7eb2SmrgLicense as published by the Free Software Foundation; either
10627f7eb2Smrgversion 3 of the License, or (at your option) any later version.
11627f7eb2Smrg
12627f7eb2SmrgLibgfortran is distributed in the hope that it will be useful,
13627f7eb2Smrgbut WITHOUT ANY WARRANTY; without even the implied warranty of
14627f7eb2SmrgMERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15627f7eb2SmrgGNU General Public License for more details.
16627f7eb2Smrg
17627f7eb2SmrgUnder Section 7 of GPL version 3, you are granted additional
18627f7eb2Smrgpermissions described in the GCC Runtime Library Exception, version
19627f7eb2Smrg3.1, as published by the Free Software Foundation.
20627f7eb2Smrg
21627f7eb2SmrgYou should have received a copy of the GNU General Public License and
22627f7eb2Smrga copy of the GCC Runtime Library Exception along with this program;
23627f7eb2Smrgsee the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
24627f7eb2Smrg<http://www.gnu.org/licenses/>.  */
25627f7eb2Smrg
26627f7eb2Smrg#include "libgfortran.h"
27627f7eb2Smrg#include <assert.h>'
28627f7eb2Smrg
29627f7eb2Smrginclude(iparm.m4)dnl
30627f7eb2Smrg
31627f7eb2Smrg`#if defined (HAVE_'rtype_name`)
32627f7eb2Smrg
33627f7eb2Smrg/* Dimensions: retarray(x,y) a(x, count) b(count,y).
34627f7eb2Smrg   Either a or b can be rank 1.  In this case x or y is 1.  */
35627f7eb2Smrg
36627f7eb2Smrgextern void matmul_'rtype_code` ('rtype` * const restrict,
37627f7eb2Smrg	gfc_array_l1 * const restrict, gfc_array_l1 * const restrict);
38627f7eb2Smrgexport_proto(matmul_'rtype_code`);
39627f7eb2Smrg
40627f7eb2Smrgvoid
41627f7eb2Smrgmatmul_'rtype_code` ('rtype` * const restrict retarray,
42627f7eb2Smrg	gfc_array_l1 * const restrict a, gfc_array_l1 * const restrict b)
43627f7eb2Smrg{
44627f7eb2Smrg  const GFC_LOGICAL_1 * restrict abase;
45627f7eb2Smrg  const GFC_LOGICAL_1 * restrict bbase;
46627f7eb2Smrg  'rtype_name` * restrict dest;
47627f7eb2Smrg  index_type rxstride;
48627f7eb2Smrg  index_type rystride;
49627f7eb2Smrg  index_type xcount;
50627f7eb2Smrg  index_type ycount;
51627f7eb2Smrg  index_type xstride;
52627f7eb2Smrg  index_type ystride;
53627f7eb2Smrg  index_type x;
54627f7eb2Smrg  index_type y;
55627f7eb2Smrg  int a_kind;
56627f7eb2Smrg  int b_kind;
57627f7eb2Smrg
58627f7eb2Smrg  const GFC_LOGICAL_1 * restrict pa;
59627f7eb2Smrg  const GFC_LOGICAL_1 * restrict pb;
60627f7eb2Smrg  index_type astride;
61627f7eb2Smrg  index_type bstride;
62627f7eb2Smrg  index_type count;
63627f7eb2Smrg  index_type n;
64627f7eb2Smrg
65627f7eb2Smrg  assert (GFC_DESCRIPTOR_RANK (a) == 2
66627f7eb2Smrg          || GFC_DESCRIPTOR_RANK (b) == 2);
67627f7eb2Smrg
68627f7eb2Smrg  if (retarray->base_addr == NULL)
69627f7eb2Smrg    {
70627f7eb2Smrg      if (GFC_DESCRIPTOR_RANK (a) == 1)
71627f7eb2Smrg        {
72627f7eb2Smrg	  GFC_DIMENSION_SET(retarray->dim[0], 0,
73627f7eb2Smrg	                    GFC_DESCRIPTOR_EXTENT(b,1) - 1, 1);
74627f7eb2Smrg        }
75627f7eb2Smrg      else if (GFC_DESCRIPTOR_RANK (b) == 1)
76627f7eb2Smrg        {
77627f7eb2Smrg	  GFC_DIMENSION_SET(retarray->dim[0], 0,
78627f7eb2Smrg	                    GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
79627f7eb2Smrg        }
80627f7eb2Smrg      else
81627f7eb2Smrg        {
82627f7eb2Smrg	  GFC_DIMENSION_SET(retarray->dim[0], 0,
83627f7eb2Smrg	                    GFC_DESCRIPTOR_EXTENT(a,0) - 1, 1);
84627f7eb2Smrg
85627f7eb2Smrg          GFC_DIMENSION_SET(retarray->dim[1], 0,
86627f7eb2Smrg	                    GFC_DESCRIPTOR_EXTENT(b,1) - 1,
87627f7eb2Smrg			    GFC_DESCRIPTOR_EXTENT(retarray,0));
88627f7eb2Smrg        }
89627f7eb2Smrg
90627f7eb2Smrg      retarray->base_addr
91627f7eb2Smrg	= xmallocarray (size0 ((array_t *) retarray), sizeof ('rtype_name`));
92627f7eb2Smrg      retarray->offset = 0;
93627f7eb2Smrg    }
94627f7eb2Smrg    else if (unlikely (compile_options.bounds_check))
95627f7eb2Smrg      {
96627f7eb2Smrg	index_type ret_extent, arg_extent;
97627f7eb2Smrg
98627f7eb2Smrg	if (GFC_DESCRIPTOR_RANK (a) == 1)
99627f7eb2Smrg	  {
100627f7eb2Smrg	    arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
101627f7eb2Smrg	    ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
102627f7eb2Smrg	    if (arg_extent != ret_extent)
103627f7eb2Smrg	      runtime_error ("Incorrect extent in return array in"
104627f7eb2Smrg			     " MATMUL intrinsic: is %ld, should be %ld",
105627f7eb2Smrg			     (long int) ret_extent, (long int) arg_extent);
106627f7eb2Smrg	  }
107627f7eb2Smrg	else if (GFC_DESCRIPTOR_RANK (b) == 1)
108627f7eb2Smrg	  {
109627f7eb2Smrg	    arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
110627f7eb2Smrg	    ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
111627f7eb2Smrg	    if (arg_extent != ret_extent)
112627f7eb2Smrg	      runtime_error ("Incorrect extent in return array in"
113627f7eb2Smrg			     " MATMUL intrinsic: is %ld, should be %ld",
114627f7eb2Smrg			     (long int) ret_extent, (long int) arg_extent);
115627f7eb2Smrg	  }
116627f7eb2Smrg	else
117627f7eb2Smrg	  {
118627f7eb2Smrg	    arg_extent = GFC_DESCRIPTOR_EXTENT(a,0);
119627f7eb2Smrg	    ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,0);
120627f7eb2Smrg	    if (arg_extent != ret_extent)
121627f7eb2Smrg	      runtime_error ("Incorrect extent in return array in"
122627f7eb2Smrg			     " MATMUL intrinsic for dimension 1:"
123627f7eb2Smrg			     " is %ld, should be %ld",
124627f7eb2Smrg			     (long int) ret_extent, (long int) arg_extent);
125627f7eb2Smrg
126627f7eb2Smrg	    arg_extent = GFC_DESCRIPTOR_EXTENT(b,1);
127627f7eb2Smrg	    ret_extent = GFC_DESCRIPTOR_EXTENT(retarray,1);
128627f7eb2Smrg	    if (arg_extent != ret_extent)
129627f7eb2Smrg	      runtime_error ("Incorrect extent in return array in"
130627f7eb2Smrg			     " MATMUL intrinsic for dimension 2:"
131627f7eb2Smrg			     " is %ld, should be %ld",
132627f7eb2Smrg			     (long int) ret_extent, (long int) arg_extent);
133627f7eb2Smrg	  }
134627f7eb2Smrg      }
135627f7eb2Smrg
136627f7eb2Smrg  abase = a->base_addr;
137627f7eb2Smrg  a_kind = GFC_DESCRIPTOR_SIZE (a);
138627f7eb2Smrg
139627f7eb2Smrg  if (a_kind == 1 || a_kind == 2 || a_kind == 4 || a_kind == 8
140627f7eb2Smrg#ifdef HAVE_GFC_LOGICAL_16
141627f7eb2Smrg     || a_kind == 16
142627f7eb2Smrg#endif
143627f7eb2Smrg     )
144627f7eb2Smrg    abase = GFOR_POINTER_TO_L1 (abase, a_kind);
145627f7eb2Smrg  else
146627f7eb2Smrg    internal_error (NULL, "Funny sized logical array");
147627f7eb2Smrg
148627f7eb2Smrg  bbase = b->base_addr;
149627f7eb2Smrg  b_kind = GFC_DESCRIPTOR_SIZE (b);
150627f7eb2Smrg
151627f7eb2Smrg  if (b_kind == 1 || b_kind == 2 || b_kind == 4 || b_kind == 8
152627f7eb2Smrg#ifdef HAVE_GFC_LOGICAL_16
153627f7eb2Smrg     || b_kind == 16
154627f7eb2Smrg#endif
155627f7eb2Smrg     )
156627f7eb2Smrg    bbase = GFOR_POINTER_TO_L1 (bbase, b_kind);
157627f7eb2Smrg  else
158627f7eb2Smrg    internal_error (NULL, "Funny sized logical array");
159627f7eb2Smrg
160627f7eb2Smrg  dest = retarray->base_addr;
161627f7eb2Smrg'
162627f7eb2Smrgsinclude(`matmul_asm_'rtype_code`.m4')dnl
163627f7eb2Smrg`
164627f7eb2Smrg  if (GFC_DESCRIPTOR_RANK (retarray) == 1)
165627f7eb2Smrg    {
166627f7eb2Smrg      rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
167627f7eb2Smrg      rystride = rxstride;
168627f7eb2Smrg    }
169627f7eb2Smrg  else
170627f7eb2Smrg    {
171627f7eb2Smrg      rxstride = GFC_DESCRIPTOR_STRIDE(retarray,0);
172627f7eb2Smrg      rystride = GFC_DESCRIPTOR_STRIDE(retarray,1);
173627f7eb2Smrg    }
174627f7eb2Smrg
175627f7eb2Smrg  /* If we have rank 1 parameters, zero the absent stride, and set the size to
176627f7eb2Smrg     one.  */
177627f7eb2Smrg  if (GFC_DESCRIPTOR_RANK (a) == 1)
178627f7eb2Smrg    {
179627f7eb2Smrg      astride = GFC_DESCRIPTOR_STRIDE_BYTES(a,0);
180627f7eb2Smrg      count = GFC_DESCRIPTOR_EXTENT(a,0);
181627f7eb2Smrg      xstride = 0;
182627f7eb2Smrg      rxstride = 0;
183627f7eb2Smrg      xcount = 1;
184627f7eb2Smrg    }
185627f7eb2Smrg  else
186627f7eb2Smrg    {
187627f7eb2Smrg      astride = GFC_DESCRIPTOR_STRIDE_BYTES(a,1);
188627f7eb2Smrg      count = GFC_DESCRIPTOR_EXTENT(a,1);
189627f7eb2Smrg      xstride = GFC_DESCRIPTOR_STRIDE_BYTES(a,0);
190627f7eb2Smrg      xcount = GFC_DESCRIPTOR_EXTENT(a,0);
191627f7eb2Smrg    }
192627f7eb2Smrg  if (GFC_DESCRIPTOR_RANK (b) == 1)
193627f7eb2Smrg    {
194627f7eb2Smrg      bstride = GFC_DESCRIPTOR_STRIDE_BYTES(b,0);
195627f7eb2Smrg      assert(count == GFC_DESCRIPTOR_EXTENT(b,0));
196627f7eb2Smrg      ystride = 0;
197627f7eb2Smrg      rystride = 0;
198627f7eb2Smrg      ycount = 1;
199627f7eb2Smrg    }
200627f7eb2Smrg  else
201627f7eb2Smrg    {
202627f7eb2Smrg      bstride = GFC_DESCRIPTOR_STRIDE_BYTES(b,0);
203627f7eb2Smrg      assert(count == GFC_DESCRIPTOR_EXTENT(b,0));
204627f7eb2Smrg      ystride = GFC_DESCRIPTOR_STRIDE_BYTES(b,1);
205627f7eb2Smrg      ycount = GFC_DESCRIPTOR_EXTENT(b,1);
206627f7eb2Smrg    }
207627f7eb2Smrg
208627f7eb2Smrg  for (y = 0; y < ycount; y++)
209627f7eb2Smrg    {
210627f7eb2Smrg      for (x = 0; x < xcount; x++)
211627f7eb2Smrg        {
212627f7eb2Smrg          /* Do the summation for this element.  For real and integer types
213627f7eb2Smrg             this is the same as DOT_PRODUCT.  For complex types we use do
214627f7eb2Smrg             a*b, not conjg(a)*b.  */
215627f7eb2Smrg          pa = abase;
216627f7eb2Smrg          pb = bbase;
217627f7eb2Smrg          *dest = 0;
218627f7eb2Smrg
219627f7eb2Smrg          for (n = 0; n < count; n++)
220627f7eb2Smrg            {
221627f7eb2Smrg              if (*pa && *pb)
222627f7eb2Smrg                {
223627f7eb2Smrg                  *dest = 1;
224627f7eb2Smrg                  break;
225627f7eb2Smrg                }
226627f7eb2Smrg              pa += astride;
227627f7eb2Smrg              pb += bstride;
228627f7eb2Smrg            }
229627f7eb2Smrg
230627f7eb2Smrg          dest += rxstride;
231627f7eb2Smrg          abase += xstride;
232627f7eb2Smrg        }
233627f7eb2Smrg      abase -= xstride * xcount;
234627f7eb2Smrg      bbase += ystride;
235627f7eb2Smrg      dest += rystride - (rxstride * xcount);
236627f7eb2Smrg    }
237627f7eb2Smrg}
238627f7eb2Smrg
239627f7eb2Smrg#endif
240627f7eb2Smrg'
241