Skip to content

Commit 102cc7c

Browse files
authored
Merge pull request #123 from washingtonpost/extrapolation-rule
small changes to extrapolation
2 parents 58a134e + a75516b commit 102cc7c

File tree

4 files changed

+26
-9
lines changed

4 files changed

+26
-9
lines changed

src/elexmodel/client.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -328,10 +328,16 @@ def get_estimates(
328328
start_date=model_parameters.get("versioned_start_date", None),
329329
end_date=model_parameters.get("versioned_end_date", None),
330330
)
331-
print(
332-
"Fetching versioned data between ", versioned_data_handler.start_date, versioned_data_handler.end_date
331+
LOG.info(
332+
"Fetching versioned data between %s and %s",
333+
versioned_data_handler.start_date,
334+
versioned_data_handler.end_date,
333335
)
334-
versioned_data_handler.get_versioned_results(model_settings.get("versioned_filepath", None))
336+
versioned_results = versioned_data_handler.get_versioned_results(
337+
model_settings.get("versioned_filepath", None)
338+
)
339+
if versioned_results is None:
340+
versioned_data_handler = None
335341
else:
336342
versioned_data_handler = None
337343

src/elexmodel/handlers/data/VersionedData.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from datetime import datetime
2+
13
import numpy as np
24
import pandas as pd
35
from dateutil import tz
@@ -30,9 +32,8 @@ def __init__(
3032
target_bucket = "elex-models-prod"
3133
else:
3234
target_bucket = TARGET_BUCKET
33-
34-
start_date = start_date.astimezone(tz=tz.gettz("UTC")) if start_date else None
35-
end_date = end_date.astimezone(tz=tz.gettz("UTC")) if start_date else None
35+
start_date = datetime.fromisoformat(start_date).astimezone(tz=tz.gettz("UTC")) if start_date else None
36+
end_date = datetime.fromisoformat(end_date).astimezone(tz=tz.gettz("UTC")) if start_date else None
3637
# versioned results natively are in UTC but we'll convert it back to timezone in tzinfo
3738
self.s3_client = s3.S3VersionUtil(target_bucket, start_date, end_date, tzinfo)
3839

@@ -66,6 +67,9 @@ def get_versioned_results(self, filepath=None):
6667
path = f"{S3_FILE_PATH}/{self.election_id}/results/{self.office_id}/{self.geographic_unit_type}/current.csv"
6768

6869
data = self.s3_client.get(path, self.sample)
70+
if data is None:
71+
self.data = data
72+
return data
6973
estimandizer = Estimandizer()
7074
data, _ = estimandizer.add_estimand_results(data, self.estimands, False)
7175
self.data = data.sort_values("last_modified")

src/elexmodel/handlers/s3.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ def list_versions(self, path, **kwargs):
105105
if "Versions" in response:
106106
versions = response["Versions"]
107107

108-
if response["IsTruncated"] and len(versions) > 0 and versions[-1]["LastModified"] >= self.start_date:
108+
if (
109+
response["IsTruncated"]
110+
and len(versions) > 0
111+
and (self.start_date is None or versions[-1]["LastModified"] >= self.start_date)
112+
):
109113
versions += self.list_versions(
110114
path,
111115
KeyMarker=response["NextKeyMarker"],
@@ -145,7 +149,8 @@ def make_request(self, path, *, version=None, **kwargs):
145149
def get(self, path, sample=2):
146150
versions = self.list_versions(path)
147151
if len(versions) == 0:
148-
raise ValueError(f"No versions found for {path}")
152+
LOG.info(f"No versions found for {path}")
153+
return None
149154

150155
# Instead of asking for the results of downloads synchronously, we're
151156
# queuing the futures and then waiting for them to complete.

src/elexmodel/models/BootstrapElectionModel.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations # pylint: disable=too-many-lines
22

33
import logging
4+
from datetime import timedelta
45
from itertools import combinations
56

67
import numpy as np
@@ -108,7 +109,7 @@ def __init__(self, model_settings={}, versioned_data_handler=None):
108109
self.contest_correlations = model_settings.get("contest_correlations", [])
109110

110111
# impose perfect correlation in the national summary aggregation
111-
self.national_summary_correlation = model_settings.get("national_summary_correlation", False)
112+
self.national_summary_correlation = model_settings.get("national_summary_correlation", True)
112113
self.stop_model_call = None
113114
# Assume that we have a baseline normalized margin
114115
# (D^{Y'} - R^{Y'}) / (D^{Y'} + R^{Y'}) is one of the covariates
@@ -795,6 +796,7 @@ def _extrapolate_unit_margin(self, reporting_units: pd.DataFrame, nonreporting_u
795796
all_units = pd.concat([reporting_units, nonreporting_units], axis=0).copy()
796797
missing_columns = list(set(self.versioned_data_handler.data.columns) - set(all_units.columns))
797798
all_units[missing_columns] = self.versioned_data_handler.data[missing_columns].max()
799+
all_units["last_modified"] = self.versioned_data_handler.data["last_modified"].max() + timedelta(seconds=1)
798800

799801
self.versioned_data_handler.data = pd.concat(
800802
[self.versioned_data_handler.data, all_units[self.versioned_data_handler.data.columns]], axis=0

0 commit comments

Comments
 (0)