Skip to content

Commit

Permalink
fix: Reduce API polling for Document.from_batch_process_operation() (
Browse files Browse the repository at this point in the history
…#249)

* fix: Add polling interval and timeout parameters for `Document.from_batch_process_operation()`

fixes #246

Also added `ClientInfo` to Document AI Clients to collect API usage metrics

* Change default polling interval to `0`.

* Fix docstring formatting

* Change Operation Polling to use default api_core retry.

- Required significant refactoring to convert the operations_pb into a Google API Core Operation that can be polled.

* test: Fix unit tests for `_get_batch_process_metadata()`
  • Loading branch information
holtskinner authored Feb 2, 2024
1 parent b741498 commit 0677299
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 31 deletions.
3 changes: 2 additions & 1 deletion google/cloud/documentai_toolbox/converters/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def _get_base_ocr(
client = documentai.DocumentProcessorServiceClient(
client_options=ClientOptions(
api_endpoint=f"{location}-documentai.googleapis.com"
)
),
client_info=gcs_utilities._get_client_info(),
)

name = (
Expand Down
63 changes: 41 additions & 22 deletions google/cloud/documentai_toolbox/wrappers/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import re
from typing import Dict, List, Optional, Type, Union

from google.api_core.client_options import ClientOptions
from google.api_core.operation import from_gapic as operation_from_gapic
from google.cloud.vision import AnnotateFileResponse
from google.longrunning.operations_pb2 import GetOperationRequest, Operation
from google.longrunning.operations_pb2 import GetOperationRequest

from jinja2 import Environment, PackageLoader
from pikepdf import Pdf
Expand Down Expand Up @@ -137,48 +137,57 @@ def _get_shards(gcs_bucket_name: str, gcs_prefix: str) -> List[documentai.Docume


def _get_batch_process_metadata(
location: str, operation_name: str
operation_name: str,
timeout: Optional[float] = None,
) -> documentai.BatchProcessMetadata:
r"""Get `BatchProcessMetadata` from a `batch_process_documents()` long-running operation.
Args:
location (str):
Required. The location of the processor used for `batch_process_documents()`.
operation_name (str):
Required. The fully qualified operation name for a `batch_process_documents()` operation.
timeout (float):
Optional. Default None. Time in seconds to wait for operation to complete.
If None, will wait indefinitely.
Returns:
documentai.BatchProcessMetadata:
Metadata from batch process.
"""
client = documentai.DocumentProcessorServiceClient(
client_options=ClientOptions(
api_endpoint=f"{location}-documentai.googleapis.com"
)
client_info=gcs_utilities._get_client_info(module="get_batch_process_metadata"),
)

while True:
operation: Operation = client.get_operation(
request=GetOperationRequest(name=operation_name)
)
# Poll Operation until complete.
operation = operation_from_gapic(
operation=client.get_operation(
request=GetOperationRequest(name=operation_name),
metadata=documentai.BatchProcessMetadata(),
),
operations_client=client,
result_type=documentai.BatchProcessResponse,
)
operation.result(timeout=timeout)

if operation.done:
break
operation_pb = operation.operation

if not operation.metadata:
# Get Operation metadata.
if not operation_pb.metadata:
raise ValueError(f"Operation does not contain metadata: {operation}")

metadata_type = (
"type.googleapis.com/google.cloud.documentai.v1.BatchProcessMetadata"
)

if not operation.metadata.type_url or operation.metadata.type_url != metadata_type:
if (
not operation_pb.metadata.type_url
or operation_pb.metadata.type_url != metadata_type
):
raise ValueError(
f"Operation metadata type is not `{metadata_type}`. Type is `{operation.metadata.type_url}`."
f"Operation metadata type is not `{metadata_type}`. Type is `{operation_pb.metadata.type_url}`."
)

metadata: documentai.BatchProcessMetadata = (
documentai.BatchProcessMetadata.deserialize(operation.metadata.value)
documentai.BatchProcessMetadata.deserialize(operation_pb.metadata.value)
)

return metadata
Expand Down Expand Up @@ -518,7 +527,10 @@ def from_batch_process_metadata(

@classmethod
def from_batch_process_operation(
cls: Type["Document"], location: str, operation_name: str
cls: Type["Document"],
location: str, # pylint: disable=unused-argument
operation_name: str,
timeout: Optional[float] = None,
) -> List["Document"]:
r"""Loads Documents from Cloud Storage, using the operation name returned from `batch_process_documents()`.
Expand All @@ -533,19 +545,26 @@ def from_batch_process_operation(
Args:
location (str):
Required. The location of the processor used for `batch_process_documents()`.
Optional. The location of the processor used for `batch_process_documents()`.
Deprecated. Maintained for backwards compatibility.
operation_name (str):
Required. The fully qualified operation name for a `batch_process_documents()` operation.
Format: `projects/{project}/locations/{location}/operations/{operation}`
timeout (float):
Optional. Default None. Time in seconds to wait for operation to complete.
If None, will wait indefinitely.
Returns:
List[Document]:
A list of wrapped documents from gcs. Each document corresponds to an input file.
"""
return cls.from_batch_process_metadata(
metadata=_get_batch_process_metadata(
location=location, operation_name=operation_name
operation_name=operation_name,
timeout=timeout,
)
)

Expand Down
13 changes: 5 additions & 8 deletions tests/unit/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ def test_get_batch_process_metadata_with_valid_operation(

mock_client.get_operation.return_value = mock_operation

location = "us"
operation_name = "projects/123456/locations/us/operations/7890123"
document._get_batch_process_metadata(location, operation_name)
timeout = 1
document._get_batch_process_metadata(operation_name, timeout=timeout)

mock_client.get_operation.assert_called()
mock_docai.BatchProcessMetadata.deserialize.assert_called()
Expand Down Expand Up @@ -264,9 +264,8 @@ def test_get_batch_process_metadata_with_running_operation(
mock_operation_finished,
]

location = "us"
operation_name = "projects/123456/locations/us/operations/7890123"
document._get_batch_process_metadata(location, operation_name)
document._get_batch_process_metadata(operation_name)

mock_client.get_operation.assert_called()
mock_docai.BatchProcessMetadata.deserialize.assert_called()
Expand All @@ -280,12 +279,11 @@ def test_get_batch_process_metadata_with_no_metadata(mock_docai):
):
mock_client = mock_docai.DocumentProcessorServiceClient.return_value

location = "us"
operation_name = "projects/123456/locations/us/operations/7890123"
mock_operation = mock.Mock(done=True, metadata=None)
mock_client.get_operation.return_value = mock_operation

document._get_batch_process_metadata(location, operation_name)
document._get_batch_process_metadata(operation_name)


@mock.patch("google.cloud.documentai_toolbox.wrappers.document.documentai")
Expand All @@ -296,7 +294,6 @@ def test_get_batch_process_metadata_with_invalid_metadata_type(mock_docai):
):
mock_client = mock_docai.DocumentProcessorServiceClient.return_value

location = "us"
operation_name = "projects/123456/locations/us/operations/7890123"
mock_operation = mock.Mock(
done=True,
Expand All @@ -306,7 +303,7 @@ def test_get_batch_process_metadata_with_invalid_metadata_type(mock_docai):
)
mock_client.get_operation.return_value = mock_operation

document._get_batch_process_metadata(location, operation_name)
document._get_batch_process_metadata(operation_name)


def test_bigquery_column_name():
Expand Down

0 comments on commit 0677299

Please sign in to comment.