diff --git a/desc/backend.py b/desc/backend.py index 1433abe73e..721920190c 100644 --- a/desc/backend.py +++ b/desc/backend.py @@ -77,7 +77,14 @@ from jax.experimental.ode import odeint from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular from jax.scipy.special import gammaln, logsumexp - from jax.tree_util import register_pytree_node, tree_flatten, tree_unflatten + from jax.tree_util import ( + register_pytree_node, + tree_flatten, + tree_leaves, + tree_map, + tree_structure, + tree_unflatten, + ) def put(arr, inds, vals): """Functional interface for array "fancy indexing". @@ -393,6 +400,18 @@ def tree_unflatten(*args, **kwargs): """Unflatten pytree for numpy backend.""" raise NotImplementedError + def tree_map(*args, **kwargs): + """Map pytree for numpy backend.""" + raise NotImplementedError + + def tree_structure(*args, **kwargs): + """Get structure of pytree for numpy backend.""" + raise NotImplementedError + + def tree_leaves(*args, **kwargs): + """Get leaves of pytree for numpy backend.""" + raise NotImplementedError + def register_pytree_node(foo, *args): """Dummy decorator for non-jax pytrees.""" return foo diff --git a/desc/coils.py b/desc/coils.py index b81a7ac24f..6430a8b458 100644 --- a/desc/coils.py +++ b/desc/coils.py @@ -133,7 +133,7 @@ class _Coil(_MagneticField, Optimizable, ABC): _io_attrs_ = _MagneticField._io_attrs_ + ["_current"] def __init__(self, current, *args, **kwargs): - self._current = float(current) + self._current = float(np.squeeze(current)) super().__init__(*args, **kwargs) @optimizable_parameter @@ -145,7 +145,7 @@ def current(self): @current.setter def current(self, new): assert jnp.isscalar(new) or new.size == 1 - self._current = float(new) + self._current = float(np.squeeze(new)) def compute_magnetic_field( self, coords, params=None, basis="rpz", source_grid=None diff --git a/desc/objectives/__init__.py b/desc/objectives/__init__.py index 25d51ce089..7c6c95a68d 100644 --- a/desc/objectives/__init__.py +++ b/desc/objectives/__init__.py @@ -1,6 +1,7 @@ """Classes defining objectives for equilibrium and optimization.""" from ._bootstrap import BootstrapRedlConsistency +from ._coils import CoilCurvature, CoilLength, CoilTorsion from ._equilibrium import ( CurrentDensity, Energy, diff --git a/desc/objectives/_coils.py b/desc/objectives/_coils.py new file mode 100644 index 0000000000..38cb32ad29 --- /dev/null +++ b/desc/objectives/_coils.py @@ -0,0 +1,599 @@ +import numbers + +import numpy as np + +from desc.backend import ( + jnp, + tree_flatten, + tree_leaves, + tree_map, + tree_structure, + tree_unflatten, +) +from desc.compute import get_transforms +from desc.grid import LinearGrid, _Grid +from desc.utils import Timer, errorif + +from .normalization import compute_scaling_factors +from .objective_funs import _Objective + + +class _CoilObjective(_Objective): + """Base class for calculating coil objectives. + + Parameters + ---------- + coil : CoilSet or Coil + Coil for which the data keys will be optimized. + data_keys : list of str + data keys that will be optimized when this class is inherited. + target : float, ndarray, optional + Target value(s) of the objective. Only used if bounds is None. + Must be broadcastable to Objective.dim_f. + bounds : tuple of float, ndarray, optional + Lower and upper bounds on the objective. Overrides target. + Both bounds must be broadcastable to to Objective.dim_f + weight : float, ndarray, optional + Weighting to apply to the Objective, relative to other Objectives. + Must be broadcastable to to Objective.dim_f + normalize : bool, optional + Whether to compute the error in physical units or non-dimensionalize. + normalize_target : bool, optional + Whether target and bounds should be normalized before comparing to computed + values. If `normalize` is `True` and the target is in physical units, + this should also be set to True. + be set to True. + loss_function : {None, 'mean', 'min', 'max'}, optional + Loss function to apply to the objective values once computed. This loss function + is called on the raw compute value, before any shifting, scaling, or + normalization. Operates over all coils, not each individial coil. + deriv_mode : {"auto", "fwd", "rev"} + Specify how to compute jacobian matrix, either forward mode or reverse mode AD. + "auto" selects forward or reverse mode based on the size of the input and output + of the objective. Has no effect on self.grad or self.hess which always use + reverse mode and forward over reverse mode respectively. + grid : Grid, list, optional + Collocation grid containing the nodes to evaluate at. If list, has to adhere to + Objective.dim_f + name : str, optional + Name of the objective function. + + """ + + def __init__( + self, + coil, + data_keys, + target=None, + bounds=None, + weight=1, + normalize=True, + normalize_target=True, + loss_function=None, + deriv_mode="auto", + grid=None, + name=None, + ): + self._grid = grid + self._data_keys = data_keys + self._normalize = normalize + super().__init__( + things=[coil], + target=target, + bounds=bounds, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + loss_function=loss_function, + deriv_mode=deriv_mode, + name=name, + ) + + def build(self, use_jit=True, verbose=1): # noqa:C901 + """Build constant arrays. + + Parameters + ---------- + use_jit : bool, optional + Whether to just-in-time compile the objective and derivatives. + verbose : int, optional + Level of output. + + """ + # local import to avoid circular import + from desc.coils import CoilSet, MixedCoilSet, _Coil + + self._dim_f = 0 + self._quad_weights = jnp.array([]) + + def get_dim_f_and_weights(coilset): + """Get dim_f and quad_weights from grid.""" + if isinstance(coilset, list): + [get_dim_f_and_weights(x) for x in coilset] + elif isinstance(coilset, MixedCoilSet): + [get_dim_f_and_weights(x) for x in coilset] + elif isinstance(coilset, CoilSet): + get_dim_f_and_weights(coilset.coils) + elif isinstance(coilset, _Grid): + self._dim_f += coilset.num_zeta + self._quad_weights = jnp.concatenate( + (self._quad_weights, coilset.spacing[:, 2]) + ) + + def to_list(coilset): + """Turn a MixedCoilSet container into a list of what it's containing.""" + if isinstance(coilset, list): + return [to_list(x) for x in coilset] + elif isinstance(coilset, MixedCoilSet): + return [to_list(x) for x in coilset] + elif isinstance(coilset, CoilSet): + # use the same grid/transform for CoilSet + return to_list(coilset.coils[0]) + else: + return [coilset] + + is_single_coil = lambda x: isinstance(x, _Coil) and not isinstance(x, CoilSet) + # gives structure of coils, e.g. MixedCoilSet(coils, coils) would give a + # a structure of [[*, *], [*, *]] if n = 2 coils + coil_structure = tree_structure( + self.things[0], + is_leaf=lambda x: is_single_coil(x), + ) + coil_leaves = tree_leaves(self.things[0], is_leaf=lambda x: is_single_coil(x)) + + # check type + if isinstance(self._grid, numbers.Integral): + self._grid = LinearGrid(N=self._grid, endpoint=False) + # all of these cases return a container MixedCoilSet that contains + # LinearGrids. i.e. MixedCoilSet.coils = list of LinearGrid + if self._grid is None: + # map default grid to structure of inputted coils + self._grid = tree_map( + lambda x: LinearGrid( + N=2 * x.N + 5, NFP=getattr(x, "NFP", 1), endpoint=False + ), + self.things[0], + is_leaf=lambda x: is_single_coil(x), + ) + elif isinstance(self._grid, _Grid): + # map inputted single LinearGrid to structure of inputted coils + self._grid = [self._grid] * len(coil_leaves) + self._grid = tree_unflatten(coil_structure, self._grid) + else: + # this case covers an inputted list of grids that matches the size + # of the inputted coils. Can be a 1D list or nested list. + flattened_grid = tree_flatten( + self._grid, is_leaf=lambda x: isinstance(x, _Grid) + )[0] + self._grid = tree_unflatten(coil_structure, flattened_grid) + + timer = Timer() + if verbose > 0: + print("Precomputing transforms") + timer.start("Precomputing transforms") + + transforms = tree_map( + lambda x, y: get_transforms(self._data_keys, obj=x, grid=y), + self.things[0], + self._grid, + is_leaf=lambda x: is_single_coil(x), + ) + + get_dim_f_and_weights(self._grid) + # get only needed grids (1 per CoilSet) and flatten that list + self._grid = tree_leaves( + to_list(self._grid), is_leaf=lambda x: isinstance(x, _Grid) + ) + transforms = tree_leaves( + to_list(transforms), is_leaf=lambda x: isinstance(x, dict) + ) + + errorif( + np.any([grid.num_rho > 1 or grid.num_theta > 1 for grid in self._grid]), + ValueError, + "Only use toroidal resolution for coil grids.", + ) + + # CoilSet and _Coil have one grid/transform + if not isinstance(self.things[0], MixedCoilSet): + self._grid = self._grid[0] + transforms = transforms[0] + + self._constants = { + "transforms": transforms, + "quad_weights": self._quad_weights, + } + + timer.stop("Precomputing transforms") + if verbose > 1: + timer.disp("Precomputing transforms") + + if self._normalize: + self._scales = compute_scaling_factors(coil_leaves[0]) + + super().build(use_jit=use_jit, verbose=verbose) + + def compute(self, params, constants=None): + """Compute data of coil for given data key. + + Parameters + ---------- + params : dict + Dictionary of the coil's degrees of freedom. + constants : dict + Dictionary of constant data, eg transforms, profiles etc. Defaults to + self._constants. + + Returns + ------- + f : float or array of floats + Coil length. + """ + if constants is None: + constants = self._constants + + coils = self.things[0] + data = coils.compute( + self._data_keys, + params=params, + transforms=constants["transforms"], + grid=self._grid, + ) + + return data + + +class CoilLength(_CoilObjective): + """Coil length. + + Parameters + ---------- + coil : CoilSet or Coil + Coil(s) that are to be optimized + target : float, ndarray, optional + Target value(s) of the objective. Only used if bounds is None. + Must be broadcastable to Objective.dim_f. If array, it has to + be flattened according to the number of inputs. + bounds : tuple of float, ndarray, optional + Lower and upper bounds on the objective. Overrides target. + Both bounds must be broadcastable to to Objective.dim_f + weight : float, ndarray, optional + Weighting to apply to the Objective, relative to other Objectives. + Must be broadcastable to to Objective.dim_f + normalize : bool, optional + Whether to compute the error in physical units or non-dimensionalize. + normalize_target : bool, optional + Whether target and bounds should be normalized before comparing to computed + values. If `normalize` is `True` and the target is in physical units, + this should also be set to True. + be set to True. + loss_function : {None, 'mean', 'min', 'max'}, optional + Loss function to apply to the objective values once computed. This loss function + is called on the raw compute value, before any shifting, scaling, or + normalization. Operates over all coils, not each individial coil. + deriv_mode : {"auto", "fwd", "rev"} + Specify how to compute jacobian matrix, either forward mode or reverse mode AD. + "auto" selects forward or reverse mode based on the size of the input and output + of the objective. Has no effect on self.grad or self.hess which always use + reverse mode and forward over reverse mode respectively. + grid : Grid, optional + Collocation grid containing the nodes to evaluate at. + name : str, optional + Name of the objective function. + """ + + _scalar = False # Not always a scalar, if a coilset is passed in + _units = "(m)" + _print_value_fmt = "Coil length: {:10.3e} " + + def __init__( + self, + coils, + target=None, + bounds=None, + weight=1, + normalize=True, + normalize_target=True, + loss_function=None, + deriv_mode="auto", + grid=None, + name="coil length", + ): + self._coils = coils + if target is None and bounds is None: + target = 2 * np.pi + + super().__init__( + coils, + ["length"], + target=target, + bounds=bounds, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + loss_function=loss_function, + deriv_mode=deriv_mode, + grid=grid, + name=name, + ) + + def build(self, use_jit=True, verbose=1): + """Build constant arrays. + + Parameters + ---------- + use_jit : bool, optional + Whether to just-in-time compile the objective and derivatives. + verbose : int, optional + Level of output. + + """ + from desc.coils import CoilSet, _Coil + + super().build(use_jit=use_jit, verbose=verbose) + + if self._normalize: + self._normalization = self._scales["a"] + + # TODO: repeated code but maybe it's fine + flattened_coils = tree_flatten( + self._coils, + is_leaf=lambda x: isinstance(x, _Coil) and not isinstance(x, CoilSet), + )[0] + flattened_coils = ( + [flattened_coils[0]] + if not isinstance(self._coils, CoilSet) + else flattened_coils + ) + self._dim_f = len(flattened_coils) + self._constants["quad_weights"] = 1 + + def compute(self, params, constants=None): + """Compute coil length. + + Parameters + ---------- + params : dict + Dictionary of the coil's degrees of freedom. + constants : dict + Dictionary of constant data, eg transforms, profiles etc. Defaults to + self._constants. + + Returns + ------- + f : float or array of floats + Coil length. + """ + data = super().compute(params, constants=constants) + data = tree_flatten(data, is_leaf=lambda x: isinstance(x, dict))[0] + out = jnp.array([dat["length"] for dat in data]) + return out + + +class CoilCurvature(_CoilObjective): + """Coil curvature. + + Targets the local curvature value per grid node for each coil. A smaller curvature + value indicates straighter coils. All curvature values are positive. + + Parameters + ---------- + coil : CoilSet or Coil + Coil(s) that are to be optimized + target : float, ndarray, optional + Target value(s) of the objective. Only used if bounds is None. + Must be broadcastable to Objective.dim_f. If array, it has to + be flattened according to the number of inputs. + bounds : tuple of float, ndarray, optional + Lower and upper bounds on the objective. Overrides target. + Both bounds must be broadcastable to to Objective.dim_f + weight : float, ndarray, optional + Weighting to apply to the Objective, relative to other Objectives. + Must be broadcastable to to Objective.dim_f + normalize : bool, optional + Whether to compute the error in physical units or non-dimensionalize. + normalize_target : bool, optional + Whether target and bounds should be normalized before comparing to computed + values. If `normalize` is `True` and the target is in physical units, + this should also be set to True. + be set to True. + loss_function : {None, 'mean', 'min', 'max'}, optional + Loss function to apply to the objective values once computed. This loss function + is called on the raw compute value, before any shifting, scaling, or + normalization. Operates over all coils, not each individial coil. + deriv_mode : {"auto", "fwd", "rev"} + Specify how to compute jacobian matrix, either forward mode or reverse mode AD. + "auto" selects forward or reverse mode based on the size of the input and output + of the objective. Has no effect on self.grad or self.hess which always use + reverse mode and forward over reverse mode respectively. + grid : Grid, optional + Collocation grid containing the nodes to evaluate at. + name : str, optional + Name of the objective function. + """ + + _scalar = False + _units = "(m^-1)" + _print_value_fmt = "Coil curvature: {:10.3e} " + + def __init__( + self, + coil, + target=None, + bounds=None, + weight=1, + normalize=True, + normalize_target=True, + loss_function=None, + deriv_mode="auto", + grid=None, + name="coil curvature", + ): + if target is None and bounds is None: + bounds = (0, 1) + + super().__init__( + coil, + ["curvature"], + target=target, + bounds=bounds, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + loss_function=loss_function, + deriv_mode=deriv_mode, + grid=grid, + name=name, + ) + + def build(self, use_jit=True, verbose=1): + """Build constant arrays. + + Parameters + ---------- + use_jit : bool, optional + Whether to just-in-time compile the objective and derivatives. + verbose : int, optional + Level of output. + + """ + super().build(use_jit=use_jit, verbose=verbose) + + if self._normalize: + self._normalization = 1 / self._scales["a"] + + def compute(self, params, constants=None): + """Compute coil curvature. + + Parameters + ---------- + params : dict + Dictionary of the coil's degrees of freedom. + constants : dict + Dictionary of constant data, eg transforms, profiles etc. Defaults to + self._constants. + + Returns + ------- + f : array of floats + 1D array of coil curvature values. + """ + data = super().compute(params, constants=constants) + data = tree_flatten(data, is_leaf=lambda x: isinstance(x, dict))[0] + out = jnp.concatenate([dat["curvature"] for dat in data]) + return out + + +class CoilTorsion(_CoilObjective): + """Coil torsion. + + Targets the local torsion value per grid node for each coil. Indicative + of how much the coil goes out of the poloidal plane. e.g. a torsion + value of 0 means the coil is completely planar. + + Parameters + ---------- + coil : CoilSet or Coil + Coil(s) that are to be optimized + target : float, ndarray, optional + Target value(s) of the objective. Only used if bounds is None. + Must be broadcastable to Objective.dim_f. If array, it has to + be flattened according to the number of inputs. + bounds : tuple of float, ndarray, optional + Lower and upper bounds on the objective. Overrides target. + Both bounds must be broadcastable to to Objective.dim_f + weight : float, ndarray, optional + Weighting to apply to the Objective, relative to other Objectives. + Must be broadcastable to to Objective.dim_f + normalize : bool, optional + Whether to compute the error in physical units or non-dimensionalize. + normalize_target : bool, optional + Whether target and bounds should be normalized before comparing to computed + values. If `normalize` is `True` and the target is in physical units, + this should also be set to True. + be set to True. + loss_function : {None, 'mean', 'min', 'max'}, optional + Loss function to apply to the objective values once computed. This loss function + is called on the raw compute value, before any shifting, scaling, or + normalization. Operates over all coils, not each individial coil. + deriv_mode : {"auto", "fwd", "rev"} + Specify how to compute jacobian matrix, either forward mode or reverse mode AD. + "auto" selects forward or reverse mode based on the size of the input and output + of the objective. Has no effect on self.grad or self.hess which always use + reverse mode and forward over reverse mode respectively. + grid : Grid, optional + Collocation grid containing the nodes to evaluate at. + name : str, optional + Name of the objective function. + """ + + _scalar = False + _units = "(m^-1)" + _print_value_fmt = "Coil torsion: {:10.3e} " + + def __init__( + self, + coil, + target=None, + bounds=None, + weight=1, + normalize=True, + normalize_target=True, + loss_function=None, + deriv_mode="auto", + grid=None, + name="coil torsion", + ): + if target is None and bounds is None: + target = 0 + + super().__init__( + coil, + ["torsion"], + target=target, + bounds=bounds, + weight=weight, + normalize=normalize, + normalize_target=normalize_target, + loss_function=loss_function, + deriv_mode=deriv_mode, + grid=grid, + name=name, + ) + + def build(self, use_jit=True, verbose=1): + """Build constant arrays. + + Parameters + ---------- + use_jit : bool, optional + Whether to just-in-time compile the objective and derivatives. + verbose : int, optional + Level of output. + + """ + super().build(use_jit=use_jit, verbose=verbose) + + if self._normalize: + self._normalization = 1 / self._scales["a"] + + def compute(self, params, constants=None): + """Compute coil torsion. + + Parameters + ---------- + params : dict + Dictionary of the coil's degrees of freedom. + constants : dict + Dictionary of constant data, eg transforms, profiles etc. Defaults to + self._constants. + + Returns + ------- + f : float or array of floats + Coil torsion. + """ + data = super().compute(params, constants=constants) + data = tree_flatten(data, is_leaf=lambda x: isinstance(x, dict))[0] + out = jnp.concatenate([dat["torsion"] for dat in data]) + return out diff --git a/desc/objectives/normalization.py b/desc/objectives/normalization.py index 8b9743d86a..9ab44cad16 100644 --- a/desc/objectives/normalization.py +++ b/desc/objectives/normalization.py @@ -3,6 +3,8 @@ import numpy as np from scipy.constants import elementary_charge, mu_0 +from desc.geometry import Curve + def compute_scaling_factors(thing): """Compute dimensional quantities for normalizations.""" @@ -65,6 +67,8 @@ def get_lowest_mode(basis, coeffs): scales["A"] = np.pi * scales["a"] ** 2 scales["V"] = 2 * np.pi * scales["R0"] * scales["A"] + elif isinstance(thing, Curve): + scales["a"] = thing.compute("length")["length"] / (2 * np.pi) # replace 0 scales to avoid normalizing by zero for scale in scales.keys(): if np.isclose(scales[scale], 0): diff --git a/desc/objectives/objective_funs.py b/desc/objectives/objective_funs.py index 62e9eeea44..48ee1ef0f2 100644 --- a/desc/objectives/objective_funs.py +++ b/desc/objectives/objective_funs.py @@ -68,9 +68,13 @@ def _set_derivatives(self): self._hess = Derivative(self.compute_scalar, mode="hess") if self._deriv_mode == "batched": self._jac_scaled = Derivative(self.compute_scaled, mode="fwd") + self._jac_scaled_error = Derivative(self.compute_scaled_error, mode="fwd") self._jac_unscaled = Derivative(self.compute_unscaled, mode="fwd") if self._deriv_mode == "looped": self._jac_scaled = Derivative(self.compute_scaled, mode="looped") + self._jac_scaled_error = Derivative( + self.compute_scaled_error, mode="looped" + ) self._jac_unscaled = Derivative(self.compute_unscaled, mode="looped") if self._deriv_mode == "blocked": # could also do something similar for grad and hess, but probably not @@ -101,6 +105,7 @@ def jac_(op, x, constants=None): return jnp.vstack(J) self._jac_scaled = partial(jac_, "jac_scaled") + self._jac_scaled_error = partial(jac_, "jac_scaled_error") self._jac_unscaled = partial(jac_, "jac_unscaled") def jit(self): # noqa: C901 @@ -120,12 +125,15 @@ def jit(self): # noqa: C901 "compute_unscaled", "compute_scalar", "jac_scaled", + "jac_scaled_error", "jac_unscaled", "hess", "grad", "jvp_scaled", + "jvp_scaled_error", "jvp_unscaled", "vjp_scaled", + "vjp_scaled_error", "vjp_unscaled", ] @@ -159,9 +167,10 @@ def build(self, use_jit=None, verbose=1): # build objectives self._dim_f = 0 for objective in self.objectives: - if verbose > 0: - print("Building objective: " + objective.name) - objective.build(use_jit=self.use_jit, verbose=verbose) + if not objective.built: + if verbose > 0: + print("Building objective: " + objective.name) + objective.build(use_jit=self.use_jit, verbose=verbose) self._dim_f += objective.dim_f if self._dim_f == 1: self._scalar = True @@ -400,33 +409,55 @@ def x(self, *things): return jnp.concatenate(xs) def grad(self, x, constants=None): - """Compute gradient vector of scalar form of the objective wrt x.""" + """Compute gradient vector of self.compute_scalar wrt x.""" if constants is None: constants = self.constants return jnp.atleast_1d(self._grad(x, constants).squeeze()) def hess(self, x, constants=None): - """Compute Hessian matrix of scalar form of the objective wrt x.""" + """Compute Hessian matrix of self.compute_scalar wrt x.""" if constants is None: constants = self.constants return jnp.atleast_2d(self._hess(x, constants).squeeze()) def jac_scaled(self, x, constants=None): - """Compute Jacobian matrix of vector form of the objective wrt x.""" + """Compute Jacobian matrix of self.compute_scaled wrt x.""" if constants is None: constants = self.constants return jnp.atleast_2d(self._jac_scaled(x, constants).squeeze()) + def jac_scaled_error(self, x, constants=None): + """Compute Jacobian matrix of self.compute_scaled_error wrt x.""" + if constants is None: + constants = self.constants + return jnp.atleast_2d(self._jac_scaled_error(x, constants).squeeze()) + def jac_unscaled(self, x, constants=None): - """Compute Jacobian matrix of vector form of the objective wrt x, unweighted.""" + """Compute Jacobian matrix of self.compute_unscaled wrt x.""" if constants is None: constants = self.constants return jnp.atleast_2d(self._jac_unscaled(x, constants).squeeze()) - def jvp_scaled(self, v, x, constants=None): - """Compute Jacobian-vector product of the objective function. + def _jvp(self, v, x, constants=None, op="compute_scaled"): + v = v if isinstance(v, (tuple, list)) else (v,) + + fun = lambda x: getattr(self, op)(x, constants) + if len(v) == 1: + jvpfun = lambda dx: Derivative.compute_jvp(fun, 0, dx, x) + return jnp.vectorize(jvpfun, signature="(n)->(k)")(v[0]) + elif len(v) == 2: + jvpfun = lambda dx1, dx2: Derivative.compute_jvp2(fun, 0, 0, dx1, dx2, x) + return jnp.vectorize(jvpfun, signature="(n),(n)->(k)")(v[0], v[1]) + elif len(v) == 3: + jvpfun = lambda dx1, dx2, dx3: Derivative.compute_jvp3( + fun, 0, 0, 0, dx1, dx2, dx3, x + ) + return jnp.vectorize(jvpfun, signature="(n),(n),(n)->(k)")(v[0], v[1], v[2]) + else: + raise NotImplementedError("Cannot compute JVP higher than 3rd order.") - Uses the scaled form of the objective. + def jvp_scaled(self, v, x, constants=None): + """Compute Jacobian-vector product of self.compute_scaled. Parameters ---------- @@ -439,29 +470,26 @@ def jvp_scaled(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - v = v if isinstance(v, (tuple, list)) else (v,) + return self._jvp(v, x, constants, "compute_scaled") - compute_scaled = lambda x: self.compute_scaled(x, constants) - if len(v) == 1: - jvpfun = lambda dx: Derivative.compute_jvp(compute_scaled, 0, dx, x) - return jnp.vectorize(jvpfun, signature="(n)->(k)")(v[0]) - elif len(v) == 2: - jvpfun = lambda dx1, dx2: Derivative.compute_jvp2( - compute_scaled, 0, 0, dx1, dx2, x - ) - return jnp.vectorize(jvpfun, signature="(n),(n)->(k)")(v[0], v[1]) - elif len(v) == 3: - jvpfun = lambda dx1, dx2, dx3: Derivative.compute_jvp3( - compute_scaled, 0, 0, 0, dx1, dx2, dx3, x - ) - return jnp.vectorize(jvpfun, signature="(n),(n),(n)->(k)")(v[0], v[1], v[2]) - else: - raise NotImplementedError("Cannot compute JVP higher than 3rd order.") + def jvp_scaled_error(self, v, x, constants=None): + """Compute Jacobian-vector product of self.compute_scaled_error. - def jvp_unscaled(self, v, x, constants=None): - """Compute Jacobian-vector product of the objective function. + Parameters + ---------- + v : tuple of ndarray + Vectors to right-multiply the Jacobian by. + The number of vectors given determines the order of derivative taken. + x : ndarray + Optimization variables. + constants : list + Constant parameters passed to sub-objectives. + + """ + return self._jvp(v, x, constants, "compute_scaled_error") - Uses the unscaled form of the objective. + def jvp_unscaled(self, v, x, constants=None): + """Compute Jacobian-vector product of self.compute_unscaled. Parameters ---------- @@ -474,29 +502,29 @@ def jvp_unscaled(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - v = v if isinstance(v, (tuple, list)) else (v,) + return self._jvp(v, x, constants, "compute_unscaled") - compute_unscaled = lambda x: self.compute_unscaled(x, constants) - if len(v) == 1: - jvpfun = lambda dx: Derivative.compute_jvp(compute_unscaled, 0, dx, x) - return jnp.vectorize(jvpfun, signature="(n)->(k)")(v[0]) - elif len(v) == 2: - jvpfun = lambda dx1, dx2: Derivative.compute_jvp2( - compute_unscaled, 0, 0, dx1, dx2, x - ) - return jnp.vectorize(jvpfun, signature="(n),(n)->(k)")(v[0], v[1]) - elif len(v) == 3: - jvpfun = lambda dx1, dx2, dx3: Derivative.compute_jvp3( - compute_unscaled, 0, 0, 0, dx1, dx2, dx3, x - ) - return jnp.vectorize(jvpfun, signature="(n),(n),(n)->(k)")(v[0], v[1], v[2]) - else: - raise NotImplementedError("Cannot compute JVP higher than 3rd order.") + def _vjp(self, v, x, constants=None, op="compute_scaled"): + fun = lambda x: getattr(self, op)(x, constants) + return Derivative.compute_vjp(fun, 0, v, x) def vjp_scaled(self, v, x, constants=None): - """Compute vector-Jacobian product of the objective function. + """Compute vector-Jacobian product of self.compute_scaled. + + Parameters + ---------- + v : ndarray + Vector to left-multiply the Jacobian by. + x : ndarray + Optimization variables. + constants : list + Constant parameters passed to sub-objectives. - Uses the scaled form of the objective. + """ + return self._vjp(v, x, constants, "compute_scaled") + + def vjp_scaled_error(self, v, x, constants=None): + """Compute vector-Jacobian product of self.compute_scaled_error. Parameters ---------- @@ -508,13 +536,10 @@ def vjp_scaled(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - compute_scaled = lambda x: self.compute_scaled(x, constants) - return Derivative.compute_vjp(compute_scaled, 0, v, x) + return self._vjp(v, x, constants, "compute_scaled_error") def vjp_unscaled(self, v, x, constants=None): - """Compute vector-Jacobian product of the objective function. - - Uses the unscaled form of the objective. + """Compute vector-Jacobian product of self.compute_unscaled. Parameters ---------- @@ -526,8 +551,7 @@ def vjp_unscaled(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - compute_unscaled = lambda x: self.compute_unscaled(x, constants) - return Derivative.compute_vjp(compute_unscaled, 0, v, x) + return self._vjp(v, x, constants, "compute_unscaled") def compile(self, mode="auto", verbose=1): """Call the necessary functions to ensure the function is compiled. @@ -806,6 +830,9 @@ def _set_derivatives(self): self._jac_scaled = Derivative( self.compute_scaled, argnums, mode=self._deriv_mode ) + self._jac_scaled_error = Derivative( + self.compute_scaled_error, argnums, mode=self._deriv_mode + ) self._jac_unscaled = Derivative( self.compute_unscaled, argnums, mode=self._deriv_mode ) @@ -820,6 +847,7 @@ def jit(self): # noqa: C901 "compute_unscaled", "compute_scalar", "jac_scaled", + "jac_scaled_error", "jac_unscaled", "hess", "grad", @@ -963,25 +991,37 @@ def compute_scalar(self, *args, **kwargs): return f.squeeze() def grad(self, *args, **kwargs): - """Compute gradient vector of scalar form of the objective wrt x.""" + """Compute gradient vector of self.compute_scalar wrt x.""" return self._grad(*args, **kwargs) def hess(self, *args, **kwargs): - """Compute Hessian matrix of scalar form of the objective wrt x.""" + """Compute Hessian matrix of self.compute_scalar wrt x.""" return self._hess(*args, **kwargs) def jac_scaled(self, *args, **kwargs): - """Compute Jacobian matrix of vector form of the objective wrt x.""" + """Compute Jacobian matrix of self.compute_scaled wrt x.""" return self._jac_scaled(*args, **kwargs) + def jac_scaled_error(self, *args, **kwargs): + """Compute Jacobian matrix of self.compute_scaled_error wrt x.""" + return self._jac_scaled_error(*args, **kwargs) + def jac_unscaled(self, *args, **kwargs): - """Compute Jacobian matrix of vector form of the objective wrt x, unweighted.""" + """Compute Jacobian matrix of self.compute_unscaled wrt x.""" return self._jac_unscaled(*args, **kwargs) - def jvp_scaled(self, v, x, constants=None): - """Compute Jacobian-vector product of the objective function. + def _jvp(self, v, x, constants=None, op="compute_scaled"): + v = v if isinstance(v, (tuple, list)) else (v,) + x = x if isinstance(x, (tuple, list)) else (x,) + assert len(x) == len(v) + + fun = lambda *x: getattr(self, op)(*x, constants=constants) + jvpfun = lambda *dx: Derivative.compute_jvp(fun, tuple(range(len(x))), dx, *x) + sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)" + return jnp.vectorize(jvpfun, signature=sig)(*v) - Uses the scaled form of the objective. + def jvp_scaled(self, v, x, constants=None): + """Compute Jacobian-vector product of self.compute_scaled. Parameters ---------- @@ -993,21 +1033,25 @@ def jvp_scaled(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - v = v if isinstance(v, (tuple, list)) else (v,) - x = x if isinstance(x, (tuple, list)) else (x,) - assert len(x) == len(v) + return self._jvp(v, x, constants, "compute_scaled") - compute_scaled = lambda *x: self.compute_scaled(*x, constants=constants) - jvpfun = lambda *dx: Derivative.compute_jvp( - compute_scaled, tuple(range(len(x))), dx, *x - ) - sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)" - return jnp.vectorize(jvpfun, signature=sig)(*v) + def jvp_scaled_error(self, v, x, constants=None): + """Compute Jacobian-vector product of self.compute_scaled_error. - def jvp_unscaled(self, v, x, constants=None): - """Compute Jacobian-vector product of the objective function. + Parameters + ---------- + v : tuple of ndarray + Vectors to right-multiply the Jacobian by. + x : tuple of ndarray + Optimization variables. + constants : list + Constant parameters passed to sub-objectives. + + """ + return self._jvp(v, x, constants, "compute_scaled_error") - Uses the unscaled form of the objective. + def jvp_unscaled(self, v, x, constants=None): + """Compute Jacobian-vector product of self.compute_unscaled. Parameters ---------- @@ -1019,16 +1063,7 @@ def jvp_unscaled(self, v, x, constants=None): Constant parameters passed to sub-objectives. """ - v = v if isinstance(v, (tuple, list)) else (v,) - x = x if isinstance(x, (tuple, list)) else (x,) - assert len(x) == len(v) - - compute_unscaled = lambda *x: self.compute_unscaled(*x, constants=constants) - jvpfun = lambda *dx: Derivative.compute_jvp( - compute_unscaled, tuple(range(len(x))), dx, *x - ) - sig = ",".join(f"(n{i})" for i in range(len(x))) + "->(k)" - return jnp.vectorize(jvpfun, signature=sig)(*v) + return self._jvp(v, x, constants, "compute_unscaled") def print_value(self, *args, **kwargs): """Print the value of the objective.""" diff --git a/desc/optimize/_constraint_wrappers.py b/desc/optimize/_constraint_wrappers.py index 5d3fc33a25..1c6453d1a2 100644 --- a/desc/optimize/_constraint_wrappers.py +++ b/desc/optimize/_constraint_wrappers.py @@ -244,7 +244,7 @@ def compute_scalar(self, x_reduced, constants=None): return self._objective.compute_scalar(x, constants) def grad(self, x_reduced, constants=None): - """Compute gradient of the sum of squares of residuals. + """Compute gradient of self.compute_scalar. Parameters ---------- @@ -264,7 +264,7 @@ def grad(self, x_reduced, constants=None): return df[self._unfixed_idx] @ self._Z def hess(self, x_reduced, constants=None): - """Compute Hessian of the sum of squares of residuals. + """Compute Hessian of self.compute_scalar. Parameters ---------- @@ -283,8 +283,18 @@ def hess(self, x_reduced, constants=None): df = self._objective.hess(x, constants) return self._Z.T @ df[self._unfixed_idx, :][:, self._unfixed_idx] @ self._Z - def jac_unscaled(self, x_reduced, constants=None): - """Compute Jacobian of the vector objective function without weighting / bounds. + def _jac(self, x_reduced, constants=None, op="scaled"): + x = self.recover(x_reduced) + if self._objective._deriv_mode == "blocked": + fun = getattr(self._objective, "jac_" + op) + return fun(x, constants)[:, self._unfixed_idx] @ self._Z + + v = self._unfixed_idx_mat + df = getattr(self._objective, "jvp_" + op)(v.T, x, constants) + return df.T + + def jac_scaled(self, x_reduced, constants=None): + """Compute Jacobian of self.compute_scaled. Parameters ---------- @@ -299,18 +309,10 @@ def jac_unscaled(self, x_reduced, constants=None): Jacobian matrix. """ - x = self.recover(x_reduced) - if self._objective._deriv_mode == "blocked": - return ( - self._objective.jac_unscaled(x, constants)[:, self._unfixed_idx] - @ self._Z - ) - v = self._unfixed_idx_mat - df = self._objective.jvp_unscaled(v.T, x, constants) - return df.T + return self._jac(x_reduced, constants, "scaled") - def jac_scaled(self, x_reduced, constants=None): - """Compute Jacobian of the vector objective function with weighting / bounds. + def jac_scaled_error(self, x_reduced, constants=None): + """Compute Jacobian of self.compute_scaled_error. Parameters ---------- @@ -323,20 +325,51 @@ def jac_scaled(self, x_reduced, constants=None): ------- J : ndarray Jacobian matrix. + """ + return self._jac(x_reduced, constants, "scaled_error") + + def jac_unscaled(self, x_reduced, constants=None): + """Compute Jacobian of self.compute_unscaled. + + Parameters + ---------- + x_reduced : ndarray + Reduced state vector that satisfies linear constraints. + constants : list + Constant parameters passed to sub-objectives. + + Returns + ------- + J : ndarray + Jacobian matrix. + + """ + return self._jac(x_reduced, constants, "unscaled") + + def _jvp(self, v, x_reduced, constants=None, op="jvp_scaled"): x = self.recover(x_reduced) - if self._objective._deriv_mode == "blocked": - return ( - self._objective.jac_scaled(x, constants)[:, self._unfixed_idx] @ self._Z - ) - v = self._unfixed_idx_mat - df = self._objective.jvp_scaled(v.T, x, constants) - return df.T + v = self._unfixed_idx_mat @ v + df = getattr(self._objective, op)(v, x, constants) + return df def jvp_scaled(self, v, x_reduced, constants=None): - """Compute Jacobian-vector product of the objective function. + """Compute Jacobian-vector product of self.compute_scaled. - Uses the scaled form of the objective. + Parameters + ---------- + v : tuple of ndarray + Vectors to right-multiply the Jacobian by. + x_reduced : ndarray + Optimization variables with linear constraints removed. + constants : list + Constant parameters passed to sub-objectives. + + """ + return self._jvp(v, x_reduced, constants, "jvp_scaled") + + def jvp_scaled_error(self, v, x_reduced, constants=None): + """Compute Jacobian-vector product of self.compute_scaled_error. Parameters ---------- @@ -348,15 +381,10 @@ def jvp_scaled(self, v, x_reduced, constants=None): Constant parameters passed to sub-objectives. """ - x = self.recover(x_reduced) - v = self._unfixed_idx_mat @ v - df = self._objective.jvp_scaled(v, x, constants) - return df + return self._jvp(v, x_reduced, constants, "jvp_scaled_error") def jvp_unscaled(self, v, x_reduced, constants=None): - """Compute Jacobian-vector product of the objective function. - - Uses the unscaled form of the objective. + """Compute Jacobian-vector product of self.compute_unscaled. Parameters ---------- @@ -368,15 +396,30 @@ def jvp_unscaled(self, v, x_reduced, constants=None): Constant parameters passed to sub-objectives. """ + return self._jvp(v, x_reduced, constants, "jvp_unscaled") + + def _vjp(self, v, x_reduced, constants=None, op="vjp_scaled"): x = self.recover(x_reduced) - v = self._unfixed_idx_mat @ v - df = self._objective.jvp_unscaled(v, x, constants) - return df + df = getattr(self._objective, op)(v, x, constants) + return df[self._unfixed_idx] @ self._Z def vjp_scaled(self, v, x_reduced, constants=None): - """Compute vector-Jacobian product of the objective function. + """Compute vector-Jacobian product of self.compute_scaled. - Uses the scaled form of the objective. + Parameters + ---------- + v : ndarray + Vector to left-multiply the Jacobian by. + x_reduced : ndarray + Optimization variables with linear constraints removed. + constants : list + Constant parameters passed to sub-objectives. + + """ + return self._vjp(v, x_reduced, constants, "vjp_scaled") + + def vjp_scaled_error(self, v, x_reduced, constants=None): + """Compute vector-Jacobian product of self.compute_scaled_error. Parameters ---------- @@ -388,14 +431,10 @@ def vjp_scaled(self, v, x_reduced, constants=None): Constant parameters passed to sub-objectives. """ - x = self.recover(x_reduced) - df = self._objective.vjp_scaled(v, x, constants) - return df[self._unfixed_idx] @ self._Z + return self._vjp(v, x_reduced, constants, "vjp_scaled_error") def vjp_unscaled(self, v, x_reduced, constants=None): - """Compute vector-Jacobian product of the objective function. - - Uses the unscaled form of the objective. + """Compute vector-Jacobian product of self.compute_unscaled. Parameters ---------- @@ -407,9 +446,7 @@ def vjp_unscaled(self, v, x_reduced, constants=None): Constant parameters passed to sub-objectives. """ - x = self.recover(x_reduced) - df = self._objective.vjp_unscaled(v, x, constants) - return df[self._unfixed_idx] @ self._Z + return self._vjp(v, x_reduced, constants, "vjp_unscaled") def __getattr__(self, name): """For other attributes we defer to the base objective.""" @@ -463,7 +500,15 @@ def __init__( not con._equilibrium, ValueError, "ProximalProjection method cannot handle general " - + "nonlinear constraint {}.".format(con), + + f"nonlinear constraint {con}.", + ) + # can't have bounds on constraint bc if constraint is satisfied then + # Fx == 0, and that messes with Gx @ Fx^-1 Fc etc. + errorif( + con.bounds is not None, + ValueError, + "ProximalProjection can only handle equality constraints, " + + f"got bounds for constraint {con}", ) self._objective = objective self._constraint = constraint @@ -805,7 +850,7 @@ def compute_unscaled(self, x, constants=None): return self._objective.compute_unscaled(xopt, constants[0]) def grad(self, x, constants=None): - """Compute gradient of the sum of squares of residuals. + """Compute gradient of self.compute_scalar. Parameters ---------- @@ -820,12 +865,16 @@ def grad(self, x, constants=None): gradient vector. """ + # TODO: figure out projected vjp to make this better f = jnp.atleast_1d(self.compute_scaled_error(x, constants)) - J = self.jac_scaled(x, constants) + J = self.jac_scaled_error(x, constants) return f.T @ J - def jac_unscaled(self, x, constants=None): - """Compute Jacobian of the vector objective function without weights / bounds. + def hess(self, x, constants=None): + """Compute Hessian of self.compute_scalar. + + Uses the "small residual approximation" where the Hessian is replaced by + the square of the Jacobian: H = J.T @ J Parameters ---------- @@ -836,14 +885,15 @@ def jac_unscaled(self, x, constants=None): Returns ------- - J : ndarray - Jacobian matrix. + H : ndarray + Hessian matrix. + """ - v = jnp.eye(x.shape[0]) - return self.jvp_unscaled(v, x, constants).T + J = self.jac_scaled_error(x, constants) + return J.T @ J def jac_scaled(self, x, constants=None): - """Compute Jacobian of the vector objective function with weights / bounds. + """Compute Jacobian of self.compute_scaled. Parameters ---------- @@ -861,11 +911,8 @@ def jac_scaled(self, x, constants=None): v = jnp.eye(x.shape[0]) return self.jvp_scaled(v, x, constants).T - def hess(self, x, constants=None): - """Compute Hessian of the sum of squares of residuals. - - Uses the "small residual approximation" where the Hessian is replaced by - the square of the Jacobian: H = J.T @ J + def jac_scaled_error(self, x, constants=None): + """Compute Jacobian of self.compute_scaled_error. Parameters ---------- @@ -876,17 +923,33 @@ def hess(self, x, constants=None): Returns ------- - H : ndarray - Hessian matrix. + J : ndarray + Jacobian matrix. """ - J = self.jac_scaled(x, constants) - return J.T @ J + v = jnp.eye(x.shape[0]) + return self.jvp_scaled_error(v, x, constants).T - def jvp_scaled(self, v, x, constants=None): - """Compute Jacobian-vector product of the objective function. + def jac_unscaled(self, x, constants=None): + """Compute Jacobian of self.compute_unscaled. - Uses the scaled form of the objective. + Parameters + ---------- + x : ndarray + State vector. + constants : list + Constant parameters passed to sub-objectives. + + Returns + ------- + J : ndarray + Jacobian matrix. + """ + v = jnp.eye(x.shape[0]) + return self.jvp_unscaled(v, x, constants).T + + def jvp_scaled(self, v, x, constants=None): + """Compute Jacobian-vector product of self.compute_scaled. Parameters ---------- @@ -905,10 +968,28 @@ def jvp_scaled(self, v, x, constants=None): jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="scaled") return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) - def jvp_unscaled(self, v, x, constants=None): - """Compute Jacobian-vector product of the objective function. + def jvp_scaled_error(self, v, x, constants=None): + """Compute Jacobian-vector product of self.compute_scaled_error. - Uses the unscaled form of the objective. + Parameters + ---------- + v : ndarray or tuple of ndarray + Vectors to right-multiply the Jacobian by. + This method only works for first order jvps. + x : ndarray + Optimization variables. + constants : list + Constant parameters passed to sub-objectives. + + """ + v = v[0] if isinstance(v, (tuple, list)) else v + constants = setdefault(constants, self.constants) + xg, xf = self._update_equilibrium(x, store=True) + jvpfun = lambda u: self._jvp(u, xf, xg, constants, op="scaled_error") + return jnp.vectorize(jvpfun, signature="(n)->(k)")(v) + + def jvp_unscaled(self, v, x, constants=None): + """Compute Jacobian-vector product of self.compute_unscaled. Parameters ---------- diff --git a/desc/optimize/_desc_wrappers.py b/desc/optimize/_desc_wrappers.py index 1ffb4480b8..cf5addc86b 100644 --- a/desc/optimize/_desc_wrappers.py +++ b/desc/optimize/_desc_wrappers.py @@ -187,7 +187,7 @@ def _optimize_desc_aug_lagrangian_least_squares( result = lsq_auglag( lambda x, *c: objective.compute_scaled_error(x, c[0]), x0=x0, - jac=lambda x, *c: objective.jac_scaled(x, c[0]), + jac=lambda x, *c: objective.jac_scaled_error(x, c[0]), bounds=(-jnp.inf, jnp.inf), constraint=constraint_wrapped, args=(objective.constants, constraint.constants if constraint else None), @@ -270,7 +270,7 @@ def _optimize_desc_least_squares( result = lsqtr( objective.compute_scaled_error, x0=x0, - jac=objective.jac_scaled, + jac=objective.jac_scaled_error, args=(objective.constants,), x_scale=x_scale, ftol=stoptol["ftol"], diff --git a/desc/optimize/_scipy_wrappers.py b/desc/optimize/_scipy_wrappers.py index 77289816be..d62cca4ca0 100644 --- a/desc/optimize/_scipy_wrappers.py +++ b/desc/optimize/_scipy_wrappers.py @@ -353,7 +353,7 @@ def _optimize_scipy_least_squares( # noqa: C901 - FIXME: simplify this assert constraint is None, f"method {method} doesn't support constraints" options = {} if options is None else options x_scale = "jac" if x_scale == "auto" else x_scale - fun, jac = objective.compute_scaled_error, objective.jac_scaled + fun, jac = objective.compute_scaled_error, objective.jac_scaled_error # need to use some "global" variables here fun_allx = [] fun_allf = [] diff --git a/desc/perturbations.py b/desc/perturbations.py index 8758b73cd9..3fa622b248 100644 --- a/desc/perturbations.py +++ b/desc/perturbations.py @@ -290,7 +290,7 @@ def perturb( # noqa: C901 - FIXME: break this up into simpler pieces if verbose > 0: print("Computing df") timer.start("df computation") - Jx = objective.jac_scaled(x) + Jx = objective.jac_scaled_error(x) Jx_reduced = Jx[:, unfixed_idx] @ Z @ scale RHS1 = objective.jvp_scaled(tangents, x) if include_f: @@ -597,7 +597,7 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces if verbose > 0: print("Computing df") timer.start("df computation") - Fx = objective_f.jac_scaled(xf) + Fx = objective_f.jac_scaled_error(xf) timer.stop("df computation") if verbose > 1: timer.disp("df computation") @@ -606,7 +606,7 @@ def optimal_perturb( # noqa: C901 - FIXME: break this up into simpler pieces if verbose > 0: print("Computing dg") timer.start("dg computation") - Gx = objective_g.jac_scaled(xg) + Gx = objective_g.jac_scaled_error(xg) timer.stop("dg computation") if verbose > 1: timer.disp("dg computation") diff --git a/tests/test_examples.py b/tests/test_examples.py index 75a66552c3..9acdbe597c 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -10,6 +10,7 @@ from qsc import Qsc from desc.backend import jnp +from desc.coils import FourierRZCoil from desc.continuation import _solve_axisym, solve_continuation_automatic from desc.equilibrium import EquilibriaFamily, Equilibrium from desc.examples import get @@ -20,6 +21,9 @@ from desc.objectives import ( AspectRatio, BoundaryError, + CoilCurvature, + CoilLength, + CoilTorsion, CurrentDensity, FixBoundaryR, FixBoundaryZ, @@ -781,7 +785,7 @@ def test_multiobject_optimization_prox(): ) surf.change_resolution(M=4, N=0) constraints = ( - ForceBalance(eq=eq, bounds=(-1e-4, 1e-4), normalize_target=False), + ForceBalance(eq=eq), FixPressure(eq=eq), FixParameter(surf, ["Z_lmn", "R_lmn"], [[-1], [0]]), FixParameter(eq, ["Psi", "i_l"]), @@ -1301,3 +1305,42 @@ def test_example_get_current(self): -1.36284423e07, ], ) + + +@pytest.mark.unit +def test_single_coil_optimization(): + """Test that single coil (not coilset) optimization works.""" + # testing that the objectives work and that the optimization framework + # works when a single coil is passed in. + + opt = Optimizer("fmintr") + coil = FourierRZCoil() + coil.change_resolution(N=1) + target_R = 9 + # length and curvature + target_length = 2 * np.pi * target_R + target_curvature = 1 / target_R + grid = LinearGrid(N=2) + obj = ObjectiveFunction( + ( + CoilLength(coil, target=target_length), + CoilCurvature(coil, target=target_curvature, grid=grid), + ), + ) + opt.optimize([coil], obj, maxiter=200) + np.testing.assert_allclose( + coil.compute("length")["length"], target_length, rtol=1e-4 + ) + np.testing.assert_allclose( + coil.compute("curvature", grid=grid)["curvature"], target_curvature, rtol=1e-4 + ) + + # torsion + # initialize with some torsion + coil.Z_n = coil.Z_n.at[0].set(0.1) + target = 0 + obj = ObjectiveFunction(CoilTorsion(coil, target=target)) + opt.optimize([coil], obj, maxiter=200, ftol=0) + np.testing.assert_allclose( + coil.compute("torsion", grid=grid)["torsion"], target, atol=1e-5 + ) diff --git a/tests/test_objective_funs.py b/tests/test_objective_funs.py index 91ab45f75a..3741a3f4f5 100644 --- a/tests/test_objective_funs.py +++ b/tests/test_objective_funs.py @@ -12,7 +12,7 @@ import desc.examples from desc.backend import jnp -from desc.coils import CoilSet, FourierXYZCoil +from desc.coils import CoilSet, FourierPlanarCoil, FourierXYZCoil, MixedCoilSet from desc.compute import get_transforms from desc.equilibrium import Equilibrium from desc.examples import get @@ -28,6 +28,9 @@ BootstrapRedlConsistency, BoundaryError, BScaleLength, + CoilCurvature, + CoilLength, + CoilTorsion, CurrentDensity, Elongation, Energy, @@ -574,6 +577,83 @@ def test(eq): test(get("DSHAPE")) test(get("HELIOTRON")) + @pytest.mark.unit + def test_coil_length(self): + """Tests coil length.""" + + def test(coil, grid=None): + obj = CoilLength(coil, grid=grid) + obj.build() + f = obj.compute(params=coil.params_dict) + np.testing.assert_allclose(f, 2 * np.pi, rtol=1e-8) + assert len(f) == obj.dim_f + + coil = FourierPlanarCoil(r_n=1) + coils = CoilSet.linspaced_linear(coil, n=2) + mixed_coils = MixedCoilSet.linspaced_linear(coil, n=2) + nested_coils = MixedCoilSet(coils, coils) + + nested_grids = [ + [LinearGrid(N=5), LinearGrid(N=5)], + [LinearGrid(N=5), LinearGrid(N=5)], + ] + test(coil, grid=LinearGrid(N=5)) + test(coils) + test(mixed_coils, grid=[LinearGrid(N=5)] * len(mixed_coils.coils)) + test(nested_coils, grid=nested_grids) + + @pytest.mark.unit + def test_coil_curvature(self): + """Tests coil curvature.""" + + def test(coil, grid=None): + obj = CoilCurvature(coil, grid=grid) + obj.build() + f = obj.compute(params=coil.params_dict) + np.testing.assert_allclose(f, 1 / 2, rtol=1e-8) + assert len(f) == obj.dim_f + + coil = FourierPlanarCoil() + coils = CoilSet.linspaced_linear(coil, n=2) + mixed_coils = MixedCoilSet.linspaced_linear(coil, n=2) + nested_coils = MixedCoilSet(coils, coils) + + nested_grids = [ + [LinearGrid(N=5), LinearGrid(N=5)], + [LinearGrid(N=5), LinearGrid(N=5)], + ] + + test(coil, grid=LinearGrid(N=5)) + test(coils) + test(mixed_coils, grid=[LinearGrid(N=5)] * len(mixed_coils.coils)) + test(nested_coils, grid=nested_grids) + + @pytest.mark.unit + def test_coil_torsion(self): + """Tests coil torsion.""" + + def test(coil, grid=None): + obj = CoilTorsion(coil, grid=grid) + obj.build() + f = obj.compute(params=coil.params_dict) + np.testing.assert_allclose(f, 0, atol=1e-8) + assert len(f) == obj.dim_f + + coil = FourierPlanarCoil() + coils = CoilSet.linspaced_linear(coil, n=2) + mixed_coils = MixedCoilSet.linspaced_linear(coil, n=2) + nested_coils = MixedCoilSet(coils, coils) + + nested_grids = [ + [LinearGrid(N=5), LinearGrid(N=5)], + [LinearGrid(N=5), LinearGrid(N=5)], + ] + + test(coil, grid=LinearGrid(N=5)) + test(coils) + test(mixed_coils, grid=[LinearGrid(N=5)] * len(mixed_coils.coils)) + test(nested_coils, grid=nested_grids) + @pytest.mark.unit def test_derivative_modes(): @@ -1365,30 +1445,6 @@ def test_boundary_error_print(capsys): assert out.out == corr_out -@pytest.mark.unit -def test_rebuild(): - """Test that the objective is rebuilt correctly when needed.""" - eq = Equilibrium(L=3, M=3) - f_obj = ForceBalance(eq=eq) - obj = ObjectiveFunction(f_obj) - eq.solve(maxiter=2, objective=obj) - - # this would fail before v0.8.2 when trying to get objective.x - eq.change_resolution(L=5, M=5) - obj.build(eq) - eq.solve(maxiter=2, objective=obj) - - eq = Equilibrium(L=3, M=3) - f_obj = ForceBalance(eq=eq) - obj = ObjectiveFunction(f_obj) - eq.solve(maxiter=2, objective=obj) - eq.change_resolution(L=5, M=5) - # this would fail at objective.compile - obj = ObjectiveFunction(f_obj) - obj.build(eq) - eq.solve(maxiter=2, objective=obj) - - @pytest.mark.unit def test_objective_fun_things(): """Test that the objective things logic works correctly.""" diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 4a703b900e..c066b642f2 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -355,9 +355,9 @@ def compute(self, params, constants=None): np.random.seed(0) objective = ObjectiveFunction(DummyObjective(things=eq), use_jit=False) # make gradient super noisy so it stalls - objective.jac_scaled = lambda x, *args: objective._jac_scaled(x) + 1e2 * ( - np.random.random((objective._dim_f, x.size)) - 0.5 - ) + objective.jac_scaled_error = lambda x, *args: objective._jac_scaled_error( + x + ) + 1e2 * (np.random.random((objective._dim_f, x.size)) - 0.5) n = 10 R_modes = np.vstack(