Skip to content

Commit c4acc6c

Browse files
authored
Merge pull request #10173 from rakhmets/topic/gtest-cuda-async
GTEST/COMMON: Added test for cuda async allocated buffers.
2 parents 0efde8e + 6a0c6c1 commit c4acc6c

File tree

3 files changed

+121
-22
lines changed

3 files changed

+121
-22
lines changed

test/gtest/common/mem_buffer.cc

+36-3
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ void mem_buffer::get_bar1_free_size_nvml()
221221
#endif
222222
}
223223

224-
void *mem_buffer::allocate(size_t size, ucs_memory_type_t mem_type)
224+
void *mem_buffer::allocate(size_t size, ucs_memory_type_t mem_type, bool async)
225225
{
226226
void *ptr;
227227

@@ -238,7 +238,18 @@ void *mem_buffer::allocate(size_t size, ucs_memory_type_t mem_type)
238238
return ptr;
239239
#if HAVE_CUDA
240240
case UCS_MEMORY_TYPE_CUDA:
241-
CUDA_CALL(cudaMalloc(&ptr, size), ": size=" << size);
241+
if (async) {
242+
#if CUDART_VERSION >= 11020
243+
CUDA_CALL(cudaMallocAsync(&ptr, size, 0), ": size=" << size);
244+
cudaStreamSynchronize(0);
245+
#else
246+
UCS_TEST_ABORT("asynchronous allocation for " +
247+
std::string(ucs_memory_type_names[mem_type]) +
248+
" memory type is not supported");
249+
#endif
250+
} else {
251+
CUDA_CALL(cudaMalloc(&ptr, size), ": size=" << size);
252+
}
242253
return ptr;
243254
case UCS_MEMORY_TYPE_CUDA_MANAGED:
244255
CUDA_CALL(cudaMallocManaged(&ptr, size), ": size=" << size);
@@ -258,7 +269,7 @@ void *mem_buffer::allocate(size_t size, ucs_memory_type_t mem_type)
258269
}
259270
}
260271

261-
void mem_buffer::release(void *ptr, ucs_memory_type_t mem_type)
272+
void mem_buffer::release(void *ptr, ucs_memory_type_t mem_type, bool async)
262273
{
263274
try {
264275
switch (mem_type) {
@@ -267,6 +278,19 @@ void mem_buffer::release(void *ptr, ucs_memory_type_t mem_type)
267278
break;
268279
#if HAVE_CUDA
269280
case UCS_MEMORY_TYPE_CUDA:
281+
if (async) {
282+
#if CUDART_VERSION >= 11020
283+
cudaStreamSynchronize(0);
284+
CUDA_CALL(cudaFreeAsync(ptr, 0), ": ptr=" << ptr);
285+
#else
286+
UCS_TEST_ABORT("asynchronous release for " +
287+
std::string(ucs_memory_type_names[mem_type]) +
288+
" memory type is not supported");
289+
#endif
290+
} else {
291+
CUDA_CALL(cudaFree(ptr), ": ptr=" << ptr);
292+
}
293+
break;
270294
case UCS_MEMORY_TYPE_CUDA_MANAGED:
271295
CUDA_CALL(cudaFree(ptr), ": ptr=" << ptr);
272296
break;
@@ -515,6 +539,15 @@ std::string mem_buffer::mem_type_name(ucs_memory_type_t mem_type)
515539
return ucs_memory_type_names[mem_type];
516540
}
517541

542+
bool mem_buffer::is_async_supported(ucs_memory_type_t mem_type)
543+
{
544+
#if CUDART_VERSION >= 11020
545+
return mem_type == UCS_MEMORY_TYPE_CUDA;
546+
#else
547+
return false;
548+
#endif
549+
}
550+
518551
mem_buffer::mem_buffer(size_t size, ucs_memory_type_t mem_type) :
519552
m_mem_type(mem_type), m_ptr(allocate(size, mem_type)), m_size(size) {
520553
}

test/gtest/common/mem_buffer.h

+9-3
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,11 @@ class mem_buffer {
2828
static bool is_mem_type_supported(ucs_memory_type_t mem_type);
2929

3030
/* allocate buffer of a given memory type */
31-
static void *allocate(size_t size, ucs_memory_type_t mem_type);
32-
31+
static void *
32+
allocate(size_t size, ucs_memory_type_t mem_type, bool async = false);
3333
/* release buffer of a given memory type */
34-
static void release(void *ptr, ucs_memory_type_t mem_type);
34+
static void
35+
release(void *ptr, ucs_memory_type_t mem_type, bool async = false);
3536

3637
/* fill pattern in a host-accessible buffer */
3738
static void pattern_fill(void *buffer, size_t length, uint64_t seed);
@@ -103,6 +104,11 @@ class mem_buffer {
103104
return m_bar1_free_size;
104105
}
105106

107+
/**
108+
* Check whether asynchronous operations are supported for the memory type
109+
*/
110+
static bool is_async_supported(ucs_memory_type_t mem_type);
111+
106112
mem_buffer(size_t size, ucs_memory_type_t mem_type);
107113
mem_buffer(size_t size, ucs_memory_type_t mem_type, uint64_t seed);
108114
virtual ~mem_buffer();

test/gtest/ucp/test_ucp_am.cc

+76-16
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,16 @@ class test_ucp_am_nbx : public test_ucp_am_base {
366366
return UCS_MEMORY_TYPE_HOST;
367367
}
368368

369+
virtual bool tx_memtype_async() const
370+
{
371+
return false;
372+
}
373+
374+
virtual bool rx_memtype_async() const
375+
{
376+
return false;
377+
}
378+
369379
void reset_counters()
370380
{
371381
m_send_counter = 0;
@@ -495,8 +505,9 @@ class test_ucp_am_nbx : public test_ucp_am_base {
495505
unsigned flags = 0, unsigned data_cb_flags = 0,
496506
uint32_t op_attr_mask = 0)
497507
{
498-
mem_buffer sbuf(size, tx_memtype());
499-
sbuf.pattern_fill(SEED);
508+
auto sbuf = mem_buffer::allocate(size, tx_memtype(),
509+
tx_memtype_async());
510+
mem_buffer::pattern_fill(sbuf, size, SEED, tx_memtype());
500511
m_hdr.resize(header_size);
501512
ucs::fill_random(m_hdr);
502513
reset_counters();
@@ -505,10 +516,10 @@ class test_ucp_am_nbx : public test_ucp_am_base {
505516
set_am_data_handler(receiver(), TEST_AM_NBX_ID, am_data_cb, this,
506517
data_cb_flags);
507518

508-
ucp::data_type_desc_t sdt_desc(m_dt, sbuf.ptr(), size);
519+
ucp::data_type_desc_t sdt_desc(m_dt, sbuf, size);
509520

510521
if (prereg()) {
511-
memh = sender().mem_map(sbuf.ptr(), size);
522+
memh = sender().mem_map(sbuf, size);
512523
}
513524

514525
ucs_status_ptr_t sptr = send_am(sdt_desc, get_send_flag() | flags,
@@ -522,6 +533,7 @@ class test_ucp_am_nbx : public test_ucp_am_base {
522533
sender().mem_unmap(memh);
523534
}
524535

536+
mem_buffer::release(sbuf, tx_memtype(), tx_memtype_async());
525537
EXPECT_EQ(m_recv_counter, m_send_counter);
526538
}
527539

@@ -562,7 +574,8 @@ class test_ucp_am_nbx : public test_ucp_am_base {
562574
{
563575
ucs_status_t status;
564576

565-
m_rx_buf = mem_buffer::allocate(length, rx_memtype());
577+
m_rx_buf = mem_buffer::allocate(length, rx_memtype(),
578+
rx_memtype_async());
566579
mem_buffer::pattern_fill(m_rx_buf, length, 0ul, rx_memtype());
567580

568581
m_rx_dt_desc.make(m_rx_dt, m_rx_buf, length);
@@ -638,7 +651,7 @@ class test_ucp_am_nbx : public test_ucp_am_base {
638651
if (m_rx_memh != NULL) {
639652
receiver().mem_unmap(m_rx_memh);
640653
}
641-
mem_buffer::release(m_rx_buf, rx_memtype());
654+
mem_buffer::release(m_rx_buf, rx_memtype(), rx_memtype_async());
642655
}
643656

644657
static ucs_status_t am_data_cb(void *arg, const void *header,
@@ -1358,10 +1371,7 @@ class test_ucp_am_nbx_eager_memtype : public test_ucp_am_nbx_prereg {
13581371
private:
13591372
static void base_test_generator(variant_vec_t &variants)
13601373
{
1361-
// 1. Do not instantiate test case if no GPU memtypes supported.
1362-
// 2. Do not exclude host memory type, because this generator is used by
1363-
// test_ucp_am_nbx_rndv_memtype class to generate combinations like
1364-
// host<->cuda, cuda-managed<->host, etc.
1374+
// Do not instantiate test case if no GPU memtypes supported.
13651375
if (!mem_buffer::is_gpu_supported()) {
13661376
return;
13671377
}
@@ -1890,10 +1900,7 @@ class test_ucp_am_nbx_rndv_memtype : public test_ucp_am_nbx_rndv {
18901900
public:
18911901
static void get_test_variants(variant_vec_t &variants)
18921902
{
1893-
// Test will not be instantiated if no GPU memtypes supported, because
1894-
// of the check for supported memory types in
1895-
// test_ucp_am_nbx_eager_memtype::get_test_variants
1896-
return test_ucp_am_nbx_eager_memtype::get_test_variants(variants);
1903+
add_variant_memtypes(variants, base_test_generator);
18971904
}
18981905

18991906
void init() override
@@ -1902,20 +1909,73 @@ class test_ucp_am_nbx_rndv_memtype : public test_ucp_am_nbx_rndv {
19021909
}
19031910

19041911
private:
1912+
static void base_test_generator(variant_vec_t &variants)
1913+
{
1914+
// Do not instantiate test case if no GPU memtypes supported.
1915+
if (!mem_buffer::is_gpu_supported()) {
1916+
return;
1917+
}
1918+
1919+
add_variant_memtypes(variants,
1920+
test_ucp_am_nbx_prereg::get_test_variants);
1921+
}
1922+
1923+
static void
1924+
add_variant_memtypes(variant_vec_t &variants, get_variants_func_t generator)
1925+
{
1926+
ucp_test::add_variant_memtypes(variants, generator);
1927+
1928+
if (mem_buffer::is_mem_type_supported(UCS_MEMORY_TYPE_CUDA) &&
1929+
mem_buffer::is_async_supported(UCS_MEMORY_TYPE_CUDA)) {
1930+
add_variant_values(variants, generator, MEMORY_TYPE_CUDA_ASYNC,
1931+
"cuda-async");
1932+
}
1933+
}
1934+
19051935
unsigned get_send_flag() const override
19061936
{
19071937
return test_ucp_am_nbx_rndv::get_send_flag() | UCP_AM_SEND_FLAG_RNDV;
19081938
}
19091939

19101940
ucs_memory_type_t tx_memtype() const override
19111941
{
1912-
return static_cast<ucs_memory_type_t>(get_variant_value(2));
1942+
return variant_index_to_mem_type(2);
19131943
}
19141944

19151945
ucs_memory_type_t rx_memtype() const override
19161946
{
1917-
return static_cast<ucs_memory_type_t>(get_variant_value(3));
1947+
return variant_index_to_mem_type(3);
19181948
}
1949+
1950+
bool tx_memtype_async() const override
1951+
{
1952+
return get_variant_value(2) == MEMORY_TYPE_CUDA_ASYNC;
1953+
}
1954+
1955+
bool rx_memtype_async() const override
1956+
{
1957+
return get_variant_value(3) == MEMORY_TYPE_CUDA_ASYNC;
1958+
}
1959+
1960+
ucs_memory_type_t variant_index_to_mem_type(unsigned index) const
1961+
{
1962+
auto variant_value = get_variant_value(index);
1963+
switch (variant_value) {
1964+
case UCS_MEMORY_TYPE_HOST:
1965+
case UCS_MEMORY_TYPE_CUDA:
1966+
case UCS_MEMORY_TYPE_CUDA_MANAGED:
1967+
case UCS_MEMORY_TYPE_ROCM:
1968+
case UCS_MEMORY_TYPE_ROCM_MANAGED:
1969+
return static_cast<ucs_memory_type_t>(variant_value);
1970+
case MEMORY_TYPE_CUDA_ASYNC:
1971+
return UCS_MEMORY_TYPE_CUDA;
1972+
default:
1973+
UCS_TEST_ABORT("invalid memory type: " << variant_value);
1974+
return UCS_MEMORY_TYPE_HOST;
1975+
}
1976+
}
1977+
1978+
static const int MEMORY_TYPE_CUDA_ASYNC = UCS_MEMORY_TYPE_LAST + 1;
19191979
};
19201980

19211981
UCS_TEST_P(test_ucp_am_nbx_rndv_memtype, rndv)

0 commit comments

Comments
 (0)