1 #include "os.h"
2 #include <mp.h>
3 #include <libsec.h>
4 #include "dat.h"
5
6 mpint*
mpfactorial(ulong n)7 mpfactorial(ulong n)
8 {
9 int i;
10 ulong k;
11 unsigned cnt;
12 int max, mmax;
13 mpdigit p, pp[2];
14 mpint *r, *s, *stk[31];
15
16 cnt = 0;
17 max = mmax = -1;
18 p = 1;
19 r = mpnew(0);
20 for(k=2; k<=n; k++){
21 pp[0] = 0;
22 pp[1] = 0;
23 mpvecdigmuladd(&p, 1, (mpdigit)k, pp);
24 if(pp[1] == 0) /* !overflow */
25 p = pp[0];
26 else{
27 cnt++;
28 if((cnt & 1) == 0){
29 s = stk[max];
30 mpbits(r, Dbits*(s->top+1+1));
31 memset(r->p, 0, Dbytes*(s->top+1+1));
32 mpvecmul(s->p, s->top, &p, 1, r->p);
33 r->sign = 1;
34 r->top = s->top+1+1; /* XXX: norm */
35 mpassign(r, s);
36 for(i=4; (cnt & (i-1)) == 0; i=i<<1){
37 mpmul(stk[max], stk[max-1], r);
38 mpassign(r, stk[max-1]);
39 max--;
40 }
41 }else{
42 max++;
43 if(max > mmax){
44 mmax++;
45 if(max >= nelem(stk)){
46 while(--max >= 0)
47 mpfree(stk[max]);
48 mpfree(r);
49 sysfatal("mpfactorial: stack overflow");
50 }
51 stk[max] = mpnew(Dbits);
52 }
53 stk[max]->top = 1;
54 stk[max]->p[0] = p;
55 }
56 p = (mpdigit)k;
57 }
58 }
59 if(max < 0){
60 mpbits(r, Dbits);
61 r->top = 1;
62 r->sign = 1;
63 r->p[0] = p;
64 }else{
65 s = stk[max--];
66 mpbits(r, Dbits*(s->top+1+1));
67 memset(r->p, 0, Dbytes*(s->top+1+1));
68 mpvecmul(s->p, s->top, &p, 1, r->p);
69 r->sign = 1;
70 r->top = s->top+1+1; /* XXX: norm */
71 }
72
73 while(max >= 0)
74 mpmul(r, stk[max--], r);
75 for(max=mmax; max>=0; max--)
76 mpfree(stk[max]);
77 mpnorm(r);
78 return r;
79 }
80