00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 
00032 
00033 
00034 #if !defined _CLUS_MULTIDECISIONTREE_H_
00035 #define _CLUS_MULTIDECISIONTREE_H_
00036 
00037 #include "machine.h"
00038 #include "multidecisiontreenode.h"
00039 #include "dctraingen.h"
00040 #include "continuouslineartransformation.h"
00041 #include "discretepermutationtransformation.h"
00042 #include "statisticsgatherers.h"
00043 #include "general.h"
00044 
00045 #include <iostream>
00046 
00047 
00048 
00049 namespace CLUS
00050 {
00051 
00052 template< class T_Splitter >
00053 class MultiDecisionTree : public Machine
00054 {
00055 protected:
00056     MultiDecisionTreeNode< T_Splitter >* root;
00057 
00058 
00059     const Vector<int>& dDomainSize;
00060 
00061 
00062     int dsplitDim;
00063 
00064 
00065     int csplitDim;
00066 
00067 
00068     int noDatasets;
00069     
00070 
00071     int min_no_datapoints;
00072 
00073 
00074     int splitType;
00075 
00076     bool schemaMatch;
00077     bool labeled;
00078     bool alignSplits;
00079 
00080     ContinuousLinearTransformation* continuousTransformer;
00081     DiscretePermutationTransformation* discreteTransformer;
00082 public:
00083     MultiDecisionTree(const Vector<int>& DDomainSize, int CsplitDim, int NoDatasets):
00084             Machine(CsplitDim,1),dDomainSize(DDomainSize),
00085             dsplitDim(DDomainSize.dim()-1),csplitDim(CsplitDim),noDatasets(NoDatasets)
00086     {
00087         ofstream file("x");
00088         continuousTransformer=0;
00089         discreteTransformer=0;
00090         min_no_datapoints = 10;
00091         splitType = 0;
00092     }
00093 
00094     ~MultiDecisionTree(void)
00095     {
00096         if (root!=0)
00097             delete root;
00098 
00099         if (continuousTransformer!=0)
00100             delete continuousTransformer;
00101 
00102         if (discreteTransformer!=0)
00103             delete discreteTransformer;
00104     }
00105 
00106     virtual int InDim(void)
00107     {
00108         return dsplitDim+csplitDim+1;
00109     }
00110 
00111     virtual string TypeName(void)
00112     {
00113         return string("MultiDecisionTree");
00114     }
00115 
00116     virtual void Infer(void)
00117     {
00118         if (root==0)
00119             return;
00120         
00121         int Dvars[MAX_VARIABLES];
00122         for(int i=0; i<dsplitDim; i++)
00123         {
00124             Dvars[i]=(int)(*input)[i];
00125             
00126         }
00127         
00128         double Cvars[MAX_VARIABLES];
00129         for (int i=0; i<csplitDim; i++)
00130         {
00131             Cvars[i]=(*input)[i+dsplitDim];
00132             
00133         }
00134 
00135         
00136         int datasetIndex=(int)(*input)[dsplitDim+csplitDim];
00137 
00138         assert (datasetIndex<noDatasets);
00139 
00140         
00141         discreteTransformer->ApplyShiftToTuple(Dvars, datasetIndex);
00142         continuousTransformer->ApplyShiftToTuple(Cvars, datasetIndex);
00143 
00144         
00145         
00146         
00147         
00148         
00149         
00150         
00151 
00152         output[0]=root->Infer(Dvars, Cvars);
00153         
00154     }
00155 
00156     void printBitFlipData()
00157     {
00158         ofstream afile("hello");
00159         afile << "HELLO" << endl;
00160         bool significant[csplitDim];
00161         
00162 
00163 
00164 
00165 
00166 
00167 
00168 
00169         ofstream file("bitFlipStuff");
00170         bool bits[csplitDim];
00171         file << noDatasets;
00172         for (int j=0; j< csplitDim; j++)
00173         {
00174             bits[j]=root->negMeanLessThanPos(j, 0);
00175         }
00176         for (int i=1; i<noDatasets; i++)
00177         {
00178             file << "in bit flip loop" << endl;
00179 
00180             file << "Dataset " << i << ":" << endl;
00181             for (int j=0; j< csplitDim; j++)
00182             {
00183                 if (root->labeledMeansSignificant(j,0))
00184                 {
00185                     file << (root->negMeanLessThanPos(j, i)!= bits[j]) << endl;
00186                 }
00187             }
00188         }
00189     }
00190 
00191     virtual void Identify()
00192     {
00193         
00194         
00195         Matrix<double> ctrainData = training->GetTrainingData();
00196 
00197         continuousTransformer=new ContinuousLinearTransformation(csplitDim, noDatasets,
00198                               ctrainData) ;
00199         ofstream file("xyz");
00200         file << noDatasets << endl;
00201         
00202         
00203         Matrix<int> dtrainData = dynamic_cast< DCTrainingData* >( training )
00204                                  -> GetDiscreteTrainingData();
00205 
00206         discreteTransformer=new DiscretePermutationTransformation(dsplitDim, noDatasets,
00207                             ctrainData, dtrainData,
00208                             dDomainSize);
00209 
00210         int M=ctrainData.num_rows();
00211 
00212         root = new MultiDecisionTreeNode<T_Splitter>
00213                (NULL, 1, dDomainSize, csplitDim, noDatasets,
00214                 *discreteTransformer, *continuousTransformer);
00215 
00216 
00217         file << "here I am";
00218         file << noDatasets << endl;
00219         
00220         if (schemaMatch)
00221         {
00222             
00223             
00224 
00225             root->StartLearningEpoch();
00226             for(int i=0; i<M; i++)
00227             {
00228                 int datasetIndex=(int)ctrainData[i][csplitDim];
00229 
00230                 ExitIf(datasetIndex>=noDatasets, "More datasets enountered than declared");
00231                 int classLabel=dtrainData[i][dsplitDim];
00232                 root->LearnSample(dtrainData[i],ctrainData[i], classLabel, datasetIndex);
00233             }
00234             file << noDatasets;
00235             
00236             for (int att =0; att<dsplitDim; att++)
00237             { 
00238 
00239                 
00240                 Vector<Permutation> shifts(noDatasets);
00241                 shifts[0]=Permutation(dDomainSize[att]);
00242                 for (int i=1; i<noDatasets; i++)
00243                     shifts[i]= root->ComputeDiscreteShift(true, att, i);
00244                 discreteTransformer->SetShiftsAttribute(att, shifts);
00245 
00246             }
00247 
00248             for (int att =0; att<csplitDim; att++)
00249             { 
00250 
00251                 double sumInvVars=0.0;
00252                 double split0;
00253                 if (labeled)
00254                 {
00255                     split0 = 0.0;
00256                     
00257                 }
00258                 else
00259                 {
00260                     if (!alignSplits)
00261                         split0= root->combineCenters(att, 0, &sumInvVars)/sumInvVars;
00262                     else
00263                         split0 == root->combineSplits(att, 0, &sumInvVars)/sumInvVars;
00264                 }
00265 
00266                 if (!NonZero(sumInvVars))
00267                     split0 =0.0;
00268                 
00269 
00270 
00271                 Vector<double> shifts(noDatasets);
00272                 shifts[0]=0.0;
00273                 for (int i=1; i<noDatasets; i++)
00274                 {
00275                     double sumInvVars=0.0;
00276                     double spliti;
00277                     if (labeled)
00278                     {
00279                         
00280                         
00281                         spliti = root->combineLabeledCenters(att, i, &sumInvVars)/sumInvVars;
00282                         if (!NonZero(sumInvVars))
00283                             spliti=split0;
00284                     }
00285                     else
00286                         spliti = root->combineCenters(att, i, &sumInvVars)/sumInvVars;
00287                     
00288 
00289                     if (!NonZero(sumInvVars))
00290                         spliti=split0;
00291                     
00292                     
00293 
00294                     shifts[i]= -(split0 - spliti);
00295                     
00296                 }
00297                 continuousTransformer->SetShiftsAttribute(att, shifts);
00298             }
00299 
00300             
00301         }
00302         
00303 
00304         
00305         
00306         
00307         
00308 
00309         
00310         do
00311         {
00312             root->StartLearningEpoch();
00313             
00314             for(int i=0; i<M; i++)
00315             {
00316                 int datasetIndex=(int)ctrainData[i][csplitDim];
00317 
00318                 ExitIf(datasetIndex>=noDatasets, "More datasets enountered than declared");
00319 
00320                 int classLabel=dtrainData[i][dsplitDim];
00321                 root->LearnSample(dtrainData[i],ctrainData[i], classLabel, datasetIndex);
00322             }
00323             
00324             list<int> attList; 
00325             root->FindSplitAttributes(attList);
00326             
00327             attList.sort();
00328             attList.unique();
00329 
00330 
00331             list<int>::iterator itrt;
00332             for (itrt=attList.begin();
00333                     itrt!=attList.end(); itrt++)
00334             {
00335                 int SplitAttribute=*itrt;
00336                 if (SplitAttribute>=0)
00337                 { 
00338 
00339                     
00340 
00341 
00342 
00343 
00344                     
00345                     Vector<Permutation> shifts(noDatasets);
00346                     shifts[0]=Permutation(dDomainSize[SplitAttribute]);
00347                     for (int i=1; i<noDatasets; i++)
00348                         shifts[i]= root->ComputeDiscreteShift(true, SplitAttribute, i);
00349                     discreteTransformer->SetShiftsAttribute(SplitAttribute, shifts);
00350 
00351                 }
00352                 else
00353                 { 
00354 
00355                     
00356 
00357 
00358 
00359 
00360 
00361 
00362 
00363 
00364 
00365 
00366                     double sumInvVars=0.0;
00367                     double split0;
00368                     if (labeled)
00369                     {
00370                         split0 = root->combineSplits(-SplitAttribute-1, 0, &sumInvVars)/sumInvVars;
00371                         split0=0.0;
00372                     }
00373                     else
00374                     {
00375                         if (!alignSplits)
00376                             split0 = root->combineCenters(-SplitAttribute-1, 0, &sumInvVars)/sumInvVars;
00377                         else
00378                             split0 = root->combineSplits(-SplitAttribute-1, 0, &sumInvVars)/sumInvVars;
00379                     }
00380 
00381                     if (sumInvVars==0.0)
00382                         split0=0.0;
00383 
00384 
00385                     
00386                     Vector<double> shifts(noDatasets);
00387                     shifts[0]=0.0;
00388                     for (int i=1; i<noDatasets; i++)
00389                     {
00390                         double sumInvVars=0.0;
00391                         double spliti=0.0;
00392                         if (labeled)
00393                         {
00394                             
00395                             
00396                             shifts[i] = root->combineLabeledCenters(-SplitAttribute-1, i, &sumInvVars);
00397                         }
00398                         else
00399                         {
00400                             if (!alignSplits)
00401                                 spliti= root->combineCenters(-SplitAttribute-1, i,&sumInvVars)/sumInvVars;
00402                             else
00403                                 spliti = root->combineSplits(-SplitAttribute-1, i,&sumInvVars)/sumInvVars;
00404 
00405                         }
00406                         if (sumInvVars==0.0)
00407                             spliti=split0;
00408                         else
00409                         {
00410                             
00411 
00412                             
00413                             
00414                             
00415                             shifts[i]= -(split0 - spliti);
00416                         }
00417                     }
00418                     continuousTransformer->SetShiftsAttribute(-SplitAttribute-1, shifts);
00419                 }
00420             }
00421         }
00422         while (root->StopLearningEpoch(splitType, min_no_datapoints));
00423         
00424     }
00425 
00426     virtual void Prune(void)
00427     {
00428         const Matrix<double>& ctrainData = pruning->GetTrainingData();
00429         const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( pruning )
00430                                         -> GetDiscreteTrainingData();
00431 
00432         int M=ctrainData.num_rows();
00433 
00434         root->InitializePruningStatistics();
00435         for(int i=0; i<M; i++)
00436         {
00437             int datasetIndex=(int)ctrainData[i][csplitDim];
00438             
00439             int Dvars[MAX_VARIABLES];
00440             for(int j=0; j<dsplitDim; j++)
00441                 Dvars[j]=dtrainData[i][j];
00442 
00443             double Cvars[MAX_VARIABLES];
00444             for (int j=0; j<csplitDim; j++)
00445                 Cvars[j]=ctrainData[i][j];
00446 
00447             
00448             discreteTransformer->ApplyShiftToTuple(Dvars, datasetIndex);
00449             continuousTransformer->ApplyShiftToTuple(Cvars, datasetIndex);
00450             int classLabel=dtrainData[i][dsplitDim];
00451             root->UpdatePruningStatistics(Dvars, Cvars, classLabel, datasetIndex );
00452         }
00453         root->FinalizePruningStatistics();
00454 
00455         
00456         double cost=root->PruneSubtree();
00457         cout << "RMSN after pruning is:" << cost/M << endl;
00458     }
00459 
00460     virtual int SetOption(char* name, char* val)
00461     {
00462 
00463         if (strcmp(name, "SchemaMatch")==0)
00464         {
00465             if (strcmp(val, "true")==0)
00466                 schemaMatch=true;
00467             else
00468                 schemaMatch=false;
00469         }
00470         else if (strcmp(name, "labeled")==0)
00471         {
00472             if (strcmp(val, "true")==0)
00473                 labeled=true;
00474             else
00475                 labeled=false;
00476         }
00477         else if (strcmp(name, "alignSplits")==0)
00478         {
00479 
00480             if (strcmp(val, "true")==0)
00481                 alignSplits = true;
00482             else
00483                 alignSplits=false;
00484 
00485         }
00486         else
00487             return Machine::SetOption(name,val);
00488         return 1;
00489     }
00490 
00491     virtual void SaveToStream(ostream& out)
00492     {
00493         
00494         out << "Tree: with "  << dsplitDim << " discrete and " << csplitDim << " continuous attributes" << endl;
00495         out << "Discrete attribute domain sizes: " << endl;
00496 
00497         out << "[ ";
00498         for(int i=0; i<dsplitDim; i++)
00499             out << dDomainSize[i] << " ";
00500         out << "]" << endl;
00501 
00502         out << "Shifts: " << endl;
00503 
00504         for(int i=0; i<dsplitDim; i++)
00505         {
00506             for (int j=1; j<noDatasets; j++)
00507             {
00508                 out << "Shift to dataset " << j << " for discrete attribute " << i << ": " << endl;
00509                 discreteTransformer->saveToStream(out, i, j);
00510             }
00511         }
00512 
00513         for(int i=0; i<csplitDim; i++)
00514         {
00515             for (int j=1; j < noDatasets; j++)
00516             {
00517                 if (continuousTransformer->HasAttributeShifts(i))
00518                 {
00519                     out << "Shift to dataset " << j << " for continuous attribute " << i << ": "
00520                     << continuousTransformer->getShift(i, j) << endl;
00521                 }
00522                 else
00523                 {
00524                     out << "Shift to dataset " << j << " for continuous attribute " << i << ": NOSHIFT" << endl;
00525                 }
00526 
00527             }
00528         }
00529         out << endl;
00530         root->SaveToStream(out);
00531     }
00532 
00533 };
00534 }
00535 
00536 #endif // _CLUS_MULTIDECISIONTREE_H_
00537