Skip to content

Commit f3ec4a9

Browse files
authored
[SYCL][CUDA] Fix context clearing in PiCuda tests (#4483)
`cuCtxSetCurrent(nullptr)` will only discard the top of the context stack so the current context may still not be `nullptr` after this. To fix this, this patch introduces a small utility function to pop the entire context stack when we're trying to reset it in the tests.
1 parent 49e1e74 commit f3ec4a9

File tree

4 files changed

+27
-4
lines changed

4 files changed

+27
-4
lines changed

sycl/unittests/pi/cuda/CudaUtils.hpp

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
// See https://llvm.org/LICENSE.txt for license information.
3+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
#pragma once
6+
7+
#include <cuda.h>
8+
9+
namespace pi {
10+
11+
// utility function to clear the CUDA context stack
12+
inline void clearCudaContext() {
13+
CUcontext ctxt = nullptr;
14+
do {
15+
cuCtxSetCurrent(nullptr);
16+
cuCtxGetCurrent(&ctxt);
17+
} while (ctxt != nullptr);
18+
}
19+
20+
} // namespace pi

sycl/unittests/pi/cuda/test_commands.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <cuda.h>
1212

13+
#include "CudaUtils.hpp"
1314
#include "TestGetPlugin.hpp"
1415
#include <CL/sycl.hpp>
1516
#include <CL/sycl/detail/pi.hpp>
@@ -34,7 +35,7 @@ struct CudaCommandsTest : public ::testing::Test {
3435
GTEST_SKIP();
3536
}
3637

37-
cuCtxSetCurrent(nullptr);
38+
pi::clearCudaContext();
3839
pi_uint32 numPlatforms = 0;
3940
ASSERT_EQ(plugin->getBackend(), backend::cuda);
4041

sycl/unittests/pi/cuda/test_contexts.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include <cuda.h>
1616

17+
#include "CudaUtils.hpp"
1718
#include "TestGetPlugin.hpp"
1819
#include <CL/sycl.hpp>
1920
#include <CL/sycl/detail/pi.hpp>
@@ -63,7 +64,7 @@ struct CudaContextsTest : public ::testing::Test {
6364

6465
TEST_F(CudaContextsTest, ContextLifetime) {
6566
// start with no active context
66-
cuCtxSetCurrent(nullptr);
67+
pi::clearCudaContext();
6768

6869
// create a context
6970
pi_context context;
@@ -149,7 +150,7 @@ TEST_F(CudaContextsTest, ContextLifetimeExisting) {
149150
// still able to work correctly in that thread.
150151
TEST_F(CudaContextsTest, ContextThread) {
151152
// start with no active context
152-
cuCtxSetCurrent(nullptr);
153+
pi::clearCudaContext();
153154

154155
// create two PI contexts
155156
pi_context context1;

sycl/unittests/pi/cuda/test_mem_obj.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <cuda.h>
1212

13+
#include "CudaUtils.hpp"
1314
#include "TestGetPlugin.hpp"
1415
#include <CL/sycl.hpp>
1516
#include <CL/sycl/detail/cuda_definitions.hpp>
@@ -34,7 +35,7 @@ struct CudaTestMemObj : public ::testing::Test {
3435
GTEST_SKIP();
3536
}
3637

37-
cuCtxSetCurrent(nullptr);
38+
pi::clearCudaContext();
3839
pi_uint32 numPlatforms = 0;
3940
ASSERT_EQ(plugin->getBackend(), backend::cuda);
4041

0 commit comments

Comments
 (0)