00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #ifndef _CLUS_BINARYDECISIONTREE_H_
00035 #define _CLUS_BINARYDECISIONTREE_H_
00036
00037 #include "machine.h"
00038 #include "binarydecisiontreenode.h"
00039 #include "dctraingen.h"
00040 #include <iostream>
00041
00042 using namespace std;
00043
00044 namespace CLUS
00045 {
00046
00047
00048 template< class T_Splitter >
00049 class BinaryDecisionTree : public Machine
00050 {
00051 protected:
00052 BinaryDecisionTreeNode< T_Splitter >* root;
00053
00054
00055 const Vector<int>& dDomainSize;
00056
00057
00058 int dsplitDim;
00059
00060
00061 int csplitDim;
00062
00063
00064 int minMass;
00065
00066 public:
00067 BinaryDecisionTree(const Vector<int>& DDomainSize,
00068 int CsplitDim):
00069 Machine(CsplitDim,1),dDomainSize(DDomainSize),
00070 dsplitDim(DDomainSize.dim()-1),
00071 csplitDim(CsplitDim)
00072 {
00073 minMass=10;
00074 root=NULL;
00075 }
00076
00077 ~BinaryDecisionTree(void)
00078 {
00079 if (root!=0)
00080 delete root;
00081 }
00082
00083 virtual int InDim(void)
00084 {
00085 return dsplitDim+csplitDim;
00086 }
00087
00088 virtual string TypeName(void)
00089 {
00090 return string("BinaryDecisionTree");
00091 }
00092
00093 virtual void Infer(void)
00094 {
00095 if (root==0)
00096 return;
00097
00098 int Dvars[MAX_VARIABLES];
00099 for(int i=0; i<dsplitDim; i++)
00100 {
00101 Dvars[i]=(int)(*input)[i];
00102 }
00103
00104 double Cvars[MAX_VARIABLES];
00105 for (int i=0; i<csplitDim; i++)
00106 {
00107 Cvars[i]=(*input)[i+dsplitDim];
00108 }
00109
00110 output[0]=root->Infer(Dvars,Cvars);
00111 }
00112
00113 virtual void Identify(void)
00114 {
00115 const Matrix<double>& ctrainData = training->GetTrainingData();
00116 const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( training )
00117 -> GetDiscreteTrainingData();
00118
00119 int M=ctrainData.num_rows();
00120
00121
00122 root = new BinaryDecisionTreeNode< T_Splitter >
00123 (1, dDomainSize, csplitDim);
00124
00125 do
00126 {
00127 root->StartLearningEpoch();
00128 for(int i=0; i<M; i++)
00129 {
00130 int classLbl=dtrainData[i][dsplitDim];
00131 #ifdef DEBUG_PRINT
00132
00133 cout << " i=" << i << " CL=" << classLbl << endl;
00134 #endif
00135
00136 root->LearnSample(dtrainData[i],ctrainData[i],classLbl);
00137 }
00138 }
00139 while (root->StopLearningEpoch(minMass));
00140
00141 cout << "End Learning" << endl;
00142
00143 }
00144
00145 virtual void Prune(void)
00146 {
00147 const Matrix<double>& ctrainData = pruning->GetTrainingData();
00148 const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( pruning )
00149 -> GetDiscreteTrainingData();
00150
00151 int M=ctrainData.num_rows();
00152
00153 root->InitializePruningStatistics();
00154 for(int i=0; i<M; i++)
00155 {
00156 int classLbl=dtrainData[i][dsplitDim];
00157 root->UpdatePruningStatistics(dtrainData[i], ctrainData[i], classLbl);
00158 }
00159 root->FinalizePruningStatistics();
00160
00161
00162 double cost=root->PruneSubtree();
00163 cout << "RMSN after pruning is:" << cost/M << endl;
00164
00165 }
00166
00167 virtual int SetOption(char* name, char* val)
00168 {
00169 if (strcmp(name,"MinMass")==0)
00170 minMass = atoi(val);
00171 else
00172 return Machine::SetOption(name,val);
00173 return 1;
00174 }
00175
00176 virtual void SaveToStream(ostream& out)
00177 {
00178 out << TypeName() << " ( " << dsplitDim << " " << csplitDim << " ) { " << endl;
00179
00180 out << "[ ";
00181 for(int i=0; i<dsplitDim; i++)
00182 out << dDomainSize[i] << " ";
00183 out << "]" << endl;
00184
00185 root->SaveToStream(out);
00186
00187 out << '}' << endl;
00188 }
00189 };
00190
00191 }
00192
00193 #endif // _CLUS_BINARYDECISIONTREE_H_