00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 #if !defined _CLUS_MULTICLASSCONTINUOUSDISTRIBUTION_H
00035 #define _CLUS_MULTICLASSCONTINUOUSDISTRIBUTION_H
00036
00037 #include "distribution.h"
00038 #include "multiclassdistribution.h"
00039 #include "statfct.h"
00040
00041 using namespace TNT;
00042
00043 namespace CLUS
00044 {
00045
00046
00047
00048
00049 template< class T_Distribution >
00050 class MulticlassContinuousDistribution : public MulticlassDistribution
00051 {
00052 protected:
00053 Vector<T_Distribution> distributions;
00054
00055 public:
00056
00057
00058
00059
00060
00061
00062 MulticlassContinuousDistribution(int NoClasses, T_Distribution& D):
00063 MulticlassDistribution(NoClasses), distributions(NoClasses,D)
00064 { }
00065
00066
00067
00068
00069 virtual void Infer(const double* cdata, const int* ddata, double* result)
00070 {
00071 double Coef=0.0;
00072 for (int i=0; i<noClasses; i++)
00073 {
00074 Coef+=distributions[i].Probability(cdata);
00075 }
00076
00077 for (int i=0; i<noClasses; i++)
00078 result[i]=distributions[i].NormalizeProbability(Coef);
00079
00080 }
00081
00082
00083
00084
00085 virtual void MultiplicativeInfer(const double* cdata, const int* ddata, double* result)
00086 {
00087 for (int i=0; i<noClasses; i++)
00088 result[i]*=distributions[i].Probability(cdata);
00089 }
00090
00091
00092 virtual void StartLearning(void)
00093 {
00094 for (int i=0; i<noClasses; i++)
00095 distributions[i].InitializeStatistics();
00096 }
00097
00098 virtual void LearnSample(const double* cdata, const int* ddata, int classLabel, double weightSample=1.0)
00099 {
00100 distributions[classLabel].UpdateStatistics(cdata, weightSample);
00101 }
00102
00103 virtual void LearnSample(const double* cdata, const int* ddata,
00104 double* classProbabilities, double weightSample=1.0)
00105 {
00106 for (int i=0; i<noClasses; i++)
00107 distributions[i].UpdateStatistics(cdata, weightSample*classProbabilities[i]);
00108 }
00109
00110
00111 virtual void StopLearning(void)
00112 {
00113 for (int i=0; i<noClasses; i++)
00114 distributions[i].UpdateParameters();
00115 }
00116
00117 virtual double PValueStatisticalTest(void);
00118
00119 #ifdef CLUS_USE_XML
00120
00121 virtual void PrintToXmlStream(ostream& output)
00122 {
00123 output << "<MulticlassContinuousDistribution";
00124 PrintAttribute(output, "typeDistribution", T_Distribution::TypeName());
00125 output << ">" << endl;
00126
00127 PrintVectorOfElements(output, distributions);
00128
00129 output << "</MulticlassContinuousDistribution>" << endl;
00130 }
00131 #endif
00132
00133 virtual bool IsClassLabelAbsent(int index)
00134 {
00135 return distributions[index].HasZeroWeight();
00136 }
00137 };
00138
00139 }
00140
00141 #endif // _CLUS_MULTICLASSCONTINUOUSDISTRIBUTION_H