00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #include "machine.h"
00033 #include "regressiontree.h"
00034 #include "binaryobliquesplitter.h"
00035 #include "probabilisticregressiontree.h"
00036 #include "binaryobliqueprobabilisticsplitter.h"
00037 #include "multinormal.h"
00038 #include "skinymultinormal.h"
00039 #include "linearregressor.h"
00040 #include "streamdctrain.h"
00041 #include <iostream>
00042 #include "syncobjlist.h"
00043 #include "fileprod.h"
00044 #include "rpms.h"
00045 #include <stdlib.h>
00046 #include <time.h>
00047 #include "gridinput.h"
00048 #include "filecons.h"
00049 #include <getopt.h>
00050 #include <gsl/gsl_errno.h>
00051
00052 using namespace CLUS;
00053
00054 int main(int argc, char** argv)
00055 {
00056
00057 char trainfilename[256];
00058 char prunefilename[256];
00059 int c_split = 0;
00060 int r_dim = 0;
00061 char modelfile[256];
00062 char testfilename[256];
00063 int griddim = 0;
00064 char rezfile[256];
00065
00066
00067 int MaxLevel = 0;
00068
00069 char convLimit[256];
00070 char maxNoDatapoints[256];
00071
00072
00073 char * SplitType = "0";
00074
00075 char maxiter[256];
00076 char noRestarts[256];
00077 bool isProbabilistic=false;
00078
00079
00080 trainfilename[0]=0;
00081 prunefilename[0]=0;
00082 testfilename[0]=0;
00083 sprintf(modelfile,"modelfile");
00084 sprintf(rezfile,"rezfile");
00085 sprintf(maxiter,"30");
00086 sprintf(convLimit,"1e-6");
00087 sprintf(noRestarts,"3");
00088 sprintf(maxNoDatapoints,"12");
00089
00090 char c;
00091
00092 while( (c = getopt ( argc, argv,"t:p:T:m:G:s:PS:r:l:ALQi:c:R:d:h"))!=-1)
00093 switch(c)
00094 {
00095 case 't':
00096 strcpy(trainfilename,optarg);
00097 break;
00098
00099 case 'p':
00100 strcpy(prunefilename,optarg);
00101 break;
00102
00103 case 'T':
00104 strcpy(testfilename,optarg);
00105 break;
00106
00107 case 'm':
00108 strcpy(modelfile,optarg);
00109 break;
00110
00111 case 'G':
00112 griddim = atoi(optarg);
00113 break;
00114
00115 case 's':
00116 strcpy(rezfile,optarg);
00117 break;
00118
00119 case 'P':
00120 isProbabilistic=true;
00121 break;
00122
00123 case 'S':
00124 c_split = atoi(optarg);
00125 break;
00126
00127 case 'r':
00128 r_dim = atoi(optarg);
00129 break;
00130
00131 case 'l':
00132 MaxLevel = atoi(optarg);
00133 break;
00134
00135 case 'A':
00136 SplitType = "0";
00137 break;
00138
00139 case 'L':
00140 SplitType = "1";
00141 break;
00142
00143 case 'Q':
00144 SplitType = "2";
00145 break;
00146
00147 case 'i':
00148 strcpy(maxiter,optarg);
00149 break;
00150
00151 case 'c':
00152 strcpy(convLimit,optarg);
00153 break;
00154
00155 case 'R':
00156 strcpy(noRestarts,optarg);
00157 break;
00158
00159 case 'd':
00160 strcpy(maxNoDatapoints,optarg);
00161 break;
00162
00163
00164 case 'h':
00165 default:
00166 cout << "regression: learns a regression tree" << endl;
00167 cout << "Options:" << endl;
00168 cout << "\t-h : this help message" << endl;
00169 cout << "\t-t filename : specify the training set" << endl;
00170 cout << "\t-p filename : specify the prunning set (absent means no prunning)" << endl;
00171 cout << "\t-T filename : specify the testing set (absent means no testing)" << endl;
00172 cout << "\t-m filename : where to put the resulting tree (default=modelfile)" << endl;
00173 cout << "\t-G number : number of isolines in the grid (no effect if testing)" << endl;
00174 cout << "\t-s filename : where to put the result of testing" << endl;
00175 cout << "\t-P : probabilistic regression tree" << endl;
00176 cout << "\t-S number : numer of continuous split variables" << endl;
00177 cout << "\t-r number : number of regressor variables" << endl;
00178 cout << "\t-l number : the biggest nodeID to be used in testing" << endl;
00179 cout << "\t-A : use unidimentional ANOVA splits" << endl;
00180 cout << "\t-L : use multidimentional LDA splits" << endl;
00181 cout << "\t-Q : use multidimentional QDA splits" << endl;
00182 cout << "\t-i number : max number of iterations in EM algorithm" << endl;
00183 cout << "\t-c number : convergence tolerance" << endl;
00184 cout << "\t-R number : number of random restarts for EM" << endl;
00185 cout << "\t-d number : minimum datapoints in leaf" << endl;
00186
00187 }
00188
00189
00190 if (trainfilename[0]==0)
00191 {
00192 cerr << "No training data specified" << endl;
00193 exit(1);
00194 }
00195
00196 bool prune = (prunefilename[0]!=0);
00197 bool test = (testfilename[0]!=0 || griddim!=0);
00198
00199
00200 DataProducer* prod;
00201 RPMSConsumer* cons;
00202 SyncObjList* gear;
00203 Machine* rTree;
00204
00205 try
00206 {
00207 srand( time(NULL) );
00208
00209 gsl_set_error_handler_off();
00210
00211 StreamDCTrainingData* traindata = CreateStreamDCTrainingDataFromFile(trainfilename);
00212 StreamDCTrainingData* pruningdata = 0;
00213
00214 if (prune)
00215 pruningdata = CreateStreamDCTrainingDataFromFile(prunefilename);
00216
00217 if (!isProbabilistic)
00218 {
00219 rTree = new
00220 BinaryRegressionTree< MultiDimNormal, LinearRegressor,
00221 BinaryObliqueSplitter > ( traindata->GetDDomainSizes(), c_split, r_dim );
00222 }
00223 else
00224 {
00225 rTree = new
00226 BinaryRegressionTree< MultiDimNormal, LinearRegressor,
00227 BinaryObliqueProbabilisticSplitter > ( traindata->GetDDomainSizes(), c_split, r_dim );
00228 }
00229
00230
00231
00232
00233 rTree->SetOption("EMMaxIterations",maxiter);
00234 rTree->SetOption("EMRestarts",noRestarts);
00235 rTree->SetOption("ConvergenceLimit",convLimit);
00236 rTree->SetOption("MaxNoDatapoints",maxNoDatapoints);
00237 rTree->SetOption("SplitType",SplitType);
00238
00239 rTree->SetTrainingData(traindata);
00240 rTree->ScaleData(Scale::Interval);
00241
00242 rTree->Identify();
00243
00244 if (prune)
00245 {
00246 rTree->SaveToFile("intermfile");
00247 rTree->SetPruningData(pruningdata);
00248 rTree->Prune();
00249 }
00250
00251 rTree->SaveToFile(modelfile);
00252
00253 if (test)
00254 {
00255
00256 gear=new SyncObjList();
00257 if (griddim==0)
00258 {
00259 prod = new FileDataProducer(rTree->InDim()+rTree->OutDim(),testfilename);
00260 }
00261 else
00262 {
00263 prod = new GridInputProducer(griddim,rTree->InDim(), rTree->GetScaleFactors());
00264 }
00265
00266 cons=new RPMSConsumer(rTree->InDim(),rezfile,rTree->GetOutput(),
00267 prod->GetOutput());
00268
00269 rTree->SetInput(prod->GetOutput());
00270 gear->AddObject(prod);
00271 gear->AddObject(rTree);
00272 gear->AddObject(cons);
00273
00274 char filename[256];
00275
00276 ofstream out("animation.gpl");
00277
00278 if (MaxLevel>0)
00279 {
00280 for (int maxlevel=0; maxlevel <= MaxLevel; maxlevel++)
00281 {
00282 cout << "Generating resutlt for level: " << maxlevel;
00283
00284 sprintf(filename,rezfile,maxlevel);
00285 cons->SetFileName(filename);
00286
00287 out << "splot \"" << filename << "\"; pause -1" << endl;
00288
00289 sprintf(filename,"%d",1<<maxlevel);
00290 rTree->SetOption("InferMaxNodeId",filename);
00291 gear->Run();
00292 cout << "\tMSE=" << cons->GetMSE() << endl;
00293
00294 }
00295 }
00296 else
00297 {
00298 cout << "Testing:" << endl;
00299 rTree->SetOption("InferMaxNodeId","200000000");
00300 cons->SetFileName(rezfile);
00301 gear->Run();
00302
00303 cout << "The error on test data is: " << cons->GetMSE() << endl;
00304 }
00305 }
00306 }
00307 catch(ErrMsg msg)
00308 {
00309 msg.PrintMessage(cout);
00310 }
00311 catch(...)
00312 {
00313 cerr << "A weird exception cought" << endl;
00314 }
00315 return 0;
00316
00317 }