Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

binarydecisiontreenode.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #ifndef _CLUS_BINARYDECISIONTREENODE_H_
00035 #define _CLUS_BINARYDECISIONTREENODE_H_
00036 
00037 #include "vec.h"
00038 #include "general.h"
00039 
00040 #ifdef DEBUG_PRINT
00041 #include <iostream>
00042 using namespace std;
00043 #endif
00044 
00045 using namespace TNT;
00046 
00047 namespace CLUS
00048 {
00049 
00050 /// Implements a node of the binary decision tree
00051 template< class T_Splitter >
00052 class BinaryDecisionTreeNode
00053 {
00054 protected:
00055 
00056     /// unique identifier of the cluster for a regression tree.
00057     int nodeId;
00058 
00059     /// the state of the node. At creation em. At load stable
00060     enum state { stable, split} State;
00061 
00062     /// the predicted class label
00063     int classLabel;
00064 
00065     /// the children of this node
00066     BinaryDecisionTreeNode< T_Splitter > * Children[2];
00067 
00068     /// the probability to return first class, for the seccond probability is 1-probFirstClass
00069     double probFirstClass;
00070 
00071     /// splitter for split criterion
00072     T_Splitter Splitter;
00073 
00074     /// pruning error statistic
00075     int pruningError;
00076 
00077     /// pruning total mass statistic
00078     int pruningTotalMass;
00079 
00080 public:
00081     BinaryDecisionTreeNode(int NodeId, const Vector<int> & DDomainSize,
00082                            int CsplitDim):
00083             nodeId(NodeId), State(split), Splitter(DDomainSize,CsplitDim)
00084     {
00085         // predict first class by default
00086         probFirstClass=1.0; 
00087     }
00088 
00089     ~BinaryDecisionTreeNode(void)
00090     {
00091         if (Children[0]!=0)
00092             delete Children[0];
00093         Children[0]=0;
00094 
00095 
00096         if (Children[1]!=0)
00097             delete Children[1];
00098         Children[1]=0;
00099     }
00100 
00101     /// Begin the learning process
00102     void StartLearningEpoch(void)
00103     {
00104         switch (State)
00105         {
00106         case stable:
00107             if (Children[0]!=0)
00108             {
00109                 Children[0]->StartLearningEpoch();
00110                 Children[1]->StartLearningEpoch();
00111             }
00112             break;
00113         case split:
00114             Splitter.InitializeSplitStatistics();
00115             break;
00116         }
00117     }
00118 
00119     /// Learn a data sample
00120     ///
00121     /// @param Dvars         discrete variables
00122     /// @param Cvars         continuous variables
00123     /// @param classlabel    classification label
00124     void LearnSample(const int* Dvars, const double* Cvars, int classlabel)
00125     {
00126         switch (State)
00127         {
00128         case stable:
00129             // if leaf do nothing
00130             if (Children[0]==0 || Children[1]==0)
00131                 return;
00132 
00133             if (Splitter.ChooseBranch(Dvars,Cvars)==0)
00134                 Children[0]->LearnSample(Dvars,Cvars,classlabel);
00135             else
00136                 Children[1]->LearnSample(Dvars,Cvars,classlabel);
00137             break;
00138 
00139         case split:
00140             Splitter.UpdateSplitStatistics(Dvars, Cvars, classlabel);
00141             break;
00142         }
00143     }
00144 
00145     /// Stop the learning process
00146     ///
00147     /// @return              true if stopped successfully
00148     ///                      false if still processing data
00149     bool StopLearningEpoch(int minMass)
00150     {
00151         switch (State)
00152         {
00153         case stable:
00154             if (Children[0]!=0)
00155                 return Children[0]->StopLearningEpoch(minMass)
00156                        | Children[1]->StopLearningEpoch(minMass); // || cannot be used since is shortcuted
00157             else
00158                 return false;
00159 
00160         case split:
00161             State=stable;
00162 
00163             classLabel=Splitter.ComputeClassLabel();
00164 
00165             if (!Splitter.ComputeSplitVariable())
00166             {
00167                 // have to make the node a leaf
00168                 goto make_node_leaf;
00169             }
00170 
00171 #ifdef DEBUG_PRINT
00172             cout << "NodeId: " << nodeId << " probFirstClass=" << probFirstClass << endl;
00173 #endif
00174 
00175             if (Splitter.MoreSplits(minMass, nodeId))
00176             {
00177                 Children[0] = new BinaryDecisionTreeNode< T_Splitter >
00178                               ( nodeId*2, Splitter.GetDDomainSize(), Splitter.GetCSplitDim() );
00179 
00180                 Children[1] = new BinaryDecisionTreeNode< T_Splitter >
00181                               ( nodeId*2+1, Splitter.GetDDomainSize(), Splitter.GetCSplitDim() );
00182             }
00183             else
00184                 goto make_node_leaf;
00185 
00186             return true;
00187             break;
00188 
00189         default:
00190             return false;
00191         }
00192 
00193 make_node_leaf:
00194 #ifdef DEBUG_PRINT
00195 
00196         cout << "Making the node " << nodeId << " a leaf" << endl;
00197 #endif
00198 
00199         Children[0]=Children[1]=0;
00200         return false;
00201     }
00202 
00203     /// Do the inference
00204     ///
00205     /// @param Dvars         discrete variables
00206     /// @param Cvars         continuous variables
00207     double Infer(const int* Dvars, const double* Cvars)
00208     {
00209         // cout << "I am node: " << nodeId << " and my classlabel is " << classLabel << endl;
00210         if (Children[0]==0 || Children[1]==0)
00211         {
00212             // leaf node
00213             return classLabel;
00214         }
00215         else
00216         {
00217             return Children[Splitter.ChooseBranch(Dvars,Cvars)]->Infer(Dvars,Cvars);
00218         }
00219     }
00220 
00221     /// Initialize stats about pruning
00222     void InitializePruningStatistics(void)
00223     {
00224         pruningError=0;
00225         pruningTotalMass=0;
00226 
00227         if (Children[0]!=0 && Children[1]!=0)
00228         {
00229             Children[0]->InitializePruningStatistics();
00230             Children[1]->InitializePruningStatistics();
00231         }
00232     }
00233 
00234     /// Update pruning stats with new data
00235     ///
00236     /// @param Dvars         discrete variables
00237     /// @param Cvars         continuous variables
00238     /// @param classlabel    classification label
00239     void UpdatePruningStatistics(const int* Dvars, const double* Cvars,
00240                                  int classlabel /* true output */)
00241     {
00242 
00243         if (classlabel!=classLabel)
00244             pruningError+=1;
00245 
00246         pruningTotalMass+=1;
00247 
00248         if (Children[0]==0 || Children[1]==0)
00249             return; // stop the process
00250         else
00251             Children[Splitter.ChooseBranch(Dvars,Cvars)]->UpdatePruningStatistics(Dvars,Cvars,classlabel);
00252 
00253     }
00254 
00255     void FinalizePruningStatistics (void)
00256     {
00257         // nothing to do
00258     }
00259 
00260     /// Return the optimal cost for this subtree and cuts the subtree to optimal size
00261     double PruneSubtree(void)
00262     {
00263 #ifdef DEBUG_PRINT
00264         cout << "Pruneerror of " <<nodeId << " is: " << pruningError
00265         << " " << 1.0*pruningError/pruningTotalMass << endl;
00266 #endif
00267 
00268         if (Children[0] == 0 && Children[1] == 0)
00269         {
00270             // node is a leaf
00271             return pruningError;
00272         }
00273         else
00274         {
00275             double errorChildren = Children[0]->PruneSubtree()+
00276                                    Children[1]->PruneSubtree();
00277 #ifdef DEBUG_PRINT
00278 
00279             cout << "Childrenerror of " << nodeId << " is: " << errorChildren
00280             << " " << 1.0*errorChildren/pruningTotalMass << endl;
00281 #endif
00282 
00283             if (pruningError<=errorChildren)
00284             {
00285 #ifdef DEBUG_PRINT
00286                 cout << "Prunning at node " << nodeId << endl;
00287 #endif
00288                 // prune the tree here
00289                 delete Children[0];
00290                 Children[0]=0;
00291                 delete Children[1];
00292                 Children[1]=0;
00293 
00294                 return pruningError;
00295             }
00296             else
00297             {
00298                 // tree is good as it is
00299                 return errorChildren;
00300             }
00301         }
00302     }
00303 
00304     /// Output the node data to a stream
00305     ///
00306     /// @param out           stream for output
00307     void SaveToStream(ostream& out)
00308     {
00309         out << nodeId << " ";
00310         //if (Children[0] != 0 && Children[1] != 0)
00311         Splitter.SaveToStream(out);
00312         out << " " << classLabel << endl;
00313 
00314         if (Children[0] != 0 && Children[1] != 0)
00315         {
00316             Children[0]->SaveToStream(out);
00317             Children[1]->SaveToStream(out);
00318         }
00319     }
00320 
00321 };
00322 }
00323 
00324 #endif // _CLUS_BINARYDECISIONTREENODE_H_

Generated on Mon Jul 21 16:57:23 2003 for SECRET by doxygen 1.3.2