00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #ifndef _CLUS_BINARYPROBABILISTICDECISIONTREENODE_H_
00035 #define _CLUS_BINARYPROBABILISTICDECISIONTREENODE_H_
00036
00037 #include "vec.h"
00038
00039 #ifdef DEBUG_PRINT
00040 #include <iostream>
00041 using namespace std;
00042 #endif
00043
00044 using namespace TNT;
00045
00046 namespace CLUS
00047 {
00048
00049 template< class T_Splitter >
00050 class BinaryProbabilisticDecisionTreeNode
00051 {
00052 protected:
00053
00054
00055 int nodeId;
00056
00057
00058 enum state { stable, split, bootstrap } State;
00059
00060
00061 BinaryProbabilisticDecisionTreeNode< T_Splitter > * Children[2];
00062
00063
00064 double probFirstClass;
00065
00066
00067 T_Splitter Splitter;
00068
00069
00070 double pruningError, pruningTotalMass, pruningTotalMassLeft;
00071
00072 public:
00073 BinaryProbabilisticDecisionTreeNode(int NodeId, const Vector<int> & DDomainSize,
00074 int CsplitDim):
00075 nodeId(NodeId), State(split), Splitter(DDomainSize,CsplitDim)
00076 {
00077 probFirstClass=1.0;
00078 }
00079
00080 ~BinaryProbabilisticDecisionTreeNode(void)
00081 {
00082 if (Children[0]!=0)
00083 delete Children[0];
00084 Children[0]=0;
00085
00086
00087 if (Children[1]!=0)
00088 delete Children[1];
00089 Children[1]=0;
00090 }
00091
00092 void StartLearningEpoch(void)
00093 {
00094 switch (State)
00095 {
00096 case stable:
00097 if (Children[0]!=0)
00098 {
00099 Children[0]->StartLearningEpoch();
00100 Children[1]->StartLearningEpoch();
00101 }
00102 break;
00103 case split:
00104 Splitter.InitializeSplitStatistics();
00105 break;
00106 #if 0
00107
00108 case bootstrap:
00109 Splitter.InitializeBootstrapSplitStatistics();
00110 break;
00111 #endif
00112
00113 }
00114 }
00115
00116 void LearnSample(const int* Dvars, const double* Cvars,
00117 int classlabel, double probability, double threshold)
00118 {
00119 switch (State)
00120 {
00121 case stable:
00122 {
00123
00124 if (Children[0]==0 || Children[1]==0)
00125 return;
00126
00127
00128 double probabilityLeft = probability*Splitter.ProbabilityLeft(Dvars,Cvars);
00129 double probabilityRight = probability-probabilityLeft;
00130 if (probabilityLeft>=threshold)
00131 {
00132 Children[0]->LearnSample(Dvars,Cvars,classlabel,probabilityLeft,threshold);
00133 }
00134
00135 if (probabilityRight>=threshold)
00136 {
00137 Children[1]->LearnSample(Dvars,Cvars,classlabel,probabilityRight,threshold);
00138 }
00139 }
00140 break;
00141 case split:
00142 Splitter.UpdateSplitStatistics(Dvars, Cvars, classlabel, probability);
00143 break;
00144 }
00145 }
00146
00147 bool StopLearningEpoch(double minMass)
00148 {
00149 switch (State)
00150 {
00151 case stable:
00152 if (Children[0]!=0)
00153 return Children[0]->StopLearningEpoch(minMass)
00154 | Children[1]->StopLearningEpoch(minMass);
00155 else
00156 return false;
00157
00158 case split:
00159 State=stable;
00160
00161 probFirstClass=Splitter.ComputeProbabilityFirstClass();
00162
00163 if (!Splitter.ComputeSplitVariable())
00164 {
00165
00166 goto make_node_leaf;
00167 }
00168
00169 #ifdef DEBUG_PRINT
00170 cout << "NodeId: " << nodeId << " probFirstClass=" << probFirstClass << endl;
00171 #endif
00172
00173 if (Splitter.MoreSplits(minMass, nodeId))
00174 {
00175 Children[0] = new BinaryProbabilisticDecisionTreeNode< T_Splitter >
00176 ( nodeId*2, Splitter.GetDDomainSize(), Splitter.GetCSplitDim() );
00177
00178 Children[1] = new BinaryProbabilisticDecisionTreeNode< T_Splitter >
00179 ( nodeId*2+1, Splitter.GetDDomainSize(), Splitter.GetCSplitDim() );
00180 }
00181 else
00182 goto make_node_leaf;
00183
00184 return true;
00185 break;
00186
00187 default:
00188 return false;
00189 }
00190
00191 make_node_leaf:
00192 #ifdef DEBUG_PRINT
00193
00194 cout << "Making the node " << nodeId << " a leaf" << endl;
00195 #endif
00196
00197 Children[0]=Children[1]=0;
00198 return false;
00199 }
00200
00201
00202 double ProbabilityFirstClass(const int* Dvars, const double* Cvars,
00203 double probability, double threshold)
00204 {
00205 if (Children[0]==0)
00206 {
00207
00208 return probability*probFirstClass;
00209 }
00210 else
00211 {
00212 double rez=0.0;
00213 double probabilityLeft = probability*Splitter.ProbabilityLeft(Dvars,Cvars);
00214 double probabilityRight = probability-probabilityLeft;
00215
00216 if (probabilityLeft>=threshold)
00217 {
00218 rez+=Children[0]->ProbabilityFirstClass(Dvars,Cvars,probabilityLeft,threshold);
00219 }
00220
00221 if (probabilityRight>=threshold)
00222 {
00223 rez+=Children[1]->ProbabilityFirstClass(Dvars,Cvars,probabilityRight,threshold);
00224 }
00225
00226 return rez;
00227 }
00228 }
00229
00230 void InitializePruningStatistics(void)
00231 {
00232 pruningError=0.0;
00233 pruningTotalMass=0.0;
00234 pruningTotalMassLeft=0.0;
00235 if (Children[0]!=0 && Children[1]!=0)
00236 {
00237 Children[0]->InitializePruningStatistics();
00238 Children[1]->InitializePruningStatistics();
00239 }
00240 }
00241
00242 void UpdatePruningStatistics(const int* Dvars, const double* Cvars, int classlabel ,
00243 double probability, double threshold)
00244 {
00245
00246 #if 1
00247 int predClass = probFirstClass>.5 ? 0 : 1;
00248
00249 if (classlabel!=predClass)
00250 pruningError+=probability;
00251 #else
00252
00253 if (classlabel==0)
00254 {
00255
00256 pruningError+=probability*pow2(1.0-probFirstClass);
00257 }
00258 else
00259 {
00260 pruningError+=probability*pow2(probFirstClass);
00261 }
00262 #endif
00263
00264 pruningTotalMass+=probability;
00265
00266 if (Children[0]==0 || Children[1]==0)
00267 return;
00268
00269 double probabilityLeft = probability*Splitter.ProbabilityLeft(Dvars,Cvars);
00270 pruningTotalMassLeft+=probabilityLeft;
00271 double probabilityRight = probability-probabilityLeft;
00272
00273 if (probabilityLeft>=threshold*probability)
00274 {
00275 Children[0]->UpdatePruningStatistics(Dvars,Cvars,classlabel,probabilityLeft,threshold);
00276 }
00277
00278 if (probabilityRight>=threshold*probability)
00279 {
00280 Children[1]->UpdatePruningStatistics(Dvars,Cvars,classlabel,probabilityRight,threshold);
00281 }
00282 }
00283
00284 void FinalizePruningStatistics (void)
00285 {
00286
00287 }
00288
00289
00290 double PruneSubtree(void)
00291 {
00292
00293
00294
00295 double error=pruningError;
00296 #ifdef DEBUG_PRINT
00297
00298 cout << "Pruneerror of " <<nodeId << " is: " << error
00299 << " " << error/pruningTotalMass << endl;
00300 #endif
00301
00302 if (Children[0] == 0 && Children[1] == 0)
00303 {
00304
00305 return error;
00306 }
00307 else
00308 {
00309
00310
00311
00312 double errorChildren = Children[0]->PruneSubtree()+
00313 Children[1]->PruneSubtree();
00314 #ifdef DEBUG_PRINT
00315
00316 cout << "Childrenerror of " << nodeId << " is: " << errorChildren
00317 << " " << errorChildren/pruningTotalMass << endl;
00318 #endif
00319
00320 if (error<=errorChildren)
00321 {
00322 #ifdef DEBUG_PRINT
00323 cout << "Prunning at node " << nodeId << endl;
00324 #endif
00325
00326 delete Children[0];
00327 Children[0]=0;
00328 delete Children[1];
00329 Children[1]=0;
00330
00331 return error;
00332 }
00333 else
00334 {
00335
00336 return errorChildren;
00337 }
00338 }
00339 }
00340
00341 void SaveToStream(ostream& out)
00342 {
00343 out << nodeId << " ";
00344
00345 Splitter.SaveToStream(out);
00346 out << " " << probFirstClass << endl;
00347
00348 if (Children[0] != 0 && Children[1] != 0)
00349 {
00350 Children[0]->SaveToStream(out);
00351 Children[1]->SaveToStream(out);
00352 }
00353 }
00354
00355 };
00356 }
00357
00358 #endif // _CLUS_BINARYPROBABILISTICDECISIONTREENODE_H_