Skip to content

Commit 82f224e

Browse files
Chahat chughChahat chugh
authored andcommitted
use cheap llm for similar prompts stored in cache
1 parent 9bd6125 commit 82f224e

2 files changed

Lines changed: 374 additions & 3 deletions

File tree

src/__tests__/db-query/unit/nodes/sql-generation.node.unit.ts

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,4 +476,370 @@ Do not use any DB concepts like enum numbers, joins, CTEs, subqueries etc. in th
476476
/This was generated for the following question - \nGet employee name by id/,
477477
);
478478
});
479+
480+
describe('Cheap LLM usage optimization', () => {
481+
let smartLLMStub: sinon.SinonStub;
482+
let cheapLLMStub: sinon.SinonStub;
483+
let nodeWithTwoLLMs: SqlGenerationNode;
484+
let originalEnv: string | undefined;
485+
486+
beforeEach(() => {
487+
smartLLMStub = sinon.stub();
488+
cheapLLMStub = sinon.stub();
489+
originalEnv = process.env.OPTIMIZE_CACHED_QUERIES;
490+
491+
const smartLLM = smartLLMStub as unknown as LLMProvider;
492+
const cheapLLM = cheapLLMStub as unknown as LLMProvider;
493+
494+
nodeWithTwoLLMs = new SqlGenerationNode(
495+
smartLLM,
496+
cheapLLM,
497+
{
498+
db: {
499+
dialect: SupportedDBs.SQLite,
500+
},
501+
models: [],
502+
},
503+
schemaHelper,
504+
['test context'],
505+
);
506+
});
507+
508+
afterEach(() => {
509+
if (originalEnv === undefined) {
510+
delete process.env.OPTIMIZE_CACHED_QUERIES;
511+
} else {
512+
process.env.OPTIMIZE_CACHED_QUERIES = originalEnv;
513+
}
514+
});
515+
516+
it('should use cheap LLM when OPTIMIZE_CACHED_QUERIES is true and sampleSql exists', async () => {
517+
process.env.OPTIMIZE_CACHED_QUERIES = 'true';
518+
cheapLLMStub.resolves({
519+
content:
520+
'<sql>SELECT * FROM employees WHERE id = 1;</sql><description>Get employee by id</description>',
521+
});
522+
523+
const state = {
524+
prompt: 'Get employee by id 1',
525+
schema: {
526+
tables: {
527+
employees: {
528+
columns: {
529+
id: {type: 'number', required: true, id: true},
530+
name: {type: 'string', required: true, id: false},
531+
},
532+
primaryKey: ['id'],
533+
description: 'Employee table',
534+
context: [],
535+
hash: 'hash1',
536+
},
537+
departments: {
538+
columns: {
539+
id: {type: 'number', required: true, id: true},
540+
name: {type: 'string', required: true, id: false},
541+
},
542+
primaryKey: ['id'],
543+
description: 'Department table',
544+
context: [],
545+
hash: 'hash2',
546+
},
547+
},
548+
relations: [],
549+
},
550+
feedbacks: [],
551+
sampleSql: 'SELECT name FROM employees WHERE id = 5',
552+
sampleSqlPrompt: 'Get employee name',
553+
done: false,
554+
sql: undefined,
555+
status: undefined,
556+
id: '123',
557+
replyToUser: undefined,
558+
datasetId: undefined,
559+
fromCache: true,
560+
resultArray: undefined,
561+
directCall: false,
562+
description: undefined,
563+
};
564+
565+
const result = await nodeWithTwoLLMs.execute(state, {});
566+
567+
expect(result.sql).to.equal('SELECT * FROM employees WHERE id = 1;');
568+
sinon.assert.calledOnce(cheapLLMStub);
569+
sinon.assert.notCalled(smartLLMStub);
570+
});
571+
572+
it('should use smart LLM when OPTIMIZE_CACHED_QUERIES is false and sampleSql exists', async () => {
573+
process.env.OPTIMIZE_CACHED_QUERIES = 'false';
574+
smartLLMStub.resolves({
575+
content:
576+
'<sql>SELECT * FROM employees WHERE id = 1;</sql><description>Get employee by id</description>',
577+
});
578+
579+
const state = {
580+
prompt: 'Get employee by id 1',
581+
schema: {
582+
tables: {
583+
employees: {
584+
columns: {
585+
id: {type: 'number', required: true, id: true},
586+
name: {type: 'string', required: true, id: false},
587+
},
588+
primaryKey: ['id'],
589+
description: 'Employee table',
590+
context: [],
591+
hash: 'hash1',
592+
},
593+
departments: {
594+
columns: {
595+
id: {type: 'number', required: true, id: true},
596+
name: {type: 'string', required: true, id: false},
597+
},
598+
primaryKey: ['id'],
599+
description: 'Department table',
600+
context: [],
601+
hash: 'hash2',
602+
},
603+
},
604+
relations: [],
605+
},
606+
feedbacks: [],
607+
sampleSql: 'SELECT name FROM employees WHERE id = 5',
608+
sampleSqlPrompt: 'Get employee name',
609+
done: false,
610+
sql: undefined,
611+
status: undefined,
612+
id: '123',
613+
replyToUser: undefined,
614+
datasetId: undefined,
615+
fromCache: true,
616+
resultArray: undefined,
617+
directCall: false,
618+
description: undefined,
619+
};
620+
621+
const result = await nodeWithTwoLLMs.execute(state, {});
622+
623+
expect(result.sql).to.equal('SELECT * FROM employees WHERE id = 1;');
624+
sinon.assert.calledOnce(smartLLMStub);
625+
sinon.assert.notCalled(cheapLLMStub);
626+
});
627+
628+
it('should use cheap LLM for single table schemas regardless of cache', async () => {
629+
process.env.OPTIMIZE_CACHED_QUERIES = 'false';
630+
cheapLLMStub.resolves({
631+
content:
632+
'<sql>SELECT * FROM employees;</sql><description>Get all employees</description>',
633+
});
634+
635+
const state = {
636+
prompt: 'Get all employees',
637+
schema: {
638+
tables: {
639+
employees: {
640+
columns: {
641+
id: {type: 'number', required: true, id: true},
642+
name: {type: 'string', required: true, id: false},
643+
},
644+
primaryKey: ['id'],
645+
description: 'Employee table',
646+
context: [],
647+
hash: 'hash1',
648+
},
649+
},
650+
relations: [],
651+
},
652+
feedbacks: [],
653+
sampleSql: undefined,
654+
sampleSqlPrompt: undefined,
655+
done: false,
656+
sql: undefined,
657+
status: undefined,
658+
id: '123',
659+
replyToUser: undefined,
660+
datasetId: undefined,
661+
fromCache: false,
662+
resultArray: undefined,
663+
directCall: false,
664+
description: undefined,
665+
};
666+
667+
const result = await nodeWithTwoLLMs.execute(state, {});
668+
669+
expect(result.sql).to.equal('SELECT * FROM employees;');
670+
sinon.assert.calledOnce(cheapLLMStub);
671+
sinon.assert.notCalled(smartLLMStub);
672+
});
673+
674+
it('should use smart LLM for multiple tables without cached queries', async () => {
675+
process.env.OPTIMIZE_CACHED_QUERIES = 'true';
676+
smartLLMStub.resolves({
677+
content:
678+
'<sql>SELECT e.name, d.name FROM employees e JOIN departments d ON e.dept_id = d.id;</sql><description>Get employee and department names</description>',
679+
});
680+
681+
const state = {
682+
prompt: 'Get employees with their departments',
683+
schema: {
684+
tables: {
685+
employees: {
686+
columns: {
687+
id: {type: 'number', required: true, id: true},
688+
name: {type: 'string', required: true, id: false},
689+
deptId: {type: 'number', required: false, id: false},
690+
},
691+
primaryKey: ['id'],
692+
description: 'Employee table',
693+
context: [],
694+
hash: 'hash1',
695+
},
696+
departments: {
697+
columns: {
698+
id: {type: 'number', required: true, id: true},
699+
name: {type: 'string', required: true, id: false},
700+
},
701+
primaryKey: ['id'],
702+
description: 'Department table',
703+
context: [],
704+
hash: 'hash2',
705+
},
706+
},
707+
relations: [],
708+
},
709+
feedbacks: [],
710+
sampleSql: undefined,
711+
sampleSqlPrompt: undefined,
712+
done: false,
713+
sql: undefined,
714+
status: undefined,
715+
id: '123',
716+
replyToUser: undefined,
717+
datasetId: undefined,
718+
fromCache: false,
719+
resultArray: undefined,
720+
directCall: false,
721+
description: undefined,
722+
};
723+
724+
const result = await nodeWithTwoLLMs.execute(state, {});
725+
726+
expect(result.sql).to.equal(
727+
'SELECT e.name, d.name FROM employees e JOIN departments d ON e.dept_id = d.id;',
728+
);
729+
sinon.assert.calledOnce(smartLLMStub);
730+
sinon.assert.notCalled(cheapLLMStub);
731+
});
732+
733+
it('should default to true for OPTIMIZE_CACHED_QUERIES when env var is not set', async () => {
734+
delete process.env.OPTIMIZE_CACHED_QUERIES;
735+
cheapLLMStub.resolves({
736+
content:
737+
'<sql>SELECT * FROM employees WHERE id = 1;</sql><description>Get employee by id</description>',
738+
});
739+
740+
const state = {
741+
prompt: 'Get employee by id 1',
742+
schema: {
743+
tables: {
744+
employees: {
745+
columns: {
746+
id: {type: 'number', required: true, id: true},
747+
name: {type: 'string', required: true, id: false},
748+
},
749+
primaryKey: ['id'],
750+
description: 'Employee table',
751+
context: [],
752+
hash: 'hash1',
753+
},
754+
departments: {
755+
columns: {
756+
id: {type: 'number', required: true, id: true},
757+
name: {type: 'string', required: true, id: false},
758+
},
759+
primaryKey: ['id'],
760+
description: 'Department table',
761+
context: [],
762+
hash: 'hash2',
763+
},
764+
},
765+
relations: [],
766+
},
767+
feedbacks: [],
768+
sampleSql: 'SELECT name FROM employees WHERE id = 5',
769+
sampleSqlPrompt: 'Get employee name',
770+
done: false,
771+
sql: undefined,
772+
status: undefined,
773+
id: '123',
774+
replyToUser: undefined,
775+
datasetId: undefined,
776+
fromCache: true,
777+
resultArray: undefined,
778+
directCall: false,
779+
description: undefined,
780+
};
781+
782+
const result = await nodeWithTwoLLMs.execute(state, {});
783+
784+
expect(result.sql).to.equal('SELECT * FROM employees WHERE id = 1;');
785+
sinon.assert.calledOnce(cheapLLMStub);
786+
sinon.assert.notCalled(smartLLMStub);
787+
});
788+
789+
it('should use smart LLM when sampleSql is null despite optimization being enabled', async () => {
790+
process.env.OPTIMIZE_CACHED_QUERIES = 'true';
791+
smartLLMStub.resolves({
792+
content:
793+
'<sql>SELECT * FROM employees, departments;</sql><description>Get all data</description>',
794+
});
795+
796+
const state = {
797+
prompt: 'Get all data',
798+
schema: {
799+
tables: {
800+
employees: {
801+
columns: {
802+
id: {type: 'number', required: true, id: true},
803+
name: {type: 'string', required: true, id: false},
804+
},
805+
primaryKey: ['id'],
806+
description: 'Employee table',
807+
context: [],
808+
hash: 'hash1',
809+
},
810+
departments: {
811+
columns: {
812+
id: {type: 'number', required: true, id: true},
813+
name: {type: 'string', required: true, id: false},
814+
},
815+
primaryKey: ['id'],
816+
description: 'Department table',
817+
context: [],
818+
hash: 'hash2',
819+
},
820+
},
821+
relations: [],
822+
},
823+
feedbacks: [],
824+
sampleSql: undefined,
825+
sampleSqlPrompt: undefined,
826+
done: false,
827+
sql: undefined,
828+
status: undefined,
829+
id: '123',
830+
replyToUser: undefined,
831+
datasetId: undefined,
832+
fromCache: false,
833+
resultArray: undefined,
834+
directCall: false,
835+
description: undefined,
836+
};
837+
838+
const result = await nodeWithTwoLLMs.execute(state, {});
839+
840+
expect(result.sql).to.equal('SELECT * FROM employees, departments;');
841+
sinon.assert.calledOnce(smartLLMStub);
842+
sinon.assert.notCalled(cheapLLMStub);
843+
});
844+
});
479845
});

src/components/db-query/nodes/sql-generation.node.ts

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,16 @@ In the last attempt, you generated this SQL query -
101101
): Promise<DbQueryState> {
102102
let llm = this.sqlLLM;
103103

104-
if (
104+
// Check if cheap LLM should be used based on cache relevance (Similar match)
105+
const useCheapLLMForCachedQueries =
106+
(process.env.OPTIMIZE_CACHED_QUERIES ?? 'true') === 'true';
107+
108+
const isSingleTable =
105109
this.config.nodes?.sqlGenerationNode?.generateDescription !== false &&
106110
state.schema.tables &&
107-
Object.keys(state.schema.tables).length === 1
108-
) {
111+
Object.keys(state.schema.tables).length === 1;
112+
113+
if ((useCheapLLMForCachedQueries && !!state.sampleSql) || isSingleTable) {
109114
llm = this.cheapllm;
110115
}
111116

0 commit comments

Comments
 (0)