Skip to content

Commit d766bcd

Browse files
authored
Merge pull request #100 from washingtonpost/ELEX-3469-save-aggregate-predictions-to-s3
ELEX-3469 save aggregate predictions to s3
2 parents 63e8aa1 + f5af27d commit d766bcd

File tree

7 files changed

+119
-32
lines changed

7 files changed

+119
-32
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ Parameters for the CLI tool:
7272
| called_contests | dict | a dictionary of called contests. specific to Bootstrap model for now. e.g. `--called_contests='{"VA": -1}'` |
7373
| save_output | list | `results`, `data`, `config` |
7474
| unexpected_units | int | number of unexpected units to simulate; only used for testing and does not work with historical run |
75+
| national_summary | flag | When not running a historical election, specify this flag to output national summary (aggregate model) estimates. |
7576

7677
Note: When running the model with multiple fixed effects, make sure they are not linearly dependent. For example, `county_fips` and `county_classification` are linearly dependent when run together. That's because every county is in one county class, so all the fixed effect columns of the counties in the county class sum up to the fixed effect column of that county class.
7778

src/elexmodel/cli.py

+11
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def type_cast_value(self, ctx, value):
7979
help="options: results, data, config",
8080
)
8181
@click.option("--handle_unreporting", "handle_unreporting", default="drop", type=click.Choice(["drop", "zero"]))
82+
@click.option(
83+
"--national_summary",
84+
"national_summary",
85+
is_flag=True,
86+
help="When not running a historical election, output results aggregated to the national level.",
87+
)
8288
def cli(
8389
election_id, estimands, office_id, prediction_intervals, percent_reporting_threshold, geographic_unit_type, **kwargs
8490
):
@@ -159,5 +165,10 @@ def cli(
159165
geographic_unit_type,
160166
**kwargs
161167
)
168+
169+
if kwargs.get("national_summary", False):
170+
# TODO: get_national_summary_votes_estimates() arguments via CLI
171+
model_client.get_national_summary_votes_estimates(None, 0, 0.99)
172+
162173
for aggregate_level, estimates in result.items():
163174
print(aggregate_level, "\n", estimates, "\n")

src/elexmodel/client.py

+50-28
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def __init__(self):
4141
self.all_conformalization_data_unit_dict = defaultdict(dict)
4242
self.all_conformalization_data_agg_dict = defaultdict(dict)
4343
self.model = None
44+
self.results_handler = None
45+
self.election_id = None
46+
self.office = None
47+
self.geographic_unit_type = None
48+
self.save_results = None
4449

4550
def _check_input_parameters(
4651
self,
@@ -170,7 +175,20 @@ def get_aggregate_list(self, office, aggregate):
170175
return sorted(list(set(raw_aggregate_list)), key=lambda x: AGGREGATE_ORDER.index(x))
171176

172177
def get_national_summary_votes_estimates(self, nat_sum_data_dict=None, base_to_add=0, alpha=0.99):
173-
return self.model.get_national_summary_estimates(nat_sum_data_dict, base_to_add, alpha)
178+
if self.model is None:
179+
raise ModelClientException(
180+
"Must call the get_estimands() method before get_national_summary_votes_estimates()."
181+
)
182+
183+
nat_sum_estimates = self.model.get_national_summary_estimates(nat_sum_data_dict, base_to_add, alpha)
184+
self.results_handler.add_national_summary_estimates(nat_sum_estimates)
185+
186+
if APP_ENV != "local" and self.save_results:
187+
self.results_handler.write_data(
188+
self.election_id, self.office, self.geographic_unit_type, keys=["nat_sum_data"]
189+
)
190+
191+
return nat_sum_estimates
174192

175193
def get_estimates(
176194
self,
@@ -202,7 +220,7 @@ def get_estimates(
202220
pi_method = kwargs.get("pi_method", "nonparametric")
203221
called_contests = kwargs.get("called_contests", None)
204222
save_output = kwargs.get("save_output", ["results"])
205-
save_results = "results" in save_output
223+
self.save_results = "results" in save_output
206224
save_data = "data" in save_output
207225
save_config = "config" in save_output
208226
# saving conformalization data only makes sense if a ConformalElectionModel is used
@@ -241,15 +259,18 @@ def get_estimates(
241259
model_parameters,
242260
handle_unreporting,
243261
)
262+
self.election_id = election_id
263+
self.office = office
264+
self.geographic_unit_type = geographic_unit_type
244265

245-
states_with_election = config_handler.get_states(office)
246-
estimand_baselines = config_handler.get_estimand_baselines(office, estimands)
266+
states_with_election = config_handler.get_states(self.office)
267+
estimand_baselines = config_handler.get_estimand_baselines(self.office, estimands)
247268

248-
LOG.info("Getting preprocessed data: %s", election_id)
269+
LOG.info("Getting preprocessed data: %s", self.election_id)
249270
preprocessed_data_handler = PreprocessedDataHandler(
250-
election_id,
251-
office,
252-
geographic_unit_type,
271+
self.election_id,
272+
self.office,
273+
self.geographic_unit_type,
253274
estimands,
254275
estimand_baselines,
255276
data=preprocessed_data,
@@ -267,7 +288,7 @@ def get_estimates(
267288
preprocessed_data,
268289
current_data,
269290
estimands,
270-
geographic_unit_type,
291+
self.geographic_unit_type,
271292
handle_unreporting=handle_unreporting,
272293
)
273294

@@ -307,8 +328,8 @@ def get_estimates(
307328
if minimum_reporting_units > minimum_reporting_units_max:
308329
minimum_reporting_units_max = minimum_reporting_units
309330

310-
if APP_ENV != "local" and save_results:
311-
data.write_data(election_id, office)
331+
if APP_ENV != "local" and self.save_results:
332+
data.write_data(self.election_id, self.office)
312333

313334
n_reporting_expected_units = reporting_units.shape[0]
314335
n_unexpected_units = unexpected_units.shape[0]
@@ -330,44 +351,44 @@ def get_estimates(
330351
if len(duplicate_units) > 0:
331352
raise ModelClientException(f"At least one unit appears twice: {duplicate_units}")
332353

333-
results_handler = ModelResultsHandler(
354+
self.results_handler = ModelResultsHandler(
334355
aggregates, prediction_intervals, reporting_units, nonreporting_units, unexpected_units
335356
)
336357

337358
for estimand in estimands:
338359
unit_predictions, unit_turnout_predictions = self.model.get_unit_predictions(
339360
reporting_units, nonreporting_units, estimand, unexpected_units=unexpected_units
340361
)
341-
results_handler.add_unit_predictions(estimand, unit_predictions, unit_turnout_predictions)
362+
self.results_handler.add_unit_predictions(estimand, unit_predictions, unit_turnout_predictions)
342363
# gets prediciton intervals for each alpha
343364
alpha_to_unit_prediction_intervals = {}
344365
for alpha in prediction_intervals:
345366
alpha_to_unit_prediction_intervals[alpha] = self.model.get_unit_prediction_intervals(
346-
results_handler.reporting_units, results_handler.nonreporting_units, alpha, estimand
367+
self.results_handler.reporting_units, self.results_handler.nonreporting_units, alpha, estimand
347368
)
348369
if isinstance(self.model, ConformalElectionModel):
349370
self.all_conformalization_data_unit_dict[alpha][
350371
estimand
351372
] = self.model.get_all_conformalization_data_unit()
352373

353-
results_handler.add_unit_intervals(estimand, alpha_to_unit_prediction_intervals)
374+
self.results_handler.add_unit_intervals(estimand, alpha_to_unit_prediction_intervals)
354375

355-
for aggregate in results_handler.aggregates:
356-
aggregate_list = self.get_aggregate_list(office, aggregate)
376+
for aggregate in self.results_handler.aggregates:
377+
aggregate_list = self.get_aggregate_list(self.office, aggregate)
357378
estimates_df = self.model.get_aggregate_predictions(
358-
results_handler.reporting_units,
359-
results_handler.nonreporting_units,
360-
results_handler.unexpected_units,
379+
self.results_handler.reporting_units,
380+
self.results_handler.nonreporting_units,
381+
self.results_handler.unexpected_units,
361382
aggregate_list,
362383
estimand,
363384
called_contests=called_contests,
364385
)
365386
alpha_to_agg_prediction_intervals = {}
366387
for alpha in prediction_intervals:
367388
alpha_to_agg_prediction_intervals[alpha] = self.model.get_aggregate_prediction_intervals(
368-
results_handler.reporting_units,
369-
results_handler.nonreporting_units,
370-
results_handler.unexpected_units,
389+
self.results_handler.reporting_units,
390+
self.results_handler.nonreporting_units,
391+
self.results_handler.unexpected_units,
371392
aggregate_list,
372393
alpha,
373394
alpha_to_unit_prediction_intervals[alpha],
@@ -380,15 +401,16 @@ def get_estimates(
380401
] = self.model.get_all_conformalization_data_agg()
381402

382403
# get all of the prediction intervals here
383-
results_handler.add_agg_predictions(
404+
self.results_handler.add_agg_predictions(
384405
estimand, aggregate, estimates_df, alpha_to_agg_prediction_intervals
385406
)
386407

387-
results_handler.process_final_results()
388-
if APP_ENV != "local" and save_results:
389-
results_handler.write_data(election_id, office, geographic_unit_type)
408+
self.results_handler.process_final_results()
409+
410+
if APP_ENV != "local" and self.save_results:
411+
self.results_handler.write_data(self.election_id, self.office, self.geographic_unit_type)
390412

391-
return results_handler.final_results
413+
return self.results_handler.final_results
392414

393415

394416
class HistoricalModelClient(ModelClient):

src/elexmodel/handlers/data/ModelResults.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
self.aggregates = [agg for agg in aggregates if agg != "unit"]
2626
self.estimates = {agg: [] for agg in self.aggregates}
2727
self.unit_data = {}
28+
self.final_results = {}
2829

2930
self.reporting_units = reporting_units
3031
self.nonreporting_units = nonreporting_units
@@ -93,7 +94,6 @@ def process_final_results(self):
9394
"""
9495
Create final data frames of results
9596
"""
96-
self.final_results = {}
9797
for agg in self.aggregates:
9898
merge_on = ["postal_code", "reporting", agg]
9999
# joins together dfs of the same level of aggregation (different estimands)
@@ -106,7 +106,14 @@ def process_final_results(self):
106106
lambda x, y: pd.merge(x, y, how="inner", on=merge_on), self.unit_data.values()
107107
)
108108

109-
def write_data(self, election_id, office, geographic_unit_type):
109+
def add_national_summary_estimates(self, national_summary_dict):
110+
df = pd.DataFrame.from_dict(
111+
national_summary_dict, orient="index", columns=["agg_pred", "agg_lower", "agg_upper"]
112+
)
113+
df.index.name = "estimand"
114+
self.final_results["nat_sum_data"] = df.reset_index()
115+
116+
def write_data(self, election_id, office, geographic_unit_type, keys=None):
110117
"""
111118
Saves dataframe of estimates for all estimands to S3
112119
Different file by aggregate level
@@ -115,6 +122,8 @@ def write_data(self, election_id, office, geographic_unit_type):
115122
self.process_final_results()
116123
s3_client = s3.S3CsvUtil(TARGET_BUCKET)
117124
for key, value in self.final_results.items():
125+
if keys is not None and key not in keys:
126+
continue
118127
path = f"{S3_FILE_PATH}/{election_id}/predictions/{office}/{geographic_unit_type}/{key}/current.csv"
119128
# convert df to csv
120129
csv_data = convert_df_to_csv(value)

src/elexmodel/models/BaseElectionModel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -172,5 +172,5 @@ def get_coefficients(self) -> dict:
172172
"""
173173
return self.features_to_coefficients
174174

175-
def get_national_summary_estimates(self, nat_sum_data_dict, called_states, base_to_add):
175+
def get_national_summary_estimates(self, nat_sum_data_dict, called_states, base_to_add, alpha):
176176
raise NotImplementedError()

src/elexmodel/models/ConformalElectionModel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -206,5 +206,5 @@ def get_all_conformalization_data_agg(cls):
206206
"""
207207
raise NotImplementedError
208208

209-
def get_national_summary_estimates(self, nat_sum_data_dict, called_states, base_to_add):
209+
def get_national_summary_estimates(self, nat_sum_data_dict, called_states, base_to_add, alpha):
210210
raise NotImplementedError()

tests/test_client.py

+44
Original file line numberDiff line numberDiff line change
@@ -832,3 +832,47 @@ def test_estimandizer_input(model_client, va_governor_county_data, va_config):
832832
)
833833
except KeyError:
834834
pytest.raises("Error with client input for estimandizer")
835+
836+
837+
def test_get_national_summary_votes_estimates(model_client, va_governor_county_data, va_config):
838+
expected = {"margin": [1.0, 1.0, 1.0]}
839+
expected_df = pd.DataFrame.from_dict(expected, orient="index", columns=["agg_pred", "agg_lower", "agg_upper"])
840+
expected_df.index.name = "estimand"
841+
expected_df = expected_df.reset_index()
842+
843+
election_id = "2017-11-07_VA_G"
844+
office_id = "G"
845+
geographic_unit_type = "county"
846+
estimands = ["margin"]
847+
prediction_intervals = [0.9]
848+
percent_reporting_threshold = 100
849+
kwargs = {"pi_method": "bootstrap", "features": ["baseline_normalized_margin"], "national_summary": True}
850+
851+
data_handler = MockLiveDataHandler(
852+
election_id, office_id, geographic_unit_type, estimands, data=va_governor_county_data
853+
)
854+
855+
data_handler.shuffle()
856+
data = data_handler.get_percent_fully_reported(100)
857+
858+
preprocessed_data = va_governor_county_data.copy()
859+
preprocessed_data["last_election_results_turnout"] = preprocessed_data["baseline_turnout"].copy() + 1
860+
861+
model_client.get_estimates(
862+
data,
863+
election_id,
864+
office_id,
865+
estimands,
866+
prediction_intervals,
867+
percent_reporting_threshold,
868+
geographic_unit_type,
869+
raw_config=va_config,
870+
preprocessed_data=preprocessed_data,
871+
save_output=[],
872+
**kwargs,
873+
)
874+
875+
current = model_client.get_national_summary_votes_estimates(None, 0, 0.99)
876+
877+
assert expected == current
878+
pd.testing.assert_frame_equal(expected_df, model_client.results_handler.final_results["nat_sum_data"])

0 commit comments

Comments
 (0)