Multiple dataloaders in training_step() and use them separately #18543
Replies: 3 comments 1 reply
-
Hi @thangld201, this is an interesting use case. Before reading the proposition below, you may want to answer a question: why do you need to have it iterate each dataset in a separate batch? Why not do what you're suggesting and feed dataset 1, compute loss, BP, feed dataset 2, compute loss, BP,..., feed dataset If it's more nuanced that the above, read on. 🙂 Ultimately, you may not want to use the following, but there could be one or two good ideas. A brute force approach may be to load all the datasets into a single Dataset object that would return the particular dataset based on the For example, suppose that you have
The downside to this approach is that - because you may look to "repurpose" the index - multiple problems come up:
|
Beta Was this translation helpful? Give feedback.
-
@changspencer I am having a similar issue so want to post here and not open a new discussion After following the info about I don't want to fit my model sequentially as I feel that'd just leave to the model forgetting/overfitting with class imbalance I believe is present between the two datasets, but that's really just a hunch so if that is really the correct way to do it let me know. |
Beta Was this translation helpful? Give feedback.
-
@changspencer To answer why I want to handle each data loader separately, the reason is that in the robustness test, I want a result for each corruption to be logged separately. Can we do anything in the |
Beta Was this translation helpful? Give feedback.
-
Hi, I’m figuring out how to use multiple dataloaders in
training_step()
of LightningModule. Currently, I pass a list of dataloaders intrainer.fit()
, it will return the list of batches, each from a dataloader simultaneously. However, my use case differs in that I would want to process each batch from each dataset sequentially.For example, I have three datasets. For step i, I receive a batch from dataset 0, update my model. For step i+1, I receive a batch from dataset 1 and update my model. For step i+2, I get a batch from dataset 2 and update my model. The process repeats until all samples are iterated.
How can I implement this in Pytorch Lightning ? Are there already supports for this ? I would be happy to dive in myself, but I don’t know where to start.
Beta Was this translation helpful? Give feedback.
All reactions