Home > . > LDA.m

LDA

PURPOSE ^

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

SYNOPSIS ^

function [fractionCorrect,classifierOrClasses] = LDA(data,CV,selectInOneTrialF,command)

DESCRIPTION ^

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Make a Linear Discriminant Classifier 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0002 %%% Make a Linear Discriminant Classifier
0003 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0004 
0005 function [fractionCorrect,classifierOrClasses] = LDA(data,CV,selectInOneTrialF,command)
0006 %%% [fractionCorrect,classifierOrClasses] = LDA(train,CV)
0007 %%%   train.features: ntasks x ntrials cell array of featurized data
0008 %%%        .classes: ntasks x ntrials matrix of class labels
0009 %%%   CV contains for collectTrainValidation()
0010 %%%     .parms.firstMode
0011 %%%     .parms.nModes
0012 %%%     .validateIndices: ntasks x ntrials cell array of indices
0013 %%%     .classes:  ntasks x ntrials cell array of class labels (ints)
0014 %%%   command:  Optional. If present, must be 'make' and classifier is created.
0015 %%% Returns
0016 %%%   classifier.weights
0017 %%%             .bias
0018 %%%   validateCorrect: fraction of validation samples correct
0019 %%%   X: training data
0020 
0021 if nargin > 3 && strcmp(command,'make')
0022 
0023   %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0024   %%% make the classifier
0025   %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0026 
0027   if ~isfield(CV.parms,'firstMode')
0028     CV.parms.firstMode = 1;  %so this works if data.features are not cell array of
0029     CV.parms.nModes = 1;    %windows, but just nchannels x nsamples matrix
0030   end
0031 
0032   [X,Y,Xval,Yval] = collectPartitions(data,CV,selectInOneTrialF,'train');
0033 
0034   [p,N] = size(X);
0035   K = size(data.features,1);
0036 
0037   %% project to first fisherDim Fisher Dimensions
0038   %% normalize data
0039   X = X';
0040   c = cov(X);
0041   sq = inv(sqrtm(c));
0042   means = mean(X);
0043   Xs = (X - repmat(means,size(X,1),1)) * sq;
0044   Yindicator = zeros(length(Y),K);
0045   for i=1:length(Y)
0046     Yindicator(i,Y(i)) = 1;
0047   end
0048   A = ccakirby(Xs,Yindicator(:,1:end-1));
0049   X = X * A;
0050   X = X(:,1:CV.parms.fisherDim);
0051   X = X';
0052   classifier.means = means;
0053   classifier.sq = sq;
0054   classifier.proj = A;
0055   classifier.fisherDim = CV.parms.fisherDim;
0056   
0057   Xval = Xval';
0058   Xval = (Xval - repmat(means,size(Xval,1),1)) * sq * A;
0059   Xval = Xval(:,1:CV.parms.fisherDim);
0060   Xval = Xval';
0061   
0062   p = CV.parms.fisherDim;
0063   
0064   %% Calculate class priors, means and sum of covariance matrices.
0065   priors = zeros(1,K);
0066   means = zeros(K,p);
0067   covsum = zeros(p,p);
0068   classNames = unique(Y);
0069   for i = 1:length(classNames)
0070     mask = Y == classNames(i);
0071     Nthisclass = sum(mask);
0072     priors(i) = Nthisclass / N;
0073     Xmat = X(:,mask);
0074     means(i,:) = mean(Xmat');
0075     XmeanZero = Xmat - repmat(means(i,:)',1,Nthisclass);
0076     covsum = covsum + XmeanZero * XmeanZero';
0077   end
0078   covariance = covsum / (N-K);
0079   invcov = inv(covariance);
0080 
0081   %% Calculate weights and biases for each class
0082   weights = zeros(p,K);
0083   bias = zeros(K,1);
0084   for k = 1:K
0085     weights(:,k) = invcov * means(k,:)';
0086     bias(k) = - 0.5 * means(k,:) * weights(:,k) + log(priors(k));
0087   end
0088 
0089   classifier.weights = weights;
0090   classifier.bias = bias;
0091   
0092   if ~isempty(CV.validateIndices)
0093     disc = classifier.weights' * Xval + repmat(classifier.bias,1,size(Xval,2));
0094     [junk,whichmax] = max(disc);
0095     classes = classNames(whichmax);
0096     validateCorrect = sum(classes == Yval) / length(Yval);
0097   else
0098     validateCorrect = [];
0099   end
0100 
0101   classifierOrClasses = classifier;
0102   fractionCorrect = validateCorrect;
0103 
0104 elseif nargin > 3 && ~strcmp(command,'make')
0105 
0106   error('Command to LDA must be ''make'' or absent.');
0107 
0108 else
0109   
0110   %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0111   %%% use the classifier
0112   %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0113 
0114   [X,Y] = collectPartitions(data,CV,selectInOneTrialF);
0115 
0116   %% fisher
0117   X = X';
0118   X = (X - repmat(CV.classifier.means,size(X,1),1)) * CV.classifier.sq * CV.classifier.A;
0119   X = X(:,1:CV.classifier.fisherDim);
0120   X = X';
0121 
0122   disc = CV.classifier.weights' * X + repmat(CV.classifier.bias,1,size(X,2));
0123   [junk,whichmax] = max(disc);
0124   classNames = unique(Y);
0125   classes = classNames(whichmax);
0126   testCorrect = sum(classes == Y) / length(Y);
0127 
0128   classifierOrClasses = classes;
0129   fractionCorrect = testCorrect;
0130 
0131 end

Generated on Tue 07-Feb-2006 12:02:57 by m2html © 2003