diff --git a/capellambse/diagram/_diagram.py b/capellambse/diagram/_diagram.py index 8f516068..fed8382a 100644 --- a/capellambse/diagram/_diagram.py +++ b/capellambse/diagram/_diagram.py @@ -16,6 +16,7 @@ ] import collections.abc as cabc +import contextlib import enum import logging import math @@ -272,18 +273,20 @@ def vector_snap( source = diagram.Vector2D(*source) if style is RoutingStyle.OBLIQUE: - return self.__vector_snap_oblique(point) + if point == source: + return self.__vector_snap_closest(point) + return self.__vector_snap_oblique(point, source) if style is RoutingStyle.MANHATTAN: return self.__vector_snap_manhattan(point, point - source) if style is RoutingStyle.TREE: return self.__vector_snap_tree(point, point - source) raise ValueError(f"Unsupported routing style: {style}") - def __vector_snap_oblique( + def __vector_snap_closest( self, source: diagram.Vector2D ) -> diagram.Vector2D: if source == self.center: - return source + return self.pos + self.size @ (0, 0.5) angle = self.size.angleto(source - self.center) alpha = 2 * self.size.angleto((1, 0)) @@ -310,6 +313,61 @@ def __vector_snap_oblique( (self.pos, self.pos + self.size @ (0, 1)), ) + def __vector_snap_oblique( + self, point: diagram.Vector2D, source: diagram.Vector2D + ) -> diagram.Vector2D: + assert point != source + if not ( + self.pos.x <= point.x <= self.pos.x + self.size.x + and self.pos.y <= point.y <= self.pos.y + self.size.y + ): + point = self.center + direction = point - source + edge = (source, point) + assert direction.x or direction.y, f"{edge} doesn't have a direction" + + edges: set[str] = set() + if direction.x > 0: + edges.add("left") + elif direction.x < 0: + edges.add("right") + + if direction.y > 0: + edges.add("top") + elif direction.y < 0: + edges.add("bottom") + assert len(edges) in (1, 2), f"{edge} doesn't have a direction" + + intersections: list[diagram.Vector2D] = [] + if "top" in edges: + border = (self.pos, self.pos + self.size @ (1, 0)) + with contextlib.suppress(ValueError): + intersection = diagram.line_intersect(border, edge) + if border[0].x <= intersection.x <= border[1].x: + intersections.append(intersection) + if "left" in edges: + border = (self.pos, self.pos + self.size @ (0, 1)) + with contextlib.suppress(ValueError): + intersection = diagram.line_intersect(border, edge) + if border[0].y <= intersection.y <= border[1].y: + intersections.append(intersection) + if "right" in edges: + border = (self.pos + self.size @ (1, 0), self.pos + self.size) + with contextlib.suppress(ValueError): + intersection = diagram.line_intersect(border, edge) + if border[0].y <= intersection.y <= border[1].y: + intersections.append(intersection) + if "bottom" in edges: + border = (self.pos + self.size @ (0, 1), self.pos + self.size) + with contextlib.suppress(ValueError): + intersection = diagram.line_intersect(border, edge) + if border[0].x <= intersection.x <= border[1].x: + intersections.append(intersection) + + assert len(intersections) > 0, f"{edge} doesn't intersect {edges}" + assert len(intersections) < 2, f"{edge} intersects multiple {edges}" + return intersections[0] + def __vector_snap_manhattan( self, point: diagram.Vector2D, direction: diagram.Vector2D ) -> diagram.Vector2D: