Skip to content

Commit

Permalink
Migrate Edge calls for Worker to FastAPI part 3 - Jobs routes (apache…
Browse files Browse the repository at this point in the history
…#44433)

* Migrate Edge calls for Worker to FastAPI 3 - Jobs route

* Review Feedback

* Remove outdated type hints from review feedback

* Update providers/src/airflow/providers/edge/worker_api/routes/jobs.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Add missing filter for free concurrency

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
  • Loading branch information
2 people authored and got686-yandex committed Jan 30, 2025
1 parent 34abbd3 commit b0822f0
Show file tree
Hide file tree
Showing 11 changed files with 541 additions and 42 deletions.
8 changes: 8 additions & 0 deletions providers/src/airflow/providers/edge/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
Changelog
---------

0.8.2pre0
.........

Misc
~~~~

* ``Migrate worker job calls to FastAPI.``

0.8.1pre0
.........

Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/edge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

__all__ = ["__version__"]

__version__ = "0.8.1pre0"
__version__ = "0.8.2pre0"

if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
"2.10.0"
Expand Down
34 changes: 32 additions & 2 deletions providers/src/airflow/providers/edge/cli/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.edge.worker_api.auth import jwt_signer
from airflow.providers.edge.worker_api.datamodels import PushLogsBody, WorkerStateBody
from airflow.providers.edge.worker_api.datamodels import (
EdgeJobFetched,
PushLogsBody,
WorkerQueuesBody,
WorkerStateBody,
)
from airflow.utils.state import TaskInstanceState # noqa: TC001

if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down Expand Up @@ -114,6 +120,28 @@ def worker_set_state(
)


def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int) -> EdgeJobFetched | None:
"""Fetch a job to execute on the edge worker."""
result = _make_generic_request(
"GET",
f"jobs/fetch/{quote(hostname)}",
WorkerQueuesBody(queues=queues, free_concurrency=free_concurrency).model_dump_json(
exclude_unset=True
),
)
if result:
return EdgeJobFetched(**result)
return None


def jobs_set_state(key: TaskInstanceKey, state: TaskInstanceState) -> None:
"""Set the state of a job."""
_make_generic_request(
"PATCH",
f"jobs/state/{key.dag_id}/{key.task_id}/{key.run_id}/{key.try_number}/{key.map_index}/{state}",
)


def logs_logfile_path(task: TaskInstanceKey) -> Path:
"""Elaborate the path and filename to expect from task execution."""
result = _make_generic_request(
Expand All @@ -133,5 +161,7 @@ def logs_push(
_make_generic_request(
"POST",
f"logs/push/{task.dag_id}/{task.task_id}/{task.run_id}/{task.try_number}/{task.map_index}",
PushLogsBody(log_chunk_time=log_chunk_time, log_chunk_data=log_chunk_data).model_dump_json(),
PushLogsBody(log_chunk_time=log_chunk_time, log_chunk_data=log_chunk_data).model_dump_json(
exclude_unset=True
),
)
21 changes: 12 additions & 9 deletions providers/src/airflow/providers/edge/cli/edge_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pathlib import Path
from subprocess import Popen
from time import sleep
from typing import TYPE_CHECKING

import psutil
from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile, write_pid_to_pidfile
Expand All @@ -37,18 +38,22 @@
from airflow.exceptions import AirflowException
from airflow.providers.edge import __version__ as edge_provider_version
from airflow.providers.edge.cli.api_client import (
jobs_fetch,
jobs_set_state,
logs_logfile_path,
logs_push,
worker_register,
worker_set_state,
)
from airflow.providers.edge.models.edge_job import EdgeJob
from airflow.providers.edge.models.edge_worker import EdgeWorkerState, EdgeWorkerVersionException
from airflow.utils import cli as cli_utils, timezone
from airflow.utils.platform import IS_WINDOWS
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.providers.edge.worker_api.datamodels import EdgeJobFetched

logger = logging.getLogger(__name__)
EDGE_WORKER_PROCESS_NAME = "edge-worker"
EDGE_WORKER_HEADER = "\n".join(
Expand Down Expand Up @@ -81,7 +86,7 @@ def force_use_internal_api_on_edge_worker():
if AIRFLOW_V_3_0_PLUS:
# Obvious TODO Make EdgeWorker compatible with Airflow 3 (again)
raise SystemExit(
"Error: EdgeWorker is currently broken on AIrflow 3/main due to removal of AIP-44, rework for AIP-72."
"Error: EdgeWorker is currently broken on Airflow 3/main due to removal of AIP-44, rework for AIP-72."
)

api_url = conf.get("edge", "api_url")
Expand Down Expand Up @@ -141,7 +146,7 @@ def _write_pid_to_pidfile(pid_file_path: str):
class _Job:
"""Holds all information for a task/job to be executed as bundle."""

edge_job: EdgeJob
edge_job: EdgeJobFetched
process: Popen
logfile: Path
logsize: int
Expand Down Expand Up @@ -240,9 +245,7 @@ def loop(self):
def fetch_job(self) -> bool:
"""Fetch and start a new job from central site."""
logger.debug("Attempting to fetch a new job...")
edge_job = EdgeJob.reserve_task(
worker_name=self.hostname, free_concurrency=self.free_concurrency, queues=self.queues
)
edge_job = jobs_fetch(self.hostname, self.queues, self.free_concurrency)
if edge_job:
logger.info("Received job: %s", edge_job)
env = os.environ.copy()
Expand All @@ -252,7 +255,7 @@ def fetch_job(self) -> bool:
process = Popen(edge_job.command, close_fds=True, env=env, start_new_session=True)
logfile = logs_logfile_path(edge_job.key)
self.jobs.append(_Job(edge_job, process, logfile, 0))
EdgeJob.set_state(edge_job.key, TaskInstanceState.RUNNING)
jobs_set_state(edge_job.key, TaskInstanceState.RUNNING)
return True

logger.info("No new job to process%s", f", {len(self.jobs)} still running" if self.jobs else "")
Expand All @@ -268,10 +271,10 @@ def check_running_jobs(self) -> None:
self.jobs.remove(job)
if job.process.returncode == 0:
logger.info("Job completed: %s", job.edge_job)
EdgeJob.set_state(job.edge_job.key, TaskInstanceState.SUCCESS)
jobs_set_state(job.edge_job.key, TaskInstanceState.SUCCESS)
else:
logger.error("Job failed: %s", job.edge_job)
EdgeJob.set_state(job.edge_job.key, TaskInstanceState.FAILED)
jobs_set_state(job.edge_job.key, TaskInstanceState.FAILED)
else:
used_concurrency += job.edge_job.concurrency_slots

Expand Down
Loading

0 comments on commit b0822f0

Please sign in to comment.