LevenbergMarquardtTraining
org.encog.neural.networks.training.lma

Class LevenbergMarquardtTraining

  • All Implemented Interfaces:
    MLTrain, MultiThreadable


    public class LevenbergMarquardtTrainingextends BasicTrainingimplements MultiThreadable
    Trains a neural network using a Levenberg Marquardt algorithm (LMA). This training technique is based on the mathematical technique of the same name. The LMA interpolates between the Gauss-Newton algorithm (GNA) and the method of gradient descent (similar to what is used by backpropagation. The lambda parameter determines the degree to which GNA and Gradient Descent are used. A lower lambda results in heavier use of GNA, whereas a higher lambda results in a heavier use of gradient descent. Each iteration starts with a low lambda that builds if the improvement to the neural network is not desirable. At some point the lambda is high enough that the training method reverts totally to gradient descent. This allows the neural network to be trained effectively in cases where GNA provides the optimal training time, but has the ability to fall back to the more primitive gradient descent method LMA finds only a local minimum, not a global minimum. References: http://www.heatonresearch.com/wiki/LMA http://en.wikipedia.org/wiki/Levenberg%E2%80%93Marquardt_algorithm http://en.wikipedia.org/wiki/Finite_difference_method http://crsouza.blogspot.com/2009/11/neural-network-learning-by-levenberg_18.html http://mathworld.wolfram.com/FiniteDifference.html http://www-alg.ist.hokudai.ac.jp/~jan/alpha.pdf - http://www.inference.phy.cam.ac.uk/mackay/Bayes_FAQ.html
    • Field Detail

      • SCALE_LAMBDA

        public static final double SCALE_LAMBDA
        The amount to scale the lambda by.
        See Also:
        Constant Field Values
      • LAMBDA_MAX

        public static final double LAMBDA_MAX
        The max amount for the LAMBDA.
        See Also:
        Constant Field Values
    • Constructor Detail

      • LevenbergMarquardtTraining

        public LevenbergMarquardtTraining(BasicNetwork network,                          MLDataSet training)
        Construct the LMA object.
        Parameters:
        network - The network to train. Must have a single output neuron.
        training - The training data to use. Must be indexable.
      • LevenbergMarquardtTraining

        public LevenbergMarquardtTraining(BasicNetwork network,                          MLDataSet training,                          ComputeHessian h)
        Construct the LMA object.
        Parameters:
        network - The network to train. Must have a single output neuron.
        training - The training data to use. Must be indexable.
    • Method Detail

      • canContinue

        public boolean canContinue()
        Specified by:
        canContinue in interface MLTrain
        Returns:
        True if the training can be paused, and later continued.
      • getMethod

        public MLMethod getMethod()
        Description copied from interface: MLTrain
        Get the current best machine learning method from the training.
        Specified by:
        getMethod in interface MLTrain
        Returns:
        The trained network.
      • iteration

        public void iteration()
        Perform one iteration.
        Specified by:
        iteration in interface MLTrain
      • resume

        public void resume(TrainingContinuation state)
        Resume training.
        Specified by:
        resume in interface MLTrain
        Parameters:
        state - The training continuation object to use to continue.
      • updateWeights

        public void updateWeights()
        Update the weights in the neural network.
      • getHessian

        public ComputeHessian getHessian()
        Returns:
        The Hessian calculation method used.
      • getThreadCount

        public int getThreadCount()
        Specified by:
        getThreadCount in interface MultiThreadable
        Returns:
        The number of threads to use, 0 to automatically determine based on core count.
      • setThreadCount

        public void setThreadCount(int numThreads)
        Description copied from interface: MultiThreadable
        Set the number of threads to use.
        Specified by:
        setThreadCount in interface MultiThreadable
        Parameters:
        numThreads - The number of threads to use, or zero to automatically determine based on core count.

SCaVis 1.8 © jWork.org