-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switch to upstream Triton compiler, and related changes #36
Conversation
The old pattern triggers a compiler bug
* Add new build option AOTRITON_BUILD_FOR_TUNING. Required by all features below. * KernelDescription.gen_autotune_configs is introduced to specify autotune.Config objects. All kernels will be built by the build system. * Users now can select kernel through ExtraArguments::force_kernel_index
… AOTRITON_BUILD_FOR_TUNING=1
This is to address is_causal=true while seqlen_q != seqlen_k cases. For these cases the inputs are not supported and this solution makes sense. However for more complicated cases interpolation will be needed.
…andling accordingly. Without type annotation the compiler behaves differently in JIT mode and then bugs will slip into AOT mode.
Note the database has not been updated yet.
The only things left are bias with very irregular shapes (e.g., False-1.2-dtype0-0.0-4-2048-32-1-1). Maybe due to tuning database extrapolation problems.
Skip unaccepted inputs (notably causual=True has lots requirements) Add more sequence lengths
… GPUs. Confirmed with amd-smi -w 2
Mainly autotune-related scripts
…erator Now the whole pipeline is: TunerManager -(mp.Q)-> [Worker] * N -(mp.Q)-> DbAccessor -(stdio)-> table_tool.py
and be less verbose.
However this seriously reduced the tuning process. Thinking of alternatives..
Performance restored to normal with this trick.
Now it is possible to continue a failed tuninng (maybe due to various limitations like power failure and full disk) CAVEAT: this option assumes the new pass uses the same set of samples (defined by --seqlen_q/k, etc. options). The correctness is not guarateed if different sets of samples were used.
Known failures:
Produced by |
The UT on MI200 has better results:
Tested with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good. What is the new library size?
o.fill_(float('nan')) | ||
return ipc_func(extargs.force_kernel_index) | ||
# print(f'running attn_fwd with {extargs.force_kernel_index=}') | ||
tuning_result = cpp_autotune(ExtraArguments, func, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any plan to expose this function (ExtraArguments) or is it only for tuning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be the part of public AOTriton API for possible further extensions.
For example, we could add a second hipStream in the ExtraArguments for bwd kernel, to run the dkdv and dq concurrently.
However for now cpptuning is the main use cases of the extra argument.
I don't have the all architecture+no zstd version size. The MI300X only+zstd size is 321M |
aotriton::v2::flash::ExtraArguments
to allaotriton::v2::flash
APIsAOTRITON_BUILD_FOR_TUNING
to build all possible GPU kernels. The configurations are supplied byKernelDescription.gen_autotune_configs
, which is compatible withtriton.Config
.AOTRITON_BUILD_FOR_TUNING
also enablesforce_kernel_index
and other fields toaotriton::v2::flash::ExtraArguments
. Users can manually select kernel and bypass the autotune mechanism.test/tune_flash.py
cpp_autotune.py
and changetest/attn_torch_function.py
to support AOT autotune (aka cpp autotune)test/tune_flash.py
will run UT before testing atriton.Config
's performance, to avoid including faulty kernels.--use_multigpu
totest/tune_flash.py
. Now this script support tuning GPU kernels on all GPUs simultaneously, and the following extra features:tune_flash.py
needs 1*(main)+n*(worker)+n*(minesweeper)+1*(db access)+1*(table_tool.py) processes--json_file
is also added since the new architecture has a unified database access process that accept outputs from all worker processes, and this new process can write to a separate json file. This is current recommended way to store the result of tuning script. Users are supposed to runv2python.table_tool
later to update the tuning database.--continue_from_json_file is introduced. Meanwhile
resultand
_debug_task_idfields are also attached to the output json object, so that a tuning process can be resumed according to the
_debug_task_id` and its tuning statusv2python.table_tool
is improved to support the new version of json fileCAVEAT: The new AOT based autotune script
test/tune_flash.py
isn't capable of handling backward pass yet.