Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

binaryprobabilisticdecisiontreenode.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #ifndef _CLUS_BINARYPROBABILISTICDECISIONTREENODE_H_
00035 #define _CLUS_BINARYPROBABILISTICDECISIONTREENODE_H_
00036 
00037 #include "vec.h"
00038 
00039 #ifdef DEBUG_PRINT
00040 #include <iostream>
00041 using namespace std;
00042 #endif
00043 
00044 using namespace TNT;
00045 
00046 namespace CLUS
00047 {
00048 
00049 template< class T_Splitter >
00050 class BinaryProbabilisticDecisionTreeNode
00051 {
00052 protected:
00053 
00054     /// unique identifier of the cluster for a regression tree
00055     int nodeId;
00056     
00057     /// the state of the node. At creation em. At load stable
00058     enum state { stable, split, bootstrap } State;
00059     
00060     /// the children of this node
00061     BinaryProbabilisticDecisionTreeNode< T_Splitter > * Children[2];    
00062 
00063     /// the probability to return first class, for the second probability is 1-probFirstClass
00064     double probFirstClass;
00065 
00066     /// Splitter for split criterion
00067     T_Splitter Splitter;
00068 
00069     /// pruning statistics. Their ratio is the error
00070     double pruningError, pruningTotalMass, pruningTotalMassLeft; 
00071 
00072 public:
00073     BinaryProbabilisticDecisionTreeNode(int NodeId, const Vector<int> & DDomainSize,
00074                                         int CsplitDim):
00075             nodeId(NodeId), State(split), Splitter(DDomainSize,CsplitDim)
00076     {
00077         probFirstClass=1.0; // predict first class by default
00078     }
00079 
00080     ~BinaryProbabilisticDecisionTreeNode(void)
00081     {
00082         if (Children[0]!=0)
00083             delete Children[0];
00084         Children[0]=0;
00085 
00086 
00087         if (Children[1]!=0)
00088             delete Children[1];
00089         Children[1]=0;
00090     }
00091 
00092     void StartLearningEpoch(void)
00093     {
00094         switch (State)
00095         {
00096         case stable:
00097             if (Children[0]!=0)
00098             {
00099                 Children[0]->StartLearningEpoch();
00100                 Children[1]->StartLearningEpoch();
00101             }
00102             break;
00103         case split:
00104             Splitter.InitializeSplitStatistics();
00105             break;
00106 #if 0
00107 
00108         case bootstrap:
00109             Splitter.InitializeBootstrapSplitStatistics();
00110             break;
00111 #endif
00112 
00113         }
00114     }
00115 
00116     void LearnSample(const int* Dvars, const double* Cvars,
00117                      int classlabel, double probability, double threshold)
00118     {
00119         switch (State)
00120         {
00121         case stable:
00122             {
00123                 // if leaf do nothing
00124                 if (Children[0]==0 || Children[1]==0)
00125                     return;
00126 
00127                 // propagate the learning
00128                 double probabilityLeft = probability*Splitter.ProbabilityLeft(Dvars,Cvars);
00129                 double probabilityRight = probability-probabilityLeft;
00130                 if (probabilityLeft>=threshold)
00131                 {
00132                     Children[0]->LearnSample(Dvars,Cvars,classlabel,probabilityLeft,threshold);
00133                 }
00134 
00135                 if (probabilityRight>=threshold)
00136                 {
00137                     Children[1]->LearnSample(Dvars,Cvars,classlabel,probabilityRight,threshold);
00138                 }
00139             }
00140             break;
00141         case split:
00142             Splitter.UpdateSplitStatistics(Dvars, Cvars, classlabel, probability);
00143             break;
00144         }
00145     }
00146 
00147     bool StopLearningEpoch(double minMass)
00148     {
00149         switch (State)
00150         {
00151         case stable:
00152             if (Children[0]!=0)
00153                 return Children[0]->StopLearningEpoch(minMass)
00154                        | Children[1]->StopLearningEpoch(minMass); // || cannot be used since is shortcuted
00155             else
00156                 return false;
00157 
00158         case split:
00159             State=stable;
00160 
00161             probFirstClass=Splitter.ComputeProbabilityFirstClass();
00162 
00163             if (!Splitter.ComputeSplitVariable())
00164             {
00165                 // have to make the node a leaf
00166                 goto make_node_leaf;
00167             }
00168 
00169 #ifdef DEBUG_PRINT
00170             cout << "NodeId: " << nodeId << " probFirstClass=" << probFirstClass << endl;
00171 #endif
00172 
00173             if (Splitter.MoreSplits(minMass, nodeId))
00174             {
00175                 Children[0] = new BinaryProbabilisticDecisionTreeNode< T_Splitter >
00176                               ( nodeId*2, Splitter.GetDDomainSize(), Splitter.GetCSplitDim() );
00177 
00178                 Children[1] = new BinaryProbabilisticDecisionTreeNode< T_Splitter >
00179                               ( nodeId*2+1, Splitter.GetDDomainSize(), Splitter.GetCSplitDim() );
00180             }
00181             else
00182                 goto make_node_leaf;
00183 
00184             return true;
00185             break;
00186 
00187         default:
00188             return false;
00189         }
00190 
00191 make_node_leaf:
00192 #ifdef DEBUG_PRINT
00193 
00194         cout << "Making the node " << nodeId << " a leaf" << endl;
00195 #endif
00196 
00197         Children[0]=Children[1]=0;
00198         return false;
00199     }
00200 
00201     /// Use Baesian decision mode
00202     double ProbabilityFirstClass(const int* Dvars, const double* Cvars,
00203                                  double probability, double threshold)
00204     {
00205         if (Children[0]==0)
00206         {
00207             // leaf node
00208             return probability*probFirstClass;
00209         }
00210         else
00211         {
00212             double rez=0.0;
00213             double probabilityLeft = probability*Splitter.ProbabilityLeft(Dvars,Cvars);
00214             double probabilityRight = probability-probabilityLeft;
00215 
00216             if (probabilityLeft>=threshold)
00217             {
00218                 rez+=Children[0]->ProbabilityFirstClass(Dvars,Cvars,probabilityLeft,threshold);
00219             }
00220 
00221             if (probabilityRight>=threshold)
00222             {
00223                 rez+=Children[1]->ProbabilityFirstClass(Dvars,Cvars,probabilityRight,threshold);
00224             }
00225 
00226             return rez;
00227         }
00228     }
00229 
00230     void InitializePruningStatistics(void)
00231     {
00232         pruningError=0.0;
00233         pruningTotalMass=0.0;
00234         pruningTotalMassLeft=0.0;
00235         if (Children[0]!=0 && Children[1]!=0)
00236         {
00237             Children[0]->InitializePruningStatistics();
00238             Children[1]->InitializePruningStatistics();
00239         }
00240     }
00241 
00242     void UpdatePruningStatistics(const int* Dvars, const double* Cvars,  int classlabel /* true output */,
00243                                  double probability, double threshold)
00244     {
00245 
00246 #if 1
00247         int predClass = probFirstClass>.5 ? 0 : 1;
00248 
00249         if (classlabel!=predClass)
00250             pruningError+=probability;
00251 #else
00252 
00253         if (classlabel==0)
00254         {
00255             // true result is first class label
00256             pruningError+=probability*pow2(1.0-probFirstClass);
00257         }
00258         else
00259         {
00260             pruningError+=probability*pow2(probFirstClass);
00261         }
00262 #endif
00263 
00264         pruningTotalMass+=probability;
00265 
00266         if (Children[0]==0 || Children[1]==0)
00267             return; // stop the process
00268 
00269         double probabilityLeft = probability*Splitter.ProbabilityLeft(Dvars,Cvars);
00270         pruningTotalMassLeft+=probabilityLeft;
00271         double probabilityRight = probability-probabilityLeft;
00272 
00273         if (probabilityLeft>=threshold*probability)
00274         {
00275             Children[0]->UpdatePruningStatistics(Dvars,Cvars,classlabel,probabilityLeft,threshold);
00276         }
00277 
00278         if (probabilityRight>=threshold*probability)
00279         {
00280             Children[1]->UpdatePruningStatistics(Dvars,Cvars,classlabel,probabilityRight,threshold);
00281         }
00282     }
00283 
00284     void FinalizePruningStatistics (void)
00285     {
00286         // nothing to do
00287     }
00288 
00289     /// Returns the optimal cost for this subtree and cuts the subtree to optimal size
00290     double PruneSubtree(void)
00291     {
00292         //      double error=pruningTotalMass>0.0 ?
00293         //pruningError/pruningTotalMass : 0.0;  // normalized error
00294 
00295         double error=pruningError;
00296 #ifdef DEBUG_PRINT
00297 
00298         cout << "Pruneerror of " <<nodeId << " is: " << error
00299         << " " << error/pruningTotalMass << endl;
00300 #endif
00301 
00302         if (Children[0] == 0 && Children[1] == 0)
00303         {
00304             // node is a leaf
00305             return error;
00306         }
00307         else
00308         {
00309             //double errorChildren = pruningTotalMassLeft/pruningTotalMass*Children[0]->PruneSubtree()+
00310             //  (pruningTotalMass-pruningTotalMassLeft)/pruningTotalMass*Children[1]->PruneSubtree();
00311 
00312             double errorChildren = Children[0]->PruneSubtree()+
00313                                    Children[1]->PruneSubtree();
00314 #ifdef DEBUG_PRINT
00315 
00316             cout << "Childrenerror of " << nodeId << " is: " << errorChildren
00317             << " " << errorChildren/pruningTotalMass << endl;
00318 #endif
00319 
00320             if (error<=errorChildren)
00321             {
00322 #ifdef DEBUG_PRINT
00323                 cout << "Prunning at node " << nodeId << endl;
00324 #endif
00325                 // prune the tree here
00326                 delete Children[0];
00327                 Children[0]=0;
00328                 delete Children[1];
00329                 Children[1]=0;
00330 
00331                 return error;
00332             }
00333             else
00334             {
00335                 // tree is good as it is
00336                 return errorChildren;
00337             }
00338         }
00339     }
00340 
00341     void SaveToStream(ostream& out)
00342     {
00343         out << nodeId << " ";
00344         //if (Children[0] != 0 && Children[1] != 0)
00345         Splitter.SaveToStream(out);
00346         out << " " << probFirstClass << endl;
00347 
00348         if (Children[0] != 0 && Children[1] != 0)
00349         {
00350             Children[0]->SaveToStream(out);
00351             Children[1]->SaveToStream(out);
00352         }
00353     }
00354 
00355 };
00356 }
00357 
00358 #endif // _CLUS_BINARYPROBABILISTICDECISIONTREENODE_H_

Generated on Mon Jul 21 16:57:24 2003 for SECRET by doxygen 1.3.2