Source code for admit.util.Segments

""" .. _Segments-api:

    **Segments** --- Manages groups of line segments.
    -------------------------------------------------

    This module defines the Segments class.
"""
# system imports
import numpy as np
import copy

# ADMIT imports
from admit.util.AdmitLogging import AdmitLogging as logging


[docs]class Segments(object): """ Class to hold segments and convert them between different types. Segments are defined by a beginning and ending channel (both inclusive). ADMIT gives special meaning to a segment, for example a line can be found within that segment, or a continuum has to be fitted in that segment (or group of segments). Parameters ---------- st : array like Array like object containing either the full segment list (array of two element arrays containing the start and end channel numbers for each segment), or an array of starting channel numbers. Default: None. en : array like An array of the ending channel number corresponding to the starting channel numbers given in st. Leave as None if st contains the full listing. Default: None. nchan : int The number of channels in the spectrum that the segments refer to. This is used to construct the bit mask for tracking all segments. If left as None, no bitmask will be made. Default: None. startchan : int The starting channel number of the spectrum that the segments refer to. Must be >= 0. Default: 0. """ def __init__(self, st=None, en=None, nchan=None, startchan=0): # initialize everything self._segments = [] self._nchan = nchan # error chack the starting channel number self._startchan = int(startchan) if self._startchan < 0: raise Exception("Start channel must be 0 or greater.") # if nchan was specified create the bitmask if nchan: self._chans = np.array([0] * nchan) # determine the maxchan self._maxchan = nchan - 1 + self._startchan else: self._chans = None self._maxchan = 0 if st is None: return if type(st) != type(en) and en is not None: raise Exception("Channel start and end points must be the same type.") # build the list of segments # if en is not given peak = 0 if en is None: # st must be array like if not hasattr(st, '__iter__'): raise Exception("A list must be given for parameter st.") for seg in st: # each one must have length of 2 if len(seg) != 2: raise Exception("Each segment must have a size of 2.") # make sure they are in the right order tempseg = [int(seg[0]), int(seg[1])] tempseg.sort() peak = max(peak, tempseg[1]) self._segments.append(tempseg) else: # if both en and st are given and are ints if not hasattr(st, '__iter__'): tempseg = [int(st), int(en)] tempseg.sort() peak = max(peak, tempseg[1]) self._segments.append(tempseg) else: # if both en ans st are given and both array like # create iterators stit = iter(st) enit = iter(en) if len(st) != len(en): logging.warning("Starting and ending channel ranges do not have the same length, truncating the longer.") # iterate through the lists while True: try: tempseg = [int(stit.next()), int(enit.next())] tempseg.sort() peak = max(peak, tempseg[1]) self._segments.append(tempseg) except StopIteration: break if self._chans is None: self._chans = np.array([0] * (peak + 1)) # determine the maxchan self._maxchan = peak + self._startchan # build the bit mask for seg in self._segments: if seg[1] > self._maxchan or seg[0] < self._startchan: raise Exception("All or part of a segment is beyond the given spectrum. Segment: %s, bounds: [%i, %i]" % (seg, self._startchan, self._maxchan)) self._chans[seg[0] - self._startchan: seg[1] - self._startchan + 1] = 1 def __len__(self): """ Returns the number of segments Parameters ---------- None Returns ------- int containing the number of segments in the class """ return len(self._segments) def __iter__(self): """ Rurns an iterator to the list of segments Parameters ---------- None Returns ------- Iterator to the list of segments """ return iter(self._segments) def __getitem__(self, index): """ Returns the segment at the given index Parameters ---------- index : int The index of the segment to return Returns ------- Two element list (segment) of the starting and ending channel numbers """ if index >= len(self._segments): raise Exception("Index %i is beyond the range of indices (%i)" % (index, len(self._segments))) return self._segments[index] def __add__(self, other): """ Method to add two Segments classes together, without merging he semgents. The bitmask is recalculated. Parameters ---------- other : Segments class instance or array like If a Segments instance is given then the internal segments are added to the current segment list. If an array is given then the items of the array are added to the current segment list. Returns ------- The instance of this class with the new segments incorporated """ new = copy.deepcopy(self) if type(other) == type(new): if new._startchan != other._startchan: raise Exception("Starting channels do not match.") if len(new._chans) != len(other._chans): raise Exception("Number of channels do not match.") for seg in other: new._segments.append(seg) new.recalcmask() elif hasattr(other, "__iter__"): for seg in other: tempseg = [int(seg[0]), int(seg[1])] if tempseg[1] > new._maxchan or tempseg[0] < new._startchan: raise Exception("All or part of a segment is beyond the given spectrum. Segment: %s, bounds: [%i, %i]" % (seg, self._startchan, self._maxchan)) new._segments.append(tempseg) new.recalcmask() return new def __setitem__(self, index, item): """ Method to set the segment at index to a new value Parameters ---------- index : int The location in the segment array to replace item : two element array The new segment to replace the indicated one with Returns ------- None """ if not hasattr(item, "__iter__"): raise Exception("Segments must be ginven as an iteratable object (list, np.array, etc.") if len(item) != 2: raise Exception("Segments must have length 2.") tempseg = [int(item[0]), int(item[1])] tempseg.sort() if tempseg[1] > self._maxchan or tempseg[0] < self._startchan: raise Exception("All or part of a segment is beyond the given spectrum. Segment: %s, bounds: [%i, %i]" % (tempseg, self._startchan, self._maxchan)) self._segments[index] = tempseg self._chans[tempseg[0] - self._startchan: tempseg[1] - self._startchan + 1] = 1 def __contains__(self, chan): """ Method to determine if a given channel is in a segment. This requires the bit mask to be available Parameters ---------- chan : int The channel number to test Returns ------- bool, True if the channel is in a segment, False otherwise """ if self._chans is None: raise Exception("No bitmask has been built, call setnchan to build it.") return bool(self._chans[chan - self._startchan])
[docs] def append(self, item): """ Method to append a new segment to the current list Parameters ---------- item : two element array The new segment to append to the list Returns ------- None """ if not hasattr(item, '__iter__'): raise Exception("Segments must be ginven as an iteratable object (list, np.array, etc.") else: if len(item) != 2: raise Exception("Segments must have length 2.") tempseg = [int(item[0]), int(item[1])] tempseg.sort() if tempseg[0] < self._startchan or tempseg[1] > self._maxchan: raise Exception("All or part of a segment is beyond the given spectrum. Segment: %s, bounds: [%i, %i]" % (tempseg, self._startchan, self._maxchan)) self._segments.append(tempseg) self._chans[tempseg[0] - self._startchan: tempseg[1] - self._startchan + 1] = 1
[docs] def getmask(self): """ Method to return the current bitmask Parameters ---------- None Returns ------- Array like object containing the bit current bit mask, 1 = in segment, 0 = not in segment """ return self._chans
[docs] def getchannels(self, invert=False): """ Method to return the current list of channel numbers in the bitmask Parameters ---------- invert : boolean Return the list of channels outside the bitmask instead Returns ------- Array like object containing the (zero based) channel numbers that are in the segments """ if invert: return (np.where(self._chans == 0)[0] + self._startchan).tolist() else: return (np.where(self._chans == 1)[0] + self._startchan).tolist()
[docs] def remove(self, index): """ Method to remove a segment from the segment list Parameters ---------- index : int The location of the segment to remove Returns ------- None """ del self._segments[index] self.recalcmask()
[docs] def pop(self): """ Method to pop, or remove and return, the last segment in the list Parameters ---------- None Returns ------- Array like 2 element list of the segment starting and ending channels of the last segment in the list """ seg = self._segments.pop() self.recalcmask() return seg
[docs] def limits(self): """ Method to return the channel range of the internal channel bit mask Parameters ---------- None Returns ------- Array like 2 element list of the segment starting and ending channels """ return [self._startchan, self._maxchan]
[docs] def recalcmask(self, test=False): """ Method to recalculate the bit mask based on the current segment list. 1 = in segment 0 = not in segment Parameters ---------- test : bool If True then test each segment to be sure it is in the current allowed channel range. If False the do not test. Returns ------- None """ self._chans = np.array([0] * (self._maxchan + 1 - self._startchan)) for seg in self._segments: if test and (seg[0] < self._startchan or seg[1] >= self._maxchan): raise Exception("All or part of a segment is beyond the given spectrum. Segment: %s, bounds: [%i, %i]" % (seg, self._startchan, self._maxchan)) self._chans[seg[0] - self._startchan: seg[1] - self._startchan + 1] = 1
[docs] def setnchan(self, nchan): """ Method to set the number of channels in the internal channel bit mask Parameters ---------- nchan : int The number of channels in the bit mask Returns ------- None """ if nchan - 1 == self._nchan: return self._nchan = nchan self._maxchan = self._startchan + nchan - 1 self.recalcmask(True)
[docs] def getnchan(self): """ Method to return the number of channels in the current bit mask Parameters ---------- None Returns ------- int giving the number of channels """ return self._maxchan - self._startchan + 1
[docs] def setstartchan(self, chan): """ Method to set the starting channels number for the internal bit mask Parameters ---------- chan : int The starting channel number Returns ------- None """ if chan == self._startchan: return self._startchan = chan self._maxchan = self._startchan + self._nchan - 1 self.recalcmask(True)
[docs] def getstartchan(self): """ Method to get the starting channel number of the current bit mask Parameters ---------- None Returns ------- int containing the starting channel number """ return self._startchan
[docs] def chans(self, invert=False): """ Method to convert the bit mask into a string of channel ranges in CASA format. e.g. [3,10],[25,50] => "3~10;25~50" Parameters ---------- None Returns ------- string containing the formatted channel ranges """ output = "" if invert: basechan = np.append(1-self._chans, 0) shiftchan = np.insert(1-self._chans, 0, 0) else: basechan = np.append(self._chans, 0) shiftchan = np.insert(self._chans, 0, 0) diff = basechan - shiftchan st = np.where(diff == 1)[0] en = np.where(diff == -1)[0] first = True for seg in zip(st, en): if not first: output += ";" else: first = False output += str(seg[0] + self._startchan) + "~" + str(seg[1] - 1 + self._startchan) return output
[docs] def merge(self, other=None): """ Method to merge overlapping segments into one segment. This operates on the current instance or takes a second instance or list of segments as input. Parameters ---------- other : Segments instance or list of segments If given then these segments are added to the current list and then merged, to give a single list of non-overlapping segments. Returns ------- None """ if other is None: basechan = np.append(self._chans, 0) shiftchan = np.insert(self._chans, 0, 0) diff = basechan - shiftchan st = np.where(diff == 1)[0] en = np.where(diff == -1)[0] tempchans = [] for seg in zip(st, en): tempchans.append([seg[0], seg[1]]) self._segments = tempchans self.recalcmask() elif type(other) == type(self): if self._startchan != other._startchan: raise "Starting channel numbers do not match." if len(self._chans) != len(other._chans): raise "Number of channels do not match." for seg in other: self._segments.append(seg) self.recalcmask() self.merge() elif hasattr(other, "__iter__"): for seg in other: tempseg = [int(seg[0]), int(seg[1])] if tempseg[1] > self._maxchan or tempseg[0] < self._startchan: raise Exception("All or part of a segment is beyond the given spectrum. Segment: %s, bounds: [%i, %i]" % (seg, self._startchan, self._maxchan)) self._segments.append(tempseg) self.recalcmask() self.merge() else: raise Exception("Improper data type given as input. It must be an iteratable (list, np.array, etc.) or Segments object.")
[docs] def getsegments(self): """ Method to get the list of segments Parameters ---------- None Returns ------- list of the segment end points [start, end] """ return self._segments
[docs] def getsegmentsaslists(self): """ Method to get the list of segments Parameters ---------- None Returns ------- list of the segment end points [start, end] """ return self._segments
[docs] def getsegmentsastuples(self): """ Method to get the list of segments as tuples Parameters ---------- None Returns ------- list of the segment end points as tuples (start, end) """ out = [] for seg in self._segments: out.append(tuple(seg)) return out