@@ -41,6 +41,11 @@ def __init__(self):
41
41
self .all_conformalization_data_unit_dict = defaultdict (dict )
42
42
self .all_conformalization_data_agg_dict = defaultdict (dict )
43
43
self .model = None
44
+ self .results_handler = None
45
+ self .election_id = None
46
+ self .office = None
47
+ self .geographic_unit_type = None
48
+ self .save_results = None
44
49
45
50
def _check_input_parameters (
46
51
self ,
@@ -170,7 +175,20 @@ def get_aggregate_list(self, office, aggregate):
170
175
return sorted (list (set (raw_aggregate_list )), key = lambda x : AGGREGATE_ORDER .index (x ))
171
176
172
177
def get_national_summary_votes_estimates (self , nat_sum_data_dict = None , base_to_add = 0 , alpha = 0.99 ):
173
- return self .model .get_national_summary_estimates (nat_sum_data_dict , base_to_add , alpha )
178
+ if self .model is None :
179
+ raise ModelClientException (
180
+ "Must call the get_estimands() method before get_national_summary_votes_estimates()."
181
+ )
182
+
183
+ nat_sum_estimates = self .model .get_national_summary_estimates (nat_sum_data_dict , base_to_add , alpha )
184
+ self .results_handler .add_national_summary_estimates (nat_sum_estimates )
185
+
186
+ if APP_ENV != "local" and self .save_results :
187
+ self .results_handler .write_data (
188
+ self .election_id , self .office , self .geographic_unit_type , keys = ["nat_sum_data" ]
189
+ )
190
+
191
+ return nat_sum_estimates
174
192
175
193
def get_estimates (
176
194
self ,
@@ -202,7 +220,7 @@ def get_estimates(
202
220
pi_method = kwargs .get ("pi_method" , "nonparametric" )
203
221
called_contests = kwargs .get ("called_contests" , None )
204
222
save_output = kwargs .get ("save_output" , ["results" ])
205
- save_results = "results" in save_output
223
+ self . save_results = "results" in save_output
206
224
save_data = "data" in save_output
207
225
save_config = "config" in save_output
208
226
# saving conformalization data only makes sense if a ConformalElectionModel is used
@@ -241,15 +259,18 @@ def get_estimates(
241
259
model_parameters ,
242
260
handle_unreporting ,
243
261
)
262
+ self .election_id = election_id
263
+ self .office = office
264
+ self .geographic_unit_type = geographic_unit_type
244
265
245
- states_with_election = config_handler .get_states (office )
246
- estimand_baselines = config_handler .get_estimand_baselines (office , estimands )
266
+ states_with_election = config_handler .get_states (self . office )
267
+ estimand_baselines = config_handler .get_estimand_baselines (self . office , estimands )
247
268
248
- LOG .info ("Getting preprocessed data: %s" , election_id )
269
+ LOG .info ("Getting preprocessed data: %s" , self . election_id )
249
270
preprocessed_data_handler = PreprocessedDataHandler (
250
- election_id ,
251
- office ,
252
- geographic_unit_type ,
271
+ self . election_id ,
272
+ self . office ,
273
+ self . geographic_unit_type ,
253
274
estimands ,
254
275
estimand_baselines ,
255
276
data = preprocessed_data ,
@@ -267,7 +288,7 @@ def get_estimates(
267
288
preprocessed_data ,
268
289
current_data ,
269
290
estimands ,
270
- geographic_unit_type ,
291
+ self . geographic_unit_type ,
271
292
handle_unreporting = handle_unreporting ,
272
293
)
273
294
@@ -307,8 +328,8 @@ def get_estimates(
307
328
if minimum_reporting_units > minimum_reporting_units_max :
308
329
minimum_reporting_units_max = minimum_reporting_units
309
330
310
- if APP_ENV != "local" and save_results :
311
- data .write_data (election_id , office )
331
+ if APP_ENV != "local" and self . save_results :
332
+ data .write_data (self . election_id , self . office )
312
333
313
334
n_reporting_expected_units = reporting_units .shape [0 ]
314
335
n_unexpected_units = unexpected_units .shape [0 ]
@@ -330,44 +351,44 @@ def get_estimates(
330
351
if len (duplicate_units ) > 0 :
331
352
raise ModelClientException (f"At least one unit appears twice: { duplicate_units } " )
332
353
333
- results_handler = ModelResultsHandler (
354
+ self . results_handler = ModelResultsHandler (
334
355
aggregates , prediction_intervals , reporting_units , nonreporting_units , unexpected_units
335
356
)
336
357
337
358
for estimand in estimands :
338
359
unit_predictions , unit_turnout_predictions = self .model .get_unit_predictions (
339
360
reporting_units , nonreporting_units , estimand , unexpected_units = unexpected_units
340
361
)
341
- results_handler .add_unit_predictions (estimand , unit_predictions , unit_turnout_predictions )
362
+ self . results_handler .add_unit_predictions (estimand , unit_predictions , unit_turnout_predictions )
342
363
# gets prediciton intervals for each alpha
343
364
alpha_to_unit_prediction_intervals = {}
344
365
for alpha in prediction_intervals :
345
366
alpha_to_unit_prediction_intervals [alpha ] = self .model .get_unit_prediction_intervals (
346
- results_handler .reporting_units , results_handler .nonreporting_units , alpha , estimand
367
+ self . results_handler .reporting_units , self . results_handler .nonreporting_units , alpha , estimand
347
368
)
348
369
if isinstance (self .model , ConformalElectionModel ):
349
370
self .all_conformalization_data_unit_dict [alpha ][
350
371
estimand
351
372
] = self .model .get_all_conformalization_data_unit ()
352
373
353
- results_handler .add_unit_intervals (estimand , alpha_to_unit_prediction_intervals )
374
+ self . results_handler .add_unit_intervals (estimand , alpha_to_unit_prediction_intervals )
354
375
355
- for aggregate in results_handler .aggregates :
356
- aggregate_list = self .get_aggregate_list (office , aggregate )
376
+ for aggregate in self . results_handler .aggregates :
377
+ aggregate_list = self .get_aggregate_list (self . office , aggregate )
357
378
estimates_df = self .model .get_aggregate_predictions (
358
- results_handler .reporting_units ,
359
- results_handler .nonreporting_units ,
360
- results_handler .unexpected_units ,
379
+ self . results_handler .reporting_units ,
380
+ self . results_handler .nonreporting_units ,
381
+ self . results_handler .unexpected_units ,
361
382
aggregate_list ,
362
383
estimand ,
363
384
called_contests = called_contests ,
364
385
)
365
386
alpha_to_agg_prediction_intervals = {}
366
387
for alpha in prediction_intervals :
367
388
alpha_to_agg_prediction_intervals [alpha ] = self .model .get_aggregate_prediction_intervals (
368
- results_handler .reporting_units ,
369
- results_handler .nonreporting_units ,
370
- results_handler .unexpected_units ,
389
+ self . results_handler .reporting_units ,
390
+ self . results_handler .nonreporting_units ,
391
+ self . results_handler .unexpected_units ,
371
392
aggregate_list ,
372
393
alpha ,
373
394
alpha_to_unit_prediction_intervals [alpha ],
@@ -380,15 +401,16 @@ def get_estimates(
380
401
] = self .model .get_all_conformalization_data_agg ()
381
402
382
403
# get all of the prediction intervals here
383
- results_handler .add_agg_predictions (
404
+ self . results_handler .add_agg_predictions (
384
405
estimand , aggregate , estimates_df , alpha_to_agg_prediction_intervals
385
406
)
386
407
387
- results_handler .process_final_results ()
388
- if APP_ENV != "local" and save_results :
389
- results_handler .write_data (election_id , office , geographic_unit_type )
408
+ self .results_handler .process_final_results ()
409
+
410
+ if APP_ENV != "local" and self .save_results :
411
+ self .results_handler .write_data (self .election_id , self .office , self .geographic_unit_type )
390
412
391
- return results_handler .final_results
413
+ return self . results_handler .final_results
392
414
393
415
394
416
class HistoricalModelClient (ModelClient ):
0 commit comments