summaryrefslogtreecommitdiff
path: root/flower/choleski.cc
blob: 82740a27c18addcbeca4278ac401dca765c2b773 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include "choleski.hh"
const Real EPS = 1e-7;		// so sue me. Hard coded

Vector
Choleski_decomposition::solve(Vector rhs)const
{
    int n= rhs.dim();
    assert(n == L.dim());
    Vector y(n);

    // forward substitution
    for (int i=0; i < n; i++) {
	Real sum(0.0);
	for (int j=0; j < i; j++)
	    sum += y(j) * L(i,j);
	y(i) = (rhs(i) - sum)/L(i,i);
    }
    for (int i=0; i < n; i++) {
	assert(D(i));
	y(i) /= D(i);
    }

    // backward subst
    Vector x(n);
    for (int i=n-1; i >= 0; i--) {
	Real sum(0.0);
	for (int j=i+1; j < n; j++)
	    sum += L(j,i)*x(j);
	x(i) = (y(i) - sum)/L(i,i);
    }
    return x;
}

/*
  Standard matrix algorithm.
  */

Choleski_decomposition::Choleski_decomposition(Matrix P)
    : L(P.dim()), D(P.dim())
{
    int n = P.dim();
    assert((P-P.transposed()).norm()/P.norm() < EPS);

    L.unit();
    for (int k= 0; k < n; k++) {
	for (int j = 0; j < k; j++){
	    Real sum(0.0);
	    for (int l=0; l < j; l++)
		sum += L(k,l)*L(j,l)*D(l);
	    L(k,j) = (P(k,j) - sum)/D(j);
	}
	Real sum=0.0;
	
	for (int l=0; l < k; l++)
	    sum += sqr(L(k,l))*D(l);
	Real d = P(k,k) - sum;
	D(k) = d;
    }

#ifdef NDEBUG
    assert((original()-P).norm() / P.norm() < EPS);
#endif
}
     
Matrix
Choleski_decomposition::original() const
{
    Matrix T(L.dim());
    T.set_diag(D);
    return L*T*L.transposed();
}

Matrix
Choleski_decomposition::inverse() const
{
    int n=L.dim();
    Matrix invm(n);
    Vector e_i(n);
    for (int i = 0; i < n; i++) {
	e_i.set_unit(i);
	Vector inv(solve(e_i));
	for (int j = 0 ; j<n; j++)
	    invm(i,j) = inv(j);
    }
    
#ifdef NDEBUG
    Matrix I1(n), I2(original());
    I1.unit();
    assert((I1-original()*invm).norm()/original.norm() < EPS);
#endif
    
    return invm;
}