from scipy import apply_along_axis,newaxis,zeros,\
     median,round_,nonzero,dot,argmax,any,sqrt,ndarray,\
     trace,zeros_like,sign,sort,real,argsort,rand,array,\
     matrix,nan
from scipy.linalg import norm,svd,inv,eig
from numpy import median

def normalise(a, axis=0, return_scales=False):
    s = apply_along_axis(norm, axis, a)
    if axis==0:
        s = s[newaxis]
    else:
        s = s[:,newaxis]
	    
    a_s = a/s

    if return_scales:
       return a_s, s

    return a_s

def sub2ind(shape, i, j):
	"""Indices from subscripts. Only support for 2d"""
	row,col = shape
	ind = []
	for k in xrange(len(i)):
		for m in xrange(len(j)):
			ind.append(i[k]*col + j[m])
	return ind


def sorted_eig(a, b=None,sort_by='sm'):
    """
    Just eig with real part of output sorted:
    This is for convenience only, not general!
    
    sort_by='sm': return the eigenvectors by eigenvalues
                  of smallest magnitude first. (default)
            'lm': returns largest eigenvalues first      

    output: just as eig with 2 outputs
            -- s,v (eigvals,eigenvectors)
    (This is reversed output compared to matlab)
    
    """
    s,v = eig(a, b)
    s = real(s) # dont expect any imaginary part
    v = real(v)
    ind = argsort(s)
    if sort_by=='lm':
        ind = ind[::-1]
    v = v.take(ind, 1)
    s = s.take(ind)

    return s,v

def str2num(string_number):
    """Convert input (string number) into number, if float(string_number) fails, a nan is inserted. 
    """
    missings = ['','nan','NaN','NA']
    try:
        num = float(string_number)
    except:
        if string_number in missings:
            num = nan
        else:
            print "Found strange entry: %s" %string_number
            raise
    return num

def randperm(n):
  r = rand(n)
  dict={}
  for i in range(n):
     dict[r[i]] = i
  r = sort(r)
  out = zeros(n)
  for i in range(n):
     out[i] = dict[r[i]]
  return array(out).astype('i')

def mat_center(X,axis=0,ret_mn=False):
    """Mean center matrix along axis.
    
        X -- matrix, data
        axis -- dim,
        ret_mn -- bool, return mean

    output:
            Xc, [mnX]
            
    NB: axis = 1 is column-centering, axis=0=row-centering
    default is row centering (axis=0)
    """

    try:
        rows,cols = X.shape
    except ValueError:
        print "The X data needs to be two-dimensional"
        
    if axis==0:
        mnX = X.mean(axis)[newaxis]
        Xs = X - mnX
    
    elif axis==1:
        mnX = X.mean(axis)[newaxis]
        Xs = (X.T - mnX).T
    if ret_mn:
        return Xs,mnX
    else:
        return Xs

def m_shape(array):
	"""Returns the array shape on the form of a numpy.matrix."""
	return matrix(array).shape