Skip to content

Commit

Permalink
feat: Added gcs_uri parameter to Document.from_gcs() to allow imp…
Browse files Browse the repository at this point in the history
…orting of a single Document JSON (#261)

* feat: Added `gcs_uri` parameter to `Document.from_gcs()` to allow importing of a single Document JSON
  * Change `get_blob()` to use `storage.Blob.from_string()`
  • Loading branch information
holtskinner authored Feb 26, 2024
1 parent 4c2a5d9 commit f654a5d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 9 deletions.
12 changes: 4 additions & 8 deletions google/cloud/documentai_toolbox/utilities/gcs_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,13 @@ def get_blob(
module (Optional[str]):
Optional. The module for a custom user agent header.
Returns:
List[storage.blob.Blob]:
A list of the blobs in the Cloud Storage path.
storage.blob.Blob:
The blob in the Cloud Storage path.
"""
gcs_bucket_name, gcs_file_name = split_gcs_uri(gcs_uri)

if not re.match(constants.FILE_CHECK_REGEX, gcs_file_name):
if not re.match(constants.FILE_CHECK_REGEX, gcs_uri):
raise ValueError("gcs_uri must link to a single file.")

storage_client = _get_storage_client(module=module)
bucket = storage_client.bucket(bucket_name=gcs_bucket_name)
return bucket.get_blob(gcs_file_name)
return storage.Blob.from_string(gcs_uri, _get_storage_client(module=module))


def split_gcs_uri(gcs_uri: str) -> Tuple[str, str]:
Expand Down
37 changes: 36 additions & 1 deletion google/cloud/documentai_toolbox/wrappers/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ class Document:
shards: List[documentai.Document] = dataclasses.field(repr=False)
gcs_bucket_name: Optional[str] = dataclasses.field(default=None, repr=False)
gcs_prefix: Optional[str] = dataclasses.field(default=None, repr=False)
gcs_uri: Optional[str] = dataclasses.field(default=None, repr=False)
gcs_input_uri: Optional[str] = dataclasses.field(default=None, repr=False)

_pages: Optional[List[Page]] = dataclasses.field(
Expand Down Expand Up @@ -463,7 +464,7 @@ def from_gcs(
gcs_prefix: str,
gcs_input_uri: Optional[str] = None,
) -> "Document":
r"""Loads Document from Cloud Storage.
r"""Loads a Document from a Cloud Storage directory.
Args:
gcs_bucket_name (str):
Expand All @@ -490,6 +491,40 @@ def from_gcs(
gcs_input_uri=gcs_input_uri,
)

@classmethod
def from_gcs_uri(
cls: Type["Document"],
gcs_uri: str,
gcs_input_uri: Optional[str] = None,
) -> "Document":
r"""Loads a Document from a Cloud Storage uri.
Args:
gcs_uri (str):
Required. The full GCS uri to a Document JSON file.
Example: `gs://{bucket_name}/{optional_folder}/{target_file}.json`.
gcs_input_uri (str):
Optional. The gcs uri to the original input file.
Format: `gs://{bucket_name}/{optional_folder}/{target_folder}/{file_name}.pdf`
Returns:
Document:
A document from gcs.
"""
blob = gcs_utilities.get_blob(gcs_uri=gcs_uri, module="get-document")
shards = [
documentai.Document.from_json(
blob.download_as_bytes(),
ignore_unknown_fields=True,
)
]
return cls(
shards=shards,
gcs_uri=gcs_uri,
gcs_input_uri=gcs_input_uri,
)

@classmethod
def from_batch_process_metadata(
cls: Type["Document"], metadata: documentai.BatchProcessMetadata
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ def get_bytes_missing_shard_mock():
yield byte_factory


@pytest.fixture
def get_blob_mock():
with mock.patch.object(gcs_utilities, "get_blob") as blob_factory:
mock_blob = mock.Mock()
mock_blob.download_as_bytes.return_value = get_bytes("tests/unit/resources/0")[
0
]
blob_factory.return_value = mock_blob
yield blob_factory


def create_document_with_images_without_bbox(get_bytes_images_mock):
doc = document.Document.from_gcs(
gcs_bucket_name="test-directory", gcs_prefix="documentai/output/123456789/0"
Expand Down Expand Up @@ -394,6 +405,25 @@ def test_document_from_gcs_with_unordered_shards(get_bytes_unordered_files_mock)
assert page.page_number == page_index + 1


def test_document_from_gcs_uri(get_blob_mock):
actual = document.Document.from_gcs_uri(
gcs_uri="gs://test-directory/documentai/output/123456789/0/document.json"
)

get_blob_mock.assert_called_once()

assert (
actual.gcs_uri
== "gs://test-directory/documentai/output/123456789/0/document.json"
)
assert len(actual.pages) == 1
# checking cached value
assert len(actual.pages) == 1

assert len(actual.text) > 0
assert len(actual.text) > 0


def test_document_from_batch_process_metadata_with_multiple_input_files(
get_bytes_multiple_directories_mock,
):
Expand Down

0 comments on commit f654a5d

Please sign in to comment.