Skip to content

Commit d533c63

Browse files
committed
Fix #492
1 parent c0d1a6e commit d533c63

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2626

2727
### Fixed
2828

29-
--
29+
- Rainbow multi-band scaler didn't work with list inputs https://github.com/light-curve/light-curve-python/issues/492 https://github.com/light-curve/light-curve-python/pull/493
3030

3131
### Security
3232

light-curve/light_curve/light_curve_py/features/rainbow/_scaler.py

+13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ class Scaler:
2323
Either a single value or an array of the same shape as the input array
2424
"""
2525

26+
def __eq__(self, other):
27+
if not isinstance(other, Scaler):
28+
return False
29+
return np.array_equal(self.shift, other.shift) and np.array_equal(self.scale, other.scale)
30+
2631
@classmethod
2732
def from_time(cls, t) -> "Scaler":
2833
"""Create a Scaler from a time array
@@ -55,13 +60,21 @@ class MultiBandScaler(Scaler):
5560
per_band_shift: Dict[str, float]
5661
"""Shift to apply to each band"""
5762

63+
def __eq__(self, other):
64+
if not isinstance(other, MultiBandScaler):
65+
return False
66+
return super().__eq__(other) and self.per_band_shift == other.per_band_shift
67+
5868
@classmethod
5969
def from_flux(cls, flux, band, *, with_baseline: bool) -> "MultiBandScaler":
6070
"""Create a Scaler from a flux array.
6171
6272
It uses standard deviation for the scale. For the shift, it is either
6373
zero (`with_baseline=False`) or the mean of each band otherwise.
6474
"""
75+
flux = np.asarray(flux)
76+
band = np.asarray(band)
77+
6578
uniq_bands = np.unique(band)
6679
per_band_shift = dict.fromkeys(uniq_bands, 0.0)
6780
shift_array = np.zeros(len(flux))

light-curve/tests/light_curve_py/features/test_rainbow.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from light_curve.light_curve_py import RainbowFit
4+
from light_curve.light_curve_py.features.rainbow._scaler import MultiBandScaler
45

56

67
def test_noisy_with_baseline():
@@ -113,3 +114,14 @@ def test_noisy_all_functions_combination():
113114
# plt.show()
114115

115116
np.testing.assert_allclose(actual[:-1], expected[:-1], rtol=0.1)
117+
118+
119+
def test_scaler_from_flux_list_input():
120+
"https://github.com/light-curve/light-curve-python/issues/492"
121+
# Was failing
122+
scaler1 = MultiBandScaler.from_flux(
123+
flux=[1.0, 2.0, 3.0, 4.0], band=np.array(["g", "r", "g", "r"]), with_baseline=True
124+
)
125+
# Was not failing, but was wrong
126+
scaler2 = MultiBandScaler.from_flux(flux=[1.0, 2.0, 3.0, 4.0], band=["g", "r", "g", "r"], with_baseline=True)
127+
assert scaler1 == scaler2

0 commit comments

Comments
 (0)