Skip to content

Commit 6aef55a

Browse files
committed
Add unit tests for getCategoricalMask defensive code paths
Drive the get_categorical_mask instruction directly to cover branches the script-level transform test cannot reach: the inline distinct-count prefix in metadata cells, default column metadata yielding zero columns, non id-based spec rejection, and the unsupported opcode guard. The error cases assert the specific exception message so they verify the intended failure rather than any wrapped exception.
1 parent c528fa2 commit 6aef55a

1 file changed

Lines changed: 148 additions & 0 deletions

File tree

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.component.frame.transform;
21+
22+
import static org.junit.Assert.assertArrayEquals;
23+
import static org.junit.Assert.assertEquals;
24+
import static org.junit.Assert.assertTrue;
25+
import static org.junit.Assert.fail;
26+
27+
import org.apache.commons.logging.Log;
28+
import org.apache.commons.logging.LogFactory;
29+
import org.apache.sysds.common.Types.DataType;
30+
import org.apache.sysds.common.Types.FileFormat;
31+
import org.apache.sysds.common.Types.ValueType;
32+
import org.apache.sysds.runtime.DMLRuntimeException;
33+
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
34+
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
35+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
36+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
37+
import org.apache.sysds.runtime.frame.data.FrameBlock;
38+
import org.apache.sysds.runtime.instructions.InstructionUtils;
39+
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
40+
import org.apache.sysds.runtime.instructions.cp.BinaryFrameScalarCPInstruction;
41+
import org.apache.sysds.runtime.instructions.cp.StringObject;
42+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
43+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
44+
import org.apache.sysds.runtime.meta.MetaDataFormat;
45+
import org.junit.BeforeClass;
46+
import org.junit.Test;
47+
48+
/**
49+
* Unit tests that drive the get_categorical_mask instruction directly to exercise the defensive code
50+
* paths (distinct-count prefix in the metadata frame, default column metadata, non id-based specs and
51+
* the unsupported opcode guard) that the script-level transform tests cannot reach.
52+
*/
53+
public class GetCategoricalMaskInstructionTest {
54+
protected static final Log LOG = LogFactory.getLog(GetCategoricalMaskInstructionTest.class.getName());
55+
56+
private static final String MASK_OPCODE = "get_categorical_mask";
57+
58+
@BeforeClass
59+
public static void init() throws java.io.IOException {
60+
CacheableData.initCaching("get_categorical_mask_instruction_test");
61+
}
62+
63+
@Test
64+
public void dummycodeReadsDistinctCountFromMetadataPrefix() {
65+
// a metadata cell prefixed with '¿' encodes the number of distinct values inline
66+
FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"¿3"}});
67+
MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [1]}");
68+
69+
assertEquals(1, res.getNumRows());
70+
assertEquals(3, res.getNumColumns());
71+
assertArrayEquals(new double[] {1, 1, 1}, res.getDenseBlockValues(), 0.0);
72+
}
73+
74+
@Test
75+
public void dummycodeDefaultMetadataContributesNoColumns() {
76+
// first column is dummycoded but carries default metadata (no distinct count) -> 0 columns,
77+
// the trailing pass-through column keeps the output non-empty
78+
FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING, ValueType.STRING},
79+
new String[][] {{"x", "y"}});
80+
MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [1]}");
81+
82+
assertEquals(1, res.getNumRows());
83+
assertEquals(1, res.getNumColumns());
84+
assertEquals(0.0, res.get(0, 0), 0.0);
85+
}
86+
87+
@Test
88+
public void nonIdSpecMissingIdsKeyThrows() {
89+
// a spec without the "ids" key must be rejected, not silently mis-interpreted
90+
FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}});
91+
assertThrowsMessage("non ID based spec", () -> run(meta, "{\"recode\": [1]}"));
92+
}
93+
94+
@Test
95+
public void nonIdSpecIdsFalseThrows() {
96+
// "ids": false is equally unsupported
97+
FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}});
98+
assertThrowsMessage("non ID based spec", () -> run(meta, "{\"ids\": false, \"recode\": [1]}"));
99+
}
100+
101+
@Test
102+
public void unsupportedOpcodeThrows() {
103+
// any frame-scalar binary opcode other than get_categorical_mask must be rejected
104+
ExecutionContext ec = ExecutionContextFactory.createContext();
105+
ec.setAutoCreateVars(true);
106+
ec.setVariable("F", frameObject(new FrameBlock(new ValueType[] {ValueType.STRING},
107+
new String[][] {{"a"}})));
108+
assertThrowsMessage("Unsupported operation", () -> maskInstruction("+").processInstruction(ec));
109+
}
110+
111+
/** Assert the action throws a DMLRuntimeException whose message chain contains the expected text. */
112+
private static void assertThrowsMessage(String expected, Runnable action) {
113+
try {
114+
action.run();
115+
fail("Expected DMLRuntimeException containing \"" + expected + "\" but nothing was thrown");
116+
}
117+
catch(DMLRuntimeException e) {
118+
StringBuilder chain = new StringBuilder();
119+
for(Throwable t = e; t != null; t = t.getCause())
120+
chain.append(t.getMessage()).append(" | ");
121+
assertTrue("Exception chain [" + chain + "] should contain \"" + expected + "\"",
122+
chain.toString().contains(expected));
123+
}
124+
}
125+
126+
private static MatrixBlock run(FrameBlock meta, String spec) {
127+
ExecutionContext ec = ExecutionContextFactory.createContext();
128+
ec.setAutoCreateVars(true);
129+
maskInstruction(MASK_OPCODE).processGetCategorical(ec, meta, new StringObject(spec));
130+
return ec.getMatrixObject("out").acquireReadAndRelease();
131+
}
132+
133+
private static BinaryFrameScalarCPInstruction maskInstruction(String opcode) {
134+
String in1 = InstructionUtils.concatOperandParts("F", DataType.FRAME.name(), ValueType.STRING.name(), "false");
135+
String in2 = InstructionUtils.concatOperandParts("spec", DataType.SCALAR.name(), ValueType.STRING.name(), "true");
136+
String out = InstructionUtils.concatOperandParts("out", DataType.MATRIX.name(), ValueType.FP64.name(), "false");
137+
String str = InstructionUtils.concatOperands("CP", opcode, in1, in2, out);
138+
return (BinaryFrameScalarCPInstruction) BinaryCPInstruction.parseInstruction(str);
139+
}
140+
141+
private static FrameObject frameObject(FrameBlock fb) {
142+
MatrixCharacteristics mc = new MatrixCharacteristics(fb.getNumRows(), fb.getNumColumns(), -1, -1);
143+
FrameObject fo = new FrameObject("F", new MetaDataFormat(mc, FileFormat.BINARY), fb.getSchema());
144+
fo.acquireModify(fb);
145+
fo.release();
146+
return fo;
147+
}
148+
}

0 commit comments

Comments
 (0)