Source code for disco._particle_history

import warnings

from astropy import units, constants
import cupy as cp
import h5py
from matplotlib import pyplot as plt
import numpy as np

from disco._dimensionalization import (
    undim_time,
    undim_magnetic_field,
    undim_space,
    undim_energy,
    undim_momentum,
    undim_magnetic_moment,
)

# Units for when variables are stored on disk
TIME_UNITS = units.s
SPACE_UNITS = constants.R_earth
MOMENTUM_UNITS = units.keV * units.s / units.m
MAGFIELD_UNITS = units.nT
MAGNETIC_MOMENT_UNITS = units.MeV / units.nT
ENERGY_UNITS = units.eV
MASS_UNITS = units.kg
CHARGE_UNITS = units.C

# Maximum number of particles before a warning is issued
MAX_PARTICLES_BEFORE_WARNING = 1000


[docs] class ParticleHistory: """History of trajectory tracing. Arrays are in the shape (n_time_steps, n_particles). Mass and charge are scalars. If some parameters completed integration in fewer timesteps than others (such as by reaching the integration limit or going out of bounds), the last item will be duplicated to match the shape of the other arrays. The last item before duplication can be found by checking the `stopped` array, which is a boolean array of the same shape as the other arrays. Output can be stored and read from disk in HDF5 format. See the `load()` and `save()` methods. Attributes ---------- t : array with units Time of the particle state at each step. x : array with units X position of the particle at each step. y : array with units Y position of the particle at each step. z : array with units Z position of the particle at each step. ppar : array with units Parallel momentum of the particle at each step. M : array with units Magnetic moment of the particle at each step. B : array with units Magnetic field at the particle position at each step. W : array with units Total energy of the particle at each step. h : array with units Adapative step size used in the integration at each step. stopped : array of bool Boolean array indicating whether the particle stopped at each step. extra_fields : dict Dictionary of additional fields computed during the trajectory tracing. Keys are field names and values are arrays with the same shape as the other arrays. mass : scalar with units Mass of the particles (constant). charge : scalar with units Charge of the particles (constant). Notes ----- See `disco.TraceConfig(output_freq=...)`: for controlling between how many iterations between particle state is saved. If `output_freq` is set to `None` (the default), only the first and last points of the trace will be saved. Examples -------- Saving output to disk: >>> hist = disco.trace_trajectory(config, particle_state, field_model) >>> hist.save("particle_history.h5") Loading output from disk and plotting: >>> hist = disco.ParticleHistory.load("particle_history.h5") >>> hist.plot_xz() >>> plt.savefig('myplot.png') """ def __init__(self, t, x, y, z, ppar, M, B, W, h, stopped, mass, charge, extra_fields=None): self.t = t.to(TIME_UNITS) self.x = x.to(SPACE_UNITS) self.y = y.to(SPACE_UNITS) self.z = z.to(SPACE_UNITS) self.ppar = ppar.to(MOMENTUM_UNITS) self.M = M.to(MAGNETIC_MOMENT_UNITS) self.B = B.to(MAGFIELD_UNITS) self.W = W.to(ENERGY_UNITS) self.h = h.to(TIME_UNITS) self.stopped = stopped.astype(bool) # Ensure mass and charge are scalars with units self.mass = mass.to(MASS_UNITS) self.charge = charge.to(CHARGE_UNITS) # If extra_fields is None, initialize as an empty dictionary if extra_fields is None: self.extra_fields = {} else: self.extra_fields = extra_fields
[docs] def save(self, hdf_path): """Save particle history to an HDF5 file. Parameters ---------- hdf_path: str Path to the HDF5 file where the history will be saved. Notes ----- See `ParticleHistory.load()`: to load particle history from an HDF5 file. """ with h5py.File(hdf_path, "w") as hdf: # Set main fields hdf["t"] = self.t.to_value(TIME_UNITS) hdf["x"] = self.x.to_value(SPACE_UNITS) hdf["y"] = self.y.to_value(SPACE_UNITS) hdf["z"] = self.z.to_value(SPACE_UNITS) hdf["ppar"] = self.ppar.to_value(MOMENTUM_UNITS) hdf["M"] = self.M.to_value(MAGNETIC_MOMENT_UNITS) hdf["B"] = self.B.to_value(MAGFIELD_UNITS) hdf["W"] = self.W.to_value(ENERGY_UNITS) hdf["h"] = self.h.to_value(TIME_UNITS) hdf["stopped"] = self.stopped hdf["mass"] = self.mass.to_value(MASS_UNITS) hdf["charge"] = self.charge.to_value(CHARGE_UNITS) # Set extra fields if they exist if self.extra_fields: extra_fields_group = hdf.create_group("extra_fields") for key, value in self.extra_fields.items(): extra_fields_group[key] = value # Set attributes for units hdf["t"].attrs["UNITS"] = TIME_UNITS.to_string() hdf["x"].attrs["UNITS"] = SPACE_UNITS.to_string() hdf["y"].attrs["UNITS"] = SPACE_UNITS.to_string() hdf["z"].attrs["UNITS"] = SPACE_UNITS.to_string() hdf["ppar"].attrs["UNITS"] = MOMENTUM_UNITS.to_string() hdf["M"].attrs["UNITS"] = MAGNETIC_MOMENT_UNITS.to_string() hdf["B"].attrs["UNITS"] = MAGFIELD_UNITS.to_string() hdf["W"].attrs["UNITS"] = ENERGY_UNITS.to_string() hdf["h"].attrs["UNITS"] = TIME_UNITS.to_string() hdf["mass"].attrs["UNITS"] = MASS_UNITS.to_string() hdf["charge"].attrs["UNITS"] = CHARGE_UNITS.to_string()
[docs] @classmethod def load(cls, hdf_path): """Load particle history from an HDF5 file. Parameters ---------- hdf_path: str Path to the HDF5 file from which the history will be loaded. Returns ------- An instance of `ParticleHistory` containing the loaded data. Notes ----- See `ParticleHistory.save()` to save particle history to an HDF5 file. """ with h5py.File(hdf_path, "r") as hdf: t = hdf["t"][:] * TIME_UNITS x = hdf["x"][:] * SPACE_UNITS y = hdf["y"][:] * SPACE_UNITS z = hdf["z"][:] * SPACE_UNITS ppar = hdf["ppar"][:] * MOMENTUM_UNITS M = hdf["M"][:] * MAGNETIC_MOMENT_UNITS B = hdf["B"][:] * MAGFIELD_UNITS W = hdf["W"][:] * ENERGY_UNITS h = hdf["h"][:] * TIME_UNITS stopped = hdf["stopped"][:].astype(bool) mass = hdf["mass"][()] * MASS_UNITS charge = hdf["charge"][()] * CHARGE_UNITS if "extra_fields" in hdf.keys(): extra_fields = {} for key, value in hdf["extra_fields"].items(): extra_fields[key] = value[:] else: extra_fields = None return cls(t, x, y, z, ppar, M, B, W, h, stopped, mass, charge, extra_fields=extra_fields)
def _plot_trajectory( self, ax, x_vals, y_vals, inds, endpoints, sample, earth, grid, title, xlabel, ylabel, ): """Helper function to plot particle trajectory in a 2D plane.""" # Create a new figure and axis if none is provided if ax is None: _, ax = plt.subplots() # Determine indices to plot if inds is None: inds = np.arange(x_vals.shape[1], dtype=int) elif not isinstance(inds, np.ndarray): inds = np.array([inds]) if sample is not None: np.random.shuffle(inds) inds = inds[:sample] # Issue warning if too many particles are being plotted if inds.size > MAX_PARTICLES_BEFORE_WARNING: warnings.warn( "Plotting more than 1000 points may be slow. Consider downsampling with sample=500." ) # Ensure indices are within bounds if inds.max() >= x_vals.shape[1]: raise IndexError( f"Index {inds.max()} is out of bounds for the number of particles {x_vals.shape[1]}." ) # Plot the trajectories or endpoints if endpoints: ax.plot(x_vals[-1, inds], y_vals[-1, inds], marker=".") else: for i in inds: ax.plot(x_vals[:, i], y_vals[:, i]) # Draw earth if set if earth: earth_circle = plt.Circle((0, 0), 1, color="k", zorder=100) ax.add_patch(earth_circle) # Setup axis labels and title ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) ax.set_aspect("equal", adjustable="box") # Setup grid if set if grid: ax.grid(True, which="both", linestyle="--", linewidth=0.5) ax.axhline(0, color="black", linewidth=0.5) ax.axvline(0, color="black", linewidth=0.5) return ax
[docs] def plot_xy( self, ax=None, inds=None, endpoints=False, sample=None, earth=True, grid=True, title="Particle Trajectory in XY Plane", ): """Plot the particle trajectory in the XY plane. Parameters ---------- ax: matplotlib axes Matplotlib axis to plot on. If None, a new figure and axis will be created. inds: int or list of ints Indices of the points to plot. If None, all points will be plotted. endpoints: bool If True, plot only the start and end points of the trajectory. sample: int, optional If specified, randomly sample this many particles to plot. earth: bool If True, draw a circle representing the Earth at the origin. grid: bool If True, add a grid to the plot. title: str Title of the plot. Returns ------- The axis with the plotted trajectory. Examples -------- >>> hist = disco.ParticleHistory.load("particle_history.h5") >>> hist.plot_xy() >>> plt.savefig('myplot.png') """ return self._plot_trajectory( ax, self.x.value, self.y.value, inds, endpoints, sample, earth, grid, title, "X ($R_E$)", "Y ($R_E$)", )
[docs] def plot_xz( self, ax=None, inds=None, endpoints=False, sample=None, earth=True, grid=True, title="Particle Trajectory in XZ Plane", ): """Plot the particle trajectory in the XZ plane. Parameters ---------- ax: matplotlib axes Matplotlib axis to plot on. If None, a new figure and axis will be created. inds: int or list of ints Indices of the points to plot. If None, all points will be plotted. endpoints: bool If True, plot only the start and end points of the trajectory. sample: int, optional If specified, randomly sample this many particles to plot. earth: bool If True, draw a circle representing the Earth at the origin. grid: bool If True, add a grid to the plot. title: str Title of the plot. Returns ------- The axis with the plotted trajectory. Examples -------- >>> hist = disco.ParticleHistory.load("particle_history.h5") >>> hist.plot_xz() >>> plt.savefig('myplot.png') """ return self._plot_trajectory( ax, self.x.value, self.z.value, inds, endpoints, sample, earth, grid, title, "X ($R_E$)", "Z ($R_E$)", )
[docs] def plot_yz( self, ax=None, inds=None, endpoints=False, sample=None, earth=True, grid=True, title="Particle Trajectory in YZ Plane", ): """Plot the particle trajectory in the YZ plane. Parameters ---------- ax: matplotlib axes Matplotlib axis to plot on. If None, a new figure and axis will be created. inds: int or list of ints Indices of the points to plot. If None, all points will be plotted. endpoints: bool If True, plot only the start and end points of the trajectory. sample: int, optional If specified, randomly sample this many particles to plot. earth: bool If True, draw a circle representing the Earth at the origin. grid: bool If True, add a grid to the plot. title: str Title of the plot. Returns ------- The axis with the plotted trajectory. Examples -------- >>> hist = disco.ParticleHistory.load("particle_history.h5") >>> hist.plot_yz() >>> plt.savefig('myplot.png') """ return self._plot_trajectory( ax, self.y.value, self.z.value, inds, endpoints, sample, earth, grid, title, "Y ($R_E$)", "Z ($R_E$)", )
class ParticleHistoryBuffer: """Buffer for storing history of particle trajectories. This class is used to accumulate history of particle trajectories during the tracing process, and related variables tracked. """ def __init__(self): """Initialize a `HistoryBuffer` instance.""" self.t = [] self.y = [] self.ppar = [] self.B = [] self.W = [] self.h = [] self.stopped = [] self.extra_fields = [] def append(self, t, y, B, h, stopped, extra_fields, total_reorder=None): """Append a new history entry to the buffer.""" if total_reorder is None: total_reorder_rev = np.arange(len(t), dtype=int) else: total_reorder_rev = np.argsort(total_reorder) _, W = _calc_gamma_W(B, y) self.t.append(t[total_reorder_rev].get()) self.y.append(y[total_reorder_rev].get()) self.B.append(B[total_reorder_rev].get()) self.W.append(W[total_reorder_rev].get()) self.h.append(h[total_reorder_rev].get()) self.stopped.append(stopped[total_reorder_rev].get()) if extra_fields.size > 0: self.extra_fields.append(extra_fields[total_reorder_rev].get()) def to_particle_history(self, particle_state, field_model): """Convert the accumulated history to a `ParticleHistory` instance. Parameters ---------- particle_state: `ParticleState` The initial conditions of the particles, used to set mass and charge. Returns ------- ParticleHistory A `ParticleHistory` instance containing the accumulated history, which can be used to save or plot results. """ hist_t = undim_time(np.array(self.t)) hist_B = undim_magnetic_field(np.array(self.B), particle_state.mass, particle_state.charge) hist_W = undim_energy(np.array(self.W), particle_state.mass) hist_h = undim_time(np.array(self.h)) hist_stopped = np.array(self.stopped) hist_raw_extra_fields = np.array(self.extra_fields) hist_y = np.array(self.y) hist_pos_x = undim_space(hist_y[:, :, 0]) hist_pos_y = undim_space(hist_y[:, :, 1]) hist_pos_z = undim_space(hist_y[:, :, 2]) hist_ppar = undim_momentum(hist_y[:, :, 3], particle_state.mass) hist_M = undim_magnetic_moment(hist_y[:, :, 4], particle_state.charge) if len(field_model.extra_fields) > 0: hist_extra_fields = {} key_names = list(field_model.extra_fields.keys()) for i, key in enumerate(key_names): hist_extra_fields[key] = hist_raw_extra_fields[:, :, i] else: hist_extra_fields = None return ParticleHistory( t=hist_t, x=hist_pos_x, y=hist_pos_y, z=hist_pos_z, ppar=hist_ppar, M=hist_M, B=hist_B, W=hist_W, h=hist_h, stopped=hist_stopped, mass=particle_state.mass, charge=particle_state.charge, extra_fields=hist_extra_fields, ) def _calc_gamma_W(B, y): """Calculate gamma (relativistic factor) and W (relativistic energy) for saving in history. Parameters ---------- B : cupy array Magnetic Field Strength, dimensionalized y : cupy array State vector, dimensionalied Returns ------- gamma: cupy array Relativstic factor, dimensionalized W : cupy array Relativistic Energy, dimensionalized """ gamma = cp.sqrt(1 + 2 * B * y[:, 4] + y[:, 3] ** 2) W = gamma - 1 return gamma, W