#include <bla.hpp>

namespace ngbla
{
  using namespace ngstd;
  using namespace ngbla;

/* 
   See Stoer, Einf. in die Num. Math, S 146
   */


  template <class TM>
  void CheckPos (const TM & m)
  {;}

  void CheckPos (const double & m)
  {
    if (m <= 0)
      {
	cout << "diag is " << m << endl;
	throw Exception ("diag is <= 0");
      }
  }



  // Compute A = L D L^T decomposition
  // A_{ij} = \sum_k L_ik D_kk L_jk
  // L .. lower left factor, columne-wise storage

  template <class T>
  void FlatBandCholeskyFactors<T> :: 
  Factor (const FlatSymBandMatrix<T> & a)
  {
    int i, j, k;

    T x;
    for (i = 0; i < n; i++)
      {
	int maxj = min2(n, i+bw);
	for (j = i; j < maxj; j++)
	  {
	    x = a(j,i);

	    int mink = max2(0, j-bw+1);
	    int ki = Index(i, mink);
	    int kj = Index(j, mink);

	    for (k = mink; k < i; k++, ki++, kj++)
	      x -= mem[kj] * mem[k] * Trans (mem[ki]);

	    /*
	    for (k = mink; k < i; k++)
	      x -= (*this)(i,k) * mem[k] * Trans ((*this)(j,k));
	    */
	    if (i == j)
	      {
		mem[i] = x;
	      }
	    else
	      {
		T invd;
		CalcInverse (mem[i], invd);
		(*this)(j,i) = x * invd;
	      }
	  }
      }

    for (i = 0; i < n; i++)
      {
	T invd;
	CalcInverse (mem[i], invd);
	mem[i] = invd;
      }

    /*
    for (i = 0; i < n; i++)
      CheckPos (diag[i]);
    */
  }
  
  
  template <class T>  
  void FlatBandCholeskyFactors<T> :: 
  Mult (const FlatVector<TV> & x, FlatVector<TV> & y) const
  {
    for (int i = 0; i < n; i++)
      y(i) = x(i);
    
    for (int i = 0; i < n; i++)
      {
	TV sum = y(i);

	int firstj = max2(0,i-bw+1);
	int jj = Index (i, firstj);

	const TV * hy = &y(0);
	const T * hm = &mem[0];

	for (int j = firstj; j < i; j++, jj++)
	  sum -= hm[jj] * hy[j];

	y(i) = sum;
      }

    for (int i = 0; i < n; i++)
      {
	TV sum = mem[i] * y(i);
	y(i) = sum;
      }

    for (int i = n-1; i >= 0; i--)
      {
	TV val = y(i);

	int firstj = max2(0,i-bw+1);
	int jj = Index (i, firstj);

	for (int j = firstj; j < i; j++, jj++)
	  y(j) -= Trans (mem[jj]) * val;
      }
  }
  
  template <class T>
  ostream & FlatBandCholeskyFactors<T> :: Print (ostream & ost) const
  {
    ost << "Diag: " << endl;
    for (int i = 0; i < n; i++)
      ost << i << ": " << mem[i] << endl;
    
    for (int i = 0; i < n; i ++)
      {
	ost << i << ": ";
	for (int j = max2(0, i-bw+1); j < i; j++)
	  ost << (*this)(i,j) << "  ";
	ost << endl;
      }
    return ost;
  }

  template class FlatBandCholeskyFactors<double>;
  template class FlatBandCholeskyFactors<Complex>;
#if MAX_SYS_DIM >= 1
  template class FlatBandCholeskyFactors<Mat<1,1,double> >;
  template class FlatBandCholeskyFactors<Mat<1,1,Complex> >;
#endif
#if MAX_SYS_DIM >= 2
  template class FlatBandCholeskyFactors<Mat<2,2,double> >;
  template class FlatBandCholeskyFactors<Mat<2,2,Complex> >;
#endif
#if MAX_SYS_DIM >= 3
  template class FlatBandCholeskyFactors<Mat<3,3,double> >;
  template class FlatBandCholeskyFactors<Mat<3,3,Complex> >;
#endif
#if MAX_SYS_DIM >= 4
  template class FlatBandCholeskyFactors<Mat<4,4,double> >;
  template class FlatBandCholeskyFactors<Mat<4,4,Complex> >;
#endif
#if MAX_SYS_DIM >= 5
  template class FlatBandCholeskyFactors<Mat<5,5,double> >;
  template class FlatBandCholeskyFactors<Mat<5,5,Complex> >;
#endif
#if MAX_SYS_DIM >= 6
  template class FlatBandCholeskyFactors<Mat<6,6,double> >;
  template class FlatBandCholeskyFactors<Mat<6,6,Complex> >;
#endif
#if MAX_SYS_DIM >= 7
  template class FlatBandCholeskyFactors<Mat<7,7,double> >;
  template class FlatBandCholeskyFactors<Mat<7,7,Complex> >;
#endif
#if MAX_SYS_DIM >= 8
  template class FlatBandCholeskyFactors<Mat<8,8,double> >;
  template class FlatBandCholeskyFactors<Mat<8,8,Complex> >;
#endif
}
