Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

statisticsgatherers.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #if !defined _CLUS_STATISTICSGATHERERS_H
00035 #define _CLUS_STATISTICSGATHERERS_H
00036 
00037 #include "general.h"
00038 #include "splitpointcomputation.h" // for gini and split point computation
00039 #include "discretepermutationtransformation.h"
00040 #include "continuouslineartransformation.h"
00041 
00042 #include <vector>
00043 #include <math.h>
00044 
00045 #ifdef DEBUG_PRINT
00046 #include <iostream>
00047 #endif
00048 
00049 using namespace TNT;
00050 using namespace std;
00051 
00052 namespace CLUS
00053 {
00054 
00055 enum ShiftType { labeled, unlabeled };
00056 
00057 class IndexedValue
00058 {
00059     int index;
00060     double value;
00061     
00062 public:
00063     IndexedValue(int i=0, double v=0.0):
00064             index(i), value(v)
00065     {}
00066     
00067     int getIndex(void)
00068     {
00069         return index;
00070     }
00071     
00072     double getValue(void) const
00073     {
00074         return value;
00075     }
00076     
00077     bool operator < (const IndexedValue& B) const
00078     {
00079         return (value < B.getValue());
00080     }
00081     
00082     bool operator != (const IndexedValue& B) const
00083     {
00084         return (value != B.getValue());
00085     }
00086 };
00087 
00088 class BasicBinomialStatistics
00089 {
00090 protected:
00091     /// Counts for each value of the attribute for class 0    
00092     Vector<double> countsC0;
00093     
00094     /// Counts for each value of the attribute (total)
00095     Vector<double> counts;
00096 
00097     bool gainComputed;
00098 
00099     double ComputeSplitPoint(double N, double alpha_1)
00100     {
00101         return 0.0;
00102     }
00103 
00104 public:
00105     BasicBinomialStatistics(int DomainSize=0):
00106             countsC0(DomainSize), counts(DomainSize), gainComputed(false)
00107     {}
00108 
00109     double getCount()
00110     {
00111         double total=0;
00112         for (int i=0; i<countsC0.size(); i++)
00113             total+= counts[i];
00114         return total;
00115     }
00116 
00117     void Print(void)
00118     {
00119         cout << "CountC0 [ ";
00120         for (int i=0; i<countsC0.size(); i++)
00121             cout << countsC0[i] << " ";
00122         cout << " ]" << endl;
00123 
00124         cout << "Count [";
00125         for (int i=0; i<counts.size(); i++)
00126             cout << counts[i] << " ";
00127         cout << " ]" << endl;
00128     }
00129 
00130     void ResetDomainSize(int DomainSize)
00131     {
00132         countsC0.newsize(DomainSize);
00133         counts.newsize(DomainSize);
00134         for (int i =0; i< DomainSize; i++)
00135         {
00136             countsC0[i]=0.0;
00137             counts[i]=0.0;
00138         }
00139     }
00140 
00141     void UpdateStatistics(int value, int classLabel, double probability=1.0)
00142     {
00143         assert(probability>=0.0 && probability<=1.0);
00144 
00145         if (classLabel==0)
00146             countsC0[value]+=probability;
00147         counts[value]+=probability;
00148 #ifdef DEBUG_PRINT
00149 
00150         cout << " value=" << value <<  " counts[value]=" << counts[value]
00151         << " countsC0[value]=" << countsC0[value] << endl;
00152 #endif
00153 
00154         assert (isfinite(counts[value]));
00155     }
00156 
00157     void UpdateStatisticsP(int value, double p0, double p1)
00158     {
00159         countsC0[value]+=p0;
00160         counts[value]+=p0+p1;
00161         assert (isfinite(counts[value]));
00162     }
00163 };
00164 
00165 
00166 class BinomialStatistics : public BasicBinomialStatistics
00167 {
00168 protected:
00169     /// the values in the left set; split point
00170     Vector<int> Split; 
00171 
00172     double ComputeSplitPoint(double N, double alpha_1)
00173     {
00174         return DiscreteGiniGain(countsC0, counts, N, alpha_1, Split);
00175     }
00176 
00177 public:
00178     BinomialStatistics(int DomainSize=0): BasicBinomialStatistics(DomainSize),
00179             Split(0)
00180     {}
00181 
00182     Vector<int>& GetSplit(void)
00183     {
00184         if (gainComputed)
00185             return Split;
00186         else
00187         {
00188             ComputeGiniGain();
00189             return Split;
00190         }
00191     }
00192 
00193     double ComputeGiniGain(void)
00194     {
00195         // create a temporary vector of doubles out of countsC0 so that
00196         // DiscreteGiniGain function can be used
00197         gainComputed = true;
00198         double N=0.0; // keeps the total count
00199         double NC0=0.0;
00200         for (int i=0; i<counts.dim(); i++)
00201         {
00202             N+=counts[i];
00203             NC0+=countsC0[i];
00204         }
00205 #ifdef DEBUG_PRINT
00206         cout << "Discrete. ComputeGiniGain " << N << " " << NC0 << "\t";
00207 #endif
00208 
00209         double alpha_1=NC0/N;
00210         //cerr << alpha_1 << "?" << endl;
00211         if (alpha_1!=alpha_1 || N==0.0 )
00212         {
00213             //cerr << "no discrete data"<< endl;
00214             return 0;
00215         }
00216 
00217         double val = DiscreteGiniGain(countsC0, counts, N, alpha_1, Split);
00218 #ifdef DEBUG_PRINT
00219 
00220         cout << val << endl;
00221 #endif
00222 
00223         return val;
00224     }
00225     
00226     // entries r_i/q_i
00227     vector<IndexedValue> getConditionalProbs(void)
00228     {
00229         vector<IndexedValue> myVals(counts.size());
00230         double C0_total=0;//how many C0 examples were seen
00231         double total = 0;//how many examples were seen
00232         for (int i=0; i<counts.size(); i++)
00233         {
00234             C0_total+= countsC0[i];
00235             total += counts[i];
00236         }
00237         // cerr << "0's: " << C0_total << " total: " << total << endl;
00238         for (int i=0; i<counts.size(); i++)
00239         {
00240             double val = (double)(countsC0[i]/C0_total)/(counts[i]/total);
00241             //  cerr << "(" << i << ", " << val << endl;
00242             IndexedValue iv(i, val);
00243             myVals[i] = iv;
00244         }
00245         return myVals;
00246     }
00247 
00248     Permutation ComputeShift(ShiftType shift, BinomialStatistics& aux)
00249     {
00250         // compute the permutation such that the rank of the
00251         // probabilities to see classLabel=0 in both is the same
00252         if(!gainComputed)
00253         {
00254             ComputeGiniGain();
00255         }
00256         vector<IndexedValue> myVals = getConditionalProbs();
00257         vector<IndexedValue> refVals = aux.getConditionalProbs();
00258         sort(myVals.begin(), myVals.end());
00259         sort(refVals.begin(), refVals.end());
00260         Vector<int> perm(counts.size());
00261         for (int i =0; i<counts.size(); i++)
00262         {
00263             int myNextIndex = myVals[i].getIndex();
00264             int refNextIndex = refVals[i].getIndex();
00265             //  cerr << "mapping " << refNextIndex << " to " <<  myNextIndex << endl;
00266             perm[refNextIndex] = myNextIndex;
00267         }
00268         return Permutation(perm);
00269     }
00270 
00271     void AddWeightedStatistics(BinomialStatistics& aux, double weight)
00272     {
00273         assert(counts.dim() == aux.counts.dim());
00274 
00275         if (weight==0.0)
00276             return;
00277 
00278         for (int i=0; i<counts.dim(); i++)
00279         {
00280             counts[i]+=weight*aux.counts[i];
00281             countsC0[i]+=weight*aux.countsC0[i];
00282         }
00283     }
00284 
00285     void CorrectWeightedStatistics(double totalWeight)
00286     {
00287         for (int i=0; i<counts.dim(); i++)
00288         {
00289             counts[i]/=totalWeight;
00290             countsC0[i]/=totalWeight;
00291         }
00292     }
00293 
00294     void AddStatisticsShifted(BinomialStatistics& aux, Permutation shift)
00295     {
00296         assert(counts.dim() == aux.counts.dim());
00297         // apply the shift when adding the statistics
00298         for (int i=0; i<counts.dim(); i++)
00299         {
00300             int shiftedIndex=shift.ApplyPermutation(i);
00301             counts[i]+=aux.counts[shiftedIndex];
00302             countsC0[i]+=aux.countsC0[shiftedIndex];
00303         }
00304     }
00305 
00306 };
00307 
00308 class ProbabilisticBinomialStatistics : public BasicBinomialStatistics
00309 {
00310 protected:
00311     Vector<double> probSet; // probability to be at the left for each value
00312 
00313 public:
00314     ProbabilisticBinomialStatistics(int DomainSize=0): BasicBinomialStatistics(DomainSize),
00315             probSet(DomainSize)
00316     {}
00317 
00318     Vector<double>& GetProbabilitySet(void)
00319     {
00320         if (gainComputed)
00321             return probSet;
00322         else
00323         {
00324             ComputeGiniGain();
00325             return probSet;
00326         }
00327     }
00328 
00329     double ComputeGiniGain(void)
00330     {
00331         // create a temporary vector of doubles out of countsC0 so that
00332         // DiscreteGiniGain function can be used
00333         gainComputed = true;
00334         double N=0.0; // keeps the total count
00335         double NC0=0.0;
00336         for (int i=0; i<counts.dim(); i++)
00337         {
00338 #ifdef DEBUG_PRINT
00339             cout << " counts[" << i << "]=" << counts[i]
00340             << " countsC0[" << i << "]=" << countsC0[i] << endl;
00341 #endif
00342 
00343             N+=counts[i];
00344             NC0+=countsC0[i];
00345         }
00346 #ifdef DEBUG_PRINT
00347         cout << "Discrete. ComputeGiniGain " << N << " " << NC0 << "\t";
00348 #endif
00349 
00350         double alpha_1=NC0/N;
00351         //cerr << alpha_1 << "?" << endl;
00352         if (alpha_1!=alpha_1 || N==0.0 )
00353         {
00354             //cerr << "no discrete data"<< endl;
00355             return 0;
00356         }
00357 
00358         double val = ProbabilisticDiscreteGiniGain(countsC0, counts, N, alpha_1, probSet);
00359 #ifdef DEBUG_PRINT
00360 
00361         cout << val << endl;
00362 #endif
00363 
00364         return val;
00365     }
00366 
00367 
00368 };
00369 
00370 class NormalStatistics
00371 {
00372 protected:
00373     double countC0;
00374     double countC1;
00375 
00376     /// maintain the sum of values for class label 0
00377     double sumC0; 
00378     double sumC1;
00379 
00380     /// maintain the sum of squares for class label 0
00381     double sum2C0; 
00382     double sum2C1;
00383     double split;
00384 
00385     /// Solution returned by QDA
00386     int whichSol; 
00387     double splitVariance;
00388 
00389     bool splitVarComputed;
00390 
00391 
00392 
00393 public:
00394     NormalStatistics(int dummy=0)
00395     {
00396         Reset();
00397     }
00398 
00399     void Reset(void)
00400     {
00401         countC0=0.0;
00402         countC1=0.0;
00403         sumC0=0.0;
00404         sumC1=0.0;
00405         sum2C0=0.0;
00406         sum2C1=0.0;
00407         splitVarComputed=false;
00408     }
00409 
00410     void Print(void)
00411     {
00412         cout << "Statistics: " << countC0 << " " << countC1 << " ";
00413         cout << sumC0 << " " << sumC1 << " " << sum2C0 << " " << sum2C1 << endl;
00414     }
00415 
00416     double getVariance(void)
00417     {
00418         return (sum2C0+sum2C1)/(countC0+countC1) - pow2((sumC0+sumC1)/(countC0+countC1));
00419     }
00420 
00421     double getSplitVariance(void)
00422     {
00423         if (!splitVarComputed)
00424         {
00425             ComputeGiniGain();
00426         }
00427         //return 1.0e-32; // FIXTHIS
00428         return splitVariance;// is this variance???
00429     }
00430 
00431     void UpdateStatistics(double value, int classLabel, double probability=1.0)
00432     {
00433         if (classLabel==0)
00434         {
00435             countC0+=probability;
00436             sumC0+=probability*value;
00437             sum2C0+=probability*pow2(value);
00438         }
00439         else
00440         {
00441             countC1+=probability;
00442             sumC1+=probability*value;
00443             sum2C1+=probability*pow2(value);
00444         }
00445         assert(isfinite(sumC0));// NaN
00446     }
00447 
00448     double GetCenter(void)
00449     {
00450         return (sumC0+sumC1)/(countC0+countC1);
00451     }
00452 
00453     double ComputeGiniGain(void)
00454     {
00455         if (countC0+countC1==0)
00456             return 0.0;
00457 
00458         double alpha_1=(double)countC0/(countC0+countC1);
00459         double alpha_2=(double)countC1/(countC0+countC1);
00460         double eta1=sumC0/countC0;
00461         double eta2=sumC1/countC1;
00462         double var1=sum2C0/countC0-pow2(eta1);
00463         double var2=sum2C1/countC1-pow2(eta2);
00464 
00465 #ifdef DEBUG_PRINT
00466 
00467         cout << "ComputeGiniGain " << alpha_1 << " " << alpha_2 << " " << eta1 << " ";
00468         cout << eta2 << " " << var1 << " " << var2 << "\t";
00469 #endif
00470 
00471         split=UnidimensionalQDA(alpha_1,eta1,var1,alpha_2,eta2,var2, whichSol);
00472         assert(isfinite(split));
00473         splitVariance=UnidimensionalQDAVariance(countC0, eta1, var1,
00474                                                 countC1, eta2, var2, whichSol);
00475         splitVarComputed=true;
00476 
00477         // compute gini like in function ComputeSeparatingHyperplane_Anova
00478         double p11=alpha_1*PValueNormalDistribution(eta1, sqrt(var1), split);
00479         double p1_=p11+alpha_2*PValueNormalDistribution(eta2, sqrt(var2), split);
00480         double gini=BinaryGiniGain(p11,alpha_1,p1_);
00481 #ifdef DEBUG_PRINT
00482 
00483         cout << split << " " << gini << endl;
00484 #endif
00485 
00486         return gini;
00487     }
00488 
00489     double GetSplit(void)
00490     {
00491         return split;
00492     }
00493 
00494     /* --------------------- Stuff only for MultiDecTrees -------------*/
00495 
00496     bool nonZero(void)
00497     {
00498         return (countC0!=0 && countC1!=0);
00499     }
00500 
00501     bool hasData(void)
00502     {
00503         return (countC0!=0 && countC1!=0);
00504     }
00505 
00506     double ComputeShift(ShiftType shift, NormalStatistics& aux)
00507     {
00508         //      cerr << "computing shift" << endl;
00509         if (shift==unlabeled)
00510             return ( (aux.sumC0+aux.sumC1)/(aux.countC0+aux.countC1) -
00511                      (sumC0+sumC1)/(countC0+countC1) );
00512         /* The ComputeGiniGain for (*this) has to be called before */
00513         if (shift==labeled)
00514         {
00515             aux.ComputeGiniGain();
00516             assert (!(split!=split));
00517             return aux.split-split;
00518         }
00519         return 0.0;
00520     }
00521 
00522     void AddStatisticsShifted(NormalStatistics& aux, double shift)
00523     {
00524         // the shift has to be substracted from the statistics of aux
00525         countC0+=aux.countC0;
00526         countC1+=aux.countC1;
00527         // simulate the situation where the shift is substracted from every datapoint
00528         sumC0+=aux.sumC0-aux.countC0*shift;
00529         assert(!(sumC0!=sumC0)); //Nan
00530         assert(!(shift!=shift));
00531         sumC1+=aux.sumC1-aux.countC1*shift;
00532         sum2C0+=aux.sum2C0-2*aux.sumC0*shift+pow2(shift)*aux.countC0;
00533         sum2C1+=aux.sum2C1-2*aux.sumC1*shift+pow2(shift)*aux.countC1;
00534     }
00535 
00536     void AddWeightedStatistics(NormalStatistics& aux, double weight)
00537     {
00538         if (weight==0.0)
00539             return;
00540 
00541         countC0+=aux.countC0*weight;
00542         countC1+=aux.countC1*weight;
00543         sumC0+=aux.sumC0*weight;
00544         sumC1+=aux.sumC1*weight;
00545         sum2C0+=pow2(weight)*(aux.sum2C0*aux.countC0-pow2(aux.sumC0));
00546         sum2C1+=pow2(weight)*(aux.sum2C1*aux.countC1-pow2(aux.sumC1));
00547     }
00548 
00549     void CorrectWeightedStatistics(double totalWeight)
00550     {
00551         countC0/=totalWeight;
00552         countC1/=totalWeight;
00553         sumC0*=countC0/totalWeight;
00554         sumC1*=countC1/totalWeight;
00555         sum2C0=sum2C0/pow2(countC0*totalWeight)+pow2(sumC0/countC0);
00556         sum2C1=sum2C1/pow2(countC1*totalWeight)+pow2(sumC1/countC1);
00557     }
00558 
00559     bool negMeanLessThanPos(void)
00560     {
00561         double negMean = sumC0/countC0;
00562         double posMean = sumC1/countC1;
00563         return (negMean < posMean);
00564     }
00565 
00566     bool labeledMeansSignificant(void)
00567     {
00568         double negMean = sumC0/countC0;
00569         double posMean = sumC1/countC1;
00570         double diff = posMean-negMean;
00571         if (!negMeanLessThanPos())
00572         {
00573             diff = -1.0*diff;
00574         }
00575 
00576         /*
00577         double posVar = getLCVariance(true);
00578         double negVar = getLCVariance(false);
00579         */
00580         if (diff >= 0)// .0000000001*(posVar+negVar))
00581         {
00582             return true;
00583         }
00584         else
00585             return false;
00586 
00587     }
00588 
00589     double getLabeledCenter(bool b)
00590     {
00591         if (b)
00592             return sumC1/countC1;
00593         else
00594             return sumC0/countC0;
00595     }
00596 
00597     double  getLCVariance(bool b)
00598     {
00599         if (b)
00600         {
00601             return sum2C1/countC1 - pow2(sumC1/countC1);
00602         }
00603         else
00604             return sum2C0/countC0 - pow2(sumC0/countC0);
00605     }
00606 
00607     double getcountC0()
00608     {
00609         return countC0;
00610     }
00611     
00612     double getcountC1()
00613     {
00614         return countC1;
00615     }
00616 
00617 };
00618 
00619 
00620 /** Class implements a multidimentional normal distribution
00621  */
00622 class MultidimNormalStatistics
00623 {
00624 protected:
00625     int dim;
00626     double s_p1, s_p2; // Sum p1i, Sum p2i
00627     Vector<double> s_p1_x, s_p2_x; // Sum p1i*x, Sum p2i*x
00628     Fortran_Matrix<double> S_p1_xxT, S_p2_xxT; // Sum p1i*xx^T, Sum p2i*xx^T
00629 
00630     int SplitVariable;
00631     Vector<double> SeparatingHyperplane;
00632 
00633 public:
00634     MultidimNormalStatistics()
00635     {
00636         dim=0;
00637     } // all the default constructors called for vectors and matrices
00638 
00639     enum SeparationType { ANOVA=0, LDA=1, QDA=2 };
00640 
00641     double GetS_P1(void)
00642     {
00643         return s_p1;
00644     }
00645     
00646     double GetS_P2(void)
00647     {
00648         return s_p2;
00649     }
00650 
00651     /** changes the dimention and reinitializes everything */
00652     void Resize(int Dim)
00653     {
00654         if (dim!=Dim)
00655         {
00656             dim=Dim;
00657 
00658             s_p1_x.newsize(dim);
00659             s_p2_x.newsize(dim);
00660             S_p1_xxT.newsize(dim,dim);
00661             S_p2_xxT.newsize(dim,dim);
00662             SeparatingHyperplane.newsize(dim+1);
00663         }
00664 
00665         if (dim>0) /* otherwise somebody is just freeing the memory */
00666             Reset();
00667     }
00668 
00669     void Reset(void)
00670     {
00671         s_p1=s_p2=0.0;
00672         s_p1_x=0.0;
00673         s_p2_x=0.0;
00674         S_p1_xxT=0.0;
00675         S_p2_xxT=0.0;
00676     }
00677 
00678     void UpdateStatistics(const double* values, double p1, double p2)
00679     {
00680         if (!finite(p1+p2))
00681             return;
00682 
00683         s_p1+=p1;
00684         s_p2+=p2;
00685 
00686         for(int i=0; i<dim; i++)
00687         {
00688             double aux1=p1*values[i];
00689             double aux2=p2*values[i];
00690             s_p1_x[i]+=aux1;
00691             s_p2_x[i]+=aux2;
00692             for(int j=0; j<dim; j++)
00693             {
00694                 S_p1_xxT(i+1,j+1)+=aux1*values[j];
00695                 S_p2_xxT(i+1,j+1)+=aux2*values[j];
00696             }
00697         }
00698     }
00699 
00700     /* must be called before the split point can be computed */
00701     void UpdateParameters(void)
00702     {
00703         for(int i=0; i<dim; i++)
00704         {
00705             // compute mu1, mu2 in s_p1_x and s_p2_x
00706             s_p1_x[i]/=s_p1;
00707             s_p2_x[i]/=s_p2;
00708         }
00709         // compute Sigma1 and Sigma2 in S_p1_xxT and S_p2_xxT
00710         for(int i=0; i<dim; i++)
00711         {
00712             for(int j=0; j<dim; j++)
00713             { // AlinTODO: scan only half of the matrix
00714                 S_p1_xxT(i+1,j+1)=S_p1_xxT(i+1,j+1)/s_p1-s_p1_x[i]*s_p1_x[j];
00715                 S_p2_xxT(i+1,j+1)=S_p2_xxT(i+1,j+1)/s_p2-s_p2_x[i]*s_p2_x[j];
00716 
00717                 // fix in S_xxT the almost 0.0 entries. This helpes a lot
00718                 if (fabs(S_p1_xxT(i+1,j+1))<SMALL_POZ_VALUE)
00719                 {
00720                     S_p1_xxT(i+1,j+1)=0.0;
00721                 }
00722                 if (fabs(S_p2_xxT(i+1,j+1))<SMALL_POZ_VALUE)
00723                 {
00724                     S_p2_xxT(i+1,j+1)=0.0;
00725                 }
00726 
00727             }
00728         }
00729     }
00730 
00731     int GetSplitVariable(void)
00732     {
00733         return SplitVariable;
00734     }
00735 
00736     Vector<double>& GetSeparatingHyperplane(void)
00737     {
00738         return SeparatingHyperplane;
00739     }
00740 
00741     double ComputeGiniGain(int type = LDA)
00742     {
00743         double gini, alpha_1, alpha_2;
00744 
00745         if (dim<=1)
00746             return 0.0;
00747 
00748         SplitVariable=-1;
00749 
00750         if (s_p1+s_p2==0.00)
00751             return 0.0;
00752 
00753         // compute the best oblique split
00754         alpha_1 =s_p1/(s_p1+s_p2);
00755         alpha_2 =s_p2/(s_p1+s_p2);
00756         double mass=s_p1+s_p2;
00757 
00758         UpdateParameters();
00759 
00760         /* All the SeparatingHyperplane computation funcions are normalized
00761         with respect to variance of the split point */
00762 
00763         switch (type)
00764         {
00765         case ANOVA:
00766             gini = ComputeSeparatingHyperplane_Anova(mass,
00767                     alpha_1, s_p1_x, S_p1_xxT,
00768                     alpha_2, s_p2_x, S_p2_xxT, SeparatingHyperplane);
00769             break;
00770         case LDA:
00771             gini = ComputeSeparatingHyperplane_LDA(mass,
00772                                                    alpha_1, s_p1_x, S_p1_xxT,
00773                                                    alpha_2, s_p2_x, S_p2_xxT, SeparatingHyperplane);
00774             break;
00775         case QDA:
00776             gini = ComputeSeparatingHyperplane_QDA(mass,
00777                                                    alpha_1, s_p1_x, S_p1_xxT,
00778                                                    alpha_2, s_p2_x, S_p2_xxT, SeparatingHyperplane);
00779             break;
00780         default:
00781             gini = 0.0; // should never be reached
00782         }
00783 
00784         SplitVariable=-1;
00785         {
00786             // See if we have a simple split
00787             bool isSimple = true;
00788             int posSplitVar = -1;
00789             for (int i=0; i<SeparatingHyperplane.dim()-1; i++)
00790             {
00791                 double cval=SeparatingHyperplane[i+1];
00792                 if (cval != 0.0)
00793                     if (posSplitVar==-1)
00794                         posSplitVar=i;
00795                     else
00796                     {
00797                         isSimple=false;
00798                         break;
00799                     }
00800                 else
00801                     continue;
00802             }
00803 
00804             if (isSimple && posSplitVar!=-1)
00805             {
00806                 SplitVariable = -(posSplitVar+2);
00807                 cout << "Split variable is simple=" << posSplitVar << endl;
00808             }
00809             else
00810             {
00811                 cout << "Simple separation detection failed at: " << posSplitVar << endl;
00812             }
00813 
00814         }
00815 
00816         return gini;
00817     }
00818 
00819     double MaxGini(void)
00820     {
00821         return 2*s_p1*s_p2/pow2(s_p1+s_p2);
00822     }
00823 
00824 };
00825 
00826 }
00827 
00828 #endif // _CLUS_STATISTICSGATHERERS_H

Generated on Mon Jul 21 16:57:25 2003 for SECRET by doxygen 1.3.2