Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

multidecisiontree.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #if !defined _CLUS_MULTIDECISIONTREE_H_
00035 #define _CLUS_MULTIDECISIONTREE_H_
00036 
00037 #include "machine.h"
00038 #include "multidecisiontreenode.h"
00039 #include "dctraingen.h"
00040 #include "continuouslineartransformation.h"
00041 #include "discretepermutationtransformation.h"
00042 #include "statisticsgatherers.h"
00043 #include "general.h"
00044 
00045 #include <iostream>
00046 
00047 // using namespace TNT;
00048 
00049 namespace CLUS
00050 {
00051 
00052 template< class T_Splitter >
00053 class MultiDecisionTree : public Machine
00054 {
00055 protected:
00056     MultiDecisionTreeNode< T_Splitter >* root;
00057 
00058    /// list of discrete domain sizes
00059     const Vector<int>& dDomainSize;
00060 
00061     /// num of discrete variables
00062     int dsplitDim;
00063 
00064     /// num of continuous+split variables
00065     int csplitDim;
00066 
00067     /// the number of trees built and outputs produces
00068     int noDatasets;
00069     
00070     /// the minimum number of datapoints in a node to split further
00071     int min_no_datapoints;
00072 
00073     /// type of split to be passed to splitter, splitter dependent
00074     int splitType;
00075 
00076     bool schemaMatch;
00077     bool labeled;
00078     bool alignSplits;
00079 
00080     ContinuousLinearTransformation* continuousTransformer;
00081     DiscretePermutationTransformation* discreteTransformer;
00082 public:
00083     MultiDecisionTree(const Vector<int>& DDomainSize, int CsplitDim, int NoDatasets):
00084             Machine(CsplitDim,1),dDomainSize(DDomainSize),
00085             dsplitDim(DDomainSize.dim()-1),csplitDim(CsplitDim),noDatasets(NoDatasets)
00086     {
00087         ofstream file("x");
00088         continuousTransformer=0;
00089         discreteTransformer=0;
00090         min_no_datapoints = 10;
00091         splitType = 0;
00092     }
00093 
00094     ~MultiDecisionTree(void)
00095     {
00096         if (root!=0)
00097             delete root;
00098 
00099         if (continuousTransformer!=0)
00100             delete continuousTransformer;
00101 
00102         if (discreteTransformer!=0)
00103             delete discreteTransformer;
00104     }
00105 
00106     virtual int InDim(void)
00107     {
00108         return dsplitDim+csplitDim+1;
00109     }
00110 
00111     virtual string TypeName(void)
00112     {
00113         return string("MultiDecisionTree");
00114     }
00115 
00116     virtual void Infer(void)
00117     {
00118         if (root==0)
00119             return;
00120         // translate the first dsplitDim inputs into int
00121         int Dvars[MAX_VARIABLES];
00122         for(int i=0; i<dsplitDim; i++)
00123         {
00124             Dvars[i]=(int)(*input)[i];
00125             //    cout << Dvars[i] << " ";
00126         }
00127         //    cout << " :: ";
00128         double Cvars[MAX_VARIABLES];
00129         for (int i=0; i<csplitDim; i++)
00130         {
00131             Cvars[i]=(*input)[i+dsplitDim];
00132             //    cout << Cvars[i] << " ";
00133         }
00134 
00135         // shift the data
00136         int datasetIndex=(int)(*input)[dsplitDim+csplitDim];
00137 
00138         assert (datasetIndex<noDatasets);
00139 
00140         //    cout << " D:" << datasetIndex << " | " << endl;
00141         discreteTransformer->ApplyShiftToTuple(Dvars, datasetIndex);
00142         continuousTransformer->ApplyShiftToTuple(Cvars, datasetIndex);
00143 
00144         //     for(int i=0; i<dsplitDim; i++){
00145         //    cout << Dvars[i] << " ";
00146         //     }
00147         //    cout << " :: ";
00148         //   for (int i=0; i<csplitDim; i++){
00149         //cout << Cvars[i] << " ";
00150         //    }
00151 
00152         output[0]=root->Infer(Dvars, Cvars);
00153         //   cout << " O: " << output[0] << endl;
00154     }
00155 
00156     void printBitFlipData()
00157     {
00158         ofstream afile("hello");
00159         afile << "HELLO" << endl;
00160         bool significant[csplitDim];
00161         /*  for (int j=0; j<csplitDim; j++)
00162         {
00163         if (root->labeledMeansSignificant(j, 0))
00164         {
00165         significant[j]=true;
00166         }
00167         else significant[j] = false;
00168         }*/
00169         ofstream file("bitFlipStuff");
00170         bool bits[csplitDim];
00171         file << noDatasets;
00172         for (int j=0; j< csplitDim; j++)
00173         {
00174             bits[j]=root->negMeanLessThanPos(j, 0);
00175         }
00176         for (int i=1; i<noDatasets; i++)
00177         {
00178             file << "in bit flip loop" << endl;
00179 
00180             file << "Dataset " << i << ":" << endl;
00181             for (int j=0; j< csplitDim; j++)
00182             {
00183                 if (root->labeledMeansSignificant(j,0))
00184                 {
00185                     file << (root->negMeanLessThanPos(j, i)!= bits[j]) << endl;
00186                 }
00187             }
00188         }
00189     }
00190 
00191     virtual void Identify()
00192     {
00193         // copy the data since we do not want to modify the original
00194         // The value of the last discrete attribute is the class label
00195         Matrix<double> ctrainData = training->GetTrainingData();
00196 
00197         continuousTransformer=new ContinuousLinearTransformation(csplitDim, noDatasets,
00198                               ctrainData) ;
00199         ofstream file("xyz");
00200         file << noDatasets << endl;
00201         // The value of the last continuous attribute designates the subdataset from
00202         // where the the tuple is comming from
00203         Matrix<int> dtrainData = dynamic_cast< DCTrainingData* >( training )
00204                                  -> GetDiscreteTrainingData();
00205 
00206         discreteTransformer=new DiscretePermutationTransformation(dsplitDim, noDatasets,
00207                             ctrainData, dtrainData,
00208                             dDomainSize);
00209 
00210         int M=ctrainData.num_rows();
00211 
00212         root = new MultiDecisionTreeNode<T_Splitter>
00213                (NULL, 1, dDomainSize, csplitDim, noDatasets,
00214                 *discreteTransformer, *continuousTransformer);
00215 
00216 
00217         file << "here I am";
00218         file << noDatasets << endl;
00219         // ------------------- DOING SCHEMA MATCHING ---------------------
00220         if (schemaMatch)
00221         {
00222             //    cerr << "DOING SCHEMA MATCHING" << endl;
00223             // create the root and give it the Id 1
00224 
00225             root->StartLearningEpoch();
00226             for(int i=0; i<M; i++)
00227             {
00228                 int datasetIndex=(int)ctrainData[i][csplitDim];
00229 
00230                 ExitIf(datasetIndex>=noDatasets, "More datasets enountered than declared");
00231                 int classLabel=dtrainData[i][dsplitDim];
00232                 root->LearnSample(dtrainData[i],ctrainData[i], classLabel, datasetIndex);
00233             }
00234             file << noDatasets;
00235             //printBitFlipData(noDatasets);
00236             for (int att =0; att<dsplitDim; att++)
00237             { // discrete attribute
00238 
00239                 // find and propagate the shift
00240                 Vector<Permutation> shifts(noDatasets);
00241                 shifts[0]=Permutation(dDomainSize[att]);
00242                 for (int i=1; i<noDatasets; i++)
00243                     shifts[i]= root->ComputeDiscreteShift(true, att, i);
00244                 discreteTransformer->SetShiftsAttribute(att, shifts);
00245 
00246             }
00247 
00248             for (int att =0; att<csplitDim; att++)
00249             { // continuous attribute
00250 
00251                 double sumInvVars=0.0;
00252                 double split0;
00253                 if (labeled)
00254                 {
00255                     split0 = 0.0;
00256                     //split0 = root->combineSplits(att, 0, &sumInvVars)/sumInvVars;
00257                 }
00258                 else
00259                 {
00260                     if (!alignSplits)
00261                         split0= root->combineCenters(att, 0, &sumInvVars)/sumInvVars;
00262                     else
00263                         split0 == root->combineSplits(att, 0, &sumInvVars)/sumInvVars;
00264                 }
00265 
00266                 if (!NonZero(sumInvVars))
00267                     split0 =0.0;
00268                 //        cerr << "split0";
00269 
00270 
00271                 Vector<double> shifts(noDatasets);
00272                 shifts[0]=0.0;
00273                 for (int i=1; i<noDatasets; i++)
00274                 {
00275                     double sumInvVars=0.0;
00276                     double spliti;
00277                     if (labeled)
00278                     {
00279                         // this aligns split points.  want to align centers
00280                         //spliti = root->combineSplits(att, i,&sumInvVars)/sumInvVars;
00281                         spliti = root->combineLabeledCenters(att, i, &sumInvVars)/sumInvVars;
00282                         if (!NonZero(sumInvVars))
00283                             spliti=split0;
00284                     }
00285                     else
00286                         spliti = root->combineCenters(att, i, &sumInvVars)/sumInvVars;
00287                     //          cerr << "here";
00288 
00289                     if (!NonZero(sumInvVars))
00290                         spliti=split0;
00291                     // else {
00292                     //        cerr << "sumInvVars=" << sumInvVars << endl;
00293 
00294                     shifts[i]= -(split0 - spliti);
00295                     //  }
00296                 }
00297                 continuousTransformer->SetShiftsAttribute(att, shifts);
00298             }
00299 
00300             // goto StartedLearning;
00301         }
00302         //      cerr << "finished schema matching" << endl;
00303 
00304         //  // create the root and give it the Id 1
00305         //root = new MultiDecisionTreeNode<T_Splitter>
00306         //    (1, dDomainSize, csplitDim, noDatasets,
00307         //     *discreteTransformer, *continuousTransformer);
00308 
00309         // learn in stages until nobody wants to learn anymore
00310         do
00311         {
00312             root->StartLearningEpoch();
00313             //printBitFlipData();
00314             for(int i=0; i<M; i++)
00315             {
00316                 int datasetIndex=(int)ctrainData[i][csplitDim];
00317 
00318                 ExitIf(datasetIndex>=noDatasets, "More datasets enountered than declared");
00319 
00320                 int classLabel=dtrainData[i][dsplitDim];
00321                 root->LearnSample(dtrainData[i],ctrainData[i], classLabel, datasetIndex);
00322             }
00323             //      StartedLearning:
00324             list<int> attList; // list of attributes; empty for now
00325             root->FindSplitAttributes(attList);
00326             // remove duplicates
00327             attList.sort();
00328             attList.unique();
00329 
00330 
00331             list<int>::iterator itrt;
00332             for (itrt=attList.begin();
00333                     itrt!=attList.end(); itrt++)
00334             {
00335                 int SplitAttribute=*itrt;
00336                 if (SplitAttribute>=0)
00337                 { // discrete attribute
00338 
00339                     /*Vector< BinomialStatistics > statistics(noDatasets, BinomialStatistics(dDomainSize[SplitAttribute]) );
00340                       root->AddDiscreteShiftStatistics(SplitAttribute, statistics);
00341                       for (int i=0; i<noDatasets; i++)
00342                       statistics[i].CorrectWeightedStatistics(totalWeight);*/
00343 
00344                     // find and propagate the shift
00345                     Vector<Permutation> shifts(noDatasets);
00346                     shifts[0]=Permutation(dDomainSize[SplitAttribute]);
00347                     for (int i=1; i<noDatasets; i++)
00348                         shifts[i]= root->ComputeDiscreteShift(true, SplitAttribute, i);
00349                     discreteTransformer->SetShiftsAttribute(SplitAttribute, shifts);
00350 
00351                 }
00352                 else
00353                 { // continuous attribute
00354 
00355                     /*Vector< NormalStatistics > statistics(noDatasets);
00356                       root->AddContinuousShiftStatistics(-SplitAttribute-1, statistics);
00357                       for (int i=0; i<noDatasets; i++)
00358                       statistics[i].CorrectWeightedStatistics(totalWeight);
00359 
00360                       // find and propagate the shift
00361                       Vector<double> shifts(noDatasets);
00362                       shifts[0]=0.0;
00363                       for (int i=1; i<noDatasets; i++)
00364                       shifts[i]=statistics[0].ComputeShift(labeled, statistics[i]);*/
00365 
00366                     double sumInvVars=0.0;
00367                     double split0;
00368                     if (labeled)
00369                     {
00370                         split0 = root->combineSplits(-SplitAttribute-1, 0, &sumInvVars)/sumInvVars;
00371                         split0=0.0;
00372                     }
00373                     else
00374                     {
00375                         if (!alignSplits)
00376                             split0 = root->combineCenters(-SplitAttribute-1, 0, &sumInvVars)/sumInvVars;
00377                         else
00378                             split0 = root->combineSplits(-SplitAttribute-1, 0, &sumInvVars)/sumInvVars;
00379                     }
00380 
00381                     if (sumInvVars==0.0)
00382                         split0=0.0;
00383 
00384 
00385                     // split0 = split0/ root->computeSumOfVarianceInverted(-SplitAttribute-1, 0);
00386                     Vector<double> shifts(noDatasets);
00387                     shifts[0]=0.0;
00388                     for (int i=1; i<noDatasets; i++)
00389                     {
00390                         double sumInvVars=0.0;
00391                         double spliti=0.0;
00392                         if (labeled)
00393                         {
00394                             // this aligns splitpnt.  want to align labeled centers
00395                             //spliti= root->combineSplits(-SplitAttribute-1, i,&sumInvVars)/sumInvVars;
00396                             shifts[i] = root->combineLabeledCenters(-SplitAttribute-1, i, &sumInvVars);
00397                         }
00398                         else
00399                         {
00400                             if (!alignSplits)
00401                                 spliti= root->combineCenters(-SplitAttribute-1, i,&sumInvVars)/sumInvVars;
00402                             else
00403                                 spliti = root->combineSplits(-SplitAttribute-1, i,&sumInvVars)/sumInvVars;
00404 
00405                         }
00406                         if (sumInvVars==0.0)
00407                             spliti=split0;
00408                         else
00409                         {
00410                             //        cerr << "sumInvVars=" << sumInvVars << endl;
00411 
00412                             //           cerr << "spliti " <<  spliti;
00413                             //           double denom = root->computeSumOfVarianceInverted(-SplitAttribute-1, i);
00414                             //           spliti = spliti/denom;
00415                             shifts[i]= -(split0 - spliti);
00416                         }
00417                     }
00418                     continuousTransformer->SetShiftsAttribute(-SplitAttribute-1, shifts);
00419                 }
00420             }
00421         }
00422         while (root->StopLearningEpoch(splitType, min_no_datapoints));
00423         //      cout << "End Learning" << endl;
00424     }
00425 
00426     virtual void Prune(void)
00427     {
00428         const Matrix<double>& ctrainData = pruning->GetTrainingData();
00429         const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( pruning )
00430                                         -> GetDiscreteTrainingData();
00431 
00432         int M=ctrainData.num_rows();
00433 
00434         root->InitializePruningStatistics();
00435         for(int i=0; i<M; i++)
00436         {
00437             int datasetIndex=(int)ctrainData[i][csplitDim];
00438             // create copies of tuple
00439             int Dvars[MAX_VARIABLES];
00440             for(int j=0; j<dsplitDim; j++)
00441                 Dvars[j]=dtrainData[i][j];
00442 
00443             double Cvars[MAX_VARIABLES];
00444             for (int j=0; j<csplitDim; j++)
00445                 Cvars[j]=ctrainData[i][j];
00446 
00447             // apply the transformations directly to the data
00448             discreteTransformer->ApplyShiftToTuple(Dvars, datasetIndex);
00449             continuousTransformer->ApplyShiftToTuple(Cvars, datasetIndex);
00450             int classLabel=dtrainData[i][dsplitDim];
00451             root->UpdatePruningStatistics(Dvars, Cvars, classLabel, datasetIndex );
00452         }
00453         root->FinalizePruningStatistics();
00454 
00455         // now cut the tree to the right size
00456         double cost=root->PruneSubtree();
00457         cout << "RMSN after pruning is:" << cost/M << endl;
00458     }
00459 
00460     virtual int SetOption(char* name, char* val)
00461     {
00462 
00463         if (strcmp(name, "SchemaMatch")==0)
00464         {
00465             if (strcmp(val, "true")==0)
00466                 schemaMatch=true;
00467             else
00468                 schemaMatch=false;
00469         }
00470         else if (strcmp(name, "labeled")==0)
00471         {
00472             if (strcmp(val, "true")==0)
00473                 labeled=true;
00474             else
00475                 labeled=false;
00476         }
00477         else if (strcmp(name, "alignSplits")==0)
00478         {
00479 
00480             if (strcmp(val, "true")==0)
00481                 alignSplits = true;
00482             else
00483                 alignSplits=false;
00484 
00485         }
00486         else
00487             return Machine::SetOption(name,val);
00488         return 1;
00489     }
00490 
00491     virtual void SaveToStream(ostream& out)
00492     {
00493         // save the description in a stream
00494         out << "Tree: with "  << dsplitDim << " discrete and " << csplitDim << " continuous attributes" << endl;
00495         out << "Discrete attribute domain sizes: " << endl;
00496 
00497         out << "[ ";
00498         for(int i=0; i<dsplitDim; i++)
00499             out << dDomainSize[i] << " ";
00500         out << "]" << endl;
00501 
00502         out << "Shifts: " << endl;
00503 
00504         for(int i=0; i<dsplitDim; i++)
00505         {
00506             for (int j=1; j<noDatasets; j++)
00507             {
00508                 out << "Shift to dataset " << j << " for discrete attribute " << i << ": " << endl;
00509                 discreteTransformer->saveToStream(out, i, j);
00510             }
00511         }
00512 
00513         for(int i=0; i<csplitDim; i++)
00514         {
00515             for (int j=1; j < noDatasets; j++)
00516             {
00517                 if (continuousTransformer->HasAttributeShifts(i))
00518                 {
00519                     out << "Shift to dataset " << j << " for continuous attribute " << i << ": "
00520                     << continuousTransformer->getShift(i, j) << endl;
00521                 }
00522                 else
00523                 {
00524                     out << "Shift to dataset " << j << " for continuous attribute " << i << ": NOSHIFT" << endl;
00525                 }
00526 
00527             }
00528         }
00529         out << endl;
00530         root->SaveToStream(out);
00531     }
00532 
00533 };
00534 }
00535 
00536 #endif // _CLUS_MULTIDECISIONTREE_H_
00537 

Generated on Mon Jul 21 16:57:24 2003 for SECRET by doxygen 1.3.2