xref: /inferno-os/libmath/gemm.c (revision 37da2899f40661e3e9631e497da8dc59b971cbd0)
1 #include "lib9.h"
2 #include "mathi.h"
3 void
gemm(int transa,int transb,int m,int n,int k,double alpha,double * a,int lda,double * b,int ldb,double beta,double * c,int ldc)4 gemm(int transa, int transb, int m, int n, int k, double alpha,
5 	double *a, int lda,
6 	double *b, int ldb, double beta,
7 	double *c, int ldc)
8 {
9     int i1, i2, i3, nota, notb, i, j, jb, jc, l, la;
10     double temp;
11 
12     nota = transa=='N';
13     notb = transb=='N';
14 
15     if(m == 0 || n == 0 || (alpha == 0. || k == 0) && beta == 1.){
16 	return;
17     }
18     if(alpha == 0.){
19 	if(beta == 0.){
20 	    i1 = n;
21 	    for(j = 0; j < i1; ++j){
22 		jc = j*ldc;
23 		i2 = m;
24 		for(i = 0; i < i2; ++i){
25 		    c[i + jc] = 0.;
26 		}
27 	    }
28 	}else{
29 	    i1 = n;
30 	    for(j = 0; j < i1; ++j){
31 		jc = j*ldc;
32 		i2 = m;
33 		for(i = 0; i < i2; ++i){
34 		    c[i + jc] = beta * c[i + jc];
35 		}
36 	    }
37 	}
38 	return;
39     }
40 
41     if(!a){
42 	if(notb){   /* C := alpha*B + beta*C. */
43 	    i1 = n;
44 	    for(j = 0; j < i1; ++j){
45 		jb = j*ldb;
46 		jc = j*ldc;
47 		i2 = m;
48 		for(i = 0; i < i2; ++i){
49 		    c[i + jc] = alpha*b[i+jb] + beta*c[i+jc];
50 		}
51 	    }
52 	}else{   /* C := alpha*B' + beta*C. */
53 	    i1 = n;
54 	    for(j = 0; j < i1; ++j){
55 		jc = j*ldc;
56 		i2 = m;
57 		for(i = 0; i < i2; ++i){
58 		    c[i + jc] = alpha*b[j+i*ldb] + beta*c[i+jc];
59 		}
60 	    }
61 	}
62 	return;
63     }
64 
65     if(notb){
66 	if(nota){
67 
68 /*          Form  C := alpha*A*B + beta*C. */
69 	    i1 = n;
70 	    for(j = 0; j < i1; ++j){
71 		jc = j*ldc;
72 		if(beta == 0.){
73 		    i2 = m;
74 		    for(i = 0; i < i2; ++i){
75 			c[i + jc] = 0.;
76 		    }
77 		}else if(beta != 1.){
78 		    i2 = m;
79 		    for(i = 0; i < i2; ++i){
80 			c[i + jc] = beta * c[i + jc];
81 		    }
82 		}
83 		i2 = k;
84 		for(l = 0; l < i2; ++l){
85 		    la = l*lda;
86 		    if(b[l + j*ldb] != 0.){
87 			temp = alpha * b[l + j*ldb];
88 			i3 = m;
89 			for(i = 0; i < i3; ++i){
90 			    c[i + jc] += temp * a[i + la];
91 			}
92 		    }
93 		}
94 	    }
95 	}else{
96 
97 /*          Form  C := alpha*A'*B + beta*C */
98 	    i1 = n;
99 	    for(j = 0; j < i1; ++j){
100 		jc = j*ldc;
101 		i2 = m;
102 		for(i = 0; i < i2; ++i){
103 		    temp = 0.;
104 		    i3 = k;
105 		    for(l = 0; l < i3; ++l){
106 			temp += a[l + i*lda] * b[l + j*ldb];
107 		    }
108 		    if(beta == 0.){
109 			c[i + jc] = alpha * temp;
110 		    }else{
111 			c[i + jc] = alpha * temp + beta * c[i + jc];
112 		    }
113 		}
114 	    }
115 	}
116     }else{
117 	if(nota){
118 
119 /*          Form  C := alpha*A*B' + beta*C */
120 	    i1 = n;
121 	    for(j = 0; j < i1; ++j){
122 		jc = j*ldc;
123 		if(beta == 0.){
124 		    i2 = m;
125 		    for(i = 0; i < i2; ++i){
126 			c[i + jc] = 0.;
127 		    }
128 		}else if(beta != 1.){
129 		    i2 = m;
130 		    for(i = 0; i < i2; ++i){
131 			c[i + jc] = beta * c[i + jc];
132 		    }
133 		}
134 		i2 = k;
135 		for(l = 0; l < i2; ++l){
136 		    if(b[j + l*ldb] != 0.){
137 			temp = alpha * b[j + l*ldb];
138 			i3 = m;
139 			for(i = 0; i < i3; ++i){
140 			    c[i + jc] += temp * a[i + l*lda];
141 			}
142 		    }
143 		}
144 	    }
145 	}else{
146 
147 /*          Form  C := alpha*A'*B' + beta*C */
148 	    i1 = n;
149 	    for(j = 0; j < i1; ++j){
150 		jc = j*ldc;
151 		i2 = m;
152 		for(i = 0; i < i2; ++i){
153 		    temp = 0.;
154 		    i3 = k;
155 		    for(l = 0; l < i3; ++l){
156 			temp += a[l + i*lda] * b[j + l*ldb];
157 		    }
158 		    if(beta == 0.){
159 			c[i + jc] = alpha * temp;
160 		    }else{
161 			c[i + jc] = alpha * temp + beta * c[i + jc];
162 		    }
163 		}
164 	    }
165 	}
166     }
167 }
168