Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

multidecisiontreenode.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #if !defined _CLUS_MULTIDECISIONTREENODE_H
00035 #define _CLUS_MULTIDECISIONTREENODE_H
00036 
00037 #include <list>
00038 
00039 #include "vec.h"
00040 #include "continuouslineartransformation.h"
00041 #include "discretepermutationtransformation.h"
00042 #include "statisticsgatherers.h"
00043 
00044 using namespace TNT;
00045 
00046 namespace CLUS
00047 {
00048 
00049 template< class T_Splitter >
00050 class MultiDecisionTreeNode
00051 {
00052     /// unique identifier of the cluster for a regression tree.
00053     int nodeId;
00054     
00055     /// the state of the node. At creation em. At load stable
00056     enum state { stable, split} State;
00057     
00058     /// the children of this node
00059     MultiDecisionTreeNode< T_Splitter > * Children[2], *parent;
00060     
00061     /// the predicted class labels for each of the trees
00062     int classLabel;
00063     
00064     /// splitter for split criterion
00065     T_Splitter Splitter;
00066     
00067     /// statistics for pruning. Individual statistics are maintained for each dataset
00068     /// sum of squared differences between prediction and true value
00069     Vector<int> pruningMistakes;
00070 
00071     /// number of samples in this node for pruning 
00072     Vector<int> pruningSamples;
00073 
00074     double ComputePruningCost(void)
00075     {
00076         int totalMistakes=0;
00077         int totalSamples=0;
00078         for (int i=0; i<pruningMistakes.dim(); i++)
00079         {
00080             totalMistakes+=pruningMistakes[i];
00081             totalSamples+=pruningSamples[i];
00082         }
00083         if (totalMistakes==0)
00084             return 0.0;
00085         return ((double)totalMistakes)/*/totalSamples*/;
00086     }
00087 
00088     double ComputeNodeWeight(void)
00089     {
00090         /// @todo add more code to implement new weighing schemes
00091         if (State==stable)
00092             return 0.0;
00093         else
00094             return 1.0;
00095     }
00096 
00097 public:
00098     MultiDecisionTreeNode( MultiDecisionTreeNode< T_Splitter >* Parent,
00099                            int NodeId, const Vector<int> & DDomainSize,
00100                            int CsplitDim, int NoDatasets,
00101                            DiscretePermutationTransformation& discreteTransformer,
00102                            ContinuousLinearTransformation& continuousTransformer):
00103             nodeId(NodeId), State(split), parent(Parent), Splitter(DDomainSize,CsplitDim,NoDatasets,
00104                     discreteTransformer, continuousTransformer),
00105             pruningMistakes(NoDatasets), pruningSamples(NoDatasets)
00106     {
00107         Children[0]=Children[1]=0;
00108         Splitter.setNodeID(NodeId);
00109         cout << "Created node with ID=" << NodeId << endl;
00110     }
00111 
00112 
00113     ~MultiDecisionTreeNode(void)
00114     {
00115         if (Children[0]!=0)
00116             delete Children[0];
00117         Children[0]=0;
00118 
00119         if (Children[1]!=0)
00120             delete Children[1];
00121         Children[1]=0;
00122     }
00123 
00124     void StartLearningEpoch(void)
00125     {
00126         switch (State)
00127         {
00128         case stable:
00129             if (Children[0]!=0)
00130             {
00131                 Children[0]->StartLearningEpoch();
00132                 Children[1]->StartLearningEpoch();
00133             }
00134             break;
00135         case split:
00136             Splitter.InitializeSplitStatistics();
00137             break;
00138         }
00139     }
00140 
00141     void LearnSample(const int* Dvars, const double* Cvars, int classlabel, int datasetNo)
00142     {
00143         switch (State)
00144         {
00145         case stable:
00146             // Pass the sample to the right child
00147             if (Children[0]!=0)
00148             {
00149                 Children[Splitter.ChooseBranch(Dvars,Cvars)]->LearnSample(Dvars,Cvars,classlabel,datasetNo);
00150             }
00151             break;
00152         case split:
00153             Splitter.UpdateSplitStatistics(Dvars, Cvars, classlabel, datasetNo);
00154             break;
00155         }
00156     }
00157 
00158     /** If the node is in splitting stage find the split attribute and if
00159        it has no shift add attribute to the attList.
00160 
00161        @return false if we have a bogus split and we should make the parent a leaf
00162     */
00163     bool FindSplitAttributes(list<int>& attList)
00164     {
00165         switch (State)
00166         {
00167         case stable:
00168             if (Children[0]!=0)
00169             {
00170                 if (!Children[0]->FindSplitAttributes(attList) ||
00171                         !Children[1]->FindSplitAttributes(attList) )
00172                 {
00173                     delete Children[0];
00174                     Children[0]=0;
00175                     delete Children[1];
00176                     Children[1]=0;
00177                 }
00178             }
00179             break;
00180 
00181         case split:
00182 
00183             if (Splitter.GotNoData())
00184             {
00185                 // make the parent a leaf
00186                 return false;
00187             }
00188 
00189 
00190             //  cout << "I am node:" << nodeId << endl;
00191             if (!Splitter.ComputeSplitVariable(attList))
00192             {
00193                 // make the node a leaf
00194                 //  cout << "Making node a leaf: " << nodeId << endl;
00195                 State=stable;
00196                 Children[1]=0;
00197                 classLabel=Splitter.ComputeClassLabel();
00198                 // distroy datastructures
00199             }
00200             break;
00201         }
00202 
00203         return true;
00204     }
00205 
00206     double ComputeTotalWeight(void)
00207     {
00208         double totalWeight=ComputeNodeWeight();
00209         if (State==stable)
00210             if (Children[0]!=0)
00211                 totalWeight+=Children[0]->ComputeTotalWeight()+
00212                              Children[1]->ComputeTotalWeight();
00213         return totalWeight;
00214     }
00215 
00216     /// sums 1/v_n, for the subtree with this as root, where v_n is the variance of the split point for attribute, as computed at node n
00217     double computeSumOfVarianceInverted(int attribute, int dataSetIndex)
00218     {
00219         if (Splitter.hasContinuousData(attribute, dataSetIndex))
00220         {
00221             double sum = (double)1/Splitter.getSplitVariance(attribute, dataSetIndex);
00222             if (State==stable)
00223                 if (Children[0]!=0)
00224                 {
00225                     sum+= Children[0]->computeSumOfVarianceInverted(attribute, dataSetIndex) +
00226                           Children[1]->computeSumOfVarianceInverted(attribute, dataSetIndex);
00227                 }
00228             return sum;
00229         }
00230         else
00231             return 0.0;
00232     }
00233 
00234     /// Combines shifts for attribute for subtree with this as root.  Note: needs to be normalized.
00235     double combineSplits(int attribute, int dataSetIndex, double* sumInvVars)
00236     {
00237 
00238         double total=0;
00239         if (Splitter.hasContinuousData(attribute, dataSetIndex))
00240         {
00241             double var = Splitter.getSplitVariance(attribute, dataSetIndex) ;
00242             assert(!isnan(var) && var!=0);
00243 
00244             //  cerr << "SplitVariance: " << var << endl;
00245             double splitpt =  Splitter.getTentativeSplitPoint(attribute, dataSetIndex);
00246             assert(!isnan(splitpt));
00247             total = splitpt/var;
00248             *sumInvVars+=1.0/var;
00249 
00250 
00251             if (State==stable)
00252                 if (Children[0]!=0)
00253                 {
00254                     total+= Children[0]->combineSplits(attribute, dataSetIndex, sumInvVars) +
00255                             Children[0]->combineSplits(attribute, dataSetIndex, sumInvVars);
00256                 }
00257         }
00258         //  cerr<< "In combineSplits for nodeID " << nodeId <<endl;
00259         return total;
00260     }
00261 
00262     double combineLabeledCenters(int attribute, int dataSetIndex, double* sumInvVars)
00263     {
00264 
00265         double total = 0.0;
00266         double posWeight = Splitter.getLabeledCount(attribute, dataSetIndex, true);
00267         double negWeight = Splitter.getLabeledCount(attribute, dataSetIndex, false);
00268         double posMeani =  Splitter.getLabeledCenter(attribute, dataSetIndex, true);
00269         double negMeani = Splitter.getLabeledCenter(attribute, dataSetIndex, false);
00270         double posMean0 =  Splitter.getLabeledCenter(attribute, 0, true);
00271         double negMean0 = Splitter.getLabeledCenter(attribute, 0, false);
00272         double posVar=Splitter.getLabeledCenterVariance(attribute, dataSetIndex, true);
00273         double negVar = Splitter.getLabeledCenterVariance(attribute, dataSetIndex, false);
00274         double combinedVar = (posWeight*posVar + negWeight*negVar)/(posWeight+negWeight);
00275         double currShift= (posWeight*(posMean0-posMeani) + negWeight*(negMean0-negMeani))/ (posWeight+negWeight);
00276         if (!isnan(combinedVar) && NonZero(combinedVar) && !isnan(currShift))
00277         {
00278             total += currShift/combinedVar;
00279 
00280             *sumInvVars+=1.0/combinedVar;
00281         }
00282         if (State==stable)
00283             if (Children[0]!=0)
00284             {
00285                 total+= Children[0]->combineLabeledCenters(attribute, dataSetIndex, sumInvVars) +
00286                         Children[0]->combineLabeledCenters(attribute, dataSetIndex, sumInvVars);
00287             }
00288         // total+= (posWeight*(posMean0-posMeani) + negWeight*(negMean0-negMeani))/ posWeight+negWeight;
00289         return total;
00290     }
00291     bool labeledMeansSignificant(int attribute, int dataSetIndex)
00292     {
00293         return Splitter.labeledMeansSignificant(attribute, dataSetIndex);
00294     }
00295 
00296     bool negMeanLessThanPos(int attribute, int dataSetIndex)
00297     {
00298         return Splitter.negMeanLessThanPos(attribute, dataSetIndex);
00299     }
00300 
00301     double combineCenters(int attribute, int dataSetIndex, double* sumInvVars)
00302     {
00303 
00304         double total=0.0;
00305         if (Splitter.hasContinuousData(attribute, dataSetIndex))
00306         {
00307             double var = Splitter.getVariance(attribute, dataSetIndex) ;
00308 
00309 
00310             //  cerr << "Variance: " << var << endl;
00311             double splitpt =  Splitter.getTentativeCenter(attribute, dataSetIndex);
00312             // assert(!isnan(splitpt));
00313             if (!isnan(var) && NonZero(var) && !isnan(splitpt))
00314             {
00315                 total = splitpt/var;
00316                 *sumInvVars+=1.0/var;
00317             }
00318 
00319             if (State==stable)
00320                 if (Children[0]!=0)
00321                 {
00322                     total+= Children[0]->combineCenters(attribute, dataSetIndex, sumInvVars) +
00323                             Children[0]->combineCenters(attribute, dataSetIndex, sumInvVars);
00324                 }
00325         }
00326         //  cerr<< "In combineCenters for nodeID " << nodeId <<endl;
00327         return total;
00328     }
00329 
00330     void AddDiscreteShiftStatistics(int SplitAttribute, Vector< BinomialStatistics >& statistics)
00331     {
00332         Splitter.AddDiscreteShiftStatistics(SplitAttribute, ComputeNodeWeight(), statistics);
00333         if (State==stable)
00334             if (Children[0]!=0)
00335             {
00336                 Children[0]->AddDiscreteShiftStatistics(SplitAttribute,statistics);
00337                 Children[1]->AddDiscreteShiftStatistics(SplitAttribute,statistics);
00338             }
00339     }
00340 
00341     void AddContinuousShiftStatistics(int SplitAttribute, Vector< NormalStatistics >& statistics)
00342     {
00343         Splitter.AddContinuousShiftStatistics(SplitAttribute, ComputeNodeWeight(), statistics);
00344         if (State==stable)
00345             if (Children[0]!=0)
00346             {
00347                 Children[0]->AddContinuousShiftStatistics(SplitAttribute,statistics);
00348                 Children[1]->AddContinuousShiftStatistics(SplitAttribute,statistics);
00349             }
00350     }
00351 
00352     /// @return true if more learning has to be done in future
00353     bool StopLearningEpoch(int splitType, int min_no_datapoints)
00354     {
00355         switch (State)
00356         {
00357         case stable:
00358             if (Children[0]!=0)
00359                 return Children[0]->StopLearningEpoch(splitType, min_no_datapoints)
00360                        | Children[1]->StopLearningEpoch(splitType, min_no_datapoints); // || cannot be used since is shortcuted
00361             else
00362                 return false;
00363 
00364         case split:
00365 
00366             State=stable;
00367             Splitter.ComputeSplitPoint();
00368             classLabel=Splitter.ComputeClassLabel();
00369 
00370             // Splitter.DeleteTemporaryStatistics();
00371 
00372             if (Splitter.MoreSplits(min_no_datapoints, nodeId))
00373             {
00374                 Children[0] = new MultiDecisionTreeNode< T_Splitter >
00375                               ( this, nodeId*2, Splitter.GetDDomainSize(), Splitter.GetCSplitDim(),
00376                                 Splitter.GetNoDatasets(),
00377                                 Splitter.GetDiscreteTransformer(), Splitter.GetContinuousTransformer());
00378 
00379                 Children[1] = new MultiDecisionTreeNode< T_Splitter >
00380                               ( this, nodeId*2+1, Splitter.GetDDomainSize(), Splitter.GetCSplitDim(),
00381                                 Splitter.GetNoDatasets(),
00382                                 Splitter.GetDiscreteTransformer(), Splitter.GetContinuousTransformer());
00383             }
00384             else
00385             {
00386                 Children[0]=Children[1]=0;
00387             }
00388 
00389             return true;
00390             break;
00391 
00392         default:
00393             return true;
00394 
00395         }
00396     }
00397 
00398     Permutation ComputeDiscreteShift(bool label, int attribute, int datasetIndex)
00399     {
00400 
00401         return Splitter.ComputeDiscreteShift(label, attribute, datasetIndex);
00402 
00403     }
00404 
00405     /// Does the inference
00406     double Infer(const int* Dvars, const double* Cvars)
00407     {
00408         // cout << "I am node: " << nodeId << " and my classlabel is " << classLabel << endl;
00409         if (Children[0]==0)
00410         {
00411             // leaf node
00412             return classLabel;
00413         }
00414         else
00415         {
00416             return Children[Splitter.ChooseBranch(Dvars,Cvars)]->Infer(Dvars,Cvars);
00417         }
00418     }
00419 
00420     void InitializePruningStatistics(void)
00421     {
00422         pruningMistakes=0;
00423         pruningSamples=0;
00424         if (Children[0]!=0 && Children[1]!=0)
00425         {
00426             Children[0]->InitializePruningStatistics();
00427             Children[1]->InitializePruningStatistics();
00428         }
00429     }
00430 
00431     void UpdatePruningStatistics(const int* Dvars, const double* Cvars,  int classlabel /* true output */, int datasetNo  )
00432     {
00433         pruningSamples[datasetNo]++;
00434         if (classLabel!=classlabel)
00435             pruningMistakes[datasetNo]++;
00436         // update pruning statistics for the proper children
00437         if (Children[0]!=0 && Children[1]!=0)
00438             Children[Splitter.ChooseBranch(Dvars,Cvars)]->UpdatePruningStatistics(Dvars,Cvars,classlabel, datasetNo);
00439     }
00440 
00441     void FinalizePruningStatistics (void)
00442     {
00443         // nothing to do
00444     }
00445 
00446     /** Returns the optimal cost for this subtree and cuts the subtree to optimal size */
00447     double PruneSubtree(void)
00448     { // double alpha /* alpha is the cost for a leaf */){
00449         double pruningCost=ComputePruningCost();
00450         if (Children[0] == 0 && Children[1] == 0)
00451         {
00452             // node is a leaf
00453             return pruningCost; // CHANGE
00454         }
00455         else
00456         {
00457             // node is an intermediary node
00458             double pruningCostChildren=/*.5**/(Children[0]->PruneSubtree()+Children[1]->PruneSubtree());
00459 
00460             if (pruningCost<pruningCostChildren)
00461             {
00462                 // prune the tree here
00463                 delete Children[0];
00464                 Children[0]=0;
00465                 delete Children[1];
00466                 Children[1]=0;
00467 
00468                 return pruningCost;
00469             }
00470             else
00471             {
00472                 // tree is good as it is
00473                 return pruningCostChildren;
00474             }
00475         }
00476     }
00477 
00478     void SaveToStream(ostream& out)
00479     {
00480         out << "{" << endl << " nodeID " << nodeId <<  endl;
00481 
00482 
00483         out << endl;
00484 
00485 
00486         if (Children[0]!=0 && Children[1]!=0)
00487         {
00488             Splitter.SaveToStream(out, false);
00489             Children[0]->SaveToStream(out);
00490             Children[1]->SaveToStream(out);
00491         }
00492         else
00493         {
00494             out << " leaf " <<  endl;
00495             Splitter.SaveToStream(out, true);
00496         }
00497     }
00498 };
00499 
00500 }
00501 
00502 #endif // _CLUS_MULTIDECISIONTREENODE_H

Generated on Mon Jul 21 16:57:24 2003 for SECRET by doxygen 1.3.2