package mydl;
import java.util.ArrayList;
import mydl.layer.Linear1D;
import mydl.layer.ReLU;
import mydl.layer.Softmax;
import mydl.loss.MSE;
import mydl.model.Sequential;
import mydl.optimizer.SGD;
import mydl.tensor.Tensor;
import mydl.tensor.Tensor1D;
/**
* A simple xor test. Use Linear1D -> Tanh ->Linear1D to fit xor operator.
*/
public class XOR {
ArrayList<Tensor> inputs;
ArrayList<Tensor> targets;
ArrayList<Tensor> predicts;
Sequential model;
public XOR(){
inputs = new ArrayList<Tensor>();
targets = new ArrayList<Tensor>();
//input = [0, 0], 0 xor 0 = 0, so target = [1, 0]
//input = [1, 0], 1 xor 0 = 1, so target = [0, 1]
//input = [0, 1], 0 xor 1 = 1, so target = [0, 1]
//input = [1, 1], 1 xor 1 = 0, so target = [1, 0]
for(int i=0;i<2;i++)
for(int j=0;j<2;j++){
double[] x = new double[2];
x[0]=i; x[1]=j;
inputs.add(new Tensor1D(x));
if((i^j) == 1){
x[1]=1;x[0]=0;
} else {
x[0]=1;x[1]=0;
}
targets.add(new Tensor1D(x));
}
}
public void train(){
model = new Sequential();
model.add(new Linear1D(2, 8));
model.add(new ReLU());
model.add(new Linear1D(8, 4));
model.add(new ReLU());
model.add(new Linear1D(4, 2));
model.add(new ReLU());
model.add(new Linear1D(2, 2));
model.add(new Softmax());
model.compile(new SGD(0.0000001), new MSE());
model.fit(inputs, targets, 1000, 4, true, true);
}
public void validate(){
predicts = model.predict(inputs);
for(int i=0;i<predicts.size();i++){
System.out.println("train:"+((Tensor1D)inputs.get(i)).darray);
System.out.println("predict:"+((Tensor1D)predicts.get(i)).darray);
System.out.println("real:"+((Tensor1D)targets.get(i)).darray);
}
}
public static void main(String[] args) {
XOR xor = new XOR();
xor.train();
xor.validate();
}
}
{jcomments on}