// lstm.cpp // // The main file for the LSTM network with forget gates. // David R. Morrison #include #include #include #include #include #include "network.hpp" using namespace std; void initNetwork(int inputSize, int outputSize); int getSamples(int, int, vector< vector< vector > >&, vector< vector< vector > >&, int trace); double randomFloat(); Network* nw; int main(int argc, char* args[]) { double eta = 0.1; // Learning rate. double target_err = 0.0001; // Goal for mean squared error int epochs = 500; // Training time in generations int trace = 0; // Process arguments; the switch is allowed to fall through to lower // cases on purpose. Argument order is as follows: // switch (argc) { case 5: trace = atoi(args[4]); case 4: epochs = atoi(args[3]); case 3: target_err = atof(args[2]); case 2: eta = atof(args[1]); break; } // Dimension of input and output vectors; int input_num = 0; int output_num = 0; cin >> output_num; cin >> input_num; // Input/output vectors. vector< vector< vector > > train_outputs; vector< vector< vector > > train_inputs; vector< vector< vector > > test_outputs; vector< vector< vector > > test_inputs; if (getSamples(input_num, output_num, train_inputs, train_outputs, trace) == -1) return -1; if (getSamples(input_num, output_num, test_inputs, test_outputs, trace) == -1) return -1; cout << test_inputs.size() << endl; initNetwork(input_num, output_num); int i = 0; while(true) { cout << "Epoch " << i << " - "; double mse = 0; mse = nw->train(train_inputs, train_outputs, eta, trace); cout << "Training Mean Square Error: " << mse << endl; if (mse <= target_err) { cout << endl << "Network achieved classification in " << i + 1 << " epoch"; (i + 1 == 1) ? cout << "." << endl : cout << "s." << endl; break; } cout << endl; ++i; } double mse = nw->test(test_inputs, test_outputs); cout << "Testing Mean Square Error: " << mse << endl << endl; delete nw; } void initNetwork(int inputSize, int outputSize) { vector blocks; vector blockSize; ifstream properties(".lstmrc"); string layerInfo; while (getline(properties, layerInfo)) { int breakpoint = layerInfo.find(" "); if (breakpoint == string::npos && layerInfo.size() > 0) { cerr << "Malformed network descriptor: " << layerInfo << ". Ignoring." << endl; continue; } else if (breakpoint == string::npos) continue; blocks.push_back(atoi(layerInfo.substr(0, breakpoint).c_str())); blockSize.push_back(atoi(layerInfo.substr(breakpoint + 1, layerInfo.size()).c_str())); } nw = new Network(inputSize, outputSize, blocks, blockSize); cout << *nw << endl; } int getSamples(int inputNum, int outputNum, vector< vector< vector > >& inputs, vector< vector< vector > >& outputs, int trace) { vector< vector > sample_in; vector< vector > sample_out; vector timeStep_in(inputNum); vector timeStep_out(outputNum); int sampleNum = 0; int timeStep = 0; string line = ""; cout << "Reading Samples..." << endl << endl; // Read in all data from the file while (getline(cin, line)) { if (line.compare("") == 0) continue; if (line.compare("#end") == 0) { inputs.push_back(sample_in); outputs.push_back(sample_out); sample_in.clear(); sample_out.clear(); ++sampleNum; timeStep = 0; continue; } if (line.compare("#dataend") == 0) break; for (int i = 0; i < inputNum; ++i) timeStep_in[i] = line[i] - 48; for (int i = 0; i < outputNum; ++i) timeStep_out[i] = line[inputNum + i + 1] - 48; sample_in.push_back(timeStep_in); sample_out.push_back(timeStep_out); if (trace) { cout << "Sample In: "; for (int i = 0; i < timeStep_in.size(); ++i) cout << timeStep_in[i]; cout << endl; cout << "TimeStep: "; for (int i = 0; i < timeStep_out.size(); ++i) cout << timeStep_out[i]; cout << endl; } ++timeStep; } return 0; } double randomFloat() { long divisor = LONG_MAX; double candidate; do { candidate = (double)random() / (double)divisor; } while (candidate >= 1); return candidate; }