Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel Storage V2 #50

Merged
merged 22 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
eeaa826
generate_compile: Add --test_clustering to figure out optimal cluster…
xinyazhang Oct 16, 2024
c0d29cb
generate_compile: add option --generate_cluster_info
xinyazhang Oct 17, 2024
41abdc5
build: support kernel storage v2
xinyazhang Oct 17, 2024
2a30ad4
Add packed kernel support and defines AKS2 format (Untested).
xinyazhang Oct 18, 2024
d3ff51b
Add the archiving tool aks2.py, also adjust the binary format
xinyazhang Oct 18, 2024
0bcd77e
Fix all compiling errors.
xinyazhang Oct 19, 2024
81ae205
Add some missing checks and cleans in packed_kernel.cc
xinyazhang Oct 21, 2024
0c9a53e
Install aotriton.images
xinyazhang Oct 21, 2024
1d0eceb
Fix in progress. python test/test_backward.py (Note NOT pytest) passed.
xinyazhang Oct 22, 2024
50ef132
Suppress debug output in Release build.
xinyazhang Oct 22, 2024
d07e9b3
Add cmake option AOTRITON_NOIMAGE_MODE
xinyazhang Oct 22, 2024
2f17ff6
Clang-format C++ Source/Header files
xinyazhang Oct 22, 2024
f10ed75
clean up v2src/CMakeLists.txt
xinyazhang Oct 22, 2024
1b0a0c2
Clean up v2src/triton_kernel.cc
xinyazhang Oct 22, 2024
406c64b
Copyright notices of aks2.py
xinyazhang Oct 22, 2024
bfbdea9
Fix the default of no image mode.
xinyazhang Oct 22, 2024
bac1a06
Update Triton to Oct/23/2024
xinyazhang Oct 25, 2024
e992093
Make _common_test compatible with both PyTorch < 2.5 and >= 2.5
xinyazhang Oct 25, 2024
0dc5705
README: Update the Prerequisites
xinyazhang Oct 30, 2024
4ef62e4
Remove AOTRITON_NO_SHARED option. It is incompatible with Kernel Stor…
xinyazhang Nov 8, 2024
536aee4
Add pybind11 to requirements.txt
xinyazhang Nov 12, 2024
b7c33cd
add missing header <mutex> to packed_kernel.cc
xinyazhang Nov 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ ColumnLimit: 110
AlignArrayOfStructures: Right
# For code in bindings/
NamespaceIndentation: Inner
# Function arguments alignment
AllowAllParametersOfDeclarationOnNextLine: false
AllowAllArgumentsOnNextLine: false
AlignAfterOpenBracket: Align
BinPackArguments: false
BinPackParameters: false
78 changes: 27 additions & 51 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)

project(AOTriton CXX C)

set(CMAKE_CXX_STANDARD 20)

set(AOTRITON_MIN_PYTHON 3.8 CACHE STRING "Minimal Python version for find_package.")
if(AOTRITON_MIN_PYTHON VERSION_LESS "3.8")
message(FATAL_ERROR "Do not set AOTRITON_MIN_PYTHON lower than 3.8. The code itself does not support it.")
Expand All @@ -13,11 +16,11 @@ find_package(Python3 ${AOTRITON_MIN_PYTHON} COMPONENTS Interpreter REQUIRED)

set(VENV_DIR "${CMAKE_CURRENT_BINARY_DIR}/venv" CACHE STRING "Virtual Environment Directory")
set(AOTRITON_HIPCC_PATH "hipcc" CACHE STRING "Set HIPCC Path")
option(AOTRITON_NO_SHARED "Build as archive library." OFF)
option(AOTRITON_NO_PYTHON "Disable python binding build" OFF)
option(AOTRITON_ENABLE_ASAN "Enable Address Sanitizer. Implies -g" OFF)
option(AOTRITON_BUILD_FOR_TUNING "Build all GPU kernels and set -DAOTRITON_BUILD_FOR_TUNING=1 (=0 otherwise)" OFF)
option(AOTRITON_ENABLE_FP32_INPUTS "Enable FP32 support." ON)
option(AOTRITON_NOIMAGE_MODE "Only build C++ Shim part. Kernel image builds are disabled" OFF)
set(AOTRITON_GPU_BUILD_TIMEOUT "8.0" CACHE STRING "GPU kernel compiler times out after X minutes. 0 for indefinite. Highly recommended if AOTRITON_BUILD_FOR_TUNING=On.")
set(TARGET_GPUS "MI200;MI300X;Navi31" CACHE STRING "Target Architecture (Note here uses Trade names)")

Expand All @@ -30,10 +33,6 @@ endif()
# Must be after pybind11
set(CMAKE_CXX_COMPILER hipcc)

# GPU kernel compression related options
option(AOTRITON_COMPRESS_KERNEL "Enable GPU kernel compression with zstd. Fail when zstd is unavailable. Only effective for AOTriton API V2" ON)
# option(AOTRITON_COMPRESS_KERNEL_STATIC_ZSTD "Use static zstd library to avoid potential zstd version conflict (e.g. pytorch)" OFF)

# Resolve name conflicts with suffix
set(AOTRITON_NAME_SUFFIX "" CACHE STRING "Add suffix to namespace and library file name. This is to resolve name conflicts with PyTorch's AOTriton during testing.")
if(AOTRITON_NAME_SUFFIX)
Expand All @@ -45,39 +44,13 @@ configure_file(include/aotriton/config.h.in include/aotriton/config.h)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/include/aotriton/config.h
DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR}/aotriton)

# Note for archive library user:
# get this property with:
# get_property(ZSTD_INCLUDE_DIR TARGET zstd::libzstd_shared PROPERTY INTERFACE_INCLUDE_DIRECTORIES)
# "zstd::libzstd_shared" can be replaced with zstd::libzstd_static
set(AOTRITON_OVERRIDE_ZSTD_INCLUDE "" CACHE STRING "(For archive library users) override zstd header directory.\
Caveat: should consider set AOTRITON_NO_SHARED because objects are compiled with this header file,\
but shared objects will be linked to libzstd found by find_package later.")
if(AOTRITON_COMPRESS_KERNEL)
find_program(ZSTD_EXEC zstd REQUIRED)
include(FindPkgConfig)
pkg_search_module(ZSTD REQUIRED libzstd)
# if (AOTRITON_COMPRESS_KERNEL_STATIC_ZSTD)
# set(ZSTD_TARGET zstd::libzstd_static)
# else()
# if(TARGET zstd::libzstd_shared)
# set(ZSTD_TARGET zstd::libzstd_shared)
# else()
# set(ZSTD_TARGET zstd::libzstd_static)
# endif()
# endif()
# get_property(AOTRITON_ZSTD_INCLUDE TARGET ${ZSTD_TARGET} PROPERTY INTERFACE_INCLUDE_DIRECTORIES)
add_library(zstd_interface INTERFACE)
# There are other options but these are sufficient on EL8
target_link_libraries(zstd_interface INTERFACE ${ZSTD_LIBRARIES})
target_link_directories(zstd_interface INTERFACE ${ZSTD_LIBRARY_DIRS})
target_include_directories(zstd_interface INTERFACE ${ZSTD_INCLUDE_DIRS})
set(AOTRITON_ZSTD_INCLUDE "${ZSTD_INCLUDE_DIRS}")
message(STATUS "ZSTD_TARGET ${ZSTD_TARGET}")
message(STATUS "get_property AOTRITON_ZSTD_INCLUDE ${AOTRITON_ZSTD_INCLUDE}")
if(AOTRITON_OVERRIDE_ZSTD_INCLUDE)
set(AOTRITON_ZSTD_INCLUDE ${AOTRITON_OVERRIDE_ZSTD_INCLUDE})
endif()
endif()
# Kernel Storage V2 uses xz/LZMA for compression
include(FindPkgConfig)
pkg_search_module(LZMA REQUIRED liblzma)
add_library(lzma_interface INTERFACE)
target_link_libraries(lzma_interface INTERFACE ${LZMA_LIBRARIES})
target_link_directories(lzma_interface INTERFACE ${LZMA_LIBRARY_DIRS})
target_include_directories(lzma_interface INTERFACE ${LZMA_INCLUDE_DIRS})

if(AOTRITON_ENABLE_ASAN)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -fsanitize=address -fno-omit-frame-pointer")
Expand All @@ -100,19 +73,22 @@ message("VENV_SITE ${VENV_SITE}")

execute_process(COMMAND ${CMAKE_COMMAND} -E env VIRTUAL_ENV=${VENV_DIR} "${VENV_DIR}/bin/python" -m pip install -r "${CMAKE_CURRENT_LIST_DIR}/requirements.txt")

set(TRITON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/triton_build")
execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory "${TRITON_BUILD_DIR}")
set(AOTRITON_TRITON_SO "${CMAKE_CURRENT_LIST_DIR}/third_party/triton/python/triton/_C/libtriton.so")
set(AOTRITON_TRITON_EGGLINK "${VENV_SITE}/triton.egg-link")
message("AOTRITON_TRITON_EGGLINK ${AOTRITON_TRITON_EGGLINK}")

add_custom_command(OUTPUT "${AOTRITON_TRITON_EGGLINK}"
COMMAND ${CMAKE_COMMAND} -E env VIRTUAL_ENV=${VENV_DIR} TRITON_BUILD_DIR=${TRITON_BUILD_DIR} "${VENV_DIR}/bin/python" setup.py develop
# COMMAND ${CMAKE_COMMAND} -E env VIRTUAL_ENV=${VENV_DIR} python -m pip show triton
WORKING_DIRECTORY "${CMAKE_CURRENT_LIST_DIR}/third_party/triton/python/"
BYPRODUCTS "${AOTRITON_TRITON_SO}"
)
add_custom_target(aotriton_venv_triton ALL DEPENDS ${AOTRITON_TRITON_EGGLINK})
# AOTRITON_NOIMAGE_MODE does not need Triton
if(NOT AOTRITON_NOIMAGE_MODE)
set(TRITON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/triton_build")
execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory "${TRITON_BUILD_DIR}")
set(AOTRITON_TRITON_SO "${CMAKE_CURRENT_LIST_DIR}/third_party/triton/python/triton/_C/libtriton.so")
set(AOTRITON_TRITON_EGGLINK "${VENV_SITE}/triton.egg-link")
message("AOTRITON_TRITON_EGGLINK ${AOTRITON_TRITON_EGGLINK}")

add_custom_command(OUTPUT "${AOTRITON_TRITON_EGGLINK}"
COMMAND ${CMAKE_COMMAND} -E env VIRTUAL_ENV=${VENV_DIR} TRITON_BUILD_DIR=${TRITON_BUILD_DIR} "${VENV_DIR}/bin/python" setup.py develop
# COMMAND ${CMAKE_COMMAND} -E env VIRTUAL_ENV=${VENV_DIR} python -m pip show triton
WORKING_DIRECTORY "${CMAKE_CURRENT_LIST_DIR}/third_party/triton/python/"
BYPRODUCTS "${AOTRITON_TRITON_SO}"
)
add_custom_target(aotriton_venv_triton ALL DEPENDS ${AOTRITON_TRITON_EGGLINK})
endif(NOT AOTRITON_NOIMAGE_MODE)

add_subdirectory(v2src)

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ system, `ninja install` will run the whole build process unconditionally.
* `hipcc` in `/opt/rocm/bin`, as a part of [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/)
* `cmake`
* `ninja`
* `libzstd`
- Common names are `libzstd-dev` or `libzstd-devel`.
* `liblzma`
- Common names are `liblzma-dev` or `xz-devel`.

## Generation

Expand Down
7 changes: 7 additions & 0 deletions bindings/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <aotriton/cpp_tune.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/gil.h>
#include <string>

namespace py = pybind11;
Expand Down Expand Up @@ -36,6 +37,7 @@ namespace pyaotriton {
m.def("attn_fwd",
&aotriton::v2::flash::attn_fwd,
"Flash Attention Forward Pass",
py::call_guard<py::gil_scoped_release>(),
py::arg("q"),
py::arg("k"),
py::arg("v"),
Expand All @@ -56,6 +58,7 @@ namespace pyaotriton {
m.def("attn_fwd_compact_varlen",
&aotriton::v2::flash::attn_fwd_compact_varlen,
"Flash Attention Forward Pass, Compact Stored Varlen",
py::call_guard<py::gil_scoped_release>(),
py::arg("q"),
py::arg("k"),
py::arg("v"),
Expand All @@ -80,6 +83,7 @@ namespace pyaotriton {
m.def("attn_bwd",
&aotriton::v2::flash::attn_bwd,
"Flash Attention Backward Pass",
py::call_guard<py::gil_scoped_release>(),
py::arg("q"),
py::arg("k"),
py::arg("v"),
Expand All @@ -103,6 +107,7 @@ namespace pyaotriton {
m.def("attn_bwd_compact_varlen",
&aotriton::v2::flash::attn_bwd_compact_varlen,
"Flash Attention Backward Pass, Compact Stored Varlen",
py::call_guard<py::gil_scoped_release>(),
py::arg("q"),
py::arg("k"),
py::arg("v"),
Expand Down Expand Up @@ -130,13 +135,15 @@ namespace pyaotriton {
m.def("debug_fill_dropout_rng",
&aotriton::v2::flash::debug_fill_dropout_rng,
"Flash Attention Debugging Function to get raw RNG numbers used in dropout",
py::call_guard<py::gil_scoped_release>(),
py::arg("q"),
py::arg("philox_seed"),
py::arg("philox_offset"),
py::arg("stream") = nullptr);
m.def("debug_fill_dropout_rng_tensor",
&aotriton::v2::flash::debug_fill_dropout_rng_tensor,
"Flash Attention Debugging Function to get raw RNG numbers used in dropout",
py::call_guard<py::gil_scoped_release>(),
py::arg("q"),
py::arg("philox_seed"),
py::arg("philox_offset"),
Expand Down
48 changes: 48 additions & 0 deletions include/aotriton/_internal/packed_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright © 2024 Advanced Micro Devices, Inc.
// SPDX-License-Identifier: MIT

#ifndef AOTRITON_V2_API_PACKED_KERNEL_H
#define AOTRITON_V2_API_PACKED_KERNEL_H

#include <aotriton/_internal/triton_kernel.h>
#include <aotriton/config.h>
#include <memory>
#include <shared_mutex>
#include <stdint.h>
#include <string_view>
#include <tuple>
#include <unordered_map>
#include <vector>

namespace AOTRITON_NS {

using PackedKernelPtr = std::shared_ptr<PackedKernel>;
struct AKS2_Metadata;

class PackedKernel {
public:
static PackedKernelPtr open(const char* package_path);
PackedKernel(int fd);
~PackedKernel();
hipError_t status() const {
return final_status_;
}

TritonKernel::Essentials filter(const char* stem_name) const;

private:
static std::shared_mutex registry_mutex_;
static std::unordered_map<std::string_view, PackedKernelPtr> registry_;
// Note: do NOT drop the decompressed directory, its content is used by
// the unordered_map directory_
std::vector<uint8_t> decompressed_content_;
hipError_t final_status_;

const uint8_t* kernel_start_;
// Note: again, AKS2_Metadata points to directory at decompressed_content_
std::unordered_map<std::string_view, const AKS2_Metadata*> directory_;
};

};

#endif
42 changes: 20 additions & 22 deletions include/aotriton/_internal/triton_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,51 @@
#ifndef AOTRITON_V2_API_TRITON_KERNEL_H
#define AOTRITON_V2_API_TRITON_KERNEL_H

#ifndef AOTRITON_USE_ZSTD
#error "Must define AOTRITON_USE_ZSTD explicitly. "\
"Inconsistent definition of this macro causes ABI incompatibility."
#endif

#include <aotriton/config.h>
#include "../runtime.h"
#include <vector>
#include <unordered_map>
#include <aotriton/config.h>
#include <memory>
#include <shared_mutex>
#include <tuple>
#include <unordered_map>
#include <vector>

namespace AOTRITON_NS {

class PackedKernel;

class TritonKernel {
public:
TritonKernel(const void* image, size_t image_size, dim3 block, int shared_memory_size);
using Essentials = std::tuple<const void*, int, dim3>;

TritonKernel(const char* package_path, const char* stem_name);

hipError_t invoke(const char* kernel_name, dim3 grid, std::vector<void*>& args, hipStream_t stream);

#if AOTRITON_USE_ZSTD
void clear_decompressed_image();
#endif

private:
std::tuple<hipFunction_t, hipError_t> load_for_device(int device_id, const char* kernel_name);
hipFunction_t cfind_function(int device_id) const;

const void* kernel_image_ = nullptr;
const char* package_path_ = nullptr;
const char* stem_name_ = nullptr;
size_t image_size_ = 0;
struct DeviceFunction {
DeviceFunction(int device_id_,
hipModule_t mod_,
hipFunction_t func_);
DeviceFunction(int device_id_, hipModule_t mod_, hipFunction_t func_);
~DeviceFunction();
int device_id = -1;
hipModule_t mod = nullptr;
hipFunction_t func = nullptr;
hipFunction_t func = nullptr;
};
std::unordered_map<int, DeviceFunction> funcache_;
std::shared_mutex mutex_;
std::shared_mutex funcache_mutex_;

int shared_memory_size_ = 0;
dim3 block_ { 256, 1, 1 };
int shared_memory_size_;
#if AOTRITON_USE_ZSTD
std::vector<char> decompressed_kernel_image_;
void* decompress_kernel();
#endif
const void* kernel_image_ = nullptr;
Essentials decompress_kernel();
std::shared_ptr<PackedKernel> packed_kernel_ = nullptr;
std::shared_mutex packedkernel_mutex_;
};

}
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pluggy
numpy
setuptools
wheel
pybind11
4 changes: 2 additions & 2 deletions test/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,5 +261,5 @@ def main():
_do_test_op_bwd(BATCH, N_HEADS, D_HEAD, seqlen_q, seqlen_k, causal, sm_scale, dropout_p, dtype, storage_flip, bias_type)

if __name__ == '__main__':
# main2()
main_npz()
main2()
# main_npz()
2 changes: 1 addition & 1 deletion third_party/triton
34 changes: 27 additions & 7 deletions tritonsrc/_common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
import math
import torch



def sdpa_math(query, key, value, attn_mask=None, dropout_p=0.0, dropout_mask=None, is_causal=False, scale=None, enable_gqa=False):
if torch.__version__ >= '2.5.0':
return torch.ops.aten._scaled_dot_product_attention_math(query, key, value,
dropout_p=dropout_p,
is_causal=is_causal,
attn_mask=attn_mask,
scale=scale,
dropout_mask=dropout_mask,
enable_gqa=enable_gqa)
else:
return torch.ops.aten._scaled_dot_product_attention_math(query, key, value,
dropout_p=dropout_p,
is_causal=is_causal,
attn_mask=attn_mask,
scale=scale,
dropout_mask=dropout_mask)


def _reference_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)
Expand Down Expand Up @@ -278,13 +298,13 @@ def _compute_ref_forward(ref_tensors, p : SdpaParams):
enable_gqa = num_head_q != num_head_k
dropout_mask = p.dropout_mask if p.dropout_mask is None else p.dropout_mask.to(device=ref_q.device)
# _scaled_dot_product_attention_math seems also working for nested tensor
ref_out, ref_mask = torch.ops.aten._scaled_dot_product_attention_math(ref_q, ref_k, ref_v,
dropout_p=p.dropout_p,
is_causal=p.causal,
attn_mask=ref_b,
scale=p.sm_scale,
dropout_mask=dropout_mask,
enable_gqa=enable_gqa)
ref_out, ref_mask = sdpa_math(ref_q, ref_k, ref_v,
dropout_p=p.dropout_p,
is_causal=p.causal,
attn_mask=ref_b,
scale=p.sm_scale,
dropout_mask=dropout_mask,
enable_gqa=enable_gqa)
return (ref_out, ref_mask)

def compute_ref_forward(self, p : SdpaParams):
Expand Down
Loading