Class SGD

All Implemented Interfaces:
Serializable, Cloneable, Classifier, UpdateableClassifier, Aggregateable<SGD>, BatchPredictor, CapabilitiesHandler, CapabilitiesIgnorer, CommandlineRunnable, OptionHandler, Randomizable, RevisionHandler

Implements stochastic gradient descent for learning various linear models (binary class SVM, binary class logistic regression, squared loss, Huber loss and epsilon-insensitive loss linear regression). Globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data.
For numeric class attributes, the squared, Huber or epsilon-insensitve loss function must be used. Epsilon-insensitive and Huber loss may require a much higher learning rate.

Valid options are:

 -F
  Set the loss function to minimize.
  0 = hinge loss (SVM), 1 = log loss (logistic regression),
  2 = squared loss (regression), 3 = epsilon insensitive loss (regression),
  4 = Huber loss (regression).
  (default = 0)
 -L
  The learning rate. If normalization is
  turned off (as it is automatically for streaming data), then the
  default learning rate will need to be reduced (try 0.0001).
  (default = 0.01).
 -R <double>
  The lambda regularization constant (default = 0.0001)
 -E <integer>
  The number of epochs to perform (batch learning only, default = 500)
 -C <double>
  The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
 -N
  Don't normalize the data
 -M
  Don't replace missing values
 -S <num>
  Random number seed.
  (default 1)
 -output-debug-info
  If set, classifier is run in debug mode and
  may output additional info to the console
 -do-not-check-capabilities
  If set, classifier capabilities are not checked before classifier is built
  (use with caution).
Version:
$Revision: 15519 $
Author:
Eibe Frank (eibe{[at]}cs{[dot]}waikato{[dot]}ac{[dot]}nz), Mark Hall (mhall{[at]}pentaho{[dot]}com)
See Also:
  • Field Details

    • HINGE

      public static final int HINGE
      the hinge loss function.
      See Also:
    • LOGLOSS

      public static final int LOGLOSS
      the log loss function.
      See Also:
    • SQUAREDLOSS

      public static final int SQUAREDLOSS
      the squared loss function.
      See Also:
    • EPSILON_INSENSITIVE

      public static final int EPSILON_INSENSITIVE
      The epsilon insensitive loss function
      See Also:
    • HUBER

      public static final int HUBER
      The Huber loss function
      See Also:
    • TAGS_SELECTION

      public static final Tag[] TAGS_SELECTION
      Loss functions to choose from
  • Constructor Details

    • SGD

      public SGD()
  • Method Details

    • getCapabilities

      public Capabilities getCapabilities()
      Returns default capabilities of the classifier.
      Specified by:
      getCapabilities in interface CapabilitiesHandler
      Specified by:
      getCapabilities in interface Classifier
      Overrides:
      getCapabilities in class AbstractClassifier
      Returns:
      the capabilities of this classifier
      See Also:
    • epsilonTipText

      public String epsilonTipText()
      Returns the tip text for this property
      Returns:
      tip text for this property suitable for displaying in the explorer/experimenter gui
    • setEpsilon

      public void setEpsilon(double e)
      Set the epsilon threshold on the error for epsilon insensitive and Huber loss functions
      Parameters:
      e - the value of epsilon to use
    • getEpsilon

      public double getEpsilon()
      Get the epsilon threshold on the error for epsilon insensitive and Huber loss functions
      Returns:
      the value of epsilon to use
    • lambdaTipText

      public String lambdaTipText()
      Returns the tip text for this property
      Returns:
      tip text for this property suitable for displaying in the explorer/experimenter gui
    • setLambda

      public void setLambda(double lambda)
      Set the value of lambda to use
      Parameters:
      lambda - the value of lambda to use
    • getLambda

      public double getLambda()
      Get the current value of lambda
      Returns:
      the current value of lambda
    • setLearningRate

      public void setLearningRate(double lr)
      Set the learning rate.
      Parameters:
      lr - the learning rate to use.
    • getLearningRate

      public double getLearningRate()
      Get the learning rate.
      Returns:
      the learning rate
    • learningRateTipText

      public String learningRateTipText()
      Returns the tip text for this property
      Returns:
      tip text for this property suitable for displaying in the explorer/experimenter gui
    • epochsTipText

      public String epochsTipText()
      Returns the tip text for this property
      Returns:
      tip text for this property suitable for displaying in the explorer/experimenter gui
    • setEpochs

      public void setEpochs(int e)
      Set the number of epochs to use
      Parameters:
      e - the number of epochs to use
    • getEpochs

      public int getEpochs()
      Get current number of epochs
      Returns:
      the current number of epochs
    • setDontNormalize

      public void setDontNormalize(boolean m)
      Turn normalization off/on.
      Parameters:
      m - true if normalization is to be disabled.
    • getDontNormalize

      public boolean getDontNormalize()
      Get whether normalization has been turned off.
      Returns:
      true if normalization has been disabled.
    • dontNormalizeTipText

      public String dontNormalizeTipText()
      Returns the tip text for this property
      Returns:
      tip text for this property suitable for displaying in the explorer/experimenter gui
    • setDontReplaceMissing

      public void setDontReplaceMissing(boolean m)
      Turn global replacement of missing values off/on. If turned off, then missing values are effectively ignored.
      Parameters:
      m - true if global replacement of missing values is to be turned off.
    • getDontReplaceMissing

      public boolean getDontReplaceMissing()
      Get whether global replacement of missing values has been disabled.
      Returns:
      true if global replacement of missing values has been turned off
    • dontReplaceMissingTipText

      public String dontReplaceMissingTipText()
      Returns the tip text for this property
      Returns:
      tip text for this property suitable for displaying in the explorer/experimenter gui
    • setLossFunction

      public void setLossFunction(SelectedTag function)
      Set the loss function to use.
      Parameters:
      function - the loss function to use.
    • getLossFunction

      public SelectedTag getLossFunction()
      Get the current loss function.
      Returns:
      the current loss function.
    • lossFunctionTipText

      public String lossFunctionTipText()
      Returns the tip text for this property
      Returns:
      tip text for this property suitable for displaying in the explorer/experimenter gui
    • listOptions

      public Enumeration<Option> listOptions()
      Returns an enumeration describing the available options.
      Specified by:
      listOptions in interface OptionHandler
      Overrides:
      listOptions in class RandomizableClassifier
      Returns:
      an enumeration of all the available options.
    • setOptions

      public void setOptions(String[] options) throws Exception
      Parses a given list of options.

      Valid options are:

       -F
        Set the loss function to minimize.
        0 = hinge loss (SVM), 1 = log loss (logistic regression),
        2 = squared loss (regression), 3 = epsilon insensitive loss (regression),
        4 = Huber loss (regression).
        (default = 0)
       -L
        The learning rate. If normalization is
        turned off (as it is automatically for streaming data), then the
        default learning rate will need to be reduced (try 0.0001).
        (default = 0.01).
       -R <double>
        The lambda regularization constant (default = 0.0001)
       -E <integer>
        The number of epochs to perform (batch learning only, default = 500)
       -C <double>
        The epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)
       -N
        Don't normalize the data
       -M
        Don't replace missing values
       -S <num>
        Random number seed.
        (default 1)
       -output-debug-info
        If set, classifier is run in debug mode and
        may output additional info to the console
       -do-not-check-capabilities
        If set, classifier capabilities are not checked before classifier is built
        (use with caution).
      Specified by:
      setOptions in interface OptionHandler
      Overrides:
      setOptions in class RandomizableClassifier
      Parameters:
      options - the list of options as an array of strings
      Throws:
      Exception - if an option is not supported
    • getOptions

      public String[] getOptions()
      Gets the current settings of the classifier.
      Specified by:
      getOptions in interface OptionHandler
      Overrides:
      getOptions in class RandomizableClassifier
      Returns:
      an array of strings suitable for passing to setOptions
    • globalInfo

      public String globalInfo()
      Returns a string describing classifier
      Returns:
      a description suitable for displaying in the explorer/experimenter gui
    • reset

      public void reset()
      Reset the classifier.
    • buildClassifier

      public void buildClassifier(Instances data) throws Exception
      Method for building the classifier.
      Specified by:
      buildClassifier in interface Classifier
      Parameters:
      data - the set of training instances.
      Throws:
      Exception - if the classifier can't be built successfully.
    • updateClassifier

      public void updateClassifier(Instance instance) throws Exception
      Updates the classifier with the given instance.
      Specified by:
      updateClassifier in interface UpdateableClassifier
      Parameters:
      instance - the new training instance to include in the model
      Throws:
      Exception - if the instance could not be incorporated in the model.
    • distributionForInstance

      public double[] distributionForInstance(Instance inst) throws Exception
      Computes the distribution for a given instance
      Specified by:
      distributionForInstance in interface Classifier
      Overrides:
      distributionForInstance in class AbstractClassifier
      Parameters:
      inst - the instance for which distribution is computed
      Returns:
      the distribution
      Throws:
      Exception - if the distribution can't be computed successfully
    • getWeights

      public double[] getWeights()
    • toString

      public String toString()
      Prints out the classifier.
      Overrides:
      toString in class Object
      Returns:
      a description of the classifier as a string
    • getRevision

      public String getRevision()
      Returns the revision string.
      Specified by:
      getRevision in interface RevisionHandler
      Overrides:
      getRevision in class AbstractClassifier
      Returns:
      the revision
    • aggregate

      public SGD aggregate(SGD toAggregate) throws Exception
      Aggregate an object with this one
      Specified by:
      aggregate in interface Aggregateable<SGD>
      Parameters:
      toAggregate - the object to aggregate
      Returns:
      the result of aggregation
      Throws:
      Exception - if the supplied object can't be aggregated for some reason
    • finalizeAggregation

      public void finalizeAggregation() throws Exception
      Call to complete the aggregation process. Allows implementers to do any final processing based on how many objects were aggregated.
      Specified by:
      finalizeAggregation in interface Aggregateable<SGD>
      Throws:
      Exception - if the aggregation can't be finalized for some reason
    • main

      public static void main(String[] args)
      Main method for testing this class.