Skip to content

Commit a8f2619

Browse files
committed
[MINOR] Fix Codestyle Issues
1 parent a78b38f commit a8f2619

1 file changed

Lines changed: 73 additions & 68 deletions

File tree

src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java

Lines changed: 73 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -24,65 +24,66 @@
2424

2525
import org.apache.sysds.common.Types.ValueType;
2626
import org.apache.sysds.runtime.data.TensorBlock;
27-
import java.util.Arrays;
27+
28+
import java.util.Arrays;
2829

2930
public class TransposeLinDataTest {
3031

3132
@Test
32-
public void testRightElem(){
33+
public void testRightElem() {
3334
int[] shape = {2, 3, 4};
3435
TensorBlock tensor = TensorUtils.createArangeTensor(shape);
3536

36-
Assert.assertArrayEquals(new int[]{2, 3, 4}, tensor.getDims());
37-
Assert.assertEquals(0.0, tensor.get(new int[]{0, 0, 0}));
38-
Assert.assertEquals(23.0, tensor.get(new int[]{1, 2, 3}));
39-
Assert.assertEquals(6.0, tensor.get(new int[]{0, 1, 2}));
40-
Assert.assertEquals(12.0, tensor.get(new int[]{1, 0, 0}));
37+
Assert.assertArrayEquals(new int[] {2, 3, 4}, tensor.getDims());
38+
Assert.assertEquals(0.0, tensor.get(new int[] {0, 0, 0}));
39+
Assert.assertEquals(23.0, tensor.get(new int[] {1, 2, 3}));
40+
Assert.assertEquals(6.0, tensor.get(new int[] {0, 1, 2}));
41+
Assert.assertEquals(12.0, tensor.get(new int[] {1, 0, 0}));
4142
printTensor(tensor);
4243

43-
4444
int[] permutation = {1, 0, 2};
45-
TensorBlock outTensor = PermuteIt.permute(tensor, permutation);
46-
printTensor(outTensor);
47-
48-
Assert.assertArrayEquals(new int[]{3, 2, 4}, outTensor.getDims());
49-
Assert.assertEquals(0.0, outTensor.get(new int[]{0,0,0}));
50-
Assert.assertEquals(23.0, outTensor.get(new int[]{2, 1, 3}));
51-
Assert.assertEquals(12.0, outTensor.get(new int[]{0, 1, 0}));
52-
Assert.assertEquals(17.0, outTensor.get(new int[]{1, 1, 1}));
53-
54-
int[] second_permutation = {2, 1, 0};
55-
TensorBlock perm2Block = PermuteIt.permute(tensor, second_permutation);
56-
printTensor(perm2Block);
57-
58-
Assert.assertArrayEquals(new int[]{4, 3, 2}, perm2Block.getDims());
59-
Assert.assertEquals(0.0, perm2Block.get(new int[]{0, 0, 0}));
60-
Assert.assertEquals(12.0, perm2Block.get(new int[]{0, 0, 1}));
61-
Assert.assertEquals(11.0, perm2Block.get(new int[]{3, 2, 0}));
62-
Assert.assertEquals(23.0, perm2Block.get(new int[]{3, 2, 1}));
45+
TensorBlock outTensor = PermuteIt.permute(tensor, permutation);
46+
printTensor(outTensor);
47+
48+
Assert.assertArrayEquals(new int[] {3, 2, 4}, outTensor.getDims());
49+
Assert.assertEquals(0.0, outTensor.get(new int[] {0, 0, 0}));
50+
Assert.assertEquals(23.0, outTensor.get(new int[] {2, 1, 3}));
51+
Assert.assertEquals(12.0, outTensor.get(new int[] {0, 1, 0}));
52+
Assert.assertEquals(17.0, outTensor.get(new int[] {1, 1, 1}));
53+
54+
int[] second_permutation = {2, 1, 0};
55+
TensorBlock perm2Block = PermuteIt.permute(tensor, second_permutation);
56+
printTensor(perm2Block);
57+
58+
Assert.assertArrayEquals(new int[] {4, 3, 2}, perm2Block.getDims());
59+
Assert.assertEquals(0.0, perm2Block.get(new int[] {0, 0, 0}));
60+
Assert.assertEquals(12.0, perm2Block.get(new int[] {0, 0, 1}));
61+
Assert.assertEquals(11.0, perm2Block.get(new int[] {3, 2, 0}));
62+
Assert.assertEquals(23.0, perm2Block.get(new int[] {3, 2, 1}));
6363
}
6464

6565
public class TensorUtils {
6666

6767
public static TensorBlock createArangeTensor(int[] shape) {
6868
TensorBlock tb = new TensorBlock(ValueType.FP64, shape);
6969
tb.allocateBlock();
70-
double[] counter = { 0.0 };
70+
double[] counter = {0.0};
7171
int[] currentIndices = new int[shape.length];
72-
72+
7373
fillRecursively(tb, shape, 0, currentIndices, counter);
74-
74+
7575
return tb;
7676
}
7777

78-
private static void fillRecursively(TensorBlock tb, int[] shape, int dim, int[] currentIndices, double[] counter) {
79-
if (dim == shape.length) {
78+
private static void fillRecursively(TensorBlock tb, int[] shape, int dim, int[] currentIndices,
79+
double[] counter) {
80+
if(dim == shape.length) {
8081
tb.set(currentIndices, counter[0]);
81-
counter[0]++;
82+
counter[0]++;
8283
return;
8384
}
8485

85-
for (int i = 0; i < shape[dim]; i++) {
86+
for(int i = 0; i < shape[dim]; i++) {
8687
currentIndices[dim] = i;
8788

8889
fillRecursively(tb, shape, dim + 1, currentIndices, counter);
@@ -91,42 +92,41 @@ private static void fillRecursively(TensorBlock tb, int[] shape, int dim, int[]
9192
}
9293

9394
public class PermuteIt {
94-
public static TensorBlock permute(TensorBlock tensor, int[] permute_dims) {
95-
int anz_dims = tensor.getNumDims();
95+
public static TensorBlock permute(TensorBlock tensor, int[] permute_dims) {
96+
int anz_dims = tensor.getNumDims();
9697
int[] dims = tensor.getDims();
9798
ValueType tensorType = tensor.getValueType();
9899

99-
int[] out_shape = new int[anz_dims];
100+
int[] out_shape = new int[anz_dims];
100101

101-
for (int idx = 0; idx < anz_dims; idx++){
102+
for(int idx = 0; idx < anz_dims; idx++) {
102103
out_shape[idx] = dims[permute_dims[idx]];
103104
}
104105

105-
TensorBlock outTensor = new TensorBlock(tensorType, out_shape);
106+
TensorBlock outTensor = new TensorBlock(tensorType, out_shape);
106107
outTensor.allocateBlock();
107108

108-
int[] inIndex = new int[anz_dims];
109-
int[] outIndex = new int[anz_dims];
109+
int[] inIndex = new int[anz_dims];
110+
int[] outIndex = new int[anz_dims];
110111

111-
recursion(tensor, outTensor, permute_dims, dims, 0, inIndex, outIndex);
112+
recursion(tensor, outTensor, permute_dims, dims, 0, inIndex, outIndex);
112113
return outTensor;
113114
}
114115

115-
public static void recursion(TensorBlock inTensor, TensorBlock outTensor,
116-
int[] permutation, int[] inShape, int dim, int[] inIndex, int[]outIndex)
117-
{
118-
if (dim == inShape.length) {
119-
for(int idx = 0; idx < permutation.length; idx++){
120-
outIndex[idx] = inIndex[permutation[idx]];
116+
public static void recursion(TensorBlock inTensor, TensorBlock outTensor, int[] permutation, int[] inShape,
117+
int dim, int[] inIndex, int[] outIndex) {
118+
if(dim == inShape.length) {
119+
for(int idx = 0; idx < permutation.length; idx++) {
120+
outIndex[idx] = inIndex[permutation[idx]];
121121
}
122122
double val = (double) inTensor.get(inIndex);
123123
outTensor.set(outIndex, val);
124-
return;
124+
return;
125125
}
126126

127-
for(int idx = 0; idx < inShape[dim]; idx++){
128-
inIndex[dim] = idx;
129-
recursion(inTensor, outTensor, permutation, inShape, dim+1, inIndex, outIndex);
127+
for(int idx = 0; idx < inShape[dim]; idx++) {
128+
inIndex[dim] = idx;
129+
recursion(inTensor, outTensor, permutation, inShape, dim + 1, inIndex, outIndex);
130130
}
131131
}
132132
}
@@ -135,42 +135,47 @@ public static void printTensor(TensorBlock tb) {
135135
StringBuilder sb = new StringBuilder();
136136
int[] shape = tb.getDims();
137137
int[] currentIndices = new int[shape.length];
138-
138+
139139
sb.append("Tensor(").append(Arrays.toString(shape)).append("):\n");
140140
printRecursive(tb, shape, 0, currentIndices, sb, 0);
141-
141+
142142
System.out.println(sb.toString());
143143
}
144144

145-
private static void printRecursive(TensorBlock tb, int[] shape, int dim, int[] indices, StringBuilder sb, int indent) {
146-
for (int k = 0; k < indent; k++) sb.append(" ");
145+
private static void printRecursive(TensorBlock tb, int[] shape, int dim, int[] indices, StringBuilder sb,
146+
int indent) {
147+
for(int k = 0; k < indent; k++)
148+
sb.append(" ");
147149

148150
sb.append("[");
149151

150-
if (dim == shape.length - 1) {
151-
for (int i = 0; i < shape[dim]; i++) {
152+
if(dim == shape.length - 1) {
153+
for(int i = 0; i < shape[dim]; i++) {
152154
indices[dim] = i;
153-
double val = (double) tb.get(indices);
154-
sb.append(String.format("%.1f", val));
155-
if (i < shape[dim] - 1) sb.append(", ");
155+
double val = (double) tb.get(indices);
156+
sb.append(String.format("%.1f", val));
157+
if(i < shape[dim] - 1)
158+
sb.append(", ");
156159
}
157160
sb.append("]");
158-
}
161+
}
159162

160163
else {
161164
sb.append("\n");
162-
for (int i = 0; i < shape[dim]; i++) {
165+
for(int i = 0; i < shape[dim]; i++) {
163166
indices[dim] = i;
164167
printRecursive(tb, shape, dim + 1, indices, sb, indent + 2);
165-
166-
if (i < shape[dim] - 1) {
168+
169+
if(i < shape[dim] - 1) {
167170
sb.append(",");
168-
sb.append("\n");
169-
if (shape.length - dim > 2) sb.append("\n");
171+
sb.append("\n");
172+
if(shape.length - dim > 2)
173+
sb.append("\n");
170174
}
171175
}
172-
sb.append("\n");
173-
for (int k = 0; k < indent; k++) sb.append(" ");
176+
sb.append("\n");
177+
for(int k = 0; k < indent; k++)
178+
sb.append(" ");
174179
sb.append("]");
175180
}
176181
}

0 commit comments

Comments
 (0)