// Cell.cpp // David R. Morrison #include "cell.hpp" // Constructor Cell::Cell(int input_vector_size) { // Add one to account for bias inputVectorSize = input_vector_size; weights = getRandomWeights(inputVectorSize + 1); inGateDerivs.resize(inputVectorSize + 1); forgetGateDerivs.resize(inputVectorSize + 1); cellDerivs.resize(inputVectorSize + 1); clear(); } double Cell::evaluateSample(vector inps, double inputGate, double forgetGate, double outputGate) { // Store the input for later use inputs = inps; // Bias input = weights[0]; for (int i = 1; i < inputVectorSize + 1; ++i) input += weights[i] * inputs[i - 1]; // If the cell needs to be cleared, do that now // Store the (squashed) information in the CEC (Constant Error Carousel), moderated by the // input gate. If the value from the input gate is low (close to 0), then the value in the // cell doesn't get updated. On the other hand, if it's high (close to 1), the CEC will // change. cec = f(forgetGate) * cec + g(input) * f(inputGate); // Similarly, we only output the (squashed) value in the CEC if the output gate tells us to. output = h(cec) * f(outputGate); return output; } void Cell::clear() { input = 0.0; output = 0.0; cec = 0.0; inputs.clear(); for (int i = 0; i < inputVectorSize + 1; ++i) { inGateDerivs[i] = 0.0; forgetGateDerivs[i] = 0.0; cellDerivs[i] = 0.0; } } // Calculate cell derivatives to use in training; see LSTM paper for details void Cell::calcDerivs(double inputGate, double forgetGate) { inGateDerivs[0] = inGateDerivs[0] * f(forgetGate) + g(input) * fPrime(inputGate); forgetGateDerivs[0] = forgetGateDerivs[0] * f(forgetGate) + cec * fPrime(forgetGate); cellDerivs[0] = cellDerivs[0] * f(forgetGate) + gPrime(input) * fPrime(inputGate); for (int i = 1; i < inputVectorSize + 1; ++i) { inGateDerivs[i] = inGateDerivs[i] * f(forgetGate) + g(input) * fPrime(inputGate) * inputs[i - 1]; forgetGateDerivs[i] = forgetGateDerivs[i] * f(forgetGate) + cec * fPrime(forgetGate) * inputs[i - 1]; cellDerivs[i] = cellDerivs[i] * f(forgetGate) + gPrime(input) * f(inputGate) * inputs[i - 1]; } } // Update the weight at index i by delta void Cell::changeWeight(unsigned int i, double delta) { assert(i < inputVectorSize + 1); weights[i] += delta; } // Return the value of the derivative of weight i to the input gate double Cell::getInGateDeriv(unsigned int i) { assert(i < inputVectorSize + 1); return inGateDerivs[i]; } // Return the value of the derivative of weight i to the forget gate double Cell::getForgetGateDeriv(unsigned int i) { assert(i < inputVectorSize + 1); return forgetGateDerivs[i]; } // Return the value of the derivative of weight i to the cell double Cell::getCellDeriv(unsigned int i) { assert(i < inputVectorSize + 1); return cellDerivs[i]; } // Get the internal state of the cell double Cell::getCECValue() { return cec; } // Get the value of the ith weight to the cell double Cell::getWeight(unsigned int i) { assert(i < inputVectorSize + 1); return weights[i]; } ostream& Cell::print(ostream& out) { out.precision(4); out << "[ "; for (int i = 0; i < weights.size(); ++i) out << weights[i] << " "; out << "] CEC: [ " << cec << " ]" << endl; } ostream& operator<<(ostream& out, Cell c) { return c.print(out); }