Source code for jdaviz.utils

import operator
import os
import time
import threading
import warnings
from collections import deque
from urllib.parse import urlparse
import fnmatch
import re
import hashlib
import multiprocessing as mp
from joblib import Parallel, delayed

import asdf
import numpy as np
from astropy.io import fits
from astropy.utils import minversion
from astropy.utils.data import download_file
from astropy.wcs import WCS
from astropy.wcs.wcsapi import BaseHighLevelWCS
from astroquery.mast import Observations, conf
from gwcs import WCS as gwcs
from gwcs.coordinate_frames import CompositeFrame, SpectralFrame
from matplotlib import colors as mpl_colors
import matplotlib.cm as cm
from photutils.utils import make_random_cmap
from regions import CirclePixelRegion, CircleAnnulusPixelRegion
from specutils.utils.wcs_utils import SpectralGWCS
import stdatamodels

from glue.config import settings
from glue.config import colormaps as glue_colormaps
from glue.core import BaseData
from glue.core.exceptions import IncompatibleAttribute
from glue.core.subset import SubsetState, RangeSubsetState, RoiSubsetState
from glue_astronomy.spectral_coordinates import SpectralCoordinates
from ipyvue import watch


__all__ = ['SnackbarQueue', 'enable_hot_reloading', 'bqplot_clear_figure',
           'standardize_metadata', 'ColorCycler', 'alpha_index',
           'get_subset_type', 'cached_uri', 'download_uri_to_path', 'layer_is_2d',
           'layer_is_2d_or_3d', 'layer_is_image_data', 'layer_is_wcs_only',
           'get_wcs_only_layer_labels', 'get_top_layer_index',
           'get_reference_image_data', 'standardize_roman_metadata',
           'wildcard_match', 'cmap_samples', 'glue_colormaps',
           'att_to_componentid', 'create_data_hash',
           'RA_COMPS', 'DEC_COMPS', 'SPECTRAL_AXIS_COMP_LABELS']

NUMPY_LT_2_0 = not minversion("numpy", "2.0.dev")
STDATAMODELS_LT_402 = not minversion(stdatamodels, "4.0.2.dev")

# For Metadata Viewer plugin internal use only.
PRIHDR_KEY = '_primary_header'
COMMENTCARD_KEY = '_fits_comment_card'

CONFIGS_WITH_LOADERS = ('deconfigged', 'lcviz',
                        'specviz', 'specviz2d',
                        'imviz', 'cubeviz',
                        'rampviz')
SPECTRAL_AXIS_COMP_LABELS = ('Wavelength', 'Wave', 'Frequency', 'Energy',
                             'Velocity', 'Wavenumber',
                             'World 0', 'World 1',
                             'Pixel Axis 0 [x]', 'Pixel Axis 1 [x]')
RA_COMPS = ['rightascension', 'ra', 'radeg', 'radeg',
            'radegrees', 'rightascensiondegrees', 'rightascensiondeg',
            'raobj', 'objra', 'sourcera', 'rasource', 'raj2000', 'ra2000',
            'worldra']
DEC_COMPS = ['declination', 'dec', 'decdeg', 'decdeg',
             'decdegrees', 'declinationdegrees', 'declinationdeg',
             'decobj', 'objdec', 'decsource', 'sourcedec', 'decj2000', 'dec2000',
             'worlddec']


[docs] class SnackbarQueue: ''' Class that performs the role of VSnackbarQueue, which is not implemented in ipyvuetify. ''' def __init__(self): self.queue = deque() # track whether we're showing a loading message which won't clear by timeout, # but instead requires another message with msg.loading = False to clear self.loading = False # track whether this is the first message - we'll increase the timeout for that # to give time for the app to load. self.first = True
[docs] def put(self, state, logger_plg, msg, history=True, popup=True): if msg.color not in ['info', 'warning', 'error', 'success', None]: raise ValueError(f"color ({msg.color}) must be on of: info, warning, error, success") if not msg.loading and history and logger_plg is not None: now = time.localtime() timestamp = f'{now.tm_hour}:{now.tm_min:02d}:{now.tm_sec:02d}' new_history = {'time': timestamp, 'text': msg.text, 'color': msg.color, 'traceback': msg.traceback} # for now, we'll hardcode the max length of the stored history if len(logger_plg.history) >= 50: logger_plg.history = logger_plg.history[1:] + [new_history] else: logger_plg.history = logger_plg.history + [new_history] if not (popup or msg.loading): if self.loading: # then we still need to clear the existing loading message self.loading = False self.close_current_message(state) return if msg.loading: # immediately show the loading message indefinitely until cleared by a new message # with loading=False (or overwritten by a new indefinite message with loading=True) self.loading = True self._write_message(state, msg) elif self.loading: # clear the loading state, immediately show this message, then re-enter the queue self.loading = False self._write_message(state, msg) else: warn_and_err = ('warning', 'error') if msg.color in warn_and_err: if (state.snackbar.get('show') and ((msg.color == 'warning' and state.snackbar.get('color') in warn_and_err) or # noqa (msg.color == 'error' and state.snackbar.get('color') == 'error'))): # put this NEXT in the queue immediately FOLLOWING all warning/errors non_warning_error = [msg.color not in warn_and_err for msg in self.queue] # noqa if True in non_warning_error: # insert BEFORE index self.queue.insert(non_warning_error.index(True), msg) else: self.queue.append(msg) else: # interrupt the queue IMMEDIATELY # (any currently shown messages will repeat after) self._write_message(state, msg) else: # put this LAST in the queue self.queue.append(msg) if len(self.queue) == 1: self._write_message(state, msg)
[docs] def close_current_message(self, state): if self.loading: # then we've been interrupted, so keep this item in the queue to show after # loading is complete return # turn off snackbar iteself state.snackbar['show'] = False if len(self.queue) > 0: # determine if the closed entry came from the queue (not an interrupt) # in which case we should remove it from the queue. We clear here instead # of when creating the snackbar so that items that are interrupted # (ie by a loading message) will reappear again at the top of the queue # so they are not missed msg = self.queue[0] if msg.text == state.snackbar['text']: try: _ = self.queue.popleft() except IndexError: # in case the queue has been cleared in the meantime pass # in case there are messages in the queue still, # display the next. if len(self.queue) > 0: msg = self.queue[0] self._write_message(state, msg)
def _write_message(self, state, msg): state.snackbar['show'] = False state.snackbar['text'] = msg.text state.snackbar['color'] = msg.color # TODO: in vuetify >2.3, timeout should be set to -1 to keep open # indefinitely state.snackbar['timeout'] = 0 # timeout controlled by thread state.snackbar['loading'] = msg.loading state.snackbar['show'] = True if msg.loading: # do not create timeout - the message will be indefinite until # cleared by another message return # timeout of the first message needs to be increased by a # few seconds to account for the time spent in page rendering. # A more elegant way to address this should be via a callback # from a vue hook such as mounted(). It doesn't work though. # Since this entire queue effort is temporary anyway (pending # the implementation of VSnackbarQueue in ipyvuetify, it's # better to keep the solution contained all in one place here. timeout = msg.timeout if timeout < 500: # half-second minimum timeout timeout = 500 if self.first: timeout += 5000 self.first = False # create the timeout function which will close this message and # show the next message if one has been added to the queue since def sleep_function(timeout, text): timeout_ = float(timeout) / 1000 time.sleep(timeout_) if state.snackbar['show'] and state.snackbar['text'] == text: # don't close the next message if the user manually clicked close! self.close_current_message(state) x = threading.Thread(target=sleep_function, args=(timeout, msg.text), daemon=True) x.start()
[docs] def enable_hot_reloading(): """Use ``watchdog`` to perform hot reloading.""" try: watch(os.path.dirname(__file__)) except ModuleNotFoundError: print(( 'Watchdog module, needed for hot reloading, not found.' ' Please install with `pip install watchdog`'))
[docs] def bqplot_clear_figure(fig): """Clears a given ``bqplot.Figure`` to mimic matplotlib ``clf()``. This is necessary when we draw multiple plots across different plugins. """ # Clear bqplot figure (copied from bqplot/pyplot.py) fig.marks = [] fig.axes = [] setattr(fig, 'axis_registry', {})
[docs] def alpha_index(index): """Converts an index to label (A-Z, AA-ZZ). Parameters ---------- index : int Index between 0 and 701, inclusive. Higher number is accepted but will have special characters. Returns ------- label : str String in the range A-Z, AA-ZZ if index is within 0-701 range, inclusive. Raises ------ TypeError Index is not integer. ValueError Index is negative. """ # if we ever want to support more than 702 layers, then we'll need a third # "digit" and will need to account for the horizontal space in the legends if not isinstance(index, int): raise TypeError("index must be an integer") if index < 0: raise ValueError("index must be positive") if index <= 25: # a-z return chr(97 + index) else: # aa-zz (26-701), then overflow strings like '{a' return chr(97 + index//26 - 1) + chr(97 + index % 26)
def _try_gwcs_to_fits_sip(gw): """ Try to convert this GWCS to FITS SIP. Some GWCS models cannot be converted to FITS SIP. In that case, a warning is raised and the GWCS is used, as is. """ if isinstance(gw, gwcs): try: result = WCS(gw.to_fits_sip(), relax=True) except ValueError as err: warnings.warn( "The GWCS coordinates could not be simplified to " "a SIP-based FITS WCS, the following error was " f"raised: {err}", UserWarning ) result = gw else: result = gw return result def data_has_valid_wcs(data, ndim=None): """Check if given glue Data has WCS that is compatible with APE 14.""" status = hasattr(data, 'coords') and isinstance(data.coords, BaseHighLevelWCS) if ndim is not None: status = status and data.coords.world_n_dim == ndim return status def layer_is_table_data(layer): return isinstance(layer, BaseData) and layer.ndim == 1 _wcs_only_label = "_WCS_ONLY" def is_wcs_only(layer): # identify WCS-only layers if hasattr(layer, 'layer'): layer = layer.layer return ( # WCS-only layers have a metadata label: getattr(layer, 'meta', {}).get(_wcs_only_label, False) ) def is_not_wcs_only(layer): return not is_wcs_only(layer) def layer_is_not_dq(data): return '[DQ' not in data.label
[docs] def standardize_metadata(metadata): """Standardize given metadata so it can be viewed in Metadata Viewer plugin. The input can be plain dictionary or FITS header object. Output is just a plain dictionary. """ if isinstance(metadata, fits.Header): try: out_meta = dict(metadata) out_meta[COMMENTCARD_KEY] = metadata.comments except Exception: # Invalid FITS header # pragma: no cover out_meta = {} elif isinstance(metadata, dict): out_meta = metadata.copy() # specutils nests it but we do not want nesting if 'header' in metadata and isinstance(metadata['header'], fits.Header): out_meta.update(standardize_metadata(metadata['header'])) del out_meta['header'] else: raise TypeError('metadata must be dictionary or FITS header') return out_meta
[docs] def standardize_roman_metadata(data_model): """ Metadata standardization for Roman datamodels ``meta`` attributes. Converts to a flat dictionary and strips the redundant top-level tags ("roman", and "meta"). Parameters ---------- data_model : `~roman_datamodels.datamodels.DataModel` Roman datamodel. Returns ------- d : dict Flattened dictionary of metadata """ # if the file is a Roman DataModel: if hasattr(data_model, 'to_flat_dict'): # Roman metadata are in nested dicts that we flatten: flat_dict_meta = data_model.to_flat_dict() # split off the redundant parts of the metadata: return { k.split('roman.meta.')[1]: v for k, v in flat_dict_meta.items() if 'roman.meta' in k } elif isinstance(data_model, asdf.AsdfFile): # otherwise use default standardization return standardize_metadata(data_model['roman']['meta'])
[docs] class ColorCycler: """ Cycles through matplotlib's default color palette after first using the Glue default data color. """ # default color cycle starts with the Glue default data color # followed by the matplotlib default color cycle, except for the # second color (orange) in the matplotlib cycle, which is too close # to the jdaviz accent color (also orange). default_dark_gray = settings._defaults['DATA_COLOR'] default_color_palette = [ default_dark_gray, '#1f77b4', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf' ] def __init__(self, counter=-1): self.counter = counter
[docs] def __call__(self): self.counter += 1 cycle_index = self.counter % len(self.default_color_palette) color = self.default_color_palette[cycle_index] return color
[docs] def reset(self): self.counter = -1
def _chain_regions(regions, ops): """ Combine multiple regions into a compound pixel/sky region based on the specified operators. If the operators are valid binary operators recognized by both glue and Regions, the function returns a compound region. Otherwise, it returns a list of individual regions paired with their respective operators, or just returns the region if regions only contains one region. Parameters ---------- regions : list A list of region objects. ops : list of str A list of glue states that map to operator names to describe how to combine regions (e.g. 'AndState'). Returns ------- Compound region or list A single compound region if valid operators are provided; otherwise, a list of tuples containing individual regions and their associated operators. """ if len(regions) == 1: return regions[0] valid_operators = { 'AndState': operator.and_, 'OrState': operator.or_, 'XorState': operator.xor } operators = ops[1:] # first subset doesn't need an operator # if regions cant be combined into a compound region as an annulus or with # and/or/xor, return list of tuples of (region, operator) annulus = _combine_if_annulus(regions[0], regions[1], operators[0]) if annulus is None: if not np.all(np.isin(operators, list(valid_operators.keys()))): return list(zip(regions, [''] + operators)) r1 = annulus or regions[0] for i in range(2 if annulus else 0, len(operators)): r1 = valid_operators[operators[i]](r1, regions[i + 1]) return r1 def _combine_if_annulus(region1, region2, op): """ Determine whether applying `region2` to `region1` using the specified operator results in a circular annulus. If the conditions are met, return a `CircleAnnulusPixelRegion`; otherwise, return `None`. """ if ( isinstance(region1, (CirclePixelRegion)) and isinstance(region1, (CirclePixelRegion)) and op == 'AndNotState' and region1.center == region2.center and region1.radius > region2.radius ): return CircleAnnulusPixelRegion(center=region1.center, inner_radius=region2.radius, outer_radius=region1.radius)
[docs] def get_subset_type(subset): """ Determine the subset type of a subset or layer Parameters ---------- subset : glue.core.subset.Subset or glue.core.subset_group.GroupedSubset should have ``subset_state`` as an attribute, otherwise will return ``None``. Returns ------- subset_type : str or None 'spatial', 'spectral', 'temporal', or None """ if not hasattr(subset, 'subset_state'): return None while hasattr(subset.subset_state, 'state1'): # this assumes no mixing between spatial and spectral subsets and just # taking the first component (down the hierarchical tree) to determine the type subset = subset.subset_state.state1 if isinstance(subset.subset_state, RoiSubsetState): return 'spatial' elif isinstance(subset.subset_state, RangeSubsetState): # look within a SubsetGroup, or a single Subset subset_list = getattr(subset, 'subsets', [subset]) for ss in subset_list: if hasattr(ss, 'data'): ss_data = ss.data elif hasattr(ss.att, 'parent'): # if `ss` is a subset state, it won't have a `data` attr, # check the world coordinate's parent data: ss_data = ss.att.parent else: # if we reach this `else`, continue searching # through other subsets in the group to identify the # subset type: continue # check for spectral coordinate in GWCS by looking for SpectralFrame if isinstance(ss_data.coords, gwcs): if isinstance(ss_data.coords, (SpectralFrame, SpectralGWCS)): return 'spectral' elif isinstance(ss_data.coords, CompositeFrame): if np.any([isinstance(frame, SpectralFrame) for frame in ss_data.coords.output_frame.frames]): return 'spectral' else: continue # check for a spectral coordinate in FITS WCS: wcs_coords = ( ss_data.coords.wcs.ctype if hasattr(ss_data.coords, 'wcs') else [] ) has_spectral_coords = ( any(str(coord).startswith('WAVE') for coord in wcs_coords) or # also check for a spectral coordinate from the glue_astronomy translator: isinstance(ss_data.coords, (SpectralCoordinates, SpectralGWCS)) ) if has_spectral_coords: return 'spectral' # otherwise, assume temporal: return 'temporal' else: return None
class MultiMaskSubsetState(SubsetState): """ A subset state that can include a different mask for different datasets. Adopted from https://github.com/glue-viz/glue/pull/2415 Parameters ---------- masks : dict A dictionary mapping data UUIDs to boolean arrays with the same dimensions as the data arrays. """ def __init__(self, masks=None): super(MultiMaskSubsetState, self).__init__() self._masks = masks def to_mask(self, data, view=None): if data.uuid in self._masks: mask = self._masks[data.uuid] if view is not None: mask = mask[view] return mask else: raise IncompatibleAttribute() def copy(self): return MultiMaskSubsetState(masks=self._masks) def __gluestate__(self, context): serialized = {key: context.do(value) for key, value in self._masks.items()} return {'masks': serialized} def total_masked_first_data(self): first_data = next(iter(self._masks)) return len(np.where(self._masks[first_data])[0]) @classmethod def __setgluestate__(cls, rec, context): masks = {key: context.object(value) for key, value in rec['masks'].items()} return cls(masks=masks) def get_cloud_fits(possible_uri, ext=None): """ Retrieve and open a FITS file from an S3 URI using fsspec. Return the input unchanged if it is not an S3 URI. If ``possible_uri`` is an S3 URI, the specified extensions from the FITS file will be opened remotely using `astropy.io.fits` with `fsspec`. Anonymous access is assumed for S3. If the URI is not S3-based, the input is returned as-is. Parameters ---------- possible_uri : str A path or URI to the FITS file. If the URI uses the ``s3://`` scheme, the file is accessed via fsspec and returned as an `~astropy.io.fits.HDUList`. Otherwise, the string is returned unchanged. ext : int, str, or list, optional Extension(s) to load from the FITS file. Can be an integer index (e.g., 0), a string name (e.g., "SCI"), or a list of such values. If `None`, all extensions are loaded. Returns ------- file_obj : `~astropy.io.fits.HDUList` or str If the URI is an S3 FITS file, returns an `HDUList` containing the requested extensions. Otherwise, returns the original input string. """ parsed_uri = urlparse(possible_uri) # TODO: Add caching logic if not parsed_uri.scheme.lower() == 's3': raise ValueError("Not an S3 URI: {}".format(possible_uri)) downloaded_hdus = [] # this loads the requested extensions into local memory: with fits.open(possible_uri, fsspec_kwargs={"anon": True}) as hdul: if ext is None: ext_list = list(range(len(hdul))) elif not isinstance(ext, list): ext_list = [ext] else: ext_list = ext for extension in ext_list: hdu_obj = hdul[extension] downloaded_hdus.append(hdu_obj.copy()) file_obj = fits.HDUList(downloaded_hdus) return file_obj
[docs] def cached_uri(uri): # return a filename if it exists in the working directory, otherwise return the URI # this is used in CI tests where the MAST files are downloaded by a separate workflow, # cached, and restored in the tox working directory to avoid downloading them again fname = uri.split(':')[-1].split('/')[-1] if os.path.isfile(fname): return fname return uri
[docs] def download_uri_to_path(possible_uri, cache=None, local_path=os.curdir, timeout=None, dryrun=False): """ Retrieve data from a URI (or a URL). Return the input if it cannot be parsed as a URI. If ``possible_uri`` is a MAST URI, the file will be retrieved via astroquery's `~astroquery.mast.ObservationsClass.download_file`. If ``possible_uri`` is a URL, it will be retrieved via astropy with `~astropy.utils.data.download_file`. Parameters ---------- possible_uri : str or other This input will be returned without changes if it is not a string, or if it is a local file path to an existing file. Otherwise, it will be parsed as a URI. Local file URIs beginning with ``file://`` are not supported by this method – nor are they necessary, since string paths without the scheme work fine! Cloud FITS are not yet supported. cache: None, bool, or ``"update"``, optional Cache file after download. If ``possible_uri`` is a URL, ``cache`` may be a boolean or ``"update"``, see documentation for `~astropy.utils.data.download_file` for details. If cache is None, the file is cached and a warning is raised suggesting to set ``cache`` explicitly in the future. local_path : str, optional Save the downloaded file to this path. Default is to save the file with its remote filename in the current working directory. This is only used if data is requested from `astroquery.mast`. timeout : float, optional If downloading from a remote URI, set the timeout limit for remote requests in seconds (passed to `~astropy.utils.data.download_file` or `~astroquery.mast.Conf.timeout`). dryrun : bool Set to `True` to skip downloading data from MAST. This is only used for debugging. Returns ------- possible_uri : str or other If ``possible_uri`` cannot be retrieved as a URI, returns the input argument unchanged. If ``possible_uri`` can be retrieved as a URI, returns the local path to the downloaded file. """ if not isinstance(possible_uri, str): # only try to parse strings: return possible_uri if os.path.exists(possible_uri): # don't try to parse file paths: return possible_uri if os.environ.get("JDAVIZ_START_DIR", ""): # avoiding creating local paths in a tmp dir when in standalone: local_path = os.path.join(os.environ["JDAVIZ_START_DIR"], local_path) timeout = int(timeout) if timeout is not None else timeout parsed_uri = urlparse(possible_uri) cache_none_msg = ( "You may be querying for a remote file " f"at '{possible_uri}', but the `cache` argument was not used." f"Unless you set `cache` " f"explicitly, remote files will be cached locally and " f"this warning will be raised." ) local_path_msg = ( f"You requested to cache data to the `local_path`='{local_path}'. This " f"keyword argument is supported for downloads of MAST URIs via astroquery, " f"but since the remote file at '{possible_uri}' will be downloaded " f"using `astropy.utils.data.download_file`, the file will be " f"stored in the astropy download cache instead." ) cache_warning = False if cache is None: cache = True cache_warning = True if parsed_uri.scheme.lower() == 'mast': if cache_warning: warnings.warn(cache_none_msg, UserWarning) if local_path is not None and os.path.isdir(local_path): # if you give a directory, save the file there with default name: # os.path.sep does not work because on windows that is a back slash # and this web path needs to be split with a forward slash local_path = os.path.join(local_path, parsed_uri.path.split('/')[-1]) if not dryrun: with conf.set_temp('timeout', timeout): (status, msg, url) = Observations.download_file( possible_uri, cache=cache, local_path=local_path ) else: status = "COMPLETE" if status != 'COMPLETE': # pass along the error message from astroquery if the # data were not successfully downloaded: raise ValueError( f"Failed query for URI '{possible_uri}' at '{url}':\n\n{msg}" ) if local_path is None: # if not specified, this is the default location: # os.path.sep does not work because on Windows that is a back slash # and this web path needs to be split with a forward slash local_path = os.path.join(os.getcwd(), parsed_uri.path.split('/')[-1]) return local_path elif parsed_uri.scheme.lower() in ('http', 'https', 'ftp'): if cache_warning: warnings.warn(cache_none_msg, UserWarning) if local_path not in (os.curdir, None): warnings.warn(local_path_msg, UserWarning) return download_file(possible_uri, cache=cache, timeout=timeout) elif parsed_uri.scheme == '': raise ValueError(f"The input file '{possible_uri}' cannot be parsed as a " f"URL or URI, and no existing local file is available " f"at this path.") else: raise ValueError(f"URI {possible_uri} with scheme {parsed_uri.scheme} is not " f"currently supported.")
[docs] def layer_is_2d(layer): # returns True for subclasses of BaseData with ndim=2, both for # layers that are WCS-only as well as images containing data: return isinstance(layer, BaseData) and layer.ndim == 2
def layer_is_3d(layer): # returns True for subclasses of BaseData with ndim=3: return isinstance(layer, BaseData) and layer.ndim == 3
[docs] def layer_is_2d_or_3d(layer): return isinstance(layer, BaseData) and layer.ndim in (2, 3)
[docs] def layer_is_image_data(layer): return layer_is_2d_or_3d(layer) and not layer.meta.get(_wcs_only_label, False)
[docs] def layer_is_wcs_only(layer): return layer_is_2d(layer) and layer.meta.get(_wcs_only_label, False)
[docs] def get_wcs_only_layer_labels(app): return [data.label for data in app.data_collection if layer_is_wcs_only(data)]
def wcs_is_spectral(wcs): if wcs is None: return False # NOTE: this may need further generalization for the GWCS but non-specutils case # or for the spectral cube case has_spectral_type = [ctype for ctype in wcs.world_axis_physical_types if ctype is not None and ctype[0:3] == 'em.'] return (isinstance(wcs, SpectralGWCS) or getattr(wcs, 'has_spectral', False) or len(has_spectral_type))
[docs] def get_top_layer_index(viewer): """Get index of the top visible image layer in a viewer. This is because when blinked, first layer might not be top visible layer. """ # exclude children of layer associations associations = viewer.jdaviz_app._data_associations visible_image_layers = [ i for i, lyr in enumerate(viewer.state.layers) if ( lyr.visible and layer_is_image_data(lyr.layer) and # check that this layer is a root, without parents: associations[lyr.layer.label]['parent'] is None ) ] if len(visible_image_layers): return visible_image_layers[-1] return None
[docs] def get_reference_image_data(app, viewer_id=None): """ Return the current reference data in the given image viewer and its index. By default, the first viewer is used. """ if viewer_id is None: if len(image_viewers := app.get_viewers_of_cls('ImvizImageView')) > 0: refdata = image_viewers[0].state.reference_data else: refdata = None else: viewer = app.get_viewer_by_id(viewer_id) refdata = viewer.state.reference_data if refdata is not None: iref = app.data_collection.index(refdata) return refdata, iref return None, -1
def escape_brackets(s): # Replace [ with [[] and ] with []] return re.sub(r'([\[\]])', r'[\1]', s) def has_wildcard(s): """Check if the string contains any shell-style wildcards: * or ?.""" return bool(re.search(r'[\*\?]', s))
[docs] def wildcard_match(obj, value, choices=None): """ Wrapper that handles both single string and list/tuple of strings as inputs for ``value``. Returns a list of strings from ``obj.choices`` that match the wildcard pattern(s) in ``value``. If no matches are found, returns a list containing ``value`` itself. .. note:: ``fnmatch`` provides support for all Unix style wildcards including ``*``, ``?``. We do not check for '[seq]`` and '[!seq]' because image extensions as we handle them contain brackets. Parameters ---------- obj : object An object with attributes ``choices`` and potentially ``multiselect``. value : str or list or tuple A string or list/tuple of strings to match against choices. Each string may contain Unix shell-style wildcards. choices : list of str, optional A list of strings to match against. If not provided, ``obj.choices`` will be used. Returns ------- list of str A list of matched strings or ``value``/``[value]`` if no matches found. """ def wildcard_match_str(internal_choices, internal_value): # Assume we want to escape brackets as in the case of images with # multiple extensions internal_value = escape_brackets(internal_value) matched = fnmatch.filter(internal_choices, internal_value) if len(matched) == 0: matched = [internal_value] return matched def wildcard_match_list_of_str(internal_choices, internal_value): matched = [] for v in internal_value: if isinstance(v, str) and has_wildcard(v): # Check for wildcard matches matched.extend(wildcard_match_str(internal_choices, v)) else: # Append as-is matched.append(v) # Remove duplicates while preserving order return list(dict.fromkeys(matched)) if choices is None: choices = getattr(obj, 'choices', None) if choices is None: return value # any works for both str and iterable if (getattr(obj, 'allow_multiselect', False) and any(has_wildcard(v) for v in value if isinstance(v, str))): if isinstance(value, str): obj.multiselect = True value = wildcard_match_str(choices, value) elif isinstance(value, (list, tuple)): obj.multiselect = True value = wildcard_match_list_of_str(choices, value) # If only '*' wildcards are left meaning that nothing matched, return empty selection. # Basically, '*' of empty should return empty---we don't want to error out. For other # patterns like 'foo*' not matching anything, we use the error to notify the user of no match. # e.g. value == ['*'] or ['*', '*'], choices == [] -> match == [] (rather than ['*']) if all(vi == '*' for v in value for vi in v): # List of strings value = [] if getattr(obj, 'multiselect', False) else '' return value
[docs] def att_to_componentid(att_helper, att): # get a glue state component id from an attribute helper and an attribute name for choice in att_helper.choices: if str(choice) == att: return choice raise ValueError(f"Could not find component ID for attribute '{att}'")
def parallelize_calculation(workers, collect_result_callback, n_cpu=mp.cpu_count() - 1): """ Function to perform parallel processing with joblib. The function takes a list of callables (functions with no arguments that return a result) and executes them in parallel. The results of each callable are passed to a callback function for collection. Parameters ---------- workers : worker type object The function to be called within the parallel backend context. collect_result_callback : function A callback function to collect the results of each worker. n_cpu : int The number of CPU cores to use for parallel processing. Defaults to the total number of available CPU cores - 1. """ results = Parallel(n_jobs=n_cpu)(delayed(worker)() for worker in workers) _ = [collect_result_callback(r) for r in results] def _clean_data_for_hash(data): """ Extract and return the array from the data object for hashing. The function checks for common attributes like 'flux' or 'data' to extract the relevant array. If the data is an `astropy.units.Quantity`, it extracts the value and records the unit. If the array is a masked array, it separates the mask and data for hashing. Parameters ---------- data : object The data object from which to extract the array for hashing. Returns ------- target_for_hash : array-like The extracted array to be used for hashing. """ new_data = data if hasattr(data, 'flux'): new_data = data.flux elif hasattr(data, 'data'): new_data = data.data unit_str = getattr(data, 'unit', None) or getattr(new_data, 'unit', None) unit_str = str(unit_str) if unit_str is not None else None data_mask = getattr(data, 'mask', None) data_mask = data_mask if data_mask is not None else getattr(new_data, 'mask', None) try: mask_arr = np.ascontiguousarray(data_mask).astype('uint8') if data_mask is not None else None # noqa except TypeError: mask_arr = None try: arr = np.ascontiguousarray(new_data) except ValueError: arr = None return arr, mask_arr, unit_str
[docs] def create_data_hash(input_data): """ Create and return a deterministic hash for the provided data. The function supports various input types including numpy arrays, astropy Quantities, strings, and specutils Spectrum objects. If the input is `None` or of an unsupported type (e.g., a plain number), the function returns `None`. Parameters ---------- input_data : array-like, str, `astropy.units.Quantity`, `specutils.Spectrum1D`, or None The data to hash. If a list or tuple, it may contain arrays or strings. If `astropy.units.Quantity`, the unit is included in the hash. If `None`, the function returns `None`. Returns ------- str or None A hexadecimal string representing the SHA-256 hash of the data, or `None` if 'input_data' is `None` or of an unsupported type (e.g., a plain number). """ # Initialize hasher and include shape/dtype to avoid collisions # Use blake2b and shorter digest for speed arr, mask_arr, unit_str = _clean_data_for_hash(input_data) try: valid_arr_check = np.any(arr) except TypeError: # np.any(arr) may fail on some data types valid_arr_check = any([bool(a) for a in arr]) if not valid_arr_check: return None hasher = hashlib.blake2b(digest_size=16) hasher.update(f'shape:{arr.shape};dtype:{arr.dtype.str}'.encode()) if unit_str is not None: hasher.update(f';unit:{unit_str}'.encode()) # Hash the main array buffer in chunks via memoryview if possible try: mv = memoryview(arr).cast('B') except TypeError: # Fallback - arr.tobytes() will create a copy but should work try: hasher.update(arr.tobytes()) except (AttributeError, TypeError, ValueError, MemoryError) as err: raise RuntimeError(f'Could not obtain bytes for hashing: {err}') # include mask if present if mask_arr is not None: try: hasher.update(b';mask:') hasher.update(np.ascontiguousarray(mask_arr).tobytes()) except (AttributeError, TypeError, ValueError, MemoryError): # best effort: ignore mask if it cannot be serialized pass return hasher.hexdigest() chunk = 1024 * 1024 n = len(mv) for i in range(0, n, chunk): hasher.update(mv[i:i + chunk]) # Include mask bytes if present if mask_arr is not None: try: hasher.update(b';mask:') mv_mask = memoryview(mask_arr).cast('B') nm = len(mv_mask) for i in range(0, nm, chunk): hasher.update(mv_mask[i:i + chunk]) except (TypeError, ValueError): # ignore mask-related failures; hash already includes data pass return hasher.hexdigest()
# Add new and inverse colormaps to Glue global state. Also see ColormapRegistry in # https://github.com/glue-viz/glue/blob/main/glue/config.py new_cms = (['Rainbow', cm.rainbow], ['Seismic', cm.seismic], ['Reversed: Gray', cm.gray_r], ['Reversed: Viridis', cm.viridis_r], ['Reversed: Plasma', cm.plasma_r], ['Reversed: Inferno', cm.inferno_r], ['Reversed: Magma', cm.magma_r], ['Reversed: Hot', cm.hot_r], ['Reversed: Rainbow', cm.rainbow_r]) for cur_cm in new_cms: if cur_cm not in glue_colormaps.members: glue_colormaps.add(*cur_cm) def _register_random_cmap( cmap_name, bkg_color=[0, 0, 0], bkg_alpha=1, seed=42, ncolors=10_000 ): """ Custom random colormap, useful for rendering image segmentation maps. The default background for `label==0` is *transparent*. If the segmentation map contains more than 10,000 labels, adjust the `ncolors` kwarg to ensure uniqueness. """ cmap = make_random_cmap(ncolors=ncolors, seed=seed) cmap.colors[0] = bkg_color + [bkg_alpha] cmap.name = cmap_name glue_colormaps.add(cmap_name, cmap) _register_random_cmap('Random', bkg_alpha=1) # give UI access to sampled version of the available colormap choices def _hex_for_cmap(cmap): N = 50 cm_sampled = cmap.resampled(N) return [mpl_colors.to_hex(cm_sampled(i)) for i in range(N)] cmap_samples = {cmap[1].name: _hex_for_cmap(cmap[1]) for cmap in glue_colormaps.members} def _get_celestial_wcs(wcs): """ If `wcs` has a celestial component return that, otherwise return None """ if isinstance(wcs, gwcs) and not type(wcs) is SpectralGWCS: data_wcs = WCS(wcs.to_fits_sip()) elif isinstance(wcs, WCS): data_wcs = getattr(wcs, 'celestial', None) else: return None return data_wcs def closest_point_on_segment(px, py, x1, y1, x2, y2): """ Find the closest point on a line segment to a reference point. Parameters ---------- px : float X coordinate of the reference point. py : float Y coordinate of the reference point. x1, y1, x2, y2 : array-like Coordinates of the line segment endpoints. Returns ------- closest_x, closest_y : ndarray Coordinates of the closest points on the segments. """ dx = x2 - x1 dy = y2 - y1 len_sq = dx**2 + dy**2 t = np.clip(((px - x1) * dx + (py - y1) * dy) / len_sq, 0, 1) closest_x = x1 + t * dx closest_y = y1 + t * dy return closest_x, closest_y def find_closest_polygon_mark(px, py, marks): """ Find the closest mark to a click point and return its observation index. Parameters ---------- px : float X coordinate of the reference point. py : float Y coordinate of the reference point. marks : list of RegionOverlay List of mark objects to compare against the given point. Returns ------- closest_idx : int or None The observation index of the closest mark, or None if no marks. """ min_dist = float('inf') closest_idx = None for mark in marks: x_coords = np.array(mark.x) y_coords = np.array(mark.y) if len(x_coords) == 0 or len(y_coords) == 0: continue x1 = x_coords x2 = np.roll(x_coords, -1) y1 = y_coords y2 = np.roll(y_coords, -1) closest_xs, closest_ys = closest_point_on_segment(px, py, x1, y1, x2, y2) dist = (closest_xs - px)**2 + (closest_ys - py)**2 min_idx = np.argmin(dist) min_dist_for_this_mark = dist[min_idx] if min_dist_for_this_mark < min_dist: min_dist = min_dist_for_this_mark closest_idx = mark.label return closest_idx