00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #ifndef _CLUS_BINARYPROBABILISTICDECISIONTREE_H_
00035 #define _CLUS_BINARYPROBABILISTICDECISIONTREE_H_
00036
00037 #include "machine.h"
00038 #include "binaryprobabilisticdecisiontreenode.h"
00039 #include "dctraingen.h"
00040 #include <iostream>
00041
00042 using namespace std;
00043
00044 namespace CLUS
00045 {
00046
00047 template< class T_Splitter >
00048 class BinaryProbabilisticDecisionTree : public Machine
00049 {
00050 protected:
00051 BinaryProbabilisticDecisionTreeNode< T_Splitter >* root;
00052
00053
00054 const Vector<int>& dDomainSize;
00055
00056
00057 int dsplitDim;
00058
00059
00060 int csplitDim;
00061
00062
00063 double minMass;
00064
00065
00066 double threshold;
00067
00068
00069 bool bootstrapping;
00070
00071
00072 int bootstrappingRepetitions;
00073
00074 public:
00075 BinaryProbabilisticDecisionTree(const Vector<int>& DDomainSize,
00076 int CsplitDim):
00077 Machine(CsplitDim,1),dDomainSize(DDomainSize),
00078 dsplitDim(DDomainSize.dim()-1),
00079 csplitDim(CsplitDim)
00080 {
00081 minMass=10.0;
00082 threshold=.01;
00083 root=NULL;
00084 bootstrapping=false;
00085 bootstrappingRepetitions=1000;
00086 }
00087
00088 ~BinaryProbabilisticDecisionTree(void)
00089 {
00090 if (root!=0)
00091 delete root;
00092 }
00093
00094 virtual int InDim(void)
00095 {
00096 return dsplitDim+csplitDim;
00097 }
00098
00099 virtual string TypeName(void)
00100 {
00101 return string("BinaryProbabilisticDecisionTree");
00102 }
00103
00104 virtual void Infer(void)
00105 {
00106 if (root==0)
00107 return;
00108
00109 int Dvars[MAX_VARIABLES];
00110 for(int i=0; i<dsplitDim; i++)
00111 {
00112 Dvars[i]=(int)(*input)[i];
00113 }
00114
00115 double Cvars[MAX_VARIABLES];
00116 for (int i=0; i<csplitDim; i++)
00117 {
00118 Cvars[i]=(*input)[i+dsplitDim];
00119 }
00120
00121
00122
00123
00124 if (root->ProbabilityFirstClass(Dvars,Cvars,1.0,threshold)>.5)
00125 output[0]=0;
00126 else
00127 output[0]=1;
00128
00129 #ifdef DEBUG_PRINT
00130
00131 cout << "\t\t";
00132 for(int i=0; i<dsplitDim; i++)
00133 cout << Dvars[i] << " ";
00134 cout << "\t";
00135 for (int i=0; i<csplitDim; i++)
00136 cout << Cvars[i] << " ";
00137
00138 double pFC=root->ProbabilityFirstClass(Dvars,Cvars,1.0,threshold);
00139 cout << pFC << " - " << output[0] << endl;
00140 #endif
00141
00142 }
00143
00144 virtual void Identify(void)
00145 {
00146 const Matrix<double>& ctrainData = training->GetTrainingData();
00147 const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( training )
00148 -> GetDiscreteTrainingData();
00149
00150 int M=ctrainData.num_rows();
00151
00152
00153 root = new BinaryProbabilisticDecisionTreeNode< T_Splitter >
00154 (1, dDomainSize, csplitDim);
00155
00156 do
00157 {
00158 root->StartLearningEpoch();
00159 for(int i=0; i<M; i++)
00160 {
00161 int classLbl=dtrainData[i][dsplitDim];
00162 root->LearnSample(dtrainData[i],ctrainData[i],classLbl,1.0,threshold);
00163 }
00164 }
00165 while (root->StopLearningEpoch(minMass));
00166
00167 cout << "End Learning" << endl;
00168
00169 }
00170
00171 virtual void Prune(void)
00172 {
00173 const Matrix<double>& ctrainData = pruning->GetTrainingData();
00174 const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( pruning )
00175 -> GetDiscreteTrainingData();
00176
00177 int M=ctrainData.num_rows();
00178
00179 root->InitializePruningStatistics();
00180 for(int i=0; i<M; i++)
00181 {
00182 int classLbl=dtrainData[i][dsplitDim];
00183 root->UpdatePruningStatistics(dtrainData[i], ctrainData[i], classLbl, 1.0, threshold);
00184 }
00185 root->FinalizePruningStatistics();
00186
00187
00188 double cost=root->PruneSubtree();
00189 cout << "RMSN after pruning is:" << cost/M << endl;
00190
00191 }
00192
00193 virtual int SetOption(char* name, char* val)
00194 {
00195 if (strcmp(name,"MinMass")==0)
00196 minMass = atof(val);
00197 else
00198 if (strcmp(name,"Threshold")==0)
00199 threshold = atof(val);
00200 else
00201 return Machine::SetOption(name,val);
00202 return 1;
00203 }
00204
00205 virtual void SaveToStream(ostream& out)
00206 {
00207 out << TypeName() << " ( " << dsplitDim << " " << csplitDim << " ) { " << endl;
00208
00209 out << "[ ";
00210 for(int i=0; i<dsplitDim; i++)
00211 out << dDomainSize[i] << " ";
00212 out << "]" << endl;
00213
00214 root->SaveToStream(out);
00215
00216 out << '}' << endl;
00217 }
00218 };
00219
00220 }
00221
00222 #endif // _CLUS_BINARYPROBABILISTICDECISIONTREE_H_