00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #if !defined _CLUS_SIMPLEBINARYSPLITTER_H_
00035 #define _CLUS_SIMPLEBINARYSPLITTER_H_
00036
00037 #include "statisticsgatherers.h"
00038
00039 using namespace TNT;
00040
00041 namespace CLUS
00042 {
00043
00044
00045 class SimpleBinarySplitter
00046 {
00047 protected:
00048
00049 int dsplitDim, csplitDim;
00050
00051
00052
00053 const Vector<int>& dDomainSize;
00054
00055
00056 int SplitVariable;
00057
00058
00059 double splitPoint;
00060
00061
00062
00063
00064 Vector<int> SeparatingSet;
00065
00066
00067
00068
00069
00070 Vector<BinomialStatistics> discreteStatistics;
00071 Vector<NormalStatistics> continuousStatistics;
00072
00073
00074 int count;
00075 int countC0;
00076 public:
00077 SimpleBinarySplitter(const Vector<int>& DDomainSize,int CsplitDim):
00078 dsplitDim(DDomainSize.dim()-1), csplitDim(CsplitDim),
00079 dDomainSize(DDomainSize)
00080 {}
00081
00082 ~SimpleBinarySplitter(void)
00083 {}
00084
00085 int GetCSplitDim(void)
00086 {
00087 return csplitDim;
00088 }
00089 int GetDSplitDim(void)
00090 {
00091 return dsplitDim;
00092 }
00093 const Vector<int>& GetDDomainSize(void)
00094 {
00095 return dDomainSize;
00096 }
00097
00098
00099 void InitializeSplitStatistics(void)
00100 {
00101 count=0;
00102 countC0=0;
00103
00104 discreteStatistics.newsize(dsplitDim);
00105 for (int i=0; i<dsplitDim; i++)
00106 {
00107 discreteStatistics[i].ResetDomainSize(dDomainSize[i]);
00108 }
00109
00110 continuousStatistics.newsize(csplitDim);
00111 }
00112
00113
00114 int ChooseBranch( const int* Dvars, const double* Cvars)
00115 {
00116 if (SplitVariable <=-1)
00117 {
00118
00119
00120 if (Cvars[-SplitVariable-1]<=splitPoint)
00121 return 0;
00122 else
00123 return 1;
00124 }
00125 else
00126 {
00127
00128 int value=Dvars[SplitVariable];
00129
00130 if (IsPointInSet<int>(value, SeparatingSet))
00131 return 0;
00132 else
00133 return 1;
00134 }
00135 }
00136
00137 void UpdateSplitStatistics( const int* Dvars, const double* Cvars,
00138 int classLabel)
00139 {
00140 #ifdef DEBUG_PRINT
00141 cout << "CL=" << classLabel << " ";
00142 #endif
00143
00144 count++;
00145 if (classLabel==0)
00146 countC0++;
00147
00148 #ifdef DEBUG_PRINT
00149
00150 cout << "count=" << count << " countC0=" << countC0 << endl;
00151 #endif
00152
00153 for (int i=0; i<dsplitDim; i++)
00154 {
00155 int value=Dvars[i];
00156 discreteStatistics[i].UpdateStatistics(value, classLabel);
00157 }
00158
00159
00160 for (int i=0; i<csplitDim; i++)
00161 {
00162 double value=Cvars[i];
00163 continuousStatistics[i].UpdateStatistics(value, classLabel);
00164 }
00165 }
00166
00167 bool ComputeSplitVariable(void)
00168 {
00169
00170 if (countC0<2.0 || count-countC0<2.0)
00171 {
00172 #ifdef DEBUG_PRINT
00173 cout << "Making the node a leaf. countC0=" << countC0
00174 << " count=" << count << endl;
00175 #endif
00176
00177 return false;
00178 }
00179
00180 double maxgini=0.0;
00181
00182
00183 for (int i=0; i<dsplitDim; i++)
00184 {
00185 double curr_gini=discreteStatistics[i].ComputeGiniGain();
00186 #ifdef DEBUG_PRINT
00187
00188 cout << "\tVariable: " << i << " gini=" << curr_gini << endl;
00189 #endif
00190
00191 if (curr_gini>maxgini)
00192 {
00193 maxgini=curr_gini;
00194 SplitVariable=i;
00195 }
00196 }
00197
00198
00199 for (int i=0; i<csplitDim; i++)
00200 {
00201 double curr_gini=continuousStatistics[i].ComputeGiniGain();
00202 #ifdef DEBUG_PRINT
00203
00204 cout << "\tVariable: " << (-i-1) << " gini=" << curr_gini << endl;
00205 #endif
00206
00207 if (curr_gini>maxgini)
00208 {
00209 maxgini=curr_gini;
00210 SplitVariable=-(i+1);
00211 }
00212 }
00213
00214 if (maxgini==0.0)
00215 return false;
00216
00217
00218 if (SplitVariable>=0)
00219 SeparatingSet=discreteStatistics[SplitVariable].GetSplit();
00220 else
00221 {
00222 int splitVar=-SplitVariable-1;
00223 splitPoint=continuousStatistics[splitVar].GetSplit();
00224 }
00225
00226 #ifdef DEBUG_PRINT
00227 cout << "Split variable is " << SplitVariable << " and split point ";
00228 if (SplitVariable>=0)
00229 cout << SeparatingSet << endl;
00230 else
00231 cout << splitPoint << endl;
00232 #endif
00233
00234 return true;
00235 }
00236
00237 int ComputeClassLabel(void)
00238 {
00239 if (countC0>=count-countC0)
00240 return 0;
00241 else
00242 return 1;
00243 }
00244
00245 bool MoreSplits(int minMass, int nodeId)
00246 {
00247 return ( count>=minMass && countC0!=0 && countC0!=count );
00248 }
00249
00250 void SaveToStream(ostream& out)
00251 {
00252 out << SplitVariable << " ";
00253 if (SplitVariable>=0)
00254 {
00255 out << "( " << count << " ) [ ";
00256 for (int i=0; i<SeparatingSet.size(); i++)
00257 out << SeparatingSet[i] << " ";
00258 out << " ] ";
00259 }
00260 else
00261 {
00262 out << "( " << splitPoint << " ) ";
00263 }
00264 }
00265 };
00266 }
00267
00268 #endif // _CLUS_SIMPLEBINARYSPLITTER_H_