Skip to content

Commit 7b59304

Browse files
committed
feat(component): minor performance improvements
1 parent e89811a commit 7b59304

5 files changed

Lines changed: 31 additions & 8 deletions

File tree

src/__tests__/db-query/unit/nodes/semantic-validator.node.unit.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ describe('SemanticValidatorNode Unit', function () {
3636
.stub(schemaHelper, 'getTablesContext')
3737
.returns(['employee salary must be converted to USD']);
3838

39-
node = new SemanticValidatorNode(llm, schemaHelper, ['test context']);
39+
node = new SemanticValidatorNode(llm, llm, {models: []}, schemaHelper, [
40+
'test context',
41+
]);
4042
});
4143

4244
afterEach(() => {

src/__tests__/db-query/unit/nodes/sql-generation.node.unit.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ describe('SqlGenerationNode Unit', function () {
3636
.returns(['Table employees contains employee information']);
3737

3838
node = new SqlGenerationNode(
39+
llm,
3940
llm,
4041
{
4142
db: {

src/components/db-query/nodes/semantic-validator.node.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,17 @@ import {DbQueryAIExtensionBindings} from '../keys';
1212
import {DbQueryNodes} from '../nodes.enum';
1313
import {DbSchemaHelperService} from '../services';
1414
import {DbQueryState} from '../state';
15-
import {EvaluationResult} from '../types';
15+
import {DbQueryConfig, EvaluationResult} from '../types';
1616

1717
@graphNode(DbQueryNodes.SemanticValidator)
1818
export class SemanticValidatorNode implements IGraphNode<DbQueryState> {
1919
constructor(
2020
@inject(AiIntegrationBindings.SmartLLM)
21-
private readonly llm: LLMProvider,
21+
private readonly smartllm: LLMProvider,
22+
@inject(AiIntegrationBindings.CheapLLM)
23+
private readonly cheapllm: LLMProvider,
24+
@inject(DbQueryAIExtensionBindings.Config)
25+
private readonly config: DbQueryConfig,
2226
@service(DbSchemaHelperService)
2327
private readonly schemaHelper: DbSchemaHelperService,
2428
@inject(DbQueryAIExtensionBindings.GlobalContext, {optional: true})
@@ -87,7 +91,10 @@ Keep these feedbacks in mind while validating the new query.
8791
type: LLMStreamEventType.Log,
8892
data: `Validating the query semantically.`,
8993
});
90-
const chain = RunnableSequence.from([this.prompt, this.llm]);
94+
const useSmartLLM =
95+
this.config.nodes?.semanticValidatorNode?.useSmartLLM ?? false;
96+
const llm = useSmartLLM ? this.smartllm : this.cheapllm;
97+
const chain = RunnableSequence.from([this.prompt, llm]);
9198
const output = await chain.invoke({
9299
query: state.sql,
93100
prompt: state.prompt,

src/components/db-query/nodes/sql-generation.node.ts

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ In the last attempt, you generated this SQL query -
8686
constructor(
8787
@inject(AiIntegrationBindings.SmartLLM)
8888
private readonly sqlLLM: LLMProvider,
89+
@inject(AiIntegrationBindings.CheapLLM)
90+
private readonly cheapllm: LLMProvider,
8991
@inject(DbQueryAIExtensionBindings.Config)
9092
private readonly config: DbQueryConfig,
9193
@service(DbSchemaHelperService)
@@ -97,10 +99,17 @@ In the last attempt, you generated this SQL query -
9799
state: DbQueryState,
98100
config: LangGraphRunnableConfig,
99101
): Promise<DbQueryState> {
100-
const chain = RunnableSequence.from([
101-
this.sqlGenerationPrompt,
102-
this.sqlLLM,
103-
]);
102+
let llm = this.sqlLLM;
103+
104+
if (
105+
this.config.nodes?.sqlGenerationNode?.generateDescription !== false &&
106+
state.schema.tables &&
107+
Object.keys(state.schema.tables).length === 1
108+
) {
109+
llm = this.cheapllm;
110+
}
111+
112+
const chain = RunnableSequence.from([this.sqlGenerationPrompt, llm]);
104113

105114
config.writer?.({
106115
type: LLMStreamEventType.Log,

src/components/db-query/types.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,14 @@ export type DbQueryConfig = {
115115
nodes?: {
116116
sqlGenerationNode?: {
117117
generateDescription?: boolean;
118+
useSmartLLMForSingleTableQueries?: boolean;
118119
};
119120
getTablesNode?: {
120121
useSmartLLM?: boolean;
121122
};
123+
semanticValidatorNode?: {
124+
useSmartLLM?: boolean;
125+
};
122126
};
123127
columnSelection?: boolean;
124128
};

0 commit comments

Comments
 (0)