Skip to content

Commit 5d8ca2f

Browse files
kaiyuxvonjackustc
andauthored
Update TensorRT-LLM (#1639)
* Update TensorRT-LLM --------- Co-authored-by: vonjackustc <fga@mail.ustc.edu.cn>
1 parent b189b61 commit 5d8ca2f

File tree

165 files changed

+511601
-517367
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

165 files changed

+511601
-517367
lines changed

benchmarks/python/benchmark.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def parse_arguments():
232232
choices=[
233233
'fp8', 'fp8_gemm', 'fp8_kv_cache', 'int8_sq_per_tensor',
234234
'int8_sq_per_token_channel', 'int8_weight_only', 'int4_weight_only',
235-
'int4_weight_only_awq', 'int4_weight_only_gptq'
235+
'int4_weight_only_awq', 'int4_weight_only_gptq',
236+
'int8_sq_per_channel_ootb'
236237
],
237238
help="Optimize the model with specified quantization recipe")
238239
parser.add_argument(

benchmarks/python/build.py

+2
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ def get_quant_config(quantization: str):
220220
elif quantization == "int8_sq_per_token_channel":
221221
return QuantConfig(
222222
quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN)
223+
elif quantization == "int8_sq_per_channel_ootb":
224+
return QuantConfig(quant_algo=QuantAlgo.W8A8_SQ_PER_CHANNEL)
223225
elif quantization == "int8_weight_only":
224226
return QuantConfig(quant_algo=QuantAlgo.W8A16)
225227
elif quantization == "int4_weight_only":

cpp/include/tensorrt_llm/runtime/decodingOutput.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ class DecodingOutput
3535
class BeamHypotheses
3636
{
3737
public:
38-
// The same as cpp/tensorrt_llm/kernels/beamSearchKernels.h
38+
// Keep same as cpp/tensorrt_llm/kernels/beamSearchKernels.h
3939
TensorPtr outputIdsCBA; // [BS, BM*2, MSL]
40-
TensorPtr sequenceLengthsCBA; // [BS, BM]
40+
TensorPtr logProbsCBA; // [BS, BM*2, MSL]
41+
TensorPtr sequenceLengthsCBA; // [BS, BM*2]
4142
TensorPtr cumLogProbsCBA; // [BS, BM*2]
4243
TensorPtr normedScoresCBA; // [BS, BM*2]
43-
TensorPtr logProbsCBA; // [BS, BM*2, MSL]
44-
TensorPtr minNormedScoresCBA; // [BS]
4544
TensorPtr numBeamsCBA; // [BS]
45+
TensorPtr minNormedScoresCBA; // [BS]
4646
TensorPtr batchDones; // [BS]
4747

4848
void empty(BufferManager& manager);
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
ed2ee8d73a5d374e800f653169bf293e libtensorrt_llm_batch_manager_static.a
22
ed2ee8d73a5d374e800f653169bf293e libtensorrt_llm_batch_manager_static.pre_cxx11.a
3-
f088526f4bce4b1143c67973b3502734c3491ab9 commit
3+
05aaf1a0fb2f0115af107b00aa839a6601f6a873 commit
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:c2cc6b820b9eb87d3417070b2996966e6147c28ba95e47cb97ae7c5d4375b8aa
3-
size 3210470
2+
oid sha256:00e2d6ee8efd00e27dd8da61be576ba7978d885a055d591c90f600b334356846
3+
size 3211414
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:a0c7e0f8e717637d2686c44814640c25c71703bbdb50465955c35990e45a0399
3-
size 3185534
2+
oid sha256:b6b65183b0aa3f40f68aa13105da9dc00fb75b8bf8892813e46a09e3f0743570
3+
size 3186478
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:5de50dfcea7e67aa4f8b1c5404f9902ceea909d798d0154a44800d3af46ce1b1
3-
size 19838492
2+
oid sha256:7d4a3bc5160666612e529f21c61dbd9d0f1b387662768f76b9351f877108f84b
3+
size 19840380

cpp/tensorrt_llm/common/cudaFp8Utils.cu

+3
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ __device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val
206206
}
207207

208208
return __ushort_as_bfloat16(old);
209+
#else
210+
asm volatile(" brkpt;\n");
211+
return 0;
209212
#endif
210213
}
211214

cpp/tensorrt_llm/common/cudaTypeUtils.cuh

+3
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,9 @@ __device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val)
597597
{
598598
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
599599
return __hmax(val.x, val.y);
600+
#else
601+
asm volatile(" brkpt;\n");
602+
return 0;
600603
#endif
601604
}
602605
#endif
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:8d54459f1db7a6b78b67f2a7f378bf4fab4a24fdf70e08f5b1e2bcbf7ffd538d
3-
size 1251758
2+
oid sha256:8ff03e99e17e64c9f559e4586dec3983d438857c0050a34a417e4d86d56fbe2a
3+
size 1251854
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:8d54459f1db7a6b78b67f2a7f378bf4fab4a24fdf70e08f5b1e2bcbf7ffd538d
3-
size 1251758
2+
oid sha256:8ff03e99e17e64c9f559e4586dec3983d438857c0050a34a417e4d86d56fbe2a
3+
size 1251854
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
3c423e837e67ce86756ea468438c41a8 libtensorrt_llm_executor_static.a
2-
3c423e837e67ce86756ea468438c41a8 libtensorrt_llm_executor_static.pre_cxx11.a
3-
f088526f4bce4b1143c67973b3502734c3491ab9 commit
1+
54670adde093baff8b031869bdeeeb1b libtensorrt_llm_executor_static.a
2+
54670adde093baff8b031869bdeeeb1b libtensorrt_llm_executor_static.pre_cxx11.a
3+
05aaf1a0fb2f0115af107b00aa839a6601f6a873 commit
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:ec73640a8f4a20baf86ec85f02ff673edf209f77d2b57b7c5b9e11a8abcc3dd5
3-
size 1269974
2+
oid sha256:431dc6352dcb332821aab031ccbd887e6a60591a5ea276a9ffd3df1f28463326
3+
size 1271014
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:b5315703c0e4c6f0154ac60b3700187dbf787d50fdfc1cdf33b9806814a39846
3-
size 1226290
2+
oid sha256:86319542d275570a0c66622d4656b88f3d153c6861db0e53f17f29d47e0a30c9
3+
size 1227362
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:0aec01d232889a57fc99133e28a39c82ae47cfcc58894a938d851d9b2c3e64d5
3-
size 12074178
2+
oid sha256:8ed99448579b40e0046eca5c8989151a66579f8fbccef9cda4ee7fc2ffd2245b
3+
size 12076106

cpp/tensorrt_llm/executor_worker/executorWorker.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17+
#include "tensorrt_llm/common/cudaUtils.h"
1718
#include "tensorrt_llm/common/logger.h"
1819
#include "tensorrt_llm/common/mpiUtils.h"
1920
#include "tensorrt_llm/executor/executor.h"
@@ -46,6 +47,11 @@ int main(int argc, char* argv[])
4647
return -1;
4748
}
4849

50+
// TRT-LLM event synchronization sometimes takes extra time to complete
51+
// after the kernel has finished. Using a yield in the wait helps improve
52+
// performance.
53+
TLLM_CUDA_CHECK(::cudaSetDeviceFlags(cudaDeviceScheduleYield));
54+
4955
// Since parentComm is an intercommunicator, input root
5056
// is the rank of the parent process in his group
5157
// (always 0 as the parent size is checked before)

cpp/tensorrt_llm/kernels/beamSearchKernels.h

+10-9
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,18 @@ struct BeamHypotheses
2929
{
3030
// clang-format off
3131

32-
// BS: batch_size, BM: beam_width, MSL: max_seq_length
33-
// %%: parameter name when dynamic_decoder.forward() / gather_tree() are called in [generation.py] (python workflow)
32+
// MBS: max_batch_size, BS: batch_size, BM: beam_width, MSL: max_seq_length
33+
// %%: parameter name in file generation.py (python workflow)
3434

3535
// Candidate beams: a beam which generates end_id or its sequence length reaches MSL
36-
// Candidate-Beam-Array (CBA): The arrays (size: BM*2) to place the candidate beams and related information
36+
// Candidate-Beam-Array (CBA): The arrays to place the candidate beams and related information
3737

3838
// Scalar values
3939
bool bReturnNormedScore{false}; // return normed_score / cum_log_probs, useless yet
40-
int nBatchSize{0}; //
40+
int nMaxBatchSize{0}; // max batch size by model configuration
41+
int nBatchSize{0}; // batch size by runtime input data
4142
int nBeamWidth{0}; //
4243
int nIte{0}; // index of local_batch, always be 0 when pp_size==1
43-
int nBatchSizeLocal{0}; //
4444
int nMaxSeqLen{0}; //
4545
int nVocabSize{0}; // vocab_size_padded
4646

@@ -54,8 +54,9 @@ struct BeamHypotheses
5454
int const* endIds{nullptr}; // [BS, BM] %% self.end_ids
5555

5656
// Pointers for output
57-
int* outputIds{nullptr}; // [BS, BM, MSL] %% self.output_ids
58-
float* logProbs{nullptr}; // [MSL, BS, BM] %% self.log_probs_tiled
57+
int* outputIds{nullptr}; // [BS, BM, MSL] %% self.output_ids only used in gather_tree
58+
float* logProbs{nullptr}; // [BS, BM, MSL] %% self.log_probs only used in gather_tree
59+
float* logProbsTiled{nullptr}; // [MSL, MBS, BM] %% self.log_probs_tiled
5960
int* sequenceLengths{nullptr}; // [BS, BM] %% self.sequence_length_buffer
6061
float* cumLogProbs{nullptr}; // [BS, BM] %% self.cum_log_probs
6162

@@ -65,8 +66,8 @@ struct BeamHypotheses
6566
int* sequenceLengthsCBA{nullptr}; // [BS, BM*2] %% self.beam_hyps_seq_len_cba
6667
float* cumLogProbsCBA{nullptr}; // [BS, BM*2] %% self.beam_hyps_cum_log_probs_cba
6768
float* normedScoresCBA{nullptr}; // [BS, BM*2] %% self.beam_hyps_normed_scores_cba
68-
int* numBeamsCBA{nullptr}; // [BS] %% self.beam_hyps_num_beams number of beams in CBA
69-
float* minNormedScoresCBA{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores worst score in CBA
69+
int* numBeamsCBA{nullptr}; // [BS] %% self.beam_hyps_num_beams number of beams in CBA
70+
float* minNormedScoresCBA{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores worst score in CBA
7071

7172
// Pointers related to beam search process, they are initialized in those two functions:
7273
// [gptDecoder.cpp] GptDecoder<T>::forward or [dynamicDecodeOp.cpp] FtDynamicDecode<T>::forward

0 commit comments

Comments
 (0)