import sys
import random

from numpy import *

from craigMath import CraigMath

cmath = CraigMath()

class Sample:
	def __init__(self, theInputs, desired):
		self.d = mat(desired)
		self.i = mat(theInputs)

	def __str__(self):
		toRet = "On Input %s, I want %s" % (self.i, self.d)
		return toRet

class Slab:
	def __init__(self, numInputs, numNeurons, funcType):
		self.x = mat(zeros([numInputs + 1, 1]))
		self.w = mat(zeros([numNeurons, numInputs+1]))
		self.u = mat(zeros([numNeurons, 1]))
		self.s = mat(zeros([numNeurons, 1]))
		self.y = mat(zeros([numNeurons, 1]))

		self.updateValues = mat(ones([numNeurons, numInputs+1]))
		self.oldDEdW = mat(zeros([numNeurons, numInputs+1]))
		self.dEdW = mat(zeros([numNeurons, numInputs+1]))

		self.deltaWStore = mat(zeros([numNeurons, numInputs+1]))

		self.type = funcType		

	def __str__(self):
		toRet = "  %s slab with %s inputs and %s neuron(s). \n" % \
		 (self.convertTypeIntToStr(self.type), \
		 self.w.shape[1] - 1, self.w.shape[0])
		toRet += "w:  %s\n" % self.w
		return toRet

	def convertTypeIntToStr(self, input):
		if input == 0: return "Linear"
		if input == 1: return "Logsig"
		if input == 2: return "Hardlim"
		if input == 3: return "Tansig"
		return -1

	def f(self, input):
		# linear
		if self.type == 0:
			return input
		# logsig
		if self.type == 1:
			return cmath.logsig(input)
		# hardlim
		if self.type == 2:
			return cmath.hardlim(input)
		# tansig
		if self.type == 3:
			return cmath.tansig(input)

	def fPrime(self, input):
		# linear
		if self.type == 0:
			return 1
		# logsig
		if self.type == 1:
			return cmath.logsigPrime(input)
		# hardlim
		if self.type == 2:
			return cmath.hardlimPrime(input)
		# tansig
		if self.type == 3:
			return cmath.tansigPrime(input)

	# matrix versions
	def fM(self, input):
		for i in range(input.shape[0]):
			for j in range(input.shape[1]):
				input[i, j] = self.f(input[i, j])
		return input

	def fPrimeM(self, input):
		for i in range(input.shape[0]):
			for j in range(input.shape[1]):
				input[i, j] = self.fPrime(input[i, j])
		return input

	def computeU(self, inputs):
		self.x = inputs
		self.u = self.w*inputs
		return self.u.copy()

	def computeY(self, inputs):
		self.y = self.fM(self.computeU(inputs)[:])
		return self.y.copy()

	def shiftWeights(self, shifts):
		if self.w.shape != shifts.shape:
			print "ERROR:  DIMENSIONS FOR SHIFT DO NOT MATCH"
			print "w: %s, shift: %s" % (self.w.shape, shifts.shape)
			sys.exit()
		self.w += shifts

	def shiftWeightsRandomly(self, ra):
		bla = []
		for i in range(self.w.shape[0]*self.w.shape[1]):
			bla.append(random.random()*ra - ra/2.0)
		shifts = mat(bla)
		shifts.resize(self.w.shape[0], self.w.shape[1])
		self.shiftWeights(shifts)

	def updateWeight(self, eta):
		delta_w = -eta * self.s * self.x.T
		self.shiftWeights(delta_w)

	def updateWeightBatch(self, eta):
		delta_w = -eta * self.s * self.x.T
		self.deltaWStore += delta_w

	def cookBatch(self):
		self.shiftWeights(self.deltaWStore)
		self.deltaWStore = \
		 mat(zeros([self.deltaWStore.shape[0], \
		 self.deltaWStore.shape[1]]))

	def updateRProp(self):
		self.dEdW = self.dEdW - self.s * self.x.T

	def cookRProp(self):
		# update weights
		self.dEdW = sign(self.dEdW)
		self.dEdW = multiply(self.dEdW, self.updateValues)
		self.shiftWeights(self.dEdW)
		self.dEdW = sign(self.dEdW)

		# update the update values
		theSigns = sign(multiply(self.dEdW, self.oldDEdW))
		self.updateValues = multiply(self.updateValues, \
		 1.2 * (1 == theSigns) + \
		 1.0 * (0 == theSigns) + \
		 0.5 * (-1 == theSigns))

		# iteration stuff
		self.oldDEdW = self.dEdW.copy()
		self.dEdW = self.dEdW * 0
		

class Sandwich:
	"""This class serves as a model for a series of neuron slabs"""
	
	def __init__(self, numOfPuts, theTypes):
		"""contains slabs"""
		self.slabs = []
		for i in range(len(theTypes)):
			self.slabs.append( \
			 Slab(numOfPuts[i], numOfPuts[i+1], theTypes[i]))

	def __str__(self):
		toRet = "Sandwich with %d slabs:  \n" % (len(self.slabs), )
		for i in range(len(self.slabs)):
			toRet = toRet + " Slab %d: %s" % (i, self.slabs[i])
		toRet += "\n"
		return toRet

	def computeOutput(self, inputs):
		lastLayer = inputs
		for i in self.slabs:
			lastLayer = concatenate((lastLayer, \
			 mat(ones([1, inputs.shape[1]])) ))
			lastLayer = i.computeY(lastLayer)
		return lastLayer

	def computeSensitivities(self, outputs):
		lastSlab = self.slabs[len(self.slabs) - 1]
		arg1 = -2*outputs
		arg2 = lastSlab.fPrime(lastSlab.u.copy())
		lastSlab.s =  multiply(arg1, arg2)
		for i in range(len(self.slabs) - 2, -1, -1):  
			self.computeSensitivity(self.slabs[i], self.slabs[i+1])

	def computeSensitivity(self, slab_a, slab_b):
		arg1 = slab_a.fPrimeM(slab_a.u)
		arg2 = slab_b.w.T*slab_b.s
		slab_a.s = multiply(arg1, arg2[0:-1]) 

	def updateWeights(self, eta):
		for slab in self.slabs:
			slab.updateWeight(eta)

	def updateWeightsBatch(self, eta):
		for slab in self.slabs:
			slab.updateWeightBatch(eta)

	def cookBatch(self):
		for slab in self.slabs:
			slab.cookBatch()

	def updateRProp(self):
		for slab in self.slabs:
			slab.updateRProp()

	def cookRProp(self):
		for slab in self.slabs:
			slab.cookRProp()

