00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #include "machine.h"
00033 #include "multidecisiontree.h"
00034 #include "binarymulticlassificationsplitter.h"
00035 #include "binarydecisiontree.h"
00036 #include "simplebinarysplitter.h"
00037 #include "streamdctrain.h"
00038 #include "gridinput.h"
00039 #include "filecons.h"
00040 #include "syncobjlist.h"
00041 #include "fileprod.h"
00042 #include "rpms.h"
00043
00044 #include <iostream>
00045 #include <stdlib.h>
00046 #include <time.h>
00047 #include <getopt.h>
00048
00049 using namespace CLUS;
00050
00051
00052
00053
00054 int main(int argc, char** argv)
00055 {
00056
00057 char trainfilename[256];
00058 char prunefilename[256];
00059 char modelfile[256];
00060 char testfilename[256];
00061 int griddim = 0;
00062 char rezfile[256];
00063 int noDatasets = 1;
00064 bool multiclassifier = false;
00065 char minMass[256];
00066 char schemaMatchString[256];
00067 char labeledString[256];
00068 char alignSplitsString[256];
00069
00070
00071 trainfilename[0]=0;
00072 prunefilename[0]=0;
00073 testfilename[0]=0;
00074 labeledString[0]=0;
00075 schemaMatchString[0]=0;
00076 sprintf(modelfile,"modelfile");
00077 sprintf(rezfile,"rezfile");
00078 sprintf(minMass,"2");
00079
00080 char c;
00081
00082 while( (c = getopt ( argc, argv,"g:t:p:T:m:G:s:MD:d:hl:x:"))!=-1)
00083 switch(c)
00084 {
00085 case 't':
00086 strcpy(trainfilename,optarg);
00087 break;
00088
00089 case 'p':
00090 strcpy(prunefilename,optarg);
00091 break;
00092
00093 case 'T':
00094 strcpy(testfilename,optarg);
00095 break;
00096
00097 case 'm':
00098 strcpy(modelfile,optarg);
00099 break;
00100
00101 case 'G':
00102 griddim = atoi(optarg);
00103 break;
00104
00105 case 's':
00106 strcpy(rezfile,optarg);
00107 break;
00108
00109 case 'M':
00110 multiclassifier=true;
00111 break;
00112
00113 case 'D':
00114 noDatasets=atoi(optarg);
00115 break;
00116
00117 case 'd':
00118 strcpy(minMass,optarg);
00119 break;
00120
00121 case 'g':
00122 strcpy(schemaMatchString, optarg);
00123 break;
00124
00125 case 'l':
00126 strcpy(labeledString, optarg);
00127 break;
00128
00129 case 'x':
00130 strcpy(alignSplitsString, optarg);
00131 break;
00132
00133 case 'h':
00134 default:
00135 cout << "classification: learns classification trees" << endl;
00136 cout << "Options:" << endl;
00137 cout << "\t-h : this help message" << endl;
00138 cout << "\t-t filename : specify the training set" << endl;
00139 cout << "\t-p filename : specify the prunning set (absent means no prunning)" << endl;
00140 cout << "\t-T filename : specify the testing set (absent means no testing)" << endl;
00141 cout << "\t-m filename : where to put the resulting tree (default=modelfile)" << endl;
00142 cout << "\t-s filename : where to put the result of testing" << endl;
00143 cout << "\t-M : use multiclassifier instead of normal classifier" << endl;
00144 cout << "\t-D number : number of datasets for multiclassifier" << endl;
00145 cout << "\t-g boolean: do schema matching first" << endl;
00146 cout << "\t-l boolean: labeled choice of split point" << endl;
00147 cout << "\t-d number : minimum no datapoints to consider splitting" << endl;
00148 return 1;
00149 }
00150
00151
00152 if (trainfilename[0]==0)
00153 {
00154 cerr << "No training data specified" << endl;
00155 exit(1);
00156 }
00157
00158 bool prune = (prunefilename[0]!=0);
00159 bool test = (testfilename[0]!=0 || griddim!=0);
00160
00161
00162 DataProducer* prod;
00163 RPMSConsumer* cons;
00164 SyncObjList* gear;
00165 Machine* dTree;
00166
00167 try
00168 {
00169 srand( time(NULL) );
00170
00171 StreamDCTrainingData* traindata = CreateStreamDCTrainingDataFromFile(trainfilename);
00172 StreamDCTrainingData* pruningdata = 0;
00173
00174 if (prune)
00175 pruningdata = CreateStreamDCTrainingDataFromFile(prunefilename);
00176
00177 if (!multiclassifier)
00178 {
00179
00180 dTree = new BinaryDecisionTree< SimpleBinarySplitter >
00181 ( traindata->GetDDomainSizes(), traindata->NumCols());
00182
00183 dTree->SetOption("MinMass",minMass);
00184 }
00185 else
00186 {
00187
00188 dTree = new MultiDecisionTree< BinaryMultiClassificationSplitter >
00189 ( traindata->GetDDomainSizes(), traindata->NumCols()-1, noDatasets);
00190
00191 dTree->SetOption("SchemaMatch", schemaMatchString);
00192 dTree->SetOption("labeled", labeledString);
00193 dTree->SetOption("alignSplits", alignSplitsString);
00194 }
00195
00196
00197 dTree->SetTrainingData(traindata);
00198
00199 dTree->Identify();
00200
00201 if (prune)
00202 {
00203 dTree->SaveToFile("intermfile");
00204 dTree->SetPruningData(pruningdata);
00205 dTree->Prune();
00206 }
00207
00208 dTree->SaveToFile(modelfile);
00209
00210 if (test)
00211 {
00212
00213 gear=new SyncObjList();
00214 if (griddim==0)
00215 {
00216 prod = new FileDataProducer(dTree->InDim()+dTree->OutDim(),testfilename);
00217 }
00218 else
00219 {
00220 prod = new GridInputProducer(griddim,dTree->InDim(), dTree->GetScaleFactors());
00221 }
00222
00223 cons=new RPMSConsumer(dTree->InDim(),rezfile,dTree->GetOutput(),
00224 prod->GetOutput());
00225
00226 dTree->SetInput(prod->GetOutput());
00227 gear->AddObject(prod);
00228 gear->AddObject(dTree);
00229 gear->AddObject(cons);
00230
00231
00232 cons->SetFileName(rezfile);
00233 gear->Run();
00234
00235 cout << "The error on test data is: " << cons->GetMSE() << endl;
00236 }
00237 }
00238 catch(ErrMsg msg)
00239 {
00240 msg.PrintMessage(cerr);
00241 }
00242 catch(...)
00243 {
00244 cerr << "A weird exception cought" << endl;
00245 }
00246 return 0;
00247
00248 }
00249
00250
00251
00252
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272
00273
00274
00275
00276
00277
00278
00279
00280
00281
00282
00283
00284
00285
00286
00287
00288
00289
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309