Skip to content
Snippets Groups Projects
cs_chol.c 2.81 KiB
Newer Older
  • Learn to ignore specific revisions
  • Lubomir Riha's avatar
    Lubomir Riha committed
    #include "cs.h"
    /* L = chol (A, [pinv parent cp]), pinv is optional */
    csn *cs_chol (const cs *A, const css *S)
    {
        CS_ENTRY d, lki, *Lx, *x, *Cx ;
        CS_INT top, i, p, k, n, *Li, *Lp, *cp, *pinv, *s, *c, *parent, *Cp, *Ci ;
        cs *L, *C, *E ;
        csn *N ;
        if (!CS_CSC (A) || !S || !S->cp || !S->parent) return (NULL) ;
        n = A->n ;
        N = cs_calloc (1, sizeof (csn)) ;       /* allocate result */
        c = cs_malloc (2*n, sizeof (CS_INT)) ;     /* get CS_INT workspace */
        x = cs_malloc (n, sizeof (CS_ENTRY)) ;    /* get CS_ENTRY workspace */
        cp = S->cp ; pinv = S->pinv ; parent = S->parent ;
        C = pinv ? cs_symperm (A, pinv, 1) : ((cs *) A) ;
        E = pinv ? C : NULL ;           /* E is alias for A, or a copy E=A(p,p) */
        if (!N || !c || !x || !C) return (cs_ndone (N, E, c, x, 0)) ;
        s = c + n ;
        Cp = C->p ; Ci = C->i ; Cx = C->x ;
        N->L = L = cs_spalloc (n, n, cp [n], 1, 0) ;    /* allocate result */
        if (!L) return (cs_ndone (N, E, c, x, 0)) ;
        Lp = L->p ; Li = L->i ; Lx = L->x ;
        for (k = 0 ; k < n ; k++) Lp [k] = c [k] = cp [k] ;
        for (k = 0 ; k < n ; k++)       /* compute L(k,:) for L*L' = C */
        {
            /* --- Nonzero pattern of L(k,:) ------------------------------------ */
            top = cs_ereach (C, k, parent, s, c) ;      /* find pattern of L(k,:) */
            x [k] = 0 ;                                 /* x (0:k) is now zero */
            for (p = Cp [k] ; p < Cp [k+1] ; p++)       /* x = full(triu(C(:,k))) */
            {
                if (Ci [p] <= k) x [Ci [p]] = Cx [p] ;
            }
            d = x [k] ;                     /* d = C(k,k) */
            x [k] = 0 ;                     /* clear x for k+1st iteration */
            /* --- Triangular solve --------------------------------------------- */
            for ( ; top < n ; top++)    /* solve L(0:k-1,0:k-1) * x = C(:,k) */
            {
                i = s [top] ;               /* s [top..n-1] is pattern of L(k,:) */
                lki = x [i] / Lx [Lp [i]] ; /* L(k,i) = x (i) / L(i,i) */
                x [i] = 0 ;                 /* clear x for k+1st iteration */
                for (p = Lp [i] + 1 ; p < c [i] ; p++)
                {
                    x [Li [p]] -= Lx [p] * lki ;
                }
                d -= lki * CS_CONJ (lki) ;            /* d = d - L(k,i)*L(k,i) */
                p = c [i]++ ;
                Li [p] = k ;                /* store L(k,i) in column i */
                Lx [p] = CS_CONJ (lki) ;
            }
            /* --- Compute L(k,k) ----------------------------------------------- */
            if (CS_REAL (d) <= 0 || CS_IMAG (d) != 0)
    	    return (cs_ndone (N, E, c, x, 0)) ; /* not pos def */
            p = c [k]++ ;
            Li [p] = k ;                /* store L(k,k) = sqrt (d) in column k */
            Lx [p] = sqrt (d) ;
        }
        Lp [n] = cp [n] ;               /* finalize L */
        return (cs_ndone (N, E, c, x, 1)) ; /* success: free E,s,x; return N */
    }