00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #ifndef _CLUS_BINARYDECISIONTREENODE_H_
00035 #define _CLUS_BINARYDECISIONTREENODE_H_
00036
00037 #include "vec.h"
00038 #include "general.h"
00039
00040 #ifdef DEBUG_PRINT
00041 #include <iostream>
00042 using namespace std;
00043 #endif
00044
00045 using namespace TNT;
00046
00047 namespace CLUS
00048 {
00049
00050
00051 template< class T_Splitter >
00052 class BinaryDecisionTreeNode
00053 {
00054 protected:
00055
00056
00057 int nodeId;
00058
00059
00060 enum state { stable, split} State;
00061
00062
00063 int classLabel;
00064
00065
00066 BinaryDecisionTreeNode< T_Splitter > * Children[2];
00067
00068
00069 double probFirstClass;
00070
00071
00072 T_Splitter Splitter;
00073
00074
00075 int pruningError;
00076
00077
00078 int pruningTotalMass;
00079
00080 public:
00081 BinaryDecisionTreeNode(int NodeId, const Vector<int> & DDomainSize,
00082 int CsplitDim):
00083 nodeId(NodeId), State(split), Splitter(DDomainSize,CsplitDim)
00084 {
00085
00086 probFirstClass=1.0;
00087 }
00088
00089 ~BinaryDecisionTreeNode(void)
00090 {
00091 if (Children[0]!=0)
00092 delete Children[0];
00093 Children[0]=0;
00094
00095
00096 if (Children[1]!=0)
00097 delete Children[1];
00098 Children[1]=0;
00099 }
00100
00101
00102 void StartLearningEpoch(void)
00103 {
00104 switch (State)
00105 {
00106 case stable:
00107 if (Children[0]!=0)
00108 {
00109 Children[0]->StartLearningEpoch();
00110 Children[1]->StartLearningEpoch();
00111 }
00112 break;
00113 case split:
00114 Splitter.InitializeSplitStatistics();
00115 break;
00116 }
00117 }
00118
00119
00120
00121
00122
00123
00124 void LearnSample(const int* Dvars, const double* Cvars, int classlabel)
00125 {
00126 switch (State)
00127 {
00128 case stable:
00129
00130 if (Children[0]==0 || Children[1]==0)
00131 return;
00132
00133 if (Splitter.ChooseBranch(Dvars,Cvars)==0)
00134 Children[0]->LearnSample(Dvars,Cvars,classlabel);
00135 else
00136 Children[1]->LearnSample(Dvars,Cvars,classlabel);
00137 break;
00138
00139 case split:
00140 Splitter.UpdateSplitStatistics(Dvars, Cvars, classlabel);
00141 break;
00142 }
00143 }
00144
00145
00146
00147
00148
00149 bool StopLearningEpoch(int minMass)
00150 {
00151 switch (State)
00152 {
00153 case stable:
00154 if (Children[0]!=0)
00155 return Children[0]->StopLearningEpoch(minMass)
00156 | Children[1]->StopLearningEpoch(minMass);
00157 else
00158 return false;
00159
00160 case split:
00161 State=stable;
00162
00163 classLabel=Splitter.ComputeClassLabel();
00164
00165 if (!Splitter.ComputeSplitVariable())
00166 {
00167
00168 goto make_node_leaf;
00169 }
00170
00171 #ifdef DEBUG_PRINT
00172 cout << "NodeId: " << nodeId << " probFirstClass=" << probFirstClass << endl;
00173 #endif
00174
00175 if (Splitter.MoreSplits(minMass, nodeId))
00176 {
00177 Children[0] = new BinaryDecisionTreeNode< T_Splitter >
00178 ( nodeId*2, Splitter.GetDDomainSize(), Splitter.GetCSplitDim() );
00179
00180 Children[1] = new BinaryDecisionTreeNode< T_Splitter >
00181 ( nodeId*2+1, Splitter.GetDDomainSize(), Splitter.GetCSplitDim() );
00182 }
00183 else
00184 goto make_node_leaf;
00185
00186 return true;
00187 break;
00188
00189 default:
00190 return false;
00191 }
00192
00193 make_node_leaf:
00194 #ifdef DEBUG_PRINT
00195
00196 cout << "Making the node " << nodeId << " a leaf" << endl;
00197 #endif
00198
00199 Children[0]=Children[1]=0;
00200 return false;
00201 }
00202
00203
00204
00205
00206
00207 double Infer(const int* Dvars, const double* Cvars)
00208 {
00209
00210 if (Children[0]==0 || Children[1]==0)
00211 {
00212
00213 return classLabel;
00214 }
00215 else
00216 {
00217 return Children[Splitter.ChooseBranch(Dvars,Cvars)]->Infer(Dvars,Cvars);
00218 }
00219 }
00220
00221
00222 void InitializePruningStatistics(void)
00223 {
00224 pruningError=0;
00225 pruningTotalMass=0;
00226
00227 if (Children[0]!=0 && Children[1]!=0)
00228 {
00229 Children[0]->InitializePruningStatistics();
00230 Children[1]->InitializePruningStatistics();
00231 }
00232 }
00233
00234
00235
00236
00237
00238
00239 void UpdatePruningStatistics(const int* Dvars, const double* Cvars,
00240 int classlabel )
00241 {
00242
00243 if (classlabel!=classLabel)
00244 pruningError+=1;
00245
00246 pruningTotalMass+=1;
00247
00248 if (Children[0]==0 || Children[1]==0)
00249 return;
00250 else
00251 Children[Splitter.ChooseBranch(Dvars,Cvars)]->UpdatePruningStatistics(Dvars,Cvars,classlabel);
00252
00253 }
00254
00255 void FinalizePruningStatistics (void)
00256 {
00257
00258 }
00259
00260
00261 double PruneSubtree(void)
00262 {
00263 #ifdef DEBUG_PRINT
00264 cout << "Pruneerror of " <<nodeId << " is: " << pruningError
00265 << " " << 1.0*pruningError/pruningTotalMass << endl;
00266 #endif
00267
00268 if (Children[0] == 0 && Children[1] == 0)
00269 {
00270
00271 return pruningError;
00272 }
00273 else
00274 {
00275 double errorChildren = Children[0]->PruneSubtree()+
00276 Children[1]->PruneSubtree();
00277 #ifdef DEBUG_PRINT
00278
00279 cout << "Childrenerror of " << nodeId << " is: " << errorChildren
00280 << " " << 1.0*errorChildren/pruningTotalMass << endl;
00281 #endif
00282
00283 if (pruningError<=errorChildren)
00284 {
00285 #ifdef DEBUG_PRINT
00286 cout << "Prunning at node " << nodeId << endl;
00287 #endif
00288
00289 delete Children[0];
00290 Children[0]=0;
00291 delete Children[1];
00292 Children[1]=0;
00293
00294 return pruningError;
00295 }
00296 else
00297 {
00298
00299 return errorChildren;
00300 }
00301 }
00302 }
00303
00304
00305
00306
00307 void SaveToStream(ostream& out)
00308 {
00309 out << nodeId << " ";
00310
00311 Splitter.SaveToStream(out);
00312 out << " " << classLabel << endl;
00313
00314 if (Children[0] != 0 && Children[1] != 0)
00315 {
00316 Children[0]->SaveToStream(out);
00317 Children[1]->SaveToStream(out);
00318 }
00319 }
00320
00321 };
00322 }
00323
00324 #endif // _CLUS_BINARYDECISIONTREENODE_H_