#!/usr/bin/python

# The robust least-squares example of section 9.1.  

from math import sqrt, ceil, floor
from cvxopt import base, solvers, blas, lapack
from cvxopt.base import matrix, spmatrix, spdiag, sqrt, mul, div
import pylab

def robls(A, b, rho): 

    # Minimize  sum_k sqrt(rho + (A*x-b)_k^2).

    m, n = A.size
    def F(x=None, z=None):
        if x is None: return 0, matrix(0.0, (n,1))
        y = A*x-b
        w = sqrt(rho + y**2)
        f = sum(w)
        Df = div(y, w).T * A 
        if z is None: return f, Df 
        H = A.T * spdiag(z[0]*rho*(w**-3)) * A
        return f, Df, H

    return solvers.cp(F)['x']


base.setseed()
m, n  = 500, 100
A = base.normal(m,n)
b = base.normal(m,1)
xh = robls(A,b,0.1)

try: import pylab
except ImportError: pass
else:

    # Least-squares solution.
    pylab.subplot(211)
    xls = +b
    lapack.gels(+A,xls)
    rls =  A*xls[:n] - b
    pylab.hist(rls, m/5)
    pylab.title('Least-squares solution')
    pylab.xlabel('Residual')
    mr = ceil(max(rls))
    pylab.axis([-mr, mr, 0, 25])
 
    # Robust least-squares solution with rho = 0.01.
    pylab.subplot(212)
    rh =  A*xh - b
    pylab.hist(rh, m/5)
    mr = ceil(max(rh))
    pylab.title('Robust least-squares solution')
    pylab.xlabel('Residual')
    pylab.axis([-mr, mr, 0, 50])

    pylab.show()
