Skip to content

Commit

Permalink
Support TransferConfig for s3 uploads and downloads (#150)
Browse files Browse the repository at this point in the history
* add boto3_transfer_config

* add config to MockBoto3Object

* use default transfer config if none is passed

* expect instantiated TransferConfig

* add local tests

* longer sleep for more reliable tests

* only run asserts for mocked session

* Add TransferConfig test for live server

* lint

* Make upload assert test more reliable

Co-authored-by: Peter Bull <pjbull@gmail.com>
  • Loading branch information
ejm714 and pjbull authored Jul 5, 2021
1 parent 25ff783 commit 80f7afd
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 20 deletions.
9 changes: 7 additions & 2 deletions cloudpathlib/s3/s3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
try:
from boto3.session import Session
import botocore.session
from boto3.s3.transfer import TransferConfig
except ModuleNotFoundError:
implementation_registry["s3"].dependencies_loaded = False

Expand All @@ -29,6 +30,7 @@ def __init__(
boto3_session: Optional["Session"] = None,
local_cache_dir: Optional[Union[str, os.PathLike]] = None,
endpoint_url: Optional[str] = None,
boto3_transfer_config: Optional["TransferConfig"] = None,
):
"""Class constructor. Sets up a boto3 [`Session`](
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html).
Expand All @@ -52,6 +54,8 @@ def __init__(
for downloaded files. If None, will use a temporary directory.
endpoint_url (Optional[str]): S3 server endpoint URL to use for the constructed boto3 S3 resource and client.
Parameterize it to access a customly deployed S3-compatible object store such as MinIO, Ceph or any other.
boto3_transfer_config (Optional[dict]): Instantiated TransferConfig for managing s3 transfers.
(https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#boto3.s3.transfer.TransferConfig)
"""
if boto3_session is not None:
self.sess = boto3_session
Expand All @@ -65,6 +69,7 @@ def __init__(
)
self.s3 = self.sess.resource("s3", endpoint_url=endpoint_url)
self.client = self.sess.client("s3", endpoint_url=endpoint_url)
self.boto3_transfer_config = boto3_transfer_config

super().__init__(local_cache_dir=local_cache_dir)

Expand All @@ -83,7 +88,7 @@ def _download_file(self, cloud_path: S3Path, local_path: Union[str, os.PathLike]
local_path = Path(local_path)
obj = self.s3.Object(cloud_path.bucket, cloud_path.key)

obj.download_file(str(local_path))
obj.download_file(str(local_path), Config=self.boto3_transfer_config)
return local_path

def _is_file_or_dir(self, cloud_path: S3Path) -> Optional[str]:
Expand Down Expand Up @@ -199,7 +204,7 @@ def _remove(self, cloud_path: S3Path) -> None:
def _upload_file(self, local_path: Union[str, os.PathLike], cloud_path: S3Path) -> S3Path:
obj = self.s3.Object(cloud_path.bucket, cloud_path.key)

obj.upload_file(str(local_path))
obj.upload_file(str(local_path), Config=self.boto3_transfer_config)
return cloud_path


Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ mkdocstrings>=0.15
mypy
pandas
pillow
psutil
pydantic
pytest
pytest-cases
Expand Down
51 changes: 39 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ class CloudProviderTestRig:
"""Class that holds together the components needed to test a cloud implementation."""

def __init__(
self, path_class: type, client_class: type, drive: str = "drive", test_dir: str = ""
self,
path_class: type,
client_class: type,
drive: str = "drive",
test_dir: str = "",
live_server: bool = False,
):
"""
Args:
Expand All @@ -63,6 +68,7 @@ def __init__(
self.client_class = client_class
self.drive = drive
self.test_dir = test_dir
self.live_server = live_server # if the server is a live server

@property
def cloud_prefix(self):
Expand Down Expand Up @@ -90,7 +96,9 @@ def azure_rig(request, monkeypatch, assets_dir):
drive = os.getenv("LIVE_AZURE_CONTAINER", "container")
test_dir = create_test_dir_name(request)

if os.getenv("USE_LIVE_CLOUD") == "1":
live_server = os.getenv("USE_LIVE_CLOUD") == "1"

if live_server:
# Set up test assets
blob_service_client = BlobServiceClient.from_connection_string(
os.getenv("AZURE_STORAGE_CONNECTION_STRING")
Expand Down Expand Up @@ -118,6 +126,7 @@ def azure_rig(request, monkeypatch, assets_dir):
client_class=AzureBlobClient,
drive=drive,
test_dir=test_dir,
live_server=live_server,
)

rig.client_class().set_as_default_client() # set default client
Expand All @@ -126,7 +135,7 @@ def azure_rig(request, monkeypatch, assets_dir):

rig.client_class._default_client = None # reset default client

if os.getenv("USE_LIVE_CLOUD") == "1":
if live_server:
# Clean up test dir
container_client = blob_service_client.get_container_client(drive)
to_delete = container_client.list_blobs(name_starts_with=test_dir)
Expand All @@ -138,7 +147,9 @@ def gs_rig(request, monkeypatch, assets_dir):
drive = os.getenv("LIVE_GS_BUCKET", "bucket")
test_dir = create_test_dir_name(request)

if os.getenv("USE_LIVE_CLOUD") == "1":
live_server = os.getenv("USE_LIVE_CLOUD") == "1"

if live_server:
# Set up test assets
bucket = google_storage.Client().bucket(drive)
test_files = [
Expand All @@ -159,7 +170,11 @@ def gs_rig(request, monkeypatch, assets_dir):
)

rig = CloudProviderTestRig(
path_class=GSPath, client_class=GSClient, drive=drive, test_dir=test_dir
path_class=GSPath,
client_class=GSClient,
drive=drive,
test_dir=test_dir,
live_server=live_server,
)

rig.client_class().set_as_default_client() # set default client
Expand All @@ -168,7 +183,7 @@ def gs_rig(request, monkeypatch, assets_dir):

rig.client_class._default_client = None # reset default client

if os.getenv("USE_LIVE_CLOUD") == "1":
if live_server:
# Clean up test dir
for blob in bucket.list_blobs(prefix=test_dir):
blob.delete()
Expand All @@ -179,7 +194,9 @@ def s3_rig(request, monkeypatch, assets_dir):
drive = os.getenv("LIVE_S3_BUCKET", "bucket")
test_dir = create_test_dir_name(request)

if os.getenv("USE_LIVE_CLOUD") == "1":
live_server = os.getenv("USE_LIVE_CLOUD") == "1"

if live_server:
# Set up test assets
session = boto3.Session() # Fresh session to ensure isolation
bucket = session.resource("s3").Bucket(drive)
Expand All @@ -200,7 +217,11 @@ def s3_rig(request, monkeypatch, assets_dir):
)

rig = CloudProviderTestRig(
path_class=S3Path, client_class=S3Client, drive=drive, test_dir=test_dir
path_class=S3Path,
client_class=S3Client,
drive=drive,
test_dir=test_dir,
live_server=live_server,
)

rig.client_class().set_as_default_client() # set default client
Expand All @@ -209,7 +230,7 @@ def s3_rig(request, monkeypatch, assets_dir):

rig.client_class._default_client = None # reset default client

if os.getenv("USE_LIVE_CLOUD") == "1":
if live_server:
# Clean up test dir
bucket.objects.filter(Prefix=test_dir).delete()

Expand All @@ -226,7 +247,9 @@ def custom_s3_rig(request, monkeypatch, assets_dir):
test_dir = create_test_dir_name(request)
custom_endpoint_url = os.getenv("CUSTOM_S3_ENDPOINT", "https://s3.us-west-1.drivendatabws.com")

if os.getenv("USE_LIVE_CLOUD") == "1":
live_server = os.getenv("USE_LIVE_CLOUD") == "1"

if live_server:
monkeypatch.setenv("AWS_ACCESS_KEY_ID", os.getenv("CUSTOM_S3_KEY_ID"))
monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", os.getenv("CUSTOM_S3_SECRET_KEY"))

Expand Down Expand Up @@ -260,7 +283,11 @@ def custom_s3_rig(request, monkeypatch, assets_dir):
)

rig = CloudProviderTestRig(
path_class=S3Path, client_class=S3Client, drive=drive, test_dir=test_dir
path_class=S3Path,
client_class=S3Client,
drive=drive,
test_dir=test_dir,
live_server=live_server,
)

rig.client_class(
Expand All @@ -274,7 +301,7 @@ def custom_s3_rig(request, monkeypatch, assets_dir):

rig.client_class._default_client = None # reset default client

if os.getenv("USE_LIVE_CLOUD") == "1":
if live_server:
bucket.objects.filter(Prefix=test_dir).delete()


Expand Down
14 changes: 10 additions & 4 deletions tests/mock_clients/mock_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def client(self, item, endpoint_url):
class MockBoto3Resource:
def __init__(self, root):
self.root = root
self.download_config = None
self.upload_config = None

def Bucket(self, bucket):
return MockBoto3Bucket(self.root)
Expand All @@ -49,13 +51,14 @@ def ObjectSummary(self, bucket, key):
return MockBoto3ObjectSummary(self.root, key)

def Object(self, bucket, key):
return MockBoto3Object(self.root, key)
return MockBoto3Object(self.root, key, self)


class MockBoto3Object:
def __init__(self, root, path):
def __init__(self, root, path, resource):
self.root = root
self.path = root / path
self.resource = resource

def get(self):
if not self.path.exists() or self.path.is_dir():
Expand All @@ -70,13 +73,16 @@ def copy_from(self, CopySource=None, Metadata=None, MetadataDirective=None):
else:
self.path.write_bytes((self.root / Path(CopySource["Key"])).read_bytes())

def download_file(self, to_path):
def download_file(self, to_path, Config=None):
to_path = Path(to_path)
to_path.write_bytes(self.path.read_bytes())
# track config to make sure it's used in tests
self.resource.download_config = Config

def upload_file(self, from_path):
def upload_file(self, from_path, Config=None):
self.path.parent.mkdir(parents=True, exist_ok=True)
self.path.write_bytes(Path(from_path).read_bytes())
self.resource.upload_config = Config

def delete(self):
self.path.unlink()
Expand Down
5 changes: 3 additions & 2 deletions tests/test_cloudpath_upload_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_upload_from_file(rig, upload_assets_dir):

# to file, file exists
to_upload_2 = upload_assets_dir / "upload_2.txt"
sleep(0.5)
sleep(1.5)
to_upload_2.touch() # make sure local is newer
p.upload_from(to_upload_2)
assert p.exists()
Expand All @@ -69,7 +69,7 @@ def test_upload_from_file(rig, upload_assets_dir):

# to file, file exists and is newer; overwrite
p.touch()
sleep(0.5)
sleep(1.5)
p.upload_from(upload_assets_dir / "upload_1.txt", force_overwrite_to_cloud=True)
assert p.exists()
assert p.read_text() == "Hello from 1"
Expand Down Expand Up @@ -98,6 +98,7 @@ def test_upload_from_dir(rig, upload_assets_dir):
assert assert_mirrored(p2, upload_assets_dir, check_no_extra=False)

# a newer file exists on cloud
sleep(1)
(p / "upload_1.txt").touch()
with pytest.raises(OverwriteNewerCloudError):
p.upload_from(upload_assets_dir)
Expand Down
Loading

0 comments on commit 80f7afd

Please sign in to comment.