00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #if !defined _CLUS_MULTIDECISIONTREENODE_H
00035 #define _CLUS_MULTIDECISIONTREENODE_H
00036
00037 #include <list>
00038
00039 #include "vec.h"
00040 #include "continuouslineartransformation.h"
00041 #include "discretepermutationtransformation.h"
00042 #include "statisticsgatherers.h"
00043
00044 using namespace TNT;
00045
00046 namespace CLUS
00047 {
00048
00049 template< class T_Splitter >
00050 class MultiDecisionTreeNode
00051 {
00052
00053 int nodeId;
00054
00055
00056 enum state { stable, split} State;
00057
00058
00059 MultiDecisionTreeNode< T_Splitter > * Children[2], *parent;
00060
00061
00062 int classLabel;
00063
00064
00065 T_Splitter Splitter;
00066
00067
00068
00069 Vector<int> pruningMistakes;
00070
00071
00072 Vector<int> pruningSamples;
00073
00074 double ComputePruningCost(void)
00075 {
00076 int totalMistakes=0;
00077 int totalSamples=0;
00078 for (int i=0; i<pruningMistakes.dim(); i++)
00079 {
00080 totalMistakes+=pruningMistakes[i];
00081 totalSamples+=pruningSamples[i];
00082 }
00083 if (totalMistakes==0)
00084 return 0.0;
00085 return ((double)totalMistakes);
00086 }
00087
00088 double ComputeNodeWeight(void)
00089 {
00090
00091 if (State==stable)
00092 return 0.0;
00093 else
00094 return 1.0;
00095 }
00096
00097 public:
00098 MultiDecisionTreeNode( MultiDecisionTreeNode< T_Splitter >* Parent,
00099 int NodeId, const Vector<int> & DDomainSize,
00100 int CsplitDim, int NoDatasets,
00101 DiscretePermutationTransformation& discreteTransformer,
00102 ContinuousLinearTransformation& continuousTransformer):
00103 nodeId(NodeId), State(split), parent(Parent), Splitter(DDomainSize,CsplitDim,NoDatasets,
00104 discreteTransformer, continuousTransformer),
00105 pruningMistakes(NoDatasets), pruningSamples(NoDatasets)
00106 {
00107 Children[0]=Children[1]=0;
00108 Splitter.setNodeID(NodeId);
00109 cout << "Created node with ID=" << NodeId << endl;
00110 }
00111
00112
00113 ~MultiDecisionTreeNode(void)
00114 {
00115 if (Children[0]!=0)
00116 delete Children[0];
00117 Children[0]=0;
00118
00119 if (Children[1]!=0)
00120 delete Children[1];
00121 Children[1]=0;
00122 }
00123
00124 void StartLearningEpoch(void)
00125 {
00126 switch (State)
00127 {
00128 case stable:
00129 if (Children[0]!=0)
00130 {
00131 Children[0]->StartLearningEpoch();
00132 Children[1]->StartLearningEpoch();
00133 }
00134 break;
00135 case split:
00136 Splitter.InitializeSplitStatistics();
00137 break;
00138 }
00139 }
00140
00141 void LearnSample(const int* Dvars, const double* Cvars, int classlabel, int datasetNo)
00142 {
00143 switch (State)
00144 {
00145 case stable:
00146
00147 if (Children[0]!=0)
00148 {
00149 Children[Splitter.ChooseBranch(Dvars,Cvars)]->LearnSample(Dvars,Cvars,classlabel,datasetNo);
00150 }
00151 break;
00152 case split:
00153 Splitter.UpdateSplitStatistics(Dvars, Cvars, classlabel, datasetNo);
00154 break;
00155 }
00156 }
00157
00158
00159
00160
00161
00162
00163 bool FindSplitAttributes(list<int>& attList)
00164 {
00165 switch (State)
00166 {
00167 case stable:
00168 if (Children[0]!=0)
00169 {
00170 if (!Children[0]->FindSplitAttributes(attList) ||
00171 !Children[1]->FindSplitAttributes(attList) )
00172 {
00173 delete Children[0];
00174 Children[0]=0;
00175 delete Children[1];
00176 Children[1]=0;
00177 }
00178 }
00179 break;
00180
00181 case split:
00182
00183 if (Splitter.GotNoData())
00184 {
00185
00186 return false;
00187 }
00188
00189
00190
00191 if (!Splitter.ComputeSplitVariable(attList))
00192 {
00193
00194
00195 State=stable;
00196 Children[1]=0;
00197 classLabel=Splitter.ComputeClassLabel();
00198
00199 }
00200 break;
00201 }
00202
00203 return true;
00204 }
00205
00206 double ComputeTotalWeight(void)
00207 {
00208 double totalWeight=ComputeNodeWeight();
00209 if (State==stable)
00210 if (Children[0]!=0)
00211 totalWeight+=Children[0]->ComputeTotalWeight()+
00212 Children[1]->ComputeTotalWeight();
00213 return totalWeight;
00214 }
00215
00216
00217 double computeSumOfVarianceInverted(int attribute, int dataSetIndex)
00218 {
00219 if (Splitter.hasContinuousData(attribute, dataSetIndex))
00220 {
00221 double sum = (double)1/Splitter.getSplitVariance(attribute, dataSetIndex);
00222 if (State==stable)
00223 if (Children[0]!=0)
00224 {
00225 sum+= Children[0]->computeSumOfVarianceInverted(attribute, dataSetIndex) +
00226 Children[1]->computeSumOfVarianceInverted(attribute, dataSetIndex);
00227 }
00228 return sum;
00229 }
00230 else
00231 return 0.0;
00232 }
00233
00234
00235 double combineSplits(int attribute, int dataSetIndex, double* sumInvVars)
00236 {
00237
00238 double total=0;
00239 if (Splitter.hasContinuousData(attribute, dataSetIndex))
00240 {
00241 double var = Splitter.getSplitVariance(attribute, dataSetIndex) ;
00242 assert(!isnan(var) && var!=0);
00243
00244
00245 double splitpt = Splitter.getTentativeSplitPoint(attribute, dataSetIndex);
00246 assert(!isnan(splitpt));
00247 total = splitpt/var;
00248 *sumInvVars+=1.0/var;
00249
00250
00251 if (State==stable)
00252 if (Children[0]!=0)
00253 {
00254 total+= Children[0]->combineSplits(attribute, dataSetIndex, sumInvVars) +
00255 Children[0]->combineSplits(attribute, dataSetIndex, sumInvVars);
00256 }
00257 }
00258
00259 return total;
00260 }
00261
00262 double combineLabeledCenters(int attribute, int dataSetIndex, double* sumInvVars)
00263 {
00264
00265 double total = 0.0;
00266 double posWeight = Splitter.getLabeledCount(attribute, dataSetIndex, true);
00267 double negWeight = Splitter.getLabeledCount(attribute, dataSetIndex, false);
00268 double posMeani = Splitter.getLabeledCenter(attribute, dataSetIndex, true);
00269 double negMeani = Splitter.getLabeledCenter(attribute, dataSetIndex, false);
00270 double posMean0 = Splitter.getLabeledCenter(attribute, 0, true);
00271 double negMean0 = Splitter.getLabeledCenter(attribute, 0, false);
00272 double posVar=Splitter.getLabeledCenterVariance(attribute, dataSetIndex, true);
00273 double negVar = Splitter.getLabeledCenterVariance(attribute, dataSetIndex, false);
00274 double combinedVar = (posWeight*posVar + negWeight*negVar)/(posWeight+negWeight);
00275 double currShift= (posWeight*(posMean0-posMeani) + negWeight*(negMean0-negMeani))/ (posWeight+negWeight);
00276 if (!isnan(combinedVar) && NonZero(combinedVar) && !isnan(currShift))
00277 {
00278 total += currShift/combinedVar;
00279
00280 *sumInvVars+=1.0/combinedVar;
00281 }
00282 if (State==stable)
00283 if (Children[0]!=0)
00284 {
00285 total+= Children[0]->combineLabeledCenters(attribute, dataSetIndex, sumInvVars) +
00286 Children[0]->combineLabeledCenters(attribute, dataSetIndex, sumInvVars);
00287 }
00288
00289 return total;
00290 }
00291 bool labeledMeansSignificant(int attribute, int dataSetIndex)
00292 {
00293 return Splitter.labeledMeansSignificant(attribute, dataSetIndex);
00294 }
00295
00296 bool negMeanLessThanPos(int attribute, int dataSetIndex)
00297 {
00298 return Splitter.negMeanLessThanPos(attribute, dataSetIndex);
00299 }
00300
00301 double combineCenters(int attribute, int dataSetIndex, double* sumInvVars)
00302 {
00303
00304 double total=0.0;
00305 if (Splitter.hasContinuousData(attribute, dataSetIndex))
00306 {
00307 double var = Splitter.getVariance(attribute, dataSetIndex) ;
00308
00309
00310
00311 double splitpt = Splitter.getTentativeCenter(attribute, dataSetIndex);
00312
00313 if (!isnan(var) && NonZero(var) && !isnan(splitpt))
00314 {
00315 total = splitpt/var;
00316 *sumInvVars+=1.0/var;
00317 }
00318
00319 if (State==stable)
00320 if (Children[0]!=0)
00321 {
00322 total+= Children[0]->combineCenters(attribute, dataSetIndex, sumInvVars) +
00323 Children[0]->combineCenters(attribute, dataSetIndex, sumInvVars);
00324 }
00325 }
00326
00327 return total;
00328 }
00329
00330 void AddDiscreteShiftStatistics(int SplitAttribute, Vector< BinomialStatistics >& statistics)
00331 {
00332 Splitter.AddDiscreteShiftStatistics(SplitAttribute, ComputeNodeWeight(), statistics);
00333 if (State==stable)
00334 if (Children[0]!=0)
00335 {
00336 Children[0]->AddDiscreteShiftStatistics(SplitAttribute,statistics);
00337 Children[1]->AddDiscreteShiftStatistics(SplitAttribute,statistics);
00338 }
00339 }
00340
00341 void AddContinuousShiftStatistics(int SplitAttribute, Vector< NormalStatistics >& statistics)
00342 {
00343 Splitter.AddContinuousShiftStatistics(SplitAttribute, ComputeNodeWeight(), statistics);
00344 if (State==stable)
00345 if (Children[0]!=0)
00346 {
00347 Children[0]->AddContinuousShiftStatistics(SplitAttribute,statistics);
00348 Children[1]->AddContinuousShiftStatistics(SplitAttribute,statistics);
00349 }
00350 }
00351
00352
00353 bool StopLearningEpoch(int splitType, int min_no_datapoints)
00354 {
00355 switch (State)
00356 {
00357 case stable:
00358 if (Children[0]!=0)
00359 return Children[0]->StopLearningEpoch(splitType, min_no_datapoints)
00360 | Children[1]->StopLearningEpoch(splitType, min_no_datapoints);
00361 else
00362 return false;
00363
00364 case split:
00365
00366 State=stable;
00367 Splitter.ComputeSplitPoint();
00368 classLabel=Splitter.ComputeClassLabel();
00369
00370
00371
00372 if (Splitter.MoreSplits(min_no_datapoints, nodeId))
00373 {
00374 Children[0] = new MultiDecisionTreeNode< T_Splitter >
00375 ( this, nodeId*2, Splitter.GetDDomainSize(), Splitter.GetCSplitDim(),
00376 Splitter.GetNoDatasets(),
00377 Splitter.GetDiscreteTransformer(), Splitter.GetContinuousTransformer());
00378
00379 Children[1] = new MultiDecisionTreeNode< T_Splitter >
00380 ( this, nodeId*2+1, Splitter.GetDDomainSize(), Splitter.GetCSplitDim(),
00381 Splitter.GetNoDatasets(),
00382 Splitter.GetDiscreteTransformer(), Splitter.GetContinuousTransformer());
00383 }
00384 else
00385 {
00386 Children[0]=Children[1]=0;
00387 }
00388
00389 return true;
00390 break;
00391
00392 default:
00393 return true;
00394
00395 }
00396 }
00397
00398 Permutation ComputeDiscreteShift(bool label, int attribute, int datasetIndex)
00399 {
00400
00401 return Splitter.ComputeDiscreteShift(label, attribute, datasetIndex);
00402
00403 }
00404
00405
00406 double Infer(const int* Dvars, const double* Cvars)
00407 {
00408
00409 if (Children[0]==0)
00410 {
00411
00412 return classLabel;
00413 }
00414 else
00415 {
00416 return Children[Splitter.ChooseBranch(Dvars,Cvars)]->Infer(Dvars,Cvars);
00417 }
00418 }
00419
00420 void InitializePruningStatistics(void)
00421 {
00422 pruningMistakes=0;
00423 pruningSamples=0;
00424 if (Children[0]!=0 && Children[1]!=0)
00425 {
00426 Children[0]->InitializePruningStatistics();
00427 Children[1]->InitializePruningStatistics();
00428 }
00429 }
00430
00431 void UpdatePruningStatistics(const int* Dvars, const double* Cvars, int classlabel , int datasetNo )
00432 {
00433 pruningSamples[datasetNo]++;
00434 if (classLabel!=classlabel)
00435 pruningMistakes[datasetNo]++;
00436
00437 if (Children[0]!=0 && Children[1]!=0)
00438 Children[Splitter.ChooseBranch(Dvars,Cvars)]->UpdatePruningStatistics(Dvars,Cvars,classlabel, datasetNo);
00439 }
00440
00441 void FinalizePruningStatistics (void)
00442 {
00443
00444 }
00445
00446
00447 double PruneSubtree(void)
00448 {
00449 double pruningCost=ComputePruningCost();
00450 if (Children[0] == 0 && Children[1] == 0)
00451 {
00452
00453 return pruningCost;
00454 }
00455 else
00456 {
00457
00458 double pruningCostChildren=(Children[0]->PruneSubtree()+Children[1]->PruneSubtree());
00459
00460 if (pruningCost<pruningCostChildren)
00461 {
00462
00463 delete Children[0];
00464 Children[0]=0;
00465 delete Children[1];
00466 Children[1]=0;
00467
00468 return pruningCost;
00469 }
00470 else
00471 {
00472
00473 return pruningCostChildren;
00474 }
00475 }
00476 }
00477
00478 void SaveToStream(ostream& out)
00479 {
00480 out << "{" << endl << " nodeID " << nodeId << endl;
00481
00482
00483 out << endl;
00484
00485
00486 if (Children[0]!=0 && Children[1]!=0)
00487 {
00488 Splitter.SaveToStream(out, false);
00489 Children[0]->SaveToStream(out);
00490 Children[1]->SaveToStream(out);
00491 }
00492 else
00493 {
00494 out << " leaf " << endl;
00495 Splitter.SaveToStream(out, true);
00496 }
00497 }
00498 };
00499
00500 }
00501
00502 #endif // _CLUS_MULTIDECISIONTREENODE_H