From dc5ce613f5084417b31a23b3812d2915b2a9bb5e Mon Sep 17 00:00:00 2001 From: chulutao Date: Tue, 29 Mar 2022 17:13:44 +0800 Subject: [PATCH] [BugFix] Fix SemanticConnectivityLoss bug on cpu about F.one_hot --- .../models/losses/semantic_connectivity_loss.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/paddleseg/models/losses/semantic_connectivity_loss.py b/paddleseg/models/losses/semantic_connectivity_loss.py index 6451be656d..b54d545d09 100644 --- a/paddleseg/models/losses/semantic_connectivity_loss.py +++ b/paddleseg/models/losses/semantic_connectivity_loss.py @@ -92,6 +92,7 @@ def forward(self, logits, labels): label_num_conn, label_conn = cv2.connectedComponents( labels_np_class.astype(np.uint8)) + origin_pred_num_conn = pred_num_conn if pred_num_conn > 2 * label_num_conn: pred_num_conn = min(pred_num_conn, self.max_pred_num_conn) real_pred_num = pred_num_conn - 1 @@ -100,8 +101,9 @@ def forward(self, logits, labels): # Connected Components Matching and SC Loss Calculation if real_label_num > 0 and real_pred_num > 0: img_connectivity = compute_class_connectiveity( - pred_conn, label_conn, pred_num_conn, label_num_conn, - pred_i, real_label_num, real_pred_num, zero) + pred_conn, label_conn, pred_num_conn, + origin_pred_num_conn, label_num_conn, pred_i, + real_label_num, real_pred_num, zero) sc_loss += 1 - img_connectivity elif real_label_num == 0 and real_pred_num == 0: # if no connected component, SC Loss = 0, so pass @@ -122,12 +124,12 @@ def forward(self, logits, labels): def compute_class_connectiveity(pred_conn, label_conn, pred_num_conn, - label_num_conn, pred, real_label_num, - real_pred_num, zero): + origin_pred_num_conn, label_num_conn, pred, + real_label_num, real_pred_num, zero): pred_conn = paddle.to_tensor(pred_conn) label_conn = paddle.to_tensor(label_conn) - pred_conn = F.one_hot(pred_conn, pred_num_conn) + pred_conn = F.one_hot(pred_conn, origin_pred_num_conn) label_conn = F.one_hot(label_conn, label_num_conn) ious = paddle.zeros((real_label_num, real_pred_num))