Skip to content

Merge main

Merge main #82

Workflow file for this run

name: ~CI, single-arch
run-name: CI-${{ inputs.ARCHITECTURE }}
on:
workflow_call:
inputs:
ARCHITECTURE:
type: string
required: true
BUILD_DATE:
type: string
description: 'Build date in YYYY-MM-DD format'
required: false
default: NOT SPECIFIED
CUDA_IMAGE:
type: string
description: CUDA image to use as base, e.g. nvidia/cuda:X.Y.Z-devel-ubuntu22.04
default: 'latest'
required: false
MANIFEST_ARTIFACT_NAME:
type: string
description: 'Artifact name in current run w/ manifest/patches. Leaving empty uses manifest/patches in current branch'
default: ''
required: false
SOURCE_URLREFS:
type: string
description: 'A JSON object containing git url+refs for softwares to be built'
required: false
default: '{}'
outputs:
DOCKER_TAGS:
description: 'JSON object containing tags of all docker images built'
value: ${{ jobs.collect-docker-tags.outputs.TAGS }}
permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
packages: write # to upload container
jobs:

Check failure on line 40 in .github/workflows/_ci.yaml

View workflow run for this annotation

GitHub Actions / .github/workflows/_ci.yaml

Invalid workflow file

You have an error in your yaml syntax on line 40
build-base:
uses: ./.github/workflows/_build_base.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
BASE_IMAGE: ${{ inputs.CUDA_IMAGE }}
BUILD_DATE: ${{ inputs.BUILD_DATE }}
MANIFEST_ARTIFACT_NAME: ${{ inputs.MANIFEST_ARTIFACT_NAME }}
secrets: inherit
build-jax:
needs: build-base
uses: ./.github/workflows/_build.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
ARTIFACT_NAME: artifact-jax-build
BADGE_FILENAME: badge-jax-build
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-base.outputs.DOCKER_TAG }}
CONTAINER_NAME: jax
DOCKERFILE: .github/container/Dockerfile.jax
RUNNER_SIZE: large
EXTRA_BUILD_ARGS: |
URLREF_JAX=${{ fromJson(inputs.SOURCE_URLREFS).JAX }}
URLREF_XLA=${{ fromJson(inputs.SOURCE_URLREFS).XLA }}
URLREF_FLAX=${{ fromJson(inputs.SOURCE_URLREFS).FLAX }}
URLREF_TRANSFORMER_ENGINE=${{ fromJson(inputs.SOURCE_URLREFS).TRANSFORMER_ENGINE }}
secrets: inherit
build-equinox:
needs: build-jax
uses: ./.github/workflows/_build.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
ARTIFACT_NAME: artifact-equinox-build
BADGE_FILENAME: badge-equinox-build
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
CONTAINER_NAME: equinox
DOCKERFILE: .github/container/Dockerfile.equinox
EXTRA_BUILD_ARGS: |
URLREF_EQUINOX=${{ fromJson(inputs.SOURCE_URLREFS).EQUINOX }}
secrets: inherit
build-maxtext:
needs: build-jax
uses: ./.github/workflows/_build.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
ARTIFACT_NAME: artifact-maxtext-build
BADGE_FILENAME: badge-maxtext-build
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
CONTAINER_NAME: maxtext
DOCKERFILE: .github/container/Dockerfile.maxtext
EXTRA_BUILD_ARGS: |
URLREF_MAXTEXT=${{ fromJson(inputs.SOURCE_URLREFS).MAXTEXT }}
secrets: inherit
build-levanter:
needs: [build-jax]
uses: ./.github/workflows/_build.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
ARTIFACT_NAME: "artifact-levanter-build"
BADGE_FILENAME: "badge-levanter-build"
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
CONTAINER_NAME: levanter
DOCKERFILE: .github/container/Dockerfile.levanter
EXTRA_BUILD_ARGS: |
URLREF_LEVANTER=${{ fromJson(inputs.SOURCE_URLREFS).LEVANTER }}
URLREF_HALIAX=${{ fromJson(inputs.SOURCE_URLREFS).HALIAX }}
secrets: inherit
build-upstream-t5x:
needs: build-jax
uses: ./.github/workflows/_build.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
ARTIFACT_NAME: "artifact-t5x-build"
BADGE_FILENAME: "badge-t5x-build"
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
CONTAINER_NAME: upstream-t5x
DOCKERFILE: .github/container/Dockerfile.t5x
EXTRA_BUILD_ARGS: |
URLREF_T5X=${{ fromJson(inputs.SOURCE_URLREFS).T5X }}
URLREF_AIRIO=${{ fromJson(inputs.SOURCE_URLREFS).AIRIO }}
secrets: inherit
build-rosetta-t5x:
needs: build-upstream-t5x
uses: ./.github/workflows/_build_rosetta.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-upstream-t5x.outputs.DOCKER_TAG_MEALKIT }}
BASE_LIBRARY: t5x
secrets: inherit
build-gemma:
needs: build-jax
uses: ./.github/workflows/_build.yaml
if: inputs.ARCHITECTURE == 'amd64' # build only amd64
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
ARTIFACT_NAME: artifact-gemma-build
BADGE_FILENAME: badge-gemma-build
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
CONTAINER_NAME: gemma
DOCKERFILE: rosetta/Dockerfile.gemma
DOCKER_CONTEXT: .
EXTRA_BUILD_ARGS: |
URLREF_GEMMA=${{ fromJson(inputs.SOURCE_URLREFS).GEMMA }}
URLREF_BIG_VISION=${{ fromJson(inputs.SOURCE_URLREFS).BIG_VISION }}
URLREF_COMMON_LOOP_UTILS=${{ fromJson(inputs.SOURCE_URLREFS).COMMON_LOOP_UTILS }}
URLREF_FLAXFORMER=${{ fromJson(inputs.SOURCE_URLREFS).FLAXFORMER }}
URLREF_PANOPTICAPI=${{ fromJson(inputs.SOURCE_URLREFS).PANOPTICAPI }}
secrets: inherit
build-axlearn:
needs: build-jax
uses: ./.github/workflows/_build.yaml
with:
ARCHITECTURE: ${{ inputs.ARCHITECTURE }}
ARTIFACT_NAME: artifact-axlearn-build
BADGE_FILENAME: badge-axlearn-build
BUILD_DATE: ${{ inputs.BUILD_DATE }}
BASE_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}
CONTAINER_NAME: axlearn
DOCKERFILE: .github/container/Dockerfile.axlearn
RUNNER_SIZE: large
secrets: inherit
collect-docker-tags:
runs-on: ubuntu-22.04
if: ${{ !cancelled() }}
needs:
- build-base
- build-jax
- build-triton
- build-equinox
- build-maxtext
- build-levanter
- build-upstream-t5x
- build-rosetta-t5x
- build-gemma
- build-axlearn
outputs:
TAGS: ${{ steps.collect-tags.outputs.TAGS }}
steps:
- name: Save docker tags as a JSON object
id: collect-tags
run: |
TAGS=$(cat <<EOF | jq -c
[\
{"flavor": "base", "stage": "final", "priority": 800, "tag": "${{ needs.build-base.outputs.DOCKER_TAG }}"},\
{"flavor": "jax", "stage": "final", "priority": 1000, "tag": "${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}"},\
{"flavor": "equinox", "stage": "final", "priority": 900, "tag": "${{ needs.build-equinox.outputs.DOCKER_TAG_FINAL }}"},\
{"flavor": "maxtext", "stage": "final", "priority": 900, "tag": "${{ needs.build-maxtext.outputs.DOCKER_TAG_FINAL }}"},\
{"flavor": "levanter", "stage": "final", "priority": 900, "tag": "${{ needs.build-levanter.outputs.DOCKER_TAG_FINAL }}"},\
{"flavor": "upstream-t5x", "stage": "final", "priority": 900, "tag": "${{ needs.build-upstream-t5x.outputs.DOCKER_TAG_FINAL }}"},\
{"flavor": "t5x", "stage": "final", "priority": 900, "tag": "${{ needs.build-rosetta-t5x.outputs.DOCKER_TAG_FINAL }}"},\
{"flavor": "gemma", "stage": "final", "priority": 900, "tag": "${{ needs.build-gemma.outputs.DOCKER_TAG_FINAL }}"},\
{"flavor": "axlearn", "stage": "final", "priority": 900, "tag": "${{ needs.build-axlearn.outputs.DOCKER_TAG_FINAL }}"},\
{"flavor": "jax", "stage": "mealkit", "priority": 500, "tag": "${{ needs.build-jax.outputs.DOCKER_TAG_MEALKIT }}"},\
{"flavor": "equinox", "stage": "mealkit", "priority": 500, "tag": "${{ needs.build-equinox.outputs.DOCKER_TAG_MEALKIT }}"},\
{"flavor": "maxtext", "stage": "mealkit", "priority": 500, "tag": "${{ needs.build-maxtext.outputs.DOCKER_TAG_MEALKIT }}"},\
{"flavor": "levanter", "stage": "mealkit", "priority": 500, "tag": "${{ needs.build-levanter.outputs.DOCKER_TAG_MEALKIT }}"},\
{"flavor": "upstream-t5x", "stage": "mealkit", "priority": 500, "tag": "${{ needs.build-upstream-t5x.outputs.DOCKER_TAG_MEALKIT }}"},\
{"flavor": "t5x", "stage": "mealkit", "priority": 500, "tag": "${{ needs.build-rosetta-t5x.outputs.DOCKER_TAG_MEALKIT }}"},\
{"flavor": "gemma", "stage": "mealkit", "priority": 500, "tag": "${{ needs.build-gemma.outputs.DOCKER_TAG_MEALKIT }}"},\
{"flavor": "axlearn", "stage": "mealkit", "priority": 500, "tag": "${{ needs.build-axlearn.outputs.DOCKER_TAG_MEALKIT }}"},\
{}\
]
EOF
)
echo "TAGS=${TAGS}" >> $GITHUB_OUTPUT
test-distribution:
runs-on: ubuntu-22.04
strategy:
matrix:
TEST_SCRIPT:
- extra-only-distribution.sh
- mirror-only-distribution.sh
- upstream-only-distribution.sh
- local-patch-distribution.sh
fail-fast: false
steps:
- name: Print environment variables
run: env
- name: Set git login for tests
run: |
git config --global user.email "jax@nvidia.com"
git config --global user.name "JAX-Toolbox CI"
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
- name: Run integration test ${{ matrix.TEST_SCRIPT }}
run: bash rosetta/tests/${{ matrix.TEST_SCRIPT }}
test-jax:
needs: build-jax
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
uses: ./.github/workflows/_test_unit.yaml
with:
TEST_NAME: jax
EXECUTE: |
docker run -i --shm-size=1g --gpus all \
${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \
bash <<"EOF" |& tee test-backend-independent.log
test-jax.sh -b backend-independent
EOF
docker run -i --shm-size=1g --gpus all \
${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \
bash <<"EOF" |& tee tee test-gpu.log
nvidia-cuda-mps-control -d
test-jax.sh -b gpu
EOF
STATISTICS_SCRIPT: |
errors=$(cat test-*.log | grep -c 'ERROR:' || true)
failed_tests=$(cat test-*.log | grep -c 'FAILED in' || true)
passed_tests=$(cat test-*.log | grep -c 'PASSED in' || true)
total_tests=$((failed_tests + passed_tests))
echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT
echo "ERRORS=${errors}" >> $GITHUB_OUTPUT
echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT
echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT
ARTIFACTS: |
test-backend-independent.log
test-gpu.log
secrets: inherit
test-nsys-jax:
needs: build-jax
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
uses: ./.github/workflows/_test_unit.yaml
with:
TEST_NAME: nsys-jax
EXECUTE: |
set -o pipefail
num_tests=0
num_failures=0
# Run the pytest-driven tests; failure is explicitly handled below so set +e to
# avoid an early abort here.
set +e
docker run -i --shm-size=1g --gpus all \
-v $PWD:/opt/output \
${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \
bash <<"EOF" |& tee test-nsys-jax.log
# nsys-jax is already installed, this is just adding the test dependencies
pip install pytest-reportlog nsys-jax[test]
# abuse knowledge that nsys-jax is installed editable, so the tests exist
test_path=$(python -c 'import importlib.resources; print(importlib.resources.files("nsys_jax").joinpath("..", "tests").resolve())')
pytest --report-log=/opt/output/pytest-report.jsonl "${test_path}"
EOF
set -e
GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU')
for mode in 1-process 2-process process-per-gpu; do
DOCKER="docker run --shm-size=1g --gpus all --env XLA_FLAGS=--xla_gpu_enable_command_buffer= --env XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 -v ${PWD}:/opt/output ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}"
if [[ "${mode}" == "1-process" ]]; then
PROCESS_COUNT=1
ARGS=""
elif [[ "${mode}" == "2-process" ]]; then
# Use two processes with GPUS_PER_NODE/2 GPUs per process in the hope that
# this will flush out more bugs than process-per-node or process-per-GPU.
PROCESS_COUNT=2
ARGS="--process-id RANK --process-count ${PROCESS_COUNT} --coordinator-address 127.0.0.1:12345 --gpus-per-process $((GPUS_PER_NODE/2)) --distributed"
else
PROCESS_COUNT=${GPUS_PER_NODE}
ARGS="--process-id RANK --process-count ${PROCESS_COUNT} --coordinator-address 127.0.0.1:12345 --gpus-per-process 1 --distributed"
fi
for collection in full partial; do
NSYS_JAX="nsys-jax"
if [[ "${mode}" == "1-process" ]]; then
# We will not run nsys-jax-combine, so run analyses eagerly
NSYS_JAX+=" --nsys-jax-analysis communication --nsys-jax-analysis summary"
fi
NSYS_JAX+=" --output=/opt/output/${mode}-${collection}-execution-%q{RANK}"
if [[ "${collection}" == "partial" ]]; then
NSYS_JAX+=" --capture-range=cudaProfilerApi --capture-range-end=stop"
# nvbug/4801401
NSYS_JAX+=" --sample=none"
fi
set +e
${DOCKER} parallel-launch RANK ${PROCESS_COUNT} ${NSYS_JAX} \
-- jax-nccl-test ${ARGS} |& tee ${mode}-${collection}-execution.log
num_failures=$((num_failures + ($? != 0)))
set -e
num_tests=$((num_tests + 1))
done
if [[ "${mode}" != "1-process" ]]; then
# Run nsys-jax-combine
NSYS_JAX_COMBINE="nsys-jax-combine --analysis communication --analysis summary --output=/opt/output/${mode}-${collection}-execution.zip"
for (( i=0; i<PROCESS_COUNT; i++ )); do
NSYS_JAX_COMBINE+=" /opt/output/${mode}-${collection}-execution-${i}.zip"
done
set +e
${DOCKER} ${NSYS_JAX_COMBINE} |& tee ${mode}-${collection}-execution-combine.log
num_failures=$((num_failures + ($? != 0)))
set -e
num_tests=$((num_tests + 1))
fi
done
ls -R .
echo "NSYS_JAX_TEST_COUNT=${num_tests}" >> $GITHUB_ENV
echo "NSYS_JAX_FAIL_COUNT=${num_failures}" >> $GITHUB_ENV
exit $num_failures
STATISTICS_SCRIPT: |
summary_line=$(tail -n1 test-nsys-jax.log)
num_errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}')
passed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "passed") | .outcome' | wc -l)
failed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "failed") | .outcome' | wc -l)
total_tests=$(( NSYS_JAX_TEST_COUNT + passed_tests + failed_tests ))
num_passed=$(( passed_tests + NSYS_JAX_TEST_COUNT - NSYS_JAX_FAIL_COUNT ))
num_failed=$(( failed_tests + NSYS_JAX_FAIL_COUNT ))
echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT
echo "ERRORS=${num_errors}" >> $GITHUB_OUTPUT
echo "PASSED_TESTS=${num_passed}" >> $GITHUB_OUTPUT
echo "FAILED_TESTS=${num_failed}" >> $GITHUB_OUTPUT
ARTIFACTS: |
# pytest-driven part
test-nsys-jax.log
pytest-report.jsonl
# nsys-jax logfiles
*process-*-execution.log
# nsys-jax output for the case that doesn't use nsys-jax-combine
1-process-*-execution-0.zip
# nsys-jax-combine output/logfiles
*process*-*-execution.zip
*-execution-combine.log
secrets: inherit
# test-nsys-jax generates several fresh .zip archive outputs by running nsys-jax with real GPU hardware; this test
# runs on a regular GitHub Actions runner and checks that offline post-processing works in an environment that does
# not already have nsys-jax installed
test-nsys-jax-archive:
needs: test-nsys-jax
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
strategy:
matrix:
os: [ubuntu-22.04, ubuntu-24.04, macOS-latest]
runs-on: ${{ matrix.os }}
steps:
- name: Download nsys-jax output .zip files
uses: actions/download-artifact@v4
with:
name: nsys-jax-unit-test-A100
- name: Extract archives and execute install scripts
run: |
pip install virtualenv # for install.sh
for zip in $(ls *.zip); do
ZIP="${PWD}/${zip}"
pushd $(mktemp -d)
unzip "${ZIP}"
ls -l
# TODO: verify this isn't needed, or make sure it isn't needed
chmod 755 install.sh
# Run the notebook with IPython, not Jupyter Lab, so it exits and prints something informative to stdout
# Skip executing Jupyter lab
NSYS_JAX_JUPYTER_EXECUTE_NOT_LAB=1 ./install.sh
popd
done
test-nsys-jax-eks:
needs: build-jax
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
runs-on: eks
env:
JAX_DOCKER_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}
JOB_NAME: ${{ github.run_id }}-nsys-jax
POSTPROCESS_JOB_NAME: ${{ github.run_id }}-nsys-jax-postprocess
steps:
- name: Check out the repository
uses: actions/checkout@v4
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: K8s GHCR store and delete token
id: store-token
uses: ./.github/actions/store-delete-k8s-ghcr
- name: Configure Kubernetes job
run: |
yq -i ea 'select(di == 0).spec.selector.job-name = strenv(JOB_NAME)
| select(di == 1).metadata.name = strenv(JOB_NAME)
| select(di == 1).spec.template.spec.imagePullSecrets[].name = "${{ steps.store-token.outputs.token-name }}"
| select(di == 1).spec.template.spec.containers[0].image = strenv(JAX_DOCKER_IMAGE)
| select(di == 1).spec.template.spec.containers[0].env[0].value = strenv(JOB_NAME)' \
.github/eks-workflow-files/job.yml
git diff .github/eks-workflow-files/job.yml
- name: Submit Kubernetes job
uses: ./.github/actions/submit-delete-k8s-job
with:
job-config-file: .github/eks-workflow-files/job.yml
job-name: ${{ env.JOB_NAME }}
- name: Configure post-processing job
run: |
export JOB_OUTPUT_PATTERN="${JOB_NAME}-rank*.zip"
yq -i '.metadata.name = strenv(POSTPROCESS_JOB_NAME)
| .spec.template.spec.containers[].image = strenv(JAX_DOCKER_IMAGE)
| .spec.template.spec.imagePullSecrets[].name = "${{ steps.store-token.outputs.token-name }}"
| .spec.template.spec.initContainers[].command[7] = strenv(JOB_OUTPUT_PATTERN)' \
.github/eks-workflow-files/post-process-job.yml
git diff .github/eks-workflow-files/post-process-job.yml
- name: Submit post process Kubernetes job
uses: ./.github/actions/submit-delete-k8s-job
with:
job-config-file: .github/eks-workflow-files/post-process-job.yml
job-name: ${{ env.POSTPROCESS_JOB_NAME }}
# test-equinox:
# needs: build-equinox
# if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
# uses: ./.github/workflows/_test_unit.yaml
# with:
# IMAGE: ${{ needs.build-equinox.outputs.DOCKER_TAG_FINAL }}
# TEST_NAME: equinox
# EXECUTE: |
# docker run --shm-size=1g --gpus all ${{ needs.build-equinox.outputs.DOCKER_TAG_FINAL }} \
# bash -exc -o pipefail \
# 'pushd /opt/equinox/tests && pip install -r requirements.txt && pytest .' | tee test-equinox.log
# STATISTICS_SCRIPT: |
# errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}')
# failed_tests=$(echo $summary_line | grep -oE '[0-9]+ failed' | awk '{print $1} END { if (!NR) print 0}')
# passed_tests=$(echo $summary_line | grep -oE '[0-9]+ passed' | awk '{print $1} END { if (!NR) print 0}')
# total_tests=$((failed_tests + passed_tests))
# echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT
# echo "ERRORS=${errors}" >> $GITHUB_OUTPUT
# echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT
# echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT
# ARTIFACTS: |
# test-equinox.log
# secrets: inherit
test-te-h100:
needs: build-jax
if: inputs.ARCHITECTURE == 'amd64'
uses: ./.github/workflows/_transformer_engine_eks.yaml
with:
JAX_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}
JOB_NAME: transformerengine-${{ github.run_id }}
S3_BUCKET: jax-toolbox-eks-output
CI_NAME: transformer-engine
secrets: inherit
test-te-multigpu-a100:
needs: build-jax
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
uses: ./.github/workflows/_test_te.yaml
with:
JAX_IMAGE: ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}
secrets: inherit
test-te-unit-a100:
needs: build-jax
secrets: inherit
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
uses: ./.github/workflows/_test_unit.yaml
with:
TEST_NAME: te
EXECUTE: |
docker run -i --gpus all --shm-size=1g -v $PWD:/log \
${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \
bash <<"EOF" |& tee test-te.log
set -x
pip install pytest-reportlog pytest-xdist
# Start MPS daemon
nvidia-cuda-mps-control -d
# TE's default is slightly different, without the hyphen
export TE_PATH=${SRC_PATH_TRANSFORMER_ENGINE}
# 1 GPU per worker, 6 workers per GPU
pytest-xdist.sh 1 6 /log/pytest-report-L0-unittest.jsonl bash ${TE_PATH}/qa/L0_jax_unittest/test.sh
## 8 GPUs per worker, 1 worker per GPU. pytest-xdist.sh allows aggregation
## into a single .jsonl file of results from multiple pytest invocations
## inside the test.sh script, so it's useful even with a single worker per
## device.
pytest-xdist.sh 8 1 /log/pytest-report-L0-distributed-unittest.jsonl bash ${TE_PATH}/qa/L0_jax_distributed_unittest/test.sh
pytest-xdist.sh 8 1 /log/pytest-report-L1-distributed-unittest.jsonl bash ${TE_PATH}/qa/L1_jax_distributed_unittest/test.sh
# merge the log files
cat \
/log/pytest-report-L0-unittest.jsonl
/log/pytest-report-L0-distributed-unittest.jsonl \
/log/pytest-report-L1-distributed-unittest.jsonl \
> /log/pytest-report.jsonl
EOF
STATISTICS_SCRIPT: |
ls
report_json=pytest-report-L0-unittest.jsonl
summary_line=$(tail -n1 test-te.log)
errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}')
passed_tests=$(cat $report_json | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "passed") | .outcome' | wc -l)
failed_tests=$(cat $report_json | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "failed") | .outcome' | wc -l)
total_tests=$((failed_tests + passed_tests))
echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT
echo "ERRORS=${errors}" >> $GITHUB_OUTPUT
echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT
echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT
TIMEOUT_MINUTES: 120
ARTIFACTS: |
test-te.log
pytest-report.jsonl
pytest-report-L0-unittest.jsonl
pytest-report-L0-distributed-unittest.jsonl
pytest-report-L1-distributed-unittest.jsonl
test-upstream-t5x:
needs: build-upstream-t5x
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
uses: ./.github/workflows/_test_upstream_t5x.yaml
with:
T5X_IMAGE: ${{ needs.build-upstream-t5x.outputs.DOCKER_TAG_FINAL }}
secrets: inherit
test-rosetta-t5x:
needs: build-rosetta-t5x
if: inputs.ARCHITECTURE == 'amd64' # no images for arm64
uses: ./.github/workflows/_test_t5x_rosetta.yaml
with:
T5X_IMAGE: ${{ needs.build-rosetta-t5x.outputs.DOCKER_TAG_FINAL }}
secrets: inherit
test-triton:
needs: build-triton
if: inputs.ARCHITECTURE == 'amd64' # no images for arm64
uses: ./.github/workflows/_test_unit.yaml
with:
TEST_NAME: triton
EXECUTE: |
docker run -i --shm-size=1g --gpus all --volume $PWD:/output \
${{ needs.build-triton.outputs.DOCKER_TAG_FINAL }} \
bash <<"EOF" |& tee test-triton.log
# autotuner tests from jax-triton now hit a triton code path that uses utilities from pytorch; this relies on
# actually having a CUDA backend for pytoch
pip install --no-deps torch
python /opt/jax-triton/tests/triton_call_test.py --xml_output_file /output/triton_test.xml
EOF
STATISTICS_SCRIPT: |
curl -L -o yq https://github.com/mikefarah/yq/releases/latest/download/yq_linux_$(dpkg --print-architecture) && chmod 777 yq;
total_tests=$(./yq '.testsuites."+@tests"' triton_test.xml)
errors=$(./yq '.testsuites."+@errors"' triton_test.xml)
failed_tests=$(./yq '.testsuites."+@failures"' triton_test.xml)
passed_tests=$((total_tests - errors - failed_tests))
echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT
echo "ERRORS=${errors}" >> $GITHUB_OUTPUT
echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT
echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT
ARTIFACTS: |
test-triton.log
secrets: inherit
test-levanter:
needs: build-levanter
if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
uses: ./.github/workflows/_test_unit.yaml
with:
TEST_NAME: levanter
EXECUTE: |
docker run -i --gpus all --shm-size=1g \
${{ needs.build-levanter.outputs.DOCKER_TAG_FINAL }} \
bash <<"EOF" |& tee test-levanter.log
pip install flake8 pytest soundfile librosa
PYTHONPATH=/opt/levanter/tests:$PYTHONPATH pytest /opt/levanter/tests -m "not entry and not slow and not ray"
EOF
STATISTICS_SCRIPT: |
summary_line=$(tail -n1 test-levanter.log)
errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}')
failed_tests=$(echo $summary_line | grep -oE '[0-9]+ failed' | awk '{print $1} END { if (!NR) print 0}')
passed_tests=$(echo $summary_line | grep -oE '[0-9]+ passed' | awk '{print $1} END { if (!NR) print 0}')
total_tests=$((failed_tests + passed_tests))
echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT
echo "ERRORS=${errors}" >> $GITHUB_OUTPUT
echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT
echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT
ARTIFACTS: |
test-levanter.log
secrets: inherit
test-gemma:
needs: build-gemma
uses: ./.github/workflows/_test_unit.yaml
if: inputs.ARCHITECTURE == 'amd64'
with:
TEST_NAME: gemma
EXECUTE: |
docker run --shm-size=1g --gpus all ${{ needs.build-gemma.outputs.DOCKER_TAG_FINAL }} \
bash -ec \
"cd /opt/gemma && pip install -e .[dev] && pytest ." | tee test-gemma.log
STATISTICS_SCRIPT: |
summary_line=$(tail -n1 test-gemma.log)
errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}')
failed_tests=$(echo $summary_line | grep -oE '[0-9]+ failed' | awk '{print $1} END { if (!NR) print 0}')
passed_tests=$(echo $summary_line | grep -oE '[0-9]+ passed' | awk '{print $1} END { if (!NR) print 0}')
total_tests=$((failed_tests + passed_tests))
echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT
echo "ERRORS=${errors}" >> $GITHUB_OUTPUT
echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT
echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT
ARTIFACTS: |
test-gemma.log
secrets: inherit
test-maxtext:
needs: build-maxtext
if: inputs.ARCHITECTURE == 'amd64' # no arm64 gpu runners
uses: ./.github/workflows/_test_maxtext.yaml
with:
MAXTEXT_IMAGE: ${{ needs.build-maxtext.outputs.DOCKER_TAG_FINAL }}
secrets: inherit
test-axlearn-eks:
needs: build-axlearn
if: inputs.ARCHITECTURE == 'amd64'
runs-on: eks
env:
AXLEARN_DOCKER_IMAGE: ${{ needs.build-axlearn.outputs.DOCKER_TAG_FINAL }}
JOB_NAME: axlearn-${{ github.run_id }}
steps:
- name: Check out the repository
uses: actions/checkout@v4
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: K8s GHCR store and delete token
id: store-token
uses: ./.github/actions/store-delete-k8s-ghcr
- name: Configure axlearn test job
run: |
# Replace placeholders in axlearn-job.yml with environment variables
yq -i ea '
select(di == 0).metadata.name = strenv(JOB_NAME)
| select(di == 0).spec.template.spec.containers[0].image = strenv(AXLEARN_DOCKER_IMAGE)
| select(di == 0).spec.template.spec.containers[1].env[0].value = "${{ github.run_id }}"
| select(di == 0).spec.template.spec.imagePullSecrets[].name = "${{ steps.store-token.outputs.token-name }}"' \
.github/eks-workflow-files/axlearn/axlearn-job.yml
git diff .github/eks-workflow-files/axlearn/axlearn-job.yml
- name: Submit & delete axlearn test
uses: ./.github/actions/submit-delete-k8s-job
with:
job-config-file: ".github/eks-workflow-files/axlearn/axlearn-job.yml"
job-name: ${{ env.JOB_NAME }}
- name: Download logs from S3
id: log-s3
run: |
mkdir -p axlearn-output
aws s3 cp s3://jax-toolbox-eks-output/axlearn/${{ github.run_id }}/summary.txt axlearn-output/
aws s3 cp s3://jax-toolbox-eks-output/axlearn/${{ github.run_id }}/ axlearn-output/ --recursive --exclude "*" --include "*.log"
passed_tests=$(grep -c ": PASSED" axlearn-output/summary.txt || true)
failed_tests=$(grep -c ": FAILED" axlearn-output/summary.txt || true)
total_tests=$((failed_tests + passed_tests))
echo "Passed tests: $passed_tests"
echo "Failed tests: $failed_tests"
echo "Total tests: $total_tests"
echo "PASSED_TESTS=$passed_tests" >> $GITHUB_OUTPUT
echo "FAILED_TESTS=$failed_tests" >> $GITHUB_OUTPUT
echo "TOTAL_TESTS=$total_tests" >> $GITHUB_OUTPUT
- name: Generate sitrep
id: sitrep
if: ${{ !cancelled() }}
shell: bash -x -e {0}
run: |
# bring in utility functions
source .github/workflows/scripts/to_json.sh
badge_label='Axlearn EKS Unit'
total_tests=${{ steps.log-s3.outputs.TOTAL_TESTS }} \
failed_tests=${{ steps.log-s3.outputs.FAILED_TESTS }} \
passed_tests=${{ steps.log-s3.outputs.PASSED_TESTS }} \
errors="0" \
summary="All tests: $total_tests. Passed: $passed_tests. Failed: $failed_tests." \
badge_message="Passed $passed_tests out of $total_tests." \
badge_color="brightgreen"
if [ "$failed_tests" -gt 0 ]; then
badge_color="red"
fi \
to_json \
summary \
errors total_tests passed_tests failed_tests \
badge_label badge_color badge_message \
> sitrep.json
schemaVersion=1 \
label="${badge_label}" \
message="Passed $passed_tests out of $total_tests." \
color=$badge_color \
to_json schemaVersion label message color \
> badge-axlearn-test.json
- name: Upload artifacts
if: ${{ !cancelled() }}
uses: actions/upload-artifact@v4
with:
name: "artifact-axlearn-test"
path: |
sitrep.json
badge-axlearn-test.json
axlearn-output/*
# the fuji test will run for 20 minutes only, as per 2025-02-24
# is not possible to set the `max_steps` value
# this will be done with a customer python code
test-axlearn-fuji-models-eks:
needs: build-axlearn
if: inputs.ARCHITECTURE == 'amd64'
runs-on: eks
env:
AXLEARN_DOCKER_IMAGE: ${{ needs.build-axlearn.outputs.DOCKER_TAG_FINAL }}
JOB_NAME: axlearn-fuji-3b-${{ github.run_id }}
steps:
- name: Check out the repository
uses: actions/checkout@v4
- name: Login to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: K8s GHCR store and delete token
id: store-token
uses: ./.github/actions/store-delete-k8s-ghcr
- name: Configure axlearn test job
run: |
yq -i ea '
select(di == 0).metadata.name = strenv(JOB_NAME)
| select(di == 0).spec.template.spec.containers[0].image = strenv(AXLEARN_DOCKER_IMAGE)
| select(di == 0).spec.template.spec.imagePullSecrets[].name = "${{ steps.store-token.outputs.token-name }}"' \
.github/eks-workflow-files/axlearn/axlearn-fuji-model.yml
git diff .github/eks-workflow-files/axlearn/axlearn-fuji-model.yml
- name: Submit & delete axlearn test
uses: ./.github/actions/submit-delete-k8s-job
with:
job-config-file: ".github/eks-workflow-files/axlearn/axlearn-fuji-model.yml"
job-name: ${{ env.JOB_NAME }}