diff --git a/src/tests/NeuralNetwork_test.cpp b/src/tests/NeuralNetwork_test.cpp index a5b4350f097d23dd5273c5a842dd30f0e571c317..7f5f672f086803c394c484ba10d3d7ae610c5296 100644 --- a/src/tests/NeuralNetwork_test.cpp +++ b/src/tests/NeuralNetwork_test.cpp @@ -82,4 +82,135 @@ BOOST_AUTO_TEST_SUITE(NeuralNetwork_test) BOOST_CHECK_THROW(network.add_connection_general(0, 2, &f, para, w_array, 5), std::out_of_range); } + BOOST_AUTO_TEST_CASE(NeuralNetwork_get_subnet_test) { + Neuron *n1 = new NeuronLinear(1, 1); + Neuron *n2 = new NeuronLinear(2, 2); + Neuron *n3 = new NeuronLinear(3, 3); + Neuron *n4 = new NeuronLinear(4, 4); + + NeuralNetwork network; + network.add_neuron(n1); + network.add_neuron(n2); + network.add_neuron(n3); + network.add_neuron(n4); + + network.add_connection_simple(0, 1, 0, 2.5); + network.add_connection_simple(0, 3, 0, 2.5); + network.add_connection_simple(2, 1, 0, 2.5); + network.add_connection_simple(2, 3, 0, 2.5); + + std::vector<size_t> input_neuron_indices(1); + input_neuron_indices.push_back(0); + + std::vector<size_t> output_neuron_indices(1); + output_neuron_indices.push_back(1); + + NeuralNetwork *network2 = network.get_subnet(input_neuron_indices, output_neuron_indices); + + BOOST_CHECK_EQUAL(2, network2->add_neuron(n1)); + } + + BOOST_AUTO_TEST_CASE(NeuralNetwork_specify_input_neurons_test) { + Neuron *n1 = new NeuronLinear(1, 1); + Neuron *n2 = new NeuronLinear(2, 2); + NeuralNetwork network; + network.add_neuron(n1); + network.add_neuron(n2); + + network.add_connection_simple(0, 1, 0, 2.5); + + BOOST_CHECK_EQUAL(0, network.get_n_inputs()); + std::vector<size_t> input_neuron_indices(1); + input_neuron_indices[0] = (size_t) 0; + + network.specify_input_neurons(input_neuron_indices); + + BOOST_CHECK_EQUAL(1, network.get_n_inputs()); + } + + BOOST_AUTO_TEST_CASE(NeuralNetwork_specify_output_neurons_test) { + Neuron *n1 = new NeuronLinear(1, 1); + Neuron *n2 = new NeuronLinear(2, 2); + NeuralNetwork network; + network.add_neuron(n1); + network.add_neuron(n2); + + network.add_connection_simple(0, 1, 0, 2.5); + + BOOST_CHECK_EQUAL(0, network.get_n_outputs()); + std::vector<size_t> output_neuron_indices(1); + output_neuron_indices[0] = (size_t) 1; + + network.specify_output_neurons(output_neuron_indices); + + BOOST_CHECK_EQUAL(1, network.get_n_outputs()); + } + + BOOST_AUTO_TEST_CASE(NeuralNetwork_get_weights_test) { + Neuron *n1 = new NeuronLinear(1, 1); + Neuron *n2 = new NeuronLinear(2, 2); + NeuralNetwork network; + network.add_neuron(n1); + network.add_neuron(n2); + + network.add_connection_simple(0, 1, 0, 2.5); + + BOOST_CHECK_EQUAL(1, network.get_n_weights()); + } + + BOOST_AUTO_TEST_CASE(NeuralNetwork_eval_single_test) { + Neuron *n1 = new NeuronLinear(1, 1); + Neuron *n2 = new NeuronLinear(2, 2); + NeuralNetwork network; + network.add_neuron(n1); + network.add_neuron(n2); + + network.add_connection_simple(0, 1, -1, 2.5); + + std::vector<size_t> output_neuron_indices(1); + output_neuron_indices[0] = (size_t) 1; + network.specify_output_neurons(output_neuron_indices); + + std::vector<size_t> input_neuron_indices(1); + input_neuron_indices[0] = (size_t) 0; + network.specify_input_neurons(input_neuron_indices); + + std::vector<double> input; + input.push_back(1); + std::vector<double> output; + output.push_back(1); + + network.eval_single(input, output); + BOOST_CHECK_EQUAL(12, output.at(0)); + } + + + BOOST_AUTO_TEST_CASE(NeuralNetwork_eval_single_weights_test) { + Neuron *n1 = new NeuronLinear(1, 1); + Neuron *n2 = new NeuronLinear(2, 2); + NeuralNetwork network; + network.add_neuron(n1); + network.add_neuron(n2); + + network.add_connection_simple(0, 1, -1, 2.5); + + std::vector<size_t> output_neuron_indices(1); + output_neuron_indices[0] = (size_t) 1; + network.specify_output_neurons(output_neuron_indices); + + std::vector<size_t> input_neuron_indices(1); + input_neuron_indices[0] = (size_t) 0; + network.specify_input_neurons(input_neuron_indices); + + std::vector<double> input; + input.push_back(1); + std::vector<double> output; + output.push_back(1); + + double weights = 5; + network.get_n_weights(); + network.eval_single(input, output, &weights); + BOOST_CHECK_EQUAL(22, output.at(0)); + } + BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file