Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

regressiontree.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #if !defined _CLUS_REGRESSIONTREE_H_
00035 #define _CLUS_REGRESSIONTREE_H_
00036 
00037 #include "machine.h"
00038 #include "regressiontreenode.h"
00039 #include "dctraingen.h"
00040 #include <iostream>
00041 
00042 // using namespace TNT;
00043 
00044 namespace CLUS
00045 {
00046 
00047 template< class T_Distribution, class T_Regressor, class T_Splitter >
00048 class BinaryRegressionTree : public Machine
00049 {
00050 protected:
00051     BinaryRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter>* root;
00052 
00053     /// list of discrete domain sizes
00054     const Vector<int>& dDomainSize;
00055 
00056     /// num of discrete variables
00057     int dsplitDim;
00058 
00059     /// num of continuous+split variables
00060     int csplitDim;
00061 
00062     /// num of regression variables
00063     int regDim;
00064 
00065     /// num of iterations to get convergence for EM
00066     int emMaxIterations;
00067 
00068     /// num of restarts of EM to get a good initial starting point
00069     int emRestarts;
00070 
00071     /// the minimum number of datapoints in a node to split further
00072     int min_no_datapoints;
00073 
00074     /// type of split to be passed to splitter, splitter dependent
00075     int splitType;
00076 
00077     /// threshold for considering a branch irrelevant
00078     double threshold; 
00079 
00080     T_Distribution* rootDistribution;
00081     int inferMaxNodeId;
00082 
00083     void PrintSizesTree(void)
00084     {
00085         int nodes=0;
00086         int term_nodes=0;
00087         root->ComputeSizesTree(nodes,term_nodes);
00088         cout << "Nuber of nodes=" << nodes << "\tNumber of terminal nodes=" << term_nodes << endl;
00089     }
00090 
00091 public:
00092     BinaryRegressionTree(const Vector<int>& DDomainSize, int CsplitDim, int RegDim):
00093             Machine(CsplitDim+RegDim,1),dDomainSize(DDomainSize),
00094             dsplitDim(DDomainSize.dim()),csplitDim(CsplitDim), regDim(RegDim)
00095     {
00096         emRestarts = 3;
00097         emMaxIterations = 30;
00098         min_no_datapoints = 10;
00099         splitType = 0;
00100         rootDistribution = 0;
00101         inferMaxNodeId = INT_MAX;
00102         threshold=.01;
00103     }
00104 
00105     virtual ~BinaryRegressionTree(void)
00106     {
00107         if (rootDistribution)
00108             delete rootDistribution;
00109     }
00110 
00111     virtual int InDim(void)
00112     {
00113         return dsplitDim+csplitDim+regDim;
00114     }
00115 
00116     virtual string TypeName(void)
00117     {
00118         return string("BinaryRegressionTree");
00119     }
00120 
00121     virtual void Infer(void)
00122     {
00123         if (root==0)
00124             return;
00125             
00126         // translate the first dsplitDim inputs into int
00127         int Dvars[MAX_VARIABLES];
00128         for(int i=0; i<dsplitDim; i++)
00129             Dvars[i]=(int)(*input)[i];
00130             
00131         // scale the continuous inputs
00132         double scaledInput[MAX_VARIABLES];
00133         for (int i=0; i<csplitDim+regDim; i++)
00134             scaledInput[i]=scale[i].Transform( (*input)[i+dsplitDim] );
00135 
00136         output[0]=scale[csplitDim+regDim].Transform( root->Infer(Dvars,scaledInput,inferMaxNodeId, threshold) );
00137     }
00138 
00139     virtual void Identify(void)
00140     {
00141         const Matrix<double>& ctrainData = training->GetTrainingData();
00142         const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( training )
00143                                         -> GetDiscreteTrainingData();
00144 
00145         int M=ctrainData.num_rows();
00146 
00147         // find a linear regressor for the first node
00148         if (!rootDistribution)
00149             rootDistribution = new T_Distribution(regDim);
00150 
00151         rootDistribution->RandomDistribution(1);
00152         double Coef;
00153         for (int i=0; i<M; i++)
00154         {
00155             Coef = rootDistribution->LearnProbability(ctrainData[i]+csplitDim);
00156             rootDistribution->NormalizeLearnProbability(Coef);
00157         }
00158         rootDistribution->UpdateParameters();
00159 
00160         /*
00161         cout << "Printing root distribution" << endl;
00162         rootDistribution->SaveToStream(cout);
00163         cout << endl;
00164         */
00165 
00166         T_Regressor* regressor = dynamic_cast<T_Regressor*>( rootDistribution->CreateRegressor() );
00167 
00168         if (regressor==NULL)
00169             regressor=new T_Regressor();
00170 
00171         // create the root and give it the Id 1
00172         root = new BinaryRegressionTreeNode<T_Distribution, T_Regressor, T_Splitter>
00173                (1, dDomainSize, csplitDim, regDim, *regressor, rootDistribution);
00174 
00175         rootDistribution = 0; // root will dealocate rootDistribution
00176         // learn in stages until nobody wants to learn anymore
00177         do
00178         {
00179             root->StartLearningEpoch();
00180             for(int i=0; i<M; i++)
00181                 root->LearnSample(dtrainData[i],ctrainData[i], 1.0, threshold);
00182         }
00183         while (root->StopLearningEpoch(splitType, emRestarts, emMaxIterations,
00184                                        convergenceLim, min_no_datapoints));
00185         cout << "End Learning" << endl;
00186         PrintSizesTree();
00187     }
00188 
00189     virtual void Prune(void)
00190     {
00191         const Matrix<double>& ctrainData = pruning->GetTrainingData();
00192         const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( pruning )
00193                                         -> GetDiscreteTrainingData();
00194 
00195         int M=ctrainData.num_rows();
00196 
00197         root->InitializePruningStatistics();
00198         for(int i=0; i<M; i++)
00199         {
00200             // scale the continuous inputs
00201             double scaledInput[MAX_VARIABLES];
00202             for (int j=0; j<csplitDim+regDim; j++)
00203                 scaledInput[j]=scale[j].Transform( ctrainData[i][j] );
00204             // scale the output
00205             double y=scale[csplitDim+regDim].InverseTransform( ctrainData[i][csplitDim+regDim] );
00206 
00207             root->UpdatePruningStatistics(dtrainData[i], scaledInput, y, 1.0, threshold);
00208         }
00209         root->FinalizePruningStatistics();
00210 
00211         // now cut the tree to the right size
00212         double cost=root->PruneSubtree();
00213         cout << "RMSN after pruning is:" << cost/M << endl;
00214         PrintSizesTree();
00215     }
00216 
00217     virtual int SetOption(char* name, char* val)
00218     {
00219         if (strcmp(name,"EMMaxIterations")==0)
00220             emMaxIterations = atoi(val);
00221         else
00222             if (strcmp(name,"EMRestarts")==0)
00223                 emRestarts = atoi(val);
00224             else
00225                 if (strcmp(name,"InferMaxNodeId")==0)
00226                     inferMaxNodeId = atoi(val);
00227                 else
00228                     if (strcmp(name,"MaxNoDatapoints")==0)
00229                         min_no_datapoints = 2*atoi(val)-2;
00230                     else
00231                         if (strcmp(name,"SplitType")==0)
00232                             splitType = atoi(val);
00233                         else
00234                             if (strcmp(name,"Threshold")==0)
00235                                 threshold = atof(val);
00236                             else
00237                                 return Machine::SetOption(name,val);
00238         return 1;
00239     }
00240 
00241     virtual void SaveToStream(ostream& out)
00242     {
00243         out << TypeName() << " (  " << dsplitDim << " " << csplitDim << " ";
00244         out << regDim << " ) { "  << endl;
00245 
00246         out << "[ ";
00247         for(int i=0; i<dsplitDim; i++)
00248             out << dDomainSize[i] << " ";
00249         out << "]" << endl;
00250 
00251         for(int i=0; i<csplitDim+regDim+1; i++)
00252             scale[i].SaveToStream(out);
00253         out << endl;
00254 
00255         root->SaveToStream(out);
00256 
00257         out << '}' << endl;
00258     }
00259 };
00260 
00261 }
00262 
00263 #endif // _CLUS_REGRESSIONTREE_H_

Generated on Mon Jul 21 16:57:25 2003 for SECRET by doxygen 1.3.2