Skip to content

Commit

Permalink
Merge branch 'master' into dp/dommaschk-NFP
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Mar 28, 2024
2 parents 0004c9b + a26926f commit 7ee3b98
Show file tree
Hide file tree
Showing 13 changed files with 1,031 additions and 193 deletions.
21 changes: 20 additions & 1 deletion desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,14 @@
from jax.experimental.ode import odeint
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
from jax.scipy.special import gammaln, logsumexp
from jax.tree_util import register_pytree_node, tree_flatten, tree_unflatten
from jax.tree_util import (
register_pytree_node,
tree_flatten,
tree_leaves,
tree_map,
tree_structure,
tree_unflatten,
)

def put(arr, inds, vals):
"""Functional interface for array "fancy indexing".
Expand Down Expand Up @@ -393,6 +400,18 @@ def tree_unflatten(*args, **kwargs):
"""Unflatten pytree for numpy backend."""
raise NotImplementedError

def tree_map(*args, **kwargs):
"""Map pytree for numpy backend."""
raise NotImplementedError

def tree_structure(*args, **kwargs):
"""Get structure of pytree for numpy backend."""
raise NotImplementedError

def tree_leaves(*args, **kwargs):
"""Get leaves of pytree for numpy backend."""
raise NotImplementedError

def register_pytree_node(foo, *args):
"""Dummy decorator for non-jax pytrees."""
return foo
Expand Down
4 changes: 2 additions & 2 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class _Coil(_MagneticField, Optimizable, ABC):
_io_attrs_ = _MagneticField._io_attrs_ + ["_current"]

def __init__(self, current, *args, **kwargs):
self._current = float(current)
self._current = float(np.squeeze(current))
super().__init__(*args, **kwargs)

@optimizable_parameter
Expand All @@ -145,7 +145,7 @@ def current(self):
@current.setter
def current(self, new):
assert jnp.isscalar(new) or new.size == 1
self._current = float(new)
self._current = float(np.squeeze(new))

def compute_magnetic_field(
self, coords, params=None, basis="rpz", source_grid=None
Expand Down
1 change: 1 addition & 0 deletions desc/objectives/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Classes defining objectives for equilibrium and optimization."""

from ._bootstrap import BootstrapRedlConsistency
from ._coils import CoilCurvature, CoilLength, CoilTorsion
from ._equilibrium import (
CurrentDensity,
Energy,
Expand Down
Loading

0 comments on commit 7ee3b98

Please sign in to comment.