Skip to content

Commit e531bd9

Browse files
authored
Merge pull request #324 from pnuu/bugfix-1d-bilinear
Fix bilinear resampler for 1D data
2 parents 959c6fb + bf047fe commit e531bd9

File tree

3 files changed

+92
-5
lines changed

3 files changed

+92
-5
lines changed

pyresample/bilinear/_base.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,17 @@ def _get_target_proj_vectors(self):
190190

191191
def _get_slices(self):
192192
shp = self._source_geo_def.shape
193-
cols, lines = np.meshgrid(np.arange(shp[1]),
194-
np.arange(shp[0]))
193+
try:
194+
cols, lines = np.meshgrid(np.arange(shp[1]),
195+
np.arange(shp[0]))
196+
data = (np.ravel(lines), np.ravel(cols))
197+
except IndexError:
198+
data = (np.zeros(shp[0], dtype=np.uint32), np.arange(shp[0]))
195199

196200
valid_lines_and_columns = array_slice_for_multiple_arrays(
197201
self._valid_input_index,
198-
(np.ravel(lines), np.ravel(cols))
199-
)
202+
data)
203+
200204
self.slices_y, self.slices_x = array_slice_for_multiple_arrays(
201205
self._index_array,
202206
valid_lines_and_columns

pyresample/bilinear/xarr.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,9 @@ def _finalize_output_data(self, data, res, fill_value):
131131
res = self._reshape_to_target_area(res, data.ndim)
132132

133133
self._add_missing_coordinates(data)
134+
dims = self._get_output_dims(data, res)
134135

135-
return DataArray(res, dims=data.dims, coords=self._out_coords)
136+
return DataArray(res, dims=dims, coords=self._out_coords)
136137

137138
def _add_missing_coordinates(self, data):
138139
self._add_x_and_y_coordinates()
@@ -154,10 +155,16 @@ def _adjust_bands_coordinates_to_match_data(self, data_coords):
154155
elif 'bands' in self._out_coords:
155156
del self._out_coords['bands']
156157

158+
def _get_output_dims(self, data, res):
159+
if data.ndim == res.ndim:
160+
return data.dims
161+
return list(self._out_coords.keys())
162+
157163
def _slice_data(self, data, fill_value):
158164
def from_delayed(delayeds, shp):
159165
return [da.from_delayed(d, shp, np.float32) for d in delayeds]
160166

167+
data = _check_data_shape(data, self._valid_input_index)
161168
if data.ndim == 2:
162169
shp = self.bilinear_s.shape
163170
else:
@@ -257,6 +264,20 @@ def _get_valid_input_index(source_geo_def,
257264
return valid_input_index, source_lons, source_lats
258265

259266

267+
def _check_data_shape(data, input_idxs):
268+
"""Check data shape and adjust if necessary."""
269+
# Handle multiple datasets
270+
if data.ndim > 2 and data.shape[0] * data.shape[1] == input_idxs.shape[0]:
271+
# Move the "channel" dimension first
272+
data = da.moveaxis(data, -1, 0)
273+
274+
# Ensure two dimensions
275+
if data.ndim == 1:
276+
data = DataArray(da.map_blocks(np.expand_dims, data.data, 0, new_axis=[0]))
277+
278+
return data
279+
280+
260281
class XArrayResamplerBilinear(XArrayBilinearResampler):
261282
"""Wrapper for the old resampler class."""
262283

pyresample/test/test_bilinear.py

+62
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,13 @@ def setUpClass(cls):
7171
cls.data1 = np.ones((in_shape[0], in_shape[1]))
7272
cls.data2 = 2. * cls.data1
7373
cls.data3 = cls.data1 + 9.5
74+
cls.data3_1d = np.ravel(cls.data3)
75+
7476
lons, lats = np.meshgrid(np.linspace(-25., 40., num=in_shape[0]),
7577
np.linspace(45., 75., num=in_shape[1]))
7678
cls.source_def = geometry.SwathDefinition(lons=lons, lats=lats)
79+
cls.source_def_1d = geometry.SwathDefinition(lons=np.ravel(lons),
80+
lats=np.ravel(lats))
7781

7882
cls.radius = 50e3
7983
cls._neighbours = 32
@@ -331,6 +335,22 @@ def test_get_sample_from_bil_info(self):
331335
input_idxs, idx_arr)
332336
assert not hasattr(res, 'mask')
333337

338+
def test_get_sample_from_bil_info_1d(self):
339+
"""Test resampling using resampling indices for 1D data."""
340+
from pyresample.bilinear import get_bil_info, get_sample_from_bil_info
341+
342+
t__, s__, input_idxs, idx_arr = get_bil_info(self.source_def_1d,
343+
self.target_def,
344+
50e5, neighbours=32,
345+
nprocs=1)
346+
# Sample from 1D data
347+
res = get_sample_from_bil_info(self.data3_1d, t__, s__,
348+
input_idxs, idx_arr)
349+
self.assertAlmostEqual(np.nanmin(res), 10.5)
350+
self.assertAlmostEqual(np.nanmax(res), 10.5)
351+
# Four pixels are outside of the data
352+
self.assertEqual(np.isnan(res).sum(), 4)
353+
334354
def test_resample_bilinear(self):
335355
"""Test whole bilinear resampling."""
336356
from pyresample.bilinear import resample_bilinear
@@ -467,9 +487,12 @@ def setUp(self):
467487
self.data1 = DataArray(da.ones((in_shape[0], in_shape[1])), dims=('y', 'x'))
468488
self.data2 = 2. * self.data1
469489
self.data3 = self.data1 + 9.5
490+
470491
lons, lats = np.meshgrid(np.linspace(-25., 40., num=in_shape[0]),
471492
np.linspace(45., 75., num=in_shape[1]))
472493
self.source_def = geometry.SwathDefinition(lons=lons, lats=lats)
494+
self.source_def_1d = geometry.SwathDefinition(lons=np.ravel(lons),
495+
lats=np.ravel(lats))
473496

474497
self.radius = 50e3
475498
self._neighbours = 32
@@ -714,6 +737,45 @@ def test_slice_data(self):
714737
self.assertTrue(np.all(np.isnan(p_1)) and np.all(np.isnan(p_2)) and
715738
np.all(np.isnan(p_3)) and np.all(np.isnan(p_4)))
716739

740+
def test_slice_data_1d(self):
741+
"""Test slicing 1D data."""
742+
import dask.array as da
743+
from xarray import DataArray
744+
745+
from pyresample.bilinear import XArrayBilinearResampler
746+
747+
resampler = XArrayBilinearResampler(self.source_def_1d, self.target_def,
748+
self.radius)
749+
resampler.get_bil_info()
750+
751+
# 1D data
752+
data = DataArray(da.ones(self.source_def_1d.shape))
753+
p_1, p_2, p_3, p_4 = resampler._slice_data(data, np.nan)
754+
self.assertEqual(p_1.shape, resampler.bilinear_s.shape)
755+
self.assertTrue(p_1.shape == p_2.shape == p_3.shape == p_4.shape)
756+
self.assertTrue(np.all(p_1 == 1.0) and np.all(p_2 == 1.0) and
757+
np.all(p_3 == 1.0) and np.all(p_4 == 1.0))
758+
759+
def test_get_sample_from_bil_info_1d(self):
760+
"""Test resampling using resampling indices for 1D data."""
761+
import dask.array as da
762+
from xarray import DataArray
763+
764+
from pyresample.bilinear import XArrayBilinearResampler
765+
766+
resampler = XArrayBilinearResampler(self.source_def_1d, self.target_def,
767+
50e5)
768+
resampler.get_bil_info()
769+
770+
# Sample from 1D data
771+
data = DataArray(da.ones(self.source_def_1d.shape), dims=('y'))
772+
res = resampler.get_sample_from_bil_info(data) # noqa
773+
assert 'x' in res.dims
774+
assert 'y' in res.dims
775+
776+
# Four pixels are outside of the data
777+
self.assertEqual(np.isnan(res).sum().compute(), 4)
778+
717779
@mock.patch('pyresample.bilinear.xarr.np.meshgrid')
718780
def test_get_slices(self, meshgrid):
719781
"""Test slice array creation."""

0 commit comments

Comments
 (0)