Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and updating version number #81

Merged
merged 15 commits into from
Dec 20, 2024
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ on:
jobs:
build:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
os: [ubuntu-latest, windows-latest, macos-latest]
Expand Down
2 changes: 1 addition & 1 deletion bpx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .schema import BPX, check_sto_limits
from .utilities import get_electrode_concentrations, get_electrode_stoichiometries

__version__ = "0.4.2"
__version__ = "0.5.0"

__all__ = [
"BPX",
Expand Down
19 changes: 1 addition & 18 deletions bpx/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import copy
import tempfile
from importlib import util
from pathlib import Path
from typing import TYPE_CHECKING, Any

from pydantic_core import CoreSchema, core_schema
Expand Down Expand Up @@ -87,26 +86,10 @@ def to_python_function(self, preamble: str | None = None) -> Callable:
source_code = preamble + function_def + function_body

with tempfile.NamedTemporaryFile(suffix=f"{function_name}.py", delete=False) as tmp:
# write to a tempory file so we can
# get the source later on using inspect.getsource
# (as long as the file still exists)
tmp.write((source_code).encode())
tmp.write(source_code.encode())
tmp.flush()

# Now load that file as a module
spec = util.spec_from_file_location("tmp", tmp.name)
module = util.module_from_spec(spec)
spec.loader.exec_module(module)

# Delete
tmp.close()
Path(tmp.name).unlink(missing_ok=True)
if module.__cached__:
cached_file = Path(module.__cached__)
cached_path = cached_file.parent
cached_file.unlink(missing_ok=True)
if not any(cached_path.iterdir()):
cached_path.rmdir()

# return the new function object
Comment on lines -101 to -111
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was causing failures due to cached files not being found. It did not seem necessary

return getattr(module, function_name)
19 changes: 8 additions & 11 deletions bpx/parsers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from __future__ import annotations

import json
from pathlib import Path

from .schema import BPX


Expand Down Expand Up @@ -26,13 +31,13 @@ def parse_bpx_obj(bpx: dict, v_tol: float = 0.001) -> BPX:
return BPX.model_validate(bpx)


def parse_bpx_file(filename: str, v_tol: float = 0.001) -> BPX:
def parse_bpx_file(filename: str | Path, v_tol: float = 0.001) -> BPX:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes the file name more flexible

"""
A convenience function to parse a bpx file into a BPX model.

Parameters
----------
filename: str
filename: str or Path
a filepath to a bpx file
v_tol: float
absolute tolerance in [V] to validate the voltage limits, 1 mV by default
Expand All @@ -42,18 +47,12 @@ def parse_bpx_file(filename: str, v_tol: float = 0.001) -> BPX:
BPX: :class:`bpx.BPX`
a parsed BPX model
"""

from pathlib import Path

bpx = ""
if filename.endswith((".yml", ".yaml")):
if str(filename).endswith((".yml", ".yaml")):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to the path check

import yaml

with Path(filename).open(encoding="utf-8") as f:
bpx = yaml.safe_load(f)
else:
import orjson as json

with Path(filename).open(encoding="utf-8") as f:
bpx = json.loads(f.read())

Expand All @@ -77,7 +76,5 @@ def parse_bpx_str(bpx: str, v_tol: float = 0.001) -> BPX:
BPX:
a parsed BPX model
"""
import orjson as json

bpx = json.loads(bpx)
return parse_bpx_obj(bpx, v_tol)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ dependencies = [
"pydantic >= 2.6",
"pyparsing",
"pyyaml",
"orjson",
]

[project.urls]
Expand Down
51 changes: 26 additions & 25 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import unittest
import warnings
from typing import Any

import pytest
from pydantic import TypeAdapter, ValidationError
Expand All @@ -12,7 +13,7 @@

class TestSchema(unittest.TestCase):
def setUp(self) -> None:
self.base = {
self.base : dict[str, Any] = {
"Header": {
"BPX": 1.0,
"Model": "DFN",
Expand Down Expand Up @@ -200,26 +201,26 @@ def setUp(self) -> None:
}

def test_simple(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
adapter.validate_python(test)

def test_simple_spme(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Header"]["Model"] = "SPMe"
adapter.validate_python(test)

def test_simple_spm(self) -> None:
test = copy.copy(self.base_spm)
test = copy.deepcopy(self.base_spm)
adapter.validate_python(test)

def test_bad_model(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Header"]["Model"] = "Wrong model type"
with pytest.raises(ValidationError):
adapter.validate_python(test)

def test_bad_dfn(self) -> None:
test = copy.copy(self.base_spm)
test = copy.deepcopy(self.base_spm)
test["Header"]["Model"] = "DFN"
with pytest.warns(
UserWarning,
Expand All @@ -228,7 +229,7 @@ def test_bad_dfn(self) -> None:
adapter.validate_python(test)

def test_bad_spme(self) -> None:
test = copy.copy(self.base_spm)
test = copy.deepcopy(self.base_spm)
test["Header"]["Model"] = "SPMe"
with pytest.warns(
UserWarning,
Expand All @@ -237,7 +238,7 @@ def test_bad_spme(self) -> None:
adapter.validate_python(test)

def test_bad_spm(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Header"]["Model"] = "SPM"
with pytest.warns(
UserWarning,
Expand All @@ -246,15 +247,15 @@ def test_bad_spm(self) -> None:
adapter.validate_python(test)

def test_table(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Parameterisation"]["Electrolyte"]["Conductivity [S.m-1]"] = {
"x": [1.0, 2.0],
"y": [2.3, 4.5],
}
adapter.validate_python(test)

def test_bad_table(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Parameterisation"]["Electrolyte"]["Conductivity [S.m-1]"] = {
"x": [1.0, 2.0],
"y": [2.3],
Expand All @@ -266,37 +267,37 @@ def test_bad_table(self) -> None:
adapter.validate_python(test)

def test_function(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Parameterisation"]["Electrolyte"]["Conductivity [S.m-1]"] = "1.0 * x + 3"
adapter.validate_python(test)

def test_function_with_exp(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Parameterisation"]["Electrolyte"]["Conductivity [S.m-1]"] = "1.0 * exp(x) + 3"
adapter.validate_python(test)

def test_bad_function(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Parameterisation"]["Electrolyte"]["Conductivity [S.m-1]"] = "this is not a function"
with pytest.raises(ValidationError):
adapter.validate_python(test)

def test_to_python_function(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Parameterisation"]["Electrolyte"]["Conductivity [S.m-1]"] = "2.0 * x"
obj = adapter.validate_python(test)
funct = obj.parameterisation.electrolyte.conductivity
pyfunct = funct.to_python_function()
assert pyfunct(2.0) == 4.0

def test_bad_input(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Parameterisation"]["Electrolyte"]["bad"] = "this shouldn't be here"
with pytest.raises(ValidationError):
adapter.validate_python(test)

def test_validation_data(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Validation"] = {
"Experiment 1": {
"Time [s]": [0, 1000, 2000],
Expand All @@ -314,39 +315,39 @@ def test_validation_data(self) -> None:

def test_check_sto_limits_validator(self) -> None:
warnings.filterwarnings("error") # Treat warnings as errors
test = copy.copy(self.base_non_blended)
test = copy.deepcopy(self.base_non_blended)
test["Parameterisation"]["Cell"]["Upper voltage cut-off [V]"] = 4.3
test["Parameterisation"]["Cell"]["Lower voltage cut-off [V]"] = 2.5
adapter.validate_python(test)

def test_check_sto_limits_validator_high_voltage(self) -> None:
test = copy.copy(self.base_non_blended)
test = copy.deepcopy(self.base_non_blended)
test["Parameterisation"]["Cell"]["Upper voltage cut-off [V]"] = 4.0
with pytest.warns(UserWarning):
adapter.validate_python(test)

def test_check_sto_limits_validator_high_voltage_tolerance(self) -> None:
warnings.filterwarnings("error") # Treat warnings as errors
test = copy.copy(self.base_non_blended)
test = copy.deepcopy(self.base_non_blended)
test["Parameterisation"]["Cell"]["Upper voltage cut-off [V]"] = 4.0
BPX.Settings.tolerances["Voltage [V]"] = 0.25
adapter.validate_python(test)

def test_check_sto_limits_validator_low_voltage(self) -> None:
test = copy.copy(self.base_non_blended)
test = copy.deepcopy(self.base_non_blended)
test["Parameterisation"]["Cell"]["Lower voltage cut-off [V]"] = 3.0
with pytest.warns(UserWarning):
adapter.validate_python(test)

def test_check_sto_limits_validator_low_voltage_tolerance(self) -> None:
warnings.filterwarnings("error") # Treat warnings as errors
test = copy.copy(self.base_non_blended)
test = copy.deepcopy(self.base_non_blended)
test["Parameterisation"]["Cell"]["Lower voltage cut-off [V]"] = 3.0
BPX.Settings.tolerances["Voltage [V]"] = 0.35
adapter.validate_python(test)

def test_user_defined(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Parameterisation"]["User-defined"] = {
"a": 1.0,
"b": 2.0,
Expand All @@ -358,7 +359,7 @@ def test_user_defined(self) -> None:
assert obj.parameterisation.user_defined.c == 3

def test_user_defined_table(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Parameterisation"]["User-defined"] = {
"a": {
"x": [1.0, 2.0],
Expand All @@ -368,12 +369,12 @@ def test_user_defined_table(self) -> None:
adapter.validate_python(test)

def test_user_defined_function(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
test["Parameterisation"]["User-defined"] = {"a": "2.0 * x"}
adapter.validate_python(test)

def test_bad_user_defined(self) -> None:
test = copy.copy(self.base)
test = copy.deepcopy(self.base)
# bool not allowed type
test["Parameterisation"]["User-defined"] = {
"bad": True,
Expand Down
Loading