package mydl.layer;
import java.util.ArrayList;
import org.ejml.EjmlUnitTests;
import org.ejml.data.DMatrixRMaj;
import org.junit.Test;
import mydl.tensor.Tensor1D;
import mydl.tensor.Tensor3D;
public class ReshapeTest {
@Test
public void successTest(){
double[][][] xx = new double[2][2][2];
double[] yy = new double[8];
for (int i = 0; i < 8; i++) {
xx[i/4][(i%4)/2][i%2] = i;
yy[i] = i;
}
Tensor1D y = new Tensor1D(yy);
Tensor3D x = new Tensor3D(xx);
Reshape layer = new Reshape(x.size, y.size);
EjmlUnitTests.assertEquals(((Tensor1D)layer.forward(x)).darray, y.darray);
ArrayList<DMatrixRMaj> c1 = ((Tensor3D)layer.backward(y)).darray;
for(int i=0;i<2;i++)
EjmlUnitTests.assertEquals(c1.get(i), x.darray.get(i));
}