import numpy as np
from traitlets import Bool, observe
from astropy import units as u
from bqplot import LinearScale
from bqplot.marks import Lines, Label, Scatter
from glue.core import HubListener
from specutils import Spectrum
from jdaviz.core.events import AddDataMessage, GlobalDisplayUnitChanged, RemoveDataMessage
from jdaviz.core.events import (SliceToolStateMessage, LineIdentifyMessage,
SpectralMarksChangedMessage,
RedshiftMessage)
from jdaviz.core.unit_conversion_utils import (all_flux_unit_conversion_equivs,
check_if_unit_is_per_solid_angle,
flux_conversion_general)
__all__ = ['OffscreenLinesMarks', 'BaseSpectrumVerticalLine', 'SpectralLine',
'SliceIndicatorMarks', 'ShadowMixin', 'ShadowLine', 'ShadowLabelFixedY',
'PluginMark', 'LinesAutoUnit', 'PluginLine', 'PluginScatter',
'LineAnalysisContinuum', 'LineAnalysisContinuumCenter',
'LineAnalysisContinuumLeft', 'LineAnalysisContinuumRight',
'LineUncertainties', 'ScatterMask', 'SelectedSpaxel', 'MarkersMark',
'CatalogMark', 'FootprintOverlay', 'ApertureMark', 'DistanceMeasurement',
'DistanceLabel']
accent_color = "#c75d2c"
# distinct from the jdaviz colors for backwards compat with MAST Portal:
mast_footprint_default_color = '#FF6600'
mast_footprint_selected_color = '#00ff00'
[docs]
class OffscreenLinesMarks(HubListener):
def __init__(self, viewer):
self.viewer = viewer
viewer.state.add_callback("x_min", lambda x_min: self._update_counts())
viewer.state.add_callback("x_max", lambda x_max: self._update_counts())
viewer.session.hub.subscribe(self, RedshiftMessage,
handler=self._update_counts)
viewer.session.hub.subscribe(self, SpectralMarksChangedMessage,
handler=self._update_counts)
self.left = Label(text=[''], x=[0.02], y=[0.8],
scales={'x': LinearScale(min=0, max=1),
'y': LinearScale(min=0, max=1)},
colors=['gray'], default_size=12,
align='start')
self.right = Label(text=[''], x=[0.98], y=[0.8],
scales={'x': LinearScale(min=0, max=1),
'y': LinearScale(min=0, max=1)},
colors=['gray'], default_size=12,
align='end')
self._update_counts()
@property
def marks(self):
return [self.left, self.right]
def _update_counts(self, *args):
oob_left, oob_right = 0, 0
for m in self.viewer.figure.marks:
if isinstance(m, SpectralLine):
if m.x[0] < self.viewer.state.x_min:
oob_left += 1
elif m.x[0] > self.viewer.state.x_max:
oob_right += 1
self.left.text = [f'\u25c0 {oob_left}' if oob_left > 0 else '']
self.right.text = [f'{oob_right} \u25b6' if oob_right > 0 else '']
class PluginMarkCollection:
# This class allows for the creation of plugin preview marks across viewers
def __init__(self, mark_cls, shadow_cls=None, shadow_kwargs={}, **mark_kwargs):
self.mark_cls = mark_cls
self.shadow_cls = shadow_cls
self.shadow_kwargs = shadow_kwargs
self.mark_kwargs = mark_kwargs
self.shadow_marks = {}
self.marks = {}
@property
def marks_list(self):
return list(self.marks.values())
def marks_for_viewers(self, viewers):
for viewer in viewers:
if viewer not in self.marks:
# Create a new mark for the given viewer
mark = self.mark_cls(viewer=viewer, **self.mark_kwargs)
self.marks[viewer] = mark
if self.shadow_cls is not None:
shadow_kwargs = {k: v.marks_for_viewers([viewer])[0]
if isinstance(v, PluginMarkCollection) else v
for k, v in self.shadow_kwargs.items()}
shadow = self.shadow_cls(shadowing=mark, **shadow_kwargs)
self.shadow_marks[viewer] = shadow
viewer.figure.marks += [shadow]
viewer.figure.marks += [mark]
viewer.figure.send_state('marks')
return [self.marks[viewer] for viewer in viewers]
def clear_if_not_in_viewers(self, viewers):
for viewer, mark in self.marks.items():
if viewer not in viewers:
# Clear the mark if the viewer is not in the list
mark.clear()
mark.visible = False
def clear(self):
for mark in self.marks.values():
mark.clear()
def update_xy(self, x, y, viewers):
for mark in self.marks_for_viewers(viewers):
mark.update_xy(x, y)
mark.visible = True
self.clear_if_not_in_viewers(viewers)
def set_for_viewers(self, name, value, viewers):
for mark in self.marks_for_viewers(viewers):
setattr(mark, name, value)
def __setattr__(self, name, value):
if name in ('plugin', 'mark_cls', 'shadow_cls',
'shadow_kwargs', 'mark_kwargs',
'shadow_marks', 'marks'):
super().__setattr__(name, value)
else:
# pass on to all stored marks
for mark in self.marks.values():
setattr(mark, name, value)
[docs]
class PluginMark:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.xunit = None
self.yunit = None
# whether to update existing marks when global display units are changed
self.auto_update_units = True
self.hub.subscribe(self, GlobalDisplayUnitChanged,
handler=self._on_global_display_unit_changed)
if self.xunit is None:
self.set_x_unit()
if self.yunit is None:
self.set_y_unit()
@property
def hub(self):
return self.viewer.hub
[docs]
def update_xy(self, x, y):
# If x and y are not in the previous units, they should be provided as quantities
if hasattr(x, 'value'):
xunit = x.unit
x = x.value
else:
xunit = None
self.x = np.asarray(x)
if xunit is not None:
self.xunit = u.Unit(xunit)
if hasattr(y, 'value'):
yunit = y.unit
y = y.value
else:
yunit = None
self.y = np.asarray(y)
if yunit is not None:
self.yunit = u.Unit(yunit)
[docs]
def append_xy(self, x, y):
self.x = np.append(self.x, x)
self.y = np.append(self.y, y)
[docs]
def set_x_unit(self, unit=None):
if unit is None:
if not hasattr(self.viewer.state, 'x_display_unit'):
if isinstance(self.viewer, ('Spectrum2DViewer', 'MosvizProfile2DView')):
# x-unit of 2d spectrum viewers are always in pixels
unit = u.pix
elif self.viewer.data() and hasattr(self.viewer.data()[0], 'spectral_axis'):
unit = self.viewer.data()[0].spectral_axis.unit
else:
return
else:
unit = self.viewer.state.x_display_unit
unit = u.Unit(unit)
if self.xunit is not None and not np.all([s == 0 for s in self.x.shape]):
x = (self.x * self.xunit).to_value(unit, u.spectral())
self.xunit = unit
self.x = x
self.xunit = unit
[docs]
def set_y_unit(self, unit=None):
if unit is None:
if not hasattr(self.viewer.state, 'y_display_unit'):
if isinstance(self.viewer, ('Spectrum2DViewer', 'MosvizProfile2DView')):
# y-unit of 2d spectrum viewers are always in pixels
unit = u.pix
elif self.viewer.data() and hasattr(self.viewer.data()[0], 'flux'):
unit = self.viewer.data()[0].flux.unit
else:
return
else:
unit = self.viewer.state.y_display_unit
unit = u.Unit(unit)
# spectrum y-values in viewer have already been converted, don't convert again
# if a spectral_y_type is changed, just update the unit
if self.yunit is not None and check_if_unit_is_per_solid_angle(self.yunit) != check_if_unit_is_per_solid_angle(unit): # noqa
self.yunit = unit
return
if self.yunit is not None and not np.all([s == 0 for s in self.y.shape]): # noqa
if self.viewer.default_class is Spectrum:
if self.xunit is None:
return
spec = self.viewer.state.reference_data.get_object(cls=Spectrum)
pixar_sr = spec.meta.get('PIXAR_SR', None)
# if x is all the same value, then we either have a vertical line mark or
# a flat spectrum, in either case we can use a single value for the spectral
# density equivalency.
if len(np.unique(self.x)) == 1 and (len(self.x) != len(self.y)):
wave = self.x[0] * self.xunit
else:
wave = self.x * self.xunit
equivs = all_flux_unit_conversion_equivs(pixar_sr, wave)
y = flux_conversion_general(self.y, self.yunit, unit,
equivs, with_unit=False)
self.y = y
self.yunit = unit
def _on_global_display_unit_changed(self, msg):
if not self.auto_update_units:
return
unit = msg.unit
if (msg.axis in ('spectral', 'spectral_y') and
self.viewer.__class__.__name__ in ('Spectrum2DViewer',
'MosvizProfile2DView')):
# then we want to ignore the change to spectral unit as these viewers
# are always in pixel units on the x-axis
unit = u.pix
if self.viewer.__class__.__name__ in ('Spectrum1DViewer',
'Spectrum2DViewer',
'CubevizProfileView',
'MosvizProfileView',
'MosvizProfile2DView'):
axis_map = {'spectral': 'x', 'spectral_y': 'y'}
else:
return
axis = axis_map.get(msg.axis, None)
if axis is not None:
scale = self.scales.get(axis, None)
# if PluginMark mark is LinearScale prevent it from entering unit conversion
# machinery so it maintains it's position in viewer.
if isinstance(scale, LinearScale) and (scale.min, scale.max) == (0, 1):
return
getattr(self, f'set_{axis}_unit')(unit)
[docs]
def clear(self):
self.update_xy([], [])
[docs]
class BaseSpectrumVerticalLine(Lines, PluginMark, HubListener):
def __init__(self, viewer, x, **kwargs):
self.viewer = viewer
# the location of the marker will need to update automatically if the
# underlying data changes (through a unit conversion, for example)
if hasattr(viewer.state, 'reference_data'):
viewer.state.add_callback("reference_data",
self._update_reference_data)
scales = viewer.scales
# Lines.__init__ will set self.x
super().__init__(x=[x, x], y=[0, 1],
scales={'x': scales['x'], 'y': LinearScale(min=0, max=1)},
**kwargs)
def _update_reference_data(self, reference_data):
# don't update x units before initialization or in rampviz
if reference_data is None or 'Rampviz' in self.viewer.__class__.__name__:
return
self._update_unit(reference_data.get_object(cls=Spectrum).spectral_axis.unit)
def _update_unit(self, new_unit):
# the x-units may have changed. We want to convert the internal self.x
# from self.xunit to the new units (x_all.unit)
if self.xunit is None:
self.xunit = new_unit
return
if new_unit == self.xunit:
return
old_quant = self.x[0]*self.xunit
x = old_quant.to_value(new_unit, equivalencies=u.spectral())
self.x = [x, x]
self.xunit = new_unit
[docs]
class SpectralLine(BaseSpectrumVerticalLine):
"""
Subclass on bqplot Lines, mostly so that we can erase spectral lines
by eliminating any SpectralLines objects from a figures marks list. Also
lets us do wavelength redshifting here on mark creation.
"""
def __init__(self, viewer, rest_value, redshift=0, name=None, **kwargs):
self._rest_value = rest_value
self._identify = False
self.name = name
# table_index is same as name_rest elsewhere
self.table_index = kwargs.pop("table_index", None)
# setting redshift will set self.x and enable the obs_value property,
# but to do that we need x_unit set first (would normally be assigned
# in the super init)
self.xunit = u.Unit(viewer.state.x_display_unit)
self.redshift = redshift
viewer.session.hub.subscribe(self, LineIdentifyMessage,
handler=self._process_identify_change)
super().__init__(viewer=viewer, x=self.obs_value, stroke_width=1,
fill='none', close_path=False, **kwargs)
@property
def name_rest(self):
return self.table_index
@property
def rest_value(self):
return self._rest_value
@property
def obs_value(self):
return self.x[0]
[docs]
def set_x_unit(self, unit=None):
prev_unit = self.xunit
super().set_x_unit(unit=unit)
self._rest_value = (self._rest_value * prev_unit).to_value(unit, u.spectral())
@property
def redshift(self):
return self._redshift
@redshift.setter
def redshift(self, redshift):
self._redshift = redshift
if str(self.xunit.physical_type) == 'length':
obs_value = self._rest_value*(1+redshift)
elif str(self.xunit.physical_type) == 'frequency':
obs_value = self._rest_value/(1+redshift)
else:
# catch all for anything else (wavenumber, energy, etc)
rest_angstrom = (self._rest_value*self.xunit).to_value(u.Angstrom,
equivalencies=u.spectral())
obs_angstrom = rest_angstrom*(1+redshift)
obs_value = (obs_angstrom*u.Angstrom).to_value(self.xunit,
equivalencies=u.spectral())
self.x = [obs_value, obs_value]
@property
def identify(self):
return self._identify
@identify.setter
def identify(self, identify):
if not isinstance(identify, bool): # pragma: no cover
raise TypeError("identify must be of type bool")
self._identify = identify
self.stroke_width = 3 if identify else 1
def _process_identify_change(self, msg):
self.identify = msg.name_rest == self.table_index
def _update_unit(self, new_unit):
if self.xunit is None:
self.xunit = new_unit
return
if new_unit == self.xunit:
return
old_quant = self._rest_value*self.xunit
self._rest_value = old_quant.to_value(new_unit, equivalencies=u.spectral())
# re-compute self.x from current redshift (instead of converting that as well)
self.redshift = self._redshift
self.xunit = new_unit
[docs]
class SliceIndicatorMarks(BaseSpectrumVerticalLine, HubListener):
"""Subclass on bqplot Lines to handle slice/wavelength indicator.
"""
def __init__(self, viewer, value=0, **kwargs):
self._viewer = viewer
self._value = None
self._oob = False # out-of-bounds, either False, 'left', or 'right'
self._active = False
# TODO: new viewers need to respect plugin settings
self._show_if_inactive = True
self._show_value = True
viewer.state.add_callback("x_min", lambda x_min: self._value_handle_oob(update_label=True))
viewer.state.add_callback("x_max", lambda x_max: self._value_handle_oob(update_label=True))
viewer.session.hub.subscribe(self, SliceToolStateMessage,
handler=self._on_change_state)
for msg in (AddDataMessage, RemoveDataMessage):
viewer.session.hub.subscribe(self, msg,
handler=lambda x: self._set_visibility())
super().__init__(viewer=viewer,
x=[value, value],
stroke_width=2,
marker='diamond',
fill='none', close_path=False,
labels=['slice'], labels_visibility='none', **kwargs)
self.value = value
# instead of using the Lines label which is limited, we'll use a Label object which
# will follow the x-coordinate of the slice indicator line, with a fixed y-value
# (in axes-units) and will flip its alignment depending on whether the line is on the
# left or right side of the axes.
self.label = ShadowLabelFixedY(viewer, self, shadow_traits=[], default_size=12, y=0.95)
# default to the initial state of the tool since we can't control if this will
# happen before or after the initialization of the tool
tool_active = self.viewer.toolbar.active_tool_id == 'jdaviz:selectslice'
self._on_change_state({'active': tool_active})
@property
def marks(self):
return [self, self.label]
def _set_visibility(self):
for dc in self._viewer.jdaviz_app.data_collection:
if len(dc.shape) == 3:
self.visible = True
self.label.visible = self._show_value
return
self.visible = False
self.label.visible = False
def _on_global_display_unit_changed(self, msg):
# Updating the value is handled by the plugin itself, need to update unit string.
if msg.axis in ["spectral", "x"]:
self.xunit = msg.unit
self._update_label()
def _value_handle_oob(self, x=None, update_label=False):
if x is None:
x = self.value
else:
self._value = x
x_min, x_max = self._viewer.state.x_min, self._viewer.state.x_max
if x_min is None or x_max is None:
self.x = [x, x]
return
x_range = x_max - x_min
padding_fig = 0.01
padding = padding_fig * x_range
x_min += padding
x_max -= padding
# ensure y-scale has been set (we'll only be overriding x, but scatter viewers complain
# if y-scale is not set)
self.scales.setdefault('y', LinearScale(min=0, max=1))
if x < x_min:
self.x = [padding_fig, padding_fig]
self.scales = {**self.scales, 'x': LinearScale(min=0, max=1)}
self.line_style = 'dashed'
self._oob = 'left'
elif x > x_max:
self.x = [1-padding_fig, 1-padding_fig]
self.scales = {**self.scales, 'x': LinearScale(min=0, max=1)}
self.line_style = 'dashed'
self._oob = 'right'
else:
self.x = [x, x]
self.scales = {**self.scales, 'x': self._viewer.scales['x']}
self.line_style = 'solid'
self._oob = False
if update_label:
self._update_label()
def _update_colors_opacities(self):
# orange (accent) if active, import button blue otherwise (see css in main_styles.vue)
if not self._show_if_inactive and not self._active:
self.label.visible = False
self.visible = False
return
self.visible = True
self.label.visible = self._show_value
self.colors = ["#c75109" if self._active else "#007BA1"]
self.opacities = [1.0 if self._active else 0.9]
def _on_change_state(self, msg={}):
if isinstance(msg, dict):
changes = msg
else:
if msg.viewer is not None and msg.viewer != self.viewer:
return
changes = msg.change
for k, v in changes.items():
if k == 'active':
self._active = v
elif k == 'show_indicator':
self._show_if_inactive = v
elif k == 'show_value':
self._show_value = v
self._update_colors_opacities()
def _update_label(self):
def _formatted_value(value):
power = abs(np.log10(value))
if power >= 3:
# use scientific notation
return f'{value:0.4e}'
else:
return f'{value:0.4f}'
valuestr = _formatted_value(self.value)
xunit = str(self.xunit) if self.xunit is not None else ''
# U+00A0 is a blank space, U+25C0 a left arrow triangle, and U+25B6 a right arrow triangle
if self._oob == 'left':
self.labels = [f'\u00A0 \u25c0 {valuestr} {xunit} \u00A0'] # noqa
elif self._oob == 'right':
self.labels = [f'{valuestr} {xunit} \u25b6 \u00A0']
else:
self.labels = [f'\u00A0 {valuestr} {xunit} \u00A0']
@property
def value(self):
return self._value
@value.setter
def value(self, value):
self._value_handle_oob(value, update_label=True)
[docs]
class ShadowMixin:
"""Mixin class to propagate traits from one mark object to another.
Anything in ``sync_traits`` will be mirrored directly from
``shadowing`` to the shadowed object.
Can manually override ``_on_shadowing_changed`` for more advanced logic cases.
"""
def _get_id(self, mark):
return getattr(mark, '_model_id', None)
def _setup_shadowing(self, shadowing, sync_traits=[], other_traits=[]):
"""
sync_traits: traits to set now, and mirror any changes to shadowing in the future
other_trait: traits to set now, but not mirror in the future
"""
if not hasattr(self, '_shadowing'):
self._shadowing = {}
self._sync_traits = {}
shadowing_id = self._get_id(shadowing)
self._shadowing[shadowing_id] = shadowing
self._sync_traits[shadowing_id] = sync_traits
# sync initial values
for attr in sync_traits + other_traits:
self._on_shadowing_changed({'name': attr,
'new': getattr(shadowing, attr),
'owner': shadowing})
# subscribe to future changes
shadowing.observe(self._on_shadowing_changed)
def _on_shadowing_changed(self, change):
if change['name'] in self._sync_traits.get(self._get_id(change.get('owner')), []):
setattr(self, change['name'], change['new'])
return
[docs]
class ShadowLine(Lines, HubListener, ShadowMixin):
"""Create a white shadow line around another line
to help make it standout on top of other lines.
"""
def __init__(self, shadowing, shadow_width=1, **kwargs):
self._shadow_width = shadow_width
super().__init__(scales=shadowing.scales,
stroke_width=shadowing.stroke_width+shadow_width if shadowing.stroke_width else 0, # noqa
marker_size=shadowing.marker_size+shadow_width if shadowing.marker_size else 0, # noqa
colors=[kwargs.pop('color', 'white')],
**kwargs)
self._setup_shadowing(shadowing,
['scales', 'x', 'y', 'visible', 'line_style', 'marker'],
['stroke_width', 'marker_size'])
[docs]
class ShadowLabelFixedY(Label, ShadowMixin):
"""Label whose position shadows that of a parent ``shadowing``
line and will flip alignment based on whether it is left or
right of the center of the viewer.
"""
def __init__(self, viewer, shadowing, shadow_traits=['visible'],
y=0.95, point_index=0, **kwargs):
super().__init__(**kwargs)
self._viewer = viewer
self.y = [y]
self.scales['y'] = LinearScale(min=0, max=1)
self._point_index = point_index
self._setup_shadowing(shadowing,
shadow_traits,
['x', 'scales', 'labels', 'colors'])
viewer.state.add_callback("x_min", lambda x_min: self._update_align())
viewer.state.add_callback("x_max", lambda x_max: self._update_align())
def _force_redraw(self):
# TODO: bug in bqplot that change in align/colors traitlet doesn't update immediately,
# we'll get around it in the meantime by just forcing the Label to see a change to the
# text traitlet
text = self.text
self.text = ['']
self.text = text
def _update_align(self):
if not isinstance(self.scales.get('x'), LinearScale):
return
# determine alignment automatically
if self.scales['x'].min == 0 and self.scales['x'].max == 1:
# then we're in axes units, so just check position compared to 0.5
is_to_right = self.x[0] > 0.5
else:
# then we're in data units, so check position compared to the median of the axes limits
is_to_right = self.x[0] > (self._viewer.state.x_min + self._viewer.state.x_max) / 2.
if is_to_right and self.align != 'end':
self.align = 'end'
# force redraw by re-updating label
self._force_redraw()
if not is_to_right and self.align != 'start':
self.align = 'start'
# force redraw by re-updating label
self._force_redraw()
def _on_shadowing_changed(self, change):
super()._on_shadowing_changed(change)
if change['name'] == 'labels':
self.text = [change['new'][self._point_index]]
elif change['name'] in ('x', 'colors'):
setattr(self, change['name'], [change['new'][self._point_index]])
if change['name'] == 'colors':
# bqplot bug that won't notice change to colors, manually force re-draw
self._force_redraw()
elif change['name'] == 'scales':
self.scales = {**self.scales, 'x': change['new']['x']}
if change['name'] in ('x', 'scales'):
# then the position of the label on the plot has changed, so re-determine whether
# it should be aligned to the left or right
self._update_align()
[docs]
class LinesAutoUnit(PluginMark, Lines, HubListener):
def __init__(self, viewer, *args, **kwargs):
self.viewer = viewer
super().__init__(*args, **kwargs)
[docs]
class PluginLine(Lines, PluginMark, HubListener):
def __init__(self, viewer, x=[], y=[], **kwargs):
self.viewer = viewer
# color is same blue as import button
kwargs.setdefault('colors', [accent_color])
self.label = kwargs.get('label')
# default to viewer scales, overriding any keys sent through scales kwarg
scales = {**viewer.scales, **kwargs.pop('scales', {})}
super().__init__(x=x, y=y, scales=scales, **kwargs)
[docs]
class PluginScatter(Scatter, PluginMark, HubListener):
def __init__(self, viewer, x=[], y=[], **kwargs):
self.viewer = viewer
# default color is same blue as import button
kwargs.setdefault('colors', [accent_color])
# default to viewer scales, overriding any keys sent through scales kwarg
scales = {**viewer.scales, **kwargs.pop('scales', {})}
super().__init__(x=x, y=y, scales=scales, **kwargs)
[docs]
class LineAnalysisContinuum(PluginLine):
def __init__(self, *args, **kwargs):
# units do not need to be updated because the plugin itself reruns
# the computation and automatically changes the arrays themselves
self.auto_update_units = kwargs.pop('auto_update_units', False)
super().__init__(*args, **kwargs)
[docs]
class LineAnalysisContinuumCenter(LineAnalysisContinuum):
def __init__(self, viewer, x=[], y=[], **kwargs):
super().__init__(viewer, x, y, **kwargs)
self.stroke_width = 1
[docs]
class LineAnalysisContinuumLeft(LineAnalysisContinuum):
def __init__(self, viewer, x=[], y=[], **kwargs):
super().__init__(viewer, x, y, **kwargs)
self.stroke_width = 5
[docs]
class LineAnalysisContinuumRight(LineAnalysisContinuumLeft):
pass
[docs]
class LineUncertainties(LinesAutoUnit):
def __init__(self, viewer, *args, **kwargs):
super().__init__(viewer, *args, **kwargs)
[docs]
class ScatterMask(Scatter):
def __init__(self, **kwargs):
super().__init__(**kwargs)
[docs]
class SelectedSpaxel(Lines):
def __init__(self, **kwargs):
super().__init__(**kwargs)
[docs]
class MarkersMark(PluginScatter):
def __init__(self, viewer, **kwargs):
kwargs.setdefault('marker', 'circle')
super().__init__(viewer, **kwargs)
[docs]
class CatalogMark(PluginScatter):
def __init__(self, viewer, **kwargs):
kwargs.setdefault('marker', 'circle')
super().__init__(viewer, **kwargs)
class RegionOverlay(PluginLine):
"""
Overlay mark representing a sky region in an image viewer,
specifically for use outside of the Footprints plugin.
Distinct from `FootprintOverlay`, which is meant for the marks
made and controlled by the Footprints plugin.
"""
default_style = dict(
stroke_width=2,
colors=[mast_footprint_default_color],
)
selected_style = dict(
stroke_width=4,
colors=[mast_footprint_selected_color],
)
selected = Bool(False).tag(sync=True)
def __init__(self, viewer, overlay, selected=False, label=None, **kwargs):
self._overlay = overlay
self.selected = selected
style = self.selected_style if selected else self.default_style
for attr, value in style.items():
kwargs.setdefault(attr, value)
kwargs.setdefault('stroke_width', 1)
kwargs.setdefault('close_path', True)
kwargs.setdefault('opacities', [1.0])
kwargs.setdefault('fill', 'inside')
kwargs.setdefault('fill_opacities', [0])
super().__init__(viewer, **kwargs)
# must be called after super()
self.label = label
@property
def overlay(self):
return self._overlay
@observe("selected")
def on_selection_change(self, msg):
style = self.selected_style if self.selected else self.default_style
self.update_style(style)
def update_style(self, style):
if not len(style):
return
for attr, value in style.items():
setattr(self, attr, value)
class HistogramMark(Lines):
def __init__(self, min_max_value, scales, **kwargs):
# Vertical line in LinearScale
y = [0, 1]
colors = [accent_color]
line_style = "solid"
super().__init__(x=min_max_value, y=y, scales=scales, colors=colors, line_style=line_style,
**kwargs)
[docs]
class DistanceLabel(Label, PluginMark, HubListener):
"""A specialized bqplot Label for use with the DistanceMeasurement tool.
This class inherits from the Jdaviz PluginMark to integrate with the
application's hub and event system, but overrides the unit conversion
methods to prevent the label's pixel-based coordinates from being
erroneously converted.
"""
def __init__(self, viewer, x, y, text, **kwargs):
self.viewer = viewer
kwargs.setdefault('align', 'middle')
kwargs.setdefault('baseline', 'middle')
kwargs.setdefault('x_offset', 0)
kwargs.setdefault('y_offset', 0)
super().__init__(x=x, y=y, text=[text], scales=viewer.scales, **kwargs)
self.auto_update_units = True
[docs]
def update_position(self, x, y, text):
self.x = x
self.y = y
self.text = [text]
class DistanceLine(PluginLine):
# We just want to be able to check for this specific line
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
class DistanceMeasurement:
"""A composite mark that displays a line between two points with a
dynamically rotated and offset label showing the distance.
This class manages a collection of bqplot marks (a Line and two Labels
for text and its halo) to create a single, cohesive measurement indicator
in a viewer. The core logic in the `update_points` method handles the
positioning, rotation, and offsetting of the label to ensure it remains
parallel to the line while avoiding intersection.
"""
def __init__(self, viewer, x=[], y=[], text=""):
self.viewer = viewer
self.endpoints = None
self.auto_update_units = True
self.line = DistanceLine(
viewer,
x=x,
y=y,
scales=viewer.scales,
colors=[accent_color],
stroke_width=2
)
anchor_x, anchor_y = (x[0] + x[1]) / 2, (y[0] + y[1]) / 2
self.label_shadow = DistanceLabel(
viewer, [anchor_x], [anchor_y], text,
colors=['white'],
stroke='white',
stroke_width=5,
font_weight='bold',
default_size=15
)
self.label_text = DistanceLabel(
viewer, [anchor_x], [anchor_y], text,
colors=[accent_color],
font_weight='bold',
default_size=15
)
self.visible = True
self.update_points(x[0], y[0], x[1], y[1], text)
self.auto_update_units = True
@property
def marks(self):
# The shadow must be listed before the text to be drawn underneath.
return [self.line, self.label_shadow, self.label_text]
@property
def visible(self):
return self.line.visible
@visible.setter
def visible(self, visible):
# Update visibility for all marks.
self.line.visible = visible
self.label_shadow.visible = visible
self.label_text.visible = visible
[docs]
def update_points(self, x0, y0, x1, y1, text=""):
x0_v = getattr(x0, 'value', x0)
y0_v = getattr(y0, 'value', y0)
x1_v = getattr(x1, 'value', x1)
y1_v = getattr(y1, 'value', y1)
self.line.x = [x0_v, x1_v]
self.line.y = [y0_v, y1_v]
mid_x = (x0_v + x1_v) / 2
mid_y = (y0_v + y1_v) / 2
dx_data = x1_v - x0_v
dy_data = y1_v - y0_v
angle_rad = np.arctan2(-dy_data, dx_data)
angle_deg = np.rad2deg(angle_rad)
# Check for the special case of a perfect diagonal line.
# A small tolerance is used for floating-point comparison.
if abs(abs(dx_data) - abs(dy_data)) < 1e-9:
x_offset_float = 0
y_offset_float = 30
else:
offset_distance = 10
line_length_data = np.sqrt(dx_data**2 + dy_data**2)
if line_length_data > 0:
perp_dx = -dy_data
perp_dy = dx_data
x_offset_float = (perp_dx / line_length_data) * offset_distance
y_offset_float = (perp_dy / line_length_data) * offset_distance
else:
x_offset_float = 0
y_offset_float = -offset_distance
# Determine text offset direction based on the center of the viewer
x_scale = self.viewer.scales.get('x')
y_scale = self.viewer.scales.get('y')
# Check if the viewer scales and their min/max values are available
if (x_scale is not None and y_scale is not None and
x_scale.min is not None and x_scale.max is not None and
y_scale.min is not None and y_scale.max is not None):
viewer_center_x = (x_scale.min + x_scale.max) / 2
viewer_center_y = (y_scale.min + y_scale.max) / 2
pos_vec_x = mid_x - viewer_center_x
pos_vec_y = mid_y - viewer_center_y
if (pos_vec_x * perp_dx + pos_vec_y * perp_dy) > 0:
x_offset_float *= -1
y_offset_float *= -1
x_offset_int = int(round(x_offset_float))
y_offset_int = int(round(y_offset_float))
# Normalize the display angle for readability
if abs(angle_deg) > 90:
angle_deg -= 180 * np.sign(angle_deg)
for label in (self.label_shadow, self.label_text):
label.update_position([mid_x], [mid_y], text)
label.rotate_angle = angle_deg
label.x_offset = x_offset_int
label.y_offset = -y_offset_int