Skip to content

Commit

Permalink
Separate orbiter analyze Command, add input file to dag dict post-f…
Browse files Browse the repository at this point in the history
…ilter, trim log output (#16)

* feat: add trim_dict to reduce log output of large objects

* refactor: separate `analyze` into it's own command

* fix: allow adding DAGs together (if parts were in multiple files)

* feat: analyze - write to file & get unknown task type from doc_md

* feat: add __file key to dag dict after filter. Make orbiter_kwargs[file_path] relative to input_dir

* fix: continue (w/ error) for failure to load a file of type

* fix: fix import path
  • Loading branch information
fritz-astronomer authored Sep 3, 2024
1 parent 49f6139 commit 4472792
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 71 deletions.
2 changes: 1 addition & 1 deletion orbiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from typing import Any, Tuple

__version__ = "1.2.0"
__version__ = "1.2.1"

version = __version__

Expand Down
137 changes: 90 additions & 47 deletions orbiter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,43 @@ def formatter(r):
**(exceptions_off if LOG_LEVEL != "DEBUG" else exceptions_on),
)

INPUT_DIR_ARGS = ("input-dir",)
INPUT_DIR_KWARGS = dict(
type=click.Path(
exists=True,
dir_okay=True,
file_okay=False,
readable=True,
resolve_path=True,
path_type=Path,
),
default=Path.cwd() / "workflow",
required=True,
)
RULESET_ARGS = (
"-r",
"--ruleset",
)
RULESET_KWARGS = dict(
help="Qualified name of a TranslationRuleset",
type=str,
prompt="Ruleset to use? (e.g. orbiter_community_translations.dag_factory.translation_ruleset)",
required=True,
)


def import_ruleset(ruleset: str) -> TranslationRuleset:
if RUNNING_AS_BINARY:
_add_pyz()

logger.debug(f"Importing ruleset: {ruleset}")
(_, translation_ruleset) = import_from_qualname(ruleset)
if not isinstance(translation_ruleset, TranslationRuleset):
raise RuntimeError(
f"translation_ruleset={translation_ruleset} is not a TranslationRuleset"
)
return translation_ruleset


def run(cmd: str, **kwargs):
"""Helper method to run a command and log the output"""
Expand Down Expand Up @@ -124,19 +161,7 @@ def orbiter():


@orbiter.command()
@click.argument(
"input-dir",
type=click.Path(
exists=True,
dir_okay=True,
file_okay=False,
readable=True,
resolve_path=True,
path_type=Path,
),
default=Path.cwd() / "workflow",
required=True,
)
@click.argument(*INPUT_DIR_ARGS, **INPUT_DIR_KWARGS)
@click.argument(
"output-dir",
type=click.Path(
Expand All @@ -149,33 +174,19 @@ def orbiter():
default=Path.cwd() / "output",
required=True,
)
@click.option(
"-r",
"--ruleset",
help="Qualified name of a TranslationRuleset",
type=str,
prompt="Ruleset to use? (e.g. orbiter_community_translations.dag_factory.translation_ruleset)",
required=True,
)
@click.option(*RULESET_ARGS, **RULESET_KWARGS)
@click.option(
"--format/--no-format",
"_format",
help="[optional] format the output with Ruff",
default=True,
show_default=True,
)
@click.option(
"--analyze/--no-analyze",
help="[optional] print statistics instead of rendering output",
default=False,
show_default=True,
)
def translate(
input_dir: Path,
output_dir: Path,
ruleset: str | None,
_format: bool,
analyze: bool,
):
"""Translate workflows in an `INPUT_DIR` to an `OUTPUT_DIR` Airflow Project.
Expand All @@ -192,34 +203,63 @@ def translate(
logger.debug(f"Creating output directory {output_dir}")
output_dir.mkdir(parents=True, exist_ok=True)

logger.debug(f"Adding current directory {os.getcwd()} to sys.path")
sys.path.insert(0, os.getcwd())

if RUNNING_AS_BINARY:
_add_pyz()

logger.debug(f"Importing ruleset: {ruleset}")
(_, translation_ruleset) = import_from_qualname(ruleset)
if not isinstance(translation_ruleset, TranslationRuleset):
raise RuntimeError(
f"translation_ruleset={translation_ruleset} is not a TranslationRuleset"
)

translation_ruleset = import_ruleset(ruleset)
try:
project = translation_ruleset.translate_fn(
translation_ruleset.translate_fn(
translation_ruleset=translation_ruleset, input_dir=input_dir
)
if analyze:
project.analyze()
else:
project.render(output_dir)
).render(output_dir)
except RuntimeError as e:
logger.error(f"Error encountered during translation: {e}")
raise click.Abort()
if _format:
run_ruff_formatter(output_dir)


@orbiter.command()
@click.argument(*INPUT_DIR_ARGS, **INPUT_DIR_KWARGS)
@click.option(*RULESET_ARGS, **RULESET_KWARGS)
@click.option(
"--format",
"_format",
type=click.Choice(["json", "csv", "md"]),
default="md",
help="[optional] format for analysis output",
show_default=True,
)
@click.option(
"-o",
"--output-file",
type=click.File("w", lazy=True),
default="-",
show_default=True,
help="File to write to, defaults to '-' (stdout)",
)
def analyze(
input_dir: Path,
ruleset: str | None,
_format: Literal["json", "csv", "md"],
output_file: Path | None,
):
"""Analyze workflows in an `INPUT_DIR`
Provide a specific ruleset with the `--ruleset` flag.
Run `orbiter list-rulesets` to see available rulesets.
`INPUT_DIR` defaults to `$CWD/workflow`.
"""
if isinstance(output_file, Path):
output_file = output_file.open("w", newline="")
translation_ruleset = import_ruleset(ruleset)
try:
translation_ruleset.translate_fn(
translation_ruleset=translation_ruleset, input_dir=input_dir
).analyze(_format, output_file)
except RuntimeError as e:
logger.exception(f"Error encountered during translation: {e}")
raise click.Abort()


def _pip_install(repo: str, key: str):
"""If we are running via python/pip, we can utilize pip to install translations"""
_exec = f"{sys.executable} -m pip install {repo}"
Expand Down Expand Up @@ -278,6 +318,9 @@ def _get_gh_pyz(


def _add_pyz():
logger.debug(f"Adding current directory {os.getcwd()} to sys.path")
sys.path.insert(0, os.getcwd())

local_pyz = [
str(_path.resolve()) for _path in Path(".").iterdir() if _path.suffix == ".pyz"
]
Expand Down
3 changes: 3 additions & 0 deletions orbiter/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,8 @@
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO")
"""You can set the log level to DEBUG to see more detailed logs."""

TRIM_LOG_OBJECT_LENGTH = os.getenv("TRIM_LOG_OBJECT_LENGTH", 1000)
"""Trim the (str) length of logged objects to avoid long logs, set to -1 to disable trimming and log full objects."""

KG_ACCOUNT_ID = "3b189b4c-c047-4fdb-9b46-408aa2978330"
RUNNING_AS_BINARY = getattr(sys, "frozen", False) and hasattr(sys, "_MEIPASS")
24 changes: 24 additions & 0 deletions orbiter/objects/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,30 @@ def repr(self):
f"catchup={self.catchup})"
)

# noinspection t
def __add__(self, other):
if other.tasks:
for task in other.tasks.values():
self.add_tasks(task)
if other.orbiter_conns:
for conn in other.orbiter_conns:
self.orbiter_conns.add(conn)
if other.orbiter_vars:
for var in other.orbiter_vars:
self.orbiter_vars.add(var)
if other.orbiter_env_vars:
for env_var in other.orbiter_env_vars:
self.orbiter_env_vars.add(env_var)
if other.orbiter_includes:
for include in other.orbiter_includes:
self.orbiter_includes.add(include)
if other.model_extra:
for key in other.model_extra.keys():
self.model_extra[key] = self.model_extra[key] or other.model_extra[key]
for key in self.render_attributes:
setattr(self, key, getattr(self, key) or getattr(other, key))
return self

def _dag_to_ast(self) -> ast.Expr:
"""
Returns the `DAG(...)` object.
Expand Down
32 changes: 25 additions & 7 deletions orbiter/objects/project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import sys
import re
from functools import reduce
from pathlib import Path
from typing import Dict, Iterable, Set, Literal
Expand Down Expand Up @@ -611,9 +613,16 @@ def render(self, output_dir: Path) -> None:
logger.debug("No entries for .env")

@validate_call
def analyze(self, output_fmt: Literal["json", "csv", "md"] = "md"):
def analyze(
self, output_fmt: Literal["json", "csv", "md"] = "md", output_file=None
):
"""Print an analysis of the project to the console.
!!! tip
Looks for a specific `[task_type=XYZ]` in the Task's `doc_md` property
or uses `type(task)` to infer the type of task.
```pycon
>>> from orbiter.objects.operators.empty import OrbiterEmptyOperator
>>> OrbiterProject().add_dags([
Expand All @@ -639,13 +648,23 @@ def analyze(self, output_fmt: Literal["json", "csv", "md"] = "md"):
```
"""
import sys
if output_file is None:
output_file = sys.stdout

_task_type = re.compile(r"\[task_type=(?P<task_type>[A-Za-z0-9-_]+)")

def get_task_type(task):
match = _task_type.match(getattr(task, "doc_md", None) or "")
match_or_task_type = (
match.groupdict().get("task_type") if match else None
) or type(task).__name__
return match_or_task_type

dag_analysis = [
{
"file": dag.orbiter_kwargs.get("file_path", dag.file_path),
"dag_id": dag.dag_id,
"task_types": [type(task).__name__ for task in dag.tasks.values()],
"task_types": [get_task_type(task) for task in dag.tasks.values()],
}
for dag in self.dags.values()
]
Expand Down Expand Up @@ -673,20 +692,19 @@ def analyze(self, output_fmt: Literal["json", "csv", "md"] = "md"):
if output_fmt == "json":
import json

json.dump(file_analysis, sys.stdout)
json.dump(file_analysis, output_file, default=str)
elif output_fmt == "csv":
import csv
import sys

writer = csv.DictWriter(sys.stdout, fieldnames=file_analysis[0].keys())
writer = csv.DictWriter(output_file, fieldnames={""} | totals.keys())
writer.writeheader()
writer.writerows(file_analysis)
elif output_fmt == "md":
from rich.console import Console
from rich.markdown import Markdown
from tabulate import tabulate

console = Console()
console = Console(file=output_file)

# DAGs EmptyOp
# file_a 1 1
Expand Down
24 changes: 21 additions & 3 deletions orbiter/rules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ def my_rule(val):
import functools
import json
import re
from typing import Callable, Any, Collection, TYPE_CHECKING, List
from typing import Callable, Any, Collection, TYPE_CHECKING, List, Mapping

from pydantic import BaseModel, Field

from loguru import logger

from orbiter.config import TRIM_LOG_OBJECT_LENGTH
from orbiter.objects.task import OrbiterOperator, OrbiterTaskDependency

if TYPE_CHECKING:
Expand All @@ -69,6 +70,17 @@ def my_rule(val):
qualname_validator = re.compile(qualname_validator_regex)


def trim_dict(v):
"""Stringify and trim a dictionary if it's greater than a certain length
(used to trim down overwhelming log output)"""
if TRIM_LOG_OBJECT_LENGTH != -1 and isinstance(v, Mapping):
if len(str(v)) > TRIM_LOG_OBJECT_LENGTH:
return json.dumps(v, default=str)[:TRIM_LOG_OBJECT_LENGTH] + "..."
if isinstance(v, list):
return [trim_dict(_v) for _v in v]
return v


def rule(
func=None, *, priority=None
) -> (
Expand Down Expand Up @@ -160,7 +172,9 @@ def __call__(self, *args, **kwargs):
setattr(result, "orbiter_kwargs", kwargs)
except Exception as e:
logger.warning(
f"[RULE]: {self.rule.__name__}\n[ERROR]:\n{type(e)} - {e}\n[INPUT]:\n{args}\n{kwargs}"
f"[RULE]: {self.rule.__name__}\n"
f"[ERROR]:\n{type(e)} - {trim_dict(e)}\n"
f"[INPUT]:\n{trim_dict(args)}\n{trim_dict(kwargs)}"
)
result = None
return result
Expand Down Expand Up @@ -189,6 +203,10 @@ def foo(val: dict) -> List[dict]:
class DAGRule(Rule):
"""A `@dag_rule` decorator creates a [`DAGRule`][orbiter.rules.DAGRule]
!!! tip
A `__file` key is added to the original input, which is the file path of the input.
```python
@dag_rule
def foo(val: dict) -> OrbiterDAG | None:
Expand Down Expand Up @@ -314,7 +332,7 @@ def cannot_map_rule(val: dict) -> OrbiterOperator | None:
# noinspection PyArgumentList
return OrbiterEmptyOperator(
task_id="UNKNOWN",
doc_md=f"""Input did not translate: `{json.dumps(val, default=str)}`""",
doc_md=f"""[task_type=UNKNOWN] Input did not translate: `{trim_dict(val)}`""",
)


Expand Down
Loading

0 comments on commit 4472792

Please sign in to comment.