Main Page | Namespace List | Class Hierarchy | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages

binarymulticlassificationsplitter.h

Go to the documentation of this file.
00001 /*
00002 
00003 Copyright (c) 2003, Cornell University
00004 All rights reserved.
00005 
00006 Redistribution and use in source and binary forms, with or without
00007 modification, are permitted provided that the following conditions are met:
00008 
00009    - Redistributions of source code must retain the above copyright notice,
00010        this list of conditions and the following disclaimer.
00011    - Redistributions in binary form must reproduce the above copyright
00012        notice, this list of conditions and the following disclaimer in the
00013        documentation and/or other materials provided with the distribution.
00014    - Neither the name of Cornell University nor the names of its
00015        contributors may be used to endorse or promote products derived from
00016        this software without specific prior written permission.
00017 
00018 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
00019 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
00020 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
00021 ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
00022 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
00023 CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
00024 SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
00025 INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
00026 CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
00027 ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
00028 THE POSSIBILITY OF SUCH DAMAGE.
00029 
00030 */
00031 
00032 // -*- C++ -*-
00033 
00034 #if !defined _CLUS_BINARYMULTICLASSICICATIONSPLITTER_H
00035 #define _CLUS_BINARYMULTICLASSICICATIONSPLITTER_H
00036 
00037 #include "statisticsgatherers.h"
00038 #include "discretepermutationtransformation.h"
00039 
00040 using namespace TNT;
00041 
00042 namespace CLUS
00043 {
00044 class BinaryMultiClassificationSplitter
00045 {
00046     /// number of discrete split variables
00047     int dsplitDim;
00048 
00049     /// number of continuous split variables
00050     int csplitDim;
00051 
00052     /// id of owner node
00053     int id;
00054     
00055     bool noWeighting;
00056 
00057     int noDatasets;
00058     
00059     const Vector<int>& dDomainSize;
00060     
00061     /// indicates on what variable this node splits on
00062     int SplitVariable;
00063     
00064     // split point if continuous variable is split variable
00065     double splitPoint;
00066     
00067     /** The list of values for the left child for the separating variable if
00068     the split is not oblique. Contains values in order so that
00069     binary search can be used. */
00070     Vector<int> SeparatingSet;
00071 
00072     DiscretePermutationTransformation& discreteTransformer;
00073     ContinuousLinearTransformation& continuousTransformer;
00074 
00075     /** Statistics for split decision. One for each attributeXdataset
00076     If an attribute already has the shift determined keep statistics 
00077     about all datasets in the location for the first dataset
00078     */
00079     Vector< Vector<BinomialStatistics> > discreteStatistics;
00080     Vector< Vector<NormalStatistics> > continuousStatistics;
00081 
00082     /// one per attribute
00083     Vector< Vector<Permutation> > discreteShifts;
00084 
00085     /// one per attribute
00086     Vector< Vector<double> > continuousShifts;
00087     
00088     Vector<int> examplesSeen;
00089     Vector<int> examples0Seen;
00090 
00091     /// statistics for class label determination
00092     int count;
00093     int countC0;
00094     
00095 public:
00096     BinaryMultiClassificationSplitter(const Vector<int>& DDomainSize,int CsplitDim,
00097                                       int NoDatasets,
00098                                       DiscretePermutationTransformation& DiscreteTransformer,
00099                                       ContinuousLinearTransformation& ContinuousTransformer):
00100             dsplitDim(DDomainSize.dim()-1), csplitDim(CsplitDim), noWeighting(true), /* was false */
00101             noDatasets(NoDatasets), dDomainSize(DDomainSize),
00102             discreteTransformer(DiscreteTransformer), continuousTransformer(ContinuousTransformer),
00103             discreteStatistics(0), continuousStatistics(0), 
00104             discreteShifts(0), continuousShifts(0),
00105             examplesSeen(noDatasets), examples0Seen(noDatasets)
00106     {}
00107 
00108     ~BinaryMultiClassificationSplitter(void)
00109     {}
00110     void setNodeID(int ID)
00111     {
00112         id=ID;
00113     }
00114 
00115     bool GotNoData(void)
00116     {
00117         return count==0;
00118     }
00119 
00120     int GetCSplitDim(void)
00121     {
00122         return csplitDim;
00123     }
00124     
00125     int GetDSplitDim(void)
00126     {
00127         return dsplitDim;
00128     }
00129     
00130     int GetNoDatasets(void)
00131     {
00132         return noDatasets;
00133     }
00134     
00135     const Vector<int>& GetDDomainSize(void)
00136     {
00137         return dDomainSize;
00138     }
00139 
00140     double getLabeledCount(int attribute, int dataSetIndex, bool b)
00141     {
00142         if (b)
00143             return continuousStatistics[attribute][dataSetIndex].getcountC1();
00144         else
00145             return continuousStatistics[attribute][dataSetIndex].getcountC0();
00146     }
00147 
00148     DiscretePermutationTransformation& GetDiscreteTransformer(void)
00149     {
00150         return discreteTransformer;
00151     }
00152 
00153     ContinuousLinearTransformation& GetContinuousTransformer(void)
00154     {
00155         return continuousTransformer;
00156     }
00157 
00158     /// Initializes the datastructures used in split variable selection
00159     void InitializeSplitStatistics(void)
00160     {
00161         count=0;
00162         countC0=0;
00163 
00164         discreteStatistics.newsize(dsplitDim);
00165         for (int i=0; i<dsplitDim; i++)
00166         {
00167             // for each discrete attribute determine if the shifts are fixed
00168             if (discreteTransformer.HasAttributeShifts(i))
00169             {
00170                 discreteStatistics[i].newsize(1);
00171                 discreteStatistics[i][0].ResetDomainSize(dDomainSize[i]);
00172             }
00173             else
00174             {
00175                 discreteStatistics[i].newsize(noDatasets);
00176                 for (int j=0; j<noDatasets; j++)
00177                     discreteStatistics[i][j].ResetDomainSize(dDomainSize[i]);
00178             }
00179         }
00180 
00181         continuousStatistics.newsize(csplitDim);
00182         for (int i=0; i<csplitDim; i++)
00183         {
00184             // for each discrete attribute determine if the shifts are fixed
00185             if (continuousTransformer.HasAttributeShifts(i))
00186             {
00187                 continuousStatistics[i].newsize(1);
00188             }
00189             else
00190             {
00191                 continuousStatistics[i].newsize(noDatasets);
00192             }
00193         }
00194     }
00195 
00196     /// Choose a branch
00197     ///
00198     /// @param Dvars         discrete variables
00199     /// @param Cvars         continuous variables
00200     /// @return              0 is left and 1 right
00201     int ChooseBranch( const int* Dvars, const double* Cvars)
00202     {
00203         if (SplitVariable <=-1)
00204         {
00205             // split on a continuous variable
00206             // the split variable is actually -SplitVariable-1
00207             if (Cvars[-SplitVariable-1]<=splitPoint)
00208                 return 0;
00209             else
00210                 return 1;
00211         }
00212         else
00213         {
00214             // split on a discrete variable
00215             int value=Dvars[SplitVariable];
00216             // look for value in SeparatingSet
00217             bool pickleft=false;
00218             int l=0, r=SeparatingSet.dim()-1;
00219 
00220             assert(r>=0);
00221 
00222             while ( l<=r && !pickleft )
00223             {
00224                 int m=(l+r)/2;
00225                 int vm=SeparatingSet[m];
00226                 if ( vm ==value )
00227                 {
00228                     pickleft=true;
00229                     break;
00230                 }
00231                 if ( vm < value )
00232                     l=m+1;
00233                 else
00234                     r=m-1;
00235             }
00236             if (pickleft)
00237                 return 0;
00238             else
00239                 return 1;
00240         }
00241     }
00242 
00243     /// Update split stats with new data
00244     ///
00245     /// @param Dvars         discrete variables
00246     /// @param Cvars         continuous variables
00247     /// @param classLabel    classification label
00248     /// @param datasetNo     dataset number
00249     void UpdateSplitStatistics( const int* Dvars, const double* Cvars,
00250                                 int classLabel, int datasetNo)
00251     {
00252         count++;
00253         examplesSeen[datasetNo]++;
00254         if (classLabel==0)
00255         {
00256             countC0++;
00257             examples0Seen[datasetNo]++;
00258         }
00259 
00260         // update discrete statistics
00261         for (int i=0; i<dsplitDim; i++)
00262         {
00263             // for each discrete attribute determine if the shifts are fixed
00264             if (discreteTransformer.HasAttributeShifts(i))
00265                 (discreteStatistics[i])[0].UpdateStatistics(Dvars[i],classLabel);
00266             else
00267                 (discreteStatistics[i])[datasetNo].UpdateStatistics(Dvars[i],classLabel);
00268         }
00269 
00270         for (int i=0; i<csplitDim; i++)
00271         {
00272             // for each discrete attribute determine if the shifts are fixed
00273             if (continuousTransformer.HasAttributeShifts(i))
00274                 (continuousStatistics[i])[0].UpdateStatistics(Cvars[i],classLabel);
00275             else
00276                 (continuousStatistics[i])[datasetNo].UpdateStatistics(Cvars[i],classLabel);
00277         }
00278     }
00279     
00280     bool labeledMeansSignificant(int attribute, int dataSetIndex)
00281     {
00282         return continuousStatistics[attribute][dataSetIndex].labeledMeansSignificant();
00283     }
00284 
00285     bool negMeanLessThanPos(int attribute, int dataSetIndex)
00286     {
00287         return continuousStatistics[attribute][dataSetIndex].negMeanLessThanPos();
00288     }
00289 
00290     bool hasContinuousData(int attribute, int dataSetIndex)
00291     {
00292         return continuousStatistics[attribute][dataSetIndex].hasData();
00293     }
00294 
00295     void DeleteTemporaryStatistics(void)
00296     {
00297         // discreteStatistics.newsize(0);
00298         // continuousStatistics.newsize(0);
00299     }
00300 
00301     /** Computes the split variable. If the variable needs a shift (doesn't have one already)
00302         add the variable to the attList
00303 
00304         @param attList       list of attribute values
00305         @return              If not enough data return false: make node a leaf
00306     */
00307     bool ComputeSplitVariable(list<int>& attList)
00308     {
00309         cerr << "nodeID " << id<< endl;
00310         cout << "Counts:"  << countC0 << "," << count-countC0 << endl;
00311         if (countC0<2 || count-countC0<2)
00312         {
00313 #ifdef DEBUG_PRINT
00314             cout << "Making the node a leaf. Counts: " << countC0 << "," << count-countC0 << endl;
00315 #endif
00316 
00317             return false; // make the node a leaf; not enough data to compute splits
00318         }
00319 
00320         // type is ignored, gini is used always
00321         double maxgini=0.0;
00322 
00323         // go over the discrete attributes and find the best one
00324         // when an attribute without shifts set is encountered the gini is average of ginies over datasets
00325         for (int i=0; i<dsplitDim; i++)
00326         {
00327             double curr_gini=0.0;
00328 
00329             // printing statistics
00330             //for (int j=0; j<noDatasets; j++){
00331             //cout << "DiscreteStatistics " << i << " " << j << endl;
00332             //discreteStatistics[i][j].Print();
00333             //}
00334 
00335             if (discreteTransformer.HasAttributeShifts(i))
00336             {
00337                 curr_gini=discreteStatistics[i][0].ComputeGiniGain();
00338             }
00339             else
00340             {
00341                 double total=0;
00342                 double weight=0;
00343                 // compute average gini
00344                 for (int j=0; j<noDatasets; j++)
00345                 {
00346                     weight = discreteStatistics[i][j].getCount();
00347                     curr_gini+=weight*discreteStatistics[i][j].ComputeGiniGain();
00348                     total += weight;
00349                 }
00350 
00351                 curr_gini/=total;
00352             }
00353             if (curr_gini>maxgini)
00354             {
00355                 maxgini=curr_gini;
00356                 SplitVariable=i;
00357             }
00358         }
00359 
00360         // continuous attributes now
00361         for (int i=0; i<csplitDim; i++)
00362         {
00363             double curr_gini=0.0;
00364 
00365             // printing statistics
00366             //  for (int j=0; j<noDatasets; j++){
00367             // cout << "ContinuousStatistics " << i << " " << j << endl;
00368             //continuousStatistics[i][j].Print();
00369             //  }
00370 
00371 
00372             if (continuousTransformer.HasAttributeShifts(i))
00373             {
00374                 curr_gini=continuousStatistics[i][0].ComputeGiniGain();
00375             }
00376             else
00377             {
00378                 double weight =0;
00379                 double total=0;
00380 
00381                 for (int j=0; j<noDatasets; j++)
00382                 {
00383 
00384                     NormalStatistics stat = continuousStatistics[i][j];
00385                     if (stat.nonZero())
00386                     {
00387                         weight = continuousStatistics[i][j].getcountC0() + continuousStatistics[i][j].getcountC1();
00388                         curr_gini+=weight*continuousStatistics[i][j].ComputeGiniGain();
00389                         total+=weight;
00390                     }
00391                 }
00392 
00393                 curr_gini/=total;
00394             }
00395             if (curr_gini>maxgini)
00396             {
00397                 maxgini=curr_gini;
00398                 SplitVariable=-(i+1);
00399             }
00400         }
00401 
00402 
00403 
00404         //      cout << "Maxgini=" << maxgini << endl;
00405 
00406         if (maxgini==0.0)
00407             return false; // make the node a leaf, nobody can do a reasonable split
00408 
00409         // put the variable in attList if it requires a shift computation and not already there
00410         cout << "Chosen split variable" << SplitVariable << " maxgini=" << maxgini << endl;
00411         if ( (SplitVariable>=0 && !discreteTransformer.HasAttributeShifts(SplitVariable)) ||
00412                 (SplitVariable<0 && !continuousTransformer.HasAttributeShifts(-SplitVariable-1)) )
00413         {
00414             // add attribute to the list
00415             attList.push_front(SplitVariable);
00416         }
00417 
00418         return true;
00419     }
00420 
00421     Permutation ComputeDiscreteShift(bool label, int attribute, int datasetIndex)
00422     {
00423         if (label)
00424             return discreteStatistics[attribute][0].ComputeShift(labeled, discreteStatistics[attribute][datasetIndex]);
00425         else
00426             return discreteStatistics[attribute][0].ComputeShift(unlabeled, discreteStatistics[attribute][datasetIndex]);
00427     }
00428     
00429     void AddDiscreteShiftStatistics(int SplitAttribute, double weight,
00430                                     Vector< BinomialStatistics >& statistics)
00431     {
00432         if (weight==0.0)
00433             return;
00434 
00435         for (int i=0; i<statistics.dim(); i++)
00436             statistics[i].AddWeightedStatistics(discreteStatistics[SplitAttribute][i], weight);
00437     }
00438 
00439     // used from silly weighting scheme
00440     /*
00441     void AddContinuousShiftStatistics(int SplitAttribute, double weight,
00442           Vector< NormalStatistics >& statistics) {
00443 
00444         if (weight==0.0)
00445             return;
00446 
00447         for (int i=0; i<statistics.dim(); i++)
00448             statistics[i].AddWeightedStatistics(continuousStatistics[SplitAttribute][i], weight);
00449     }*/
00450 
00451     void ComputeCenter(void)
00452     {
00453 
00454         // don't forget to propagate shifts if you have to
00455         if (SplitVariable>=0)
00456         {
00457             // discrete variable
00458             int noDatasets=discreteStatistics[SplitVariable].dim();
00459             if (noDatasets>1)
00460             {
00461                 // variable was shifted, have to propagate the shifts
00462                 for (int i=1; i<noDatasets; i++)
00463                 {
00464 
00465                     discreteStatistics[SplitVariable][0].AddStatisticsShifted(discreteStatistics[SplitVariable][i],
00466                             discreteTransformer.GetShift(SplitVariable, i));
00467                 }
00468             }
00469             discreteStatistics[SplitVariable][0].ComputeGiniGain(); // computes the shift also
00470             SeparatingSet=discreteStatistics[SplitVariable][0].GetSplit();
00471 
00472 
00473         }
00474         else
00475         {
00476             int splitVar=-SplitVariable-1;
00477             if (continuousStatistics[splitVar].dim()>1)
00478             {
00479                 // variable was shifted, have to propagate the shifts
00480                 for (int i=1; i<noDatasets; i++)
00481                 {
00482                     continuousStatistics[splitVar][0].AddStatisticsShifted(continuousStatistics[splitVar][i],
00483                             continuousTransformer.GetShift(splitVar, i));
00484                 }
00485             }
00486             continuousStatistics[splitVar][0].ComputeGiniGain();
00487             splitPoint=continuousStatistics[splitVar][0].GetCenter();
00488             //  cout << "Selected split point of " << splitPoint << "for var " << SplitVariable << endl;
00489         }
00490     }
00491 
00492     double getLabeledCenter(int attribute, int dataSetIndex, bool b)
00493     {
00494         return continuousStatistics[attribute][dataSetIndex].getLabeledCenter(b);
00495     }
00496 
00497     double getLabeledCenterVariance(int attribute, int dataSetIndex, bool b)
00498     {
00499         return continuousStatistics[attribute][dataSetIndex].getLCVariance(b);
00500     }
00501 
00502     double getVariance(int attribute, int dataSetIndex)
00503     {
00504         return continuousStatistics[attribute][dataSetIndex].getVariance();
00505     }
00506 
00507     /// For continuous attributes only
00508     ///
00509     /// @param attribute
00510     /// @param dataSetIndex
00511     double getSplitVariance(int attribute, int dataSetIndex)
00512     {
00513         if (noWeighting)
00514             return 1;
00515         return continuousStatistics[attribute][dataSetIndex].getSplitVariance();
00516     }
00517 
00518     /// For continuous attributes not chosen as split attribute at current node
00519     ///
00520     /// @param attribute
00521     /// @param dataSetIndex
00522     double getTentativeSplitPoint(int attribute, int dataSetIndex)
00523     {
00524         return continuousStatistics[attribute][dataSetIndex].GetSplit();
00525         //return continuousStatistics[attribute][dataSetIndex].GetSplit();
00526     }
00527 
00528     // For continuous attributes not chosen as split attribute at current node
00529     ///
00530     /// @param attribute
00531     /// @param dataSetIndex
00532     double getTentativeCenter(int attribute, int dataSetIndex)
00533     {
00534         return continuousStatistics[attribute][dataSetIndex].GetCenter();
00535         //return continuousStatistics[attribute][dataSetIndex].GetSplit();
00536     }
00537 
00538     void ComputeSplitPoint(void)
00539     {
00540         // don't forget to propagate shifts if you have to
00541         if (SplitVariable>=0)
00542         {
00543             // discrete variable
00544             int noDatasets=discreteStatistics[SplitVariable].dim();
00545             if (noDatasets>1)
00546             {
00547                 // variable was shifted, have to propagate the shifts
00548                 for (int i=1; i<noDatasets; i++)
00549                 {
00550 
00551                     discreteStatistics[SplitVariable][0].AddStatisticsShifted(discreteStatistics[SplitVariable][i],
00552                             discreteTransformer.GetShift(SplitVariable, i));
00553                 }
00554             }
00555             discreteStatistics[SplitVariable][0].ComputeGiniGain(); // computes the shift also
00556             SeparatingSet=discreteStatistics[SplitVariable][0].GetSplit();
00557         }
00558         else
00559         {
00560             int splitVar=-SplitVariable-1;
00561             if (continuousStatistics[splitVar].dim()>1)
00562             {
00563                 // variable was shifted, have to propagate the shifts
00564                 for (int i=1; i<noDatasets; i++)
00565                 {
00566                     continuousStatistics[splitVar][0].AddStatisticsShifted(continuousStatistics[splitVar][i],
00567                             continuousTransformer.GetShift(splitVar, i));
00568                 }
00569             }
00570             continuousStatistics[splitVar][0].ComputeGiniGain();
00571             splitPoint=continuousStatistics[splitVar][0].GetSplit();
00572             //cout << "Selected split point of " << splitPoint << "for var " << SplitVariable << endl;
00573         }
00574     }
00575 
00576     int ComputeClassLabel(void)
00577     {
00578         if (countC0>=count-countC0)
00579             return 0;
00580         else
00581             return 1;
00582     }
00583 
00584     double getContinuousShift(int attribute, int dataSetIndex)
00585     {
00586         return continuousTransformer.GetShift(attribute, dataSetIndex);
00587     }
00588 
00589     bool MoreSplits(int min_no_datapoints, int nodeID)
00590     {
00591         bool discreteSplitGoneBad = (SplitVariable>=0 && SeparatingSet.size()==0);
00592         return ( (count>=min_no_datapoints) && countC0!=0 && countC0!=count && !discreteSplitGoneBad); //REBA ADDED SECOND CLAUSES
00593     }
00594 
00595     void SaveToStream(ostream& out, bool isLeaf)
00596     {
00597 
00598         out << " label " << ComputeClassLabel()  << endl;
00599         for (int j=0; j < noDatasets; j++)
00600         {
00601             //  out << " Seen " <<  examplesSeen[j] << " examples from data set " << j << ", "
00602             //      << examples0Seen[j] << " of them labelled 0" << endl;
00603         }
00604         out << "}" << endl << endl;
00605 
00606         if (!isLeaf)
00607         {
00608             out << " split attribute: " << SplitVariable;
00609             if (SplitVariable<0)
00610                 out << ", split point: " << splitPoint << endl << "}" << endl << endl;
00611             else
00612             {
00613                 out << ", split set: ";
00614                 for (int i=0; i<SeparatingSet.dim(); i++)
00615                     out << SeparatingSet[i] << " ";
00616                 out << endl << "}" << endl << endl;
00617             }
00618 
00619         }
00620     }
00621 
00622 };
00623 }
00624 
00625 #endif // _CLUS_MULTIDECISIONTREENODE_H

Generated on Mon Jul 21 16:57:23 2003 for SECRET by doxygen 1.3.2