Skip to content

Commit e818a5a

Browse files
Change flows_from_changes to expect partitions as arguments to enable caching (#384)
1 parent 61ca3f3 commit e818a5a

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

gerrychain/partition/partition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _from_parent(self, parent, flips):
7878
self.graph = parent.graph
7979
self.updaters = parent.updaters
8080

81-
self.flows = flows_from_changes(parent.assignment, flips)
81+
self.flows = flows_from_changes(parent, self) # careful
8282

8383
self.assignment = parent.assignment.copy()
8484
self.assignment.update_flows(self.flows)

gerrychain/updaters/flows.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ def create_flow():
1616
return {'in': set(), 'out': set()}
1717

1818

19-
def flows_from_changes(old_assignment, flips):
19+
@functools.lru_cache(maxsize=2)
20+
def flows_from_changes(old_partition, new_partition):
2021
flows = collections.defaultdict(create_flow)
21-
for node, target in flips.items():
22-
source = old_assignment.mapping[node]
22+
for node, target in new_partition.flips.items():
23+
source = old_partition.assignment.mapping[node]
2324
if source != target:
2425
flows[target]['in'].add(node)
2526
flows[source]['out'].add(node)

gerrychain/updaters/tally.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,13 @@ def _update_tally(self, partition):
102102
:param partition: :class:`Partition` class.
103103
"""
104104
parent = partition.parent
105-
flips = partition.flips
106105

107106
old_tally = parent[self.alias]
108107
new_tally = dict(old_tally)
109108

110109
graph = partition.graph
111110

112-
for part, flow in flows_from_changes(parent.assignment, flips).items():
111+
for part, flow in flows_from_changes(parent, partition).items():
113112
out_flow = compute_out_flow(graph, self.fields, flow)
114113
in_flow = compute_in_flow(graph, self.fields, flow)
115114
new_tally[part] = old_tally[part] - out_flow + in_flow

0 commit comments

Comments
 (0)