from datetime import datetime
import numpy as np
import pandas as pd
from trap.log import logger
[docs]
class FrequencyBands:
def __init__(self) -> None:
self.bands = pd.DataFrame(
{
"band_id": [],
"freq_low": [],
"freq_high": [],
}
).set_index("band_id")
[docs]
@staticmethod
def parse_frequency(header, im_path):
"""Obtain frequency parameters in the FITS file header.
Parameters
----------
header: `astropy.io.fits.header.Header`
The FITS header from which to read the frequency related parameters
im_path: `str`
The path to the image, only used for log warnings.
Returns
-------
freq_eff : float
The effective frequency extracted from the FITS header.
freq_bw : float
The bandwidth extracted from the FITS header.
"""
# Adapted after PySE: https://github.com/transientskp/pyse/blob/c27418e3c1005e9eb77130146c46e7ee6961f6b3/sourcefinder/accessors/fitsimage.py#L192
try:
if "RESTFRQ" in header:
freq_eff = header["RESTFRQ"]
if "RESTBW" in header:
freq_bw = header["RESTBW"]
else:
logger.warning(
"Bandwidth header missing in image {}, setting to 1 MHz".format(
im_path
)
)
freq_bw = 1e6
elif ("CTYPE3" in header) and (header["CTYPE3"] in ("FREQ", "VOPT")):
freq_eff = header["CRVAL3"]
freq_bw = header["CDELT3"]
elif ("CTYPE4" in header) and (header["CTYPE4"] in ("FREQ", "VOPT")):
freq_eff = header["CRVAL4"]
freq_bw = header["CDELT4"]
else:
freq_eff = header["RESTFREQ"]
freq_bw = 1e6
except KeyError:
msg = "Frequency not specified in headers for {}".format(im_path)
logger.error(msg)
raise TypeError(msg)
return freq_eff, freq_bw
[docs]
def get_frequency_band(self, im_freq_eff, im_freq_bw):
"""Get the ID of the frequency band that is related to the given effective frequency and it's related bandwidth.
If there is no frequency band that matches the given parameters, create a new one and return it's related ID.
Parameters
----------
im_freq_eff : float
The effective frequency
im_freq_bw : float
The frequency bandwidth
Returns
-------
`int`
The ID of the band, matching the self.bands dataframe.
"""
bw_half = im_freq_bw / 2
low = im_freq_eff - bw_half
high = im_freq_eff + bw_half
# Band selection as was implemented in original tkp:
# https://github.com/transientskp/tkp/blob/b34582712b82b888a5a7b51b3ee371e682b8c349/tkp/db/alchemy/image.py#L10
w1 = high - low
w2 = self.bands["freq_high"] - self.bands["freq_low"]
union_width = np.maximum(high, self.bands["freq_high"]) - np.minimum(
low, self.bands["freq_low"]
)
mask = union_width < (w1 + w2)
# In case of multiple overlaps, take the first like was done in the original tkp
# Alternatively we could pick the band with most overlap
band = self.bands.loc[mask].head(1)
# Add new band in case no match was found
if band.empty:
band = pd.DataFrame(
[
{
"band_id": len(self.bands.index),
"freq_low": low,
"freq_high": high,
}
]
).set_index("band_id")
self.bands = pd.concat([self.bands, band])
return band.index[0]
[docs]
def plot_freq_bands(self, effective_frequencies, bandwidths):
"""Convenience function that plots all frequency bands currently stored in self.bands.
It also plots any effective frequency supplied to this function to make it clear which
frequency falls into which band.
Parameters
----------
effective_frequencies: List
A list or array of frequencies to plot
"""
try:
import matplotlib.cm as cm
import matplotlib.pyplot as plt
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Unable to import all required visualization dependencies. Please install TraP with the optional visualization dependencies: `pip install trap[view]`"
)
# distinct color for each band, though wraps at 10 colors
cmap = cm.get_cmap("tab10", len(self.bands))
fig, ax = plt.subplots()
im_ids = range(len(effective_frequencies))
# Plot markers with individual colors, include error bar for bandwidth
for im_id, freq_eff, freq_bw in zip(im_ids, effective_frequencies, bandwidths):
band_id = self.get_frequency_band(freq_eff, freq_bw)
color = cmap(band_id)
ax.errorbar(
im_id,
freq_eff,
yerr=freq_bw / 2,
fmt="o",
markersize=5,
capsize=5,
linestyle="",
color=color,
)
ax.set_ylabel("Effective frequency")
ax.set_xticks(im_ids)
ax.set_xlabel("Image id")
# Plot frequency bands as two horizontal lines
for i, band in self.bands.iterrows():
color = cmap(i)
ax.axhline(y=band.freq_low, color=color, linestyle="solid", linewidth=1)
ax.axhline(y=band.freq_high, color=color, linestyle="solid", linewidth=1)
# Plot vertical line for each image
for i in im_ids:
ax.axvline(x=i, color="black", linestyle="dotted", linewidth=0.5)
try:
plt.show()
except:
time = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
filename = f"frequency_bands_{time}.png"
plt.savefig(filename, bbox_inches="tight")
logger.warning(
f"Unable to display interactive plot, saved image to '{filename}'"
)