diff --git a/torchdistill/core/interfaces/forward_proc.py b/torchdistill/core/interfaces/forward_proc.py index 56de0c05..029e3aec 100644 --- a/torchdistill/core/interfaces/forward_proc.py +++ b/torchdistill/core/interfaces/forward_proc.py @@ -37,6 +37,25 @@ def forward_batch_only(model, sample_batch, targets=None, supp_dict=None, **kwar return model(sample_batch) +@register_forward_proc_func +def forward_batch_only_as_kwargs(model, sample_batch, targets=None, supp_dict=None): + """ + Performs forward computation using `sample_batch` only. + + :param model: model. + :type model: nn.Module + :param sample_batch: sample batch. + :type sample_batch: dict + :param targets: training targets (won't be passed to forward). + :type targets: Any + :param supp_dict: supplementary dict (won't be passed to forward). + :type supp_dict: dict + :return: model's forward output. + :rtype: Any + """ + return model(**sample_batch) + + @register_forward_proc_func def forward_batch_target(model, sample_batch, targets, supp_dict=None, **kwargs): """