diff --git a/CHANGES.rst b/CHANGES.rst index 1fa43b38..ac1bf25d 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -13,12 +13,21 @@ General New Features ------------ +- The ``Regions`` class can now be initialized without any arguments. + [#527] + +- The ``Regions`` ``extend`` method now can accept another ``Regions`` + object as input. [#527] + Bug Fixes --------- API Changes ----------- +- The ``Regions`` class and its ``append`` and ``extend`` methods now + raise a ``TypeError`` for invalid inputs. [#527] + 0.7 (2022-10-27) ================ diff --git a/regions/core/regions.py b/regions/core/regions.py index 886fde9f..17e563bc 100644 --- a/regions/core/regions.py +++ b/regions/core/regions.py @@ -24,7 +24,13 @@ class Regions: The list of region objects. """ - def __init__(self, regions): + def __init__(self, regions=(), /): + if regions == (): + regions = [] + for item in regions: + if not isinstance(item, Region): + raise TypeError('Input regions must be a list of Region ' + 'objects') self.regions = regions def __getitem__(self, index): @@ -55,6 +61,8 @@ def append(self, region): region : `~regions.Region` The region to append. """ + if not isinstance(region, Region): + raise TypeError('Input region must be a Region object') self.regions.append(region) def extend(self, regions): @@ -64,10 +72,17 @@ def extend(self, regions): Parameters ---------- - regions : list of `~regions.Region` - A list of regions to include. + regions : `~regions.Regions` or list of `~regions.Region` + A `~regions.Regions` object or a list of regions to include. """ - self.regions.extend(regions) + if isinstance(regions, Regions): + self.regions.extend(regions.regions) + else: + for item in regions: + if not isinstance(item, Region): + raise TypeError('Input regions must be a list of Region ' + 'objects') + self.regions.extend(regions) def insert(self, index, region): """ diff --git a/regions/core/tests/test_regions.py b/regions/core/tests/test_regions.py new file mode 100644 index 00000000..278c2e70 --- /dev/null +++ b/regions/core/tests/test_regions.py @@ -0,0 +1,112 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst +""" +Tests for the regions module. +""" + +import pytest +from astropy.table import Table + +from regions.core import PixCoord, Regions +from regions.shapes import CirclePixelRegion + + +def test_regions_inputs(): + regs = [] + for radius in range(1, 5): + center = PixCoord(14, 21) + regs.append(CirclePixelRegion(center, radius=radius)) + + regions = Regions(regs) + assert len(regions) == 4 + + +def test_regions_no_input(): + regions = Regions() + assert len(regions) == 0 + + regions = Regions([]) + assert len(regions) == 0 + + +def test_regions_invalid_input(): + match = "'int' object is not iterable" + with pytest.raises(TypeError, match=match): + Regions(1) + + match = 'Input regions must be a list of Region objects' + with pytest.raises(TypeError, match=match): + Regions([1]) + with pytest.raises(TypeError, match=match): + Regions([1, 2, 3]) + + +def test_regions_append(): + regions = Regions() + + for radius in range(1, 5): + center = PixCoord(14, 21) + regions.append(CirclePixelRegion(center, radius=radius)) + assert len(regions) == 4 + + match = 'Input region must be a Region object' + with pytest.raises(TypeError, match=match): + regions.append(1) + + +def test_regions_extend(): + regions = Regions() + + regs = [] + for radius in range(1, 5): + center = PixCoord(14, 21) + regs.append(CirclePixelRegion(center, radius=radius)) + + regions.extend(regs) + assert len(regions) == 4 + + regions.extend(regions) + assert len(regions) == 8 + + match = 'Input regions must be a list of Region objects' + with pytest.raises(TypeError, match=match): + regions.extend([1]) + + +def test_regions_methods(): + regions = Regions() + + for radius in range(1, 5): + center = PixCoord(14, 21) + regions.append(CirclePixelRegion(center, radius=radius)) + + reg_slc = regions[0:2] + assert len(reg_slc) == 2 + + reg = CirclePixelRegion(PixCoord(0, 0), radius=1) + regions.insert(0, reg) + assert len(regions) == 5 + assert regions[0] == reg + + regions2 = regions.copy() + regions.reverse() + assert regions.regions == regions2.regions[::-1] + + outreg = regions.pop(-1) + assert outreg == reg + + +def test_regions_get_formats(): + reg = CirclePixelRegion(PixCoord(0, 0), radius=1) + regions = Regions([reg]) + tbl = regions.get_formats() + assert isinstance(tbl, Table) + assert len(tbl) == 3 + + +def test_regions_repr(): + reg = CirclePixelRegion(PixCoord(0, 0), radius=1) + regions = Regions([reg]) + regions_str = '[]' + regions_repr = f'' + assert str(regions) == regions_str + assert repr(regions) == regions_repr