|
14 | 14 | ) |
15 | 15 | from langchain.chains.openai_functions import create_openai_fn_chain |
16 | 16 | from langchain.chains.summarize import load_summarize_chain |
17 | | -from langchain.chat_models import AzureChatOpenAI, ChatOpenAI |
| 17 | +from langchain_community.chat_models import AzureChatOpenAI, ChatOpenAI |
18 | 18 | from langchain.document_loaders import TextLoader |
19 | 19 | from langchain.embeddings.openai import OpenAIEmbeddings |
20 | 20 | from langchain.llms import OpenAI |
@@ -1302,3 +1302,77 @@ def _identifying_params(self) -> Mapping[str, Any]: |
1302 | 1302 |
|
1303 | 1303 | assert custom_generation.output == "This is a " |
1304 | 1304 | assert custom_generation.model is None |
| 1305 | + |
| 1306 | + |
| 1307 | +def test_names_on_spans_lcel(): |
| 1308 | + from langchain_core.output_parsers import StrOutputParser |
| 1309 | + from langchain_core.runnables import RunnablePassthrough |
| 1310 | + from langchain_openai import OpenAIEmbeddings |
| 1311 | + |
| 1312 | + callback = CallbackHandler(debug=False) |
| 1313 | + model = ChatOpenAI(temperature=0) |
| 1314 | + |
| 1315 | + template = """Answer the question based only on the following context: |
| 1316 | + {context} |
| 1317 | +
|
| 1318 | + Question: {question} |
| 1319 | + """ |
| 1320 | + prompt = ChatPromptTemplate.from_template(template) |
| 1321 | + |
| 1322 | + loader = TextLoader("./static/state_of_the_union.txt", encoding="utf8") |
| 1323 | + |
| 1324 | + documents = loader.load() |
| 1325 | + text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) |
| 1326 | + texts = text_splitter.split_documents(documents) |
| 1327 | + |
| 1328 | + embeddings = OpenAIEmbeddings(openai_api_key=os.environ.get("OPENAI_API_KEY")) |
| 1329 | + docsearch = Chroma.from_documents(texts, embeddings) |
| 1330 | + |
| 1331 | + retriever = docsearch.as_retriever() |
| 1332 | + |
| 1333 | + retrieval_chain = ( |
| 1334 | + { |
| 1335 | + "context": retriever.with_config(run_name="Docs"), |
| 1336 | + "question": RunnablePassthrough(), |
| 1337 | + } |
| 1338 | + | prompt |
| 1339 | + | model.with_config(run_name="my_llm") |
| 1340 | + | StrOutputParser() |
| 1341 | + ) |
| 1342 | + |
| 1343 | + retrieval_chain.invoke( |
| 1344 | + "What did the president say about Ketanji Brown Jackson?", |
| 1345 | + config={ |
| 1346 | + "callbacks": [callback], |
| 1347 | + }, |
| 1348 | + ) |
| 1349 | + |
| 1350 | + callback.flush() |
| 1351 | + api = get_api() |
| 1352 | + trace = api.trace.get(callback.get_trace_id()) |
| 1353 | + |
| 1354 | + assert len(trace.observations) == 7 |
| 1355 | + |
| 1356 | + assert ( |
| 1357 | + len( |
| 1358 | + list( |
| 1359 | + filter( |
| 1360 | + lambda x: x.type == "GENERATION" and x.name == "my_llm", |
| 1361 | + trace.observations, |
| 1362 | + ) |
| 1363 | + ) |
| 1364 | + ) |
| 1365 | + == 1 |
| 1366 | + ) |
| 1367 | + |
| 1368 | + assert ( |
| 1369 | + len( |
| 1370 | + list( |
| 1371 | + filter( |
| 1372 | + lambda x: x.type == "SPAN" and x.name == "Docs", |
| 1373 | + trace.observations, |
| 1374 | + ) |
| 1375 | + ) |
| 1376 | + ) |
| 1377 | + == 1 |
| 1378 | + ) |
0 commit comments