@@ -489,12 +489,41 @@ jobs:
489
489
# test-equinox.log
490
490
# secrets: inherit
491
491
492
- test-te-multigpu :
492
+ te-unittests :
493
+ secrets : inherit
493
494
needs : build-jax
494
- if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
495
- uses : ./.github/workflows/_test_te .yaml
495
+ if : inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a
496
+ uses : ./.github/workflows/_test_unit .yaml
496
497
with :
497
- TE_IMAGE : ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}
498
+ TEST_NAME : te
499
+ EXECUTE : |
500
+ docker run -i --gpus all --shm-size=1g -v $PWD:/log \
501
+ ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \
502
+ bash <<"EOF" |& tee test-te.log
503
+ pip install pytest-reportlog pytest-xdist
504
+ # Start MPS daemon
505
+ nvidia-cuda-mps-control -d
506
+ # TE's default is slightly different, without the hyphen
507
+ export TE_PATH=${SRC_PATH_TRANSFORMER_ENGINE}
508
+ # 1 GPU per worker, 6 workers per GPU
509
+ pytest-xdist.sh 1 6 pytest-report-L0-unittest.jsonl bash ${TE_PATH}/qa/L0_jax_unittest/test.sh
510
+ EOF
511
+
512
+ STATISTICS_SCRIPT : |
513
+ summary_line=$(tail -n1 test-te.log)
514
+ errors=$(echo $summary_line | grep -oE '[0-9]+ error' | awk '{print $1} END { if (!NR) print 0}')
515
+ passed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "passed") | .outcome' | wc -l)
516
+ failed_tests=$(cat pytest-report.jsonl | jq -r 'select(."$report_type" == "TestReport" and .when == "call" and .outcome == "failed") | .outcome' | wc -l)
517
+ total_tests=$((failed_tests + passed_tests))
518
+ echo "TOTAL_TESTS=${total_tests}" >> $GITHUB_OUTPUT
519
+ echo "ERRORS=${errors}" >> $GITHUB_OUTPUT
520
+ echo "PASSED_TESTS=${passed_tests}" >> $GITHUB_OUTPUT
521
+ echo "FAILED_TESTS=${failed_tests}" >> $GITHUB_OUTPUT
522
+
523
+ TIMEOUT_MINUTES : 120
524
+ ARTIFACTS : |
525
+ test-te.log
526
+ pytest-report.jsonl
498
527
secrets : inherit
499
528
500
529
# test-upstream-t5x:
0 commit comments