Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to upstream Triton compiler, and related changes #36

Merged
merged 51 commits into from
Jul 26, 2024

Conversation

xinyazhang
Copy link
Collaborator

@xinyazhang xinyazhang commented Jul 15, 2024

  1. Switch to performance kernel for forward pass. The old Triton kernel does not work with new compiler
  2. Support AOT based autotune, which includes
    1. Add argument aotriton::v2::flash::ExtraArguments to all aotriton::v2::flash APIs
    2. Add build option AOTRITON_BUILD_FOR_TUNING to build all possible GPU kernels. The configurations are supplied by KernelDescription.gen_autotune_configs, which is compatible with triton.Config.
    3. AOTRITON_BUILD_FOR_TUNING also enables force_kernel_index and other fields to aotriton::v2::flash::ExtraArguments. Users can manually select kernel and bypass the autotune mechanism.
    4. Add test/tune_flash.py cpp_autotune.py and change test/attn_torch_function.py to support AOT autotune (aka cpp autotune)
      • The test/tune_flash.py will run UT before testing a triton.Config's performance, to avoid including faulty kernels.
  3. Add Navi31/32 compiler options (but not added to the default config due to compiler problems)
  4. Add --use_multigpu to test/tune_flash.py. Now this script support tuning GPU kernels on all GPUs simultaneously, and the following extra features:
    • It also put the UT to a separate process (referred as minesweeper process here), in case the faulty kernel triggers a segfault and crashes the worker process.
      • Thus the tune_flash.py needs 1*(main)+n*(worker)+n*(minesweeper)+1*(db access)+1*(table_tool.py) processes
      • For better performance, the minesweeper process is reused and only get recreated if the previous one hit segfault (or other failures).
    • --json_file is also added since the new architecture has a unified database access process that accept outputs from all worker processes, and this new process can write to a separate json file. This is current recommended way to store the result of tuning script. Users are supposed to run v2python.table_tool later to update the tuning database.
    • --continue_from_json_file is introduced. Meanwhile resultand_debug_task_idfields are also attached to the output json object, so that a tuning process can be resumed according to the_debug_task_id` and its tuning status
    • v2python.table_tool is improved to support the new version of json file
  5. Tuning results of the forward kernel are updated for MI200/MI300X +new compiler. Most UTs passed (see comments for known failures on MI300X)

CAVEAT: The new AOT based autotune script test/tune_flash.py isn't capable of handling backward pass yet.

The old pattern triggers a compiler bug
* Add new build option AOTRITON_BUILD_FOR_TUNING. Required by all
  features below.
* KernelDescription.gen_autotune_configs is introduced to specify autotune.Config objects.
  All kernels will be built by the build system.
* Users now can select kernel through ExtraArguments::force_kernel_index
This is to address is_causal=true while seqlen_q != seqlen_k cases.
For these cases the inputs are not supported and this solution makes
sense.

However for more complicated cases interpolation will be needed.
…andling accordingly.

Without type annotation the compiler behaves differently in JIT mode and
then bugs will slip into AOT mode.
Note the database has not been updated yet.
The only things left are bias with very irregular shapes (e.g.,
False-1.2-dtype0-0.0-4-2048-32-1-1). Maybe due to tuning database
extrapolation problems.
Skip unaccepted inputs (notably causual=True has lots requirements)
Add more sequence lengths
…erator

Now the whole pipeline is:
TunerManager -(mp.Q)-> [Worker] * N -(mp.Q)-> DbAccessor -(stdio)-> table_tool.py
However this seriously reduced the tuning process. Thinking of
alternatives..
Performance restored to normal with this trick.
Now it is possible to continue a failed tuninng (maybe due to various
limitations like power failure and full disk)

CAVEAT: this option assumes the new pass uses the same set of samples
(defined by --seqlen_q/k, etc. options). The correctness is not
guarateed if different sets of samples were used.
@xinyazhang
Copy link
Collaborator Author

xinyazhang commented Jul 15, 2024

Known failures:

test_op_bwd_with_matrix_bias[False-1.2-dtype2-0.0-2048-143-256-4-4]
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-4-4-256-1-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-4-4-256-1-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-4-4-256-4-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-4-4-256-4-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-8-8-256-1-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-8-8-256-1-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-8-8-256-4-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[False-1.2-dtype1-0.0-True-8-8-256-4-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-4-4-256-1-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-4-4-256-1-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-4-4-256-4-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-4-4-256-4-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-8-8-256-1-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-8-8-256-1-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-8-8-256-4-1] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True
FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-8-8-256-4-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True

Produced by pytest test/test_backward.py -v -k 1.2 on MI300X

@xinyazhang
Copy link
Collaborator Author

The UT on MI200 has better results:

FAILED ../test/test_backward.py::test_op_bwd[True-1.2-dtype1-0.0-True-8-8-256-4-4] - AssertionError: dk_allclose=True dv_allclose=False dq_allclose=True db_allclose=True

Tested with pytest test/test_backward.py -v -k 1.2

Copy link
Contributor

@groenenboomj groenenboomj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good. What is the new library size?

o.fill_(float('nan'))
return ipc_func(extargs.force_kernel_index)
# print(f'running attn_fwd with {extargs.force_kernel_index=}')
tuning_result = cpp_autotune(ExtraArguments, func,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any plan to expose this function (ExtraArguments) or is it only for tuning?

Copy link
Collaborator Author

@xinyazhang xinyazhang Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be the part of public AOTriton API for possible further extensions.
For example, we could add a second hipStream in the ExtraArguments for bwd kernel, to run the dkdv and dq concurrently.

However for now cpptuning is the main use cases of the extra argument.

@xinyazhang
Copy link
Collaborator Author

Looks mostly good. What is the new library size?

I don't have the all architecture+no zstd version size. The MI300X only+zstd size is 321M

@xinyazhang xinyazhang merged commit 85d120c into main Jul 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants