|
4 | 4 | from dataall.base.db import exceptions |
5 | 5 | from dataall.base.aws.sts import SessionHelper |
6 | 6 | from typing import List, Optional |
7 | | -from langchain.prompts import PromptTemplate |
| 7 | +from langchain_core.prompts import PromptTemplate |
8 | 8 | from langchain_core.pydantic_v1 import BaseModel |
9 | | -from langchain_aws import BedrockLLM |
| 9 | +from langchain_aws import ChatBedrock as BedrockChat |
10 | 10 | from langchain_core.output_parsers import JsonOutputParser |
11 | 11 |
|
12 | 12 | log = logging.getLogger(__name__) |
@@ -34,65 +34,68 @@ class BedrockClient: |
34 | 34 | def __init__(self): |
35 | 35 | session = SessionHelper.get_session() |
36 | 36 | self._client = session.client('bedrock-runtime', region_name=os.getenv('AWS_REGION', 'eu-west-1')) |
37 | | - model_id = 'anthropic.claude-3-5-sonnet-20240620-v1:0' |
| 37 | + model_id = 'eu.anthropic.claude-3-5-sonnet-20240620-v1:0' |
38 | 38 | model_kwargs = { |
39 | 39 | 'max_tokens': 4096, |
40 | 40 | 'temperature': 0.5, |
41 | 41 | 'top_k': 250, |
42 | 42 | 'top_p': 0.5, |
43 | 43 | 'stop_sequences': ['\n\nHuman'], |
44 | 44 | } |
45 | | - self._model = BedrockLLM(model_id=model_id, client=self._client, model_kwargs=model_kwargs) |
| 45 | + self._model = BedrockChat(client=self._client, model_id=model_id, model_kwargs=model_kwargs) |
46 | 46 |
|
47 | 47 | def invoke_model_dataset_metadata(self, metadata_types, dataset, tables, folders): |
48 | | - prompt_template = PromptTemplate.from_file(METADATA_GENERATION_DATASET_TEMPLATE_PATH) |
49 | | - parser = JsonOutputParser(pydantic_object=MetadataOutput) |
50 | | - chain = prompt_template | self._model | parser |
51 | | - context = { |
52 | | - 'metadata_types': metadata_types, |
53 | | - 'label': dataset.label, |
54 | | - 'description': dataset.description, |
55 | | - 'tags': dataset.tags, |
56 | | - 'table_labels': [t.label for t in tables], |
57 | | - 'table_descriptions': [t.description for t in tables], |
58 | | - 'folder_labels': [f.label for f in folders], |
59 | | - } |
60 | | - response = chain.invoke(context) |
61 | | - if response.startswith('Error:'): |
62 | | - raise exceptions.ModelGuardrailException(response) |
63 | | - return response |
| 48 | + try: |
| 49 | + prompt_template = PromptTemplate.from_file(METADATA_GENERATION_DATASET_TEMPLATE_PATH) |
| 50 | + parser = JsonOutputParser(pydantic_object=MetadataOutput) |
| 51 | + chain = prompt_template | self._model | parser |
| 52 | + context = { |
| 53 | + 'metadata_types': metadata_types, |
| 54 | + 'dataset_label': dataset.label, |
| 55 | + 'description': dataset.description, |
| 56 | + 'tags': dataset.tags, |
| 57 | + 'topics': dataset.topics, |
| 58 | + 'table_names': [t.label for t in tables], |
| 59 | + 'table_descriptions': [t.description for t in tables], |
| 60 | + 'folder_names': [f.label for f in folders], |
| 61 | + } |
| 62 | + return chain.invoke(context) |
| 63 | + except Exception as e: |
| 64 | + raise e |
64 | 65 |
|
65 | 66 | def invoke_model_table_metadata(self, metadata_types, table, columns, sample_data, generate_columns_metadata=False): |
66 | | - prompt_template = PromptTemplate.from_file(METADATA_GENERATION_TABLE_TEMPLATE_PATH) |
67 | | - parser = JsonOutputParser(pydantic_object=MetadataOutput) |
68 | | - chain = prompt_template | self._model | parser |
69 | | - context = { |
70 | | - 'metadata_types': metadata_types, |
71 | | - 'generate_columns_metadata': generate_columns_metadata, |
72 | | - 'label': table.label, |
73 | | - 'description': table.description, |
74 | | - 'tags': table.tags, |
75 | | - 'column_labels': [c.label for c in columns], |
76 | | - 'column_descriptions': [c.description for c in columns], |
77 | | - 'sample_data': sample_data, |
78 | | - } |
79 | | - response = chain.invoke(context) |
80 | | - if response.startswith('Error:'): |
81 | | - raise exceptions.ModelGuardrailException(response) |
82 | | - return response |
| 67 | + try: |
| 68 | + prompt_template = PromptTemplate.from_file(METADATA_GENERATION_TABLE_TEMPLATE_PATH) |
| 69 | + parser = JsonOutputParser(pydantic_object=MetadataOutput) |
| 70 | + chain = prompt_template | self._model | parser |
| 71 | + context = { |
| 72 | + 'metadata_types': metadata_types, |
| 73 | + 'generate_columns_metadata': generate_columns_metadata, |
| 74 | + 'label': table.label, |
| 75 | + 'description': table.description, |
| 76 | + 'tags': table.tags, |
| 77 | + 'topics': table.topics, |
| 78 | + 'column_labels': [c.label for c in columns], |
| 79 | + 'column_descriptions': [c.description for c in columns], |
| 80 | + 'sample_data': sample_data, |
| 81 | + } |
| 82 | + return chain.invoke(context) |
| 83 | + except Exception as e: |
| 84 | + raise e |
83 | 85 |
|
84 | 86 | def invoke_model_folder_metadata(self, metadata_types, folder, files): |
85 | | - prompt_template = PromptTemplate.from_file(METADATA_GENERATION_FOLDER_TEMPLATE_PATH) |
86 | | - parser = JsonOutputParser(pydantic_object=MetadataOutput) |
87 | | - chain = prompt_template | self._model | parser |
88 | | - context = { |
89 | | - 'metadata_types': metadata_types, |
90 | | - 'label': folder.label, |
91 | | - 'description': folder.description, |
92 | | - 'tags': folder.tags, |
93 | | - 'file_names': files, |
94 | | - } |
95 | | - response = chain.invoke(context) |
96 | | - if response.startswith('Error:'): |
97 | | - raise exceptions.ModelGuardrailException(response) |
98 | | - return response |
| 87 | + try: |
| 88 | + prompt_template = PromptTemplate.from_file(METADATA_GENERATION_FOLDER_TEMPLATE_PATH) |
| 89 | + parser = JsonOutputParser(pydantic_object=MetadataOutput) |
| 90 | + chain = prompt_template | self._model | parser |
| 91 | + context = { |
| 92 | + 'metadata_types': metadata_types, |
| 93 | + 'label': folder.label, |
| 94 | + 'description': folder.description, |
| 95 | + 'tags': folder.tags, |
| 96 | + 'topics': folder.topics, |
| 97 | + 'file_names': files, |
| 98 | + } |
| 99 | + return chain.invoke(context) |
| 100 | + except Exception as e: |
| 101 | + raise e |
0 commit comments