Skip to content

Commit

Permalink
Fix (global version) (#12)
Browse files Browse the repository at this point in the history
* Global Encoder-Processor-Decoder graph (#9)

* feat: Initial implementation of global graphs + fixes

Co-authored-by: Mario Santa Cruz <[email protected]>
Co-authored-by: Helen Theissen <[email protected]>
Co-authored-by: Sara Hahner <[email protected]>
Co-authored-by: Jesper Dramsch <[email protected]>

* fix: attributes as torch.float32

* new test: attributes must be float32

* fix typo

* Homogeneize base builders

* improve test docstrings

* homogeneize (name as class attribute)

* new input config

* new default

* remove dataclass from attribute classes
  • Loading branch information
JPXKQX authored and JesperDramsch committed Jul 8, 2024
1 parent 9231b56 commit 7bf0d23
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 114 deletions.
9 changes: 6 additions & 3 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,14 @@ def generate_graph(self) -> HeteroData:
HeteroData: The generated graph.
"""
graph = HeteroData()
for name, nodes_cfg in self.config.nodes.items():
graph = instantiate(nodes_cfg.node_builder).update_graph(graph, name, nodes_cfg.get("attributes", {}))

for nodes_cfg in self.config.nodes:
graph = instantiate(nodes_cfg.node_builder, name=nodes_cfg.name).update_graph(
graph, nodes_cfg.get("attributes", {})
)

for edges_cfg in self.config.edges:
graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).update_graph(
graph = instantiate(edges_cfg.edge_builder, edges_cfg.source_name, edges_cfg.target_name).update_graph(
graph, edges_cfg.get("attributes", {})
)

Expand Down
29 changes: 15 additions & 14 deletions src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional

import numpy as np
Expand All @@ -15,11 +14,11 @@
LOGGER = logging.getLogger(__name__)


@dataclass
class BaseEdgeAttribute(ABC, NormalizerMixin):
"""Base class for edge attributes."""

norm: Optional[str] = None
def __init__(self, norm: Optional[str] = None) -> None:
self.norm = norm

@abstractmethod
def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> np.ndarray: ...
Expand All @@ -29,10 +28,13 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:
if values.ndim == 1:
values = values[:, np.newaxis]

return torch.tensor(values)
normed_values = self.normalize(values)

return torch.tensor(normed_values, dtype=torch.float32)

def compute(self, graph: HeteroData, source_name: str, target_name: str, *args, **kwargs) -> torch.Tensor:
def compute(self, graph: HeteroData, edges_name: tuple[str, str, str], *args, **kwargs) -> torch.Tensor:
"""Compute the edge attributes."""
source_name, _, target_name = edges_name
assert (
source_name in graph.node_types
), f"Node \"{source_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}."
Expand All @@ -41,13 +43,11 @@ def compute(self, graph: HeteroData, source_name: str, target_name: str, *args,
), f"Node \"{target_name}\" not found in graph. Optional nodes are {', '.join(graph.node_types)}."

values = self.get_raw_values(graph, source_name, target_name, *args, **kwargs)
normed_values = self.normalize(values)
return self.post_process(normed_values)
return self.post_process(values)


@dataclass
class EdgeDirection(BaseEdgeAttribute):
"""Compute directional features for edges.
"""Edge direction feature.
If using the rotated features, the direction of the edge is computed
rotating the target nodes to the north pole. If not, it is computed
Expand All @@ -69,8 +69,9 @@ class EdgeDirection(BaseEdgeAttribute):
Compute directional attributes.
"""

norm: Optional[str] = None
luse_rotated_features: bool = True
def __init__(self, norm: Optional[str] = None, luse_rotated_features: bool = True) -> None:
super().__init__(norm)
self.luse_rotated_features = luse_rotated_features

def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray:
"""Compute directional features for edges.
Expand All @@ -96,7 +97,6 @@ def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str)
return edge_dirs


@dataclass
class EdgeLength(BaseEdgeAttribute):
"""Edge length feature.
Expand All @@ -115,8 +115,9 @@ class EdgeLength(BaseEdgeAttribute):
Compute edge lengths attributes.
"""

norm: str = "l1"
invert: bool = True
def __init__(self, norm: Optional[str] = None, invert: bool = False) -> None:
super().__init__(norm)
self.invert = invert

def get_raw_values(self, graph: HeteroData, source_name: str, target_name: str) -> np.ndarray:
"""Compute haversine distance (in kilometers) between nodes connected by edges.
Expand Down
81 changes: 61 additions & 20 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,61 @@


class BaseEdgeBuilder(ABC):
"""Base class for edge builders."""

def __init__(self, source_name: str, target_name: str):
super().__init__()
self.source_name = source_name
self.target_name = target_name

@property
def name(self) -> tuple[str, str, str]:
"""Name of the edge subgraph."""
return self.source_name, "to", self.target_name

@abstractmethod
def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage): ...

def register_edges(self, graph: HeteroData, source_indices: np.ndarray, target_indices: np.ndarray) -> HeteroData:
def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]:
"""Prepare nodes information."""
return graph[self.source_name], graph[self.target_name]

def get_edge_index(self, graph: HeteroData) -> torch.Tensor:
"""Get the edge indices of source and target nodes.
Parameters
----------
graph : HeteroData
The graph.
Returns
-------
torch.Tensor of shape (2, num_edges)
The edge indices.
"""
source_nodes, target_nodes = self.prepare_node_data(graph)

adjmat = self.get_adjacency_matrix(source_nodes, target_nodes)

# Get source & target indices of the edges
edge_index = np.stack([adjmat.col, adjmat.row], axis=0)

return torch.from_numpy(edge_index).to(torch.int32)

def register_edges(self, graph: HeteroData) -> HeteroData:
"""Register edges in the graph.
Parameters
----------
graph : HeteroData
The graph to register the edges.
source_indices : np.ndarray of shape (N, )
The indices of the source nodes.
target_indices : np.ndarray of shape (N, )
The indices of the target nodes.
Returns
-------
HeteroData
The graph with the registered edges.
"""
edge_index = np.stack([source_indices, target_indices], axis=0).astype(np.int32)
graph[(self.source_name, "to", self.target_name)].edge_index = torch.from_numpy(edge_index)
graph[self.name].edge_index = self.get_edge_index(graph)
graph[self.name].edge_type = type(self).__name__
return graph

def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData:
Expand All @@ -64,15 +91,9 @@ def register_attributes(self, graph: HeteroData, config: DotDict) -> HeteroData:
The graph with the registered attributes.
"""
for attr_name, attr_config in config.items():
graph[self.source_name, "to", self.target_name][attr_name] = instantiate(attr_config).compute(
graph, self.source_name, self.target_name
)
graph[self.name][attr_name] = instantiate(attr_config).compute(graph, self.name)
return graph

def prepare_node_data(self, graph: HeteroData) -> tuple[NodeStorage, NodeStorage]:
"""Prepare nodes information."""
return graph[self.source_name], graph[self.target_name]

def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None) -> HeteroData:
"""Update the graph with the edges.
Expand All @@ -88,11 +109,7 @@ def update_graph(self, graph: HeteroData, attrs_config: Optional[DotDict] = None
HeteroData
The graph with the edges.
"""
source_nodes, target_nodes = self.prepare_node_data(graph)

adjmat = self.get_adjacency_matrix(source_nodes, target_nodes)

graph = self.register_edges(graph, adjmat.col, adjmat.row)
graph = self.register_edges(graph)

if attrs_config is None:
return graph
Expand All @@ -113,6 +130,17 @@ class KNNEdges(BaseEdgeBuilder):
The name of the target nodes.
num_nearest_neighbours : int
Number of nearest neighbours.
Methods
-------
get_adjacency_matrix(source_nodes, target_nodes)
Compute the adjacency matrix for the KNN method.
register_edges(graph)
Register the edges in the graph.
register_attributes(graph, config)
Register attributes in the edges of the graph.
update_graph(graph, attrs_config)
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int):
Expand Down Expand Up @@ -162,6 +190,19 @@ class CutOffEdges(BaseEdgeBuilder):
Factor to multiply the grid reference distance to get the cut-off radius.
radius : float
Cut-off radius.
Methods
-------
get_cutoff_radius(graph, mask_attr)
Compute the cut-off radius.
get_adjacency_matrix(source_nodes, target_nodes)
Get the adjacency matrix for the cut-off method.
register_edges(graph)
Register the edges in the graph.
register_attributes(graph, config)
Register attributes in the edges of the graph.
update_graph(graph, attrs_config)
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, cutoff_factor: float):
Expand Down
33 changes: 22 additions & 11 deletions src/anemoi/graphs/nodes/attributes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional

import numpy as np
import torch
from scipy.spatial import SphericalVoronoi
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import NodeStorage

from anemoi.graphs.generate.transforms import latlon_rad_to_cartesian
Expand All @@ -15,11 +15,11 @@
LOGGER = logging.getLogger(__name__)


@dataclass
class BaseWeights(ABC, NormalizerMixin):
"""Base class for the weights of the nodes."""

norm: Optional[str] = None
def __init__(self, norm: Optional[str] = None) -> None:
self.norm = norm

@abstractmethod
def get_raw_values(self, nodes: NodeStorage, *args, **kwargs): ...
Expand All @@ -29,19 +29,28 @@ def post_process(self, values: np.ndarray) -> torch.Tensor:
if values.ndim == 1:
values = values[:, np.newaxis]

return torch.tensor(values)
norm_values = self.normalize(values)

def compute(self, nodes: NodeStorage, *args, **kwargs) -> torch.Tensor:
return torch.tensor(norm_values, dtype=torch.float32)

def compute(self, graph: HeteroData, nodes_name: str, *args, **kwargs) -> torch.Tensor:
"""Get the node weights.
Parameters
----------
graph : HeteroData
Graph.
nodes_name : str
Name of the nodes.
Returns
-------
torch.Tensor
Weights associated to the nodes.
"""
nodes = graph[nodes_name]
weights = self.get_raw_values(nodes, *args, **kwargs)
norm_weights = self.normalize(weights)
return self.post_process(norm_weights)
return self.post_process(weights)


class UniformWeights(BaseWeights):
Expand All @@ -63,7 +72,6 @@ def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
return np.ones(nodes.num_nodes)


@dataclass
class AreaWeights(BaseWeights):
"""Implements the area of the nodes as the weights.
Expand All @@ -84,9 +92,12 @@ class AreaWeights(BaseWeights):
Compute the area attributes for each node.
"""

norm: Optional[str] = "unit-max"
radius: float = 1.0
centre: np.ndarray = np.array([0, 0, 0])
def __init__(
self, norm: Optional[str] = None, radius: float = 1.0, centre: np.ndarray = np.array([0, 0, 0])
) -> None:
super().__init__(norm)
self.radius = radius
self.centre = centre

def get_raw_values(self, nodes: NodeStorage, *args, **kwargs) -> np.ndarray:
"""Compute the area associated to each node.
Expand Down
Loading

0 comments on commit 7bf0d23

Please sign in to comment.