Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PythonVirtualenvOperator not working with Airflow context #9394

Merged
merged 1 commit into from
Jul 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions airflow/example_dags/example_python_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def my_sleeping_function(random_base):
# [END howto_operator_python_kwargs]


# [START howto_operator_python_venv]
def callable_virtualenv():
"""
Example function that will be performed in a virtual environment.
Expand Down Expand Up @@ -101,3 +102,4 @@ def callable_virtualenv():
system_site_packages=False,
dag=dag,
)
# [END howto_operator_python_venv]
213 changes: 106 additions & 107 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from itertools import islice
from tempfile import TemporaryDirectory
from textwrap import dedent
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, cast
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union, cast

import dill

Expand All @@ -38,7 +38,7 @@
from airflow.models.xcom_arg import XComArg
from airflow.utils.decorators import apply_defaults
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script


class PythonOperator(BaseOperator):
Expand Down Expand Up @@ -363,14 +363,18 @@ class PythonVirtualenvOperator(PythonOperator):
Note that if your virtualenv runs in a different Python major version than Airflow,
you cannot use return values, op_args, or op_kwargs. You can use string_args though.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:PythonVirtualenvOperator`

:param python_callable: A python function with no references to outside variables,
defined with def, which will be run in a virtualenv
:type python_callable: function
:param requirements: A list of requirements as specified in a pip install command
:type requirements: list[str]
:param python_version: The Python version to run the virtualenv with. Note that
both 2 and 2.7 are acceptable forms.
:type python_version: str
:type python_version: Optional[Union[str, int, float]]
:param use_dill: Whether to use dill to serialize
the args and result (pickle is default). This allow more complex types
but requires you to include dill in your requirements.
Expand All @@ -397,13 +401,48 @@ class PythonVirtualenvOperator(PythonOperator):
:type templates_exts: list[str]
"""

BASE_SERIALIZABLE_CONTEXT_KEYS = {
'ds_nodash',
'inlets',
'next_ds',
'next_ds_nodash',
'outlets',
'params',
'prev_ds',
'prev_ds_nodash',
'run_id',
'task_instance_key_str',
'test_mode',
'tomorrow_ds',
'tomorrow_ds_nodash',
'ts',
'ts_nodash',
'ts_nodash_with_tz',
'yesterday_ds',
'yesterday_ds_nodash'
}
PENDULUM_SERIALIZABLE_CONTEXT_KEYS = {
'execution_date',
'next_execution_date',
'prev_execution_date',
'prev_execution_date_success',
'prev_start_date_success'
}
AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = {
'macros',
'conf',
'dag',
'dag_run',
'task'
}

@apply_defaults
def __init__( # pylint: disable=too-many-arguments
self,
*,
python_callable: Callable,
requirements: Optional[Iterable[str]] = None,
python_version: Optional[str] = None,
python_version: Optional[Union[str, int, float]] = None,
use_dill: bool = False,
system_site_packages: bool = True,
op_args: Optional[List] = None,
Expand All @@ -413,151 +452,111 @@ def __init__( # pylint: disable=too-many-arguments
templates_exts: Optional[List[str]] = None,
**kwargs
):
if (
not isinstance(python_callable, types.FunctionType) or
isinstance(python_callable, types.LambdaType) and python_callable.__name__ == "<lambda>"
):
raise AirflowException('PythonVirtualenvOperator only supports functions for python_callable arg')
if (
python_version and str(python_version)[0] != str(sys.version_info.major) and
(op_args or op_kwargs)
):
raise AirflowException("Passing op_args or op_kwargs is not supported across different Python "
"major versions for PythonVirtualenvOperator. Please use string_args.")
super().__init__(
python_callable=python_callable,
op_args=op_args,
op_kwargs=op_kwargs,
templates_dict=templates_dict,
templates_exts=templates_exts,
**kwargs)
self.requirements = requirements or []
self.requirements = list(requirements or [])
self.string_args = string_args or []
self.python_version = python_version
self.use_dill = use_dill
self.system_site_packages = system_site_packages
# check that dill is present if needed
dill_in_requirements = map(lambda x: x.lower().startswith('dill'),
self.requirements)
if (not system_site_packages) and use_dill and not any(dill_in_requirements):
raise AirflowException('If using dill, dill must be in the environment ' +
'either via system_site_packages or requirements')
# check that a function is passed, and that it is not a lambda
if (not isinstance(self.python_callable,
types.FunctionType) or (self.python_callable.__name__ ==
(lambda x: 0).__name__)):
raise AirflowException('{} only supports functions for python_callable arg'.format(
self.__class__.__name__))
# check that args are passed iff python major version matches
if (python_version is not None and
str(python_version)[0] != str(sys.version_info[0]) and
self._pass_op_args()):
raise AirflowException("Passing op_args or op_kwargs is not supported across "
"different Python major versions "
"for PythonVirtualenvOperator. "
"Please use string_args.")
if not self.system_site_packages and self.use_dill and 'dill' not in self.requirements:
self.requirements.append('dill')
self.pickling_library = dill if self.use_dill else pickle

def execute(self, context: Dict):
serializable_context = {key: context[key] for key in self._get_serializable_context_keys()}
super().execute(context=serializable_context)

def execute_callable(self):
with TemporaryDirectory(prefix='venv') as tmp_dir:
if self.templates_dict:
self.op_kwargs['templates_dict'] = self.templates_dict
# generate filenames

input_filename = os.path.join(tmp_dir, 'script.in')
output_filename = os.path.join(tmp_dir, 'script.out')
string_args_filename = os.path.join(tmp_dir, 'string_args.txt')
script_filename = os.path.join(tmp_dir, 'script.py')

# set up virtualenv
python_bin = 'python' + str(self.python_version) if self.python_version else None
prepare_virtualenv(
venv_directory=tmp_dir,
python_bin=python_bin,
python_bin=f'python{self.python_version}' if self.python_version else None,
system_site_packages=self.system_site_packages,
requirements=self.requirements,
requirements=self.requirements
)

self._write_args(input_filename)
self._write_script(script_filename)
self._write_string_args(string_args_filename)
write_python_script(
jinja_context=dict(
op_args=self.op_args,
op_kwargs=self.op_kwargs,
pickling_library=self.pickling_library.__name__,
python_callable=self.python_callable.__name__,
python_callable_source=dedent(inspect.getsource(self.python_callable))
),
filename=script_filename
)

execute_in_subprocess(cmd=[
f'{tmp_dir}/bin/python',
script_filename,
input_filename,
output_filename,
string_args_filename
])

# execute command in virtualenv
execute_in_subprocess(
self._generate_python_cmd(tmp_dir,
script_filename,
input_filename,
output_filename,
string_args_filename))
return self._read_result(output_filename)

def _pass_op_args(self):
# we should only pass op_args if any are given to us
return len(self.op_args) + len(self.op_kwargs) > 0
def _write_args(self, filename):
if self.op_args or self.op_kwargs:
with open(filename, 'wb') as file:
self.pickling_library.dump({'args': self.op_args, 'kwargs': self.op_kwargs}, file)

def _get_serializable_context_keys(self):
def _is_airflow_env():
return self.system_site_packages or 'apache-airflow' in self.requirements

def _is_pendulum_env():
return 'pendulum' in self.requirements and 'lazy_object_proxy' in self.requirements

serializable_context_keys = self.BASE_SERIALIZABLE_CONTEXT_KEYS.copy()
if _is_airflow_env():
serializable_context_keys.update(self.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS)
if _is_pendulum_env() or _is_airflow_env():
serializable_context_keys.update(self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS)
return serializable_context_keys

def _write_string_args(self, filename):
# writes string_args to a file, which are read line by line
with open(filename, 'w') as file:
file.write('\n'.join(map(str, self.string_args)))

def _write_args(self, input_filename):
# serialize args to file
if self._pass_op_args():
with open(input_filename, 'wb') as file:
arg_dict = ({'args': self.op_args, 'kwargs': self.op_kwargs})
if self.use_dill:
dill.dump(arg_dict, file)
else:
pickle.dump(arg_dict, file)

def _read_result(self, output_filename):
if os.stat(output_filename).st_size == 0:
def _read_result(self, filename):
if os.stat(filename).st_size == 0:
return None
with open(output_filename, 'rb') as file:
with open(filename, 'rb') as file:
try:
if self.use_dill:
return dill.load(file)
else:
return pickle.load(file)
return self.pickling_library.load(file)
except ValueError:
self.log.error("Error deserializing result. "
"Note that result deserialization "
self.log.error("Error deserializing result. Note that result deserialization "
"is not supported across major Python versions.")
raise

def _write_script(self, script_filename):
with open(script_filename, 'w') as file:
python_code = self._generate_python_code()
self.log.debug('Writing code to file\n %s', python_code)
file.write(python_code)

@staticmethod
def _generate_python_cmd(tmp_dir, script_filename,
input_filename, output_filename, string_args_filename):
# direct path alleviates need to activate
return ['{}/bin/python'.format(tmp_dir), script_filename,
input_filename, output_filename, string_args_filename]

def _generate_python_code(self):
if self.use_dill:
pickling_library = 'dill'
else:
pickling_library = 'pickle'

# dont try to read pickle if we didnt pass anything
if self._pass_op_args():
load_args_line = 'with open(sys.argv[1], "rb") as file: arg_dict = {}.load(file)' \
.format(pickling_library)
else:
load_args_line = 'arg_dict = {"args": [], "kwargs": {}}'

# no indents in original code so we can accept
# any type of indents in the original function
# we deserialize args, call function, serialize result if necessary
return dedent("""\
import {pickling_library}
import sys
{load_args_code}
args = arg_dict["args"]
kwargs = arg_dict["kwargs"]
with open(sys.argv[3], 'r') as file:
virtualenv_string_args = list(map(lambda x: x.strip(), list(file)))
{python_callable_lines}
res = {python_callable_name}(*args, **kwargs)
with open(sys.argv[2], 'wb') as file:
res is not None and {pickling_library}.dump(res, file)
""").format(load_args_code=load_args_line,
python_callable_lines=dedent(inspect.getsource(self.python_callable)),
python_callable_name=self.python_callable.__name__,
pickling_library=pickling_library)


def get_current_context() -> Dict[str, Any]:
"""
Expand Down
22 changes: 22 additions & 0 deletions airflow/utils/python_virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
"""
Utilities for creating a virtual environment
"""
import os
from typing import List, Optional

import jinja2

from airflow.utils.process_utils import execute_in_subprocess


Expand Down Expand Up @@ -69,3 +72,22 @@ def prepare_virtualenv(
execute_in_subprocess(pip_cmd)

return '{}/bin/python'.format(venv_directory)


def write_python_script(jinja_context: dict, filename: str):
"""
Renders the python script to a file to execute in the virtual environment.

:param jinja_context: The jinja context variables to unpack and replace with its placeholders in the
template file.
:type jinja_context: dict
:param filename: The name of the file to dump the rendered script to.
:type filename: str
"""
template_loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__))
template_env = jinja2.Environment(
loader=template_loader,
undefined=jinja2.StrictUndefined
)
template = template_env.get_template('python_virtualenv_script.jinja2')
template.stream(**jinja_context).dump(filename)
42 changes: 42 additions & 0 deletions airflow/utils/python_virtualenv_script.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{#
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
#}

import {{ pickling_library }}
import sys

# Read args
{% if op_args or op_kwargs %}
with open(sys.argv[1], "rb") as file:
arg_dict = {{ pickling_library }}.load(file)
{% else %}
arg_dict = {"args": [], "kwargs": {}}
{% endif %}

# Read string args
with open(sys.argv[3], "r") as file:
virtualenv_string_args = list(map(lambda x: x.strip(), list(file)))

# Script
{{ python_callable_source }}
res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])

# Write output
with open(sys.argv[2], "wb") as file:
if res:
{{ pickling_library }}.dump(res, file)
Loading