Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

probabilisticregressiontree.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #if !defined _CLUS_PROBABILISTICREGRESSIONTREE_H_
00035 #define _CLUS_PROBABILISTICREGRESSIONTREE_H_
00036 
00037 #include "machine.h"
00038 #include "probabilisticregressiontreenode.h"
00039 #include "dctraingen.h"
00040 #include <iostream>
00041 
00042 // using namespace TNT;
00043 
00044 namespace CLUS
00045 {
00046 
00047 template< class T_Distribution, class T_Regressor, class T_Splitter >
00048 class BinaryProbabilisticRegressionTree : public Machine
00049 {
00050 protected:
00051     BinaryProbabilisticRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter>* root;
00052     
00053     /// list of discrete domain sizes
00054     const Vector<int>& dDomainSize;
00055 
00056     /// num of discrete variables
00057     int dsplitDim;
00058 
00059     /// num of continuous+split variables
00060     int csplitDim;
00061 
00062     /// num of regression variables
00063     int regDim;
00064 
00065     /// num of iterations to get convergence for EM
00066     int emMaxIterations;
00067 
00068     /// num of restarts of EM to get a good initial starting point
00069     int emRestarts;
00070 
00071     /// the minimum number of datapoints in a node to split further
00072     int min_no_datapoints;
00073 
00074     /// type of split to be passed to splitter, splitter dependent
00075     int splitType;
00076 
00077     T_Distribution* rootDistribution;
00078     int inferMaxNodeId;
00079 
00080     void PrintSizesTree(void)
00081     {
00082         int nodes=0;
00083         int term_nodes=0;
00084         root->ComputeSizesTree(nodes,term_nodes);
00085         cout << "Nuber of nodes=" << nodes << "\tNumber of terminal nodes=" << term_nodes << endl;
00086     }
00087 
00088 public:
00089     BinaryProbabilisticRegressionTree(const Vector<int>& DDomainSize, int CsplitDim, int RegDim):
00090             Machine(CsplitDim+RegDim,1),dDomainSize(DDomainSize),
00091             dsplitDim(DDomainSize.dim()),csplitDim(CsplitDim), regDim(RegDim)
00092     {
00093         emRestarts = 3;
00094         emMaxIterations = 30;
00095         min_no_datapoints = 10;
00096         splitType = 0;
00097         rootDistribution = 0;
00098         inferMaxNodeId = INT_MAX;
00099     }
00100 
00101     virtual ~BinaryProbabilisticRegressionTree(void)
00102     {
00103         if (rootDistribution)
00104             delete rootDistribution;
00105     }
00106 
00107     virtual int InDim(void)
00108     {
00109         return dsplitDim+csplitDim+regDim;
00110     }
00111 
00112     virtual string TypeName(void)
00113     {
00114         return string("BinaryRegressionTree");
00115     }
00116 
00117     virtual void Infer(void)
00118     {
00119         if (root==0)
00120             return;
00121             
00122         // translate the first dsplitDim inputs into int
00123         int Dvars[MAX_VARIABLES];
00124         
00125         for(int i=0; i<dsplitDim; i++)
00126             Dvars[i]=(int)(*input)[i];
00127             
00128         // scale the continuous inputs
00129         double scaledInput[MAX_VARIABLES];
00130         for (int i=0; i<csplitDim+regDim; i++)
00131             scaledInput[i]=scale[i].Transform( (*input)[i+dsplitDim] );
00132 
00133         output[0]=scale[csplitDim+regDim].Transform( root->Infer(Dvars,scaledInput,inferMaxNodeId) );
00134     }
00135 
00136     virtual void Identify(void)
00137     {
00138         const Matrix<double>& ctrainData = training->GetTrainingData();
00139         const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( training )
00140                                         -> GetDiscreteTrainingData();
00141 
00142         int M=ctrainData.num_rows();
00143 
00144         // find a linear regressor for the first node
00145         if (!rootDistribution)
00146             rootDistribution = new T_Distribution(regDim);
00147 
00148         rootDistribution->RandomDistribution(1);
00149         double Coef;
00150         for (int i=0; i<M; i++)
00151         {
00152             Coef = rootDistribution->LearnProbability(ctrainData[i]+csplitDim);
00153             rootDistribution->NormalizeLearnProbability(Coef);
00154         }
00155         rootDistribution->UpdateParameters();
00156 
00157         /*
00158         cout << "Printing root distribution" << endl;
00159         rootDistribution->SaveToStream(cout);
00160         cout << endl;
00161         */
00162 
00163         T_Regressor* regressor = dynamic_cast<T_Regressor*>( rootDistribution->CreateRegressor() );
00164         // create the root and give it the Id 1
00165         root = new BinaryProbabilisticRegressionTreeNode<T_Distribution, T_Regressor, T_Splitter>
00166                (1, dDomainSize, csplitDim, regDim, *regressor, rootDistribution);
00167 
00168         rootDistribution = 0; // root will dealocate rootDistribution
00169         // learn in stages until nobody wants to learn anymore
00170         do
00171         {
00172             root->StartLearningEpoch();
00173             for(int i=0; i<M; i++)
00174                 root->LearnSample(dtrainData[i],ctrainData[i],1.0,0.01);
00175         }
00176         while (root->StopLearningEpoch(splitType, emRestarts, emMaxIterations,
00177                                        convergenceLim, min_no_datapoints));
00178         cout << "End Learning" << endl;
00179         PrintSizesTree();
00180     }
00181 
00182     virtual void Prune(void)
00183     {
00184         const Matrix<double>& ctrainData = pruning->GetTrainingData();
00185         const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( pruning )
00186                                         -> GetDiscreteTrainingData();
00187 
00188         int M=ctrainData.num_rows();
00189 
00190         root->InitializePruningStatistics();
00191         for(int i=0; i<M; i++)
00192         {
00193             // scale the continuous inputs
00194             double scaledInput[MAX_VARIABLES];
00195             for (int j=0; j<csplitDim+regDim; j++)
00196                 scaledInput[j]=scale[j].Transform( ctrainData[i][j] );
00197             // scale the output
00198             double y=scale[csplitDim+regDim].InverseTransform( ctrainData[i][csplitDim+regDim] );
00199 
00200             root->UpdatePruningStatistics(dtrainData[i], scaledInput, y, 1.0);
00201         }
00202         root->FinalizePruningStatistics();
00203 
00204         // now cut the tree to the right size
00205         double cost=root->PruneSubtree();
00206         cout << "RMSN after pruning is:" << cost/M << endl;
00207         PrintSizesTree();
00208     }
00209 
00210     virtual int SetOption(char* name, char* val)
00211     {
00212         if (strcmp(name,"EMMaxIterations")==0)
00213             emMaxIterations = atoi(val);
00214         else
00215             if (strcmp(name,"EMRestarts")==0)
00216                 emRestarts = atoi(val);
00217             else
00218                 if (strcmp(name,"InferMaxNodeId")==0)
00219                     inferMaxNodeId = atoi(val);
00220                 else
00221                     if (strcmp(name,"MaxNoDatapoints")==0)
00222                         min_no_datapoints = 2*atoi(val)-2;
00223                     else
00224                         if (strcmp(name,"SplitType")==0)
00225                             splitType = atoi(val);
00226                         else
00227                             return Machine::SetOption(name,val);
00228         return 1;
00229     }
00230 
00231     virtual void SaveToStream(ostream& out)
00232     {
00233         out << TypeName() << " (  " << dsplitDim << " " << csplitDim << " ";
00234         out << regDim << " ) { "  << endl;
00235 
00236         out << "[ ";
00237         for(int i=0; i<dsplitDim; i++)
00238             out << dDomainSize[i] << " ";
00239         out << "]" << endl;
00240 
00241         for(int i=0; i<csplitDim+regDim+1; i++)
00242             scale[i].SaveToStream(out);
00243         out << endl;
00244 
00245         root->SaveToStream(out);
00246 
00247         out << '}' << endl;
00248     }
00249 };
00250 
00251 }
00252 
00253 #endif // _CLUS_PROBABILISTICREGRESSIONTREE_H_

Generated on Mon Jul 21 16:57:24 2003 for SECRET by doxygen 1.3.2