00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 
00032 
00033 
00034 #ifndef _CLUS_BINARYDECISIONTREE_H_
00035 #define _CLUS_BINARYDECISIONTREE_H_
00036 
00037 #include "machine.h"
00038 #include "binarydecisiontreenode.h"
00039 #include "dctraingen.h"
00040 #include <iostream>
00041 
00042 using namespace std;
00043 
00044 namespace CLUS
00045 {
00046 
00047 
00048 template< class T_Splitter >
00049 class BinaryDecisionTree : public Machine
00050 {
00051 protected:
00052     BinaryDecisionTreeNode< T_Splitter >* root;
00053 
00054 
00055     const Vector<int>& dDomainSize;
00056 
00057 
00058     int dsplitDim;
00059     
00060 
00061     int csplitDim;
00062 
00063 
00064     int minMass; 
00065 
00066 public:
00067     BinaryDecisionTree(const Vector<int>& DDomainSize,
00068                        int CsplitDim):
00069             Machine(CsplitDim,1),dDomainSize(DDomainSize),
00070             dsplitDim(DDomainSize.dim()-1),
00071             csplitDim(CsplitDim)
00072     {
00073         minMass=10;
00074         root=NULL;
00075     }
00076 
00077     ~BinaryDecisionTree(void)
00078     {
00079         if (root!=0)
00080             delete root;
00081     }
00082 
00083     virtual int InDim(void)
00084     {
00085         return dsplitDim+csplitDim;
00086     }
00087     
00088     virtual string TypeName(void)
00089     {
00090         return string("BinaryDecisionTree");
00091     }
00092     
00093     virtual void Infer(void)
00094     {
00095         if (root==0)
00096             return;
00097         
00098         int Dvars[MAX_VARIABLES];
00099         for(int i=0; i<dsplitDim; i++)
00100         {
00101             Dvars[i]=(int)(*input)[i];
00102         }
00103 
00104         double Cvars[MAX_VARIABLES];
00105         for (int i=0; i<csplitDim; i++)
00106         {
00107             Cvars[i]=(*input)[i+dsplitDim];
00108         }
00109 
00110         output[0]=root->Infer(Dvars,Cvars);
00111     }
00112     
00113     virtual void Identify(void)
00114     {
00115         const Matrix<double>& ctrainData = training->GetTrainingData();
00116         const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( training )
00117                                         -> GetDiscreteTrainingData();
00118 
00119         int M=ctrainData.num_rows();
00120 
00121         
00122         root = new BinaryDecisionTreeNode< T_Splitter >
00123                (1, dDomainSize, csplitDim);
00124 
00125         do
00126         {
00127             root->StartLearningEpoch();
00128             for(int i=0; i<M; i++)
00129             {
00130                 int classLbl=dtrainData[i][dsplitDim];
00131 #ifdef DEBUG_PRINT
00132 
00133                 cout << " i=" << i << " CL=" << classLbl << endl;
00134 #endif
00135 
00136                 root->LearnSample(dtrainData[i],ctrainData[i],classLbl);
00137             }
00138         }
00139         while (root->StopLearningEpoch(minMass));
00140 
00141         cout << "End Learning" << endl;
00142         
00143     }
00144 
00145     virtual void Prune(void)
00146     {
00147         const Matrix<double>& ctrainData = pruning->GetTrainingData();
00148         const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( pruning )
00149                                         -> GetDiscreteTrainingData();
00150 
00151         int M=ctrainData.num_rows();
00152 
00153         root->InitializePruningStatistics();
00154         for(int i=0; i<M; i++)
00155         {
00156             int classLbl=dtrainData[i][dsplitDim];
00157             root->UpdatePruningStatistics(dtrainData[i], ctrainData[i], classLbl);
00158         }
00159         root->FinalizePruningStatistics();
00160 
00161         
00162         double cost=root->PruneSubtree();
00163         cout << "RMSN after pruning is:" << cost/M << endl;
00164         
00165     }
00166     
00167     virtual int SetOption(char* name, char* val)
00168     {
00169         if (strcmp(name,"MinMass")==0)
00170             minMass = atoi(val);
00171         else
00172             return Machine::SetOption(name,val);
00173         return 1;
00174     }
00175 
00176     virtual void SaveToStream(ostream& out)
00177     {
00178         out << TypeName() << " (  " << dsplitDim << " " << csplitDim << " ) { "  << endl;
00179 
00180         out << "[ ";
00181         for(int i=0; i<dsplitDim; i++)
00182             out << dDomainSize[i] << " ";
00183         out << "]" << endl;
00184 
00185         root->SaveToStream(out);
00186 
00187         out << '}' << endl;
00188     }
00189 };
00190 
00191 }
00192 
00193 #endif // _CLUS_BINARYDECISIONTREE_H_