Skip to content

Commit a2a9c36

Browse files
Add release call into wrapped_callback.
1 parent 0db9154 commit a2a9c36

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

python/pyarrow/tensorflow/plasma_op.cc

+10-2
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,19 @@ class PlasmaToTensorOp : public tf::AsyncOpKernel {
297297
OP_REQUIRES_OK_ASYNC(context, context->allocate_output(0, shape, &output_tensor),
298298
done);
299299

300+
auto wrapped_callback = [this, context, done, object_id]() {
301+
{
302+
tf::mutex_lock lock(mu_);
303+
ARROW_CHECK_OK(client_.Release(object_id));
304+
}
305+
done();
306+
};
307+
300308
if (std::is_same<Device, CPUDevice>::value) {
301309
std::memcpy(
302310
reinterpret_cast<void*>(const_cast<char*>(output_tensor->tensor_data().data())),
303311
plasma_data, size_in_bytes);
304-
done();
312+
wrapped_callback();
305313
} else {
306314
#ifdef GOOGLE_CUDA
307315
auto orig_stream = context->op_device_context()->stream();
@@ -340,7 +348,7 @@ class PlasmaToTensorOp : public tf::AsyncOpKernel {
340348
CHECK(orig_stream->ThenWaitFor(h2d_stream).ok());
341349

342350
context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
343-
h2d_stream, std::move(done));
351+
h2d_stream, std::move(wrapped_callback));
344352
#endif
345353
}
346354
}

0 commit comments

Comments
 (0)