Skip to content

Commit c479989

Browse files
Add fail_if_symlink to fns.io functions (#3150)
### Changes - Add `fail_if_symlink` check to check symbolic links before load statistics files - Remove try block for `fns.io.save_file` in dump_statistics ### Reason for changes Prevent problems with symlinks --------- Co-authored-by: Alexander Dokuchaev <alexander.dokuchaev@intel.com>
1 parent 5a55a7d commit c479989

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

nncf/common/tensor_statistics/statistics_serializer.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,7 @@ def dump_statistics(
139139

140140
# Update the mapping
141141
metadata["mapping"][unique_sanitized_name] = original_name
142-
143-
try:
144-
fns.io.save_file(statistics_value, file_path)
145-
except Exception as e:
146-
raise nncf.InternalError(f"Failed to write data to file {file_path}: {e}")
142+
fns.io.save_file(statistics_value, file_path)
147143

148144
if additional_metadata:
149145
metadata |= additional_metadata

nncf/tensor/functions/io.py

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pathlib import Path
1414
from typing import Dict, Optional
1515

16+
from nncf.common.utils.os import fail_if_symlink
1617
from nncf.tensor import Tensor
1718
from nncf.tensor.definitions import TensorBackend
1819
from nncf.tensor.definitions import TensorDeviceType
@@ -35,6 +36,7 @@ def load_file(
3536
then the default device is determined by backend.
3637
:return: A dictionary where the keys are tensor names and the values are Tensor objects.
3738
"""
39+
fail_if_symlink(file_path)
3840
loaded_dict = get_io_backend_fn("load_file", backend)(file_path, device=device)
3941
return {key: Tensor(val) for key, val in loaded_dict.items()}
4042

@@ -50,6 +52,7 @@ def save_file(
5052
:param data: A dictionary where the keys are tensor names and the values are Tensor objects.
5153
:param file_path: The path to the file where the tensor data will be saved.
5254
"""
55+
fail_if_symlink(file_path)
5356
if isinstance(data, dict):
5457
return dispatch_dict(save_file, data, file_path)
5558
raise NotImplementedError(f"Function `save_file` is not implemented for {type(data)}")

tests/cross_fw/test_templates/template_test_nncf_tensor.py

+15
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import pytest
2121

22+
import nncf
2223
import nncf.tensor.functions as fns
2324
from nncf.experimental.common.tensor_statistics import statistical_functions as s_fns
2425
from nncf.tensor import Tensor
@@ -1709,6 +1710,20 @@ def test_save_load_file(self, tmp_path, data):
17091710
assert loaded_stat[tensor_key].device == tensor.device
17101711
assert loaded_stat[tensor_key].dtype == tensor.dtype
17111712

1713+
def test_save_load_symlink_error(self, tmp_path):
1714+
file_path = tmp_path / "test_tensor"
1715+
symlink_path = tmp_path / "symlink_test_tensor"
1716+
symlink_path.symlink_to(file_path)
1717+
1718+
tensor_key = "tensor_key"
1719+
tensor = Tensor(self.to_tensor([1, 2]))
1720+
stat = {tensor_key: tensor}
1721+
1722+
with pytest.raises(nncf.ValidationError, match="symbolic link"):
1723+
fns.io.save_file(stat, symlink_path)
1724+
with pytest.raises(nncf.ValidationError, match="symbolic link"):
1725+
fns.io.load_file(symlink_path, backend=tensor.backend)
1726+
17121727
@pytest.mark.parametrize("data", [[3.0, 2.0, 2.0], [1, 2, 3]])
17131728
@pytest.mark.parametrize("dtype", [TensorDataType.float32, TensorDataType.int32, TensorDataType.uint8, None])
17141729
def test_fn_tensor(self, data, dtype):

0 commit comments

Comments
 (0)