Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-stream allocation for CUDA place #37290

Merged
merged 13 commits into from
Nov 25, 2021

Conversation

From00
Copy link
Contributor

@From00 From00 commented Nov 17, 2021

PR types

New features

PR changes

Others

Describe

This PR is to support multi-stream alloc and free in CUDA place.
A new StreamSafeCUDAAllocator is implement, which support safe and efficient CUDA memory alloc and GC. The core ideas are:

  1. an allocation is associated with a CUDA stream, i.e., the stream who firstly requests this allocation
  2. the allocation can only be re-alloced to the associated stream
  3. other streams who use this allocation asynchronously should be recorded proactively
  4. when trying to free an allocation, the CUDA event for the recored streams will be created, while the records from the associated stream are ignored
  5. the free operation is delay until all events of other stream are complete to ensure the correctness of asynchronous CUDA kernel execution

Interface changes

  1. The old interface "AllocShared", "Alloc", and "Release" use NULL stream implicitly.
  2. A new set of interfaces was exposed, which supports to pass stream parameter in.
  3. A "RecordedStream" interface is exposed. When reusing the memory from another CUDA stream, the "RecordedStream" should be called on the host side after kernel launch.

Notes

  1. Only support auto_growth allocator strategy now
  2. Set "FLAGS_use_stream_safe_cuda_allocator=true" to enable it

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retry机制,在本Stream有显存释放时,不需要通知其它的Stream

@From00
Copy link
Contributor Author

From00 commented Nov 18, 2021

Retry机制,在本Stream有显存释放时,不需要通知其它的Stream

新提交的commit已经做了修改。在显存释放时只在RetryAllocator中通知本stream的retry,不通知其它的stream。同时,RetryAllocator重试若超时失败,上层AllocatorFacade会尝试释放所有stream的显存,然后再继续尝试分配。

Comment on lines 242 to 245
if (FLAGS_use_stream_safe_cuda_allocator && platform::is_gpu_place(place) &&
size > 0) {
return GetCUDAAllocator(BOOST_GET_CONST(platform::CUDAPlace, place),
default_stream_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的处理方式是FLAGS_use_stream_safe_cuda_allocator的优先级是最高的,高于 size==0,也高于FLAGS_use_system_allocator。但我觉得FLAGS_use_stream_safe_cuda_allocator的优先级低于另两个更合理。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢,已经修改,在size>0之后加入了FLAGS_use_system_allocator == false的判断,只有在FLAGS_use_system_allocator == false的情况下,才会走多stream的逻辑:

FLAGS_use_system_allocator == false) {

Comment on lines 29 to 36
void StreamSafeCUDAAllocation::RecordStream(gpuStream_t stream) {
VLOG(8) << "Record stream " << stream << " to " << ptr();
if (stream == owning_stream_) {
return;
}
std::lock_guard<std::mutex> lock(mutex_);
recorded_streams_->insert(stream);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好是在这里就调用EventRecord,我看咱们是在FreeAllocation的时候调用CreateEventForAllRecordedStream来Record的。越早的Record Event,就能越早的释放显存。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里如果在RecordStream的时候直接创建event,虽然可以减少释放显存的delay时间,但可能会出现对同一个stream多次调用RecordStream,从而创建了多个Event的情况。在有新Alloca的时候才创建event,可以减少event的数量。
由于event的创建和查询也会有较大的时间开销,所以哪种实现方式比较好取决于上层的显存使用模式。如果上层会频繁重复调用RecordStream,则现在的实现方式更好,反之则在RecordStream的时候直接创建event比较好。这个问题目前未有定论,需要后边在实际模型中实测看看,所以先随便实现了一种,之后这块如果有比较明确的实测结论说明早点创建event比较好,再来做改进。

bool StreamSafeCUDAAllocator::IsAllocThreadSafe() const { return true; }

Allocation* StreamSafeCUDAAllocator::AllocateImpl(size_t size) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么使用recursive_mutex呢?以前实验中发现,Allocator的申请释放的锁是会频繁发生碰撞的,如果因为锁碰撞导致线程进入挂起状态,等另一个线程释放锁后,挂起的线程还需要唤醒。唤醒的时间成本很高。所以后来这里我们开始使用自旋锁了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢,已经改成自旋锁。

AllocationPtr underlying_allocation = underlying_allocator_->Allocate(size);
StreamSafeCUDAAllocation* allocation = new StreamSafeCUDAAllocation(
std::move(underlying_allocation), default_stream_);
allocation_info_map_[allocation] = std::make_shared<AllocationInfo>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我建议,被申请的allocation和被Record的allocation分开存储。要知道,被申请的allocation很多,而且有很多是在频繁的申请-释放。但是会发生Record的比例非常小。如果分开存储,在ProcessEventsAndFree的时候,只需要遍历Record的allocation就行了。否则ProcessEventsAndFree的性能会差。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢,已经修改。在alloca时不会往map中插入所申请allocation的信息,只有在该allocation被free时,才维护到map里。这样就不会出现每次ProcessEventsAndFree时都重复遍历大量未释放的allocation。

Comment on lines +95 to +96
std::deque<gpuEvent_t>& outstanding_events =
outstanding_events_map_[allocation];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里能不能先拿到dynamic_cast<StreamSafeCUDAAllocation*>(allocation)->GetRecordedStreams().get(),然后看这个set是否为空,如果不为空,再outstanding_events_map_[allocation]=CreateEventForAllRecordedStream() ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢,已经修改。在FreeImpl中加入了判断,recorded_streams为空的allocation直接释放,非空的才走FreeStreamSafeCUDAAllocation函数中创建outstanding_events和插入map的相关逻辑:

void StreamSafeCUDAAllocator::FreeImpl(Allocation* allocation) {
  std::lock_guard<SpinLock> lock_guard(spin_lock_);
  if (dynamic_cast<StreamSafeCUDAAllocation*>(allocation)
          ->GetRecordedStreams()
          ->empty()) {
    delete allocation;
  } else {
    FreeStreamSafeCUDAAllocation(allocation);
  }
}

…_streams is empty in FreeImpl of StreamSafeCUDAAllocator
wanghuancoder
wanghuancoder previously approved these changes Nov 22, 2021
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

return allocation::AllocatorFacade::Instance().Release(place);
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
std::shared_ptr<Allocation> AllocShared(const platform::CUDAPlace& place,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can make stream as the last parameter and make it default nullptr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Comment on lines 303 to 320

try {
return cuda_allocator->Allocate(size);
} catch (BadAlloc&) {
VLOG(9) << "Allocation failed when allocating " << size
<< " bytes for stream " << stream;
for (auto pair : cuda_allocators_[place]) {
pair.second->Release(place);
}
try {
return cuda_allocator->Allocate(size);
} catch (...) {
VLOG(9) << "Still allocation failed "
<< "after release memory from all streams";
throw;
}
} catch (...) {
throw;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here can be removed to StreamSafeCUDAAllocator

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiqiu zhiqiu merged commit b9c464c into PaddlePaddle:develop Nov 25, 2021
@From00 From00 deleted the stream-safe-cuda-allocator branch December 5, 2021 03:17
Zjq9409 pushed a commit to Zjq9409/Paddle that referenced this pull request Dec 10, 2021
* Support multi-stream allocation for CUDA place

* Do not notify the retrying from other streams when free CUDA allocation

* Fix compile error for CPU

* Fix compile error for HIP

* Release memory for StreamSafeCUDAAllocaRetry in malloc_test

* Add FLAGS_use_stream_safe_cuda_allocator

* Fix CI error for 'set_tests_properties'

* Invalidate stream safe CUDA allocator for naive_best_fit and thread_local strategy

* Performance improvement: insert allocation pair to outstanding_events_map when free but not alloc; replace recursive_mutex with SpinLock

* FLAGS priority changes: FLAGS_use_system_allocator > FLAGS_use_stream_safe_cuda_allocator

* Performance improvement: directly delete allocation when the recorded_streams is empty in FreeImpl of StreamSafeCUDAAllocator

* Add UT for alloc interface

* Changes multi-stream interface; move retry code from AllocatorFacadePrivate to StreamSafeCUDAAllocator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants