Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ public List<Map<String, Object>> handleTask(List<Map<String, Object>> flowData)

if (!this.checkModelAvailable(businessData)) {
this.doOnAgentError(llmMeta, "statusCode=500");
return flowData;
}

// 待add多模态,期望使用image的url,当前传入的历史记录里面没有image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import modelengine.jade.common.globalization.LocaleService;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down Expand Up @@ -119,19 +118,6 @@ public class LlmComponentTest {
@Mock
private AippModelCenter aippModelCenter;

@BeforeEach
void setUp() {
Mockito.when(toolProvider.getTool(any())).thenReturn(Collections.emptyList());
doAnswer(invocationOnMock -> {
Object advice = invocationOnMock.getArgument(0);
Object context = invocationOnMock.getArgument(1);
return new PromptBuilderStub().build(ObjectUtils.cast(advice), ObjectUtils.cast(context));
}).when(this.promptBuilderChain).build(any(), any());

when(this.aippModelCenter.getModelAccessInfo(any(), any(), any())).thenReturn(
ModelAccessInfo.builder().tag("tag").build());
}

static class PromptBuilderStub implements PromptBuilder {
@Override
public Optional<PromptMessage> build(UserAdvice userAdvice, Map<String, Object> context) {
Expand Down Expand Up @@ -232,6 +218,7 @@ public List<ToolInfo> getTool(List<String> name) {
@Disabled("多线程阻塞,无法唤醒")
void shouldOkWhenWaterFlowAgentWithoutAsyncTool() throws InterruptedException {
// stub
this.prepareModel();
AbstractAgent<Prompt, Prompt> agent = this.getWaterFlowAgent(this.buildChatStreamModel(null), false);
LlmComponent llmComponent = getLlmComponent(agent);

Expand All @@ -253,6 +240,7 @@ void shouldOkWhenWaterFlowAgentWithoutAsyncTool() throws InterruptedException {
@Test
void shouldFailWhenWaterFlowAgentThrowException() throws InterruptedException {
// stub
this.prepareModel();
AbstractAgent<Prompt, Prompt> agent = this.getWaterFlowAgent(this.buildChatStreamModel("exceptionMsg"), false);
LlmComponent llmComponent = getLlmComponent(agent);

Expand All @@ -268,6 +256,7 @@ void shouldFailWhenWaterFlowAgentThrowException() throws InterruptedException {
@Disabled("多线程阻塞,无法唤醒")
void shouldOkWhenWaterFlowAgentWithAsyncTool() throws InterruptedException {
// stub
this.prepareModel();
AbstractAgent<Prompt, Prompt> agent = this.getWaterFlowAgent(this.buildChatStreamModel(null), true);
LlmComponent llmComponent = getLlmComponent(agent);

Expand Down Expand Up @@ -305,6 +294,7 @@ void shouldOkWhenWaterFlowAgentWithAsyncTool() throws InterruptedException {
@Test
void shouldOkWhenNoTool() throws InterruptedException {
// stub
this.prepareModel();
AiProcessFlow<Prompt, Prompt> testAgent = AiFlows.<Prompt>create()
.map(m -> ObjectUtils.<Prompt>cast(ChatMessages.from(new AiMessage("bad"))))
.close();
Expand All @@ -322,6 +312,7 @@ void shouldOkWhenNoTool() throws InterruptedException {
@Test
void shouldFailedWhenNoTool() throws InterruptedException {
// stub
this.prepareModel();
AiProcessFlow<Prompt, Prompt> testAgent = AiFlows.<Prompt>create().just(m -> {
int err = 1 / 0;
}).close();
Expand All @@ -341,6 +332,7 @@ void shouldFailedWhenNoTool() throws InterruptedException {
void shouldOkWhenUseWorkflowNoReturn() throws InterruptedException {
AtomicReference<Prompt> prompt = new AtomicReference<>();
// stub
this.prepareModel();
AiProcessFlow<Prompt, Prompt> testAgent = AiFlows.<Prompt>create()
.just(prompt::set)
.map(m -> ObjectUtils.<Prompt>cast(ChatMessages.from(new ToolMessage("1", "\"tool_async\""))))
Expand Down Expand Up @@ -372,6 +364,7 @@ void shouldOkWhenUseWorkflowNoReturn() throws InterruptedException {
@Test
void shouldOkWhenUseWorkflowNormalReturn() throws InterruptedException {
// stub
this.prepareModel();
AtomicBoolean flag = new AtomicBoolean(false);
List<Prompt> prompts = new ArrayList<>();
AiProcessFlow<Prompt, Prompt> testAgent = AiFlows.<Prompt>create().just(m -> prompts.add(m)).map(m -> {
Expand Down Expand Up @@ -428,6 +421,7 @@ private void generateBusinessDataAndCallBack(String childInstanceId, Map<String,
@Test
void shouldOkWhenUseMaxMemoryRounds() throws InterruptedException {
// stub
this.prepareModel();
AiProcessFlow<Prompt, Prompt> testAgent = AiFlows.<Prompt>create().just(m -> {
List<? extends ChatMessage> messages = m.messages();
Assertions.assertEquals(2, messages.size());
Expand All @@ -450,6 +444,7 @@ void shouldOkWhenUseMaxMemoryRounds() throws InterruptedException {
@Test
void shouldFailLLmNodeWhenHandleGivenWorkflowException() throws InterruptedException {
// given
this.prepareModel();
AbstractAgent<Prompt, Prompt> agent = this.getWaterFlowAgent(this.buildChatStreamModel(null), true);
LlmComponent llmComponent = getLlmComponent(agent);

Expand Down Expand Up @@ -501,4 +496,16 @@ private LlmComponent getLlmComponent(final AbstractAgent<Prompt, Prompt> agent)
return new LlmComponent(flowInstanceService, metaInstanceService, toolProvider, agent, aippLogService,
aippLogStreamService, client, serializer, localeService, aippModelCenter, promptBuilderChain);
}

private void prepareModel() {
Mockito.when(toolProvider.getTool(any())).thenReturn(Collections.emptyList());
doAnswer(invocationOnMock -> {
Object advice = invocationOnMock.getArgument(0);
Object context = invocationOnMock.getArgument(1);
return new PromptBuilderStub().build(ObjectUtils.cast(advice), ObjectUtils.cast(context));
}).when(this.promptBuilderChain).build(any(), any());

when(this.aippModelCenter.getModelAccessInfo(any(), any(), any())).thenReturn(
ModelAccessInfo.builder().tag("tag").build());
}
}