00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #if !defined _CLUS_REGRESSIONTREE_H_
00035 #define _CLUS_REGRESSIONTREE_H_
00036
00037 #include "machine.h"
00038 #include "regressiontreenode.h"
00039 #include "dctraingen.h"
00040 #include <iostream>
00041
00042
00043
00044 namespace CLUS
00045 {
00046
00047 template< class T_Distribution, class T_Regressor, class T_Splitter >
00048 class BinaryRegressionTree : public Machine
00049 {
00050 protected:
00051 BinaryRegressionTreeNode< T_Distribution, T_Regressor, T_Splitter>* root;
00052
00053
00054 const Vector<int>& dDomainSize;
00055
00056
00057 int dsplitDim;
00058
00059
00060 int csplitDim;
00061
00062
00063 int regDim;
00064
00065
00066 int emMaxIterations;
00067
00068
00069 int emRestarts;
00070
00071
00072 int min_no_datapoints;
00073
00074
00075 int splitType;
00076
00077
00078 double threshold;
00079
00080 T_Distribution* rootDistribution;
00081 int inferMaxNodeId;
00082
00083 void PrintSizesTree(void)
00084 {
00085 int nodes=0;
00086 int term_nodes=0;
00087 root->ComputeSizesTree(nodes,term_nodes);
00088 cout << "Nuber of nodes=" << nodes << "\tNumber of terminal nodes=" << term_nodes << endl;
00089 }
00090
00091 public:
00092 BinaryRegressionTree(const Vector<int>& DDomainSize, int CsplitDim, int RegDim):
00093 Machine(CsplitDim+RegDim,1),dDomainSize(DDomainSize),
00094 dsplitDim(DDomainSize.dim()),csplitDim(CsplitDim), regDim(RegDim)
00095 {
00096 emRestarts = 3;
00097 emMaxIterations = 30;
00098 min_no_datapoints = 10;
00099 splitType = 0;
00100 rootDistribution = 0;
00101 inferMaxNodeId = INT_MAX;
00102 threshold=.01;
00103 }
00104
00105 virtual ~BinaryRegressionTree(void)
00106 {
00107 if (rootDistribution)
00108 delete rootDistribution;
00109 }
00110
00111 virtual int InDim(void)
00112 {
00113 return dsplitDim+csplitDim+regDim;
00114 }
00115
00116 virtual string TypeName(void)
00117 {
00118 return string("BinaryRegressionTree");
00119 }
00120
00121 virtual void Infer(void)
00122 {
00123 if (root==0)
00124 return;
00125
00126
00127 int Dvars[MAX_VARIABLES];
00128 for(int i=0; i<dsplitDim; i++)
00129 Dvars[i]=(int)(*input)[i];
00130
00131
00132 double scaledInput[MAX_VARIABLES];
00133 for (int i=0; i<csplitDim+regDim; i++)
00134 scaledInput[i]=scale[i].Transform( (*input)[i+dsplitDim] );
00135
00136 output[0]=scale[csplitDim+regDim].Transform( root->Infer(Dvars,scaledInput,inferMaxNodeId, threshold) );
00137 }
00138
00139 virtual void Identify(void)
00140 {
00141 const Matrix<double>& ctrainData = training->GetTrainingData();
00142 const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( training )
00143 -> GetDiscreteTrainingData();
00144
00145 int M=ctrainData.num_rows();
00146
00147
00148 if (!rootDistribution)
00149 rootDistribution = new T_Distribution(regDim);
00150
00151 rootDistribution->RandomDistribution(1);
00152 double Coef;
00153 for (int i=0; i<M; i++)
00154 {
00155 Coef = rootDistribution->LearnProbability(ctrainData[i]+csplitDim);
00156 rootDistribution->NormalizeLearnProbability(Coef);
00157 }
00158 rootDistribution->UpdateParameters();
00159
00160
00161
00162
00163
00164
00165
00166 T_Regressor* regressor = dynamic_cast<T_Regressor*>( rootDistribution->CreateRegressor() );
00167
00168 if (regressor==NULL)
00169 regressor=new T_Regressor();
00170
00171
00172 root = new BinaryRegressionTreeNode<T_Distribution, T_Regressor, T_Splitter>
00173 (1, dDomainSize, csplitDim, regDim, *regressor, rootDistribution);
00174
00175 rootDistribution = 0;
00176
00177 do
00178 {
00179 root->StartLearningEpoch();
00180 for(int i=0; i<M; i++)
00181 root->LearnSample(dtrainData[i],ctrainData[i], 1.0, threshold);
00182 }
00183 while (root->StopLearningEpoch(splitType, emRestarts, emMaxIterations,
00184 convergenceLim, min_no_datapoints));
00185 cout << "End Learning" << endl;
00186 PrintSizesTree();
00187 }
00188
00189 virtual void Prune(void)
00190 {
00191 const Matrix<double>& ctrainData = pruning->GetTrainingData();
00192 const Matrix<int>& dtrainData = dynamic_cast< DCTrainingData* >( pruning )
00193 -> GetDiscreteTrainingData();
00194
00195 int M=ctrainData.num_rows();
00196
00197 root->InitializePruningStatistics();
00198 for(int i=0; i<M; i++)
00199 {
00200
00201 double scaledInput[MAX_VARIABLES];
00202 for (int j=0; j<csplitDim+regDim; j++)
00203 scaledInput[j]=scale[j].Transform( ctrainData[i][j] );
00204
00205 double y=scale[csplitDim+regDim].InverseTransform( ctrainData[i][csplitDim+regDim] );
00206
00207 root->UpdatePruningStatistics(dtrainData[i], scaledInput, y, 1.0, threshold);
00208 }
00209 root->FinalizePruningStatistics();
00210
00211
00212 double cost=root->PruneSubtree();
00213 cout << "RMSN after pruning is:" << cost/M << endl;
00214 PrintSizesTree();
00215 }
00216
00217 virtual int SetOption(char* name, char* val)
00218 {
00219 if (strcmp(name,"EMMaxIterations")==0)
00220 emMaxIterations = atoi(val);
00221 else
00222 if (strcmp(name,"EMRestarts")==0)
00223 emRestarts = atoi(val);
00224 else
00225 if (strcmp(name,"InferMaxNodeId")==0)
00226 inferMaxNodeId = atoi(val);
00227 else
00228 if (strcmp(name,"MaxNoDatapoints")==0)
00229 min_no_datapoints = 2*atoi(val)-2;
00230 else
00231 if (strcmp(name,"SplitType")==0)
00232 splitType = atoi(val);
00233 else
00234 if (strcmp(name,"Threshold")==0)
00235 threshold = atof(val);
00236 else
00237 return Machine::SetOption(name,val);
00238 return 1;
00239 }
00240
00241 virtual void SaveToStream(ostream& out)
00242 {
00243 out << TypeName() << " ( " << dsplitDim << " " << csplitDim << " ";
00244 out << regDim << " ) { " << endl;
00245
00246 out << "[ ";
00247 for(int i=0; i<dsplitDim; i++)
00248 out << dDomainSize[i] << " ";
00249 out << "]" << endl;
00250
00251 for(int i=0; i<csplitDim+regDim+1; i++)
00252 scale[i].SaveToStream(out);
00253 out << endl;
00254
00255 root->SaveToStream(out);
00256
00257 out << '}' << endl;
00258 }
00259 };
00260
00261 }
00262
00263 #endif // _CLUS_REGRESSIONTREE_H_