00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #if !defined _CLUS_PROBABILISTICREGRESSIONTREENODE_H_
00035 #define _CLUS_PROBABILISTICREGRESSIONTREENODE_H_
00036
00037 #include "vec.h"
00038 #include "distribution.h"
00039 #include "dynamicbuffer.h"
00040
00041 using namespace TNT;
00042
00043 #define EMMAXTRIALS 3
00044
00045 namespace CLUS
00046 {
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059 template< class T_Distribution, class T_Regressor, class T_Splitter >
00060 class BinaryProbabilisticRegressionTreeNode
00061 {
00062 protected:
00063
00064
00065 int nodeId;
00066
00067
00068 int emTrials;
00069
00070
00071 int csDim;
00072
00073
00074 int regDim;
00075
00076
00077 enum state { stable, em, split, regression } State;
00078
00079
00080 BinaryProbabilisticRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter > * Children[2];
00081
00082
00083 T_Splitter Splitter;
00084
00085
00086 T_Regressor Regressor;
00087
00088
00089 T_Distribution* Distributions[2];
00090
00091
00092 T_Distribution* parentDistribution;
00093
00094
00095 DynamicBuffer* buffer;
00096
00097
00098 double pruningCost;
00099
00100
00101 int pruningSamples;
00102
00103 void RandomDistributions(void)
00104 {
00105
00106
00107
00108
00109
00110
00111
00112 Distributions[0]->RandomDistribution(2);
00113 Distributions[1]->RandomDistribution(2);
00114
00115
00116 }
00117
00118
00119
00120
00121
00122 double EMStep(double& Likelihood)
00123 {
00124 Likelihood = 0.0;
00125 for (double* X=buffer->begin(); X<buffer->end(); X+=regDim+2)
00126 {
00127
00128 double probability=X[regDim+1];
00129
00130 double p0=Distributions[0]->LearnProbability(X);
00131 double p1=Distributions[1]->LearnProbability(X);
00132 Likelihood+=log((p0+p1)/probability);
00133 Distributions[0]->NormalizeLearnProbability((p0+p1)/probability,2);
00134 Distributions[1]->NormalizeLearnProbability((p0+p1)/probability,2);
00135 }
00136
00137 double c0=Distributions[0]->UpdateParameters();
00138 double c1=Distributions[1]->UpdateParameters();
00139
00140 return ( (c0+c1)/2.0 );
00141 }
00142
00143
00144 public:
00145
00146 BinaryProbabilisticRegressionTreeNode(int NodeId, int CsDim, T_Regressor& regressor, T_Distribution* ParentDistribution = 0):
00147 nodeId(NodeId), csDim(CsDim), Splitter(), Regressor(regressor), parentDistribution(ParentDistribution)
00148 {
00149 Children[0]=Children[1]=0;
00150 Distributions[0]=Distributions[0]=0;
00151 State=stable;
00152 emTrials=0;
00153 buffer=0;
00154 }
00155
00156
00157 BinaryProbabilisticRegressionTreeNode( int NodeId,
00158 const Vector<int> & DDomainSize,
00159 int CsplitDim, int RegDim,
00160 T_Regressor& regressor,
00161 T_Distribution* ParentDistribution = 0):
00162 nodeId(NodeId),csDim(CsplitDim), regDim(RegDim), State(em),
00163 Splitter(DDomainSize,CsplitDim,RegDim), Regressor(regressor),
00164 parentDistribution(ParentDistribution)
00165 {
00166
00167 Children[0]=Children[1]=0;
00168 Distributions[0] = new T_Distribution(regDim);
00169 Distributions[1] = new T_Distribution(regDim);
00170 RandomDistributions();
00171 emTrials=0;
00172 buffer=0;
00173 }
00174
00175 ~BinaryProbabilisticRegressionTreeNode(void)
00176 {
00177 if (Children[0]!=0)
00178 delete Children[0];
00179 Children[0]=0;
00180
00181 if (Children[1]!=0)
00182 delete Children[1];
00183 Children[1]=0;
00184
00185
00186
00187
00188
00189
00190 }
00191
00192 int GetNodeId(void)
00193 {
00194 return nodeId;
00195 }
00196
00197 void ComputeSizesTree(int& nodes, int& term_nodes)
00198 {
00199 nodes++;
00200 if (Children[0]!=0 && Children[1]!=0)
00201 {
00202 Children[0]->ComputeSizesTree(nodes,term_nodes);
00203 Children[1]->ComputeSizesTree(nodes,term_nodes);
00204 }
00205 else
00206 {
00207 term_nodes++;
00208 }
00209 }
00210
00211 void StartLearningEpoch(void)
00212 {
00213 switch (State)
00214 {
00215 case stable:
00216 if (Children[0]!=0)
00217 {
00218 Children[0]->StartLearningEpoch();
00219 Children[1]->StartLearningEpoch();
00220 }
00221 break;
00222 case em:
00223 buffer = new DynamicBuffer(regDim+2);
00224 break;
00225 case split:
00226
00227
00228 Splitter.InitializeSplitStatistics();
00229 break;
00230 case regression:
00231
00232 break;
00233 }
00234 }
00235
00236 void LearnSample(const int* Dvars, const double* Cvars,
00237 double probability, double threshold)
00238 {
00239 double p0,p1;
00240
00241 if (probability<threshold)
00242 return;
00243
00244 switch (State)
00245 {
00246 case stable:
00247 if (Children[0]!=0)
00248 {
00249 double pChild1=Splitter.ProbabilityFirstBranch(Dvars,Cvars);
00250 Children[0]->LearnSample(Dvars,Cvars,pChild1*probability, threshold);
00251 Children[1]->LearnSample(Dvars,Cvars,(1.0-pChild1)*probability, threshold);
00252 }
00253 break;
00254 case em:
00255
00256
00257
00258
00259
00260
00261 assert(buffer!=0);
00262 {
00263 double* cBufferLine=buffer->next();
00264 parentDistribution->NormalizeData(Cvars+csDim,cBufferLine);
00265 cBufferLine[regDim+1]=probability;
00266 }
00267 break;
00268 case split:
00269 p0=Distributions[0]->LearnProbability(Cvars+csDim);
00270 p1=Distributions[1]->LearnProbability(Cvars+csDim);
00271 Splitter.UpdateSplitStatistics(Dvars, Cvars, p0/(p0+p1), p1/(p0+p1), probability );
00272 break;
00273 case regression:
00274 {
00275 double pChild1=Splitter.ProbabilityLeft(Dvars,Cvars);
00276
00277 if (pChild1 > threshold)
00278 {
00279 p0=Distributions[0]->LearnProbability(Cvars+csDim);
00280 Distributions[0]->NormalizeLearnProbability(p0/(pChild1*probability));
00281 }
00282 if (1.0-pChild1 > threshold)
00283 {
00284 p1=Distributions[1]->LearnProbability(Cvars+csDim);
00285 Distributions[1]->NormalizeLearnProbability(p1/((1.0-pChild1)*probability));
00286 }
00287 }
00288 break;
00289 }
00290 }
00291
00292
00293 bool StopLearningEpoch(int splitType, int emRestarts, int emMaxIterations,
00294 double convergenceLim, int min_no_datapoints)
00295 {
00296 bool moresplits=false;
00297 int emIterations;
00298
00299 T_Regressor* regressor0=0, * regressor1=0;
00300 switch (State)
00301 {
00302 case stable:
00303 if (Children[0]!=0)
00304 return Children[0]->StopLearningEpoch(splitType, emRestarts, emMaxIterations,
00305 convergenceLim, min_no_datapoints)
00306 | Children[1]->StopLearningEpoch(splitType, emRestarts, emMaxIterations,
00307 convergenceLim, min_no_datapoints);
00308 else
00309 return false;
00310 case em:
00311 emIterations=0;
00312
00313 {
00314 cout << "XXXNodeID: " << nodeId << " noDatapoints: " << buffer->dim() << endl;
00315
00316
00317 T_Distribution best_d0(0);
00318 T_Distribution best_d1(0);
00319 double best_Likelihood=-1.0e+100;
00320 for (int repetition=0; repetition<emRestarts; repetition++)
00321 {
00322
00323 RandomDistributions();
00324 double Likelihood;
00325
00326 EMStep(Likelihood);
00327 EMStep(Likelihood);
00328
00329 cout << "NodeID=" << nodeId << " repetition=" << repetition ;
00330 cout << " Likelihood=" << Likelihood << endl;
00331
00332 if (Likelihood > best_Likelihood)
00333 {
00334 best_d0 = *(Distributions[0]);
00335 best_d1 = *(Distributions[1]);
00336 best_Likelihood = Likelihood;
00337 }
00338 }
00339
00340 *(Distributions[0])=best_d0;
00341 *(Distributions[1])=best_d1;
00342 }
00343
00344
00345
00346 while (emIterations<emMaxIterations)
00347 {
00348 double Likelihood;
00349 double convFactor = EMStep(Likelihood);
00350
00351 cout << "NodeID=" << nodeId << " Likelihood=" << Likelihood << endl;
00352
00353 emIterations++;
00354
00355 if ( Distributions[0]->HasZeroWeight() || Distributions[1]->HasZeroWeight() )
00356 {
00357 if (emTrials < EMMAXTRIALS)
00358 {
00359 emTrials++;
00360 cerr << "One of the distributions got killed. Starting again" << endl;
00361 RandomDistributions();
00362 emIterations=0;
00363 }
00364 else
00365 {
00366 cerr << "Tried " << EMMAXTRIALS << " times and didn't work. Making the node a leaf." << endl;
00367 goto makeleaf;
00368 }
00369 }
00370
00371 if (!finite(convFactor) || convFactor <= convergenceLim)
00372 break;
00373 }
00374
00375
00376 Distributions[0]->DenormalizeParameters(parentDistribution);
00377 Distributions[1]->DenormalizeParameters(parentDistribution);
00378 State=split;
00379
00380 delete buffer;
00381 buffer=0;
00382
00383
00384 Distributions[0]->SaveToStream(cout);
00385 cout << endl;
00386 Distributions[1]->SaveToStream(cout);
00387 cout << endl;
00388
00389 return true;
00390
00391 case split:
00392
00393
00394
00395 State=regression;
00396
00397 if (Splitter.ComputeSplitVariable(splitType)!=0)
00398 goto makeleaf;
00399
00400 Splitter.DeleteTemporaryStatistics();
00401
00402 return true;
00403 break;
00404
00405 case regression:
00406 cout << "Node:" << nodeId << " finishing regression" << endl;
00407
00408 Distributions[0]->UpdateParameters();
00409 Distributions[1]->UpdateParameters();
00410
00411 regressor0 = dynamic_cast<T_Regressor*> ( Distributions[0]->CreateRegressor() );
00412 regressor1 = dynamic_cast<T_Regressor*> ( Distributions[1]->CreateRegressor() );
00413
00414 if ( !regressor0 || !regressor1 )
00415 goto makeleaf;
00416
00417 State=stable;
00418
00419 if (Splitter.MoreSplits(0, min_no_datapoints))
00420 {
00421
00422 moresplits=true;
00423 Children[0] = new BinaryProbabilisticRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter >
00424 ( nodeId*2, Splitter.GetDDomainSize(), Splitter.GetCSplitDim(),
00425 Splitter.GetRegDim(), *regressor0, Distributions[0] );
00426 }
00427 else
00428 {
00429
00430 Children[0]=new BinaryProbabilisticRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter >
00431 ( nodeId*2, csDim, *regressor0, Distributions[0] );
00432 }
00433
00434 if (Splitter.MoreSplits(1, min_no_datapoints))
00435 {
00436 moresplits=true;
00437 Children[1] = new BinaryProbabilisticRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter >
00438 ( nodeId*2+1, Splitter.GetDDomainSize(), Splitter.GetCSplitDim(),
00439 Splitter.GetRegDim(), *regressor1, Distributions[1] );
00440 }
00441 else
00442 {
00443 Children[1]=new BinaryProbabilisticRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter >
00444 ( nodeId*2+1, csDim, *regressor1, Distributions[1] );
00445 }
00446
00447 Distributions[0]=Distributions[1]=0;
00448
00449 delete regressor0;
00450 delete regressor1;
00451
00452 return moresplits;
00453
00454 default:
00455 return false;
00456 }
00457
00458 makeleaf:
00459 cerr << "Something went wrong. Making node " << nodeId << " a leaf." << endl;
00460 Splitter.DeleteTemporaryStatistics();
00461
00462 Children[0]=Children[1]=0;
00463
00464 if (buffer!=0)
00465 {
00466 delete buffer;
00467 }
00468
00469 delete Distributions[0];
00470 delete Distributions[1];
00471
00472 if (regressor0)
00473 delete regressor0;
00474 if (regressor1)
00475 delete regressor1;
00476
00477 State=stable;
00478 return false;
00479 }
00480
00481
00482 double Infer(const int* Dvars, const double* Cvars, int maxNodeId, double threshold)
00483 {
00484 if (Children[0]==0 || nodeId>maxNodeId)
00485 {
00486
00487 return Regressor.Y(Cvars+csDim);
00488
00489 }
00490 else
00491 {
00492 double pChild1=Splitter.ProbabilityLeft(Dvars,Cvars);
00493
00494 return (pChild1>=threshold ? Children[0]->Infer(Dvars,Cvars,maxNodeId) : 0.0 )*pChild1+
00495 ( 1.0-pChild1>=threshold ? Children[1]->Infer(Dvars,Cvars,maxNodeId) : 0.0 )*(1.0-pChild1);
00496 }
00497 }
00498
00499 void InitializePruningStatistics(void)
00500 {
00501 pruningCost=0.0;
00502 pruningSamples=0;
00503 if (Children[0]!=0 && Children[1]!=0)
00504 {
00505 Children[0]->InitializePruningStatistics();
00506 Children[1]->InitializePruningStatistics();
00507 }
00508 }
00509
00510 void UpdatePruningStatistics(const int* Dvars, const double* Cvars, double y ,
00511 double probability)
00512 {
00513 pruningSamples++;
00514 double predY=Regressor.Y(Cvars+csDim);
00515 pruningCost+=pow2(y-predY)*probability;
00516
00517 if (Children[0]!=0 && Children[1]!=0)
00518 {
00519 double pChild1=Splitter.ProbabilityFirstBranch(Dvars,Cvars);
00520 Children[0]->UpdatePruningStatistics(Dvars,Cvars,y,probability*pChild1);
00521 Children[1]->UpdatePruningStatistics(Dvars,Cvars,y,probability*(1.0-pChild1));
00522 }
00523 }
00524
00525 void FinalizePruningStatistics (void)
00526 {
00527
00528 }
00529
00530
00531 double PruneSubtree(void)
00532 {
00533 if (Children[0] == 0 && Children[1] == 0)
00534 {
00535
00536 cout << "nodeID=" << nodeId << " pruningCost=" << pruningCost << endl;
00537 return pruningCost;
00538 }
00539 else
00540 {
00541
00542 double pruningCostChildren=Children[0]->PruneSubtree()+Children[1]->PruneSubtree();
00543
00544 cout << "nodeID=" << nodeId << " pruningCost=" << pruningCost << " pruningCostChildren=" <<
00545 pruningCostChildren << endl;
00546
00547 if (pruningCost<=pruningCostChildren)
00548 {
00549
00550 delete Children[0];
00551 Children[0]=0;
00552 delete Children[1];
00553 Children[1]=0;
00554
00555 return pruningCost;
00556 }
00557 else
00558 {
00559
00560 return pruningCostChildren;
00561 }
00562 }
00563 }
00564
00565 void SaveToStream(ostream& out)
00566 {
00567 out << "{ " << nodeId << " [ ";
00568 if (Children[0]!=0 && Children[1]!=0)
00569 {
00570
00571 Splitter.SaveToStream(out);
00572 }
00573 out << " ] ( ";
00574 Regressor.SaveToStream(out);
00575 out << " ) }";
00576
00577 out << endl;
00578
00579
00580
00581
00582
00583
00584 if (Children[0]!=0 && Children[1]!=0)
00585 {
00586 Children[0]->SaveToStream(out);
00587 Children[1]->SaveToStream(out);
00588 }
00589 }
00590 };
00591 }
00592
00593
00594 #endif // _CLUS_PROBABILISTICREGRESSIONTREENODE_H_