From 414af042daabb40b780add7ab8dc9e7b940001b9 Mon Sep 17 00:00:00 2001 From: Larry Bradley Date: Tue, 5 Sep 2023 18:02:31 -0400 Subject: [PATCH 1/4] Return single group if only one input source --- photutils/psf/groupers.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/photutils/psf/groupers.py b/photutils/psf/groupers.py index fab3caa24..a07df71f6 100644 --- a/photutils/psf/groupers.py +++ b/photutils/psf/groupers.py @@ -67,6 +67,15 @@ def _group_sources(self, x, y): """ from scipy.cluster.hierarchy import fclusterdata + x = np.atleast_1d(x) + y = np.atleast_1d(y) + if x.shape != y.shape: + raise ValueError('x and y must have the same shape') + if x.shape == (0,): # no sources + raise ValueError('x and y must not be empty') + if x.shape == (1,): # single source -> single group + return np.array([1]) + xypos = np.transpose((x, y)) group_id = fclusterdata(xypos, t=self.min_separation, criterion='distance') From 9a642ffbab23c4b1d7828af86b0f295cc9932435 Mon Sep 17 00:00:00 2001 From: Larry Bradley Date: Tue, 5 Sep 2023 18:36:17 -0400 Subject: [PATCH 2/4] Add SourceGrouper tests --- photutils/psf/tests/test_groupers.py | 282 +++++++++++++++++++++++++++ 1 file changed, 282 insertions(+) create mode 100644 photutils/psf/tests/test_groupers.py diff --git a/photutils/psf/tests/test_groupers.py b/photutils/psf/tests/test_groupers.py new file mode 100644 index 000000000..05e7ec895 --- /dev/null +++ b/photutils/psf/tests/test_groupers.py @@ -0,0 +1,282 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst +""" +Tests for the grouper module. +""" + +import numpy as np +import pytest +from numpy.testing import assert_equal + +from photutils.psf.groupers import SourceGrouper + + +def test_grouper_empty(): + """ + Test case when there are no sources. + """ + xx = np.array([]) + yy = np.array([]) + grouper = SourceGrouper(min_separation=10) + match = 'x and y must not be empty' + with pytest.raises(ValueError, match=match): + grouper(xx, yy) + + +def test_grouper_one_source(): + """ + Test case when there is only one source. + """ + xx = np.array([0]) + yy = np.array([0]) + gg = np.array([1]) + grouper = SourceGrouper(min_separation=10) + groups = grouper(xx, yy) + assert_equal(groups, gg) + + +def test_grouper_inputs(): + xx = np.array([1, 2, 3, 4]) + yy = np.array([1, 2]) + grouper = SourceGrouper(min_separation=10) + match = 'x and y must have the same shape' + with pytest.raises(ValueError, match=match): + grouper(xx, yy) + + +def test_isolated_sources(): + """ + Test case when all sources are isolated. + """ + xx = np.array([0, np.sqrt(2) / 4, np.sqrt(2) / 4, -np.sqrt(2) / 4, + -np.sqrt(2) / 4]) + yy = np.array([0, np.sqrt(2) / 4, -np.sqrt(2) / 4, np.sqrt(2) / 4, + -np.sqrt(2) / 4]) + gg = np.arange(len(xx), dtype=int) + 1 + grouper = SourceGrouper(min_separation=0.01) + groups = grouper(xx, yy) + assert_equal(groups, gg) + + +def test_grouper_one(): + """ + +---------+--------+---------+---------+--------+---------+ + | * * * * | + | | + 0.2 + + + | | + | | + | | + 0 + * * + + | | + | | + | | + -0.2 + + + | | + | * * * * | + +---------+--------+---------+---------+--------+---------+ + 0 0.5 1 1.5 2 + + x and y axis are in pixel coordinates. Each asterisk represents + the centroid of a star. + """ + x1 = np.array([0, np.sqrt(2) / 4, np.sqrt(2) / 4, -np.sqrt(2) / 4, + -np.sqrt(2) / 4]) + y1 = np.array([0, np.sqrt(2) / 4, -np.sqrt(2) / 4, np.sqrt(2) / 4, + -np.sqrt(2) / 4]) + g1 = np.ones(len(x1), dtype=int) + x2 = x1 + 2.0 + y2 = y1 + g2 = np.ones(len(x1), dtype=int) + 1 + + xx = np.hstack([x1, x2]) + yy = np.hstack([y1, y2]) + gg = np.hstack([g1, g2]) + grouper = SourceGrouper(min_separation=0.6) + groups = grouper(xx, yy) + assert_equal(groups, gg) + + +def test_grouper_two(): + """ + +--------------+--------------+-------------+--------------+ + 3 + * + + | * | + 2.5 + * + + | * | + 2 + * + + | | + 1.5 + + + | | + 1 + * + + | * | + 0.5 + * + + | * | + 0 + * + + +--------------+--------------+-------------+--------------+ + -1 -0.5 0 0.5 1 + """ + x1 = np.zeros(5) + y1 = np.linspace(0, 1, 5) + g1 = np.ones(5, dtype=int) + x2 = np.zeros(5) + y2 = np.linspace(2, 3, 5) + g2 = np.ones(5, dtype=int) + 1 + + xx = np.hstack([x1, x2]) + yy = np.hstack([y1, y2]) + gg = np.hstack([g1, g2]) + grouper = SourceGrouper(min_separation=0.3) + groups = grouper(xx, yy) + assert_equal(groups, gg) + + +def test_grouper_three(): + """ + 1 +--+-------+--------+--------+--------+-------+--------+--+ + | | + | | + | | + 0.5 + + + | | + | | + 0 + * * * * * * * * * * + + | | + | | + -0.5 + + + | | + | | + | | + -1 +--+-------+--------+--------+--------+-------+--------+--+ + 0 0.5 1 1.5 2 2.5 3 + """ + x1 = np.linspace(0, 1, 5) + y1 = np.zeros(5) + g1 = np.ones(5, dtype=int) + x2 = np.linspace(2, 3, 5) + y2 = np.zeros(5) + g2 = np.ones(5, dtype=int) + 1 + + xx = np.hstack([x1, x2]) + yy = np.hstack([y1, y2]) + gg = np.hstack([g1, g2]) + grouper = SourceGrouper(min_separation=0.3) + groups = grouper(xx, yy) + assert_equal(groups, gg) + + +def test_grouper_four(): + """ + +-+---------+---------+---------+---------+-+ + 1 + * + + | * * | + | | + | | + 0.5 + + + | | + | | + | | + 0 + * * + + | | + | | + -0.5 + + + | | + | | + | * * | + -1 + * + + +-+---------+---------+---------+---------+-+ + -1 -0.5 0 0.5 1 + """ + x = np.linspace(-1.0, 1.0, 5) + y = np.sqrt(1.0 - x**2) + xx = np.hstack((x, x)) + yy = np.hstack((y, -y)) + gg = np.ones(len(xx), dtype=int) + + grouper = SourceGrouper(min_separation=2.5) + groups = grouper(xx, yy) + assert_equal(groups, gg) + + +def test_grouper_five(): + """ + +--+--------+--------+-------+--------+--------+--------+--+ + 3 + * + + | * | + 2.5 + * + + | * | + 2 + * + + | | + 1.5 + * * * * * * * * * * + + | | + 1 + * + + | * | + 0.5 + * + + | * | + 0 + * + + +--+--------+--------+-------+--------+--------+--------+--+ + 0 0.5 1 1.5 2 2.5 3 + """ + x1 = 1.5 * np.ones(5) + y1 = np.linspace(0, 1, 5) + g1 = np.ones(5, dtype=int) + + x2 = 1.5 * np.ones(5) + y2 = np.linspace(2, 3, 5) + g2 = np.ones(5, dtype=int) + 1 + + x3 = np.linspace(0, 1, 5) + y3 = 1.5 * np.ones(5) + g3 = np.ones(5, dtype=int) + 2 + + x4 = np.linspace(2, 3, 5) + y4 = 1.5 * np.ones(5) + g4 = np.ones(5, dtype=int) + 3 + + xx = np.hstack([x1, x2, x3, x4]) + yy = np.hstack([y1, y2, y3, y4]) + gg = np.hstack([g1, g2, g3, g4]) + + grouper = SourceGrouper(min_separation=0.3) + groups = grouper(xx, yy) + assert_equal(groups, gg) + + +def test_grouper_six(): + """ + +------+----------+----------+----------+----------+------+ + | * * * * * * | + | | + 0.2 + + + | | + | | + | | + 0 + * * * + + | | + | | + | | + -0.2 + + + | | + | * * * * * * | + +------+----------+----------+----------+----------+------+ + 0 1 2 3 4 + """ + x1 = np.array([0, np.sqrt(2) / 4, np.sqrt(2) / 4, -np.sqrt(2) / 4, + -np.sqrt(2) / 4]) + y1 = np.array([0, np.sqrt(2) / 4, -np.sqrt(2) / 4, np.sqrt(2) / 4, + -np.sqrt(2) / 4]) + g1 = np.ones(len(x1), dtype=int) + + x2 = x1 + 2.0 + y2 = y1 + g2 = np.ones(len(x1), dtype=int) + 1 + + x3 = x1 + 4.0 + y3 = y1 + g3 = np.ones(len(x1), dtype=int) + 2 + + xx = np.hstack([x1, x2, x3]) + yy = np.hstack([y1, y2, y3]) + gg = np.hstack([g1, g2, g3]) + grouper = SourceGrouper(min_separation=0.6) + groups = grouper(xx, yy) + assert_equal(groups, gg) From d5219356696a55a3156dc1926dac46877706b039 Mon Sep 17 00:00:00 2001 From: Larry Bradley Date: Tue, 5 Sep 2023 18:39:37 -0400 Subject: [PATCH 3/4] Add changelog entry --- CHANGES.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index b6454af1e..0569125c2 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,12 @@ New Features Bug Fixes ^^^^^^^^^ +- ``photutils.psf`` + + - Fixed a bug where ``SourceGrouper`` would fail if only one source + was input. [#1617] + + API Changes ^^^^^^^^^^^ From a3fd98f1cba3c6994af817374805b1599a078f1b Mon Sep 17 00:00:00 2001 From: Larry Bradley Date: Tue, 5 Sep 2023 18:47:42 -0400 Subject: [PATCH 4/4] Skip tests if scipy is not installed --- photutils/psf/tests/test_groupers.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/photutils/psf/tests/test_groupers.py b/photutils/psf/tests/test_groupers.py index 05e7ec895..5796faf25 100644 --- a/photutils/psf/tests/test_groupers.py +++ b/photutils/psf/tests/test_groupers.py @@ -8,8 +8,10 @@ from numpy.testing import assert_equal from photutils.psf.groupers import SourceGrouper +from photutils.utils._optional_deps import HAS_SCIPY +@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required') def test_grouper_empty(): """ Test case when there are no sources. @@ -22,6 +24,7 @@ def test_grouper_empty(): grouper(xx, yy) +@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required') def test_grouper_one_source(): """ Test case when there is only one source. @@ -34,6 +37,7 @@ def test_grouper_one_source(): assert_equal(groups, gg) +@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required') def test_grouper_inputs(): xx = np.array([1, 2, 3, 4]) yy = np.array([1, 2]) @@ -43,6 +47,7 @@ def test_grouper_inputs(): grouper(xx, yy) +@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required') def test_isolated_sources(): """ Test case when all sources are isolated. @@ -57,6 +62,7 @@ def test_isolated_sources(): assert_equal(groups, gg) +@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required') def test_grouper_one(): """ +---------+--------+---------+---------+--------+---------+ @@ -96,6 +102,7 @@ def test_grouper_one(): assert_equal(groups, gg) +@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required') def test_grouper_two(): """ +--------------+--------------+-------------+--------------+ @@ -130,6 +137,7 @@ def test_grouper_two(): assert_equal(groups, gg) +@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required') def test_grouper_three(): """ 1 +--+-------+--------+--------+--------+-------+--------+--+ @@ -164,6 +172,7 @@ def test_grouper_three(): assert_equal(groups, gg) +@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required') def test_grouper_four(): """ +-+---------+---------+---------+---------+-+ @@ -197,6 +206,7 @@ def test_grouper_four(): assert_equal(groups, gg) +@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required') def test_grouper_five(): """ +--+--------+--------+-------+--------+--------+--------+--+ @@ -241,6 +251,7 @@ def test_grouper_five(): assert_equal(groups, gg) +@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required') def test_grouper_six(): """ +------+----------+----------+----------+----------+------+