11
11
#include <string.h>
12
12
13
13
#include "iree/base/api.h"
14
+ #include "iree/base/internal/synchronization.h"
14
15
#include "iree/base/tracing.h"
15
16
16
17
typedef struct iree_hal_hip_buffer_t {
@@ -19,6 +20,9 @@ typedef struct iree_hal_hip_buffer_t {
19
20
void * host_ptr ;
20
21
hipDeviceptr_t device_ptr ;
21
22
iree_hal_buffer_release_callback_t release_callback ;
23
+ iree_slim_mutex_t device_ptr_lock ;
24
+ iree_notification_t device_ptr_notification ;
25
+ bool empty ;
22
26
} iree_hal_hip_buffer_t ;
23
27
24
28
static const iree_hal_buffer_vtable_t iree_hal_hip_buffer_vtable ;
@@ -65,13 +69,36 @@ iree_status_t iree_hal_hip_buffer_wrap(
65
69
buffer -> host_ptr = host_ptr ;
66
70
buffer -> device_ptr = device_ptr ;
67
71
buffer -> release_callback = release_callback ;
72
+ buffer -> empty = false;
73
+ iree_slim_mutex_initialize (& buffer -> device_ptr_lock );
74
+ iree_notification_initialize (& buffer -> device_ptr_notification );
68
75
* out_buffer = & buffer -> base ;
69
76
}
70
77
71
78
IREE_TRACE_ZONE_END (z0 );
72
79
return status ;
73
80
}
74
81
82
+ void iree_hal_hip_buffer_set_device_pointer (iree_hal_buffer_t * base_buffer ,
83
+ hipDeviceptr_t pointer ) {
84
+ iree_hal_hip_buffer_t * buffer = iree_hal_hip_buffer_cast (base_buffer );
85
+ IREE_ASSERT (buffer -> device_ptr == NULL ,
86
+ "Cannot set a device_ptr to a buffer that already has one" );
87
+ iree_slim_mutex_lock (& buffer -> device_ptr_lock );
88
+ buffer -> device_ptr = pointer ;
89
+ iree_slim_mutex_unlock (& buffer -> device_ptr_lock );
90
+ iree_notification_post (& buffer -> device_ptr_notification , IREE_ALL_WAITERS );
91
+ }
92
+
93
+ void iree_hal_hip_buffer_set_allocation_empty (iree_hal_buffer_t * base_buffer ) {
94
+ iree_hal_hip_buffer_t * buffer = iree_hal_hip_buffer_cast (base_buffer );
95
+ iree_slim_mutex_lock (& buffer -> device_ptr_lock );
96
+ buffer -> empty = true;
97
+ buffer -> device_ptr = NULL ;
98
+ iree_slim_mutex_unlock (& buffer -> device_ptr_lock );
99
+ iree_notification_post (& buffer -> device_ptr_notification , IREE_ALL_WAITERS );
100
+ }
101
+
75
102
static void iree_hal_hip_buffer_destroy (iree_hal_buffer_t * base_buffer ) {
76
103
iree_hal_hip_buffer_t * buffer = iree_hal_hip_buffer_cast (base_buffer );
77
104
iree_allocator_t host_allocator = base_buffer -> host_allocator ;
@@ -80,6 +107,8 @@ static void iree_hal_hip_buffer_destroy(iree_hal_buffer_t* base_buffer) {
80
107
buffer -> release_callback .fn (buffer -> release_callback .user_data ,
81
108
base_buffer );
82
109
}
110
+ iree_slim_mutex_deinitialize (& buffer -> device_ptr_lock );
111
+ iree_notification_deinitialize (& buffer -> device_ptr_notification );
83
112
iree_allocator_free (host_allocator , buffer );
84
113
IREE_TRACE_ZONE_END (z0 );
85
114
}
@@ -143,10 +172,20 @@ iree_hal_hip_buffer_type_t iree_hal_hip_buffer_type(
143
172
return buffer -> type ;
144
173
}
145
174
175
+ static bool iree_hal_hip_buffer_has_device_ptr (void * arg ) {
176
+ iree_hal_hip_buffer_t * buffer = (iree_hal_hip_buffer_t * )arg ;
177
+ iree_slim_mutex_lock (& buffer -> device_ptr_lock );
178
+ bool has_ptr_or_error = buffer -> device_ptr || buffer -> empty ;
179
+ iree_slim_mutex_unlock (& buffer -> device_ptr_lock );
180
+ return has_ptr_or_error ;
181
+ }
182
+
146
183
hipDeviceptr_t iree_hal_hip_buffer_device_pointer (
147
- const iree_hal_buffer_t * base_buffer ) {
148
- const iree_hal_hip_buffer_t * buffer =
149
- iree_hal_hip_buffer_const_cast (base_buffer );
184
+ iree_hal_buffer_t * base_buffer ) {
185
+ iree_hal_hip_buffer_t * buffer = iree_hal_hip_buffer_cast (base_buffer );
186
+ iree_notification_await (& buffer -> device_ptr_notification ,
187
+ iree_hal_hip_buffer_has_device_ptr , buffer ,
188
+ iree_infinite_timeout ());
150
189
return buffer -> device_ptr ;
151
190
}
152
191
0 commit comments