Source code for jdaviz.configs.specviz.plugins.viewers

import warnings

import numpy as np
from astropy import table
from astropy import units as u
from functools import cached_property
from matplotlib.colors import cnames
from specutils import Spectrum
from scipy.interpolate import interp1d
from glue.core import BaseData
from glue_jupyter.bqplot.image import BqplotImageView

from jdaviz.core.events import (AddDataToViewerMessage,
                                RemoveDataFromViewerMessage,
                                SpectralMarksChangedMessage,
                                LineIdentifyMessage)
from jdaviz.core.freezable_state import FreezableBqplotImageViewerState
from jdaviz.core.registries import viewer_registry
from jdaviz.core.marks import SpectralLine
from jdaviz.core.linelists import load_preset_linelist, get_available_linelists
from jdaviz.core.unit_conversion_utils import (spectral_axis_conversion,
                                               flux_conversion_general,
                                               all_flux_unit_conversion_equivs)
from jdaviz.utils import SPECTRAL_AXIS_COMP_LABELS
from jdaviz.core.freezable_state import FreezableProfileViewerState
from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin, JdavizProfileView
from jdaviz.configs.cubeviz.plugins.viewers import WithSliceIndicator

__all__ = ['Spectrum1DViewer', 'Spectrum2DViewer']


[docs] @viewer_registry("spectrum-1d-viewer", label="1D Spectrum") class Spectrum1DViewer(JdavizProfileView, WithSliceIndicator): # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['jdaviz:homezoom_matchx', 'jdaviz:homezoom', 'jdaviz:prevzoom'], ['jdaviz:boxzoom_matchx', 'jdaviz:xrangezoom_matchx', 'jdaviz:boxzoom', 'jdaviz:yrangezoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'], # noqa ['jdaviz:panzoom_matchx', 'jdaviz:panzoomx_matchx', 'jdaviz:panzoom_y', 'jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'], # noqa ['bqplot:xrange'], ['jdaviz:selectslice'], ['jdaviz:viewer_clone', 'jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] ] default_class = Spectrum spectral_lines = None _state_cls = FreezableProfileViewerState _default_profile_subset_type = 'spectral' def __init__(self, *args, **kwargs): kwargs.setdefault('default_tool_priority', ['jdaviz:selectslice']) super().__init__(*args, **kwargs) def compatible_units(data): if not len(self.layers): return True # If we load, e.g., one spectrum in Frequency and one in Wavelength, # the viewer state x_att won't match the component of the second spectrum since # (as of Specutils 2.X) they're no longer both the non-specific "World 0" if str(self.state.x_att) in data.component_ids(): data_xunit = data.get_component(str(self.state.x_att)).units else: for comp in SPECTRAL_AXIS_COMP_LABELS: if comp in data.component_ids(): data_xunit = data.get_component(comp).units break data_yunit = data.get_component('flux').units viewer_xunit = self.state.x_display_unit viewer_yunit = self.state.y_display_unit if None in (data_xunit, data_yunit, viewer_xunit, viewer_yunit): return True try: spectral_axis_conversion([1], data_xunit, viewer_xunit) except u.UnitConversionError: return False equivs = all_flux_unit_conversion_equivs(cube_wave=[1]*u.Unit(viewer_xunit)) try: flux_conversion_general([1], data_yunit, viewer_yunit, equivalencies=equivs) except u.UnitConversionError: return False return True self.data_menu._obj.dataset.add_filter('is_spectrum', compatible_units) self.data_menu.layer.add_filter('not_trace', 'not_spatial_subset_in_profile_viewer') @property def redshift(self): return self.jdaviz_helper._redshift
[docs] def load_line_list(self, line_table, replace=False, return_table=False, show=True): # If string, load the named preset list and don't show by default # since there might be too many lines if isinstance(line_table, str): self.load_line_list(load_preset_linelist(line_table), replace=replace, return_table=return_table, show=False) return elif not isinstance(line_table, table.QTable): raise TypeError("Line list must be an astropy QTable with\ (minimally) 'linename' and 'rest' columns") if "linename" not in line_table.columns: raise ValueError("Line table must have a 'linename' column'") if "rest" not in line_table.columns: raise ValueError("Line table must have a 'rest' column'") if np.any(line_table['rest'] <= 0): raise ValueError("all rest values must be positive") # Use the redshift of the displayed spectrum if no redshifts are specified if "redshift" in line_table.colnames: warnings.warn("per line/list redshifts not supported, use viz.set_redshift") # Set whether to show all of the lines on the plot by default on load # We convert bool to int to work around ipywidgets json serialization line_table["show"] = int(show) # If there is already a loaded table, convert units to match. This # attempts to do some sane rounding after the unit conversion. # TODO: Fix this so that things don't get rounded to 0 in some cases """ if self.spectral_lines is not None: sig_figs = [] for row in line_table: rest_str = str(row["rest"].value).replace(".", "").split("e")[0] sig_figs.append(len(rest_str)) line_table["rest"] = line_table["rest"].to(self.spectral_lines["rest"].unit) line_table["sig_figs"] = sig_figs for row in line_table: row["rest"] = row["rest"].round(row["sig_figs"]) del line_table["sig_figs"] """ # Combine name and rest value for indexing if "name_rest" not in line_table.colnames: line_table["name_rest"] = None for row in line_table: row["name_rest"] = "{} {}".format(row["linename"], row["rest"].value) # If no name was given to this list, consider it part of the "Custom" list if "listname" not in line_table.colnames: line_table["listname"] = "Custom" else: for row in line_table: if row["listname"] is None: row["listname"] = "Custom" # Convert colors to hexa values, or set to default (red) if "colors" not in line_table.colnames: line_table["colors"] = "#FF0000FF" else: for row in line_table: if row["colors"][0] == "#": if len(row["colors"]) == 6: row["colors"] += "FF" else: row["colors"] = cnames[row["colors"]] + "FF" # Create or update the main spectral_lines astropy table if self.spectral_lines is None or replace: self.spectral_lines = line_table else: self.spectral_lines = table.vstack([self.spectral_lines, line_table]) self.spectral_lines = table.unique(self.spectral_lines, keys='name_rest') # It seems that we need to recreate this index after v-stacking. self.spectral_lines.add_index("name_rest") self.spectral_lines.add_index("linename") self.spectral_lines.add_index("listname") self._broadcast_plotted_lines() if return_table: return line_table
def _broadcast_plotted_lines(self, marks=None): if marks is None: marks = [x for x in self.figure.marks if isinstance(x, SpectralLine)] msg = SpectralMarksChangedMessage(marks, sender=self) self.session.hub.broadcast(msg) if not np.any([mark.identify for mark in marks]): # then clear the identified entry msg = LineIdentifyMessage(name_rest='', sender=self) self.session.hub.broadcast(msg)
[docs] def erase_spectral_lines(self, name=None, name_rest=None, show_none=True): """ Erase either all spectral lines, all spectral lines sharing the same name (e.g. 'He II') or a specific name-rest value combination (e.g. 'HE II 1640.5', stored in SpectralLine as 'table_index'). """ fig = self.figure if name is None and name_rest is None: fig.marks = [x for x in fig.marks if not isinstance(x, SpectralLine)] if show_none: self.spectral_lines["show"] = False self._broadcast_plotted_lines([]) else: temp_marks = [] # Toggle "show" value in main astropy table. The astropy table # machinery only allows updating a single row at a time. if name_rest is not None: if isinstance(name_rest, str): self.spectral_lines.loc[name_rest]["show"] = False elif isinstance(name_rest, list): for nr in name_rest: self.spectral_lines.loc[nr]["show"] = False # Get rid of the marks we no longer want for x in fig.marks: if isinstance(x, SpectralLine): if name is not None: self.spectral_lines.loc[name]["show"] = False if x.name == name: continue else: if isinstance(name_rest, str): if x.table_index == name_rest: continue elif isinstance(name_rest, list): if x.table_index in name_rest: continue temp_marks.append(x) fig.marks = temp_marks self._broadcast_plotted_lines()
[docs] def plot_spectral_line(self, line, global_redshift=None, plot_units=None, **kwargs): if isinstance(line, str): # Try the full index first (for backend calls), otherwise name only try: line = self.spectral_lines.loc[line] except KeyError: line = self.spectral_lines.loc["linename", line] if plot_units is None: plot_units = self.data()[0].spectral_axis.unit if global_redshift is None: redshift = self.redshift else: redshift = global_redshift line_mark = SpectralLine(self, line['rest'].to_value(plot_units), redshift, name=line["linename"], table_index=line["name_rest"], colors=[line["colors"]], **kwargs) # Erase this line if it already existed, to avoid duplication self.erase_spectral_lines(name_rest=line["name_rest"]) self.figure.marks = self.figure.marks + [line_mark] line["show"] = True self._broadcast_plotted_lines()
[docs] def plot_spectral_lines(self, colors=["blue"], global_redshift=None, **kwargs): """ Plots a user-provided astropy table of spectral lines in the viewer. """ fig = self.figure self.erase_spectral_lines(show_none=False) # Check to see if colors were defined for each line if "colors" in self.spectral_lines.columns: colors = self.spectral_lines["colors"] elif len(colors) != len(self.spectral_lines): colors = colors*len(self.spectral_lines) lines = self.spectral_lines plot_units = self.data()[0].spectral_axis.unit if global_redshift is None: redshift = self.redshift else: redshift = global_redshift marks = [] for line, color in zip(lines, colors): if not line["show"]: continue line = SpectralLine(self, line['rest'].to_value(plot_units), redshift, name=line["linename"], table_index=line["name_rest"], colors=[color], **kwargs) marks.append(line) fig.marks = fig.marks + marks self._broadcast_plotted_lines()
[docs] def available_linelists(self): return get_available_linelists()
[docs] def set_plot_axes(self): super().set_plot_axes() self.figure.axes[1].num_ticks = 5
[docs] @viewer_registry('spectrum-2d-viewer', label="2D Spectrum") class Spectrum2DViewer(JdavizViewerMixin, BqplotImageView): # Due to limitations in CCDData and 2D data that has spectral and spatial # axes, the default conversion class must handle cubes default_class = Spectrum # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['jdaviz:homezoom_matchx', 'jdaviz:homezoom'], ['jdaviz:boxzoom_matchx', 'jdaviz:xrangezoom_matchx', 'jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'], ['jdaviz:panzoom_matchx', 'jdaviz:panzoomx_matchx', 'jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'], ['bqplot:xrange'], ['jdaviz:viewer_clone', 'jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] ] _state_cls = FreezableBqplotImageViewerState def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._subscribe_to_layers_update() # Setup viewer option defaults self.state.aspect = 'auto' self.session.hub.subscribe(self, AddDataToViewerMessage, handler=self._on_viewer_data_changed) self.session.hub.subscribe(self, RemoveDataFromViewerMessage, handler=self._on_viewer_data_changed) for k in ('x_min', 'x_max'): self.state.add_callback(k, self._handle_x_axis_orientation) self.data_menu._obj.dataset.add_filter('is_2d_spectrum_or_trace') @cached_property def reference_spectral_axis(self): return self.state.reference_data.get_object().spectral_axis.value @cached_property def pixel_to_world_interp(self): pixels = range(len(self.reference_spectral_axis)) return interp1d(pixels, self.reference_spectral_axis)
[docs] def pixel_to_world_limits(self, limits): if not len(limits) == 2: raise ValueError("limits must be length 2") pixels = np.arange(0, len(self.reference_spectral_axis)) # we'll use interpolation when possible, but also want to fit a line between # the outermost edge of the data within the limits line_edges_pix = np.array([max((min(pixels), min(limits))), min((max(pixels), max(limits)))]) if line_edges_pix[0] > line_edges_pix[1]: # then the limits are entirely out of range, so use the whole range # when fitting the linear approximation line_edges_pix = np.array([min(pixels), max(pixels)]) line_edges_world = self.pixel_to_world_interp(line_edges_pix) line_coeffs = np.polyfit(line_edges_pix, line_edges_world, deg=1) def pixel_to_world_line(pixel): return line_coeffs[0] * pixel + line_coeffs[1] def map_pixel_to_world(pixel): if pixels[0] <= pixel <= pixels[-1]: # interpolate directly return float(self.pixel_to_world_interp(pixel)) else: # use the line model to extrapolate return pixel_to_world_line(pixel) invert = (-1) ** sum((self.inverted_x_axis, limits[0] > limits[1])) out_lims = list(map(map_pixel_to_world, limits))[::invert] return out_lims
@cached_property def world_to_pixel_interp(self): pixels = range(len(self.reference_spectral_axis)) return interp1d(self.reference_spectral_axis, pixels)
[docs] def world_to_pixel_limits(self, limits): if not len(limits) == 2: raise ValueError("limits must be length 2") # we'll use interpolation when possible, but also want to fit a line between # the outermost edge of the data within the limits line_edges_world = np.array([max((min(self.reference_spectral_axis), min(limits))), min((max(self.reference_spectral_axis), max(limits)))]) if line_edges_world[0] > line_edges_world[1]: # then the limits are entirely out of range, so use the whole range # when fitting the linear approximation line_edges_world = np.array([min(self.reference_spectral_axis), max(self.reference_spectral_axis)]) line_edges_pixels = self.world_to_pixel_interp(line_edges_world) line_coeffs = np.polyfit(line_edges_world, line_edges_pixels, deg=1) def world_to_pixel_line(world): return line_coeffs[0] * world + line_coeffs[1] def map_world_to_pixel(world): if min(self.reference_spectral_axis) <= world <= max(self.reference_spectral_axis): # interpolate directly return float(self.world_to_pixel_interp(world)) else: # use the line model to extrapolate return world_to_pixel_line(world) invert = (-1) ** sum((self.inverted_x_axis, limits[0] > limits[1])) out_lims = list(map(map_world_to_pixel, limits))[::invert] return out_lims
def _on_viewer_data_changed(self, msg): if msg.viewer_reference != self.reference: return # clear cached properties that are based on reference data - this is probably # overly-conservative and we might be able to limit the clearing for only when # reference data is changed (perhaps with a callback on the state for reference_data) for attr in ('reference_spectral_axis', 'inverted_x_axis', 'pixel_to_world_interp', 'world_to_pixel_interp'): if attr in self.__dict__: del self.__dict__[attr] if len(self.data()): self._handle_x_axis_orientation() @cached_property def inverted_x_axis(self): return self.reference_spectral_axis[0] > self.reference_spectral_axis[-1] def _handle_x_axis_orientation(self, *args): x_scales = self.scales['x'] limits = [x_scales.min, x_scales.max] limits_inverted = limits[0] > limits[1] if limits_inverted == self.inverted_x_axis: return with x_scales.hold_sync(): x_scales.min = max(limits) if self.inverted_x_axis else min(limits) x_scales.max = min(limits) if self.inverted_x_axis else max(limits)
[docs] def data(self, cls=None, statistic=None): return [layer_state.layer.get_object(statistic=statistic, cls=cls or self.default_class) for layer_state in self.state.layers if hasattr(layer_state, 'layer') and isinstance(layer_state.layer, BaseData)]
[docs] def set_plot_axes(self): self.figure.axes[0].label = "x: pixels" self.figure.axes[1].label = "y: pixels" self.figure.axes[1].tick_format = None self.figure.axes[1].label_location = "middle" # Sync with Spectrum viewer (that is also used by other viz). # Make it so y axis label is not covering tick numbers. self.figure.fig_margin["left"] = 95 self.figure.fig_margin["bottom"] = 60 self.figure.send_state('fig_margin') # Force update self.figure.axes[0].label_offset = "40" self.figure.axes[1].label_offset = "-70" # NOTE: with tick_style changed below, the default responsive ticks in bqplot result # in overlapping tick labels. For now we'll hardcode at 8, but this could be removed # (default to None) if/when bqplot auto ticks react to styling options. self.figure.axes[1].num_ticks = 8 for i in (0, 1): self.figure.axes[i].tick_style = {'font-size': 15, 'font-weight': 600}