Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

regressiontreenode.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #if !defined _CLUS_REGRESSIONTREENODE_H_
00035 #define _CLUS_REGRESSIONTREENODE_H_
00036 
00037 #include "vec.h"
00038 #include "distribution.h"
00039 #include "dynamicbuffer.h"
00040 
00041 using namespace TNT;
00042 
00043 #define EMMAXTRIALS 3
00044 
00045 namespace CLUS
00046 {
00047 
00048 /** Class used in building regression trees.
00049     Every class is a node in the tree, there can be 0 or 2 children
00050     NodeId determines the position of the node in the tree. The nodeId
00051     od children is alwais 2*nodeId and 2*nodeId+1. The nodeId starts from 1
00052     for the root. 
00053  
00054     T_Distribution is the distribution that approximates the data
00055     T_Regressor is the regressor that commes with T_Distribution
00056     T_Spliter is the class that in a final scan can compute the 
00057     splitting predicate
00058 */
00059 template< class T_Distribution, class T_Regressor, class T_Splitter >
00060 class BinaryRegressionTreeNode
00061 {
00062 protected:
00063     /// unique identifier of the cluster for a regression tree.
00064     /// the id of the children is always 2*Id, 2*Id+1
00065     int nodeId; 
00066 
00067     /// how many times EM was attempted
00068     int emTrials;
00069 
00070     /// the number of continuous split variables
00071     int csDim;
00072 
00073     /// the number of regressors
00074     int regDim;
00075     
00076     /// the state of the node. At creation em. At load stable
00077     enum state { stable, em, split, regression } State;
00078     
00079     /// the children of this node
00080     BinaryRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter > * Children[2];
00081     
00082     /// splitter for split criterion
00083     T_Splitter Splitter;
00084     
00085     /// regressor for the node
00086     T_Regressor Regressor;
00087     
00088     /// distributions for the children for the learning process
00089     T_Distribution* Distributions[2];
00090     
00091     /// the distribution of the parent. Usefull in the learning process to pick initial values for the new distributions
00092     T_Distribution* parentDistribution;
00093 
00094     /// keep the EM samples here temporarily
00095     DynamicBuffer* buffer; 
00096 
00097     /// sum of squared differences between prediction and true value
00098     double pruningCost;
00099 
00100     /// number of samples in this node for pruning
00101     int pruningSamples; 
00102 
00103     void RandomDistributions(void)
00104     {
00105         /*      if (parentDistribution != 0){
00106          Distributions[0]->RandomDistribution(2,*parentDistribution);
00107          Distributions[1]->RandomDistribution(2,*parentDistribution);
00108          } else { 
00109         */
00110 
00111         // Distributions are build on normalized data
00112         Distributions[0]->RandomDistribution(2);
00113         Distributions[1]->RandomDistribution(2);
00114 
00115         // }
00116     }
00117 
00118     /**
00119     Peform an EM step on the data in buff and return the convergence factor 
00120     and compute the likelihood. 
00121     */
00122     double EMStep(double& Likelihood)
00123     {
00124         Likelihood = 0.0;
00125         for (double* X=buffer->begin(); X<buffer->end(); X+=regDim+2)
00126         {
00127             // Do the EM learning on the two distributions
00128             double probability=X[regDim+1];
00129 
00130             double p0=Distributions[0]->LearnProbability(X);
00131             double p1=Distributions[1]->LearnProbability(X);
00132 
00133             if ( finite(p0+p1) )
00134             {
00135                 Likelihood+=log((p0+p1)/probability);
00136                 Distributions[0]->NormalizeLearnProbability((p0+p1)/probability,2);
00137                 Distributions[1]->NormalizeLearnProbability((p0+p1)/probability,2);
00138             }
00139         }
00140 
00141         double c0=Distributions[0]->UpdateParameters();
00142         double c1=Distributions[1]->UpdateParameters();
00143 
00144         return ( (c0+c1)/2.0 );
00145     }
00146 
00147 public:
00148     /// Constructor for a leaf
00149     BinaryRegressionTreeNode(int NodeId, int CsDim, T_Regressor& regressor,  T_Distribution* ParentDistribution = 0):
00150             nodeId(NodeId), csDim(CsDim), Splitter(), Regressor(regressor), parentDistribution(ParentDistribution)
00151     {
00152         Children[0]=Children[1]=0;
00153         Distributions[0]=Distributions[1]=0;
00154         State=stable; // do no learning on this node
00155         emTrials=0;
00156         buffer=0;
00157     }
00158 
00159     /// Constructor for an intermediate node
00160     BinaryRegressionTreeNode( int NodeId,
00161                               const Vector<int> & DDomainSize,
00162                               int CsplitDim, int RegDim,
00163                               T_Regressor& regressor,
00164                               T_Distribution* ParentDistribution = 0):
00165             nodeId(NodeId),csDim(CsplitDim), regDim(RegDim), State(em),
00166             Splitter(DDomainSize,CsplitDim,RegDim), Regressor(regressor),
00167             parentDistribution(ParentDistribution)
00168     {
00169 
00170         Children[0]=Children[1]=0;
00171         Distributions[0] = new T_Distribution(regDim);
00172         Distributions[1] = new T_Distribution(regDim);
00173         RandomDistributions();
00174         emTrials=0;
00175         buffer=0;
00176     }
00177 
00178     ~BinaryRegressionTreeNode(void)
00179     {
00180         if (Children[0]!=0)
00181             delete Children[0];
00182         Children[0]=0;
00183 
00184         if (Children[1]!=0)
00185             delete Children[1];
00186         Children[1]=0;
00187 
00188         if (Distributions[0]!=0)
00189             delete Distributions[0];
00190 
00191         if (Distributions[1]!=0)
00192             delete Distributions[1];
00193     }
00194 
00195     int GetNodeId(void)
00196     {
00197         return nodeId;
00198     }
00199 
00200     void ComputeSizesTree(int& nodes, int& term_nodes)
00201     {
00202         nodes++;
00203         if (Children[0]!=0 && Children[1]!=0)
00204         {
00205             Children[0]->ComputeSizesTree(nodes,term_nodes);
00206             Children[1]->ComputeSizesTree(nodes,term_nodes);
00207         }
00208         else
00209         {
00210             term_nodes++;
00211         }
00212     }
00213 
00214     void StartLearningEpoch(void)
00215     {
00216         switch (State)
00217         {
00218         case stable:
00219             if (Children[0]!=0)
00220             {
00221                 Children[0]->StartLearningEpoch();
00222                 Children[1]->StartLearningEpoch();
00223             }
00224             break;
00225         case em:
00226             buffer = new DynamicBuffer(regDim+2); // allocate the temporary buffer
00227             break;
00228         case split:
00229             // delete parentDistribution; // we dont' need the parent distribution anymore and we get rid of it
00230             // parentDistribution=0;
00231             Splitter.InitializeSplitStatistics();
00232             break;
00233         case regression:
00234             // nothing to do
00235             break;
00236         }
00237     }
00238 
00239     void LearnSample(const int* Dvars, const double* Cvars,
00240                      double probability, double threshold=.01)
00241     {
00242         double p0,p1;
00243 
00244         if (probability<threshold)
00245             return;
00246 
00247         switch (State)
00248         {
00249         case stable:
00250             // Pass the sample to the right child
00251             if (Children[0]==0 || Children[1]==0)
00252                 return;
00253 
00254             // propagate the learning
00255             {
00256                 double probabilityLeft = probability*Splitter.ProbabilityLeft(Dvars,Cvars);
00257                 double probabilityRight = probability-probabilityLeft;
00258                 if (probabilityLeft>=threshold)
00259                 {
00260                     Children[0]->LearnSample(Dvars,Cvars,probabilityLeft,threshold);
00261                 }
00262 
00263                 if (probabilityRight>=threshold)
00264                 {
00265                     Children[1]->LearnSample(Dvars,Cvars,probabilityRight,threshold);
00266                 }
00267             }
00268 
00269             break;
00270         case em:
00271             // First normalize, otherwise em converges to wrong things (no substructure identified)
00272             // The normalization makes transforms the data to be centered on 0 and have identity
00273             // spartity matrix (look superficially like a ball)
00274 
00275             assert(buffer!=0);
00276             {
00277                 double* cBufferLine=buffer->next();
00278                 parentDistribution->NormalizeData(Cvars+csDim,cBufferLine);
00279                 cBufferLine[regDim+1]=probability;
00280             }
00281             break;
00282         case split:
00283             p0=Distributions[0]->LearnProbability(Cvars+csDim);
00284             p1=Distributions[1]->LearnProbability(Cvars+csDim);
00285 
00286             if (p0+p1>0.0)
00287                 Splitter.UpdateSplitStatistics(Dvars, Cvars, p0/(p0+p1), p1/(p0+p1), probability );
00288             break;
00289         case regression:
00290             {
00291                 double pChild1=Splitter.ProbabilityLeft(Dvars,Cvars);
00292 
00293                 if (pChild1 > threshold)
00294                 {
00295                     p0=Distributions[0]->LearnProbability(Cvars+csDim);
00296                     Distributions[0]->NormalizeLearnProbability(p0/(pChild1*probability));
00297                 }
00298                 if (1.0-pChild1 > threshold)
00299                 {
00300                     p1=Distributions[1]->LearnProbability(Cvars+csDim);
00301                     Distributions[1]->NormalizeLearnProbability(p1/((1.0-pChild1)*probability));
00302                 }
00303             }
00304             break;
00305         }
00306     }
00307 
00308     /// Returns true if more learning has to be donne in future
00309     bool StopLearningEpoch(int splitType, int emRestarts, int emMaxIterations,
00310                            double convergenceLim, int min_no_datapoints)
00311     {
00312         bool moresplits=false;
00313         int emIterations;
00314 
00315         T_Regressor* regressor0=0, * regressor1=0;
00316         switch (State)
00317         {
00318         case stable:
00319             if (Children[0]!=0)
00320                 return Children[0]->StopLearningEpoch(splitType, emRestarts, emMaxIterations,
00321                                                       convergenceLim, min_no_datapoints)
00322                        | Children[1]->StopLearningEpoch(splitType, emRestarts, emMaxIterations,
00323                                                         convergenceLim, min_no_datapoints); // || cannot be used since is shortcuted
00324             else
00325                 return false;
00326         case em:
00327             emIterations=0;
00328 
00329             {
00330                 // cout << "XXXNodeID: " << nodeId << " noDatapoints: " << buffer->dim() << endl;
00331 
00332                 // do a number of random restarts and pick the starting point with the maximum likelihood
00333                 T_Distribution best_d0(0);
00334                 T_Distribution best_d1(0);
00335                 double best_Likelihood=-1.0e+100;
00336                 for (int repetition=0; repetition<emRestarts; repetition++)
00337                 {
00338                     // restart randomly
00339                     RandomDistributions();
00340                     double Likelihood;
00341                     // do two EM steps
00342                     EMStep(Likelihood);
00343                     EMStep(Likelihood);
00344 
00345                     //cout << "NodeID=" << nodeId << " repetition=" << repetition ;
00346                     //cout << " Likelihood=" << Likelihood << endl;
00347 
00348                     if (Likelihood > best_Likelihood)
00349                     {
00350                         best_d0 = *(Distributions[0]);
00351                         best_d1 = *(Distributions[1]);
00352                         best_Likelihood = Likelihood;
00353                     }
00354                 }
00355 
00356                 *(Distributions[0])=best_d0;
00357                 *(Distributions[1])=best_d1;
00358             }
00359 
00360             // continue the learning process with the best starting point
00361 
00362             while (emIterations<emMaxIterations)
00363             {
00364                 double Likelihood;
00365                 double convFactor = EMStep(Likelihood);
00366 
00367                 // cout << "NodeID=" << nodeId << " Likelihood=" << Likelihood << endl;
00368 
00369                 emIterations++;
00370 
00371                 if ( Distributions[0]->HasZeroWeight() || Distributions[1]->HasZeroWeight() )
00372                 {
00373                     if (emTrials < EMMAXTRIALS)
00374                     {
00375                         emTrials++;
00376                         cerr << "One of the distributions got killed. Starting again" << endl;
00377                         RandomDistributions();
00378                         emIterations=0;
00379                     }
00380                     else
00381                     {
00382                         cerr << "Tried " << EMMAXTRIALS << " times and didn't work. Making the node a leaf." << endl;
00383                         goto makeleaf;
00384                     }
00385                 }
00386 
00387                 if (!finite(convFactor) || convFactor <= convergenceLim)
00388                     break;
00389             }
00390 
00391             // denormalize parameters for the distributions
00392             Distributions[0]->DenormalizeParameters(parentDistribution);
00393             Distributions[1]->DenormalizeParameters(parentDistribution);
00394             State=split;
00395 
00396             delete buffer;
00397             buffer=0; // to avoid deleting it again
00398 
00399             return true;
00400 
00401         case split:
00402 
00403             // cout << "Node: " << nodeId << " being split" << endl;
00404 
00405             State=regression;
00406 
00407             if (Splitter.ComputeSplitVariable(splitType)!=0)
00408                 goto makeleaf;
00409 
00410             Splitter.DeleteTemporaryStatistics();
00411 
00412             return true;
00413             break;
00414 
00415         case regression:
00416             // cout << "Node:" << nodeId << " finishing regression" << endl;
00417 
00418             Distributions[0]->UpdateParameters();
00419             Distributions[1]->UpdateParameters();
00420 
00421             // print the distribution information for the two distributions
00422             /*
00423               cout << "Left distribution ";
00424               Distributions[0]->SaveToStream(cout);
00425               cout << endl;
00426               cout << "Right distribution ";
00427               Distributions[1]->SaveToStream(cout);
00428               cout << endl;
00429             */
00430 
00431             regressor0 = dynamic_cast<T_Regressor*> ( Distributions[0]->CreateRegressor() );
00432             regressor1 = dynamic_cast<T_Regressor*> ( Distributions[1]->CreateRegressor() );
00433 
00434             if ( !regressor0 || !regressor1 )
00435                 goto makeleaf;
00436 
00437             State=stable;
00438 
00439             if (Splitter.MoreSplits(0, min_no_datapoints))
00440             {
00441                 // more splits in future for child 1, create normal child
00442                 moresplits=true;
00443                 Children[0] = new BinaryRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter >
00444                               ( nodeId*2, Splitter.GetDDomainSize(), Splitter.GetCSplitDim(),
00445                                 Splitter.GetRegDim(), *regressor0, Distributions[0] );
00446             }
00447             else
00448             {
00449                 // no more splits in future, create leaf children
00450                 Children[0]=new BinaryRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter >
00451                             ( nodeId*2, csDim, *regressor0, Distributions[0] );
00452             }
00453 
00454             if (Splitter.MoreSplits(1, min_no_datapoints))
00455             {
00456                 moresplits=true;
00457                 Children[1] = new BinaryRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter >
00458                               ( nodeId*2+1, Splitter.GetDDomainSize(), Splitter.GetCSplitDim(),
00459                                 Splitter.GetRegDim(), *regressor1, Distributions[1] );
00460             }
00461             else
00462             {
00463                 Children[1]=new BinaryRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter >
00464                             ( nodeId*2+1, csDim, *regressor1, Distributions[1] );
00465             }
00466 
00467             Distributions[0]=Distributions[1]=0; // the children are responsible with dealocating distribs when they don't need them anymore
00468 
00469             delete regressor0;
00470             delete regressor1;
00471 
00472             return moresplits;
00473 
00474         default:
00475             return false;
00476         }
00477 
00478 makeleaf:
00479         cerr << "Something went wrong. Making node " << nodeId << "  a leaf." << endl;
00480         Splitter.DeleteTemporaryStatistics();
00481 
00482         Children[0]=Children[1]=0;
00483 
00484         if (buffer!=0)
00485         {
00486             delete buffer;
00487         }
00488 
00489         delete Distributions[0];
00490         Distributions[0]=0;
00491         delete Distributions[1];
00492         Distributions[1]=0;
00493 
00494         if (regressor0)
00495             delete regressor0;
00496         if (regressor1)
00497             delete regressor1;
00498 
00499         State=stable;
00500         return false;
00501     }
00502 
00503     /// Does the inference
00504     double Infer(const int* Dvars, const double* Cvars, int maxNodeId, double threshold)
00505     {
00506         if (Children[0]==0 || nodeId>maxNodeId)
00507         {
00508             // leaf node or level cut
00509             return Regressor.Y(Cvars+csDim);
00510             //return nodeId;
00511         }
00512         else
00513         {
00514             double pChild1=Splitter.ProbabilityLeft(Dvars,Cvars);
00515 
00516             return (pChild1>=threshold ? Children[0]->Infer(Dvars,Cvars,maxNodeId,threshold) : 0.0 )*pChild1+
00517                    ( 1.0-pChild1>=threshold ? Children[1]->Infer(Dvars,Cvars,maxNodeId,threshold) : 0.0 )*(1.0-pChild1);
00518         }
00519     }
00520 
00521     void InitializePruningStatistics(void)
00522     {
00523         pruningCost=0.0;
00524         pruningSamples=0;
00525         if (Children[0]!=0 && Children[1]!=0)
00526         {
00527             Children[0]->InitializePruningStatistics();
00528             Children[1]->InitializePruningStatistics();
00529         }
00530     }
00531 
00532 
00533     void UpdatePruningStatistics(const int* Dvars, const double* Cvars, double y /* true output */,
00534                                  double probability, double threshold)
00535     {
00536         //pruningTotalMass+=probability;
00537 
00538         double predY=Regressor.Y(Cvars+csDim);
00539         pruningCost+=pow2(y-predY)*probability;
00540 
00541         // update pruning statistics for the proper children
00542         if (Children[0]==0 || Children[1]==0)
00543             return; // stop the process
00544 
00545         double probabilityLeft = probability*Splitter.ProbabilityLeft(Dvars,Cvars);
00546         //pruningTotalMassLeft+=probabilityLeft;
00547         double probabilityRight = probability-probabilityLeft;
00548 
00549         if (probabilityLeft>=threshold)
00550         {
00551             Children[0]->UpdatePruningStatistics(Dvars,Cvars,y,probabilityLeft,threshold);
00552         }
00553 
00554         if (probabilityRight>=threshold)
00555         {
00556             Children[1]->UpdatePruningStatistics(Dvars,Cvars,y,probabilityRight,threshold);
00557         }
00558     }
00559 
00560     void FinalizePruningStatistics (void)
00561     {
00562         // nothing to do
00563     }
00564 
00565     /** Returns the optimal cost for this subtree and cuts the subtree to optimal size */
00566     double PruneSubtree(void)
00567     { // double alpha /* alpha is the cost for a leaf */){
00568         if (Children[0] == 0 && Children[1] == 0)
00569         {
00570             // node is a leaf
00571             return pruningCost;
00572         }
00573         else
00574         {
00575             // node is an intermediary node
00576             double pruningCostChildren=Children[0]->PruneSubtree()+
00577                                        Children[1]->PruneSubtree();
00578 
00579             if (pruningCost<=pruningCostChildren)
00580             {
00581                 // prune the tree here
00582                 delete Children[0];
00583                 Children[0]=0;
00584                 delete Children[1];
00585                 Children[1]=0;
00586 
00587                 return pruningCost;
00588             }
00589             else
00590             {
00591                 // tree is good as it is
00592                 return pruningCostChildren;
00593             }
00594         }
00595     }
00596 
00597     void SaveToStream(ostream& out)
00598     {
00599         out << "{ " << nodeId << " [ ";
00600         if (Children[0]!=0 && Children[1]!=0)
00601         {
00602             // intermediate node
00603             Splitter.SaveToStream(out);
00604         } // else leaf, leave empty
00605         out << " ] ( ";
00606         Regressor.SaveToStream(out);
00607         out << " ) }";
00608 
00609         out << endl;
00610 
00611         /*
00612         if (parentDistribution!=0)
00613         parentDistribution->SaveToStream(out);
00614         */
00615 
00616         if (Children[0]!=0 && Children[1]!=0)
00617         {
00618             Children[0]->SaveToStream(out);
00619             Children[1]->SaveToStream(out);
00620         }
00621     }
00622 };
00623 }
00624 
00625 
00626 #endif //  _CLUS_REGRESSIONTREENODE_H_

Generated on Mon Jul 21 16:57:25 2003 for SECRET by doxygen 1.3.2