Skip to content

Commit 861e50c

Browse files
committed
Fix AINode inference output type
1 parent 75d6855 commit 861e50c

2 files changed

Lines changed: 9 additions & 3 deletions

File tree

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import java.sql.ResultSetMetaData;
4343
import java.sql.SQLException;
4444
import java.sql.Statement;
45+
import java.sql.Types;
4546
import java.util.Arrays;
4647
import java.util.HashSet;
4748
import java.util.LinkedList;
@@ -271,8 +272,10 @@ public static void callInferenceTest(Statement statement, AINodeTestUtils.FakeMo
271272
try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) {
272273
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
273274
checkHeader(resultSetMetaData, "Time,output");
275+
Assert.assertEquals(Types.DOUBLE, resultSetMetaData.getColumnType(2));
274276
int count = 0;
275277
while (resultSet.next()) {
278+
resultSet.getDouble("output");
276279
count++;
277280
}
278281
Assert.assertEquals(DEFAULT_OUTPUT_LENGTH, count);
@@ -288,8 +291,10 @@ public static void callInferenceByDefaultTest(
288291
try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) {
289292
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
290293
checkHeader(resultSetMetaData, "output");
294+
Assert.assertEquals(Types.DOUBLE, resultSetMetaData.getColumnType(1));
291295
int count = 0;
292296
while (resultSet.next()) {
297+
resultSet.getDouble("output");
293298
count++;
294299
}
295300
Assert.assertTrue(count > 0);

iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,11 @@ def _do_inference_and_construct_resp(
209209
logger.error("[Inference] Unsupported pipeline type.")
210210
outputs = inference_pipeline.postprocess(outputs, **inference_attrs)
211211

212-
# convert tensor into tsblock for the output in each batch
212+
# DataNode currently exposes inference outputs as DOUBLE, so serialize the
213+
# physical TsBlock column as double even when model tensors are float32.
213214
resp_list = []
214-
for batch_idx, output in enumerate(outputs):
215-
resp = convert_tensor_to_tsblock(output)
215+
for output in outputs:
216+
resp = convert_tensor_to_tsblock(output.to(dtype=torch.float64))
216217
resp_list.append(resp)
217218
return resp_list
218219

0 commit comments

Comments
 (0)