00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #if !defined _CLUS_REGRESSIONTREENODE_H_
00035 #define _CLUS_REGRESSIONTREENODE_H_
00036
00037 #include "vec.h"
00038 #include "distribution.h"
00039 #include "dynamicbuffer.h"
00040
00041 using namespace TNT;
00042
00043 #define EMMAXTRIALS 3
00044
00045 namespace CLUS
00046 {
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059 template< class T_Distribution, class T_Regressor, class T_Splitter >
00060 class BinaryRegressionTreeNode
00061 {
00062 protected:
00063
00064
00065 int nodeId;
00066
00067
00068 int emTrials;
00069
00070
00071 int csDim;
00072
00073
00074 int regDim;
00075
00076
00077 enum state { stable, em, split, regression } State;
00078
00079
00080 BinaryRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter > * Children[2];
00081
00082
00083 T_Splitter Splitter;
00084
00085
00086 T_Regressor Regressor;
00087
00088
00089 T_Distribution* Distributions[2];
00090
00091
00092 T_Distribution* parentDistribution;
00093
00094
00095 DynamicBuffer* buffer;
00096
00097
00098 double pruningCost;
00099
00100
00101 int pruningSamples;
00102
00103 void RandomDistributions(void)
00104 {
00105
00106
00107
00108
00109
00110
00111
00112 Distributions[0]->RandomDistribution(2);
00113 Distributions[1]->RandomDistribution(2);
00114
00115
00116 }
00117
00118
00119
00120
00121
00122 double EMStep(double& Likelihood)
00123 {
00124 Likelihood = 0.0;
00125 for (double* X=buffer->begin(); X<buffer->end(); X+=regDim+2)
00126 {
00127
00128 double probability=X[regDim+1];
00129
00130 double p0=Distributions[0]->LearnProbability(X);
00131 double p1=Distributions[1]->LearnProbability(X);
00132
00133 if ( finite(p0+p1) )
00134 {
00135 Likelihood+=log((p0+p1)/probability);
00136 Distributions[0]->NormalizeLearnProbability((p0+p1)/probability,2);
00137 Distributions[1]->NormalizeLearnProbability((p0+p1)/probability,2);
00138 }
00139 }
00140
00141 double c0=Distributions[0]->UpdateParameters();
00142 double c1=Distributions[1]->UpdateParameters();
00143
00144 return ( (c0+c1)/2.0 );
00145 }
00146
00147 public:
00148
00149 BinaryRegressionTreeNode(int NodeId, int CsDim, T_Regressor& regressor, T_Distribution* ParentDistribution = 0):
00150 nodeId(NodeId), csDim(CsDim), Splitter(), Regressor(regressor), parentDistribution(ParentDistribution)
00151 {
00152 Children[0]=Children[1]=0;
00153 Distributions[0]=Distributions[1]=0;
00154 State=stable;
00155 emTrials=0;
00156 buffer=0;
00157 }
00158
00159
00160 BinaryRegressionTreeNode( int NodeId,
00161 const Vector<int> & DDomainSize,
00162 int CsplitDim, int RegDim,
00163 T_Regressor& regressor,
00164 T_Distribution* ParentDistribution = 0):
00165 nodeId(NodeId),csDim(CsplitDim), regDim(RegDim), State(em),
00166 Splitter(DDomainSize,CsplitDim,RegDim), Regressor(regressor),
00167 parentDistribution(ParentDistribution)
00168 {
00169
00170 Children[0]=Children[1]=0;
00171 Distributions[0] = new T_Distribution(regDim);
00172 Distributions[1] = new T_Distribution(regDim);
00173 RandomDistributions();
00174 emTrials=0;
00175 buffer=0;
00176 }
00177
00178 ~BinaryRegressionTreeNode(void)
00179 {
00180 if (Children[0]!=0)
00181 delete Children[0];
00182 Children[0]=0;
00183
00184 if (Children[1]!=0)
00185 delete Children[1];
00186 Children[1]=0;
00187
00188 if (Distributions[0]!=0)
00189 delete Distributions[0];
00190
00191 if (Distributions[1]!=0)
00192 delete Distributions[1];
00193 }
00194
00195 int GetNodeId(void)
00196 {
00197 return nodeId;
00198 }
00199
00200 void ComputeSizesTree(int& nodes, int& term_nodes)
00201 {
00202 nodes++;
00203 if (Children[0]!=0 && Children[1]!=0)
00204 {
00205 Children[0]->ComputeSizesTree(nodes,term_nodes);
00206 Children[1]->ComputeSizesTree(nodes,term_nodes);
00207 }
00208 else
00209 {
00210 term_nodes++;
00211 }
00212 }
00213
00214 void StartLearningEpoch(void)
00215 {
00216 switch (State)
00217 {
00218 case stable:
00219 if (Children[0]!=0)
00220 {
00221 Children[0]->StartLearningEpoch();
00222 Children[1]->StartLearningEpoch();
00223 }
00224 break;
00225 case em:
00226 buffer = new DynamicBuffer(regDim+2);
00227 break;
00228 case split:
00229
00230
00231 Splitter.InitializeSplitStatistics();
00232 break;
00233 case regression:
00234
00235 break;
00236 }
00237 }
00238
00239 void LearnSample(const int* Dvars, const double* Cvars,
00240 double probability, double threshold=.01)
00241 {
00242 double p0,p1;
00243
00244 if (probability<threshold)
00245 return;
00246
00247 switch (State)
00248 {
00249 case stable:
00250
00251 if (Children[0]==0 || Children[1]==0)
00252 return;
00253
00254
00255 {
00256 double probabilityLeft = probability*Splitter.ProbabilityLeft(Dvars,Cvars);
00257 double probabilityRight = probability-probabilityLeft;
00258 if (probabilityLeft>=threshold)
00259 {
00260 Children[0]->LearnSample(Dvars,Cvars,probabilityLeft,threshold);
00261 }
00262
00263 if (probabilityRight>=threshold)
00264 {
00265 Children[1]->LearnSample(Dvars,Cvars,probabilityRight,threshold);
00266 }
00267 }
00268
00269 break;
00270 case em:
00271
00272
00273
00274
00275 assert(buffer!=0);
00276 {
00277 double* cBufferLine=buffer->next();
00278 parentDistribution->NormalizeData(Cvars+csDim,cBufferLine);
00279 cBufferLine[regDim+1]=probability;
00280 }
00281 break;
00282 case split:
00283 p0=Distributions[0]->LearnProbability(Cvars+csDim);
00284 p1=Distributions[1]->LearnProbability(Cvars+csDim);
00285
00286 if (p0+p1>0.0)
00287 Splitter.UpdateSplitStatistics(Dvars, Cvars, p0/(p0+p1), p1/(p0+p1), probability );
00288 break;
00289 case regression:
00290 {
00291 double pChild1=Splitter.ProbabilityLeft(Dvars,Cvars);
00292
00293 if (pChild1 > threshold)
00294 {
00295 p0=Distributions[0]->LearnProbability(Cvars+csDim);
00296 Distributions[0]->NormalizeLearnProbability(p0/(pChild1*probability));
00297 }
00298 if (1.0-pChild1 > threshold)
00299 {
00300 p1=Distributions[1]->LearnProbability(Cvars+csDim);
00301 Distributions[1]->NormalizeLearnProbability(p1/((1.0-pChild1)*probability));
00302 }
00303 }
00304 break;
00305 }
00306 }
00307
00308
00309 bool StopLearningEpoch(int splitType, int emRestarts, int emMaxIterations,
00310 double convergenceLim, int min_no_datapoints)
00311 {
00312 bool moresplits=false;
00313 int emIterations;
00314
00315 T_Regressor* regressor0=0, * regressor1=0;
00316 switch (State)
00317 {
00318 case stable:
00319 if (Children[0]!=0)
00320 return Children[0]->StopLearningEpoch(splitType, emRestarts, emMaxIterations,
00321 convergenceLim, min_no_datapoints)
00322 | Children[1]->StopLearningEpoch(splitType, emRestarts, emMaxIterations,
00323 convergenceLim, min_no_datapoints);
00324 else
00325 return false;
00326 case em:
00327 emIterations=0;
00328
00329 {
00330
00331
00332
00333 T_Distribution best_d0(0);
00334 T_Distribution best_d1(0);
00335 double best_Likelihood=-1.0e+100;
00336 for (int repetition=0; repetition<emRestarts; repetition++)
00337 {
00338
00339 RandomDistributions();
00340 double Likelihood;
00341
00342 EMStep(Likelihood);
00343 EMStep(Likelihood);
00344
00345
00346
00347
00348 if (Likelihood > best_Likelihood)
00349 {
00350 best_d0 = *(Distributions[0]);
00351 best_d1 = *(Distributions[1]);
00352 best_Likelihood = Likelihood;
00353 }
00354 }
00355
00356 *(Distributions[0])=best_d0;
00357 *(Distributions[1])=best_d1;
00358 }
00359
00360
00361
00362 while (emIterations<emMaxIterations)
00363 {
00364 double Likelihood;
00365 double convFactor = EMStep(Likelihood);
00366
00367
00368
00369 emIterations++;
00370
00371 if ( Distributions[0]->HasZeroWeight() || Distributions[1]->HasZeroWeight() )
00372 {
00373 if (emTrials < EMMAXTRIALS)
00374 {
00375 emTrials++;
00376 cerr << "One of the distributions got killed. Starting again" << endl;
00377 RandomDistributions();
00378 emIterations=0;
00379 }
00380 else
00381 {
00382 cerr << "Tried " << EMMAXTRIALS << " times and didn't work. Making the node a leaf." << endl;
00383 goto makeleaf;
00384 }
00385 }
00386
00387 if (!finite(convFactor) || convFactor <= convergenceLim)
00388 break;
00389 }
00390
00391
00392 Distributions[0]->DenormalizeParameters(parentDistribution);
00393 Distributions[1]->DenormalizeParameters(parentDistribution);
00394 State=split;
00395
00396 delete buffer;
00397 buffer=0;
00398
00399 return true;
00400
00401 case split:
00402
00403
00404
00405 State=regression;
00406
00407 if (Splitter.ComputeSplitVariable(splitType)!=0)
00408 goto makeleaf;
00409
00410 Splitter.DeleteTemporaryStatistics();
00411
00412 return true;
00413 break;
00414
00415 case regression:
00416
00417
00418 Distributions[0]->UpdateParameters();
00419 Distributions[1]->UpdateParameters();
00420
00421
00422
00423
00424
00425
00426
00427
00428
00429
00430
00431 regressor0 = dynamic_cast<T_Regressor*> ( Distributions[0]->CreateRegressor() );
00432 regressor1 = dynamic_cast<T_Regressor*> ( Distributions[1]->CreateRegressor() );
00433
00434 if ( !regressor0 || !regressor1 )
00435 goto makeleaf;
00436
00437 State=stable;
00438
00439 if (Splitter.MoreSplits(0, min_no_datapoints))
00440 {
00441
00442 moresplits=true;
00443 Children[0] = new BinaryRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter >
00444 ( nodeId*2, Splitter.GetDDomainSize(), Splitter.GetCSplitDim(),
00445 Splitter.GetRegDim(), *regressor0, Distributions[0] );
00446 }
00447 else
00448 {
00449
00450 Children[0]=new BinaryRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter >
00451 ( nodeId*2, csDim, *regressor0, Distributions[0] );
00452 }
00453
00454 if (Splitter.MoreSplits(1, min_no_datapoints))
00455 {
00456 moresplits=true;
00457 Children[1] = new BinaryRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter >
00458 ( nodeId*2+1, Splitter.GetDDomainSize(), Splitter.GetCSplitDim(),
00459 Splitter.GetRegDim(), *regressor1, Distributions[1] );
00460 }
00461 else
00462 {
00463 Children[1]=new BinaryRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter >
00464 ( nodeId*2+1, csDim, *regressor1, Distributions[1] );
00465 }
00466
00467 Distributions[0]=Distributions[1]=0;
00468
00469 delete regressor0;
00470 delete regressor1;
00471
00472 return moresplits;
00473
00474 default:
00475 return false;
00476 }
00477
00478 makeleaf:
00479 cerr << "Something went wrong. Making node " << nodeId << " a leaf." << endl;
00480 Splitter.DeleteTemporaryStatistics();
00481
00482 Children[0]=Children[1]=0;
00483
00484 if (buffer!=0)
00485 {
00486 delete buffer;
00487 }
00488
00489 delete Distributions[0];
00490 Distributions[0]=0;
00491 delete Distributions[1];
00492 Distributions[1]=0;
00493
00494 if (regressor0)
00495 delete regressor0;
00496 if (regressor1)
00497 delete regressor1;
00498
00499 State=stable;
00500 return false;
00501 }
00502
00503
00504 double Infer(const int* Dvars, const double* Cvars, int maxNodeId, double threshold)
00505 {
00506 if (Children[0]==0 || nodeId>maxNodeId)
00507 {
00508
00509 return Regressor.Y(Cvars+csDim);
00510
00511 }
00512 else
00513 {
00514 double pChild1=Splitter.ProbabilityLeft(Dvars,Cvars);
00515
00516 return (pChild1>=threshold ? Children[0]->Infer(Dvars,Cvars,maxNodeId,threshold) : 0.0 )*pChild1+
00517 ( 1.0-pChild1>=threshold ? Children[1]->Infer(Dvars,Cvars,maxNodeId,threshold) : 0.0 )*(1.0-pChild1);
00518 }
00519 }
00520
00521 void InitializePruningStatistics(void)
00522 {
00523 pruningCost=0.0;
00524 pruningSamples=0;
00525 if (Children[0]!=0 && Children[1]!=0)
00526 {
00527 Children[0]->InitializePruningStatistics();
00528 Children[1]->InitializePruningStatistics();
00529 }
00530 }
00531
00532
00533 void UpdatePruningStatistics(const int* Dvars, const double* Cvars, double y ,
00534 double probability, double threshold)
00535 {
00536
00537
00538 double predY=Regressor.Y(Cvars+csDim);
00539 pruningCost+=pow2(y-predY)*probability;
00540
00541
00542 if (Children[0]==0 || Children[1]==0)
00543 return;
00544
00545 double probabilityLeft = probability*Splitter.ProbabilityLeft(Dvars,Cvars);
00546
00547 double probabilityRight = probability-probabilityLeft;
00548
00549 if (probabilityLeft>=threshold)
00550 {
00551 Children[0]->UpdatePruningStatistics(Dvars,Cvars,y,probabilityLeft,threshold);
00552 }
00553
00554 if (probabilityRight>=threshold)
00555 {
00556 Children[1]->UpdatePruningStatistics(Dvars,Cvars,y,probabilityRight,threshold);
00557 }
00558 }
00559
00560 void FinalizePruningStatistics (void)
00561 {
00562
00563 }
00564
00565
00566 double PruneSubtree(void)
00567 {
00568 if (Children[0] == 0 && Children[1] == 0)
00569 {
00570
00571 return pruningCost;
00572 }
00573 else
00574 {
00575
00576 double pruningCostChildren=Children[0]->PruneSubtree()+
00577 Children[1]->PruneSubtree();
00578
00579 if (pruningCost<=pruningCostChildren)
00580 {
00581
00582 delete Children[0];
00583 Children[0]=0;
00584 delete Children[1];
00585 Children[1]=0;
00586
00587 return pruningCost;
00588 }
00589 else
00590 {
00591
00592 return pruningCostChildren;
00593 }
00594 }
00595 }
00596
00597 void SaveToStream(ostream& out)
00598 {
00599 out << "{ " << nodeId << " [ ";
00600 if (Children[0]!=0 && Children[1]!=0)
00601 {
00602
00603 Splitter.SaveToStream(out);
00604 }
00605 out << " ] ( ";
00606 Regressor.SaveToStream(out);
00607 out << " ) }";
00608
00609 out << endl;
00610
00611
00612
00613
00614
00615
00616 if (Children[0]!=0 && Children[1]!=0)
00617 {
00618 Children[0]->SaveToStream(out);
00619 Children[1]->SaveToStream(out);
00620 }
00621 }
00622 };
00623 }
00624
00625
00626 #endif // _CLUS_REGRESSIONTREENODE_H_