Skip to content

Commit 8acd44c

Browse files
Add option to disable cut_edges updater with use_cut_edges flag (#375)
1 parent b686e25 commit 8acd44c

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

gerrychain/partition/partition.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ class Partition:
1616
:ivar dict parts: Maps district IDs to the set of nodes in that district.
1717
:ivar dict subgraphs: Maps district IDs to the induced subgraph of that district.
1818
"""
19-
default_updaters = {"cut_edges": cut_edges}
2019
__slots__ = (
2120
'graph',
2221
'subgraphs',
@@ -30,26 +29,28 @@ class Partition:
3029
)
3130

3231
def __init__(
33-
self, graph=None, assignment=None, updaters=None, parent=None, flips=None
32+
self, graph=None, assignment=None, updaters=None, parent=None, flips=None,
33+
use_cut_edges=True
3434
):
3535
"""
3636
:param graph: Underlying graph.
3737
:param assignment: Dictionary assigning nodes to districts.
3838
:param updaters: Dictionary of functions to track data about the partition.
3939
The keys are stored as attributes on the partition class,
4040
which the functions compute.
41+
:param use_cut_edges: If `False`, do not include `cut_edges` updater by default
42+
and do not calculate edge flows.
4143
"""
4244
if parent is None:
43-
self._first_time(graph, assignment, updaters)
45+
self._first_time(graph, assignment, updaters, use_cut_edges)
4446
else:
4547
self._from_parent(parent, flips)
4648

4749
self._cache = dict()
4850
self.subgraphs = SubgraphView(self.graph, self.parts)
4951

50-
def _first_time(self, graph, assignment, updaters):
52+
def _first_time(self, graph, assignment, updaters, use_cut_edges):
5153
self.graph = graph
52-
5354
self.assignment = get_assignment(assignment, graph)
5455

5556
if set(self.assignment) != set(graph):
@@ -58,7 +59,11 @@ def _first_time(self, graph, assignment, updaters):
5859
if updaters is None:
5960
updaters = dict()
6061

61-
self.updaters = self.default_updaters.copy()
62+
if use_cut_edges:
63+
self.updaters = {"cut_edges": cut_edges}
64+
else:
65+
self.updaters = {}
66+
6267
self.updaters.update(updaters)
6368

6469
self.parent = None
@@ -77,7 +82,9 @@ def _from_parent(self, parent, flips):
7782
self.updaters = parent.updaters
7883

7984
self.flows = flows_from_changes(parent.assignment, flips)
80-
self.edge_flows = compute_edge_flows(self)
85+
86+
if "cut_edges" in self.updaters:
87+
self.edge_flows = compute_edge_flows(self)
8188

8289
def __repr__(self):
8390
number_of_parts = len(self)

tests/partition/test_partition.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def example_geographic_partition():
5656

5757
def test_geographic_partition_can_be_instantiated(example_geographic_partition):
5858
partition = example_geographic_partition
59-
assert partition.updaters == GeographicPartition.default_updaters
59+
assert isinstance(partition, GeographicPartition)
6060

6161

6262
def test_Partition_parts_is_a_dictionary_of_parts_to_nodes(example_partition):
@@ -144,11 +144,9 @@ def test_repr(example_partition):
144144

145145
def test_partition_has_default_updaters(example_partition):
146146
partition = example_partition
147-
default_updaters = partition.default_updaters
148147
should_have_updaters = {"cut_edges": cut_edges}
149148

150149
for updater in should_have_updaters:
151-
assert default_updaters.get(updater, None) is not None
152150
assert should_have_updaters[updater](partition) == partition[updater]
153151

154152

0 commit comments

Comments
 (0)