#! /usr/bin/python2.4

import sys
import random
import Image
import math
import psyco

from numpy import *

from neuralNets import Sample
from neuralNets import Slab
from neuralNets import Sandwich
from craigMath import CraigMath

from pic import Pic

cmath = CraigMath()

#DEBUG = True
DEBUG = False

#f=open("C:\\prog\\image\\images\\test.log", 'w')


def convertNeuStrToInt(input):
	if input == "linear": return 0
	if input == "logsig": return 1
	if input == "hardlim": return 2
	if input == "tansig": return 3
	return -1

def parseLayers(layArgs):
    if layArgs == []:
	print "***using default layers***"
	layArgs = ["2", "logsig", "1", "linear"]
    numbers = []	
    types = []
    for i in range(0, len(layArgs), 2):
	neuNo = int(layArgs[i])
	neuType = convertNeuStrToInt(layArgs[i+1])
	if neuNo == 0:
		print "ERROR - MUST:  >= 1 NEURON PER LAYER"
		sys.exit(1)
	if neuType == -1:
		print "ERROR - INVALID ARGUMENT FOR LAYER TYPE"
		sys.exit(1)
        numbers.append(neuNo)
        types.append(neuType)
    toRet = [numbers, types]
    return toRet


def getRandomShift(dim):
	bla = []
	for i in range(dim[0]*dim[1]):
		bla.append(random.random())
	toRet = mat(bla)
	toRet.resize(dim[0], dim[1])
	return toRet

def makeSandwich(inputDim, layers):
    inputOutputNumbers = [inputDim]
    inputOutputNumbers.extend(layers[0])
    neuronTypes = layers[1]

    s = Sandwich(inputOutputNumbers, neuronTypes)
    for slab in s.slabs:
	slab.shiftWeights(getRandomShift(slab.w.shape))

    return s


def getTheSamples(theImage):
    toRet = []
    for i in range(theImage.width):
        for j in range(theImage.height):
            newSample = Sample([i, j], theImage.getPixel(i, j))
            toRet.append(newSample)
    random.shuffle(toRet)
    return toRet


def runEpochs(theSandwich, theSamples, startEpoch, stopEpoch):
    for i in range(startEpoch, stopEpoch):
	mse = 0
	er = 0
	for sample in theSamples:
	    actual = theSandwich.computeOutput(sample.i.T)
            epsilon = sample.d - actual
            mse += epsilon*epsilon
            er += abs(epsilon)
            theSandwich.computeSensitivities(epsilon)
            theSandwich.updateRProp()
        theSandwich.cookRProp()
	mse /= len(theSamples)
	er /= len(theSamples)
	print "iteration %s, mse:  %f, mean error:  %f" \
         % (i, mse, er)
        skipSave = 4
	if i % skipSave == 0:
            #showPic()
            savePici(i/skipSave)
            
    return theSandwich


def see(theImage, theSand):
    toWrite = Image.new("L", (theImage.width, theImage.height), 256)
    for i in range(theImage.width):
        for j in range(theImage.height):
            toWrite.putpixel((i, j), theSand.computeOutput(mat([i, j]).T))
    return toWrite


def setupPic(fileName):
    theImage = Pic("C:\\prog\\image\\images\\"+fileName+".bmp")
    print "got pic"
    return theImage

def setupSamples(theImage):
    theSamples = getTheSamples(theImage)
    return theSamples


def setupSand(layargs):
    layers = parseLayers(layargs)
    s = makeSandwich(2, layers)
    return s        


# depends on preexisting globals
def runMore(startR, stopR):
    runEpochs(sand, samp, startR, stopR)

def showPic():
    see(pic, sand).show()

def savePic():
    see(pic, sand).save("C:\\prog\\image\\images\\"+picName+"2.bmp")

def savePici(i):
    a = "C:\\prog\\image\\images\\method1\\"+picName+"\\"+picName+"%04d.bmp" % i 
    see(pic, sand).save(a)

def setPicName(name):
    picName = name

def setupTest(picLabel, weightToRun):
    picName1 = picLabel
    pic1 = setupPic(picName1)
    samp1 = setupSamples(pic1)
    sand1 = setupSand(["10", "logsig", "10", "logsig", "1", "linear"])
    return [picName1, pic1, samp1, sand1, weightToRun, 0]

tests = []
#tests.append(setupTest("target", 120))
#tests.append(setupTest("stripes", 120))
#tests.append(setupTest("cross", 120))
#tests.append(setupTest("diag", 60))
#tests.append(setupTest("lrgrad", 120))
tests.append(setupTest("circle", 60))
#tests.append(setupTest("checker", 120))
#tests.append(setupTest("a", 60))
#tests.append(setupTest("squiggle", 60))
#tests.append(setupTest("dots", 60))

psyco.full()

for i in range(100000):
    picName = tests[i % len(tests)][0]
    pic = tests[i % len(tests)][1]
    samp = tests[i % len(tests)][2]
    random.shuffle(samp)
    sand = tests[i % len(tests)][3]
    print "--------%s--------" % picName
    runMore(tests[i % len(tests)][5], \
     tests[i % len(tests)][5] + tests[i % len(tests)][4])
    tests[i % len(tests)][5] += tests[i % len(tests)][4]

