Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

simplebinarysplitter.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #if !defined _CLUS_SIMPLEBINARYSPLITTER_H_
00035 #define _CLUS_SIMPLEBINARYSPLITTER_H_
00036 
00037 #include "statisticsgatherers.h"
00038 
00039 using namespace TNT;
00040 
00041 namespace CLUS
00042 {
00043 
00044 /// Splitter for decision trees
00045 class SimpleBinarySplitter
00046 {
00047 protected:
00048     /// the number of discrete and continuous split variables
00049     int dsplitDim, csplitDim;
00050 
00051 
00052     /// List of discrete domain sizes
00053     const Vector<int>& dDomainSize;
00054 
00055     /** Indicates on what variable this node splits on. */
00056     int SplitVariable;
00057     
00058     /** split point if continuous variable is split variable */
00059     double splitPoint;
00060     
00061     /** The list of values for the left child for the separating variable if
00062     the split is not oblique. Contains values in order so that
00063     binary search can be used. */
00064     Vector<int> SeparatingSet;
00065 
00066     /** Statistics for split decision. One for each attributeXdataset
00067     If an attribute already has the shift determined keep statistics 
00068     about all datasets in the location for the first dataset
00069     */
00070     Vector<BinomialStatistics> discreteStatistics;
00071     Vector<NormalStatistics> continuousStatistics;
00072 
00073     /// statistics for class label determination
00074     int count;
00075     int countC0;
00076 public:
00077     SimpleBinarySplitter(const Vector<int>& DDomainSize,int CsplitDim):
00078             dsplitDim(DDomainSize.dim()-1), csplitDim(CsplitDim),
00079             dDomainSize(DDomainSize)
00080     {}
00081 
00082     ~SimpleBinarySplitter(void)
00083     {}
00084 
00085     int GetCSplitDim(void)
00086     {
00087         return csplitDim;
00088     }
00089     int GetDSplitDim(void)
00090     {
00091         return dsplitDim;
00092     }
00093     const Vector<int>& GetDDomainSize(void)
00094     {
00095         return dDomainSize;
00096     }
00097 
00098     /** Initializes the datastructures used in split variable selection */
00099     void InitializeSplitStatistics(void)
00100     {
00101         count=0;
00102         countC0=0;
00103 
00104         discreteStatistics.newsize(dsplitDim);
00105         for (int i=0; i<dsplitDim; i++)
00106         {
00107             discreteStatistics[i].ResetDomainSize(dDomainSize[i]);
00108         }
00109 
00110         continuousStatistics.newsize(csplitDim);
00111     }
00112 
00113     /** 0 is left and 1 right */
00114     int ChooseBranch( const int* Dvars, const double* Cvars)
00115     {
00116         if (SplitVariable <=-1)
00117         {
00118             // split on a continuous variable
00119             // the split variable is actually -SplitVariable-1
00120             if (Cvars[-SplitVariable-1]<=splitPoint)
00121                 return 0;
00122             else
00123                 return 1;
00124         }
00125         else
00126         {
00127             // split on a discrete variable
00128             int value=Dvars[SplitVariable];
00129             // look for value in SeparatingSet
00130             if (IsPointInSet<int>(value, SeparatingSet))
00131                 return 0;
00132             else
00133                 return 1;
00134         }
00135     }
00136 
00137     void UpdateSplitStatistics( const int* Dvars, const double* Cvars,
00138                                 int classLabel)
00139     {
00140 #ifdef DEBUG_PRINT
00141         cout << "CL=" << classLabel << " ";
00142 #endif
00143 
00144         count++;
00145         if (classLabel==0)
00146             countC0++;
00147 
00148 #ifdef DEBUG_PRINT
00149 
00150         cout << "count=" << count << " countC0=" << countC0 << endl;
00151 #endif
00152 
00153         for (int i=0; i<dsplitDim; i++)
00154         {
00155             int value=Dvars[i];
00156             discreteStatistics[i].UpdateStatistics(value, classLabel);
00157         }
00158 
00159         // update continuous statistics
00160         for (int i=0; i<csplitDim; i++)
00161         {
00162             double value=Cvars[i];
00163             continuousStatistics[i].UpdateStatistics(value, classLabel);
00164         }
00165     }
00166 
00167     bool ComputeSplitVariable(void)
00168     {
00169         // make node a leaf if not enough data to take a split decision
00170         if (countC0<2.0 || count-countC0<2.0)
00171         {
00172 #ifdef DEBUG_PRINT
00173             cout << "Making the node a leaf. countC0=" << countC0
00174             << " count=" << count << endl;
00175 #endif
00176 
00177             return false;
00178         }
00179 
00180         double maxgini=0.0;
00181 
00182         // go over the discrete attributes and find the best one
00183         for (int i=0; i<dsplitDim; i++)
00184         {
00185             double curr_gini=discreteStatistics[i].ComputeGiniGain();
00186 #ifdef DEBUG_PRINT
00187 
00188             cout << "\tVariable: " << i << " gini=" << curr_gini << endl;
00189 #endif
00190 
00191             if (curr_gini>maxgini)
00192             {
00193                 maxgini=curr_gini;
00194                 SplitVariable=i;
00195             }
00196         }
00197 
00198         // go over continuous variable
00199         for (int i=0; i<csplitDim; i++)
00200         {
00201             double curr_gini=continuousStatistics[i].ComputeGiniGain();
00202 #ifdef DEBUG_PRINT
00203 
00204             cout << "\tVariable: " << (-i-1) << " gini=" << curr_gini << endl;
00205 #endif
00206 
00207             if (curr_gini>maxgini)
00208             {
00209                 maxgini=curr_gini;
00210                 SplitVariable=-(i+1);
00211             }
00212         }
00213 
00214         if (maxgini==0.0)
00215             return false; // make the node a leaf, nobody can do a reasonable split
00216 
00217         // not set the split point for the variable picked as the split point
00218         if (SplitVariable>=0)
00219             SeparatingSet=discreteStatistics[SplitVariable].GetSplit();
00220         else
00221         {
00222             int splitVar=-SplitVariable-1;
00223             splitPoint=continuousStatistics[splitVar].GetSplit();
00224         }
00225 
00226 #ifdef DEBUG_PRINT
00227         cout << "Split variable is " << SplitVariable << " and split point ";
00228         if (SplitVariable>=0)
00229             cout << SeparatingSet << endl;
00230         else
00231             cout << splitPoint << endl;
00232 #endif
00233 
00234         return true;
00235     }
00236 
00237     int ComputeClassLabel(void)
00238     {
00239         if (countC0>=count-countC0)
00240             return 0;
00241         else
00242             return 1;
00243     }
00244 
00245     bool MoreSplits(int minMass, int nodeId)
00246     {
00247         return ( count>=minMass && countC0!=0 && countC0!=count );
00248     }
00249 
00250     void SaveToStream(ostream& out)
00251     {
00252         out << SplitVariable << " ";
00253         if (SplitVariable>=0)
00254         {
00255             out << "( " << count << " ) [ ";
00256             for (int i=0; i<SeparatingSet.size(); i++)
00257                 out << SeparatingSet[i] << " ";
00258             out << " ] ";
00259         }
00260         else
00261         {
00262             out << "( " << splitPoint << " ) ";
00263         }
00264     }
00265 };
00266 }
00267 
00268 #endif // _CLUS_SIMPLEBINARYSPLITTER_H_

Generated on Mon Jul 21 16:57:25 2003 for SECRET by doxygen 1.3.2