Source code for jdaviz.core.marks

import numpy as np

from astropy import units as u
from bqplot import LinearScale
from bqplot.marks import Lines, Label, Scatter
from copy import deepcopy
from glue.core import HubListener
from specutils import Spectrum1D

from jdaviz.core.events import (SliceToolStateMessage, LineIdentifyMessage,
                                SpectralMarksChangedMessage,
                                RedshiftMessage)

__all__ = ['OffscreenLinesMarks', 'BaseSpectrumVerticalLine', 'SpectralLine',
           'SliceIndicatorMarks', 'ShadowMixin', 'ShadowLine', 'ShadowLabelFixedY',
           'PluginMark', 'PluginLine', 'PluginScatter',
           'LineAnalysisContinuum', 'LineAnalysisContinuumCenter',
           'LineAnalysisContinuumLeft', 'LineAnalysisContinuumRight',
           'LineUncertainties', 'ScatterMask', 'SelectedSpaxel', 'MarkersMark']


[docs]class OffscreenLinesMarks(HubListener): def __init__(self, viewer): self.viewer = viewer viewer.state.add_callback("x_min", lambda x_min: self._update_counts()) viewer.state.add_callback("x_max", lambda x_max: self._update_counts()) viewer.session.hub.subscribe(self, RedshiftMessage, handler=self._update_counts) viewer.session.hub.subscribe(self, SpectralMarksChangedMessage, handler=self._update_counts) self.left = Label(text=[''], x=[0.02], y=[0.8], scales={'x': LinearScale(min=0, max=1), 'y': LinearScale(min=0, max=1)}, colors=['gray'], default_size=12, align='start') self.right = Label(text=[''], x=[0.98], y=[0.8], scales={'x': LinearScale(min=0, max=1), 'y': LinearScale(min=0, max=1)}, colors=['gray'], default_size=12, align='end') self._update_counts() @property def marks(self): return [self.left, self.right] def _update_counts(self, *args): oob_left, oob_right = 0, 0 for m in self.viewer.figure.marks: if isinstance(m, SpectralLine): if m.x[0] < self.viewer.state.x_min: oob_left += 1 elif m.x[0] > self.viewer.state.x_max: oob_right += 1 self.left.text = [f'\u25c0 {oob_left}' if oob_left > 0 else ''] self.right.text = [f'{oob_right} \u25b6' if oob_right > 0 else '']
[docs]class BaseSpectrumVerticalLine(Lines, HubListener): def __init__(self, viewer, x, **kwargs): # we'll store the current units so that we can automatically update the # positioning on a change to the x-units self._x_unit = viewer.state.reference_data.get_object(cls=Spectrum1D).spectral_axis.unit # the location of the marker will need to update automatically if the # underlying data changes (through a unit conversion, for example) viewer.state.add_callback("reference_data", self._update_reference_data) scales = viewer.scales # Lines.__init__ will set self.x super().__init__(x=[x, x], y=[0, 1], scales={'x': scales['x'], 'y': LinearScale(min=0, max=1)}, **kwargs) def _update_reference_data(self, reference_data): if reference_data is None: return self._update_data(reference_data.get_object(cls=Spectrum1D).spectral_axis) def _update_data(self, x_all): # the x-units may have changed. We want to convert the internal self.x # from self._x_unit to the new units (x_all.unit) new_unit = x_all.unit if new_unit == self._x_unit: return old_quant = self.x[0]*self._x_unit x = old_quant.to_value(x_all.unit, equivalencies=u.spectral()) self.x = [x, x] self._x_unit = new_unit
[docs]class SpectralLine(BaseSpectrumVerticalLine): """ Subclass on bqplot Lines, mostly so that we can erase spectral lines by eliminating any SpectralLines objects from a figures marks list. Also lets us do wavelength redshifting here on mark creation. """ def __init__(self, viewer, rest_value, redshift=0, name=None, **kwargs): self._rest_value = rest_value self._identify = False self.name = name # table_index is same as name_rest elsewhere self.table_index = kwargs.pop("table_index", None) # setting redshift will set self.x and enable the obs_value property, # but to do that we need x_unit set first (would normally be assigned # in the super init) self._x_unit = viewer.state.reference_data.get_object(cls=Spectrum1D).spectral_axis.unit self.redshift = redshift viewer.session.hub.subscribe(self, LineIdentifyMessage, handler=self._process_identify_change) super().__init__(viewer=viewer, x=self.obs_value, stroke_width=1, fill='none', close_path=False, **kwargs) @property def name_rest(self): return self.table_index @property def rest_value(self): return self._rest_value @property def obs_value(self): return self.x[0] @property def redshift(self): return self._redshift @redshift.setter def redshift(self, redshift): self._redshift = redshift if str(self._x_unit.physical_type) == 'length': obs_value = self._rest_value*(1+redshift) elif str(self._x_unit.physical_type) == 'frequency': obs_value = self._rest_value/(1+redshift) else: # catch all for anything else (wavenumber, energy, etc) rest_angstrom = (self._rest_value*self._x_unit).to_value(u.Angstrom, equivalencies=u.spectral()) obs_angstrom = rest_angstrom*(1+redshift) obs_value = (obs_angstrom*u.Angstrom).to_value(self._x_unit, equivalencies=u.spectral()) self.x = [obs_value, obs_value] @property def identify(self): return self._identify @identify.setter def identify(self, identify): if not isinstance(identify, bool): # pragma: no cover raise TypeError("identify must be of type bool") self._identify = identify self.stroke_width = 3 if identify else 1 def _process_identify_change(self, msg): self.identify = msg.name_rest == self.table_index def _update_data(self, x_all): new_unit = x_all.unit if new_unit == self._x_unit: return old_quant = self._rest_value*self._x_unit self._rest_value = old_quant.to_value(new_unit, equivalencies=u.spectral()) # re-compute self.x from current redshift (instead of converting that as well) self.redshift = self._redshift self._x_unit = new_unit
[docs]class SliceIndicatorMarks(BaseSpectrumVerticalLine, HubListener): """Subclass on bqplot Lines to handle slice/wavelength indicator. """ def __init__(self, viewer, slice=0, **kwargs): self._viewer = viewer self._oob = False # out-of-bounds, either False, 'left', or 'right' self._active = False self._show_if_inactive = True self._show_wavelength = True self.slice = slice x_all = viewer.data()[0].spectral_axis # _update_data will set self._x_all, self._x_unit, self.x self._update_data(x_all) viewer.state.add_callback("x_min", lambda x_min: self._handle_oob(update_label=True)) viewer.state.add_callback("x_max", lambda x_max: self._handle_oob(update_label=True)) viewer.session.hub.subscribe(self, SliceToolStateMessage, handler=self._on_change_state) super().__init__(viewer=viewer, x=self.x[0], stroke_width=2, marker='diamond', fill='none', close_path=False, labels=['slice'], labels_visibility='none', **kwargs) self._handle_oob() # instead of using the Lines label which is limited, we'll use a Label object which # will follow the x-coordinate of the slice indicator line, with a fixed y-value # (in axes-units) and will flip its alignment depending on whether the line is on the # left or right side of the axes. self.label = ShadowLabelFixedY(viewer, self, shadow_traits=[], default_size=12, y=0.95) # default to the initial state of the tool since we can't control if this will # happen before or after the initialization of the tool self._on_change_state({'active': True}) @property def marks(self): return [self, self.label] def _handle_oob(self, x_coord=None, update_label=False): if x_coord is None: x_coord = self._slice_to_x(self.slice) x_min, x_max = self._viewer.state.x_min, self._viewer.state.x_max if x_min is None or x_max is None: self.x = [x_coord, x_coord] return x_range = x_max - x_min padding_fig = 0.01 padding = padding_fig * x_range x_min += padding x_max -= padding if x_coord < x_min: self.x = [padding_fig, padding_fig] self.scales = {**self.scales, 'x': LinearScale(min=0, max=1)} self.line_style = 'dashed' self._oob = 'left' elif x_coord > x_max: self.x = [1-padding_fig, 1-padding_fig] self.scales = {**self.scales, 'x': LinearScale(min=0, max=1)} self.line_style = 'dashed' self._oob = 'right' else: self.x = [x_coord, x_coord] self.scales = {**self.scales, 'x': self._viewer.scales['x']} self.line_style = 'solid' self._oob = False if update_label: self._update_label() def _slice_to_x(self, slice=0): if not isinstance(slice, int): raise TypeError(f"slice must be of type int, not {type(slice)}") return self._x_all[slice] def _update_colors_opacities(self): # orange (accent) if active, import button blue otherwise (see css in app.vue) if not self._show_if_inactive and not self._active: self.label.visible = False self.visible = False return self.visible = True self.label.visible = self._show_wavelength self.colors = ["#c75109" if self._active else "#007BA1"] self.opacities = [1.0 if self._active else 0.9] def _on_change_state(self, msg): if isinstance(msg, dict): changes = msg else: changes = msg.change for k, v in changes.items(): if k == 'active': self._active = v elif k == 'show_indicator': self._show_if_inactive = v elif k == 'show_wavelength': self._show_wavelength = v self._update_colors_opacities() def _update_label(self): # U+00A0 is a blank space, U+25C0 a left arrow triangle, and U+25B6 a right arrow triangle if self._oob == 'left': self.labels = [f'\u00A0 \u25c0 {self._slice_to_x(self.slice):0.4e} {self._x_unit} \u00A0'] # noqa elif self._oob == 'right': self.labels = [f'{self._slice_to_x(self.slice):0.4e} {self._x_unit} \u25b6 \u00A0'] else: self.labels = [f'\u00A0 {self._slice_to_x(self.slice):0.4e} {self._x_unit} \u00A0'] @property def slice(self): return self._slice @slice.setter def slice(self, slice): self._slice = slice # if this is within the init, the data may not have been set yet, # in which case we'll just set self._slice for the first time, but # do not need to update self.x or label (yet) if hasattr(self, '_x_all'): x_coord = self._slice_to_x(slice) self._handle_oob(x_coord) self._update_label() def _update_data(self, x_all): # we want to preserve slice number, so we'll do a bit more than the # default unit-conversion in the base class self._x_all = x_all.value self._x_unit = str(x_all.unit) x_coord = self._slice_to_x(self.slice) self._handle_oob(x_coord) if self.labels_visibility == 'label': # update label with new value/unit self._update_label()
[docs]class ShadowMixin: """Mixin class to propagate traits from one mark object to another. Anything in ``sync_traits`` will be mirrored directly from ``shadowing`` to the shadowed object. Can manually override ``_on_shadowing_changed`` for more advanced logic cases. """ def _get_id(self, mark): return getattr(mark, '_model_id', None) def _setup_shadowing(self, shadowing, sync_traits=[], other_traits=[]): """ sync_traits: traits to set now, and mirror any changes to shadowing in the future other_trait: traits to set now, but not mirror in the future """ if not hasattr(self, '_shadowing'): self._shadowing = {} self._sync_traits = {} shadowing_id = self._get_id(shadowing) self._shadowing[shadowing_id] = shadowing self._sync_traits[shadowing_id] = sync_traits # sync initial values for attr in sync_traits + other_traits: self._on_shadowing_changed({'name': attr, 'new': getattr(shadowing, attr), 'owner': shadowing}) # subscribe to future changes shadowing.observe(self._on_shadowing_changed) def _on_shadowing_changed(self, change): if change['name'] in self._sync_traits.get(self._get_id(change.get('owner')), []): setattr(self, change['name'], change['new']) return
[docs]class ShadowLine(Lines, HubListener, ShadowMixin): """Create a white shadow line around another line to help make it standout on top of other lines. """ def __init__(self, shadowing, shadow_width=1, **kwargs): self._shadow_width = shadow_width super().__init__(scales=shadowing.scales, stroke_width=shadowing.stroke_width+shadow_width if shadowing.stroke_width else 0, # noqa marker_size=shadowing.marker_size+shadow_width if shadowing.marker_size else 0, # noqa colors=[kwargs.pop('color', 'white')], **kwargs) self._setup_shadowing(shadowing, ['scales', 'x', 'y', 'visible', 'line_style', 'marker'], ['stroke_width', 'marker_size'])
class ShadowSpatialSpectral(Lines, HubListener, ShadowMixin): """ Shadow the mark of a spatial subset collapsed spectrum, with the mask from a spectral subset, and the styling from the spatial subset. """ def __init__(self, spatial_spectrum_mark, spectral_subset_mark): # spatial_spectrum_mark: Lines mark corresponding to the spatially-collapsed spectrum # from a spatial subset # spectral_subset_mark: Lines mark on the FULL cube corresponding to the glue-highlight # of the spectral subset super().__init__(scales=spatial_spectrum_mark.scales, marker=None) self._spatial_mark_id = self._get_id(spatial_spectrum_mark) self._setup_shadowing(spatial_spectrum_mark, ['scales', 'y', 'visible', 'line_style'], ['x']) self._spectral_mark_id = self._get_id(spectral_subset_mark) self._setup_shadowing(spectral_subset_mark, ['stroke_width', 'x', 'y', 'visible', 'opacities', 'colors']) @property def spatial_spectrum_mark(self): return self._shadowing[self._spatial_mark_id] @property def spectral_subset_mark(self): return self._shadowing[self._spectral_mark_id] def _on_shadowing_changed(self, change): if hasattr(self, '_spectral_mark_id'): if change['name'] == 'y': # at initial setup, the arrays may not be populated yet if self.spatial_spectrum_mark.y.shape == self.spectral_subset_mark.y.shape: # force a copy or else we'll overwrite the mask to the spatial mark! change['new'] = deepcopy(self.spatial_spectrum_mark.y) change['new'][np.isnan(self.spectral_subset_mark.y)] = np.nan elif change['name'] == 'visible': # only show if BOTH shadowing marks are set to visible change['new'] = self.spectral_subset_mark.visible and self.spatial_spectrum_mark.visible # noqa return super()._on_shadowing_changed(change)
[docs]class ShadowLabelFixedY(Label, ShadowMixin): """Label whose position shadows that of a parent ``shadowing`` line and will flip alignment based on whether it is left or right of the center of the viewer. """ def __init__(self, viewer, shadowing, shadow_traits=['visible'], y=0.95, point_index=0, **kwargs): super().__init__(**kwargs) self._viewer = viewer self.y = [y] self.scales['y'] = LinearScale(min=0, max=1) self._point_index = point_index self._setup_shadowing(shadowing, shadow_traits, ['x', 'scales', 'labels', 'colors']) viewer.state.add_callback("x_min", lambda x_min: self._update_align()) viewer.state.add_callback("x_max", lambda x_max: self._update_align()) def _force_redraw(self): # TODO: bug in bqplot that change in align/colors traitlet doesn't update immediately, # we'll get around it in the meantime by just forcing the Label to see a change to the # text traitlet text = self.text self.text = [''] self.text = text def _update_align(self): if not isinstance(self.scales.get('x'), LinearScale): return # determine alignment automatically if self.scales['x'].min == 0 and self.scales['x'].max == 1: # then we're in axes units, so just check position compared to 0.5 is_to_right = self.x[0] > 0.5 else: # then we're in data units, so check position compared to the median of the axes limits is_to_right = self.x[0] > (self._viewer.state.x_min + self._viewer.state.x_max) / 2. if is_to_right and self.align != 'end': self.align = 'end' # force redraw by re-updating label self._force_redraw() if not is_to_right and self.align != 'start': self.align = 'start' # force redraw by re-updating label self._force_redraw() def _on_shadowing_changed(self, change): super()._on_shadowing_changed(change) if change['name'] == 'labels': self.text = [change['new'][self._point_index]] elif change['name'] in ('x', 'colors'): setattr(self, change['name'], [change['new'][self._point_index]]) if change['name'] == 'colors': # bqplot bug that won't notice change to colors, manually force re-draw self._force_redraw() elif change['name'] == 'scales': self.scales = {**self.scales, 'x': change['new']['x']} if change['name'] in ('x', 'scales'): # then the position of the label on the plot has changed, so re-determine whether # it should be aligned to the left or right self._update_align()
[docs]class PluginMark():
[docs] def update_xy(self, x, y): self.x = np.asarray(x) self.y = np.asarray(y)
[docs] def append_xy(self, x, y): self.x = np.append(self.x, x) self.y = np.append(self.y, y)
[docs] def clear(self): self.update_xy([], [])
[docs]class PluginLine(Lines, PluginMark, HubListener): def __init__(self, viewer, x=[], y=[], **kwargs): # color is same blue as import button super().__init__(x=x, y=y, colors=["#007BA1"], scales=viewer.scales, **kwargs)
[docs]class PluginScatter(Scatter, PluginMark, HubListener): def __init__(self, viewer, x=[], y=[], **kwargs): # color is same blue as import button super().__init__(x=x, y=y, colors=["#007BA1"], scales=viewer.scales, **kwargs)
[docs]class LineAnalysisContinuum(PluginLine): pass
[docs]class LineAnalysisContinuumCenter(LineAnalysisContinuum): def __init__(self, viewer, x=[], y=[], **kwargs): super().__init__(viewer, x, y, **kwargs) self.stroke_width = 1
[docs]class LineAnalysisContinuumLeft(LineAnalysisContinuum): def __init__(self, viewer, x=[], y=[], **kwargs): super().__init__(viewer, x, y, **kwargs) self.stroke_width = 5
[docs]class LineAnalysisContinuumRight(LineAnalysisContinuumLeft): pass
[docs]class LineUncertainties(Lines): def __init__(self, **kwargs): super().__init__(**kwargs)
[docs]class ScatterMask(Scatter): def __init__(self, **kwargs): super().__init__(**kwargs)
[docs]class SelectedSpaxel(Lines): def __init__(self, **kwargs): super().__init__(**kwargs)
[docs]class MarkersMark(PluginScatter): def __init__(self, viewer, **kwargs): kwargs.setdefault('marker', 'circle') super().__init__(viewer, **kwargs)