0001
0002
0003
0004
0005 function [fractionCorrect,classifierOrClasses] = LDA(data,CV,selectInOneTrialF,command)
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021 if nargin > 3 && strcmp(command,'make')
0022
0023
0024
0025
0026
0027 if ~isfield(CV.parms,'firstMode')
0028 CV.parms.firstMode = 1;
0029 CV.parms.nModes = 1;
0030 end
0031
0032 [X,Y,Xval,Yval] = collectPartitions(data,CV,selectInOneTrialF,'train');
0033
0034 [p,N] = size(X);
0035 K = size(data.features,1);
0036
0037
0038
0039 X = X';
0040 c = cov(X);
0041 sq = inv(sqrtm(c));
0042 means = mean(X);
0043 Xs = (X - repmat(means,size(X,1),1)) * sq;
0044 Yindicator = zeros(length(Y),K);
0045 for i=1:length(Y)
0046 Yindicator(i,Y(i)) = 1;
0047 end
0048 A = ccakirby(Xs,Yindicator(:,1:end-1));
0049 X = X * A;
0050 X = X(:,1:CV.parms.fisherDim);
0051 X = X';
0052 classifier.means = means;
0053 classifier.sq = sq;
0054 classifier.proj = A;
0055 classifier.fisherDim = CV.parms.fisherDim;
0056
0057 Xval = Xval';
0058 Xval = (Xval - repmat(means,size(Xval,1),1)) * sq * A;
0059 Xval = Xval(:,1:CV.parms.fisherDim);
0060 Xval = Xval';
0061
0062 p = CV.parms.fisherDim;
0063
0064
0065 priors = zeros(1,K);
0066 means = zeros(K,p);
0067 covsum = zeros(p,p);
0068 classNames = unique(Y);
0069 for i = 1:length(classNames)
0070 mask = Y == classNames(i);
0071 Nthisclass = sum(mask);
0072 priors(i) = Nthisclass / N;
0073 Xmat = X(:,mask);
0074 means(i,:) = mean(Xmat');
0075 XmeanZero = Xmat - repmat(means(i,:)',1,Nthisclass);
0076 covsum = covsum + XmeanZero * XmeanZero';
0077 end
0078 covariance = covsum / (N-K);
0079 invcov = inv(covariance);
0080
0081
0082 weights = zeros(p,K);
0083 bias = zeros(K,1);
0084 for k = 1:K
0085 weights(:,k) = invcov * means(k,:)';
0086 bias(k) = - 0.5 * means(k,:) * weights(:,k) + log(priors(k));
0087 end
0088
0089 classifier.weights = weights;
0090 classifier.bias = bias;
0091
0092 if ~isempty(CV.validateIndices)
0093 disc = classifier.weights' * Xval + repmat(classifier.bias,1,size(Xval,2));
0094 [junk,whichmax] = max(disc);
0095 classes = classNames(whichmax);
0096 validateCorrect = sum(classes == Yval) / length(Yval);
0097 else
0098 validateCorrect = [];
0099 end
0100
0101 classifierOrClasses = classifier;
0102 fractionCorrect = validateCorrect;
0103
0104 elseif nargin > 3 && ~strcmp(command,'make')
0105
0106 error('Command to LDA must be ''make'' or absent.');
0107
0108 else
0109
0110
0111
0112
0113
0114 [X,Y] = collectPartitions(data,CV,selectInOneTrialF);
0115
0116
0117 X = X';
0118 X = (X - repmat(CV.classifier.means,size(X,1),1)) * CV.classifier.sq * CV.classifier.A;
0119 X = X(:,1:CV.classifier.fisherDim);
0120 X = X';
0121
0122 disc = CV.classifier.weights' * X + repmat(CV.classifier.bias,1,size(X,2));
0123 [junk,whichmax] = max(disc);
0124 classNames = unique(Y);
0125 classes = classNames(whichmax);
0126 testCorrect = sum(classes == Y) / length(Y);
0127
0128 classifierOrClasses = classes;
0129 fractionCorrect = testCorrect;
0130
0131 end