-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathAdaBoost.java
128 lines (112 loc) · 3.69 KB
/
AdaBoost.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import java.util.ArrayList;
import java.util.List;
public class AdaBoost extends RandomForest {
static int NO_TREES=30; //Number of classifers
static double alpha[]=new double[32561]; //An array for storing the weights for each instance.
static double weightedError[]=new double[NO_TREES]; //An array to store weighted error for each classifier.
static double wHat[]=new double[NO_TREES]; //An array to store wHAT for each classifier.
static ArrayList<List<Integer>> classifierPredictions=new ArrayList<List<Integer>>(); //ArrayList of ArrayList to store classifier predicitons on train data.
static ArrayList<List<Integer>> classifierPredictionsOnTestData=new ArrayList<List<Integer>>();//ArrayList of ArrayList to store classifier predicitons on test data
static int classifiersize=0;
AdaBoost()
{
for(int i=0;i<32561;i++)
alpha[i]=1.0/32561;
}
public static void calcWeightedError(int d)
{
double classifierWeightedError=0;
double total_alpha=0;
for(int i=0;i<32561;i++)
{
if(Split.b[i]!=classifierPredictions.get(d).get(i))
{
classifierWeightedError+=alpha[i];
}
total_alpha+=alpha[i];
}
weightedError[d]=classifierWeightedError/total_alpha;
}
public static int calcWHat()
{
int minindex=0;
for(int i=1;i<weightedError.length;i++)
{
if(weightedError[i]<weightedError[minindex])
minindex=i;
}
wHat[minindex]=0.5*Math.log((1-weightedError[minindex])/weightedError[minindex]);
//System.out.println("MIn What= "+wHat[minindex]);
weightedError[minindex]=Double.MAX_VALUE;
return minindex;
}
public static void updateAlpha(int minindex)
{
double sum=0;
for(int i=0;i<32561;i++)
{
if(Split.b[i]!=classifierPredictions.get(minindex).get(i))
alpha[i]=alpha[i]*(Math.exp(wHat[minindex]));
else
alpha[i]=alpha[i]*(Math.exp(-1*wHat[minindex]));
sum+=alpha[i];
}
for(int i=0;i<32561;i++)
alpha[i]/=sum;
}
public static List<Integer> adaBoost()
{
List<Integer> predictions=new ArrayList<Integer>();
for(int j=0;j<classifierPredictionsOnTestData.get(0).size();j++)
{
double signum=0;
for(int i=0;i<NO_TREES;i++)
{
int ft=classifierPredictionsOnTestData.get(i).get(j);
if(ft==0)
signum+=wHat[i]*-1;
else
signum+=wHat[i]*ft;
}
predictions.add((int)Math.signum(signum));
}
return predictions;
}
public static void main(String []args)
{
long startTime = System.nanoTime();
AdaBoost ab;
List<Instance> testInstances=new ArrayList<Instance>();
List<Instance> completeTrainInstances=new ArrayList<Instance>();
List<Integer> predictions = new ArrayList<Integer>();
loadTest("src/test.txt", testInstances);
loadTest("src/adult.txt", completeTrainInstances);
Split sp=new Split();
for(int i=0;i<NO_TREES;i++){
ArrayList<Instance> trainInstances = new ArrayList<Instance>();
load(trainInstances);
ab=new AdaBoost();
ab.learn(trainInstances);
ab.classifierPredictions.add(ab.classify(completeTrainInstances));
ab.classifierPredictionsOnTestData.add(ab.classify(testInstances));
}
for(int j=0;j<NO_TREES;j++)
{
for(int i=0;i<NO_TREES;i++)
{
calcWeightedError(i);
}
int minindex=calcWHat();
updateAlpha(minindex);
}
long endTime = System.nanoTime();
predictions=adaBoost();
for(int i=0;i<predictions.size();i++)
{
if(predictions.get(i)==-1)
predictions.set(i, 0);
}
System.out.println("Learning time taken in seconds is\t"+ (endTime-startTime)/1000000000);
System.out.println("Accuracy of testdata using AdaBoost technique is = "+computeAccuracy(predictions, testInstances)*100+" %" );
}
}