Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP on multinode [not yet working] #55

Merged
merged 4 commits into from
Jan 16, 2023
Merged

DDP on multinode [not yet working] #55

merged 4 commits into from
Jan 16, 2023

Conversation

karpathy
Copy link
Owner

I'm now experimenting with running nanoGPT on multinode DDP, on this branch. Everything works fine on an individual node, but multinode crashing immediately before the training is about to start. I'm launching manually on two nodes, with the following commands respectively on each one:

$ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
$ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py

I see the following error:

[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU opera
tions might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.

And the whole thing crashes down. Currently failing to find a resolution, if anyone made nanoGPT work multinode please let me know if it worked and you ran into above and resolved it, or never seen the issue, ty.

@chuanli11
Copy link

Can you try adding NCCL_IB_DISABLE=1 to the training command? It could be caused by NCCL looking for IB for inter-node communication when there isn't one available.

@karpathy
Copy link
Owner Author

Thank you @chuanli11 that resolved the problem! I'm curious where did you come upon the solution originally?

As an aside, the numbers I'm seeing atm training the base GPT-2 (124M) model:

  • 8X A100 40GB DDP on 1 node alone: 163ms/iter
  • 2 nodes: 800ms/iter ...

I benchmarked the (non-infiniband, clearly, haha) connection between these nodes (in the same region) to be only ~3Gbits/s. 124M params in fp32 (as gradients would be) is ~124e6*32 = 4B bits, so it's O(~1s) to transfer, hence the ~800ms time. i.e. the code is spending all of its time syncing gradients.

Anyway, this verifies the implementation to be good, so I'll merge this PR.

@karpathy karpathy merged commit 684800d into master Jan 16, 2023
@chuanli11
Copy link

Glad it "helps". I had similar problem before. I used NCCL_DEBUG=INFO to see what could cause the problem and it gave Got completion with error 12, opcode 0, len 0, vendor err 129. Found this thread, which suggests it could be NICs not talking to each other.

Knowing the Lambda instances i used (on-demand SXM in the same region) currently do no support IB, I decided to just explicitly disable IB.

Your ~3Gbits/s inter-node bandwidth is inline with my iperf test. I observed some throughput gain from 1 node to 2 nodes with GPT-NeoX 13B, but going beyond 2 nodes didn't help. It seems to be even less efficient in your case (could be I used a heavier model, or the extra optimization e.g. deepspeed used in NeoX's code).

klei22 pushed a commit to klei22/nanoGPT that referenced this pull request Jan 20, 2024
Merge in the GPU Quickstart Colab PR, let's utilize this as a template for future colabs.
gkielian added a commit to gkielian/ReaLLMASIC_nanogpt that referenced this pull request Sep 5, 2024
Merge in the GPU Quickstart Colab PR, let's utilize this as a template for future colabs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants