Skip to content

Commit e19ff6f

Browse files
ephraimbuddypierrejeambrun
authored andcommitted
Cascade update of taskinstance to TaskMap table (#31445)
(cherry picked from commit f6bb474)
1 parent d7ba536 commit e19ff6f

File tree

6 files changed

+105
-7
lines changed

6 files changed

+105
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
19+
"""Add ``onupdate`` cascade to ``task_map`` table
20+
21+
Revision ID: c804e5c76e3e
22+
Revises: 98ae134e6fff
23+
Create Date: 2023-05-19 23:30:57.368617
24+
25+
"""
26+
from __future__ import annotations
27+
28+
from alembic import op
29+
30+
# revision identifiers, used by Alembic.
31+
revision = "c804e5c76e3e"
32+
down_revision = "98ae134e6fff"
33+
branch_labels = None
34+
depends_on = None
35+
airflow_version = "2.6.2"
36+
37+
38+
def upgrade():
39+
"""Apply Add onupdate cascade to taskmap"""
40+
with op.batch_alter_table("task_map") as batch_op:
41+
batch_op.drop_constraint("task_map_task_instance_fkey", type_="foreignkey")
42+
batch_op.create_foreign_key(
43+
"task_map_task_instance_fkey",
44+
"task_instance",
45+
["dag_id", "task_id", "run_id", "map_index"],
46+
["dag_id", "task_id", "run_id", "map_index"],
47+
ondelete="CASCADE",
48+
onupdate="CASCADE",
49+
)
50+
51+
52+
def downgrade():
53+
"""Unapply Add onupdate cascade to taskmap"""
54+
with op.batch_alter_table("task_map") as batch_op:
55+
batch_op.drop_constraint("task_map_task_instance_fkey", type_="foreignkey")
56+
batch_op.create_foreign_key(
57+
"task_map_task_instance_fkey",
58+
"task_instance",
59+
["dag_id", "task_id", "run_id", "map_index"],
60+
["dag_id", "task_id", "run_id", "map_index"],
61+
ondelete="CASCADE",
62+
)

airflow/models/taskmap.py

+1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class TaskMap(Base):
7272
],
7373
name="task_map_task_instance_fkey",
7474
ondelete="CASCADE",
75+
onupdate="CASCADE",
7576
),
7677
)
7778

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
c9c60f41925503ca504eca8b6688ca7ec2d8d7cbbd711285ae5c711cc2599f0e
1+
b8e715d65aba5965f3ae0134ee7c38dcff539ebf27442db9771145ed7fb18186

docs/apache-airflow/img/airflow_erd.svg

+4-4
Loading

docs/apache-airflow/migrations-ref.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru
3939
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
4040
| Revision ID | Revises ID | Airflow Version | Description |
4141
+=================================+===================+===================+==============================================================+
42-
| ``98ae134e6fff`` (head) | ``6abdffdd4815`` | ``2.6.0`` | Increase length of user identifier columns in ``ab_user`` |
42+
| ``c804e5c76e3e`` (head) | ``98ae134e6fff`` | ``2.6.2`` | Add ``onupdate`` cascade to ``task_map`` table |
43+
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
44+
| ``98ae134e6fff`` | ``6abdffdd4815`` | ``2.6.0`` | Increase length of user identifier columns in ``ab_user`` |
4345
| | | | and ``ab_register_user`` tables |
4446
+---------------------------------+-------------------+-------------------+--------------------------------------------------------------+
4547
| ``6abdffdd4815`` | ``290244fb8b83`` | ``2.6.0`` | add dttm index on log table |

tests/models/test_taskinstance.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from airflow.operators.python import PythonOperator
7070
from airflow.sensors.base import BaseSensorOperator
7171
from airflow.sensors.python import PythonSensor
72-
from airflow.serialization.serialized_objects import SerializedBaseOperator
72+
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
7373
from airflow.settings import TIMEZONE
7474
from airflow.stats import Stats
7575
from airflow.ti_deps.dep_context import DepContext
@@ -3484,6 +3484,39 @@ def pull_something(value):
34843484
assert task_map.length == expected_length
34853485
assert task_map.keys == expected_keys
34863486

3487+
def test_no_error_on_changing_from_non_mapped_to_mapped(self, dag_maker, session):
3488+
"""If a task changes from non-mapped to mapped, don't fail on integrity error."""
3489+
with dag_maker(dag_id="test_no_error_on_changing_from_non_mapped_to_mapped") as dag:
3490+
3491+
@dag.task()
3492+
def add_one(x):
3493+
return [x + 1]
3494+
3495+
@dag.task()
3496+
def add_two(x):
3497+
return x + 2
3498+
3499+
task1 = add_one(2)
3500+
add_two.expand(x=task1)
3501+
3502+
dr = dag_maker.create_dagrun()
3503+
ti = dr.get_task_instance(task_id="add_one")
3504+
ti.run()
3505+
assert ti.state == TaskInstanceState.SUCCESS
3506+
dag._remove_task("add_one")
3507+
with dag:
3508+
task1 = add_one.expand(x=[1, 2, 3]).operator
3509+
serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
3510+
3511+
dr.dag = serialized_dag
3512+
dr.verify_integrity(session=session)
3513+
ti = dr.get_task_instance(task_id="add_one")
3514+
assert ti.state == TaskInstanceState.REMOVED
3515+
dag.clear()
3516+
ti.refresh_from_task(task1)
3517+
# This should not raise an integrity error
3518+
dr.task_instance_scheduling_decisions()
3519+
34873520

34883521
class TestMappedTaskInstanceReceiveValue:
34893522
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)