xref: /inferno-os/appl/math/linalg.b (revision 37da2899f40661e3e9631e497da8dc59b971cbd0)
1implement LinAlg;
2
3include "sys.m";
4sys: Sys;
5print: import sys;
6
7include "math.m";
8math: Math;
9ceil, fabs, floor, Infinity, log10, pow10, sqrt: import math;
10dot, gemm, iamax: import math;
11
12include "linalg.m";
13
14# print a matrix in MATLAB-compatible format
15printmat(label:string, a:array of real, lda, m, n:int)
16{
17	if(m>30 || n>10)
18		return;
19	if(sys==nil){
20		sys = load Sys Sys->PATH;
21		math = load Math Math->PATH;
22	}
23	print("%% %d by %d matrix\n",m,n);
24	print("%s = [",label);
25	for(i:=0; i<m; i++){
26		print("%.4g",a[i]);
27		for(j:=1; j<n; j++)
28			print(", %.4g",a[i+lda*j]);
29		if(i==m-1)
30			print("]\n");
31		else
32			print(";\n");
33	}
34}
35
36
37# Constant times a vector plus a vector.
38daxpy(da:real, dx:array of real, dy:array of real)
39{
40	n := len dx;
41	gemm('N','N',n,1,n,da,nil,0,dx,n,1.,dy,n);
42}
43
44# Scales a vector by a constant.
45dscal(da:real, dx:array of real)
46{
47	n := len dx;
48	gemm('N','N',n,1,n,0.,nil,0,nil,0,da,dx,n);
49}
50
51# gaussian elimination with partial pivoting
52#   dgefa factors a double precision matrix by gaussian elimination.
53#   dgefa is usually called by dgeco, but it can be called
54#   directly with a saving in time if  rcond  is not needed.
55#   (time for dgeco) = (1 + 9/n)*(time for dgefa) .
56#   on entry
57#      a       REAL precision[n][lda]
58#	      the matrix to be factored.
59#      lda     integer
60#	      the leading dimension of the array  a .
61#      n       integer
62#	      the order of the matrix  a .
63#   on return
64#      a       an upper triangular matrix and the multipliers
65#	      which were used to obtain it.
66#	      the factorization can be written  a = l*u  where
67#	      l  is a product of permutation and unit lower
68#	      triangular matrices and  u  is upper triangular.
69#      ipvt    integer[n]
70#	      an integer vector of pivot indices.
71#      info    integer
72#	      = 0  normal value.
73#	      = k  if  u[k][k] .eq. 0.0 .  this is not an error
74#		   condition for this subroutine, but it does
75#		   indicate that dgesl or dgedi will divide by zero
76#		   if called.  use  rcond  in dgeco for a reliable
77#		   indication of singularity.
78dgefa(a:array of real, lda, n:int, ipvt:array of int): int
79{
80	if(sys==nil){
81		sys = load Sys Sys->PATH;
82		math = load Math Math->PATH;
83	}
84	info := 0;
85	nm1 := n - 1;
86	if(nm1 >= 0)
87	    for(k := 0; k < nm1; k++){
88		kp1 := k + 1;
89		ldak := lda*k;
90
91		# find l = pivot index
92		l := iamax(a[ldak+k:ldak+n]) + k;
93		ipvt[k] = l;
94
95		# zero pivot implies this column already triangularized
96		if(a[ldak+l]!=0.){
97
98		    # interchange if necessary
99		    if(l!=k){
100			t := a[ldak+l];
101			a[ldak+l] = a[ldak+k];
102			a[ldak+k] = t;
103		    }
104
105		    # compute multipliers
106		    t := -1./a[ldak+k];
107		    dscal(t,a[ldak+k+1:ldak+n]);
108
109		    # row elimination with column indexing
110		    for(j := kp1; j < n; j++){
111			ldaj := lda*j;
112			t = a[ldaj+l];
113			if(l!=k){
114			    a[ldaj+l] = a[ldaj+k];
115			    a[ldaj+k] = t;
116			}
117			daxpy(t,a[ldak+k+1:ldak+n],a[ldaj+k+1:ldaj+n]);
118		    }
119		}else
120		    info = k;
121	    }
122	ipvt[n-1] = n-1;
123	if(a[lda*(n-1)+(n-1)] == 0.)
124	    info = n-1;
125	return info;
126}
127
128
129#   dgesl solves the double precision system
130#   a * x = b  or  trans(a) * x = b
131#   using the factors computed by dgeco or dgefa.
132#   on entry
133#      a       double precision[n][lda]
134#	      the output from dgeco or dgefa.
135#      lda     integer
136#	      the leading dimension of the array  a .
137#      n       integer
138#	      the order of the matrix  a .
139#      ipvt    integer[n]
140#	      the pivot vector from dgeco or dgefa.
141#      b       double precision[n]
142#	      the right hand side vector.
143#      job     integer
144#	      = 0	 to solve  a*x = b ,
145#	      = nonzero   to solve  trans(a)*x = b  where
146#			  trans(a)  is the transpose.
147#  on return
148#      b       the solution vector  x .
149#   error condition
150#      a division by zero will occur if the input factor contains a
151#      zero on the diagonal.  technically this indicates singularity
152#      but it is often caused by improper arguments or improper
153#      setting of lda.
154dgesl(a:array of real, lda, n:int, ipvt:array of int, b:array of real, job:int)
155{
156	nm1 := n - 1;
157	if(job == 0){	# job = 0 , solve  a * x = b
158	    # first solve  l*y = b
159	    if(nm1 >= 1)
160		for(k := 0; k < nm1; k++){
161		    l := ipvt[k];
162		    t := b[l];
163		    if(l!=k){
164			b[l] = b[k];
165			b[k] = t;
166		    }
167		    daxpy(t,a[lda*k+k+1:lda*k+n],b[k+1:n]);
168		}
169
170	    # now solve  u*x = y
171	    for(kb := 0; kb < n; kb++){
172		k = n - (kb + 1);
173		b[k] = b[k]/a[lda*k+k];
174		t := -b[k];
175		daxpy(t,a[lda*k:lda*k+k],b[0:k]);
176	    }
177	}else{	# job = nonzero, solve  trans(a) * x = b
178	    # first solve  trans(u)*y = b
179	    for(k := 0; k < n; k++){
180		t := dot(a[lda*k:lda*k+k],b[0:k]);
181		b[k] = (b[k] - t)/a[lda*k+k];
182	    }
183
184	    # now solve trans(l)*x = y
185	    if(nm1 >= 1)
186		for(kb := 1; kb < nm1; kb++){
187		    k = n - (kb+1);
188		    b[k] += dot(a[lda*k+k+1:lda*k+n],b[k+1:n]);
189		    l := ipvt[k];
190		    if(l!=k){
191			t := b[l];
192			b[l] = b[k];
193			b[k] = t;
194		    }
195		}
196	 }
197}
198