from numpy import *

HAS_SYMPOWER=True
try:
    from _numpy_module import sym_powerit
except:
    raise ImportError("Sym_powerit module not present")
    HAS_SYMPOWER = False

class SymPowerException(Exception):
    pass

_ERRORCODES = {1: "Some eigenvectors did not converge, try to increase \nthe number of iterations or lower the tolerance level",
               0: ""}


def sympowerit(xx, T0=None, mn_center=False, a_max=10, tol=1e-7, maxiter=100,
               verbose=0):
    """Estimate eigenvectos of a symmetric matrix using the power method.

    *Parameters*:

        xx : {array}
            Symmetric square array (m, m)
        T0 : {array}
            Initial solution (m, a_max), optional
        mn_center : {boolean}, optional
            Mean centering
        a_max : {integer}, optional
            Number of components to extract
        tol : {float}, optional
            Tolerance level of eigenvector solver
        maxiter : {integer}
            Maximum number of poweriterations to use
        verbose : {integer}
            Debug output (==1)

    *Returns*:
        v : {array}
            Eigenvectors of xx, (m , a_max)
    """
    
    valid_types = ['D','d','F','f']
    dtype = xx.dtype.char
    n, m = xx.shape
    if not(dtype in valid_types):
        msg = "Array type: (%s) needs to be a float or double" %dtype
        raise SymPowerException(msg)
    if not (m==n):
        msg = "Input array needs to be square, input: (%d,%d)" %(m,n)
        raise SymPowerException(msg)
    # small test of symmetry
    N = 5
    num = random.randint(0,n,N)
    for i in range(5):
        j = N-5
        if abs(xx[num[i],num[j]] - xx[num[j],num[i]])>1e-15:
            msg = "Array needs to be symmetric"
            raise SymPowerException(msg)

    if not a_max:
        a_max = 10
    
    if T0 !=None:
        tn, tm = T0.shape
        if not (tn==n):
            msg = "Start eigenvectors need to match input array ()"
            raise SymPowerException(msg)
        if not (tm==a_max):
            msg = "Start eigenvectors need to match input a_max ()"
            raise SymPowerException(msg)
    else:
        T0 = zeros((n, a_max), 'd')
        T0[0,:] = ones((a_max,),'d')

    if mn_center:
        xx = _center(xx)

    # call c-function
    T, info = sym_powerit(xx, T0, n, a_max, tol, maxiter, verbose)
    
    if info != 0:
        if verbose:
            print _ERRORCODES.get(info, "Dont know this error")
    return T


def _center(xx, ret_mn=False):
    """Returns mean centered symmetric kernel matrix.
    """
    n = xx.shape[0]
    h = xx.sum(0)[:,newaxis]
    h = (h - mean(h)/2)/n
    mn_a = h + h.T
    xxc = xx - mn_a
    if ret_mn:
        return xxc, mn_a
    return xxc