Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use python 3.11 for knowledge-graph linting #620

Merged
merged 2 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions libs/knowledge-graph/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pytest-dotenv = "^0.5.2"
pytest-rerunfailures = "^14.0"
mypy = "^1.10.1"
types-pyyaml = "^6.0.1"
pydantic = "<2" # for compatibility between LangChain and pydantic-yaml type checking
pydantic = "^2.6.0"

[build-system]
requires = ["poetry-core"]
Expand All @@ -61,4 +61,81 @@ warn_unused_ignores = true

[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "auto"
asyncio_mode = "auto"

[tool.ruff]
target-version = "py311"

[tool.ruff.lint]
pydocstyle.convention = "google"
ignore = [
"COM812", # Messes with the formatter
"D100", # Do we want to activate (docstring in module) ?
"D104", # Do we want to activate (docstring in package) ?
"D105", # Do we want to activate (docstring in magic method) ?
"D107", # Do we want to activate (docstring in __init__) ?
"ERA", # Do we want to activate (no commented code) ?
"ISC001", # Messes with the formatter
"PERF203", # Incorrect detection
"PLR09", # TODO: do we enforce these ones (complexity) ?
"TRY003", # A bit too strict ?
"TD002", # We know the TODOs authors with git. Activate anyway ?
"TD003", # Do we want to activate (TODOs with issue reference) ?
]

select = [
"A",
"ARG",
"ASYNC",
"B",
"BLE",
"C4",
"COM",
"D",
"DTZ",
"E",
"EXE",
"F",
"FLY",
"FURB",
"G",
"I",
"ICN",
"INP",
"INT",
"ISC",
"LOG",
"N",
"NPY",
"PD",
"PERF",
"PGH",
"PIE",
"PL",
"PT",
"PYI",
"Q",
"RET",
"RSE",
"RUF",
"S",
"SIM",
"SLF",
"SLOT",
"T10",
"T20",
"TCH",
"TD",
"TID",
"TRY",
"UP",
"W",
"YTT",
]

[tool.ruff.lint.per-file-ignores]
"tests/*" = [
"D",
"S101",
"T20",
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
from collections.abc import Iterable, Sequence
from typing import Any

from cassandra.cluster import Session
from langchain_community.graphs.graph_document import GraphDocument
Expand All @@ -12,7 +13,7 @@
from .traverse import Node, Relation


def _elements(documents: Iterable[GraphDocument]) -> Iterable[Union[Node, Relation]]:
def _elements(documents: Iterable[GraphDocument]) -> Iterable[Node | Relation]:
def _node(node: LangChainNode) -> Node:
return Node(name=str(node.id), type=node.type)

Expand All @@ -32,9 +33,9 @@ def __init__(
self,
node_table: str = "entities",
edge_table: str = "relationships",
text_embeddings: Optional[Embeddings] = None,
session: Optional[Session] = None,
keyspace: Optional[str] = None,
text_embeddings: Embeddings | None = None,
session: Session | None = None,
keyspace: str | None = None,
) -> None:
"""Create a Cassandra Graph Store.

Expand All @@ -51,16 +52,16 @@ def __init__(

@override
def add_graph_documents(
self, graph_documents: List[GraphDocument], include_source: bool = False
self, graph_documents: list[GraphDocument], include_source: bool = False
) -> None:
# TODO: Include source.
self.graph.insert(_elements(graph_documents))

# TODO: should this include the types of each node?
@override
def query(
self, query: str, params: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
self, query: str, params: dict[str, Any] | None = None
) -> list[dict[str, Any]]:
raise ValueError("Querying Cassandra should use `as_runnable`.")

@override
Expand All @@ -71,7 +72,7 @@ def get_schema(self) -> str:

@property
@override
def get_structured_schema(self) -> Dict[str, Any]:
def get_structured_schema(self) -> dict[str, Any]:
raise NotImplementedError

@override
Expand All @@ -80,7 +81,7 @@ def refresh_schema(self) -> None:

def as_runnable(
self, steps: int = 3, edge_filters: Sequence[str] = ()
) -> Runnable[Union[Node, Sequence[Node]], Iterable[Relation]]:
) -> Runnable[Node | Sequence[Node], Iterable[Relation]]:
"""Convert to a runnable.

Returns a runnable that retrieves the sub-graph near the input entity or
Expand Down
12 changes: 7 additions & 5 deletions libs/knowledge-graph/ragstack_knowledge_graph/extraction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union, cast
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, cast

from langchain_community.graphs.graph_document import GraphDocument
from langchain_core.documents import Document
Expand Down Expand Up @@ -31,7 +32,8 @@
def _format_example(idx: int, example: Example) -> str:
from pydantic_yaml import to_yaml_str

return f"Example {idx}:\n```yaml\n{to_yaml_str(example)}\n```"
yaml_example = to_yaml_str(example) # type: ignore[arg-type]
return f"Example {idx}:\n```yaml\n{yaml_example}\n```"


class KnowledgeSchemaExtractor:
Expand All @@ -47,7 +49,7 @@ def __init__(
self._validator = KnowledgeSchemaValidator(schema)
self.strict = strict

messages: List[MessageLikeRepresentation] = [
messages: list[MessageLikeRepresentation] = [
SystemMessagePromptTemplate(
prompt=load_template(
"extraction.md", knowledge_schema_yaml=schema.to_yaml_str()
Expand All @@ -73,7 +75,7 @@ def __init__(
self._chain = prompt | structured_llm

def _process_response(
self, document: Document, response: Union[Dict[str, Any], BaseModel]
self, document: Document, response: dict[str, Any] | BaseModel
) -> GraphDocument:
raw_graph = cast(_Graph, response)
nodes = (
Expand All @@ -96,7 +98,7 @@ def _process_response(

return graph_document

def extract(self, documents: List[Document]) -> List[GraphDocument]:
def extract(self, documents: list[Document]) -> list[GraphDocument]:
"""Extract knowledge graphs from a list of documents."""
# TODO: Define an async version of extraction?
responses = self._chain.batch_as_completed(
Expand Down
25 changes: 13 additions & 12 deletions libs/knowledge-graph/ragstack_knowledge_graph/knowledge_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
import re
from collections.abc import Iterable, Sequence
from itertools import repeat
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union, cast
from typing import Any, cast

from cassandra.cluster import ResponseFuture, Session
from cassandra.query import BatchStatement
Expand All @@ -12,12 +13,12 @@
from .utils import batched


def _serialize_md_dict(md_dict: Dict[str, Any]) -> str:
def _serialize_md_dict(md_dict: dict[str, Any]) -> str:
return json.dumps(md_dict, separators=(",", ":"), sort_keys=True)


def _deserialize_md_dict(md_string: str) -> Dict[str, Any]:
return cast(Dict[str, Any], json.loads(md_string))
def _deserialize_md_dict(md_string: str) -> dict[str, Any]:
return cast(dict[str, Any], json.loads(md_string))


def _parse_node(row: Any) -> Node:
Expand Down Expand Up @@ -52,9 +53,9 @@ def __init__(
self,
node_table: str = "entities",
edge_table: str = "relationships",
text_embeddings: Optional[Embeddings] = None,
session: Optional[Session] = None,
keyspace: Optional[str] = None,
text_embeddings: Embeddings | None = None,
session: Session | None = None,
keyspace: str | None = None,
apply_schema: bool = True,
) -> None:
session = check_resolve_session(session)
Expand Down Expand Up @@ -197,7 +198,7 @@ def query_nearest_nodes(self, nodes: Iterable[str], k: int = 1) -> Iterable[Node
# TODO: Introduce `ainsert` for async insertions.
def insert(
self,
elements: Iterable[Union[Node, Relation]],
elements: Iterable[Node | Relation],
) -> None:
"""Insert the given elements into the graph."""
for batch in batched(elements, n=4):
Expand Down Expand Up @@ -245,10 +246,10 @@ def insert(

def subgraph(
self,
start: Union[Node, Sequence[Node]],
start: Node | Sequence[Node],
edge_filters: Sequence[str] = (),
steps: int = 3,
) -> Tuple[Iterable[Node], Iterable[Relation]]:
) -> tuple[Iterable[Node], Iterable[Relation]]:
"""Retrieve the sub-graph from the given starting nodes."""
edges = self.traverse(start, edge_filters, steps)

Expand All @@ -274,7 +275,7 @@ def subgraph(

def traverse(
self,
start: Union[Node, Sequence[Node]],
start: Node | Sequence[Node],
edge_filters: Sequence[str] = (),
steps: int = 3,
) -> Iterable[Relation]:
Expand Down Expand Up @@ -306,7 +307,7 @@ def traverse(

async def atraverse(
self,
start: Union[Node, Sequence[Node]],
start: Node | Sequence[Node],
edge_filters: Sequence[str] = (),
steps: int = 3,
) -> Iterable[Relation]:
Expand Down
19 changes: 10 additions & 9 deletions libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Sequence
from pathlib import Path
from typing import Dict, List, Self, Sequence, Union
from typing import Self

from langchain_community.graphs.graph_document import GraphDocument
from langchain_core.pydantic_v1 import BaseModel
Expand Down Expand Up @@ -33,10 +34,10 @@ class RelationshipSchema(BaseModel):
edge_type: str
"""The name of the edge type for the relationhsip."""

source_types: List[str]
source_types: list[str]
"""The node types for the source of the relationship."""

target_types: List[str]
target_types: list[str]
"""The node types for the target of the relationship."""

description: str
Expand All @@ -59,28 +60,28 @@ class Example(BaseModel):
class KnowledgeSchema(BaseModel):
"""Schema for a knowledge graph."""

nodes: List[NodeSchema]
nodes: list[NodeSchema]
"""Allowed node types for the knowledge schema."""

relationships: List[RelationshipSchema]
relationships: list[RelationshipSchema]
"""Allowed relationships for the knowledge schema."""

@classmethod
def from_file(cls, path: Union[str, Path]) -> Self:
def from_file(cls, path: str | Path) -> Self:
"""Load a KnowledgeSchema from a JSON or YAML file.

Args:
path: The path to the file to load.
"""
from pydantic_yaml import parse_yaml_file_as

return parse_yaml_file_as(cls, path)
return parse_yaml_file_as(cls, path) # type: ignore[type-var]

def to_yaml_str(self) -> str:
"""Convert the schema to a YAML string."""
from pydantic_yaml import to_yaml_str

return to_yaml_str(self)
return to_yaml_str(self) # type: ignore[arg-type]


class KnowledgeSchemaValidator:
Expand All @@ -91,7 +92,7 @@ def __init__(self, schema: KnowledgeSchema) -> None:

self._nodes = {node.type: node for node in schema.nodes}

self._relationships: Dict[str, List[RelationshipSchema]] = {}
self._relationships: dict[str, list[RelationshipSchema]] = {}
for r in schema.relationships:
self._relationships.setdefault(r.edge_type, []).append(r)

Expand Down
8 changes: 4 additions & 4 deletions libs/knowledge-graph/ragstack_knowledge_graph/render.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, Tuple, Union
from collections.abc import Iterable

import graphviz
from langchain_community.graphs.graph_document import GraphDocument, Node
Expand All @@ -11,7 +11,7 @@ def _node_label(node: Node) -> str:


def print_graph_documents(
graph_documents: Union[GraphDocument, Iterable[GraphDocument]],
graph_documents: GraphDocument | Iterable[GraphDocument],
) -> None:
"""Prints the relationships in the graph documents."""
if isinstance(graph_documents, GraphDocument):
Expand All @@ -25,15 +25,15 @@ def print_graph_documents(


def render_graph_documents(
graph_documents: Union[GraphDocument, Iterable[GraphDocument]],
graph_documents: GraphDocument | Iterable[GraphDocument],
) -> graphviz.Digraph:
"""Renders the relationships in the graph documents."""
if isinstance(graph_documents, GraphDocument):
graph_documents = [graph_documents]

dot = graphviz.Digraph()

nodes: Dict[Tuple[Union[str, int], str], str] = {}
nodes: dict[tuple[str | int, str], str] = {}

def _node_id(node: Node) -> str:
node_key = (node.id, node.type)
Expand Down
Loading
Loading