Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

probabilisticclassification.cc

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 #include "machine.h"
00033 #include "binaryprobabilisticdecisiontree.h"
00034 #include "binaryprobabilisticsplitter.h"
00035 #include "streamdctrain.h"
00036 #include "gridinput.h"
00037 #include "filecons.h"
00038 #include "syncobjlist.h"
00039 #include "fileprod.h"
00040 #include "rpms.h"
00041 
00042 #include <iostream>
00043 #include <stdlib.h>
00044 #include <time.h>
00045 #include <getopt.h>
00046 
00047 using namespace CLUS;
00048 
00049 int main(int argc, char** argv)
00050 {
00051 
00052     char trainfilename[256];
00053     char prunefilename[256];
00054     char modelfile[256];
00055     char testfilename[256];
00056     int griddim = 0;
00057     char rezfile[256];
00058 
00059     char threshold[256];
00060     char minMass[256];
00061 
00062     // Put default values in parameters
00063     trainfilename[0]=0;
00064     prunefilename[0]=0;
00065     testfilename[0]=0;
00066     sprintf(modelfile,"modelfile");
00067     sprintf(rezfile,"rezfile");
00068     sprintf(threshold, "0.01");
00069     sprintf(minMass, "2");
00070 
00071     char c;
00072 
00073     while( (c = getopt ( argc, argv,"t:p:T:m:G:s:d:r:h"))!=-1)
00074         switch(c)
00075         {
00076         case 't':
00077             strcpy(trainfilename,optarg);
00078             break;
00079 
00080         case 'p':
00081             strcpy(prunefilename,optarg);
00082             break;
00083 
00084         case 'T':
00085             strcpy(testfilename,optarg);
00086             break;
00087 
00088         case 'm':
00089             strcpy(modelfile,optarg);
00090             break;
00091 
00092         case 'G':
00093             griddim = atoi(optarg);
00094             break;
00095 
00096         case 's':
00097             strcpy(rezfile,optarg);
00098             break;
00099 
00100         case 'd':
00101             strcpy(minMass,optarg);
00102             break;
00103 
00104         case 'r':
00105             strcpy(threshold,optarg);
00106             break;
00107 
00108         case 'h':
00109         default:
00110             cout << "classification: learns classification trees" << endl;
00111             cout << "Options:" << endl;
00112             cout << "\t-h : this help message" << endl;
00113             cout << "\t-t filename : specify the training set" << endl;
00114             cout << "\t-p filename : specify the prunning set (absent means no prunning)" << endl;
00115             cout << "\t-T filename : specify the testing set (absent means no testing)" << endl;
00116             cout << "\t-m filename : where to put the resulting tree (default=modelfile)" << endl;
00117             cout << "\t-s filename : where to put the result of testing" << endl;
00118             cout << "\t-d : minMass" << endl;
00119             cout << "\t-r : threshold" << endl;
00120             return 1;
00121         }
00122 
00123     // Testing the parameters
00124     if (trainfilename[0]==0)
00125     {
00126         cerr << "No training data specified" << endl;
00127         exit(1);
00128     }
00129 
00130     bool prune = (prunefilename[0]!=0);
00131     bool test = (testfilename[0]!=0 || griddim!=0);
00132 
00133 
00134     DataProducer*   prod;
00135     RPMSConsumer*   cons;
00136     SyncObjList*    gear;
00137     Machine* dTree;
00138 
00139     try
00140     {
00141         srand( time(NULL) );
00142 
00143         StreamDCTrainingData* traindata = CreateStreamDCTrainingDataFromFile(trainfilename);
00144         StreamDCTrainingData* pruningdata = 0;
00145 
00146         if (prune)
00147             pruningdata = CreateStreamDCTrainingDataFromFile(prunefilename);
00148 
00149         dTree = new BinaryProbabilisticDecisionTree< BinaryProbabilisticSplitter >
00150                 ( traindata->GetDDomainSizes(), traindata->NumCols() );
00151 
00152         dTree->SetOption("MinMass", minMass);
00153         dTree->SetOption("Threshold", threshold);
00154 
00155         dTree->SetTrainingData(traindata);
00156 
00157         dTree->Identify();
00158 
00159         if (prune)
00160         {
00161             dTree->SaveToFile("intermfile");
00162             dTree->SetPruningData(pruningdata);
00163             dTree->Prune();
00164         }
00165 
00166         dTree->SaveToFile(modelfile);
00167 
00168         if (test)
00169         {
00170             // test now on the same data and put result in testfile
00171             gear=new SyncObjList();
00172             if (griddim==0)
00173             {
00174                 prod = new FileDataProducer(dTree->InDim()+dTree->OutDim(),testfilename);
00175             }
00176             else
00177             {
00178                 prod = new GridInputProducer(griddim,dTree->InDim(), dTree->GetScaleFactors());
00179             }
00180 
00181             cons=new RPMSConsumer(dTree->InDim(),rezfile,dTree->GetOutput(),
00182                                   prod->GetOutput());
00183 
00184             dTree->SetInput(prod->GetOutput());
00185             gear->AddObject(prod);
00186             gear->AddObject(dTree);
00187             gear->AddObject(cons);
00188 
00189             //cout << "Testing:" << endl;
00190             cons->SetFileName(rezfile);
00191             gear->Run();
00192 
00193             cout << "The error on test data is: " << cons->GetMSE() << endl;
00194         }
00195     }
00196     catch(ErrMsg msg)
00197     {
00198         msg.PrintMessage(cerr);
00199     }
00200     catch(...)
00201     {
00202         cerr << "A weird exception cought" << endl;
00203     }
00204     return 0;
00205 
00206 }

Generated on Mon Jul 21 16:57:24 2003 for SECRET by doxygen 1.3.2