Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port 20250128 main perf kernel #70

Merged
merged 24 commits into from
Feb 4, 2025

Conversation

xinyazhang
Copy link
Collaborator

@xinyazhang xinyazhang commented Feb 4, 2025

Major Changes

  • [kernel] Backport the 2025/01/28 main_perf kernel
  • [shim] Remove non-power-of-two (NPOT) head dim 72, which triggers compiler bugs on bf16
  • [db] Remove attn_fwd table from tuning database, since the old entries are not valid anymore.
  • [db] Set all entries to num_stages=1 since num_stages=2 constantly trigger compiler bugs
  • [test] Add new head dimensions. Now categories into three groups
    • Power-of-two head dimensions
    • Optimized NPOT head dimensions
    • Prime number head dimensions to cover all gaps b/w neighboring POT+NPOT head dims.
  • [shim] Add env var AOTRITON_SKIP_LUT_CHECK to skip LUT sanity check on certain kernels
    • As of this PR, AOTriton must be built with AOTRITON_SKIP_LUT_CHECK=flash.attn_fwd ninja install

Minor Changes

  • [build] Bump the version number to 0.9.0. (Should be done at the beginning of 0.9 dev)
  • [API] In the API, move bias tensor to the position immediately after v tensor, matching the kernel argument order
  • [shim] Add TensorView<0>::get_null_tensor
  • [test] Change AttentionExtraArgs from namedtuple to dataclass for easier-to-read default values.
  • [mptune] Change output json format to match kernel argument changes.
  • [test] Use cpu reference when seqlen_k == 579 (used by test_gqa tests). GPU reference triggers segfault.
  • [test] Change default value_fudge_factor to 36.0 (Should be 40.0 if considering GQA tests)
  • [shim] Fix the code path when the tuning database is not available

Know Problems

  • Tuning database for flash.attn_fwd kernel is cleared and no plan to re-build it ATM due to immediate additional changes to the forward kernel.

pytest tritonsrc/test_backward -k '1.2 and 4-4]' passed

Only tested POT heads, NPOT optimization will be added later
Has NaN problems for head dim [68,72) when running on MI200 + bf16
Notable changes for users
1. attn_fwd is untuned for now (no point of using the old database due
   to outdated arguments)
2. Only one combination of perf/copts is built for untuned kernels.
   Previously it depends on the there is only one combination of
   PERF_CHOICES perf in KernelDescription. Now we can list all possible
   options there and leaving the first choice default.
3. Replace num_stages=2 with =1 in tuning_database for all remaining kernels.
   num_stages=2 triggers compiler bugs: "error: operation scheduled
   before its operands"
… of types.

Also explicitly disable persistent, and fix a Debug build bug in attn_fwd.
@xinyazhang
Copy link
Collaborator Author

All Unit tests passed when compiled with default options provided in v2python/rules/flash/attn_fwd.py
Note: must run the reference on CPU otherwise will trigger GPU segfault.

@xinyazhang xinyazhang marked this pull request as ready for review February 4, 2025 17:03
@xinyazhang xinyazhang merged commit c70eab6 into main Feb 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants