00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032 #include "machine.h"
00033 #include "binaryprobabilisticdecisiontree.h"
00034 #include "binaryprobabilisticsplitter.h"
00035 #include "streamdctrain.h"
00036 #include "gridinput.h"
00037 #include "filecons.h"
00038 #include "syncobjlist.h"
00039 #include "fileprod.h"
00040 #include "rpms.h"
00041
00042 #include <iostream>
00043 #include <stdlib.h>
00044 #include <time.h>
00045 #include <getopt.h>
00046
00047 using namespace CLUS;
00048
00049 int main(int argc, char** argv)
00050 {
00051
00052 char trainfilename[256];
00053 char prunefilename[256];
00054 char modelfile[256];
00055 char testfilename[256];
00056 int griddim = 0;
00057 char rezfile[256];
00058
00059 char threshold[256];
00060 char minMass[256];
00061
00062
00063 trainfilename[0]=0;
00064 prunefilename[0]=0;
00065 testfilename[0]=0;
00066 sprintf(modelfile,"modelfile");
00067 sprintf(rezfile,"rezfile");
00068 sprintf(threshold, "0.01");
00069 sprintf(minMass, "2");
00070
00071 char c;
00072
00073 while( (c = getopt ( argc, argv,"t:p:T:m:G:s:d:r:h"))!=-1)
00074 switch(c)
00075 {
00076 case 't':
00077 strcpy(trainfilename,optarg);
00078 break;
00079
00080 case 'p':
00081 strcpy(prunefilename,optarg);
00082 break;
00083
00084 case 'T':
00085 strcpy(testfilename,optarg);
00086 break;
00087
00088 case 'm':
00089 strcpy(modelfile,optarg);
00090 break;
00091
00092 case 'G':
00093 griddim = atoi(optarg);
00094 break;
00095
00096 case 's':
00097 strcpy(rezfile,optarg);
00098 break;
00099
00100 case 'd':
00101 strcpy(minMass,optarg);
00102 break;
00103
00104 case 'r':
00105 strcpy(threshold,optarg);
00106 break;
00107
00108 case 'h':
00109 default:
00110 cout << "classification: learns classification trees" << endl;
00111 cout << "Options:" << endl;
00112 cout << "\t-h : this help message" << endl;
00113 cout << "\t-t filename : specify the training set" << endl;
00114 cout << "\t-p filename : specify the prunning set (absent means no prunning)" << endl;
00115 cout << "\t-T filename : specify the testing set (absent means no testing)" << endl;
00116 cout << "\t-m filename : where to put the resulting tree (default=modelfile)" << endl;
00117 cout << "\t-s filename : where to put the result of testing" << endl;
00118 cout << "\t-d : minMass" << endl;
00119 cout << "\t-r : threshold" << endl;
00120 return 1;
00121 }
00122
00123
00124 if (trainfilename[0]==0)
00125 {
00126 cerr << "No training data specified" << endl;
00127 exit(1);
00128 }
00129
00130 bool prune = (prunefilename[0]!=0);
00131 bool test = (testfilename[0]!=0 || griddim!=0);
00132
00133
00134 DataProducer* prod;
00135 RPMSConsumer* cons;
00136 SyncObjList* gear;
00137 Machine* dTree;
00138
00139 try
00140 {
00141 srand( time(NULL) );
00142
00143 StreamDCTrainingData* traindata = CreateStreamDCTrainingDataFromFile(trainfilename);
00144 StreamDCTrainingData* pruningdata = 0;
00145
00146 if (prune)
00147 pruningdata = CreateStreamDCTrainingDataFromFile(prunefilename);
00148
00149 dTree = new BinaryProbabilisticDecisionTree< BinaryProbabilisticSplitter >
00150 ( traindata->GetDDomainSizes(), traindata->NumCols() );
00151
00152 dTree->SetOption("MinMass", minMass);
00153 dTree->SetOption("Threshold", threshold);
00154
00155 dTree->SetTrainingData(traindata);
00156
00157 dTree->Identify();
00158
00159 if (prune)
00160 {
00161 dTree->SaveToFile("intermfile");
00162 dTree->SetPruningData(pruningdata);
00163 dTree->Prune();
00164 }
00165
00166 dTree->SaveToFile(modelfile);
00167
00168 if (test)
00169 {
00170
00171 gear=new SyncObjList();
00172 if (griddim==0)
00173 {
00174 prod = new FileDataProducer(dTree->InDim()+dTree->OutDim(),testfilename);
00175 }
00176 else
00177 {
00178 prod = new GridInputProducer(griddim,dTree->InDim(), dTree->GetScaleFactors());
00179 }
00180
00181 cons=new RPMSConsumer(dTree->InDim(),rezfile,dTree->GetOutput(),
00182 prod->GetOutput());
00183
00184 dTree->SetInput(prod->GetOutput());
00185 gear->AddObject(prod);
00186 gear->AddObject(dTree);
00187 gear->AddObject(cons);
00188
00189
00190 cons->SetFileName(rezfile);
00191 gear->Run();
00192
00193 cout << "The error on test data is: " << cons->GetMSE() << endl;
00194 }
00195 }
00196 catch(ErrMsg msg)
00197 {
00198 msg.PrintMessage(cerr);
00199 }
00200 catch(...)
00201 {
00202 cerr << "A weird exception cought" << endl;
00203 }
00204 return 0;
00205
00206 }