Skip to content

Commit 7cbf632

Browse files
Cache node attribute fields
1 parent d4a19fc commit 7cbf632

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

gerrychain/updaters/tally.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ class DataTally:
99
"""An updater for tallying numerical data that is not necessarily stored as
1010
node attributes
1111
"""
12+
__slots__ = [
13+
"data",
14+
"alias",
15+
"_call"
16+
]
1217

1318
def __init__(self, data, alias):
1419
"""
@@ -54,6 +59,11 @@ def __call__(self, partition, previous=None):
5459
class Tally:
5560
"""An updater for keeping a tally of one or more node attributes.
5661
"""
62+
__slots__ = [
63+
"fields",
64+
"alias",
65+
"dtype"
66+
]
5767

5868
def __init__(self, fields, alias=None, dtype=int):
5969
"""
@@ -117,12 +127,12 @@ def _update_tally(self, partition):
117127
return new_tally
118128

119129
def _get_tally_from_node(self, partition, node):
120-
return sum(partition.graph.nodes[node][field] for field in self.fields)
130+
return sum(partition.graph.cached_node_data_lookup[node][field] for field in self.fields)
121131

122132

123133
def compute_out_flow(graph, fields, flow):
124-
return sum(graph.nodes[node][field] for node in flow["out"] for field in fields)
134+
return sum(graph.cached_node_data_lookup[node][field] for node in flow["out"] for field in fields)
125135

126136

127137
def compute_in_flow(graph, fields, flow):
128-
return sum(graph.nodes[node][field] for node in flow["in"] for field in fields)
138+
return sum(graph.cached_node_data_lookup[node][field] for node in flow["in"] for field in fields)

tests/test_tree.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from gerrychain import MarkovChain
77
from gerrychain.constraints import contiguous, within_percent_of_ideal_population
8+
from gerrychain.graph import Graph
89
from gerrychain.partition import Partition
910
from gerrychain.proposals import recom
1011
from gerrychain.tree import (
@@ -22,7 +23,7 @@
2223
def graph_with_pop(three_by_three_grid):
2324
for node in three_by_three_grid:
2425
three_by_three_grid.nodes[node]["pop"] = 1
25-
return three_by_three_grid
26+
return Graph.from_networkx(three_by_three_grid)
2627

2728

2829
@pytest.fixture
@@ -41,7 +42,7 @@ def twelve_by_twelve_with_pop():
4142
grid = networkx.relabel_nodes(xy_grid, nodes)
4243
for node in grid:
4344
grid.nodes[node]["pop"] = 1
44-
return grid
45+
return Graph.from_networkx(grid)
4546

4647

4748
def test_bipartition_tree_returns_a_subset_of_nodes(graph_with_pop):

0 commit comments

Comments
 (0)