-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathrunnables.py
75 lines (63 loc) · 2.85 KB
/
runnables.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import Any, Dict, List, Optional
from langchain_core.language_models import BaseChatModel
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from langchain_experimental.graph_transformers.llm import optional_enum_field
from .traverse import Node
QUERY_ENTITY_EXTRACT_PROMPT = (
"A question is provided below. Given the question, extract up to 5 "
"entities (name and type) from the text. Focus on extracting the entities "
" that we can use to best lookup answers to the question. Avoid stopwords.\n"
"---------------------\n"
"{question}\n"
"---------------------\n"
"{format_instructions}\n"
)
# TODO: Use a knowledge schema when extracting entities,
# to get the right kinds of nodes.
def extract_entities(
llm: BaseChatModel,
keyword_extraction_prompt: str = QUERY_ENTITY_EXTRACT_PROMPT,
node_types: Optional[List[str]] = None,
) -> Runnable[Dict[str, Any], List[Node]]:
"""Return a keyword-extraction runnable.
This will expect a dictionary containing the `"question"` to extract keywords from.
Args:
llm: The LLM to use for extracting entities.
node_types: List of node types to extract.
keyword_extraction_prompt: The prompt to use for requesting entities.
This should include the `{question}` being asked as well as the
`{format_instructions}` which describe how to produce the output.
"""
prompt = ChatPromptTemplate.from_messages([keyword_extraction_prompt])
if "question" not in prompt.input_variables:
raise ValueError(
"Missing 'question' placeholder in extraction prompt template."
)
if "format_instructions" not in prompt.input_variables:
raise ValueError(
"Missing 'format_instructions' placeholder in extraction prompt template."
)
class SimpleNode(BaseModel):
"""Represents a node in a graph with associated properties."""
id: str = Field(description="Name or human-readable unique identifier.")
type: str = optional_enum_field(
node_types, description="The type or label of the node."
)
class SimpleNodeList(BaseModel):
"""Represents a list of simple nodes."""
nodes: List[SimpleNode]
output_parser = JsonOutputParser(pydantic_object=SimpleNodeList)
return (
RunnablePassthrough.assign(
format_instructions=lambda _: output_parser.get_format_instructions(),
)
| ChatPromptTemplate.from_messages([keyword_extraction_prompt])
| llm
| output_parser
| RunnableLambda(
lambda node_list: [Node(n["id"], n["type"]) for n in node_list["nodes"]]
)
)