-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added Batch creation for Cloud Storage documents. (#66)
* feat: Added Batch creation for Cloud Storage documents. * Ran Black format on samples * Update noxfile.py * Changed Client to use custom user agent header * Updates to tests and docs * Fixed Test inputs * Add link to send processing request page * Change Import for sample --------- Co-authored-by: Gal Zahavi <38544478+galz10@users.noreply.github.com>
- Loading branch information
1 parent
19edf79
commit c32a371
Showing
10 changed files
with
367 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Document AI Toolbox Utilities | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
.. automodule:: google.cloud.documentai_toolbox.utilities.utilities | ||
:members: | ||
:private-members: | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
"""Document AI utilities.""" | ||
|
||
from typing import List, Optional | ||
|
||
from google.cloud import documentai | ||
|
||
from google.cloud.documentai_toolbox import constants | ||
from google.cloud.documentai_toolbox.wrappers.document import _get_storage_client | ||
|
||
|
||
def create_batches( | ||
gcs_bucket_name: str, | ||
gcs_prefix: str, | ||
batch_size: Optional[int] = constants.BATCH_MAX_FILES, | ||
) -> List[documentai.BatchDocumentsInputConfig]: | ||
"""Create batches of documents in Cloud Storage to process with `batch_process_documents()`. | ||
Args: | ||
gcs_bucket_name (str): | ||
Required. The name of the gcs bucket. | ||
Format: `gs://bucket/optional_folder/target_folder/` where gcs_bucket_name=`bucket`. | ||
gcs_prefix (str): | ||
Required. The prefix of the json files in the `target_folder` | ||
Format: `gs://bucket/optional_folder/target_folder/` where gcs_prefix=`optional_folder/target_folder`. | ||
batch_size (Optional[int]): | ||
Optional. Size of each batch of documents. Default is `50`. | ||
Returns: | ||
List[documentai.BatchDocumentsInputConfig]: | ||
A list of `BatchDocumentsInputConfig`, each corresponding to one batch. | ||
""" | ||
if batch_size > constants.BATCH_MAX_FILES: | ||
raise ValueError( | ||
f"Batch size must be less than {constants.BATCH_MAX_FILES}. You provided {batch_size}." | ||
) | ||
|
||
storage_client = _get_storage_client() | ||
blob_list = storage_client.list_blobs(gcs_bucket_name, prefix=gcs_prefix) | ||
batches: List[documentai.BatchDocumentsInputConfig] = [] | ||
batch: List[documentai.GcsDocument] = [] | ||
|
||
for blob in blob_list: | ||
# Skip Directories | ||
if blob.name.endswith("/"): | ||
continue | ||
|
||
if blob.content_type not in constants.VALID_MIME_TYPES: | ||
print(f"Skipping file {blob.name}. Invalid Mime Type {blob.content_type}.") | ||
continue | ||
|
||
if blob.size > constants.BATCH_MAX_FILE_SIZE: | ||
print( | ||
f"Skipping file {blob.name}. File size must be less than {constants.BATCH_MAX_FILE_SIZE} bytes. File size is {blob.size} bytes." | ||
) | ||
continue | ||
|
||
if len(batch) == batch_size: | ||
batches.append( | ||
documentai.BatchDocumentsInputConfig( | ||
gcs_documents=documentai.GcsDocuments(documents=batch) | ||
) | ||
) | ||
batch = [] | ||
|
||
batch.append( | ||
documentai.GcsDocument( | ||
gcs_uri=f"gs://{gcs_bucket_name}/{blob.name}", | ||
mime_type=blob.content_type, | ||
) | ||
) | ||
|
||
if batch != []: | ||
# Append the last batch, which could be less than `batch_size` | ||
batches.append( | ||
documentai.BatchDocumentsInputConfig( | ||
gcs_documents=documentai.GcsDocuments(documents=batch) | ||
) | ||
) | ||
|
||
return batches |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
|
||
# [START documentai_toolbox_create_batches] | ||
|
||
from google.cloud import documentai | ||
from google.cloud.documentai_toolbox import utilities | ||
|
||
# TODO(developer): Uncomment these variables before running the sample. | ||
# Given unprocessed documents in path gs://bucket/path/to/folder | ||
# gcs_bucket_name = "bucket" | ||
# gcs_prefix = "path/to/folder" | ||
# batch_size = 50 | ||
|
||
|
||
def create_batches_sample( | ||
gcs_bucket_name: str, | ||
gcs_prefix: str, | ||
batch_size: int = 50, | ||
) -> None: | ||
# Creating batches of documents for processing | ||
batches = utilities.create_batches( | ||
gcs_bucket_name=gcs_bucket_name, gcs_prefix=gcs_prefix, batch_size=batch_size | ||
) | ||
|
||
print(f"{len(batches)} batch(es) created.") | ||
for batch in batches: | ||
print(f"{len(batch.gcs_documents.documents)} files in batch.") | ||
print(batch.gcs_documents.documents) | ||
|
||
# Use as input for batch_process_documents() | ||
# Refer to https://cloud.google.com/document-ai/docs/send-request | ||
# for how to send a batch processing request | ||
request = documentai.BatchProcessRequest( | ||
name="processor_name", input_documents=batch | ||
) | ||
|
||
|
||
# [END documentai_toolbox_create_batches] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
|
||
import pytest | ||
from samples.snippets import create_batches_sample | ||
|
||
gcs_bucket_name = "cloud-samples-data" | ||
gcs_input_uri = "documentai_toolbox/document_batches/" | ||
batch_size = 50 | ||
|
||
|
||
def test_create_batches_sample(capsys: pytest.CaptureFixture) -> None: | ||
create_batches_sample.create_batches_sample( | ||
gcs_bucket_name=gcs_bucket_name, gcs_prefix=gcs_input_uri, batch_size=batch_size | ||
) | ||
out, _ = capsys.readouterr() | ||
|
||
assert "2 batch(es) created." in out | ||
assert "50 files in batch." in out | ||
assert "47 files in batch." in out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# pylint: disable=protected-access | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
|
||
from google.cloud.documentai_toolbox.utilities import utilities | ||
|
||
# try/except added for compatibility with python < 3.8 | ||
try: | ||
from unittest import mock | ||
except ImportError: # pragma: NO COVER | ||
import mock | ||
|
||
|
||
test_bucket = "test-directory" | ||
test_prefix = "documentai/input" | ||
|
||
|
||
@mock.patch("google.cloud.documentai_toolbox.wrappers.document.storage") | ||
def test_create_batches_with_3_documents(mock_storage, capfd): | ||
client = mock_storage.Client.return_value | ||
mock_bucket = mock.Mock() | ||
client.Bucket.return_value = mock_bucket | ||
|
||
mock_blobs = [] | ||
for i in range(3): | ||
mock_blob = mock.Mock( | ||
name=f"test_file{i}.pdf", content_type="application/pdf", size=1024 | ||
) | ||
mock_blob.name.endswith.return_value = False | ||
mock_blobs.append(mock_blob) | ||
client.list_blobs.return_value = mock_blobs | ||
|
||
actual = utilities.create_batches( | ||
gcs_bucket_name=test_bucket, gcs_prefix=test_prefix | ||
) | ||
|
||
mock_storage.Client.assert_called_once() | ||
|
||
out, err = capfd.readouterr() | ||
assert out == "" | ||
assert len(actual) == 1 | ||
assert len(actual[0].gcs_documents.documents) == 3 | ||
|
||
|
||
def test_create_batches_with_invalid_batch_size(capfd): | ||
with pytest.raises(ValueError): | ||
utilities.create_batches( | ||
gcs_bucket_name=test_bucket, gcs_prefix=test_prefix, batch_size=51 | ||
) | ||
|
||
out, err = capfd.readouterr() | ||
assert "Batch size must be less than" in out | ||
assert err | ||
|
||
|
||
@mock.patch("google.cloud.documentai_toolbox.wrappers.document.storage") | ||
def test_create_batches_with_large_folder(mock_storage, capfd): | ||
client = mock_storage.Client.return_value | ||
mock_bucket = mock.Mock() | ||
client.Bucket.return_value = mock_bucket | ||
|
||
mock_blobs = [] | ||
for i in range(96): | ||
mock_blob = mock.Mock( | ||
name=f"test_file{i}.pdf", content_type="application/pdf", size=1024 | ||
) | ||
mock_blob.name.endswith.return_value = False | ||
mock_blobs.append(mock_blob) | ||
client.list_blobs.return_value = mock_blobs | ||
|
||
actual = utilities.create_batches( | ||
gcs_bucket_name=test_bucket, gcs_prefix=test_prefix | ||
) | ||
|
||
mock_storage.Client.assert_called_once() | ||
|
||
out, err = capfd.readouterr() | ||
assert out == "" | ||
assert len(actual) == 2 | ||
assert len(actual[0].gcs_documents.documents) == 50 | ||
assert len(actual[1].gcs_documents.documents) == 46 | ||
|
||
|
||
@mock.patch("google.cloud.documentai_toolbox.wrappers.document.storage") | ||
def test_create_batches_with_invalid_file_type(mock_storage, capfd): | ||
client = mock_storage.Client.return_value | ||
mock_bucket = mock.Mock() | ||
client.Bucket.return_value = mock_bucket | ||
|
||
mock_blob = mock.Mock( | ||
name="test_file.json", content_type="application/json", size=1024 | ||
) | ||
mock_blob.name.endswith.return_value = False | ||
client.list_blobs.return_value = [mock_blob] | ||
|
||
actual = utilities.create_batches( | ||
gcs_bucket_name=test_bucket, gcs_prefix=test_prefix | ||
) | ||
|
||
mock_storage.Client.assert_called_once() | ||
|
||
out, err = capfd.readouterr() | ||
assert "Invalid Mime Type" in out | ||
assert actual == [] | ||
|
||
|
||
@mock.patch("google.cloud.documentai_toolbox.wrappers.document.storage") | ||
def test_create_batches_with_large_file(mock_storage, capfd): | ||
client = mock_storage.Client.return_value | ||
mock_bucket = mock.Mock() | ||
client.Bucket.return_value = mock_bucket | ||
|
||
mock_blob = mock.Mock( | ||
name="test_file.pdf", content_type="application/pdf", size=2073741824 | ||
) | ||
mock_blob.name.endswith.return_value = False | ||
client.list_blobs.return_value = [mock_blob] | ||
|
||
actual = utilities.create_batches( | ||
gcs_bucket_name=test_bucket, gcs_prefix=test_prefix | ||
) | ||
|
||
mock_storage.Client.assert_called_once() | ||
|
||
out, err = capfd.readouterr() | ||
assert "File size must be less than" in out | ||
assert actual == [] |