Skip to content

Commit

Permalink
linter updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdeitke committed May 30, 2022
1 parent a9f3fa3 commit 4b8de31
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions prior/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,24 +101,23 @@ def load_dataset(
repo = g.get_repo(f"{entity}/{dataset}")

# main sha
sha: str
if revision is None:
# get the latest commit
f_revision: str = repo.get_branch("main").commit.sha
sha = repo.get_branch("main").commit.sha
elif any(revision == branch.name for branch in repo.get_branches()):
# if revision is a branch name, get the commit_id of the branch
f_revision: str = repo.get_branch(revision).commit.sha
sha = repo.get_branch(revision).commit.sha
elif any(revision == tag.name for tag in repo.get_tags()):
# if revision is a tag, get the commit_id of the tag
f_revision: str = repo.get_tag(revision).commit.sha
sha = repo.get_tag(revision).commit.sha
else:
# if revision is a commit_id, use it
f_revision: str = revision

revision: str = f_revision
sha = revision

# make sure the commit_id is valid
try:
repo.get_commit(revision)
repo.get_commit(sha)
except GithubException:
raise GithubException(
f"Could not find revision={revision} in dataset={entity}/{dataset}."
Expand All @@ -127,7 +126,7 @@ def load_dataset(

# download the dataset
dataset_dir = f"{os.environ['HOME']}/.prior/datasets/{dataset}"
dataset_path = f"{dataset_dir}/{revision}"
dataset_path = f"{dataset_dir}/{sha}"
if os.path.exists(dataset_path):
logging.info(f"Found dataset {dataset} at revision {revision} in {dataset_path}.")
else:
Expand All @@ -141,7 +140,7 @@ def load_dataset(
# change the subprocess working directory to the dataset directory
os.chdir(dataset_dir)
subprocess.run(
args=["git", "checkout", revision],
args=["git", "checkout", sha],
stderr=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
)
Expand All @@ -153,4 +152,4 @@ def load_dataset(
dataset: DatasetDict = out["load_dataset"]()
os.chdir(start_dir)
return dataset
raise NotImplemented("Dataset not .")
raise NotImplementedError("Dataset not .")

0 comments on commit 4b8de31

Please sign in to comment.