Source code for trap.cli

import argparse
import os
import pdb
import sys

import psutil

tomllib_present = False
try:
    import tomllib  # Build-in from python 3.11

    tomllib_present = True
except ImportError:
    import toml


[docs] def parse_bool(value): if value is None: return None truthy = {"1", "true", "yes", "on", "t"} falsy = {"0", "false", "no", "off", "f"} if not isinstance(value, str): raise TypeError(f"Expected string, got {type(value).__name__}: {value}") value = value.strip().lower() if value in truthy: return True elif value in falsy: return False else: raise ValueError(f"Cannot convert string to boolean: '{value}'")
[docs] def filter_none_values_from_dict(dict_to_filter: dict): return {k: v for (k, v) in dict_to_filter.items() if v is not None}
[docs] def update_params_from_config_file(config_file, params): """Update the parameters that were not found on the command line with those in the config file. Parameters ---------- config_file: :class:`str` The path to the configuration file. Must be a .toml format. params: :class:`dict` The parameters as parsed by argparse from the command line. This dictionary is updated in-place Returns ------- None """ def overwrite_params_from_nested_dict(mapping: dict): for key, value in mapping.items(): if isinstance(value, dict): overwrite_params_from_nested_dict(value) else: if params.get(key) is None: params[key] = value if config_file is not None: if tomllib_present: # >= python3.11 with open(config_file, "rb") as f: config_data = tomllib.load(f) else: # < python3.11 config_data = toml.load(config_file) overwrite_params_from_nested_dict(config_data)
[docs] def construct_argument_parser(require_command=False, require_processing_args=True): parser = argparse.ArgumentParser( description="Transients Pipeline. Extract sources in radio astronomy data to find transients. Source: https://git.astron.nl/RD/trap" ) if require_command: parser.add_argument( "command", choices=["run", "view"], help="""Command to execute: 'run' or 'view'. 'run' will process the images and store the results in a database. 'view' will visualize the sources in the database created by the 'run' command. For 'view' only the variables that create a database connection are used. """, ) # General arguments general_group = parser.add_argument_group("General") if not require_processing_args: parser.add_argument( "--view_command", "--view-command", "-v", choices=["all_sources", "interactive", "lightcurves", "frequencies"], default="all_sources", help="""View command to execute: 'all_sources' or 'interactive', 'lightcurves'. 'all_sources' will show all locations of each source. 'interactive' will run an interactive viewer where you can navigate through the images and see what sources were found, also showing new and missed sources. 'lightcurves' will plot all lightcurves as intensity over time. 'frequencies' will show what image is in which frequency band. """, ) general_group.add_argument( "--version", action="store_true", help=""" Display the version of the currently installed TraP. """, ) general_group.add_argument( "--pyse_version", "--pyse-version", action="store_true", help=""" Display the version of the currently installed radio-pyse. """, ) general_group.add_argument( "--config_file", "--config-file", help=""" TOML file containing default input arguments to TraP. Default file name: trap_config.toml This is especially convenient when swapping between configurations for the same project. """, ) general_group.add_argument( "--log_dir", "--log-dir", help=""" The directory in which to write the log and the error log file. The same information is also printed to the terminal standard output. """, ) general_group.add_argument( "--nr_threads", "--nr-threads", "-n", type=int, help="""Number of threads to spawn. With multiple threads, images can be read and processed in parallel. If None, use as many theads as there are cores. Using multiple threads speeds up computation significantly. Note that there is a point where the association becomes the bottleneck, since it is fundamentally a sequential operation. That means that adding more processes to pre-load the images has diminishing returns. Warning: there is a known bug where the RAM of your device might blow up if you use many processes but either have little RAM or very many images to process. It can happen that all of the images get loaded in before the program is near completion. To work around this, use a low number of processes. This bug will be addressed in the future. When the distributed scheduler is used, the nr_threads will be divided over the processes and threads. For example, if nr_threads=12, we might get 3 processes with 4 workers each. If a prime number is chosen this might result in a skewed distribution. When nr_threads is set to 13, we might for example get one process with 13 threads, which will run fine but is often non-optimal in terms of performance. Also see argument: scheduler """, ) general_group.add_argument( "--max_concurrent_images", "--max-concurrent-images", type=int, help="""Make sure we don't read more than N images at a time. A consrvative limit protects agains RAM overflows. """, ) general_group.add_argument( "--scheduler", type=str, help="""The Dask scheduler to use. Options are: ['threads', 'distributed']. Default: 'threads'. The 'threads' scheduler uses multithreading to process in parallel. This has the least overhead but is limited by Python's GIL. The 'distributed' scheduler uses a balance of both processes and threads which allows for more versitile parallel computing but can carry more overhead, especially when data between processes needs to be communicated. The distributed scheduler also provides a real-time diagnostics dashboard and allows for running accross multiple nodes. In general I recommend using the threads scheduler when running on a smaller machine like a laptop and use the distributed scheduler when running on a compute node with lots of RAM and many CPU cores. See also the argument: nr_threads. """, ) general_group.add_argument( "--pdb", action="store_true", help=""" Enter debug mode when the application crashes. Meant to be used for more comprehensive debugging. This argument is not exported to the database. """, ) # Database parameters db_group = parser.add_argument_group("Database parameters") db_group.add_argument( "--db_backend", "--db-backend", help=""" The database solution to use. Options are: ['sqlite', 'postgres']. Default: 'sqlite'. If set to 'postgres', the following parameters also need to be provided: 'db_user', 'db_password', 'db_host', 'db_port'. """, ) db_group.add_argument( "--db_name", "--db-name", help=""" When 'db_backend' is sqlite, 'db_name' represents the path to the file. When 'db_backend' is 'postgres', 'db_name' represents the name of the database. """, ) db_group.add_argument( "--db_host", "--db-host", help="The name of the host where the database is located (used for 'postgres').", ) db_group.add_argument( "--db_port", "--db-port", help="The port number to go along with 'db_host' (used for 'postgres').", ) db_group.add_argument( "--db_user", "--db-user", help="The username used for accessing the database (used for 'postgres').", ) db_group.add_argument( "--db_password", "--db-password", help=""" The password used for accessing the database (used for 'postgres'). This argument is not exported to the database. """, ) db_group.add_argument( "--db_overwrite", "--db-overwrite", action="store_true", default=None, help=""" If supplied, clears the database and starts fresh. Only use this if you are OK losing the existing database with the supplied 'db_name'. When not supplied, the program will error if the database already exists, preventing deletion of existing data. """, ) if require_processing_args: # Image parameters image_group = parser.add_argument_group("Input image parameters") image_group.add_argument( "--input_images", "--input-images", "-i", action="append", help=""" The input images in which to find the sources. Only .fits images are supported. This can refer to either a file, directory or glob pattern (e.g. 'images/my_image_*.fits'). When using a glob pattern, remember to wrap the line in quotes or the terminal might get confused. If a directory or glob pattern is used, all fits images found there will be used. If a nested directory is supplied, the subdirectories will also be searched for fits files. These arguments can be supplied multiple times to refer to multiple files or locations. """, ) image_group.add_argument( "--rms_min", "--rms-min", type=float, help="Lower bound for the RMS quality check. If an image has a lower RMS value than 'rms_min', it is 'rejected' and not processed.", ) image_group.add_argument( "--rms_max", "--rms-max", type=float, help="Upper bound for the RMS quality check. If an image has a larger RMS value than 'rms_max', it is 'rejected' and not processed.", ) image_group.add_argument( "--reduction_factor_for_rms", type=float, help=""" Only the region around the center of the image is used to determine the RMS of the image on which the rejection is based. The 'reduction_factor_for_rms' determines the size of this region around the center, where 'reduction_factor_for_rms' is the fraction of each axis to use. To illustrate: if an image is 100x100 pixels, a `reduction_factor_for_rms=2` results in a 50x50 slice around the center that is used to calculate the RMS of the image. """, ) image_group.add_argument( "--frequency_band", "--frequency-band", "-f", type=str, action="append", help="""Pre-definde frequency band. To be supplied as a Json dictionary like so: '{"center_frequency": 140000000, "bandwidth": 4687500.0}'. This can also be a file path to a json file with a list of such dictionaries. A mix can be supplied by supplying it multiple times, e.g.: trap-run -f '{"center_frequency": 140000000, "bandwidth": 4687500.0}' -f frequency_bands.json """, ) # Extraction parameters extraction_group = parser.add_argument_group("Extraction parameters") extraction_group.add_argument( "--ew_sys_err", "--ew-sys-err", type=float, help="Systematic error in arcseconds along the east-west axis.", ) extraction_group.add_argument( "--monitor_loc", "--monitor-loc", "-m", type=str, action="append", help="""A list of locations to always monitor. The coordinates are to be supplied in decimal degrees in the form [Ra, Dec] (e.g '[12.3, 45.6]' or '[[123.4,56.7],[359.9,89.9]]') It can also be the path to a json file with similar structure. A mix of these can be supplied by supplying it multiple times, e.g.: trap-run -m '[12.3, 45.6]' -m monitor_locations.json """, ) extraction_group.add_argument( "--ns_sys_err", "--ns-sys-err", type=float, help="Systematic error in arcseconds along the north-south axis.", ) extraction_group.add_argument( "--detection_threshold", "--detection-threshold", type=float, help="The detection threshold, as a multiple of the RMS noise.", ) extraction_group.add_argument( "--analysis_threshold", "--analysis-threshold", type=float, help="Analysis threshold, as a multiple of the RMS noise.", ) extraction_group.add_argument( "--deblend_nthresh", "--deblend-nthresh", type=int, help="Number of subthresholds to use for deblending. Set to 0 to disable.", ) extraction_group.add_argument( "--max_nr_consecutive_force_fits", "--max-nr-consecutive-force-fits", type=int, help="""Stop force fitting if the source has not naturally been found after a specified number of images. If the source has been found naturally again this is reset and we will again force fit for the specified number of images. If the source is naturally detected at a regular interval that is smaller than max_nr_consecutive_force_fits, the lightcurve will be continuous. If there are periods where the source is not naturally found within the specified number of images, there will be gaps in the time axis of the lightcurve. Since the number of images is also related to the number of frequency bands, we multiply the max_nr_consecutive_force_fits by the number of frequency bands in order to make sure we work with the number epochs, not images. This assumes each epoch has the same number of frequency bands. If some epochs miss frequency bands, we might end up sampling a few more epochs than intended. """, ) extraction_group.add_argument( "--force_beam", "--force-beam", action="store_true", default=None, help="Force all extractions to have major/minor axes equal to the restoring beam.", ) extraction_group.add_argument( "--im_margin", "--im-margin", type=int, help="The number of pixels from the edge of the image within which sources are ignored.", ) extraction_group.add_argument( "--im_radius", "--im-radius", type=int, help="The radius in pixels around the center of the image, outside of which sources are ignored.", ) extraction_group.add_argument( "--im_back_size_x", "--im-back-size-x", type=int, help="Width of the background boxes as used in SEP.", ) extraction_group.add_argument( "--im_back_size_y", "--im-back-size-y", type=int, help="Height of the background boxes as used in SEP.", ) # Association parameters association_group = parser.add_argument_group("Association parameters") association_group.add_argument( "--de_ruiter_r_max", "--de-ruiter-r-max", type=float, help="If the de Ruiter radius is larger than this value, sources are considered different.", ) return parser
[docs] def parse_arguments(require_command=False, require_processing_args=True): parser = construct_argument_parser( require_command=require_command, require_processing_args=require_processing_args ) args = parser.parse_args() params = vars(args) update_params_from_config_file(params["config_file"], params) return params
[docs] def run_batch(params=None): if params is None: params = parse_arguments() # Handle the case where version information is requested. # The program is meant to exist after printing the information, # similar to the --help argument. version_info = [] if params["version"]: import trap version_info.append(f"TraP version: v{trap.__version__}") if params["pyse_version"]: import sourcefinder version_info.append(f"Radio-PySE version: v{sourcefinder.__version__}") from trap.log import logger if version_info: logger.info("\n".join(version_info)) sys.exit(0) if params["pdb"]: # Automatically start the debugger on an unhandled exception def excepthook(type, value, traceback): pdb.post_mortem(traceback) sys.excepthook = excepthook # Prepare trap for running import json from pathlib import Path # Start coverage if env. var. COVERAGE_PROCESS_START is set # This accounts for subprocesses running trap in the test suite import coverage import dask import numpy as np import pandas as pd from dask.distributed import Client from dask.distributed.deploy.utils import nprocesses_nthreads from trap import run from trap.frequency_bands import FrequencyBands from trap.io import ( find_fits, headers_by_time, init_db, max_image_size, read_fits_headers, ) from trap.log import add_log_file_handler coverage.process_startup() logger = add_log_file_handler(params["log_dir"] or "./logs") if params["input_images"] is None: raise Exception( "No images were specified. Use '--input_images' or '-i' to specify the location of input images." ) # Parse pre-defined frequency bands bands = FrequencyBands() def parse_frequency_bands(path_or_array_as_string): # If file was supplied, read positions from there if Path(path_or_array_as_string).is_file(): with open(path_or_array_as_string, "r") as file: try: freq_bands = json.load(file) except json.JSONDecodeError: logger.error("Could not parse file: " + path_or_array_as_string) raise else: # not a path, assume json style dict try: # Turn into list with single item to have consistent return type freq_bands = [json.loads(path_or_array_as_string)] except json.JSONDecodeError: logger.error( "Could not parse frequency bands from command line:" "string passed was:\n{}".format(path_or_array_as_string) ) raise # Ensure the shape of the array is always in the form [[ra1, dec1], [ra2, dec2]] return freq_bands if params["frequency_band"] is not None: freq_bands = [] for arg in params["frequency_band"]: freq_bands.extend(parse_frequency_bands(arg)) for band in freq_bands: _ = bands.get_frequency_band(band["effective_frequency"], band["bandwidth"]) logger.debug(f"Found {len(freq_bands)} pre-defined frequency bands") # Turn relative paths into absolute paths. for i in range(len(params["input_images"])): params["input_images"][i] = str(Path(params["input_images"][i]).absolute()) logger.info(f"Gathering .fits files in: {', '.join(params['input_images'])}") fits_paths = [] for path in params["input_images"]: fits_paths.extend(find_fits(path)) fits_paths = np.unique(fits_paths) if len(fits_paths) == 0: raise Exception( "No input images were found in any of: \n - " + "\n - ".join(params["input_images"]) ) elif len(fits_paths) == 1: logger.info("Found exactly one input image") else: logger.info(f"Found {len(fits_paths)} input images") fits_headers = read_fits_headers(fits_paths) for header, im_path in zip(fits_headers, fits_paths): freq_eff, freq_bw = bands.parse_frequency(header, im_path) # Add frequency band if image is not already in a known frequency band. _ = bands.get_frequency_band(freq_eff, freq_bw) max_concurrent_images = params.pop("max_concurrent_images") if not max_concurrent_images: max_im_size = max_image_size(fits_headers) available_ram = psutil.virtual_memory().available # Note: We need to leave room in RAM for copies of the image made by PySE. # The factor of 20 (0.05) has been chosen by running experiments of # two datasets on two different machines, so is by no means definitive. # Also, changes in PySE could affect this scaling factor for better or worse. # For now this runs the workloads well, but may be revisited when new # insights come in. max_concurrent_images = int(np.floor(0.05 * (available_ram / max_im_size))) # Prevent a 0 for max_concurrent_images if image would not fit in RAM if max_concurrent_images < 1: logger.warning( "The images might be too large to process with available RAM on this machine!" ) max_concurrent_images = 1 nr_threads = params["nr_threads"] if isinstance(nr_threads, str): # Parse "None" as None such that Dask can use it's own defaults, # which is related to to total number of CPU cores on the machine. if nr_threads.lower() == "none": nr_threads = None if params["scheduler"] is None or params["scheduler"].lower() == "threads": dask.config.set(scheduler="threads") dask.config.set(num_workers=nr_threads) dask.config.set(pool=None) # Ensure it respects num_workers elif params["scheduler"].lower() == "single-threaded": dask.config.set(scheduler="single-threaded") elif params["scheduler"].lower() == "distributed": if nr_threads is None: nr_threads = os.cpu_count() nr_threads_per_worker, nr_workers = nprocesses_nthreads(nr_threads) # Note: `max_concurrent_images` assumes it can devide the available RAM over the images. # Since Dask divides the RAM over the available workers, we must therefore make sure # that we do not have more workers than `max_concurrent_images`. Not only would the extra # processes be redundant for there are not enough images to process per batch, they would # be harmful since it would reduce the available RAM per worker, possibly resulting in # Dask killing the workers due to RAM overflows. Therefore we limit the number of workers # to match max_concurrent_images. Multiple threads per worker are fine, they do not # influence the way Dask distributes the RAM. if nr_workers > max_concurrent_images: nr_workers = max_concurrent_images client = Client(n_workers=nr_workers, threads_per_worker=nr_threads_per_worker) logger.info( f"Processing with {nr_workers} processes with {nr_threads} threads each." ) logger.info("View progress dashboard at: " + str(client.dashboard_link)) elif ( params["scheduler"].lower() == "processes" or params["scheduler"].lower() == "multiprocessing" ): raise NotImplementedError(""" Note: we deliberately don't support using only processes, because that results in a lot of data communication between workers which is very inefficient. This communication happens especially when copying the image object for force-fitting after association. Re-reading is also not performant, unless maybe if the file is still kept hot by the OS but in practice this rarely seems to happen during TraP processing. Consider using a distributed scheduler with a combination of processes and threads instead. """) else: raise ValueError( f"Unrecognized 'scheduler' argument provided. Expected 'threads', 'distributed' or 'single-threaded', got: {params['scheduler']}" ) if params["db_name"] is None: raise ValueError( "No '--db_name' was supplied. Please specify the name of the database the TraP data is to be exported to." ) db_kwargs = filter_none_values_from_dict( dict( db_backend=params["db_backend"] or "sqlite", db_name=params["db_name"], db_user=params["db_user"], db_password=params["db_password"], db_host=params["db_host"], db_port=params["db_port"], ) ) # Init db, cleaning if needed db_engine = init_db(**db_kwargs, db_overwrite=params["db_overwrite"]) # order the fits chronologically fits_order_ids = headers_by_time(fits_headers) fits_paths = np.array(fits_paths)[fits_order_ids] for path in fits_paths: logger.debug(f"Found image: {path}") # Parse fixed monitoring positions def parse_monitor_positions(path_or_array_as_string): # If file was supplied, read positions from there if Path(path_or_array_as_string).is_file(): with open(path_or_array_as_string, "r") as file: try: monitor_coords = json.load(file) except json.JSONDecodeError: logger.error( "Could not parse monitor-loc from file: " + path_or_array_as_string ) raise else: # not a path, assume list of coordinates try: monitor_coords = json.loads(path_or_array_as_string) except json.JSONDecodeError: logger.error( "Could not parse monitor-loc from command line:" "string passed was:\n{}".format(path_or_array_as_string) ) raise # Ensure the shape of the array is always in the form [[ra1, dec1], [ra2, dec2]] monitor_coords_arr = np.array(monitor_coords) if monitor_coords_arr.ndim == 1: return monitor_coords_arr[np.newaxis, :] elif monitor_coords_arr.ndim == 2: return monitor_coords_arr raise ValueError( f"Monitor-loc is expected to be in two dimentions, like [[ra1, dec1], [ra2, dec2]], but we got {monitor_coords_arr.ndim} dimentions, parsed from: {path_or_array_as_string}" ) if params["monitor_loc"] is not None: monitor_coords = [] for arg in params["monitor_loc"]: monitor_coords.append(parse_monitor_positions(arg)) monitor_coord_arr = np.vstack(monitor_coords) logger.debug(f"Found {len(monitor_coord_arr)} monitoring positions") else: monitor_coord_arr = np.array([]) # Save the configuration in the database. # Make sure any list is turned into a single string. # This happens when multiple inputs are allowed in the CLI such as with --input_images # If we keep this a list there will be multiple rows in the table. Any non-list values are then duplicated. params_for_db = params.copy() for key, val in params_for_db.items(): if hasattr(val, "__len__") and not isinstance(val, str): params_for_db[key] = "; ".join(params_for_db[key]) params_for_db.pop("pdb") # Uninteresting for export params_for_db.pop("db_password") # Not safe to export pd.DataFrame(params_for_db, index=[0]).to_sql("config", db_engine, index=False) # Also write frequency bands to database # Adding in freq_central to match the schema from the orignal tkp: # https://tkp.readthedocs.io/en/latest/devref/database/schema.html#frequencyband bands_df = bands.bands.copy() bands_df["freq_central"] = ( bands_df["freq_low"] + (bands_df["freq_high"] - bands_df["freq_low"]) / 2 ) bands_df.to_sql("frequencybands", db_engine, index=True) return run.main( fits_paths, db_kwargs=db_kwargs, freq_bands=bands, max_nr_consecutive_force_fits=params["max_nr_consecutive_force_fits"], max_concurrent_images=max_concurrent_images, monitor_coords=monitor_coord_arr, pyse_config=filter_none_values_from_dict( dict( margin=params["im_margin"], radius=params["im_radius"], back_size_x=params["im_back_size_x"], back_size_y=params["im_back_size_y"], force_beam=params["force_beam"], ew_sys_err=params["ew_sys_err"], ns_sys_err=params["ns_sys_err"], detection_thr=params["detection_threshold"], analysis_thr=params["analysis_threshold"], deblend_nthresh=params["deblend_nthresh"], ) ), association_kwargs=filter_none_values_from_dict( dict( de_ruiter_r_max=params["de_ruiter_r_max"], ) ), )
[docs] def view(params=None): if params is None: params = parse_arguments(require_processing_args=False) if params["pdb"]: # Automatically start the debugger on an unhandled exception def excepthook(type, value, traceback): pdb.post_mortem(traceback) sys.excepthook = excepthook from trap.io import open_db from trap.log import logger from trap.visualize import ( plot_all_sources, plot_frequencies, plot_lightcurves, visualize, ) db_engine = open_db( db_backend=params["db_backend"], db_name=params["db_name"], db_user=params["db_user"], db_password=params["db_password"], db_host=params["db_host"], db_port=params["db_port"], ) match params["view_command"]: case "all_sources": plot_all_sources(db_engine) case "interactive": visualize(db_engine) case "lightcurves": plot_lightcurves(db_engine) case "frequencies": plot_frequencies(db_engine) case _: raise ValueError(f"Unrecognized view_command '{params['view_command']}'")
[docs] def main(): params = parse_arguments(require_command=True) if params["command"] == "run": return run_batch(params) elif params["command"] == "view": return view(params)
if __name__ == "__main__": sys.exit(main())