00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #if !defined _CLUS_BINARYPROBABILISTICSPLITTER_H_
00035 #define _CLUS_BINARYPROBABILISTICSPLITTER_H_
00036
00037 #include "general.h"
00038 #include "statisticsgatherers.h"
00039 #include "vec.h"
00040
00041 #ifdef DEBUG_PRINT
00042 #include <iostream>
00043 using namespace std;
00044 #endif
00045
00046 using namespace TNT;
00047
00048 namespace CLUS
00049 {
00050 class BinaryProbabilisticSplitter
00051 {
00052 protected:
00053
00054 int dsplitDim, csplitDim;
00055
00056
00057 double mass0, mass;
00058
00059
00060 const Vector<int>& dDomainSize;
00061
00062
00063 int SplitVariable;
00064
00065
00066 double splitPoint, splitSTD;
00067
00068 Vector<double> splitSetProbability;
00069
00070
00071
00072
00073
00074 Vector<ProbabilisticBinomialStatistics> discreteStatistics;
00075 Vector<NormalStatistics> continuousStatistics;
00076
00077 public:
00078 BinaryProbabilisticSplitter(const Vector<int>& DDomainSize,int CsplitDim ):
00079 dsplitDim(DDomainSize.dim()-1), csplitDim(CsplitDim),
00080 dDomainSize(DDomainSize),
00081 discreteStatistics(0), continuousStatistics(0)
00082 {}
00083
00084 bool GotNoData(void)
00085 {
00086 return mass==0.0;
00087 }
00088
00089 int GetCSplitDim(void)
00090 {
00091 return csplitDim;
00092 }
00093
00094 int GetDSplitDim(void)
00095 {
00096 return dsplitDim;
00097 }
00098
00099 const Vector<int>& GetDDomainSize(void)
00100 {
00101 return dDomainSize;
00102 }
00103
00104
00105 void InitializeSplitStatistics(void)
00106 {
00107 mass0=mass=0.0;
00108
00109 discreteStatistics.newsize(dsplitDim);
00110 for (int i=0; i<dsplitDim; i++)
00111 {
00112 discreteStatistics[i].ResetDomainSize(dDomainSize[i]);
00113 }
00114
00115 continuousStatistics.newsize(csplitDim);
00116
00117 }
00118
00119
00120 double ProbabilityLeft(const int* Dvars, const double* Cvars)
00121 {
00122 if (SplitVariable <=-1)
00123 {
00124
00125
00126 int splitVar=-SplitVariable-1;
00127
00128 return PValueNormalDistribution(splitPoint,splitSTD,Cvars[splitVar]);
00129 }
00130 else
00131 {
00132
00133 int value=Dvars[SplitVariable];
00134 return splitSetProbability[value];
00135 }
00136 }
00137
00138 void UpdateSplitStatistics( const int* Dvars, const double* Cvars,
00139 int classLabel, double probability)
00140 {
00141 mass+=probability;
00142 if (classLabel==0)
00143 mass0+=probability;
00144
00145
00146 for (int i=0; i<dsplitDim; i++)
00147 {
00148 int value=Dvars[i];
00149 discreteStatistics[i].UpdateStatistics(value, classLabel, probability);
00150 }
00151
00152
00153 for (int i=0; i<csplitDim; i++)
00154 {
00155 double value=Cvars[i];
00156 continuousStatistics[i].UpdateStatistics(value, classLabel, probability);
00157 }
00158 }
00159
00160 bool ComputeSplitVariable(void)
00161 {
00162
00163 if (mass0<2.0 || mass-mass0<2.0)
00164 {
00165 #ifdef DEBUG_PRINT
00166 cout << "Making the node a leaf" << endl;
00167 #endif
00168
00169 return false;
00170 }
00171
00172 double maxgini=0.0;
00173
00174
00175 for (int i=0; i<dsplitDim; i++)
00176 {
00177 double curr_gini=discreteStatistics[i].ComputeGiniGain();
00178 #ifdef DEBUG_PRINT
00179
00180 cout << "\tVariable: " << i << " gini=" << curr_gini << endl;
00181 #endif
00182
00183 if (curr_gini>maxgini)
00184 {
00185 maxgini=curr_gini;
00186 SplitVariable=i;
00187 }
00188 }
00189
00190
00191 for (int i=0; i<csplitDim; i++)
00192 {
00193 double curr_gini=continuousStatistics[i].ComputeGiniGain();
00194 #ifdef DEBUG_PRINT
00195
00196 cout << "\tVariable: " << (-i-1) << " gini=" << curr_gini << endl;
00197 #endif
00198
00199 if (curr_gini>maxgini)
00200 {
00201 maxgini=curr_gini;
00202 SplitVariable=-(i+1);
00203 }
00204 }
00205
00206 if (maxgini==0.0)
00207 return false;
00208
00209
00210 if (SplitVariable>=0)
00211 splitSetProbability=discreteStatistics[SplitVariable].GetProbabilitySet();
00212 else
00213 {
00214 int splitVar=-SplitVariable-1;
00215 splitPoint=continuousStatistics[splitVar].GetSplit();
00216 splitSTD=sqrt(continuousStatistics[splitVar].getSplitVariance());
00217 }
00218
00219 #ifdef DEBUG_PRINT
00220 cout << "Split variable is " << SplitVariable << " and split point ";
00221 if (SplitVariable>=0)
00222 cout << splitSetProbability << endl;
00223 else
00224 cout << splitPoint << " " << splitSTD << endl;
00225 #endif
00226
00227 return true;
00228 }
00229
00230 double ComputeProbabilityFirstClass(void)
00231 {
00232 return (mass0/mass);
00233 }
00234
00235 bool MoreSplits(double minMass, int nodeId)
00236 {
00237 return ( mass>=minMass && mass0!=0.0 && mass0!=mass );
00238 }
00239
00240 void SaveToStream(ostream& out)
00241 {
00242 out << SplitVariable << " ";
00243 if (SplitVariable>=0)
00244 {
00245 out << "( " << mass << " ) [ ";
00246 for (int i=0; i<splitSetProbability.size(); i++)
00247 out << splitSetProbability[i] << " ";
00248 out << " ] ";
00249 }
00250 else
00251 {
00252 out << "( " << splitPoint << " " << splitSTD << " ) ";
00253 }
00254 }
00255 };
00256
00257 }
00258
00259 #endif //_CLUS_BINARYPROBABILISTICSPLITTER_H_