|
16 | 16 |
|
17 | 17 | package io.serverlessworkflow.impl.executors.ai; |
18 | 18 |
|
19 | | -import io.serverlessworkflow.ai.api.types.CallAILangChainChatModel; |
20 | 19 | import io.serverlessworkflow.api.types.TaskBase; |
21 | 20 | import io.serverlessworkflow.api.types.ai.AbstractCallAIChatModelTask; |
22 | 21 | import io.serverlessworkflow.api.types.ai.CallAIChatModel; |
| 22 | +import io.serverlessworkflow.api.types.ai.CallAILangChainChatModel; |
23 | 23 | import io.serverlessworkflow.impl.TaskContext; |
24 | 24 | import io.serverlessworkflow.impl.WorkflowApplication; |
25 | 25 | import io.serverlessworkflow.impl.WorkflowContext; |
|
31 | 31 |
|
32 | 32 | public class AIChatModelCallExecutor implements CallableTask<AbstractCallAIChatModelTask> { |
33 | 33 |
|
| 34 | + private AIChatModelExecutor executor; |
| 35 | + |
34 | 36 | @Override |
35 | 37 | public void init( |
36 | | - AbstractCallAIChatModelTask task, WorkflowApplication application, ResourceLoader loader) {} |
| 38 | + AbstractCallAIChatModelTask task, WorkflowApplication application, ResourceLoader loader) { |
| 39 | + if (task instanceof CallAILangChainChatModel model) { |
| 40 | + executor = new CallAILangChainChatModelExecutor(model); |
| 41 | + } else if (task instanceof CallAIChatModel model) { |
| 42 | + executor = new CallAIChatModelExecutor(model); |
| 43 | + } |
| 44 | + } |
37 | 45 |
|
38 | 46 | @Override |
39 | 47 | public CompletableFuture<WorkflowModel> apply( |
40 | 48 | WorkflowContext workflowContext, TaskContext taskContext, WorkflowModel input) { |
41 | 49 | WorkflowModelFactory modelFactory = workflowContext.definition().application().modelFactory(); |
42 | | - if (taskContext.task() instanceof CallAILangChainChatModel callAILangChainChatModel) { |
43 | | - return CompletableFuture.completedFuture( |
44 | | - modelFactory.fromAny( |
45 | | - new CallAILangChainChatModelExecutor() |
46 | | - .apply(callAILangChainChatModel, input.asJavaObject()))); |
47 | | - } else if (taskContext.task() instanceof CallAIChatModel callAIChatModel) { |
48 | | - return CompletableFuture.completedFuture( |
49 | | - modelFactory.fromAny( |
50 | | - new CallAIChatModelExecutor().apply(callAIChatModel, input.asJavaObject()))); |
51 | | - } |
52 | | - throw new IllegalArgumentException( |
53 | | - "AIChatModelCallExecutor can only process CallAIChatModel tasks, but received: " |
54 | | - + taskContext.task().getClass().getName()); |
| 50 | + return CompletableFuture.completedFuture( |
| 51 | + modelFactory.fromAny(executor.apply(input.asJavaObject()))); |
55 | 52 | } |
56 | 53 |
|
57 | 54 | @Override |
|
0 commit comments