Skip to content

Commit

Permalink
Clean nodes after building the graph (#23)
Browse files Browse the repository at this point in the history
* feat: clean graph of unneeded attributes after creation

Co-authored-by: Mario Santa Cruz <[email protected]>
Co-authored-by: Helen Theissen <[email protected]>
Co-authored-by: Jesper Dramsch <[email protected]>
  • Loading branch information
4 people authored Jul 26, 2024
1 parent d510eb2 commit 7fd9c74
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 124 deletions.
17 changes: 17 additions & 0 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ def generate_graph(self) -> HeteroData:

return graph

def clean(self, graph: HeteroData) -> HeteroData:
"""Clean the hidden attributes of the nodes and edges."""
for nodes_name in graph.node_types:
node_attrs = list(graph[nodes_name].keys())
for node_attr_name in node_attrs:
if node_attr_name.startswith("_"):
del graph[nodes_name][node_attr_name]

for edge_key in graph.edge_types:
edge_attrs = graph[edge_key].keys()
for edge_attr_name in edge_attrs:
if edge_attr_name.startswith("_"):
del graph[edge_key][edge_attr_name]

return graph

def save(self, graph: HeteroData) -> None:
"""Save the graph to the output path."""
if not os.path.exists(self.path) or self.overwrite:
Expand All @@ -69,6 +85,7 @@ def create(self) -> HeteroData:
"""Create the graph and save it to the output path."""
self.init()
graph = self.generate_graph()
graph = self.clean(graph)
self.save(graph)
return graph

Expand Down
18 changes: 9 additions & 9 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,26 +276,26 @@ def __init__(self, source_name: str, target_name: str, x_hops: int):
self.x_hops = x_hops

def adjacency_from_tri_nodes(self, source_nodes: NodeStorage):
source_nodes["nx_graph"] = icosahedral.add_edges_to_nx_graph(
source_nodes["nx_graph"],
resolutions=source_nodes["resolutions"],
source_nodes["_nx_graph"] = icosahedral.add_edges_to_nx_graph(
source_nodes["_nx_graph"],
resolutions=source_nodes["_resolutions"],
x_hops=self.x_hops,
) # HeteroData refuses to accept None

adjmat = nx.to_scipy_sparse_array(
source_nodes["nx_graph"], nodelist=list(range(len(source_nodes["nx_graph"]))), format="coo"
source_nodes["_nx_graph"], nodelist=list(range(len(source_nodes["_nx_graph"]))), format="coo"
)
return adjmat

def adjacency_from_hex_nodes(self, source_nodes: NodeStorage):

source_nodes["nx_graph"] = hexagonal.add_edges_to_nx_graph(
source_nodes["nx_graph"],
resolutions=source_nodes["resolutions"],
source_nodes["_nx_graph"] = hexagonal.add_edges_to_nx_graph(
source_nodes["_nx_graph"],
resolutions=source_nodes["_resolutions"],
x_hops=self.x_hops,
)

adjmat = nx.to_scipy_sparse_array(source_nodes["nx_graph"], format="coo")
adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo")
return adjmat

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
Expand All @@ -311,7 +311,7 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
return adjmat

def post_process_adjmat(self, nodes: NodeStorage, adjmat):
graph_sorted = {node_pos: i for i, node_pos in enumerate(nodes["node_ordering"])}
graph_sorted = {node_pos: i for i, node_pos in enumerate(nodes["_node_ordering"])}
sort_func = np.vectorize(graph_sorted.get)
adjmat.row = sort_func(adjmat.row)
adjmat.col = sort_func(adjmat.col)
Expand Down
6 changes: 3 additions & 3 deletions src/anemoi/graphs/nodes/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ def get_coordinates(self) -> torch.Tensor:
def create_nodes(self) -> np.ndarray: ...

def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData:
graph[self.name]["resolutions"] = self.resolutions
graph[self.name]["nx_graph"] = self.nx_graph
graph[self.name]["node_ordering"] = self.node_ordering
graph[self.name]["_resolutions"] = self.resolutions
graph[self.name]["_nx_graph"] = self.nx_graph
graph[self.name]["_node_ordering"] = self.node_ordering
return super().register_attributes(graph, config)


Expand Down
64 changes: 0 additions & 64 deletions tests/edges/test_icosahedral_edges.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/edges/test_multiscale_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from anemoi.graphs.nodes import TriNodes


class TestIcosahedralEdgesInit:
class TestMultiScaleEdgesInit:
def test_init(self):
"""Test MultiScaleEdges initialization."""
assert isinstance(MultiScaleEdges("test_nodes", "test_nodes", 1), MultiScaleEdges)
Expand All @@ -23,7 +23,7 @@ def test_fail_init_diff_nodes(self):
MultiScaleEdges("test_nodes", "test_nodes2", 0)


class TestIcosahedralEdgesTransform:
class TestMultiScaleEdgesTransform:

@pytest.fixture()
def tri_ico_graph(self) -> HeteroData:
Expand Down
8 changes: 4 additions & 4 deletions tests/nodes/test_hex_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_update_graph():
node_builder = HexNodes(0, "test_nodes")
graph = HeteroData()
graph = node_builder.update_graph(graph, {})
assert "resolutions" in graph["test_nodes"]
assert "nx_graph" in graph["test_nodes"]
assert "node_ordering" in graph["test_nodes"]
assert len(graph["test_nodes"]["node_ordering"]) == graph["test_nodes"].num_nodes
assert "_resolutions" in graph["test_nodes"]
assert "_nx_graph" in graph["test_nodes"]
assert "_node_ordering" in graph["test_nodes"]
assert len(graph["test_nodes"]["_node_ordering"]) == graph["test_nodes"].num_nodes
8 changes: 4 additions & 4 deletions tests/nodes/test_tri_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_update_graph():
node_builder = TriNodes(1, "test_nodes")
graph = HeteroData()
graph = node_builder.update_graph(graph, {})
assert "resolutions" in graph["test_nodes"]
assert "nx_graph" in graph["test_nodes"]
assert "node_ordering" in graph["test_nodes"]
assert len(graph["test_nodes"]["node_ordering"]) == graph["test_nodes"].num_nodes
assert "_resolutions" in graph["test_nodes"]
assert "_nx_graph" in graph["test_nodes"]
assert "_node_ordering" in graph["test_nodes"]
assert len(graph["test_nodes"]["_node_ordering"]) == graph["test_nodes"].num_nodes
47 changes: 47 additions & 0 deletions tests/test_create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from pathlib import Path

import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.create import GraphCreator


class TestGraphCreator:

def test_generate_graph(self, config_file: tuple[Path, str], mock_grids_path: tuple[str, int]):
"""Test GraphCreator workflow."""
tmp_path, config_name = config_file
graph_path = tmp_path / "graph.pt"
config_path = tmp_path / config_name

GraphCreator(graph_path, config_path).create()

graph = torch.load(graph_path)
assert isinstance(graph, HeteroData)
assert "test_nodes" in graph.node_types
assert ("test_nodes", "to", "test_nodes") in graph.edge_types

for nodes in graph.node_stores:
for node_attr in nodes.node_attrs():
assert isinstance(nodes[node_attr], torch.Tensor)
assert nodes[node_attr].dtype in [torch.int32, torch.float32]

for edges in graph.edge_stores:
for edge_attr in edges.edge_attrs():
assert isinstance(edges[edge_attr], torch.Tensor)
assert edges[edge_attr].dtype in [torch.int32, torch.float32]

for nodes in graph.node_stores:
for node_attr in nodes.node_attrs():
assert not node_attr.startswith("_")
for edges in graph.edge_stores:
for edge_attr in edges.edge_attrs():
assert not edge_attr.startswith("_")
38 changes: 0 additions & 38 deletions tests/test_graphs.py

This file was deleted.

0 comments on commit 7fd9c74

Please sign in to comment.