diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java index 9bf0927bb622..852827aa8ab9 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java @@ -40,6 +40,7 @@ import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree; @RunWith(IoTDBTestRunner.class) @@ -72,6 +73,7 @@ public void callInferenceTest() throws SQLException { Statement statement = connection.createStatement()) { callInferenceTest(statement, modelInfo); callInferenceByDefaultTest(statement, modelInfo); + callInferenceErrorTest(statement, modelInfo); } } } @@ -118,4 +120,16 @@ public static void callInferenceByDefaultTest( } } } + + public static void callInferenceErrorTest( + Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) { + String multiVariateSQL = + String.format( + "CALL INFERENCE(%s, \"SELECT s0,s1 FROM root.AI LIMIT 128\", generateTime=true, outputLength=10)", + modelInfo.getModelId()); + errorTest( + statement, + multiVariateSQL, + "701: Call inference function should not contain more than one input column, found [2] input columns."); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java index 8bf9f5ead14b..d39d81aff5be 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java @@ -24,6 +24,7 @@ import org.apache.iotdb.commons.client.exception.ClientManagerException; import org.apache.iotdb.db.exception.ainode.AINodeConnectionException; import org.apache.iotdb.db.exception.runtime.ModelInferenceProcessException; +import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.protocol.client.an.AINodeClient; import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper; @@ -225,6 +226,12 @@ private void appendTsBlockToBuilder(TsBlock inputTsBlock) { maxTimestamp = Math.max(maxTimestamp, timestamp); } timeColumnBuilder.writeLong(timestamp); + if (inputTsBlock.getValueColumnCount() > 1) { + throw new SemanticException( + String.format( + "Call inference function should not contain more than one input column, found [%d] input columns.", + inputTsBlock.getValueColumnCount())); + } for (int columnIndex = 0; columnIndex < inputTsBlock.getValueColumnCount(); columnIndex++) { columnBuilders[columnIndexes[columnIndex]].write(inputTsBlock.getColumn(columnIndex), i); }