Home > . > sensitivity.m

sensitivity

PURPOSE ^

[accuracyByParm,names,allvalues] = sensitivity(cv)

SYNOPSIS ^

function [accuracyByParm,names,allvalues] = sensitivity(cv)

DESCRIPTION ^

  [accuracyByParm,names,allvalues] = sensitivity(cv)

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [accuracyByParm,names,allvalues] = sensitivity(cv)
0002 %  [accuracyByParm,names,allvalues] = sensitivity(cv)
0003 
0004 %%% Collect all unique parameter values
0005 nTestTrials = length(cv.testResults);
0006 for ttrial = 1:nTestTrials
0007   [bestvalues,names] = getfields(cv.testResults{ttrial}{1});
0008   for i=1:length(names)
0009     col = find(strcmp(cv.cvresultsNames,names{i}));
0010     allvalues{ttrial}{i} = unique(cv.cvresults(:,col))';
0011   end
0012 end
0013 
0014 %%% combine cv results into their mean
0015 nreps = cv.validate.repetitions;
0016 mat = cv.cvresults(:,1:end-nreps);
0017 mat = [mat mean(cv.cvresults(:,end-nreps+1:end),2)];
0018 
0019 
0020 %%% accuracyByParm{testTrial}{parametervaried}
0021 for ttrial = 1:nTestTrials
0022   nparms = length(allvalues{ttrial});
0023   %% For each parameter, collect cv accuracy for each of its values
0024   for parm = 1:nparms
0025     if isnan(bestvalues(parm))
0026       continue;
0027     end
0028     rows = ones(size(mat,1),1);
0029     rows = rows & (mat(:,1) == ttrial);
0030     for p = setdiff(1:nparms,parm)
0031       if isnan(bestvalues(p))
0032     continue;
0033       end
0034       col = find(strcmp(cv.cvresultsNames,names{p}));
0035       rows = rows & (mat(:,col) == bestvalues(p));
0036 %      fprintf(1,'p %d %s col %d sum(rows) %d\n',...
0037 %      p,names{p},col,sum(rows));
0038     end
0039     col = find(strcmp(cv.cvresultsNames,names{parm}));
0040     accuracyByParm{ttrial}{parm} = [mat(rows,col) mat(rows,end)];
0041   end
0042 end
0043 
0044 figure(1);
0045 for p=1:length(accuracyByParm{1})
0046   subplot(2,3,p);
0047   for i=1:length(accuracyByParm)
0048     r = accuracyByParm{i}{p};
0049     z=plot(r(:,1),r(:,2),'o-k');
0050     set(z,'MarkerSize',3);
0051     xlabel(names{p});
0052     ylabel('Fraction Correct');
0053     hold on; 
0054   end; 
0055   if p==2
0056     title('Sensitivity Using Best Values');
0057   end
0058   hold off;
0059 end
0060 
0061 %%% accuracyByParm{testTrial}{parametervaried}
0062 for ttrial = 1:nTestTrials
0063   nparms = length(allvalues{ttrial});
0064   %% For each parameter, collect cv rmse for each of its values
0065   for parm = 1:nparms
0066     if isnan(bestvalues(parm))
0067       continue;
0068     end
0069     rows = ones(size(mat,1),1);
0070     rows = rows & (mat(:,1) == ttrial);
0071     %% For each value of parm, find best validation accuracy
0072     parmvals = allvalues{ttrial}{parm};
0073     curve = [];
0074     col = find(strcmp(cv.cvresultsNames,names{parm}));
0075     for pv = parmvals
0076       rowspv = rows & (mat(:,col) == pv);
0077 %if parm==2
0078 %  fprintf(1,'col %d %s sum(rows) %d pv %d sum(rowspv) %d\n',...
0079 %      col,names{parm},sum(rows),pv,sum(rowspv));
0080 %  end
0081       best = max(mat(rowspv,end));
0082       curve = [curve; pv best];
0083     end
0084     accuracyByParm{ttrial}{parm} = curve;
0085   end
0086 end
0087 
0088 figure(2);
0089 for p=1:length(accuracyByParm{1})
0090   subplot(2,3,p);
0091   for i=1:length(accuracyByParm)
0092     r = accuracyByParm{i}{p};
0093     z=plot(r(:,1),r(:,2),'o-k');
0094 set(z,'MarkerSize',3);
0095     xlabel(names{p});
0096     ylabel('Fraction Correct');
0097     hold on; 
0098   end; 
0099   if p==2
0100     title('Sensitivity Using All Values');
0101   end
0102   hold off;
0103 end
0104 
0105

Generated on Tue 07-Feb-2006 12:02:57 by m2html © 2003