00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #if !defined _CLUS_BINARYMULTICLASSICICATIONSPLITTER_H
00035 #define _CLUS_BINARYMULTICLASSICICATIONSPLITTER_H
00036
00037 #include "statisticsgatherers.h"
00038 #include "discretepermutationtransformation.h"
00039
00040 using namespace TNT;
00041
00042 namespace CLUS
00043 {
00044 class BinaryMultiClassificationSplitter
00045 {
00046
00047 int dsplitDim;
00048
00049
00050 int csplitDim;
00051
00052
00053 int id;
00054
00055 bool noWeighting;
00056
00057 int noDatasets;
00058
00059 const Vector<int>& dDomainSize;
00060
00061
00062 int SplitVariable;
00063
00064
00065 double splitPoint;
00066
00067
00068
00069
00070 Vector<int> SeparatingSet;
00071
00072 DiscretePermutationTransformation& discreteTransformer;
00073 ContinuousLinearTransformation& continuousTransformer;
00074
00075
00076
00077
00078
00079 Vector< Vector<BinomialStatistics> > discreteStatistics;
00080 Vector< Vector<NormalStatistics> > continuousStatistics;
00081
00082
00083 Vector< Vector<Permutation> > discreteShifts;
00084
00085
00086 Vector< Vector<double> > continuousShifts;
00087
00088 Vector<int> examplesSeen;
00089 Vector<int> examples0Seen;
00090
00091
00092 int count;
00093 int countC0;
00094
00095 public:
00096 BinaryMultiClassificationSplitter(const Vector<int>& DDomainSize,int CsplitDim,
00097 int NoDatasets,
00098 DiscretePermutationTransformation& DiscreteTransformer,
00099 ContinuousLinearTransformation& ContinuousTransformer):
00100 dsplitDim(DDomainSize.dim()-1), csplitDim(CsplitDim), noWeighting(true),
00101 noDatasets(NoDatasets), dDomainSize(DDomainSize),
00102 discreteTransformer(DiscreteTransformer), continuousTransformer(ContinuousTransformer),
00103 discreteStatistics(0), continuousStatistics(0),
00104 discreteShifts(0), continuousShifts(0),
00105 examplesSeen(noDatasets), examples0Seen(noDatasets)
00106 {}
00107
00108 ~BinaryMultiClassificationSplitter(void)
00109 {}
00110 void setNodeID(int ID)
00111 {
00112 id=ID;
00113 }
00114
00115 bool GotNoData(void)
00116 {
00117 return count==0;
00118 }
00119
00120 int GetCSplitDim(void)
00121 {
00122 return csplitDim;
00123 }
00124
00125 int GetDSplitDim(void)
00126 {
00127 return dsplitDim;
00128 }
00129
00130 int GetNoDatasets(void)
00131 {
00132 return noDatasets;
00133 }
00134
00135 const Vector<int>& GetDDomainSize(void)
00136 {
00137 return dDomainSize;
00138 }
00139
00140 double getLabeledCount(int attribute, int dataSetIndex, bool b)
00141 {
00142 if (b)
00143 return continuousStatistics[attribute][dataSetIndex].getcountC1();
00144 else
00145 return continuousStatistics[attribute][dataSetIndex].getcountC0();
00146 }
00147
00148 DiscretePermutationTransformation& GetDiscreteTransformer(void)
00149 {
00150 return discreteTransformer;
00151 }
00152
00153 ContinuousLinearTransformation& GetContinuousTransformer(void)
00154 {
00155 return continuousTransformer;
00156 }
00157
00158
00159 void InitializeSplitStatistics(void)
00160 {
00161 count=0;
00162 countC0=0;
00163
00164 discreteStatistics.newsize(dsplitDim);
00165 for (int i=0; i<dsplitDim; i++)
00166 {
00167
00168 if (discreteTransformer.HasAttributeShifts(i))
00169 {
00170 discreteStatistics[i].newsize(1);
00171 discreteStatistics[i][0].ResetDomainSize(dDomainSize[i]);
00172 }
00173 else
00174 {
00175 discreteStatistics[i].newsize(noDatasets);
00176 for (int j=0; j<noDatasets; j++)
00177 discreteStatistics[i][j].ResetDomainSize(dDomainSize[i]);
00178 }
00179 }
00180
00181 continuousStatistics.newsize(csplitDim);
00182 for (int i=0; i<csplitDim; i++)
00183 {
00184
00185 if (continuousTransformer.HasAttributeShifts(i))
00186 {
00187 continuousStatistics[i].newsize(1);
00188 }
00189 else
00190 {
00191 continuousStatistics[i].newsize(noDatasets);
00192 }
00193 }
00194 }
00195
00196
00197
00198
00199
00200
00201 int ChooseBranch( const int* Dvars, const double* Cvars)
00202 {
00203 if (SplitVariable <=-1)
00204 {
00205
00206
00207 if (Cvars[-SplitVariable-1]<=splitPoint)
00208 return 0;
00209 else
00210 return 1;
00211 }
00212 else
00213 {
00214
00215 int value=Dvars[SplitVariable];
00216
00217 bool pickleft=false;
00218 int l=0, r=SeparatingSet.dim()-1;
00219
00220 assert(r>=0);
00221
00222 while ( l<=r && !pickleft )
00223 {
00224 int m=(l+r)/2;
00225 int vm=SeparatingSet[m];
00226 if ( vm ==value )
00227 {
00228 pickleft=true;
00229 break;
00230 }
00231 if ( vm < value )
00232 l=m+1;
00233 else
00234 r=m-1;
00235 }
00236 if (pickleft)
00237 return 0;
00238 else
00239 return 1;
00240 }
00241 }
00242
00243
00244
00245
00246
00247
00248
00249 void UpdateSplitStatistics( const int* Dvars, const double* Cvars,
00250 int classLabel, int datasetNo)
00251 {
00252 count++;
00253 examplesSeen[datasetNo]++;
00254 if (classLabel==0)
00255 {
00256 countC0++;
00257 examples0Seen[datasetNo]++;
00258 }
00259
00260
00261 for (int i=0; i<dsplitDim; i++)
00262 {
00263
00264 if (discreteTransformer.HasAttributeShifts(i))
00265 (discreteStatistics[i])[0].UpdateStatistics(Dvars[i],classLabel);
00266 else
00267 (discreteStatistics[i])[datasetNo].UpdateStatistics(Dvars[i],classLabel);
00268 }
00269
00270 for (int i=0; i<csplitDim; i++)
00271 {
00272
00273 if (continuousTransformer.HasAttributeShifts(i))
00274 (continuousStatistics[i])[0].UpdateStatistics(Cvars[i],classLabel);
00275 else
00276 (continuousStatistics[i])[datasetNo].UpdateStatistics(Cvars[i],classLabel);
00277 }
00278 }
00279
00280 bool labeledMeansSignificant(int attribute, int dataSetIndex)
00281 {
00282 return continuousStatistics[attribute][dataSetIndex].labeledMeansSignificant();
00283 }
00284
00285 bool negMeanLessThanPos(int attribute, int dataSetIndex)
00286 {
00287 return continuousStatistics[attribute][dataSetIndex].negMeanLessThanPos();
00288 }
00289
00290 bool hasContinuousData(int attribute, int dataSetIndex)
00291 {
00292 return continuousStatistics[attribute][dataSetIndex].hasData();
00293 }
00294
00295 void DeleteTemporaryStatistics(void)
00296 {
00297
00298
00299 }
00300
00301
00302
00303
00304
00305
00306
00307 bool ComputeSplitVariable(list<int>& attList)
00308 {
00309 cerr << "nodeID " << id<< endl;
00310 cout << "Counts:" << countC0 << "," << count-countC0 << endl;
00311 if (countC0<2 || count-countC0<2)
00312 {
00313 #ifdef DEBUG_PRINT
00314 cout << "Making the node a leaf. Counts: " << countC0 << "," << count-countC0 << endl;
00315 #endif
00316
00317 return false;
00318 }
00319
00320
00321 double maxgini=0.0;
00322
00323
00324
00325 for (int i=0; i<dsplitDim; i++)
00326 {
00327 double curr_gini=0.0;
00328
00329
00330
00331
00332
00333
00334
00335 if (discreteTransformer.HasAttributeShifts(i))
00336 {
00337 curr_gini=discreteStatistics[i][0].ComputeGiniGain();
00338 }
00339 else
00340 {
00341 double total=0;
00342 double weight=0;
00343
00344 for (int j=0; j<noDatasets; j++)
00345 {
00346 weight = discreteStatistics[i][j].getCount();
00347 curr_gini+=weight*discreteStatistics[i][j].ComputeGiniGain();
00348 total += weight;
00349 }
00350
00351 curr_gini/=total;
00352 }
00353 if (curr_gini>maxgini)
00354 {
00355 maxgini=curr_gini;
00356 SplitVariable=i;
00357 }
00358 }
00359
00360
00361 for (int i=0; i<csplitDim; i++)
00362 {
00363 double curr_gini=0.0;
00364
00365
00366
00367
00368
00369
00370
00371
00372 if (continuousTransformer.HasAttributeShifts(i))
00373 {
00374 curr_gini=continuousStatistics[i][0].ComputeGiniGain();
00375 }
00376 else
00377 {
00378 double weight =0;
00379 double total=0;
00380
00381 for (int j=0; j<noDatasets; j++)
00382 {
00383
00384 NormalStatistics stat = continuousStatistics[i][j];
00385 if (stat.nonZero())
00386 {
00387 weight = continuousStatistics[i][j].getcountC0() + continuousStatistics[i][j].getcountC1();
00388 curr_gini+=weight*continuousStatistics[i][j].ComputeGiniGain();
00389 total+=weight;
00390 }
00391 }
00392
00393 curr_gini/=total;
00394 }
00395 if (curr_gini>maxgini)
00396 {
00397 maxgini=curr_gini;
00398 SplitVariable=-(i+1);
00399 }
00400 }
00401
00402
00403
00404
00405
00406 if (maxgini==0.0)
00407 return false;
00408
00409
00410 cout << "Chosen split variable" << SplitVariable << " maxgini=" << maxgini << endl;
00411 if ( (SplitVariable>=0 && !discreteTransformer.HasAttributeShifts(SplitVariable)) ||
00412 (SplitVariable<0 && !continuousTransformer.HasAttributeShifts(-SplitVariable-1)) )
00413 {
00414
00415 attList.push_front(SplitVariable);
00416 }
00417
00418 return true;
00419 }
00420
00421 Permutation ComputeDiscreteShift(bool label, int attribute, int datasetIndex)
00422 {
00423 if (label)
00424 return discreteStatistics[attribute][0].ComputeShift(labeled, discreteStatistics[attribute][datasetIndex]);
00425 else
00426 return discreteStatistics[attribute][0].ComputeShift(unlabeled, discreteStatistics[attribute][datasetIndex]);
00427 }
00428
00429 void AddDiscreteShiftStatistics(int SplitAttribute, double weight,
00430 Vector< BinomialStatistics >& statistics)
00431 {
00432 if (weight==0.0)
00433 return;
00434
00435 for (int i=0; i<statistics.dim(); i++)
00436 statistics[i].AddWeightedStatistics(discreteStatistics[SplitAttribute][i], weight);
00437 }
00438
00439
00440
00441
00442
00443
00444
00445
00446
00447
00448
00449
00450
00451 void ComputeCenter(void)
00452 {
00453
00454
00455 if (SplitVariable>=0)
00456 {
00457
00458 int noDatasets=discreteStatistics[SplitVariable].dim();
00459 if (noDatasets>1)
00460 {
00461
00462 for (int i=1; i<noDatasets; i++)
00463 {
00464
00465 discreteStatistics[SplitVariable][0].AddStatisticsShifted(discreteStatistics[SplitVariable][i],
00466 discreteTransformer.GetShift(SplitVariable, i));
00467 }
00468 }
00469 discreteStatistics[SplitVariable][0].ComputeGiniGain();
00470 SeparatingSet=discreteStatistics[SplitVariable][0].GetSplit();
00471
00472
00473 }
00474 else
00475 {
00476 int splitVar=-SplitVariable-1;
00477 if (continuousStatistics[splitVar].dim()>1)
00478 {
00479
00480 for (int i=1; i<noDatasets; i++)
00481 {
00482 continuousStatistics[splitVar][0].AddStatisticsShifted(continuousStatistics[splitVar][i],
00483 continuousTransformer.GetShift(splitVar, i));
00484 }
00485 }
00486 continuousStatistics[splitVar][0].ComputeGiniGain();
00487 splitPoint=continuousStatistics[splitVar][0].GetCenter();
00488
00489 }
00490 }
00491
00492 double getLabeledCenter(int attribute, int dataSetIndex, bool b)
00493 {
00494 return continuousStatistics[attribute][dataSetIndex].getLabeledCenter(b);
00495 }
00496
00497 double getLabeledCenterVariance(int attribute, int dataSetIndex, bool b)
00498 {
00499 return continuousStatistics[attribute][dataSetIndex].getLCVariance(b);
00500 }
00501
00502 double getVariance(int attribute, int dataSetIndex)
00503 {
00504 return continuousStatistics[attribute][dataSetIndex].getVariance();
00505 }
00506
00507
00508
00509
00510
00511 double getSplitVariance(int attribute, int dataSetIndex)
00512 {
00513 if (noWeighting)
00514 return 1;
00515 return continuousStatistics[attribute][dataSetIndex].getSplitVariance();
00516 }
00517
00518
00519
00520
00521
00522 double getTentativeSplitPoint(int attribute, int dataSetIndex)
00523 {
00524 return continuousStatistics[attribute][dataSetIndex].GetSplit();
00525
00526 }
00527
00528
00529
00530
00531
00532 double getTentativeCenter(int attribute, int dataSetIndex)
00533 {
00534 return continuousStatistics[attribute][dataSetIndex].GetCenter();
00535
00536 }
00537
00538 void ComputeSplitPoint(void)
00539 {
00540
00541 if (SplitVariable>=0)
00542 {
00543
00544 int noDatasets=discreteStatistics[SplitVariable].dim();
00545 if (noDatasets>1)
00546 {
00547
00548 for (int i=1; i<noDatasets; i++)
00549 {
00550
00551 discreteStatistics[SplitVariable][0].AddStatisticsShifted(discreteStatistics[SplitVariable][i],
00552 discreteTransformer.GetShift(SplitVariable, i));
00553 }
00554 }
00555 discreteStatistics[SplitVariable][0].ComputeGiniGain();
00556 SeparatingSet=discreteStatistics[SplitVariable][0].GetSplit();
00557 }
00558 else
00559 {
00560 int splitVar=-SplitVariable-1;
00561 if (continuousStatistics[splitVar].dim()>1)
00562 {
00563
00564 for (int i=1; i<noDatasets; i++)
00565 {
00566 continuousStatistics[splitVar][0].AddStatisticsShifted(continuousStatistics[splitVar][i],
00567 continuousTransformer.GetShift(splitVar, i));
00568 }
00569 }
00570 continuousStatistics[splitVar][0].ComputeGiniGain();
00571 splitPoint=continuousStatistics[splitVar][0].GetSplit();
00572
00573 }
00574 }
00575
00576 int ComputeClassLabel(void)
00577 {
00578 if (countC0>=count-countC0)
00579 return 0;
00580 else
00581 return 1;
00582 }
00583
00584 double getContinuousShift(int attribute, int dataSetIndex)
00585 {
00586 return continuousTransformer.GetShift(attribute, dataSetIndex);
00587 }
00588
00589 bool MoreSplits(int min_no_datapoints, int nodeID)
00590 {
00591 bool discreteSplitGoneBad = (SplitVariable>=0 && SeparatingSet.size()==0);
00592 return ( (count>=min_no_datapoints) && countC0!=0 && countC0!=count && !discreteSplitGoneBad);
00593 }
00594
00595 void SaveToStream(ostream& out, bool isLeaf)
00596 {
00597
00598 out << " label " << ComputeClassLabel() << endl;
00599 for (int j=0; j < noDatasets; j++)
00600 {
00601
00602
00603 }
00604 out << "}" << endl << endl;
00605
00606 if (!isLeaf)
00607 {
00608 out << " split attribute: " << SplitVariable;
00609 if (SplitVariable<0)
00610 out << ", split point: " << splitPoint << endl << "}" << endl << endl;
00611 else
00612 {
00613 out << ", split set: ";
00614 for (int i=0; i<SeparatingSet.dim(); i++)
00615 out << SeparatingSet[i] << " ";
00616 out << endl << "}" << endl << endl;
00617 }
00618
00619 }
00620 }
00621
00622 };
00623 }
00624
00625 #endif // _CLUS_MULTIDECISIONTREENODE_H