@@ -366,6 +366,16 @@ class test_ucp_am_nbx : public test_ucp_am_base {
366
366
return UCS_MEMORY_TYPE_HOST;
367
367
}
368
368
369
+ virtual bool tx_memtype_async () const
370
+ {
371
+ return false ;
372
+ }
373
+
374
+ virtual bool rx_memtype_async () const
375
+ {
376
+ return false ;
377
+ }
378
+
369
379
void reset_counters ()
370
380
{
371
381
m_send_counter = 0 ;
@@ -495,8 +505,9 @@ class test_ucp_am_nbx : public test_ucp_am_base {
495
505
unsigned flags = 0 , unsigned data_cb_flags = 0 ,
496
506
uint32_t op_attr_mask = 0 )
497
507
{
498
- mem_buffer sbuf (size, tx_memtype ());
499
- sbuf.pattern_fill (SEED);
508
+ auto sbuf = mem_buffer::allocate (size, tx_memtype (),
509
+ tx_memtype_async ());
510
+ mem_buffer::pattern_fill (sbuf, size, SEED, tx_memtype ());
500
511
m_hdr.resize (header_size);
501
512
ucs::fill_random (m_hdr);
502
513
reset_counters ();
@@ -505,10 +516,10 @@ class test_ucp_am_nbx : public test_ucp_am_base {
505
516
set_am_data_handler (receiver (), TEST_AM_NBX_ID, am_data_cb, this ,
506
517
data_cb_flags);
507
518
508
- ucp::data_type_desc_t sdt_desc (m_dt, sbuf. ptr () , size);
519
+ ucp::data_type_desc_t sdt_desc (m_dt, sbuf, size);
509
520
510
521
if (prereg ()) {
511
- memh = sender ().mem_map (sbuf. ptr () , size);
522
+ memh = sender ().mem_map (sbuf, size);
512
523
}
513
524
514
525
ucs_status_ptr_t sptr = send_am (sdt_desc, get_send_flag () | flags,
@@ -522,6 +533,7 @@ class test_ucp_am_nbx : public test_ucp_am_base {
522
533
sender ().mem_unmap (memh);
523
534
}
524
535
536
+ mem_buffer::release (sbuf, tx_memtype (), tx_memtype_async ());
525
537
EXPECT_EQ (m_recv_counter, m_send_counter);
526
538
}
527
539
@@ -562,7 +574,8 @@ class test_ucp_am_nbx : public test_ucp_am_base {
562
574
{
563
575
ucs_status_t status;
564
576
565
- m_rx_buf = mem_buffer::allocate (length, rx_memtype ());
577
+ m_rx_buf = mem_buffer::allocate (length, rx_memtype (),
578
+ rx_memtype_async ());
566
579
mem_buffer::pattern_fill (m_rx_buf, length, 0ul , rx_memtype ());
567
580
568
581
m_rx_dt_desc.make (m_rx_dt, m_rx_buf, length);
@@ -638,7 +651,7 @@ class test_ucp_am_nbx : public test_ucp_am_base {
638
651
if (m_rx_memh != NULL ) {
639
652
receiver ().mem_unmap (m_rx_memh);
640
653
}
641
- mem_buffer::release (m_rx_buf, rx_memtype ());
654
+ mem_buffer::release (m_rx_buf, rx_memtype (), rx_memtype_async () );
642
655
}
643
656
644
657
static ucs_status_t am_data_cb (void *arg, const void *header,
@@ -1358,10 +1371,7 @@ class test_ucp_am_nbx_eager_memtype : public test_ucp_am_nbx_prereg {
1358
1371
private:
1359
1372
static void base_test_generator (variant_vec_t &variants)
1360
1373
{
1361
- // 1. Do not instantiate test case if no GPU memtypes supported.
1362
- // 2. Do not exclude host memory type, because this generator is used by
1363
- // test_ucp_am_nbx_rndv_memtype class to generate combinations like
1364
- // host<->cuda, cuda-managed<->host, etc.
1374
+ // Do not instantiate test case if no GPU memtypes supported.
1365
1375
if (!mem_buffer::is_gpu_supported ()) {
1366
1376
return ;
1367
1377
}
@@ -1890,10 +1900,7 @@ class test_ucp_am_nbx_rndv_memtype : public test_ucp_am_nbx_rndv {
1890
1900
public:
1891
1901
static void get_test_variants (variant_vec_t &variants)
1892
1902
{
1893
- // Test will not be instantiated if no GPU memtypes supported, because
1894
- // of the check for supported memory types in
1895
- // test_ucp_am_nbx_eager_memtype::get_test_variants
1896
- return test_ucp_am_nbx_eager_memtype::get_test_variants (variants);
1903
+ add_variant_memtypes (variants, base_test_generator);
1897
1904
}
1898
1905
1899
1906
void init () override
@@ -1902,20 +1909,73 @@ class test_ucp_am_nbx_rndv_memtype : public test_ucp_am_nbx_rndv {
1902
1909
}
1903
1910
1904
1911
private:
1912
+ static void base_test_generator (variant_vec_t &variants)
1913
+ {
1914
+ // Do not instantiate test case if no GPU memtypes supported.
1915
+ if (!mem_buffer::is_gpu_supported ()) {
1916
+ return ;
1917
+ }
1918
+
1919
+ add_variant_memtypes (variants,
1920
+ test_ucp_am_nbx_prereg::get_test_variants);
1921
+ }
1922
+
1923
+ static void
1924
+ add_variant_memtypes (variant_vec_t &variants, get_variants_func_t generator)
1925
+ {
1926
+ ucp_test::add_variant_memtypes (variants, generator);
1927
+
1928
+ if (mem_buffer::is_mem_type_supported (UCS_MEMORY_TYPE_CUDA) &&
1929
+ mem_buffer::is_async_supported (UCS_MEMORY_TYPE_CUDA)) {
1930
+ add_variant_values (variants, generator, MEMORY_TYPE_CUDA_ASYNC,
1931
+ " cuda-async" );
1932
+ }
1933
+ }
1934
+
1905
1935
unsigned get_send_flag () const override
1906
1936
{
1907
1937
return test_ucp_am_nbx_rndv::get_send_flag () | UCP_AM_SEND_FLAG_RNDV;
1908
1938
}
1909
1939
1910
1940
ucs_memory_type_t tx_memtype () const override
1911
1941
{
1912
- return static_cast < ucs_memory_type_t >( get_variant_value ( 2 ) );
1942
+ return variant_index_to_mem_type ( 2 );
1913
1943
}
1914
1944
1915
1945
ucs_memory_type_t rx_memtype () const override
1916
1946
{
1917
- return static_cast < ucs_memory_type_t >( get_variant_value ( 3 ) );
1947
+ return variant_index_to_mem_type ( 3 );
1918
1948
}
1949
+
1950
+ bool tx_memtype_async () const override
1951
+ {
1952
+ return get_variant_value (2 ) == MEMORY_TYPE_CUDA_ASYNC;
1953
+ }
1954
+
1955
+ bool rx_memtype_async () const override
1956
+ {
1957
+ return get_variant_value (3 ) == MEMORY_TYPE_CUDA_ASYNC;
1958
+ }
1959
+
1960
+ ucs_memory_type_t variant_index_to_mem_type (unsigned index) const
1961
+ {
1962
+ auto variant_value = get_variant_value (index );
1963
+ switch (variant_value) {
1964
+ case UCS_MEMORY_TYPE_HOST:
1965
+ case UCS_MEMORY_TYPE_CUDA:
1966
+ case UCS_MEMORY_TYPE_CUDA_MANAGED:
1967
+ case UCS_MEMORY_TYPE_ROCM:
1968
+ case UCS_MEMORY_TYPE_ROCM_MANAGED:
1969
+ return static_cast <ucs_memory_type_t >(variant_value);
1970
+ case MEMORY_TYPE_CUDA_ASYNC:
1971
+ return UCS_MEMORY_TYPE_CUDA;
1972
+ default :
1973
+ UCS_TEST_ABORT (" invalid memory type: " << variant_value);
1974
+ return UCS_MEMORY_TYPE_HOST;
1975
+ }
1976
+ }
1977
+
1978
+ static const int MEMORY_TYPE_CUDA_ASYNC = UCS_MEMORY_TYPE_LAST + 1 ;
1919
1979
};
1920
1980
1921
1981
UCS_TEST_P (test_ucp_am_nbx_rndv_memtype, rndv)
0 commit comments