00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #if !defined _CLUS_BINARYOBLIQUESPLITTER_H
00035 #define _CLUS_BINARYOBLIQUESPLITTER_H
00036
00037 #include "general.h"
00038 #include "binarysplitter.h"
00039 #include "splitpointcomputation.h"
00040 #include "statisticsgatherers.h"
00041 #include <math.h>
00042
00043 #include <stdlib.h>
00044 #include <iostream>
00045
00046 using namespace TNT;
00047
00048 namespace CLUS
00049 {
00050
00051 class BinaryObliqueSplitter: public BinarySplitter
00052 {
00053 protected:
00054 int N;
00055 Vector<ProbabilisticBinomialStatistics> discreteStatistics;
00056 MultidimNormalStatistics continuousStatistics;
00057
00058 bool purePart;
00059
00060 double ProbabilityLeftPrivate(const int* Dvars, const double* Cvars)
00061 {
00062
00063 if (SplitVariable==-1)
00064 {
00065
00066 double crit=SeparatingHyperplane[0];
00067 for (int i=0; i<csplitDim+regDim; i++)
00068 crit+=Cvars[i]*SeparatingHyperplane[i+1];
00069
00070 return 1.0-PValueNormalDistribution(0.0,1.0,crit);
00071 }
00072 else
00073 if (SplitVariable <=-2)
00074 {
00075
00076 double crit=SeparatingHyperplane[0]+
00077 Cvars[-SplitVariable-2]*SeparatingHyperplane[-SplitVariable-1];
00078
00079 return 1.0-PValueNormalDistribution(0.0,1.0,crit);
00080 }
00081 else
00082 {
00083
00084 int value=Dvars[SplitVariable];
00085 return splitSetProbability[value];
00086 }
00087 }
00088
00089 public:
00090 BinaryObliqueSplitter():BinarySplitter(), N(0), purePart(false)
00091 {}
00092 BinaryObliqueSplitter(const Vector<int>& DDomainSize,int CsplitDim, int RegDim):
00093 BinarySplitter(DDomainSize,CsplitDim,RegDim), N(0),purePart(false)
00094 {}
00095 BinaryObliqueSplitter(BinarySplitter& aux):
00096 BinarySplitter(aux)
00097 { }
00098
00099
00100
00101 enum MultidimNormalStatistics::SeparationType CSepHypType;
00102
00103
00104
00105
00106 bool MoreSplits(int branch, int Min_no_datapoints)
00107 {
00108 if (purePart)
00109 return false;
00110
00111
00112
00113 if (branch==0)
00114 return continuousStatistics.GetS_P1()>=Min_no_datapoints;
00115 else
00116 return continuousStatistics.GetS_P2()>=Min_no_datapoints;
00117 }
00118
00119
00120 double ProbabilityLeft(const int* Dvars, const double* Cvars)
00121 {
00122 if (ProbabilityLeftPrivate(Dvars,Cvars)>.5)
00123 return 1.0;
00124 else
00125 return 0.0;
00126 }
00127
00128 int ChooseBranch( const int* Dvars, const double* Cvars)
00129 {
00130 if (ProbabilityLeftPrivate(Dvars,Cvars)>.5)
00131 return 0;
00132 else
00133 return 1;
00134 }
00135
00136 void InitializeSplitStatistics(void)
00137 {
00138 discreteStatistics.newsize(dsplitDim);
00139 for (int i=0; i<dsplitDim; i++)
00140 {
00141 discreteStatistics[i].ResetDomainSize(dDomainSize[i]);
00142 }
00143
00144 continuousStatistics.Resize(csplitDim+regDim);
00145 }
00146
00147 void UpdateSplitStatistics( const int* Dvars, const double* Cvars,
00148 double p1I, double p2I, double probability)
00149 {
00150 N++;
00151
00152 double p1=p1I*probability;
00153 double p2=p2I*probability;
00154
00155
00156 for(int i=0; i<dsplitDim; i++)
00157 {
00158 discreteStatistics[i].UpdateStatisticsP(Dvars[i],p1,p2);
00159 }
00160
00161 continuousStatistics.UpdateStatistics(Cvars,p1,p2);
00162 }
00163
00164 void DeleteTemporaryStatistics(void)
00165 {
00166
00167 discreteStatistics.newsize(0);
00168 continuousStatistics.Resize(0);
00169 }
00170
00171 int ComputeSplitVariable(int type = MultidimNormalStatistics::LDA)
00172 {
00173 double maxgini;
00174
00175 if (N==0)
00176 goto error;
00177
00178
00179 maxgini=continuousStatistics.ComputeGiniGain(type);
00180 SplitVariable=continuousStatistics.GetSplitVariable();
00181
00182
00183
00184 for (int i=0; i<dsplitDim; i++)
00185 {
00186 double gini=discreteStatistics[i].ComputeGiniGain();
00187
00188
00189
00190 if (gini>maxgini)
00191 {
00192 maxgini=gini;
00193 SplitVariable=i;
00194 }
00195 }
00196
00197 if (maxgini==0.0)
00198 goto error;
00199
00200 if (SplitVariable>=0)
00201 {
00202 splitSetProbability=discreteStatistics[SplitVariable].GetProbabilitySet();
00203 }
00204 else
00205 {
00206 SeparatingHyperplane=continuousStatistics.GetSeparatingHyperplane();
00207 }
00208
00209
00210
00211
00212 if (fabs(maxgini-continuousStatistics.MaxGini()) < TNNearlyZero )
00213 purePart=true;
00214
00215 #if 0
00216
00217 cout << "SplitVariable=" << SplitVariable << endl;
00218 if (SplitVariable>=0)
00219 {
00220 cout << " [[ ";
00221 for (int i=0; i<splitSetProbability.size(); i++)
00222 cout << splitSetProbability[i] << " ";
00223 cout << " ]]" << endl;
00224 ;
00225 }
00226 #endif
00227
00228 return 0;
00229
00230 error:
00231 cout << "Error encountered, killing node" << endl;
00232
00233 return -1;
00234 }
00235 };
00236
00237 }
00238
00239 #endif // _CLUS_BINARYOBLIQUESPLITTER_H