Source code for pyhdtoolkit.plotting.layout

"""
.. _plotting-layout:

Layout Plotters
---------------

Module with functions used to represent a machine's
elements in an `~matplotlib.axes.Axes` object, mostly
used in different `~pyhdtoolkit.plotting` modules.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from loguru import logger
from matplotlib import patches

from pyhdtoolkit.plotting.utils import (
    _get_twiss_table_with_offsets_and_limits,
    make_elements_groups,
    maybe_get_ax,
)

if TYPE_CHECKING:
    from cpymad.madx import Madx
    from matplotlib.axes import Axes
    from pandas import DataFrame


[docs] def plot_machine_layout( madx: Madx, /, title: str | None = None, xoffset: float = 0, xlimits: tuple[float, float] | None = None, plot_dipoles: bool = True, plot_dipole_k1: bool = False, plot_quadrupoles: bool = True, plot_bpms: bool = False, k0l_lim: tuple[float, float] | float | None = None, k1l_lim: tuple[float, float] | float | None = None, k2l_lim: tuple[float, float] | float | None = None, k3l_lim: tuple[float, float] | float | None = None, **kwargs, ) -> None: """ .. versionadded:: 1.0.0 Draws patches elements representing the lattice layout on the given *axis*. This is the function that takes care of the machine layout axis in `~.plotting.lattice.plot_latwiss` and `~.plotting.aperture.plot_aperture`. Its results can be seen in the :ref:`machine lattice <demo-accelerator-lattice>` and :ref:`machine aperture <demo-accelerator-aperture>` example galleries. Note ---- This current implementation can plot dipoles, quadrupoles, sextupoles, octupoles and BPMs. Important --------- If not provided, the limits for the ``k0l_lim``, ``k1l_lim`` will be auto-determined, which might not be the perfect choice for the plot. When providing these limits (also for ``k2l_lim``), make sure to provide symmetric values around 0 (so [-x, x]) otherwise the element patches will show up vertically displaced from the axis' center line. Warning ------- Currently the function tries to plot legends for the different layout patches. The position of the different legends has been hardcoded in corners of the `~matplotlib.axes.Axes` and might require users to tweak the axis limits (through ``k0l_lim``, ``k1l_lim`` and ``k2l_lim``) to ensure legend labels and plotted elements don't overlap. Parameters ---------- madx : cpymad.madx.Madx An instanciated `~cpymad.madx.Madx` object. Positional only. title : str, optional If provided, is set as title of the plot. xoffset : float An offset applied to the ``S`` coordinate before plotting. This is useful if you want to center a plot around a specific point or element, which would then become located at :math:`s = 0`. Beware this offset is applied before applying the *xlimits*. Defaults to 0. xlimits : tuple[float, float], optional If given, will be used for the xlim (for the ``s`` coordinate), using the tuple passed. plot_dipoles : bool If `True`, dipole patches will be plotted on the layout subplot of the figure. Defaults to `True`. Dipoles are plotted in blue. plot_dipole_k1 : bool If `True`, dipole elements with a quadrupolar gradient will have this gradient plotted as a quadrupole patch. Defaults to `False`. plot_quadrupoles : bool If `True`, quadrupole patches will be plotted on the layout subplot of the figure. Defaults to `True`. Quadrupoles are plotted in red. plot_bpms : bool If `True`, additional patches will be plotted on the layout subplot to represent Beam Position Monitors. BPMs are plotted in dark grey. Defaults to `False`. k0l_lim : tuple[float, float] | float, optional If given, will be used as vertical axis limits for the ``k0l`` values used for the height of dipole patches. Can be given as a single value (float, int) or a tuple (in which case it should be symmetric). If `None` is given, then the limits will be determined automatically based on the ``k0l`` values of the dipoles. k1l_lim : tuple[float, float] | float, optional If given, will be used as vertical axis limits for the ``k1l`` values used for the height of quadrupole patches. Can be given as a single value (float, int) or a tuple (in which case it should be symmetric). If `None` is given, then the limits will be determined automatically based on the ``k1l`` values of the quadrupoles. k2l_lim : tuple[float, float] | float, optional If given, will be used as vertical axis limits for the ``k2l`` values used for the height of sextupole patches. Can be given as a single value (float, int) or a tuple (in which case it should be symmetric). If `None` is given, then the limits will be determined automatically based on the ``k2l`` values of the sextupoles. k3l_lim : tuple[float, float] | float, optional If given, will be used as vertical axis limits for the ``k3l`` values used for the height of octupole patches. Can be given as a single value (float, int) or a tuple (in which case it should be symmetric). If `None` is given, then the limits will be determined automatically based on the ``k3l`` values of the octupoles. **kwargs Any keyword argument will be transmitted to `~.plotting.utils._plot_lattice_series`, and then `~matplotlib.patches.Rectangle`, such as ``lw`` etc. If either `ax` or `axis` is found in the kwargs, the corresponding value is used as the axis object to plot on. By definition, the quadrupole elements will be drawn on said axis, and for each new element type to plot a call to `~matplotlib.axes.Axes.twinx` is made and the new elements will be drawn on the newly created twin `~matplotlib.axes.Axes`. If ``bpms_legend`` is given as `False` and BPMs are plotted, the BPM legend will not be plotted on the layout axis. Example ------- .. code-block:: python fig, ax = plt.subplots(figsize=(6, 2)) plot_machine_layout(madx, title="Machine Elements", lw=3) """ # pylint: disable=too-many-arguments axis, kwargs = maybe_get_ax(**kwargs) bpms_legend = kwargs.pop("bpms_legend", True) twiss_df = _get_twiss_table_with_offsets_and_limits(madx, xoffset, xlimits) logger.trace("Extracting element-specific dataframes") element_dfs = make_elements_groups(madx, xoffset, xlimits) dipoles_df = element_dfs["dipoles"] quadrupoles_df = element_dfs["quadrupoles"] sextupoles_df = element_dfs["sextupoles"] octupoles_df = element_dfs["octupoles"] bpms_df = element_dfs["bpms"] logger.trace("Determining the ylimits for k0l and k1l patches") # Assume lattice doesnt mix 'k0l' and 'angle' for dipoles powering dipoles_power_column = "k0l" if dipoles_df.k0l.any() else "angle" k0l_lim = ( _ylim_from_input(k0l_lim, "k0l_lim") if k0l_lim is not None else _determine_default_knl_lim(dipoles_df, col=dipoles_power_column, coeff=2) ) k1l_lim = ( _ylim_from_input(k1l_lim, "k1l_lim") if k1l_lim is not None else _determine_default_knl_lim(quadrupoles_df, col="k1l", coeff=1.3) ) logger.debug("Plotting machine layout") logger.trace(f"Plotting from axis '{axis}'") axis.set_ylabel("$1/f=K_{1}L$ $[m^{-1}]$", color="red") # quadrupole in red axis.tick_params(axis="y", labelcolor="red") axis.set_ylim(k1l_lim) if xlimits is not None: axis.set_xlim(xlimits) axis.set_title(title) axis.plot(twiss_df.s, 0 * twiss_df.s, "k") # 0-level line axis.grid(False) dipole_patches_axis = axis.twinx() dipole_patches_axis.set_ylabel("$\\theta=K_{0}L$ $[rad]$", color="royalblue") # dipoles in blue dipole_patches_axis.tick_params(axis="y", labelcolor="royalblue") if np.nan not in k0l_lim: dipole_patches_axis.set_ylim(k0l_lim) dipole_patches_axis.grid(False) if plot_dipoles: # beware 'sbend' and 'rbend' have an 'angle' value and not a 'k0l' logger.trace("Plotting dipole patches") plotted_elements = 0 # will help us not declare a label for legend at every patch for dipole_name, dipole in dipoles_df.iterrows(): logger.trace(f"Plotting dipole element '{dipole_name}'") bend_value = dipole.k0l if dipole.k0l != 0 else dipole.angle # check for each element _plot_lattice_series( dipole_patches_axis, dipole, height=bend_value, v_offset=bend_value / 2, color="royalblue", label="MB" if plotted_elements == 0 else None, # avoid duplicating legend labels **kwargs, ) if dipole.k1l != 0 and plot_dipole_k1: # plot dipole quadrupolar gradient (with reduced alpha) logger.trace(f"Plotting quadrupolar gradient of dipole element '{dipole_name}'") _plot_lattice_series( axis, dipole, height=dipole.k1l, v_offset=dipole.k1l / 2, color="r", **kwargs, ) plotted_elements += 1 logger.debug(f"Plotted {plotted_elements} dipole elements") if plotted_elements > 0: # If we plotted at least one dipole, we need to plot the legend dipole_patches_axis.legend(loc=1) if plot_quadrupoles: logger.trace("Plotting quadrupole patches") plotted_elements = 0 for quadrupole_name, quadrupole in quadrupoles_df.iterrows(): logger.trace(f"Plotting quadrupole element '{quadrupole_name}'") element_k = quadrupole.k1l if quadrupole.k1l != 0 else quadrupole.k1sl # can be skew quadrupole _plot_lattice_series( axis, quadrupole, height=element_k, v_offset=element_k / 2, color="r", hatch=None if quadrupole.k1l != 0 else "///", # hatch skew quadrupoles label="MQ" if plotted_elements == 0 else None, # avoid duplicating legend labels **kwargs, ) plotted_elements += 1 logger.debug(f"Plotted {plotted_elements} quadrupole elements") if plotted_elements > 0: # If we plotted at least one quadrupole, we need to plot the legend axis.legend(loc=2) if k2l_lim: logger.trace("Plotting sextupole patches") sextupoles_patches_axis = axis.twinx() sextupoles_patches_axis.set_ylabel("$K_{2}L$ $[m^{-2}]$", color="darkgoldenrod") sextupoles_patches_axis.tick_params(axis="y", labelcolor="darkgoldenrod") sextupoles_patches_axis.spines["right"].set_position(("axes", 1.12)) k2l_lim = _ylim_from_input(k2l_lim, "k2l_lim") sextupoles_patches_axis.set_ylim(k2l_lim) plotted_elements = 0 for sextupole_name, sextupole in sextupoles_df.iterrows(): logger.trace(f"Plotting sextupole element '{sextupole_name}'") element_k = sextupole.k2l if sextupole.k2l != 0 else sextupole.k2sl # can be skew sextupole _plot_lattice_series( sextupoles_patches_axis, sextupole, height=element_k, v_offset=element_k / 2, color="goldenrod", hatch=None if sextupole.k2l != 0 else "\\\\\\", # hatch skew sextupoles label="MS" if plotted_elements == 0 else None, # avoid duplicating legend labels **kwargs, ) plotted_elements += 1 logger.debug(f"Plotted {plotted_elements} sextupole elements") sextupoles_patches_axis.grid(False) if plotted_elements > 0: # If we plotted at least one sextupole, we need to plot the legend sextupoles_patches_axis.legend(loc=3) if k3l_lim: logger.trace("Plotting octupole patches") octupoles_patches_axis = axis.twinx() octupoles_patches_axis.set_ylabel("$K_{3}L$ $[m^{-3}]$", color="forestgreen") octupoles_patches_axis.tick_params(axis="y", labelcolor="forestgreen") octupoles_patches_axis.yaxis.set_label_position("left") octupoles_patches_axis.yaxis.tick_left() octupoles_patches_axis.spines["left"].set_position(("axes", -0.14)) k3l_lim = _ylim_from_input(k3l_lim, "k3l_lim") octupoles_patches_axis.set_ylim(k3l_lim) plotted_elements = 0 for octupole_name, octupole in octupoles_df.iterrows(): logger.trace(f"Plotting octupole element '{octupole_name}'") element_k = octupole.k3l if octupole.k3l else octupole.k3sl # can be skew octupole _plot_lattice_series( octupoles_patches_axis, octupole, height=octupole.k3l, v_offset=octupole.k3l / 2, color="forestgreen", hatch=None if octupole.k3l != 0 else "xxx", # hatch skew octupoles label="MO" if plotted_elements == 0 else None, # avoid duplicating legend labels **kwargs, ) plotted_elements += 1 logger.debug(f"Plotted {plotted_elements} octupole elements") octupoles_patches_axis.grid(False) if plotted_elements > 0: # If we plotted at least one octupole, we need to plot the legend octupoles_patches_axis.legend(loc=4) if plot_bpms: logger.trace("Plotting BPM patches") bpm_patches_axis = axis.twinx() bpm_patches_axis.set_axis_off() # hide yticks, labels etc bpm_patches_axis.set_ylim(-1.6, 1.6) plotted_elements = 0 for bpm_name, bpm in bpms_df.iterrows(): logger.trace(f"Plotting BPM element '{bpm_name}'") _plot_lattice_series( bpm_patches_axis, bpm, height=2, v_offset=0, color="dimgrey", label="BPM" if plotted_elements == 0 else None, # avoid duplicating legend labels **kwargs, ) plotted_elements += 1 logger.debug(f"Plotted {plotted_elements} BPMs") logger.trace("Determining BPM legend location") if bpms_legend is True: if k2l_lim is not None and k3l_lim is not None: bpm_legend_loc = 8 # all corners are taken, we go bottom center elif k2l_lim is not None: bpm_legend_loc = 4 # sextupoles are here but not octupoles, we go bottom left elif k3l_lim is not None: # pragma: no cover bpm_legend_loc = 3 # octupoles are here but not sextupoles, we go bottom right else: bpm_legend_loc = "best" # can't easily determine the best position, go automatic and leave to the user if plotted_elements > 0: # If we plotted at least one BPM, we need to plot the legend bpm_patches_axis.legend(loc=bpm_legend_loc) bpm_patches_axis.grid(False)
[docs] def scale_patches(scale: float, ylabel: str, **kwargs) -> None: """ .. versionadded:: 1.3.0 This is a convenience function to update the scale of the elements layout patches as well as the corresponding y-axis label. Parameters ---------- scale : float The scale factor to apply to the patches. The new height of the patches will be ``scale * original_height``. ylabel : str The new label for the y-axis. **kwargs If either `ax` or `axis` is found in the kwargs, the corresponding value is used as the axis object to plot on, otherwise the current axis is used. Example ------- .. code-block:: python fig, ax = plt.subplots(figsize=(6, 2)) plot_machine_layout(madx, title="Machine Elements", lw=3) scale_patches(ax=fig.axes[0], scale=100, ylabel=r"$K_{1}L$ $[10^{-2} m^{-1}]$") """ axis, kwargs = maybe_get_ax(**kwargs) axis.set_ylabel(ylabel) for patch in axis.patches: h = patch.get_height() patch.set_height(scale * h)
# ----- Helpers ----- # def _plot_lattice_series( ax: Axes, series: DataFrame, height: float = 1.0, v_offset: float = 0.0, color: str = "r", alpha: float = 0.5, **kwargs, ) -> None: """ .. versionadded:: 1.0.0 Plots a `~matplotlib.patches.Rectangle` element on the provided `~matplotlib.axes.Axes` to represent an element of the machine. Original code from :user:`Guido Sterbini <sterbini>`. Parameters ---------- ax : matplotlib.axes.Axes An existing `~matplotlib.axes.Axes` object to draw on. series : pd.DataFrame A `pandas.DataFrame` with the elements' data. height : float Value to reach for the patch on the y axis. Defaults to 1. v_offset : float Vertical offset for the patch. Defaults to 0. Should not be used unless you know exactly what you're doing. color : str Color kwarg to transmit to `~matplotlib.pyplot`. Defaults to 'r', for red. alpha : float Alpha kwarg to transmit to `~matplotlib.pyplot`. Defaults to 0.5. **kwargs Any keyword argument will be transmitted to `~matplotlib.patches.Rectangle`, for instance ``lw`` for the edge line width. """ ax.add_patch( patches.Rectangle( (series.s - series.l, v_offset - height / 2.0), # anchor point series.l, # width height, # height color=color, alpha=alpha, **kwargs, ) ) def _ylim_from_input(ylim: tuple[float, float] | float | int, name_for_error: str = "knl_lim") -> tuple[float, float]: """ .. versionadded:: 1.2.0 Determines the ylimits for a given axis from the input provided by the user. This is used in `~.plotting.utils.plot_machine_layout` and handles different inputs from the user, such as a tuple, a float and an int. Parameters ---------- ylim : tuple[float, float] | float | int The input provided by the user. name_for_error : str The name of the variable to use in the error message. Returns ------- tuple[float, float] A `tuple` for the ylimits from the input. Raises ------ TypeError If the input is not a `tuple`, a `float` or an `int`. """ if not isinstance(ylim, tuple | float | int): msg = f"Invalid type for '{name_for_error}': {type(ylim)}. " raise TypeError(msg) if isinstance(ylim, tuple): return ylim # otherwise we have float | int if ylim >= 0: return (-ylim, ylim) return (ylim, -ylim) def _determine_default_knl_lim(df: DataFrame, col: str, coeff: float) -> tuple[float, float]: """ .. versionadded:: 1.0.0 Determine the default limits for the ``knl`` axis, when plotting machine layout. This is in case `None` are provided by the user, to make sure the plot still looks coherent and symmetric in `~.plotting.utils.plot_machine_layout`. The limits are determined symmetric, using the maximum absolute value of the knl column in the provided dataframe and a 1.25 scaling factor. Parameters ---------- df : pandas.DataFrame A `pandas.DataFrame` with the multipoles' data. The ``knl`` column is used to determine the limits. col : str The 'knl' column to query in the dataframe. coeff : float A scaling factor to apply to the max absolute value when determining the limits. Returns ------- tuple[float, float] A `tuple` with the ylimits for the knl axis. """ logger.debug(f"Determining '{col}_lim' based on plotted data") max_val = df[col].abs().max() max_val_scaled = coeff * max_val logger.debug(f"Determined '{col}_lim' are: (-{max_val_scaled}, {max_val_scaled})") return (-max_val_scaled, max_val_scaled)