Skip to content

Commit 97be207

Browse files
authored
Rollback for all retry exceptions (#40882) (#40883)
In #19856, we added `DBAPIError` besides `OperationalError` to the retry exception types, but did not change the `retry_db_transaction` decorator to rollback transaction after failures and before a retry. In certain cases (see #40882), this is needed as otherwise all retries will fail when the current session/transaction was "poisened" by the initial error.
1 parent bef82d6 commit 97be207

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

airflow/utils/retries.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from inspect import signature
2222
from typing import Callable, TypeVar, overload
2323

24-
from sqlalchemy.exc import DBAPIError, OperationalError
24+
from sqlalchemy.exc import DBAPIError
2525

2626
from airflow.configuration import conf
2727

@@ -36,7 +36,7 @@ def run_with_db_retries(max_retries: int = MAX_DB_RETRIES, logger: logging.Logge
3636

3737
# Default kwargs
3838
retry_kwargs = dict(
39-
retry=tenacity.retry_if_exception_type(exception_types=(OperationalError, DBAPIError)),
39+
retry=tenacity.retry_if_exception_type(exception_types=(DBAPIError)),
4040
wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
4141
stop=tenacity.stop_after_attempt(max_retries),
4242
reraise=True,
@@ -58,7 +58,7 @@ def retry_db_transaction(_func: F) -> F: ...
5858

5959
def retry_db_transaction(_func: Callable | None = None, *, retries: int = MAX_DB_RETRIES, **retry_kwargs):
6060
"""
61-
Retry functions in case of ``OperationalError`` from DB.
61+
Retry functions in case of ``DBAPIError`` from DB.
6262
6363
It should not be used with ``@provide_session``.
6464
"""
@@ -96,7 +96,7 @@ def wrapped_function(*args, **kwargs):
9696
)
9797
try:
9898
return func(*args, **kwargs)
99-
except OperationalError:
99+
except DBAPIError:
100100
session.rollback()
101101
raise
102102

tests/utils/test_retries.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@
1818
from __future__ import annotations
1919

2020
import logging
21+
from typing import TYPE_CHECKING
2122
from unittest import mock
2223

2324
import pytest
24-
from sqlalchemy.exc import OperationalError
25+
from sqlalchemy.exc import InternalError, OperationalError
2526

2627
from airflow.utils.retries import retry_db_transaction
2728

29+
if TYPE_CHECKING:
30+
from sqlalchemy.exc import DBAPIError
31+
2832

2933
class TestRetries:
3034
def test_retry_db_transaction_with_passing_retries(self):
@@ -45,23 +49,24 @@ def test_function(session):
4549
assert mock_obj.call_count == 2
4650

4751
@pytest.mark.db_test
48-
def test_retry_db_transaction_with_default_retries(self, caplog):
52+
@pytest.mark.parametrize("excection_type", [OperationalError, InternalError])
53+
def test_retry_db_transaction_with_default_retries(self, caplog, excection_type: type[DBAPIError]):
4954
"""Test that by default 3 retries will be carried out"""
5055
mock_obj = mock.MagicMock()
5156
mock_session = mock.MagicMock()
5257
mock_rollback = mock.MagicMock()
5358
mock_session.rollback = mock_rollback
54-
op_error = OperationalError(statement=mock.ANY, params=mock.ANY, orig=mock.ANY)
59+
db_error = excection_type(statement=mock.ANY, params=mock.ANY, orig=mock.ANY)
5560

5661
@retry_db_transaction
5762
def test_function(session):
5863
session.execute("select 1")
5964
mock_obj(2)
60-
raise op_error
65+
raise db_error
6166

6267
caplog.set_level(logging.DEBUG, logger=self.__module__)
6368
caplog.clear()
64-
with pytest.raises(OperationalError):
69+
with pytest.raises(excection_type):
6570
test_function(session=mock_session)
6671

6772
for try_no in range(1, 4):

0 commit comments

Comments
 (0)