Skip to content

Commit 07f8e02

Browse files
authored
Add type checking of ragstack-langchain (#619)
1 parent e59b277 commit 07f8e02

File tree

9 files changed

+118
-107
lines changed

9 files changed

+118
-107
lines changed

.github/workflows/ci-unit-tests.yml

+6
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ jobs:
9595
- name: "Type check (knowledge-store)"
9696
run: tox -e type -c libs/knowledge-store && rm -rf libs/knowledge-store/.tox
9797

98+
- name: "Type check (langchain)"
99+
run: tox -e type -c libs/langchain && rm -rf libs/langchain/.tox
100+
101+
- name: "Type check (llama-index)"
102+
run: tox -e type -c libs/llamaindex && rm -rf libs/llamaindex/.tox
103+
98104
- name: "Type check (ragulate)"
99105
run: tox -e type -c libs/ragulate && rm -rf libs/ragulate/.tox
100106

libs/langchain/pyproject.toml

+20
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,26 @@ ragstack-ai-colbert = { path = "../colbert", develop = true }
4040
ragstack-ai-knowledge-store = { path = "../knowledge-store", develop = true }
4141
pytest-asyncio = "^0.23.6"
4242

43+
[tool.poetry.group.dev.dependencies]
44+
mypy = "^1.11.0"
45+
46+
[tool.mypy]
47+
disallow_any_generics = true
48+
disallow_incomplete_defs = true
49+
disallow_untyped_calls = true
50+
disallow_untyped_decorators = true
51+
disallow_untyped_defs = true
52+
follow_imports = "normal"
53+
ignore_missing_imports = true
54+
no_implicit_reexport = true
55+
show_error_codes = true
56+
show_error_context = true
57+
strict_equality = true
58+
strict_optional = true
59+
warn_redundant_casts = true
60+
warn_return_any = true
61+
warn_unused_ignores = true
62+
4363
[tool.pytest.ini_options]
4464
asyncio_mode = "auto"
4565

libs/langchain/ragstack_langchain/colbert/colbert_retriever.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
1+
from typing import TYPE_CHECKING, List, Optional, Tuple
22

33
from langchain_core.callbacks.manager import (
44
AsyncCallbackManagerForRetrieverRun,
@@ -39,11 +39,10 @@ class ColbertRetriever(BaseRetriever):
3939
def __init__(
4040
self,
4141
retriever: ColbertBaseRetriever,
42-
k: Optional[int] = 5,
42+
k: int = 5,
4343
query_maxlen: Optional[int] = None,
44-
**kwargs: Any,
4544
):
46-
super().__init__(retriever=retriever, k=k, **kwargs)
45+
super().__init__()
4746
self.retriever = retriever
4847
self.k = k
4948
self.query_maxlen = query_maxlen

libs/langchain/ragstack_langchain/colbert/colbert_vector_store.py

+37-53
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
1-
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
24

35
from langchain_core.documents import Document
4-
from langchain_core.embeddings import Embeddings
56
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
67
from ragstack_colbert import Chunk
78
from ragstack_colbert import ColbertVectorStore as RagstackColbertVectorStore
8-
from ragstack_colbert.base_database import BaseDatabase as ColbertBaseDatabase
9-
from ragstack_colbert.base_embedding_model import (
10-
BaseEmbeddingModel as ColbertBaseEmbeddingModel,
11-
)
12-
from ragstack_colbert.base_retriever import BaseRetriever as ColbertBaseRetriever
13-
from ragstack_colbert.base_vector_store import BaseVectorStore as ColbertBaseVectorStore
14-
from typing_extensions import override
9+
from typing_extensions import Self, override
1510

1611
from ragstack_langchain.colbert.embedding import TokensEmbeddings
1712

18-
CVS = TypeVar("CVS", bound="ColbertVectorStore")
13+
if TYPE_CHECKING:
14+
from langchain_core.embeddings import Embeddings
15+
from ragstack_colbert.base_database import BaseDatabase as ColbertBaseDatabase
16+
from ragstack_colbert.base_embedding_model import (
17+
BaseEmbeddingModel as ColbertBaseEmbeddingModel,
18+
)
19+
from ragstack_colbert.base_retriever import BaseRetriever as ColbertBaseRetriever
20+
from ragstack_colbert.base_vector_store import (
21+
BaseVectorStore as ColbertBaseVectorStore,
22+
)
1923

2024

2125
class ColbertVectorStore(VectorStore):
@@ -35,7 +39,7 @@ def _initialize(
3539
self,
3640
database: ColbertBaseDatabase,
3741
embedding_model: ColbertBaseEmbeddingModel,
38-
):
42+
) -> None:
3943
self._vector_store = RagstackColbertVectorStore(
4044
database=database, embedding_model=embedding_model
4145
)
@@ -45,7 +49,7 @@ def _initialize(
4549
def add_texts(
4650
self,
4751
texts: Iterable[str],
48-
metadatas: Optional[List[dict]] = None,
52+
metadatas: Optional[List[Dict[str, Any]]] = None,
4953
doc_id: Optional[str] = None,
5054
**kwargs: Any,
5155
) -> List[str]:
@@ -60,17 +64,18 @@ def add_texts(
6064
Returns:
6165
List of ids from adding the texts into the vectorstore.
6266
"""
63-
return self._vector_store.add_texts(
67+
results = self._vector_store.add_texts(
6468
texts=list(texts), metadatas=metadatas, doc_id=doc_id
6569
)
70+
return [results[0][0]] if results else []
6671

6772
@override
6873
async def aadd_texts(
6974
self,
7075
texts: Iterable[str],
71-
metadatas: Optional[List[dict]] = None,
76+
metadatas: Optional[List[Dict[str, Any]]] = None,
7277
doc_id: Optional[str] = None,
73-
concurrent_inserts: Optional[int] = 100,
78+
concurrent_inserts: int = 100,
7479
**kwargs: Any,
7580
) -> List[str]:
7681
"""Run more texts through the embeddings and add to the vectorstore.
@@ -86,51 +91,30 @@ async def aadd_texts(
8691
Returns:
8792
List of ids from adding the texts into the vectorstore.
8893
"""
89-
return await self._vector_store.aadd_texts(
94+
results = await self._vector_store.aadd_texts(
9095
texts=list(texts),
9196
metadatas=metadatas,
9297
doc_id=doc_id,
9398
concurrent_inserts=concurrent_inserts,
9499
)
100+
return [results[0][0]] if results else []
95101

96102
@override
97103
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
98-
"""Delete by vector ID or other criteria.
99-
100-
Args:
101-
ids: List of ids to delete.
102-
**kwargs: Other keyword arguments that subclasses might use.
103-
104-
Returns:
105-
Optional[bool]: True if deletion is successful,
106-
False otherwise, None if not implemented.
107-
"""
108-
return None if ids is None else self._vector_store.delete(ids=ids)
104+
return None if ids is None else self._vector_store.delete_chunks(doc_ids=ids)
109105

110106
@override
111107
async def adelete(
112108
self,
113109
ids: Optional[List[str]] = None,
114-
concurrent_deletes: Optional[int] = 100,
110+
concurrent_deletes: int = 100,
115111
**kwargs: Any,
116112
) -> Optional[bool]:
117-
"""Delete by vector ID or other criteria.
118-
119-
Args:
120-
ids: List of ids to delete.
121-
concurrent_deletes: How many concurrent deletes to make to the database.
122-
Defaults to 100.
123-
**kwargs: Other keyword arguments that subclasses might use.
124-
125-
Returns:
126-
Optional[bool]: True if deletion is successful,
127-
False otherwise, None if not implemented.
128-
"""
129113
return (
130114
None
131115
if ids is None
132-
else await self._vector_store.adelete(
133-
ids=ids, concurrent_deletes=concurrent_deletes
116+
else await self._vector_store.adelete_chunks(
117+
doc_ids=ids, concurrent_deletes=concurrent_deletes
134118
)
135119
)
136120

@@ -215,7 +199,7 @@ def from_documents(
215199
*,
216200
database: Optional[ColbertBaseDatabase] = None,
217201
**kwargs: Any,
218-
) -> CVS:
202+
) -> Self:
219203
"""Return VectorStore initialized from documents and embeddings."""
220204
texts = [d.page_content for d in documents]
221205
metadatas = [d.metadata for d in documents]
@@ -230,14 +214,14 @@ def from_documents(
230214
@classmethod
231215
@override
232216
async def afrom_documents(
233-
cls: Type[CVS],
217+
cls,
234218
documents: List[Document],
235219
embedding: Embeddings,
236220
*,
237221
database: Optional[ColbertBaseDatabase] = None,
238-
concurrent_inserts: Optional[int] = 100,
222+
concurrent_inserts: int = 100,
239223
**kwargs: Any,
240-
) -> CVS:
224+
) -> Self:
241225
"""Return VectorStore initialized from documents and embeddings."""
242226
texts = [d.page_content for d in documents]
243227
metadatas = [d.metadata for d in documents]
@@ -253,14 +237,14 @@ async def afrom_documents(
253237
@classmethod
254238
@override
255239
def from_texts(
256-
cls: Type[CVS],
240+
cls,
257241
texts: List[str],
258242
embedding: Embeddings,
259-
metadatas: Optional[List[dict]] = None,
243+
metadatas: Optional[List[Dict[str, Any]]] = None,
260244
*,
261245
database: Optional[ColbertBaseDatabase] = None,
262246
**kwargs: Any,
263-
) -> CVS:
247+
) -> Self:
264248
if not isinstance(embedding, TokensEmbeddings):
265249
raise TypeError("ColbertVectorStore requires a TokensEmbeddings embedding.")
266250
if database is None:
@@ -276,15 +260,15 @@ def from_texts(
276260
@classmethod
277261
@override
278262
async def afrom_texts(
279-
cls: Type[CVS],
263+
cls,
280264
texts: List[str],
281265
embedding: Embeddings,
282-
metadatas: Optional[List[dict]] = None,
266+
metadatas: Optional[List[Dict[str, Any]]] = None,
283267
*,
284268
database: Optional[ColbertBaseDatabase] = None,
285-
concurrent_inserts: Optional[int] = 100,
269+
concurrent_inserts: int = 100,
286270
**kwargs: Any,
287-
) -> CVS:
271+
) -> Self:
288272
if not isinstance(embedding, TokensEmbeddings):
289273
raise TypeError("ColbertVectorStore requires a TokensEmbeddings embedding.")
290274
if database is None:

libs/langchain/ragstack_langchain/colbert/embedding.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from langchain_core.embeddings import Embeddings
44
from ragstack_colbert import DEFAULT_COLBERT_MODEL, ColbertEmbeddingModel
55
from ragstack_colbert.base_embedding_model import BaseEmbeddingModel
6-
from typing_extensions import override
6+
from typing_extensions import Self, override
77

88

99
class TokensEmbeddings(Embeddings):
1010
"""Adapter for token-based embedding models and the LangChain Embeddings."""
1111

12-
def __init__(self, embedding: BaseEmbeddingModel = None):
12+
def __init__(self, embedding: Optional[BaseEmbeddingModel] = None):
1313
self.embedding = embedding or ColbertEmbeddingModel()
1414

1515
@override
@@ -32,8 +32,9 @@ def get_embedding_model(self) -> BaseEmbeddingModel:
3232
"""Get the embedding model."""
3333
return self.embedding
3434

35-
@staticmethod
35+
@classmethod
3636
def colbert(
37+
cls,
3738
checkpoint: str = DEFAULT_COLBERT_MODEL,
3839
doc_maxlen: int = 256,
3940
nbits: int = 2,
@@ -42,9 +43,9 @@ def colbert(
4243
query_maxlen: Optional[int] = None,
4344
verbose: int = 3,
4445
chunk_batch_size: int = 640,
45-
):
46+
) -> Self:
4647
"""Create a new ColBERT embedding model."""
47-
return TokensEmbeddings(
48+
return cls(
4849
ColbertEmbeddingModel(
4950
checkpoint,
5051
doc_maxlen,

libs/langchain/tests/integration_tests/conftest.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Iterator
22

33
import pytest
4+
from _pytest.fixtures import FixtureRequest
45
from cassandra.cluster import Session
56
from dotenv import load_dotenv
67
from ragstack_tests_utils import AstraDBTestStore, LocalCassandraTestStore
@@ -21,13 +22,13 @@ def astra_db() -> AstraDBTestStore:
2122
return AstraDBTestStore()
2223

2324

24-
def get_session(request) -> Session:
25+
def get_session(request: FixtureRequest) -> Session:
2526
test_store = request.getfixturevalue(request.param)
2627
session = test_store.create_cassandra_session()
2728
session.default_timeout = 180
2829
return session
2930

3031

3132
@pytest.fixture()
32-
def session(request) -> Session:
33+
def session(request: FixtureRequest) -> Session:
3334
return get_session(request)

0 commit comments

Comments
 (0)