From 6c58c07f3cde610a939012471e3f1d6c18587387 Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Fri, 28 Feb 2025 00:05:20 -0800 Subject: [PATCH] Add a new forward proc --- torchdistill/core/interfaces/forward_proc.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torchdistill/core/interfaces/forward_proc.py b/torchdistill/core/interfaces/forward_proc.py index 56de0c05..029e3aec 100644 --- a/torchdistill/core/interfaces/forward_proc.py +++ b/torchdistill/core/interfaces/forward_proc.py @@ -37,6 +37,25 @@ def forward_batch_only(model, sample_batch, targets=None, supp_dict=None, **kwar return model(sample_batch) +@register_forward_proc_func +def forward_batch_only_as_kwargs(model, sample_batch, targets=None, supp_dict=None): + """ + Performs forward computation using `sample_batch` only. + + :param model: model. + :type model: nn.Module + :param sample_batch: sample batch. + :type sample_batch: dict + :param targets: training targets (won't be passed to forward). + :type targets: Any + :param supp_dict: supplementary dict (won't be passed to forward). + :type supp_dict: dict + :return: model's forward output. + :rtype: Any + """ + return model(**sample_batch) + + @register_forward_proc_func def forward_batch_target(model, sample_batch, targets, supp_dict=None, **kwargs): """