00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #if !defined _CLUS_MULTIDECISIONTREE_H_
00035 #define _CLUS_MULTIDECISIONTREE_H_
00036
00037 #include "machine.h"
00038 #include "multidecisiontreenode.h"
00039 #include "dctraingen.h"
00040 #include "continuouslineartransformation.h"
00041 #include "discretepermutationtransformation.h"
00042 #include "statisticsgatherers.h"
00043 #include "general.h"
00044
00045 #include <iostream>
00046
00047
00048
00049 namespace CLUS
00050 {
00051
00052 template< class T_Splitter >
00053 class MultiDecisionTree : public Machine
00054 {
00055 protected:
00056 MultiDecisionTreeNode< T_Splitter >* root;
00057
00058
00059 const Vector<int>& dDomainSize;
00060
00061
00062 int dsplitDim;
00063
00064
00065 int csplitDim;
00066
00067
00068 int noDatasets;
00069
00070
00071 int min_no_datapoints;
00072
00073
00074 int splitType;
00075
00076 bool schemaMatch;
00077 bool labeled;
00078 bool alignSplits;
00079
00080 ContinuousLinearTransformation* continuousTransformer;
00081 DiscretePermutationTransformation* discreteTransformer;
00082 public:
00083 MultiDecisionTree(const Vector<int>& DDomainSize, int CsplitDim, int NoDatasets):
00084 Machine(CsplitDim,1),dDomainSize(DDomainSize),
00085 dsplitDim(DDomainSize.dim()-1),csplitDim(CsplitDim),noDatasets(NoDatasets)
00086 {
00087 ofstream file("x");
00088 continuousTransformer=0;
00089 discreteTransformer=0;
00090 min_no_datapoints = 10;
00091 splitType = 0;
00092 }
00093
00094 ~MultiDecisionTree(void)
00095 {
00096 if (root!=0)
00097 delete root;
00098
00099 if (continuousTransformer!=0)
00100 delete continuousTransformer;
00101
00102 if (discreteTransformer!=0)
00103 delete discreteTransformer;
00104 }
00105
00106 virtual int InDim(void)
00107 {
00108 return dsplitDim+csplitDim+1;
00109 }
00110
00111 virtual string TypeName(void)
00112 {
00113 return string("MultiDecisionTree");
00114 }
00115
00116 virtual void Infer(void)
00117 {
00118 if (root==0)
00119 return;
00120
00121 int Dvars[MAX_VARIABLES];
00122 for(int i=0; i<dsplitDim; i++)
00123 {
00124 Dvars[i]=(int)(*input)[i];
00125
00126 }
00127
00128 double Cvars[MAX_VARIABLES];
00129 for (int i=0; i<csplitDim; i++)
00130 {
00131 Cvars[i]=(*input)[i+dsplitDim];
00132
00133 }
00134
00135
00136 int datasetIndex=(int)(*input)[dsplitDim+csplitDim];
00137
00138 assert (datasetIndex<noDatasets);
00139
00140
00141 discreteTransformer->ApplyShiftToTuple(Dvars, datasetIndex);
00142 continuousTransformer->ApplyShiftToTuple(Cvars, datasetIndex);
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152 output[0]=root->Infer(Dvars, Cvars);
00153
00154 }
00155
00156 void printBitFlipData()
00157 {
00158 ofstream afile("hello");
00159 afile << "HELLO" << endl;
00160 bool significant[csplitDim];
00161
00162
00163
00164
00165
00166
00167
00168
00169 ofstream file("bitFlipStuff");
00170 bool bits[csplitDim];
00171 file << noDatasets;
00172 for (int j=0; j< csplitDim; j++)
00173 {
00174 bits[j]=root->negMeanLessThanPos(j, 0);
00175 }
00176 for (int i=1; i<noDatasets; i++)
00177 {
00178 file << "in bit flip loop" << endl;
00179
00180 file << "Dataset " << i << ":" << endl;
00181 for (int j=0; j< csplitDim; j++)
00182 {
00183 if (root->labeledMeansSignificant(j,0))
00184 {
00185 file << (root->negMeanLessThanPos(j, i)!= bits[j]) << endl;
00186 }
00187 }
00188 }
00189 }
00190
00191 virtual void Identify()
00192 {
00193
00194
00195 Matrix<double> ctrainData = training->GetTrainingData();
00196
00197 continuousTransformer=new ContinuousLinearTransformation(csplitDim, noDatasets,
00198 ctrainData) ;
00199 ofstream file("xyz");
00200 file << noDatasets << endl;
00201
00202
00203 Matrix<int> dtrainData = dynamic_cast< DCTrainingData* >( training )
00204 -> GetDiscreteTrainingData();
00205
00206 discreteTransformer=new DiscretePermutationTransformation(dsplitDim, noDatasets,
00207 ctrainData, dtrainData,
00208 dDomainSize);
00209
00210 int M=ctrainData.num_rows();
00211
00212 root = new MultiDecisionTreeNode<T_Splitter>
00213 (NULL, 1, dDomainSize, csplitDim, noDatasets,
00214 *discreteTransformer, *continuousTransformer);
00215
00216
00217 file << "here I am";
00218 file << noDatasets << endl;
00219
00220 if (schemaMatch)
00221 {
00222
00223
00224
00225 root->StartLearningEpoch();
00226 for(int i=0; i<M; i++)
00227 {
00228 int datasetIndex=(int)ctrainData[i][csplitDim];
00229
00230 ExitIf(datasetIndex>=noDatasets, "More datasets enountered than declared");
00231 int classLabel=dtrainData[i][dsplitDim];
00232 root->LearnSample(dtrainData[i],ctrainData[i], classLabel, datasetIndex);
00233 }
00234 file << noDatasets;
00235
00236 for (int att =0; att<dsplitDim; att++)
00237 {
00238
00239
00240 Vector<Permutation> shifts(noDatasets);
00241 shifts[0]=Permutation(dDomainSize[att]);
00242 for (int i=1; i<noDatasets; i++)
00243 shifts[i]= root->ComputeDiscreteShift(true, att, i);
00244 discreteTransformer->SetShiftsAttribute(att, shifts);
00245
00246 }
00247
00248 for (int att =0; att<csplitDim; att++)
00249 {
00250
00251 double sumInvVars=0.0;
00252 double split0;
00253 if (labeled)
00254 {
00255 split0 = 0.0;
00256
00257 }
00258 else
00259 {
00260 if (!alignSplits)
00261 split0= root->combineCenters(att, 0, &sumInvVars)/sumInvVars;
00262 else
00263 split0 == root->combineSplits(att, 0, &sumInvVars)/sumInvVars;
00264 }
00265
00266 if (!NonZero(sumInvVars))
00267 split0 =0.0;
00268
00269
00270
00271 Vector<double> shifts(noDatasets);
00272 shifts[0]=0.0;
00273 for (int i=1; i<noDatasets; i++)
00274 {
00275 double sumInvVars=0.0;
00276 double spliti;
00277 if (labeled)
00278 {
00279
00280
00281 spliti = root->combineLabeledCenters(att, i, &sumInvVars)/sumInvVars;
00282 if (!NonZero(sumInvVars))
00283 spliti=split0;
00284 }
00285 else
00286 spliti = root->combineCenters(att, i, &sumInvVars)/sumInvVars;
00287
00288
00289 if (!NonZero(sumInvVars))
00290 spliti=split0;
00291
00292
00293
00294 shifts[i]= -(split0 - spliti);
00295
00296 }
00297 continuousTransformer->SetShiftsAttribute(att, shifts);
00298 }
00299
00300
00301 }
00302
00303
00304
00305
00306
00307
00308
00309
00310 do
00311 {
00312 root->StartLearningEpoch();
00313
00314 for(int i=0; i<M; i++)
00315 {
00316 int datasetIndex=(int)ctrainData[i][csplitDim];
00317
00318 ExitIf(datasetIndex>=noDatasets, "More datasets enountered than declared");
00319
00320 int classLabel=dtrainData[i][dsplitDim];
00321 root->LearnSample(dtrainData[i],ctrainData[i], classLabel, datasetIndex);
00322 }
00323
00324 list<int> attList;
00325 root->FindSplitAttributes(attList);
00326
00327 attList.sort();
00328 attList.unique();
00329
00330
00331 list<int>::iterator itrt;
00332 for (itrt=attList.begin();
00333 itrt!=attList.end(); itrt++)
00334 {
00335 int SplitAttribute=*itrt;
00336 if (SplitAttribute>=0)
00337 {
00338
00339
00340
00341
00342
00343
00344
00345 Vector<Permutation> shifts(noDatasets);
00346 shifts[0]=Permutation(dDomainSize[SplitAttribute]);
00347 for (int i=1; i<noDatasets; i++)
00348 shifts[i]= root->ComputeDiscreteShift(true, SplitAttribute, i);
00349 discreteTransformer->SetShiftsAttribute(SplitAttribute, shifts);
00350
00351 }
00352 else
00353 {
00354
00355
00356
00357
00358
00359
00360
00361
00362
00363
00364
00365
00366 double sumInvVars=0.0;
00367 double split0;
00368 if (labeled)
00369 {
00370 split0 = root->combineSplits(-SplitAttribute-1, 0, &sumInvVars)/sumInvVars;
00371 split0=0.0;
00372 }
00373 else
00374 {
00375 if (!alignSplits)
00376 split0 = root->combineCenters(-SplitAttribute-1, 0, &sumInvVars)/sumInvVars;
00377 else
00378 split0 = root->combineSplits(-SplitAttribute-1, 0, &sumInvVars)/sumInvVars;
00379 }
00380
00381 if (sumInvVars==0.0)
00382 split0=0.0;
00383
00384
00385
00386 Vector<double> shifts(noDatasets);
00387 shifts[0]=0.0;
00388 for (int i=1; i<noDatasets; i++)
00389 {
00390 double sumInvVars=0.0;
00391 double spliti=0.0;
00392 if (labeled)
00393 {
00394
00395
00396 shifts[i] = root->combineLabeledCenters(-SplitAttribute-1, i, &sumInvVars);
00397 }
00398 else
00399 {
00400 if (!alignSplits)
00401 spliti= root->combineCenters(-SplitAttribute-1, i,&sumInvVars)/sumInvVars;
00402 else
00403 spliti = root->combineSplits(-SplitAttribute-1, i,&sumInvVars)/sumInvVars;
00404
00405 }
00406 if (sumInvVars==0.0)
00407 spliti=split0;
00408 else
00409 {
00410
00411
00412
00413
00414
00415 shifts[i]= -(split0 - spliti);
00416 }
00417 }
00418 continuousTransformer->SetShiftsAttribute(-SplitAttribute-1, shifts);
00419 }
00420 }
00421 }
00422 while (root->StopLearningEpoch(splitType, min_no_datapoints));
00423
00424 }
00425
00426 virtual void Prune(void)
00427 {
00428 const Matrix<double>& ctrainData = pruning->GetTrainingData();
00429 const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( pruning )
00430 -> GetDiscreteTrainingData();
00431
00432 int M=ctrainData.num_rows();
00433
00434 root->InitializePruningStatistics();
00435 for(int i=0; i<M; i++)
00436 {
00437 int datasetIndex=(int)ctrainData[i][csplitDim];
00438
00439 int Dvars[MAX_VARIABLES];
00440 for(int j=0; j<dsplitDim; j++)
00441 Dvars[j]=dtrainData[i][j];
00442
00443 double Cvars[MAX_VARIABLES];
00444 for (int j=0; j<csplitDim; j++)
00445 Cvars[j]=ctrainData[i][j];
00446
00447
00448 discreteTransformer->ApplyShiftToTuple(Dvars, datasetIndex);
00449 continuousTransformer->ApplyShiftToTuple(Cvars, datasetIndex);
00450 int classLabel=dtrainData[i][dsplitDim];
00451 root->UpdatePruningStatistics(Dvars, Cvars, classLabel, datasetIndex );
00452 }
00453 root->FinalizePruningStatistics();
00454
00455
00456 double cost=root->PruneSubtree();
00457 cout << "RMSN after pruning is:" << cost/M << endl;
00458 }
00459
00460 virtual int SetOption(char* name, char* val)
00461 {
00462
00463 if (strcmp(name, "SchemaMatch")==0)
00464 {
00465 if (strcmp(val, "true")==0)
00466 schemaMatch=true;
00467 else
00468 schemaMatch=false;
00469 }
00470 else if (strcmp(name, "labeled")==0)
00471 {
00472 if (strcmp(val, "true")==0)
00473 labeled=true;
00474 else
00475 labeled=false;
00476 }
00477 else if (strcmp(name, "alignSplits")==0)
00478 {
00479
00480 if (strcmp(val, "true")==0)
00481 alignSplits = true;
00482 else
00483 alignSplits=false;
00484
00485 }
00486 else
00487 return Machine::SetOption(name,val);
00488 return 1;
00489 }
00490
00491 virtual void SaveToStream(ostream& out)
00492 {
00493
00494 out << "Tree: with " << dsplitDim << " discrete and " << csplitDim << " continuous attributes" << endl;
00495 out << "Discrete attribute domain sizes: " << endl;
00496
00497 out << "[ ";
00498 for(int i=0; i<dsplitDim; i++)
00499 out << dDomainSize[i] << " ";
00500 out << "]" << endl;
00501
00502 out << "Shifts: " << endl;
00503
00504 for(int i=0; i<dsplitDim; i++)
00505 {
00506 for (int j=1; j<noDatasets; j++)
00507 {
00508 out << "Shift to dataset " << j << " for discrete attribute " << i << ": " << endl;
00509 discreteTransformer->saveToStream(out, i, j);
00510 }
00511 }
00512
00513 for(int i=0; i<csplitDim; i++)
00514 {
00515 for (int j=1; j < noDatasets; j++)
00516 {
00517 if (continuousTransformer->HasAttributeShifts(i))
00518 {
00519 out << "Shift to dataset " << j << " for continuous attribute " << i << ": "
00520 << continuousTransformer->getShift(i, j) << endl;
00521 }
00522 else
00523 {
00524 out << "Shift to dataset " << j << " for continuous attribute " << i << ": NOSHIFT" << endl;
00525 }
00526
00527 }
00528 }
00529 out << endl;
00530 root->SaveToStream(out);
00531 }
00532
00533 };
00534 }
00535
00536 #endif // _CLUS_MULTIDECISIONTREE_H_
00537