Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

binaryprobabilisticsplitter.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #if !defined _CLUS_BINARYPROBABILISTICSPLITTER_H_
00035 #define _CLUS_BINARYPROBABILISTICSPLITTER_H_
00036 
00037 #include "general.h"
00038 #include "statisticsgatherers.h"
00039 #include "vec.h"
00040 
00041 #ifdef DEBUG_PRINT
00042 #include <iostream>
00043 using namespace std;
00044 #endif
00045 
00046 using namespace TNT;
00047 
00048 namespace CLUS
00049 {
00050 class BinaryProbabilisticSplitter
00051 {
00052 protected:
00053     /// the number of discrete and continuous split variables
00054     int dsplitDim, csplitDim;
00055 
00056     /// statistics about datapoints seen
00057     double mass0, mass; // sum of probabilities
00058 
00059     /// list of discrete domain sizes
00060     const Vector<int>& dDomainSize;
00061     
00062     /** Indicates on what variable this node splits on. */
00063     int SplitVariable;
00064 
00065     /// split for continuous variables
00066     double splitPoint, splitSTD; 
00067 
00068     Vector<double> splitSetProbability;
00069 
00070     /** Statistics for split decision. One for each attributeXdataset
00071         If an attribute already has the shift determined keep statistics
00072         about all datasets in the location for the first dataset
00073     */
00074     Vector<ProbabilisticBinomialStatistics>  discreteStatistics;
00075     Vector<NormalStatistics>  continuousStatistics;
00076 
00077 public:
00078     BinaryProbabilisticSplitter(const Vector<int>& DDomainSize,int CsplitDim ):
00079             dsplitDim(DDomainSize.dim()-1), csplitDim(CsplitDim),
00080             dDomainSize(DDomainSize),
00081             discreteStatistics(0), continuousStatistics(0)
00082     {}
00083 
00084     bool GotNoData(void)
00085     {
00086         return mass==0.0;
00087     }
00088 
00089     int GetCSplitDim(void)
00090     {
00091         return csplitDim;
00092     }
00093     
00094     int GetDSplitDim(void)
00095     {
00096         return dsplitDim;
00097     }
00098     
00099     const Vector<int>& GetDDomainSize(void)
00100     {
00101         return dDomainSize;
00102     }
00103 
00104     /** Initializes the data structures used in split variable selection */
00105     void InitializeSplitStatistics(void)
00106     {
00107         mass0=mass=0.0;
00108 
00109         discreteStatistics.newsize(dsplitDim);
00110         for (int i=0; i<dsplitDim; i++)
00111         {
00112             discreteStatistics[i].ResetDomainSize(dDomainSize[i]);
00113         }
00114 
00115         continuousStatistics.newsize(csplitDim);
00116 
00117     }
00118 
00119     /** Computes the probability to take the left branch */
00120     double ProbabilityLeft(const int* Dvars, const double* Cvars)
00121     {
00122         if (SplitVariable <=-1)
00123         {
00124             // split on a continuous variable
00125             // the split variable is actually -SplitVariable-1
00126             int splitVar=-SplitVariable-1;
00127             // return the CDF of N(splitPoint,splitSTD)(current point)
00128             return PValueNormalDistribution(splitPoint,splitSTD,Cvars[splitVar]);
00129         }
00130         else
00131         {
00132             // split on a discrete variable
00133             int value=Dvars[SplitVariable];
00134             return splitSetProbability[value];
00135         }
00136     }
00137 
00138     void UpdateSplitStatistics( const int* Dvars, const double* Cvars,
00139                                 int classLabel, double probability)
00140     {
00141         mass+=probability;
00142         if (classLabel==0)
00143             mass0+=probability;
00144 
00145         // update discrete statistics
00146         for (int i=0; i<dsplitDim; i++)
00147         {
00148             int value=Dvars[i];
00149             discreteStatistics[i].UpdateStatistics(value, classLabel, probability);
00150         }
00151 
00152         // update continuous statistics
00153         for (int i=0; i<csplitDim; i++)
00154         {
00155             double value=Cvars[i];
00156             continuousStatistics[i].UpdateStatistics(value, classLabel, probability);
00157         }
00158     }
00159 
00160     bool ComputeSplitVariable(void)
00161     {
00162         // make node a leaf if not enough data to take a split decision
00163         if (mass0<2.0 || mass-mass0<2.0)
00164         {
00165 #ifdef DEBUG_PRINT
00166             cout << "Making the node a leaf" << endl;
00167 #endif
00168 
00169             return false;
00170         }
00171 
00172         double maxgini=0.0;
00173 
00174         // go over the discrete attributes and find the best one
00175         for (int i=0; i<dsplitDim; i++)
00176         {
00177             double curr_gini=discreteStatistics[i].ComputeGiniGain();
00178 #ifdef DEBUG_PRINT
00179 
00180             cout << "\tVariable: " << i << " gini=" << curr_gini << endl;
00181 #endif
00182 
00183             if (curr_gini>maxgini)
00184             {
00185                 maxgini=curr_gini;
00186                 SplitVariable=i;
00187             }
00188         }
00189 
00190         // go over continuous variable
00191         for (int i=0; i<csplitDim; i++)
00192         {
00193             double curr_gini=continuousStatistics[i].ComputeGiniGain();
00194 #ifdef DEBUG_PRINT
00195 
00196             cout << "\tVariable: " << (-i-1) << " gini=" << curr_gini << endl;
00197 #endif
00198 
00199             if (curr_gini>maxgini)
00200             {
00201                 maxgini=curr_gini;
00202                 SplitVariable=-(i+1);
00203             }
00204         }
00205 
00206         if (maxgini==0.0)
00207             return false; // make the node a leaf, nobody can do a reasonable split
00208 
00209         // not set the split point for the variable picked as the split point
00210         if (SplitVariable>=0)
00211             splitSetProbability=discreteStatistics[SplitVariable].GetProbabilitySet();
00212         else
00213         {
00214             int splitVar=-SplitVariable-1;
00215             splitPoint=continuousStatistics[splitVar].GetSplit();
00216             splitSTD=sqrt(continuousStatistics[splitVar].getSplitVariance());
00217         }
00218 
00219 #ifdef DEBUG_PRINT
00220         cout << "Split variable is " << SplitVariable << " and split point ";
00221         if (SplitVariable>=0)
00222             cout << splitSetProbability << endl;
00223         else
00224             cout << splitPoint << " " << splitSTD << endl;
00225 #endif
00226 
00227         return true;
00228     }
00229 
00230     double ComputeProbabilityFirstClass(void)
00231     {
00232         return (mass0/mass);
00233     }
00234 
00235     bool MoreSplits(double minMass, int nodeId)
00236     {
00237         return ( mass>=minMass && mass0!=0.0 && mass0!=mass );
00238     }
00239 
00240     void SaveToStream(ostream& out)
00241     {
00242         out << SplitVariable << " ";
00243         if (SplitVariable>=0)
00244         {
00245             out << "( " << mass << " ) [ ";
00246             for (int i=0; i<splitSetProbability.size(); i++)
00247                 out << splitSetProbability[i] << " ";
00248             out << " ] ";
00249         }
00250         else
00251         {
00252             out << "( " << splitPoint << " " << splitSTD << " ) ";
00253         }
00254     }
00255 };
00256 
00257 }
00258 
00259 #endif //_CLUS_BINARYPROBABILISTICSPLITTER_H_

Generated on Mon Jul 21 16:57:24 2003 for SECRET by doxygen 1.3.2