Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

binaryprobabilisticdecisiontree.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #ifndef _CLUS_BINARYPROBABILISTICDECISIONTREE_H_
00035 #define _CLUS_BINARYPROBABILISTICDECISIONTREE_H_
00036 
00037 #include "machine.h"
00038 #include "binaryprobabilisticdecisiontreenode.h"
00039 #include "dctraingen.h"
00040 #include <iostream>
00041 
00042 using namespace std;
00043 
00044 namespace CLUS
00045 {
00046 
00047 template< class T_Splitter >
00048 class BinaryProbabilisticDecisionTree : public Machine
00049 {
00050 protected:
00051     BinaryProbabilisticDecisionTreeNode< T_Splitter >* root;
00052 
00053     /// vector of discrete domain sizes
00054     const Vector<int>& dDomainSize;
00055     
00056     /// number of discrete split variables
00057     int dsplitDim;
00058 
00059     /// number of continuous split variables
00060     int csplitDim;
00061 
00062     /// the minimum mass (sum of weights) to continue splitting
00063     double minMass;
00064 
00065     /// the minimum value of the probability to belong to a partition to be considered
00066     double threshold;
00067 
00068     /// do we do bootstrapping
00069     bool bootstrapping;
00070 
00071     /// number of repetitions for bootstrapping
00072     int bootstrappingRepetitions;
00073 
00074 public:
00075     BinaryProbabilisticDecisionTree(const Vector<int>& DDomainSize,
00076                                     int CsplitDim):
00077             Machine(CsplitDim,1),dDomainSize(DDomainSize),
00078             dsplitDim(DDomainSize.dim()-1),
00079             csplitDim(CsplitDim)
00080     {
00081         minMass=10.0;
00082         threshold=.01;
00083         root=NULL;
00084         bootstrapping=false;
00085         bootstrappingRepetitions=1000;
00086     }
00087 
00088     ~BinaryProbabilisticDecisionTree(void)
00089     {
00090         if (root!=0)
00091             delete root;
00092     }
00093     
00094     virtual int InDim(void)
00095     {
00096         return dsplitDim+csplitDim;
00097     }
00098 
00099     virtual string TypeName(void)
00100     {
00101         return string("BinaryProbabilisticDecisionTree");
00102     }
00103 
00104     virtual void Infer(void)
00105     {
00106         if (root==0)
00107             return;
00108         // translate the first dsplitDim inputs into ints
00109         int Dvars[MAX_VARIABLES];
00110         for(int i=0; i<dsplitDim; i++)
00111         {
00112             Dvars[i]=(int)(*input)[i];
00113         }
00114 
00115         double Cvars[MAX_VARIABLES];
00116         for (int i=0; i<csplitDim; i++)
00117         {
00118             Cvars[i]=(*input)[i+dsplitDim];
00119         }
00120 
00121         // ask the root what is the probability to have class label 0
00122         // if >.5 return 0 otherwise 1
00123         // this has a small bias for first class
00124         if (root->ProbabilityFirstClass(Dvars,Cvars,1.0,threshold)>.5)
00125             output[0]=0;
00126         else
00127             output[0]=1;
00128 
00129 #ifdef DEBUG_PRINT
00130 
00131         cout << "\t\t";
00132         for(int i=0; i<dsplitDim; i++)
00133             cout << Dvars[i] << " ";
00134         cout << "\t";
00135         for (int i=0; i<csplitDim; i++)
00136             cout << Cvars[i] << " ";
00137 
00138         double pFC=root->ProbabilityFirstClass(Dvars,Cvars,1.0,threshold);
00139         cout << pFC << " - " << output[0] << endl;
00140 #endif
00141 
00142     }
00143 
00144     virtual void Identify(void)
00145     {
00146         const Matrix<double>& ctrainData = training->GetTrainingData();
00147         const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( training )
00148                                         -> GetDiscreteTrainingData();
00149 
00150         int M=ctrainData.num_rows();
00151 
00152         // create the root and give it the Id 1
00153         root = new BinaryProbabilisticDecisionTreeNode< T_Splitter >
00154                (1, dDomainSize, csplitDim);
00155 
00156         do
00157         {
00158             root->StartLearningEpoch();
00159             for(int i=0; i<M; i++)
00160             {
00161                 int classLbl=dtrainData[i][dsplitDim];
00162                 root->LearnSample(dtrainData[i],ctrainData[i],classLbl,1.0,threshold);
00163             }
00164         }
00165         while (root->StopLearningEpoch(minMass));
00166 
00167         cout << "End Learning" << endl;
00168         // PrintSizeTree();
00169     }
00170 
00171     virtual void Prune(void)
00172     {
00173         const Matrix<double>& ctrainData = pruning->GetTrainingData();
00174         const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( pruning )
00175                                         -> GetDiscreteTrainingData();
00176 
00177         int M=ctrainData.num_rows();
00178 
00179         root->InitializePruningStatistics();
00180         for(int i=0; i<M; i++)
00181         {
00182             int classLbl=dtrainData[i][dsplitDim];
00183             root->UpdatePruningStatistics(dtrainData[i], ctrainData[i], classLbl, 1.0, threshold);
00184         }
00185         root->FinalizePruningStatistics();
00186 
00187         // now cut the tree to the right size
00188         double cost=root->PruneSubtree();
00189         cout << "RMSN after pruning is:" << cost/M << endl;
00190         // PrintSizeTree();
00191     }
00192 
00193     virtual int SetOption(char* name, char* val)
00194     {
00195         if (strcmp(name,"MinMass")==0)
00196             minMass = atof(val);
00197         else
00198             if (strcmp(name,"Threshold")==0)
00199                 threshold = atof(val);
00200             else
00201                 return Machine::SetOption(name,val);
00202         return 1;
00203     }
00204 
00205     virtual void SaveToStream(ostream& out)
00206     {
00207         out << TypeName() << " (  " << dsplitDim << " " << csplitDim << " ) { "  << endl;
00208 
00209         out << "[ ";
00210         for(int i=0; i<dsplitDim; i++)
00211             out << dDomainSize[i] << " ";
00212         out << "]" << endl;
00213 
00214         root->SaveToStream(out);
00215 
00216         out << '}' << endl;
00217     }
00218 };
00219 
00220 }
00221 
00222 #endif // _CLUS_BINARYPROBABILISTICDECISIONTREE_H_

Generated on Mon Jul 21 16:57:24 2003 for SECRET by doxygen 1.3.2