// // Created by fluffymoo on 11.6.18. // #include "NeuronLogistic.h" Neuron* NeuronLogistic::get_copy( ){ NeuronLogistic* output = new NeuronLogistic( this->activation_function_parameters[0], this->activation_function_parameters[1]); return output; } NeuronLogistic::NeuronLogistic(double a, double b) { this->n_activation_function_parameters = 2; this->activation_function_parameters = new double[2]; this->activation_function_parameters[0] = a; this->activation_function_parameters[1] = b; this->edges_in = new std::vector<Connection*>(0); this->edges_out = new std::vector<Connection*>(0); } void NeuronLogistic::activate( ) { double a = this->activation_function_parameters[0]; double b = this->activation_function_parameters[1]; double x = this->potential; double ex = std::pow(E, b - x); this->state = std::pow(1.0 + ex, -a); } double NeuronLogistic::activation_function_eval_partial_derivative(int param_idx ) { double a = this->activation_function_parameters[0]; double b = this->activation_function_parameters[1]; double x = this->potential; if(param_idx == 0){ //partial derivative according to parameter 'a' double ex = std::pow(E, b - x); double exa= -std::pow(ex + 1.0, -a); return exa * std::log(ex + 1.0); } else if(param_idx == 1){ //partial derivative according to parameter 'b' /** * TODO * Could be write as activation_function_get_derivative() * -1 */ double ex = std::pow(E, b - x); double ex2 = std::pow(ex + 1.0, -a - 1.0); return -a * ex * ex2; } return 0.0; } double NeuronLogistic::activation_function_eval_derivative( ) { double a = this->activation_function_parameters[0]; double b = this->activation_function_parameters[1]; double x = this->potential; double ex = std::pow(E, b - x); double ex2 = std::pow(ex + 1.0, -a - 1.0); return a * ex * ex2; } Neuron* NeuronLogistic::get_derivative() { NeuronLogistic_d1 *output = nullptr; double a = this->activation_function_parameters[0]; double b = this->activation_function_parameters[1]; output = new NeuronLogistic_d1(a, b); output->set_potential( this->potential ); return output; }