Skip to content

Commit a8c84c5

Browse files
authored
Add ability to (in/ex)clude providers by ID within client (#1412)
* Add ability to in(ex)clude providers by ID within client - Using `--include-providers`, `--exclude-providers` and `--exclude_databases` at the CLI or the corresponding Python options * Add test case * Add docs notes
1 parent 5243791 commit a8c84c5

File tree

5 files changed

+188
-7
lines changed

5 files changed

+188
-7
lines changed

docs/getting_started/client.md

+56
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,62 @@ We can refine the search by manually specifying some URLs:
7272
client.get()
7373
```
7474

75+
or by including/excluding some providers by their registered IDs in the [Providers list](https://providers.optimade.org).
76+
77+
Query only a list of included providers (after a lookup of the providers list):
78+
79+
=== "Command line"
80+
```shell
81+
# Only query databases served by the example providers
82+
optimade-get --include-providers exmpl,optimade
83+
```
84+
85+
=== "Python"
86+
```python
87+
# Only query databases served by the example providers
88+
from optimade.client import OptimadeClient
89+
client = OptimadeClient(
90+
include_providers={"exmpl", "optimade"},
91+
)
92+
client.get()
93+
```
94+
95+
Exclude certain providers:
96+
97+
=== "Command line"
98+
```shell
99+
# Exclude example providers from global list
100+
optimade-get --exclude-providers exmpl,optimade
101+
```
102+
103+
=== "Python"
104+
```python
105+
# Exclude example providers from global list
106+
from optimade.client import OptimadeClient
107+
client = OptimadeClient(
108+
exclude_providers={"exmpl", "optimade"},
109+
)
110+
client.get()
111+
```
112+
113+
Exclude particular databases by URL:
114+
115+
=== "Command line"
116+
```shell
117+
# Exclude specific example databases
118+
optimade-get --exclude-databases https://example.org/optimade,https://optimade.org/example
119+
```
120+
121+
=== "Python"
122+
```python
123+
# Exclude specific example databases
124+
from optimade.client import OptimadeClient
125+
client = OptimadeClient(
126+
exclude_databases={"https://example.org/optimade", "https://optimade.org/example"}
127+
)
128+
client.get()
129+
```
130+
75131
### Filtering
76132

77133
By default, an empty filter will be used (which will return all entries in a database).

optimade/client/cli.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,26 @@
5353
is_flag=True,
5454
help="Pretty print the JSON results.",
5555
)
56-
@click.argument("base-url", default=None, nargs=-1)
56+
@click.option(
57+
"--include-providers",
58+
default=None,
59+
help="A string of comma-separated provider IDs to query.",
60+
)
61+
@click.option(
62+
"--exclude-providers",
63+
default=None,
64+
help="A string of comma-separated provider IDs to exclude from queries.",
65+
)
66+
@click.option(
67+
"--exclude-databases",
68+
default=None,
69+
help="A string of comma-separated database URLs to exclude from queries.",
70+
)
71+
@click.argument(
72+
"base-url",
73+
default=None,
74+
nargs=-1,
75+
)
5776
def get(
5877
use_async,
5978
filter,
@@ -65,6 +84,9 @@ def get(
6584
sort,
6685
endpoint,
6786
pretty_print,
87+
include_providers,
88+
exclude_providers,
89+
exclude_databases,
6890
):
6991
return _get(
7092
use_async,
@@ -77,6 +99,9 @@ def get(
7799
sort,
78100
endpoint,
79101
pretty_print,
102+
include_providers,
103+
exclude_providers,
104+
exclude_databases,
80105
)
81106

82107

@@ -91,6 +116,9 @@ def _get(
91116
sort,
92117
endpoint,
93118
pretty_print,
119+
include_providers,
120+
exclude_providers,
121+
exclude_databases,
94122
):
95123

96124
if output_file:
@@ -106,6 +134,15 @@ def _get(
106134
base_urls=base_url,
107135
use_async=use_async,
108136
max_results_per_provider=max_results_per_provider,
137+
include_providers=set(_.strip() for _ in include_providers.split(","))
138+
if include_providers
139+
else None,
140+
exclude_providers=set(_.strip() for _ in exclude_providers.split(","))
141+
if exclude_providers
142+
else None,
143+
exclude_databases=set(_.strip() for _ in exclude_databases.split(","))
144+
if exclude_databases
145+
else None,
109146
)
110147
if response_fields:
111148
response_fields = response_fields.split(",")

optimade/client/client.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import json
1111
import time
1212
from collections import defaultdict
13-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
13+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
1414
from urllib.parse import urlparse
1515

1616
# External deps that are only used in the client code
@@ -85,6 +85,15 @@ class OptimadeClient:
8585
use_async: bool
8686
"""Whether or not to make all requests asynchronously using asyncio."""
8787

88+
_excluded_providers: Optional[Set[str]] = None
89+
"""A set of providers IDs excluded from future queries."""
90+
91+
_included_providers: Optional[Set[str]] = None
92+
"""A set of providers IDs included from future queries."""
93+
94+
_excluded_databases: Optional[Set[str]] = None
95+
"""A set of child database URLs excluded from future queries."""
96+
8897
__current_endpoint: Optional[str] = None
8998
"""Used internally when querying via `client.structures.get()` to set the
9099
chosen endpoint. Should be reset to `None` outside of all `get()` calls."""
@@ -97,6 +106,9 @@ def __init__(
97106
http_timeout: int = 10,
98107
max_attempts: int = 5,
99108
use_async: bool = True,
109+
exclude_providers: Optional[List[str]] = None,
110+
include_providers: Optional[List[str]] = None,
111+
exclude_databases: Optional[List[str]] = None,
100112
):
101113
"""Create the OPTIMADE client object.
102114
@@ -108,16 +120,32 @@ def __init__(
108120
http_timeout: The HTTP timeout to use per request.
109121
max_attempts: The maximum number of times to repeat a failing query.
110122
use_async: Whether or not to make all requests asynchronously.
123+
exclude_providers: A set or collection of provider IDs to exclude from queries.
124+
include_providers: A set or collection of provider IDs to include in queries.
125+
exclude_databases: A set or collection of child database URLs to exclude from queries.
111126
112127
"""
113128

114129
self.max_results_per_provider = max_results_per_provider
115130
if self.max_results_per_provider in (-1, 0):
116131
self.max_results_per_provider = None
117132

133+
self._excluded_providers = set(exclude_providers) if exclude_providers else None
134+
self._included_providers = set(include_providers) if include_providers else None
135+
self._excluded_databases = set(exclude_databases) if exclude_databases else None
136+
118137
if not base_urls:
119-
self.base_urls = get_all_databases()
138+
self.base_urls = get_all_databases(
139+
exclude_providers=self._excluded_providers,
140+
include_providers=self._included_providers,
141+
exclude_databases=self._excluded_databases,
142+
)
120143
else:
144+
if exclude_providers or include_providers or exclude_databases:
145+
raise RuntimeError(
146+
"Cannot provide both a list of base URLs and included/excluded databases."
147+
)
148+
121149
self.base_urls = base_urls
122150

123151
if isinstance(self.base_urls, str):

optimade/utils.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
import json
7-
from typing import Iterable, List
7+
from typing import Container, Iterable, List, Optional
88

99
from pydantic import ValidationError
1010

@@ -101,7 +101,7 @@ def get_providers(add_mongo_id: bool = False) -> list:
101101

102102

103103
def get_child_database_links(
104-
provider: LinksResource, obey_aggregate=True
104+
provider: LinksResource, obey_aggregate: bool = True
105105
) -> List[LinksResource]:
106106
"""For a provider, return a list of available child databases.
107107
@@ -155,13 +155,37 @@ def get_child_database_links(
155155
) from exc
156156

157157

158-
def get_all_databases() -> Iterable[str]:
159-
"""Iterate through all databases reported by registered OPTIMADE providers."""
158+
def get_all_databases(
159+
include_providers: Optional[Container[str]] = None,
160+
exclude_providers: Optional[Container[str]] = None,
161+
exclude_databases: Optional[Container[str]] = None,
162+
) -> Iterable[str]:
163+
"""Iterate through all databases reported by registered OPTIMADE providers.
164+
165+
Parameters:
166+
include_providers: A set/container of provider IDs to include child databases for.
167+
exclude_providers: A set/container of provider IDs to exclude child databases for.
168+
exclude_databases: A set/container of specific database URLs to exclude.
169+
170+
Returns:
171+
A generator of child database links that obey the given parameters.
172+
173+
"""
160174
for provider in get_providers():
175+
if exclude_providers and provider["id"] in exclude_providers:
176+
continue
177+
if include_providers and provider["id"] not in include_providers:
178+
continue
179+
161180
try:
162181
links = get_child_database_links(provider)
163182
for link in links:
164183
if link.attributes.base_url:
184+
if (
185+
exclude_databases
186+
and link.attributes.base_url in exclude_databases
187+
):
188+
continue
165189
yield str(link.attributes.base_url)
166190
except RuntimeError:
167191
pass

tests/server/test_client.py

+36
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,39 @@ def test_multiple_base_urls(httpx_mocked_response, use_async):
116116
)
117117

118118

119+
@pytest.mark.parametrize("use_async", [False])
120+
def test_include_exclude_providers(use_async):
121+
with pytest.raises(
122+
SystemExit,
123+
match="Unable to access any OPTIMADE base URLs. If you believe this is an error, try manually specifying some base URLs.",
124+
):
125+
OptimadeClient(
126+
include_providers={"exmpl"},
127+
exclude_providers={"exmpl"},
128+
use_async=use_async,
129+
)
130+
131+
with pytest.raises(
132+
RuntimeError,
133+
match="Cannot provide both a list of base URLs and included/excluded databases.",
134+
):
135+
OptimadeClient(
136+
base_urls=TEST_URLS,
137+
include_providers={"exmpl"},
138+
use_async=use_async,
139+
)
140+
141+
with pytest.raises(
142+
SystemExit,
143+
match="Unable to access any OPTIMADE base URLs. If you believe this is an error, try manually specifying some base URLs.",
144+
):
145+
OptimadeClient(
146+
include_providers={"exmpl"},
147+
exclude_databases={"https://example.org/optimade"},
148+
use_async=use_async,
149+
)
150+
151+
119152
@pytest.mark.parametrize("use_async", [False])
120153
def test_client_sort(httpx_mocked_response, use_async):
121154
cli = OptimadeClient(base_urls=[TEST_URL], use_async=use_async)
@@ -138,6 +171,9 @@ def test_command_line_client(httpx_mocked_response, use_async, capsys):
138171
sort=None,
139172
endpoint="structures",
140173
pretty_print=False,
174+
include_providers=None,
175+
exclude_providers=None,
176+
exclude_databases=None,
141177
)
142178

143179
# Test multi-provider query

0 commit comments

Comments
 (0)