Skip to content

Commit

Permalink
fix: Use stable narwhals imports
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Nov 22, 2024
1 parent de03046 commit e7974d9
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 20 deletions.
8 changes: 5 additions & 3 deletions altair/datasets/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, get_args

import narwhals.stable.v1 as nw
from narwhals.dependencies import get_pyarrow
from narwhals.typing import IntoDataFrameT, IntoFrameT
from narwhals.stable.v1 import dependencies as nw_dep
from narwhals.stable.v1.typing import IntoDataFrameT, IntoFrameT

from altair.datasets._typing import VERSION_LATEST

Expand Down Expand Up @@ -151,7 +151,9 @@ def clear(self) -> None:
.get_column("sha_suffix")
)
names = set[str](
ser.to_list() if nw.get_native_namespace(ser) is get_pyarrow() else ser
ser.to_list()
if nw.get_native_namespace(ser) is nw_dep.get_pyarrow()
else ser
)
for fp in self:
if fp.name in names:
Expand Down
2 changes: 1 addition & 1 deletion altair/datasets/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, Generic, final, overload

from narwhals.typing import IntoDataFrameT, IntoFrameT
from narwhals.stable.v1.typing import IntoDataFrameT, IntoFrameT

from altair.datasets._readers import _Reader, backend

Expand Down
2 changes: 1 addition & 1 deletion altair/datasets/_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)

import narwhals.stable.v1 as nw
from narwhals.typing import IntoDataFrameT, IntoExpr, IntoFrameT
from narwhals.stable.v1.typing import IntoDataFrameT, IntoExpr, IntoFrameT

from altair.datasets._cache import DatasetCache
from altair.datasets._typing import EXTENSION_SUFFIXES, is_ext_read
Expand Down
25 changes: 10 additions & 15 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,8 @@
from urllib.error import URLError

import pytest
from narwhals.dependencies import (
is_into_dataframe,
is_pandas_dataframe,
is_polars_dataframe,
is_pyarrow_table,
)
from narwhals.stable import v1 as nw
from narwhals.stable.v1 import dependencies as nw_dep

from altair.datasets import Loader, url
from altair.datasets._readers import _METADATA, AltairDatasetsError
Expand Down Expand Up @@ -227,11 +222,11 @@ def test_load_call(monkeypatch: pytest.MonkeyPatch) -> None:
default_2 = load("cars")
df_polars = load("cars", backend="polars")

assert is_polars_dataframe(default)
assert is_pyarrow_table(df_pyarrow)
assert is_pandas_dataframe(df_pandas)
assert is_polars_dataframe(default_2)
assert is_polars_dataframe(df_polars)
assert nw_dep.is_polars_dataframe(default)
assert nw_dep.is_pyarrow_table(df_pyarrow)
assert nw_dep.is_pandas_dataframe(df_pandas)
assert nw_dep.is_polars_dataframe(default_2)
assert nw_dep.is_polars_dataframe(df_polars)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -320,7 +315,7 @@ def test_loader_call(backend: _Backend, monkeypatch: pytest.MonkeyPatch) -> None

data = Loader.from_backend(backend)
frame = data("stocks", ".csv")
assert is_into_dataframe(frame)
assert nw_dep.is_into_dataframe(frame)
nw_frame = nw.from_native(frame)
assert set(nw_frame.columns) == {"symbol", "date", "price"}

Expand Down Expand Up @@ -493,7 +488,7 @@ def test_reader_cache(
cached_paths = tuple(data.cache)
assert len(cached_paths) == 4

if is_polars_dataframe(lookup_groups):
if nw_dep.is_polars_dataframe(lookup_groups):
left, right = (
lookup_groups,
cast(pl.DataFrame, data("lookup_groups", tag="v2.5.3")),
Expand Down Expand Up @@ -664,7 +659,7 @@ def test_all_datasets(
) -> None:
"""Ensure all annotated datasets can be loaded with the most reliable backend."""
frame = polars_loader(name, suffix, tag=tag)
assert is_polars_dataframe(frame)
assert nw_dep.is_polars_dataframe(frame)


def _raise_exception(e: type[Exception], *args: Any, **kwds: Any):
Expand Down Expand Up @@ -698,7 +693,7 @@ def test_no_remote_connection(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -

# Now we can get a cache-hit
frame = data("birdstrikes")
assert is_polars_dataframe(frame)
assert nw_dep.is_polars_dataframe(frame)
assert len(tuple(tmp_path.iterdir())) == 4

with monkeypatch.context() as mp:
Expand Down

0 comments on commit e7974d9

Please sign in to comment.