#!/usr/bin/python

from numpy import *
import sys

class ann:
    
    def __init__(self, nof_input, nof_hidden, nof_output):

        # Learning rate
        self.learning_rate = 0.1

        # Number of neurons in each layer
        self.nof_input  = nof_input   
        self.nof_hidden = nof_hidden  
        self.nof_output = nof_output

        # Outputs of neurons
        self.input_value = zeros((self.nof_input + 1, 1), dtype=float)      # bias
        self.hidden_value = zeros((self.nof_hidden + 1, 1), dtype=float)    # bias
        self.output_value = zeros((self.nof_output), dtype=float)

        # Deltas for hidden and output layers
        self.hidden_delta = zeros((self.nof_hidden), dtype=float) 
        self.output_delta = zeros((self.nof_output), dtype=float)

        # Activation level of neurons (sum of input)
        self.hidden_activation = zeros((self.nof_hidden, 1), dtype=float)
        self.output_activation = zeros((self.nof_output, 1), dtype=float)

        # Weights in the two links
        self.inner_weights = random.random((self.nof_hidden, self.nof_input + 1))
        self.outer_weights = random.random((self.nof_output, self.nof_hidden + 1))

    def feed_forward(self, input_list):

        # Input layer
        self.input_value[:-1, 0] = input_list

        # Set bias neuron in input layer to 1.0
        self.input_value[-1:, 0] = 1.0  

        # Hidden layer
        self.hidden_activation = dot(self.inner_weights, self.input_value)
        self.hidden_value[:-1, :] = tanh(self.hidden_activation)

        # Set bias neuron in hidden layer
        self.hidden_value[-1:, :] = 1.0

        # Output layer
        self.output_activation = dot(self.outer_weights, self.hidden_value)
        self.output_value = tanh(self.output_activation)

    def back_propagation(self, correct):

        # Find the error between the output and the actual solution
        error = self.output_value - correct

        # Delta of the output neurons
        self.output_delta = (1 - self.output_value) * self.output_value * error

        # Delta of the hidden neurons
        self.hidden_delta = (1 - tanh(self.hidden_activation)) * tanh(self.hidden_activation) * dot(self.outer_weights[:,:-1].transpose(), self.output_delta)

        # Apply weight changes
        self.inner_weights -= self.learning_rate * dot(self.hidden_delta, self.input_value.transpose())
        self.outer_weights -= self.learning_rate * dot(self.output_delta, self.hidden_value.transpose())

    def print_ann(self):
        print
        print self.inner_weights
        print self.outer_weights

def xor_ann():
    xor_in = [[0,0], [0,1], [1,0], [1,1]]
    xor_out = [[0], [1], [1], [0]]
    xor_ann = ann(2,2,1)

    cnt = 0
    try:
        while(True):
            sample = random.randint(0,4)

            xor_ann.feed_forward(xor_in[sample])
            xor_ann.back_propagation(xor_out[sample])

            result = xor_ann.output_value[0]
            true = 'neutral'

            print cnt, xor_in[sample], result,
            if result > 0.8:
                print '\033[92mTRUE\033[0m'
                true = 'true'
            elif result < 0.2 and result >= 0:
                print '\033[93mFALSE\033[0m'
                true = 'false'
            else:
                print '\033[91mNO ANSWER\033[0m'
            xor_ann.print_ann()
            cnt += 1

    except(KeyboardInterrupt, SystemExit):
        xor_ann.print_ann()


def and_ann():
    and_in = [[0,0], [0,1], [1,0], [1,1]]
    and_out = [[0], [0], [0], [1]]
    and_ann = ann(2,1,1)

    # Weights in the two links
    and_ann.inner_weights = [[1,1,-1.5]] 
    and_ann.outer_weights = [[1,-0.5]] 

    print and_ann.inner_weights
    print and_ann.input_value

    cnt = 0
    try:
        while(True):
            sample = random.randint(0,4)

            and_ann.feed_forward(and_in[sample])

            result = and_ann.hidden_value[0]
            true = 'neutral'

            print cnt, and_in[sample], result,
            if result > 0.1:
                print '\033[92mTRUE\033[0m'
                true = 'true'
            elif result < -0.1:
                print '\033[93mFALSE\033[0m'
                true = 'false'
            else:
                print '\033[91mNO ANSWER\033[0m'
            cnt += 1

    except(KeyboardInterrupt, SystemExit):
        and_ann.print_ann()


def or_ann():
    or_in = [[0,0], [0,1], [1,0], [1,1]]
    or_out = [[0], [1], [1], [1]]
    or_ann = ann(2,1,1)

    # Weights in the two links
    or_ann.inner_weights = [[0.5,0.5,-0.3]] 
    or_ann.outer_weights = [[1,-0.5]]

    print or_ann.inner_weights
    print or_ann.input_value

    cnt = 0
    try:
        while(True):
            sample = random.randint(0,4)

            or_ann.feed_forward(or_in[sample])

            result = or_ann.hidden_value[0]
            true = 'neutral'

            print cnt, or_in[sample], result,
            if result > 0.1:
                print '\033[92mTRUE\033[0m'
                true = 'true'
            elif result < -0.1:
                print '\033[93mFALSE\033[0m'
                true = 'false'
            else:
                print '\033[91mNO ANSWER\033[0m'
            cnt += 1

    except(KeyboardInterrupt, SystemExit):
        or_ann.print_ann()

if len(sys.argv) > 1:
    if sys.argv[1] == 'xor':
        xor_ann()
    elif sys.argv[1] == 'and':
        and_ann()
    elif sys.argv[1] == 'or':
        or_ann()