Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

binarydecisiontree.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #ifndef _CLUS_BINARYDECISIONTREE_H_
00035 #define _CLUS_BINARYDECISIONTREE_H_
00036 
00037 #include "machine.h"
00038 #include "binarydecisiontreenode.h"
00039 #include "dctraingen.h"
00040 #include <iostream>
00041 
00042 using namespace std;
00043 
00044 namespace CLUS
00045 {
00046 
00047 /// Implements the binary decision tree
00048 template< class T_Splitter >
00049 class BinaryDecisionTree : public Machine
00050 {
00051 protected:
00052     BinaryDecisionTreeNode< T_Splitter >* root;
00053 
00054     /// vector of discrete domain sizes
00055     const Vector<int>& dDomainSize;
00056 
00057     /// number of discrete split variables
00058     int dsplitDim;
00059     
00060     /// number of continuous split variables
00061     int csplitDim;
00062 
00063     /// the minimum mass (sum of weights) to continue splitting
00064     int minMass; 
00065 
00066 public:
00067     BinaryDecisionTree(const Vector<int>& DDomainSize,
00068                        int CsplitDim):
00069             Machine(CsplitDim,1),dDomainSize(DDomainSize),
00070             dsplitDim(DDomainSize.dim()-1),
00071             csplitDim(CsplitDim)
00072     {
00073         minMass=10;
00074         root=NULL;
00075     }
00076 
00077     ~BinaryDecisionTree(void)
00078     {
00079         if (root!=0)
00080             delete root;
00081     }
00082 
00083     virtual int InDim(void)
00084     {
00085         return dsplitDim+csplitDim;
00086     }
00087     
00088     virtual string TypeName(void)
00089     {
00090         return string("BinaryDecisionTree");
00091     }
00092     
00093     virtual void Infer(void)
00094     {
00095         if (root==0)
00096             return;
00097         // translate the first dsplitDim inputs into ints
00098         int Dvars[MAX_VARIABLES];
00099         for(int i=0; i<dsplitDim; i++)
00100         {
00101             Dvars[i]=(int)(*input)[i];
00102         }
00103 
00104         double Cvars[MAX_VARIABLES];
00105         for (int i=0; i<csplitDim; i++)
00106         {
00107             Cvars[i]=(*input)[i+dsplitDim];
00108         }
00109 
00110         output[0]=root->Infer(Dvars,Cvars);
00111     }
00112     
00113     virtual void Identify(void)
00114     {
00115         const Matrix<double>& ctrainData = training->GetTrainingData();
00116         const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( training )
00117                                         -> GetDiscreteTrainingData();
00118 
00119         int M=ctrainData.num_rows();
00120 
00121         // create the root and give it the Id 1
00122         root = new BinaryDecisionTreeNode< T_Splitter >
00123                (1, dDomainSize, csplitDim);
00124 
00125         do
00126         {
00127             root->StartLearningEpoch();
00128             for(int i=0; i<M; i++)
00129             {
00130                 int classLbl=dtrainData[i][dsplitDim];
00131 #ifdef DEBUG_PRINT
00132 
00133                 cout << " i=" << i << " CL=" << classLbl << endl;
00134 #endif
00135 
00136                 root->LearnSample(dtrainData[i],ctrainData[i],classLbl);
00137             }
00138         }
00139         while (root->StopLearningEpoch(minMass));
00140 
00141         cout << "End Learning" << endl;
00142         // PrintSizeTree();
00143     }
00144 
00145     virtual void Prune(void)
00146     {
00147         const Matrix<double>& ctrainData = pruning->GetTrainingData();
00148         const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( pruning )
00149                                         -> GetDiscreteTrainingData();
00150 
00151         int M=ctrainData.num_rows();
00152 
00153         root->InitializePruningStatistics();
00154         for(int i=0; i<M; i++)
00155         {
00156             int classLbl=dtrainData[i][dsplitDim];
00157             root->UpdatePruningStatistics(dtrainData[i], ctrainData[i], classLbl);
00158         }
00159         root->FinalizePruningStatistics();
00160 
00161         // now cut the tree to the right size
00162         double cost=root->PruneSubtree();
00163         cout << "RMSN after pruning is:" << cost/M << endl;
00164         // PrintSizeTree();
00165     }
00166     
00167     virtual int SetOption(char* name, char* val)
00168     {
00169         if (strcmp(name,"MinMass")==0)
00170             minMass = atoi(val);
00171         else
00172             return Machine::SetOption(name,val);
00173         return 1;
00174     }
00175 
00176     virtual void SaveToStream(ostream& out)
00177     {
00178         out << TypeName() << " (  " << dsplitDim << " " << csplitDim << " ) { "  << endl;
00179 
00180         out << "[ ";
00181         for(int i=0; i<dsplitDim; i++)
00182             out << dDomainSize[i] << " ";
00183         out << "]" << endl;
00184 
00185         root->SaveToStream(out);
00186 
00187         out << '}' << endl;
00188     }
00189 };
00190 
00191 }
00192 
00193 #endif // _CLUS_BINARYDECISIONTREE_H_

Generated on Mon Jul 21 16:57:23 2003 for SECRET by doxygen 1.3.2