""" Logistic Regression with Stochastic Gradient Descent. Copyright (c) 2009, Naoaki Okazaki This code illustrates an implementation of logistic regression models trained by Stochastic Gradient Decent (SGD). This program reads a training set from STDIN, trains a logistic regression model, evaluates the model on a test set (given by the first argument) if specified, and outputs the feature weights to STDOUT. This is the typical usage of this problem: $ ./logistic_regression_sgd.py test.txt < train.txt Each line in a data set represents an instance that consists of binary features and label separated by TAB characters. This is the BNF notation of the data format: <line> ::= <label> ('\t' <feature>)+ '\n' <label> ::= '1' | '0' <feature> ::= <string> The following topics are not covered for simplicity: - bias term - regularization - real-valued features - multiclass logistic regression (maximum entropy model) - two or more iterations for training - calibration of learning rate This code requires Python 2.5 or later for collections.defaultdict(). """ import collections import math import sys N = 17997 # Change this to present the number of training instances. eta0 = 0.1 # Initial learning rate; change this if desired. def update(W, X, l, eta): # Compute the inner product of features and their weights. a = sum([W[x] for x in X]) # Compute the gradient of the error function (avoiding +Inf overflow). g = ((1. / (1. + math.exp(-a))) - l) if -100. < a else (0. - l) # Update the feature weights by Stochastic Gradient Descent. for x in X: W[x] -= eta * g def train(fi): t = 1 W = collections.defaultdict(float) # Loop for instances. for line in fi: fields = line.strip('\n').split('\t') update(W, fields[1:], float(fields[0]), eta0 / (1 + t / float(N))) t += 1 return W def classify(W, X): return 1 if 0. < sum([W[x] for x in X]) else 0 def test(W, fi): m = 0 n = 0 for line in fi: fields = line.strip('\n').split('\t') l = classify(W, fields[1:]) m += (1 - (l ^ int(fields[0]))) n += 1 print('Accuracy = %f (%d/%d)' % (m / float(n), m, n)) if __name__ == '__main__': W = train(sys.stdin) if 1 < len(sys.argv): test(W, open(sys.argv[1])) else: for name, value in W.iteritems(): print('%f\t%s' % (value, name))