import pylab
import matplotlib
import networkx as nx
import scipy

def plot_corrloads(R, pc1=0,pc2=1,s=20, c='b', zorder=5,expvar=None,ax=None,drawback=True, labels=None):
    """ Correlation loading plot."""

    # background
    if ax==None or drawback==True:
        radius = 1
        center = (0,0)
        c100 = matplotlib.patches.Circle(center,radius=radius,
                                         facecolor='gray',
                                         alpha=.1,
                                         zorder=1)
        c50 = matplotlib.patches.Circle(center, radius=radius/2.0,
                                        facecolor='gray',
                                        alpha=.1,
                                        zorder=2)
        ax = pylab.gca()
        ax.add_patch(c100)
        ax.add_patch(c50)
        ax.axhline(lw=1.5,color='k')
        ax.axvline(lw=1.5,color='k')

    # corrloads
    ax.scatter(R[:,pc1], R[:,pc2], s=s, c=c,zorder=zorder)
    ax.set_xlim([-1,1])
    ax.set_ylim([-1,1])
    if expvar!=None:
        xstring = "Comp: %d   expl.var:  %.1f " %(pc1+1, expvar[pc1])
        pylab.xlabel(xstring)
        ystring = "Comp: %d   expl.var.:  %.1f " %(pc2+1, expvar[pc2])
        pylab.ylabel(ystring)
    if labels!=None:
        assert(len(labels)==R.shape[0])
        for name, r in zip(labels, R):
            pylab.text(r[pc1], r[pc2], "  " + name)
    #pylab.show()

def plot_dag(edge_dict, node_color='b', node_size=30,labels=None,nodelist=None,pos=None):
    # networkx does not play well with colon in node names
    clean_edges = {}
    for head, neigb in edge_dict.items():
        head = head.replace(":", "_")
        nei = [i.replace(":", "_") for i in neigb]
        clean_edges[head] = nei
    if pos==None:
        G = nx.from_dict_of_lists(clean_edges, nx.DiGraph(name='GO'))
        pos = nx.pydot_layout(G, prog='dot')
    G = nx.from_dict_of_lists(edge_dict, nx.DiGraph(name='GO'))
    if len(node_color)>1:
        assert(len(node_color)==len(nodelist))
    if labels!=None:
        with_labels=True
        
    nx.draw_networkx(G,pos, with_labels=with_labels, node_size=node_size, node_color=node_color, nodelist=nodelist)
    

def plot_ZXcorr(gene_ids, term_ids, gene2go, X, D, scale=True):
    """ Plot correlation/covariance between genes as a function of
    semantic difference.

    input: X (n, p) data matrix
           D (p, p) gene-gene sematic similarity matrix
    """
    D = scipy.corrcoef(X)
    term2ind = dict(enumerate(term_ids))
    for i, gene_i in enumerate(gene_ids):
        for j, gene_j in enumerate(gene_ids):
            if j<i:
                r2 = D[i,j]
                terms_i = gene2go[gene_i]
                terms_j = gene2go[gene_j]
                for ti, term in enumerate(term_ids):
                    if term in terms_i:
                        pass


def clustering_index(T, Yg):
    pass