Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

regression.cc

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 #include "machine.h"
00033 #include "regressiontree.h"
00034 #include "binaryobliquesplitter.h"
00035 #include "probabilisticregressiontree.h"
00036 #include "binaryobliqueprobabilisticsplitter.h"
00037 #include "multinormal.h"
00038 #include "skinymultinormal.h"
00039 #include "linearregressor.h"
00040 #include "streamdctrain.h"
00041 #include <iostream>
00042 #include "syncobjlist.h"
00043 #include "fileprod.h"
00044 #include "rpms.h"
00045 #include <stdlib.h>
00046 #include <time.h>
00047 #include "gridinput.h"
00048 #include "filecons.h"
00049 #include <getopt.h>
00050 #include <gsl/gsl_errno.h>
00051 
00052 using namespace CLUS;
00053 
00054 int main(int argc, char** argv)
00055 {
00056 
00057     char trainfilename[256];
00058     char prunefilename[256];
00059     int c_split = 0;
00060     int r_dim = 0;
00061     char modelfile[256];
00062     char testfilename[256];
00063     int griddim = 0;
00064     char rezfile[256];
00065 
00066     /// 0 means whole tree
00067     int MaxLevel = 0;
00068     
00069     char convLimit[256];
00070     char maxNoDatapoints[256];
00071 
00072     /// Unidim Anova
00073     char * SplitType = "0";
00074      
00075     char maxiter[256];
00076     char noRestarts[256];
00077     bool isProbabilistic=false;
00078 
00079     // Put default values in parameters
00080     trainfilename[0]=0;
00081     prunefilename[0]=0;
00082     testfilename[0]=0;
00083     sprintf(modelfile,"modelfile");
00084     sprintf(rezfile,"rezfile");
00085     sprintf(maxiter,"30");
00086     sprintf(convLimit,"1e-6");
00087     sprintf(noRestarts,"3");
00088     sprintf(maxNoDatapoints,"12");
00089 
00090     char c;
00091 
00092     while( (c = getopt ( argc, argv,"t:p:T:m:G:s:PS:r:l:ALQi:c:R:d:h"))!=-1)
00093         switch(c)
00094         {
00095         case 't':
00096             strcpy(trainfilename,optarg);
00097             break;
00098 
00099         case 'p':
00100             strcpy(prunefilename,optarg);
00101             break;
00102 
00103         case 'T':
00104             strcpy(testfilename,optarg);
00105             break;
00106 
00107         case 'm':
00108             strcpy(modelfile,optarg);
00109             break;
00110 
00111         case 'G':
00112             griddim = atoi(optarg);
00113             break;
00114 
00115         case 's':
00116             strcpy(rezfile,optarg);
00117             break;
00118 
00119         case 'P':
00120             isProbabilistic=true;
00121             break;
00122 
00123         case 'S':
00124             c_split = atoi(optarg);
00125             break;
00126 
00127         case 'r':
00128             r_dim = atoi(optarg);
00129             break;
00130 
00131         case 'l':
00132             MaxLevel = atoi(optarg);
00133             break;
00134 
00135         case 'A':
00136             SplitType = "0";
00137             break;
00138 
00139         case 'L':
00140             SplitType = "1";
00141             break;
00142 
00143         case 'Q':
00144             SplitType = "2";
00145             break;
00146 
00147         case 'i':
00148             strcpy(maxiter,optarg);
00149             break;
00150 
00151         case 'c':
00152             strcpy(convLimit,optarg);
00153             break;
00154 
00155         case 'R':
00156             strcpy(noRestarts,optarg);
00157             break;
00158 
00159         case 'd':
00160             strcpy(maxNoDatapoints,optarg);
00161             break;
00162 
00163 
00164         case 'h':
00165         default:
00166             cout << "regression: learns a regression tree" << endl;
00167             cout << "Options:" << endl;
00168             cout << "\t-h : this help message" << endl;
00169             cout << "\t-t filename : specify the training set" << endl;
00170             cout << "\t-p filename : specify the prunning set (absent means no prunning)" << endl;
00171             cout << "\t-T filename : specify the testing set (absent means no testing)" << endl;
00172             cout << "\t-m filename : where to put the resulting tree (default=modelfile)" << endl;
00173             cout << "\t-G number : number of isolines in the grid (no effect if testing)" << endl;
00174             cout << "\t-s filename : where to put the result of testing" << endl;
00175             cout << "\t-P : probabilistic regression tree" << endl;
00176             cout << "\t-S number : numer of continuous split variables" << endl;
00177             cout << "\t-r number : number of regressor variables" << endl;
00178             cout << "\t-l number : the biggest nodeID to be used in testing" << endl;
00179             cout << "\t-A : use unidimentional ANOVA splits" << endl;
00180             cout << "\t-L : use multidimentional LDA splits" << endl;
00181             cout << "\t-Q : use multidimentional QDA splits" << endl;
00182             cout << "\t-i number : max number of iterations in EM algorithm" << endl;
00183             cout << "\t-c number : convergence tolerance" << endl;
00184             cout << "\t-R number : number of random restarts for EM" << endl;
00185             cout << "\t-d number : minimum datapoints in leaf" << endl;
00186 
00187         }
00188 
00189     // Testing the parameters
00190     if (trainfilename[0]==0)
00191     {
00192         cerr << "No training data specified" << endl;
00193         exit(1);
00194     }
00195 
00196     bool prune = (prunefilename[0]!=0);
00197     bool test = (testfilename[0]!=0 || griddim!=0);
00198 
00199 
00200     DataProducer*   prod;
00201     RPMSConsumer*   cons;
00202     SyncObjList*    gear;
00203     Machine* rTree;
00204 
00205     try
00206     {
00207         srand( time(NULL) );
00208 
00209         gsl_set_error_handler_off(); // make sure that gsl errors do not stop the program
00210 
00211         StreamDCTrainingData* traindata = CreateStreamDCTrainingDataFromFile(trainfilename);
00212         StreamDCTrainingData* pruningdata = 0;
00213 
00214         if (prune)
00215             pruningdata = CreateStreamDCTrainingDataFromFile(prunefilename);
00216 
00217         if (!isProbabilistic)
00218         {
00219             rTree = new
00220                     BinaryRegressionTree< MultiDimNormal, LinearRegressor,
00221                     BinaryObliqueSplitter > ( traindata->GetDDomainSizes(), c_split, r_dim );
00222         }
00223         else
00224         {
00225             rTree = new
00226                     BinaryRegressionTree< MultiDimNormal, LinearRegressor,
00227                     BinaryObliqueProbabilisticSplitter > ( traindata->GetDDomainSizes(), c_split, r_dim );
00228         }
00229         /*    BinaryRegressionTree< SkinyMultiDimNormal, LinearRegressor,
00230         BinaryObliqueSplitter > rTree( traindata->GetDDomainSizes(), c_split, r_dim );
00231         */
00232 
00233         rTree->SetOption("EMMaxIterations",maxiter);
00234         rTree->SetOption("EMRestarts",noRestarts);
00235         rTree->SetOption("ConvergenceLimit",convLimit);
00236         rTree->SetOption("MaxNoDatapoints",maxNoDatapoints);
00237         rTree->SetOption("SplitType",SplitType);
00238 
00239         rTree->SetTrainingData(traindata);
00240         rTree->ScaleData(Scale::Interval);
00241 
00242         rTree->Identify();
00243 
00244         if (prune)
00245         {
00246             rTree->SaveToFile("intermfile");
00247             rTree->SetPruningData(pruningdata);
00248             rTree->Prune();
00249         }
00250 
00251         rTree->SaveToFile(modelfile);
00252 
00253         if (test)
00254         {
00255             // test now on the same data and put result in testfile
00256             gear=new SyncObjList();
00257             if (griddim==0)
00258             {
00259                 prod = new FileDataProducer(rTree->InDim()+rTree->OutDim(),testfilename);
00260             }
00261             else
00262             {
00263                 prod = new GridInputProducer(griddim,rTree->InDim(), rTree->GetScaleFactors());
00264             }
00265 
00266             cons=new RPMSConsumer(rTree->InDim(),rezfile,rTree->GetOutput(),
00267                                   prod->GetOutput());
00268 
00269             rTree->SetInput(prod->GetOutput());
00270             gear->AddObject(prod);
00271             gear->AddObject(rTree);
00272             gear->AddObject(cons);
00273 
00274             char filename[256];
00275 
00276             ofstream out("animation.gpl");
00277 
00278             if (MaxLevel>0)
00279             {
00280                 for (int maxlevel=0; maxlevel <= MaxLevel; maxlevel++)
00281                 {
00282                     cout << "Generating resutlt for level: " << maxlevel;
00283 
00284                     sprintf(filename,rezfile,maxlevel);
00285                     cons->SetFileName(filename);
00286 
00287                     out << "splot \"" << filename << "\"; pause -1" << endl;
00288 
00289                     sprintf(filename,"%d",1<<maxlevel);
00290                     rTree->SetOption("InferMaxNodeId",filename);
00291                     gear->Run();
00292                     cout << "\tMSE=" << cons->GetMSE() << endl;
00293 
00294                 }
00295             }
00296             else
00297             {
00298                 cout << "Testing:" << endl;
00299                 rTree->SetOption("InferMaxNodeId","200000000");
00300                 cons->SetFileName(rezfile);
00301                 gear->Run();
00302 
00303                 cout << "The error on test data is: " << cons->GetMSE() << endl;
00304             }
00305         }
00306     }
00307     catch(ErrMsg msg)
00308     {
00309         msg.PrintMessage(cout);
00310     }
00311     catch(...)
00312     {
00313         cerr << "A weird exception cought" << endl;
00314     }
00315     return 0;
00316 
00317 }

Generated on Mon Jul 21 16:57:25 2003 for SECRET by doxygen 1.3.2