Source code for democratic_detrender.plot

""" This module contains functions to plot light curves and detrending results. """

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from matplotlib.widgets import Slider, Button
from democratic_detrender.helper_functions import bin_data

import matplotlib
matplotlib.rc("xtick", labelsize=27)
matplotlib.rc("ytick", labelsize=27)
matplotlib.rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
matplotlib.rc("text", usetex=True)

[docs] def plot_transit( xs_star, ys_star, xs_transit, ys_transit, t0, period, title, bin_window, object_id, problem_times_input=None, dont_bin=False, ): # xs_star = time not in transit # ys_star = flux not in transit # xs_transit = times in transit # ys_tranist = fluxed in transit # t0 = midtransit # period = planet period to define plotting limit global problem_times if problem_times_input == None: problem_times = [] else: problem_times_input.sort() problem_times = problem_times_input window = 1.0 / 2.0 fig, ax = plt.subplots(1, 1, figsize=[9, 6]) plt.subplots_adjust(left=0.2, bottom=0.3, hspace=0.3) if not dont_bin: xstar_bin, ystar_bin = bin_data(xs_star, ys_star, bin_window) xtransit_bin, ytransit_bin = bin_data(xs_transit, ys_transit, bin_window) bin_colors = ["#00008B", "#DC143C"] xmin, xmax = t0 - (period * window)[0], t0 + (period * window)[0] ymin_transit, ymax_transit = np.nanmin(ys_transit), np.nanmax(ys_transit) ymin_star, ymax_star = np.nanmin(ys_star), np.nanmax(ys_star) ymin, ymax = np.nanmin([ymin_transit, ymin_star]), np.max([ymax_transit, ymax_star]) if ymin > 0: ymin = ymin / 1.2 else: ymin = ymin * 1.2 if ymax > 0: ymax = ymax * 1.2 else: ymax = ymax / 1.2 t_init = 0 y = np.arange(ymin, ymax, 0.000001) t = t_init * np.ones(np.shape(y)) l = ax.plot(t, y, lw=2, color="k")[0] ax.plot(xs_star, ys_star, ".", color="grey", alpha=0.3) ax.plot(xs_transit, ys_transit, ".", color="black", alpha=0.3) if not dont_bin: ax.plot(xstar_bin, ystar_bin, "o", color=bin_colors[1], alpha=0.9, markersize=7) ax.plot( xtransit_bin, ytransit_bin, "o", color=bin_colors[0], alpha=0.9, markersize=7, ) ax.text(xmin + (xmax - xmin) * 0.05, 0.7 * ymax, title, fontsize=27) ax.axvline(t0, linewidth=1, color="k", ls="dashed") ax.set_xlabel("time [days]") ax.set_ylabel("intensity") ax.set_xlim(xmin, xmax) ax.set_ylim(ymin, ymax) ax.ticklabel_format(useOffset=False) # disable scientific notation ax.set_title(object_id, fontsize=27) axtime = plt.axes([0.197, 0.1, 0.702, 0.09]) stime = Slider( axtime, "time", xmin, xmax, valinit=t_init, orientation="horizontal", color="k" ) def update(val): l.set_xdata(val * np.ones(np.shape(y))) fig.canvas.draw_idle() stime.on_changed(update) save_time = plt.axes([0.8, 0.025, 0.13, 0.04]) button = Button(save_time, "save time", color="0.97", hovercolor="0.79") def save(event): global problem_times if stime.val not in problem_times: problem_times.append(stime.val) problem_times.sort() button.on_clicked(save) return (stime, button, problem_times)
[docs] def plot_transits( x_transits, y_transits, mask_transits, t0s, period, bin_window, object_id, problem_times_input=None, dont_bin=False, data_name=None, ): # xs = times # ys = fluxes # mask = masks for transit # t0s = midtransits in data # period = planet period to define plotting limits plt.close("all") sliders, buttons, problem_times = [], [], [] if len(t0s) != len(x_transits): print("ERROR, length of t0s doesn't match length of x_transits") for ii in range(0, len(t0s)): t0 = t0s[ii] xs = x_transits[ii] ys = y_transits[ii] mask = mask_transits[ii] title = "epoch " + str(ii + 1) slider, button, problem_times_epoch = plot_transit( xs[~mask], ys[~mask], xs[mask], ys[mask], t0, period, title, bin_window, object_id, problem_times_input=problem_times_input, dont_bin=dont_bin, ) sliders.append(slider) buttons.append(button) problem_times.append(problem_times_epoch) if data_name is not None: # saving detrend data as csv detrend_dict = {} detrend_dict["time"] = xs detrend_dict["flux"] = ys detrend_dict["mask"] = mask detrend_df = pd.DataFrame(detrend_dict) print("saving data to " + data_name + str(ii + 1) + ".csv") detrend_df.to_csv(data_name + str(ii + 1) + ".csv") plt.show() # turn list of lists into a flattened list problem_times_flat = [item for sublist in problem_times for item in sublist] return sliders, buttons, problem_times_flat
[docs] def plot_detrended_lc( xs, ys, detrend_labels, t0s_in_data, window, period, colors, duration, mask_width=1.1, depth=None, figname=None, title=None, ): """ inputs: x = times ys = [detrended light curves] of length N number of detrendings detrend_labels = [detrending type] of length N number of detrendings t0s_in_data = midtransits in data window = what fraction of the period to plot on either side of transit (ie. window=1/2 means 1/2 period on either side) period = planet period to define plotting limit colors = [colors] of length N number of detrendings figname = Name of file if you want to save figure return: None """ import math transit_windows = [] for t0 in t0s_in_data: transit_windows.append( [ t0 - mask_width * duration / (2 * 24.0), t0 + mask_width * duration / (2 * 24.0), ] ) n_transit = np.arange(0, len(t0s_in_data), 1) if len(t0s_in_data) > 1: if len(ys) == 1: fig, ax = plt.subplots( ncols=3, nrows=math.ceil(len(t0s_in_data) / 3), figsize=[27, len(t0s_in_data) * len(ys)], sharey=True, ) else: fig, ax = plt.subplots( ncols=3, nrows=math.ceil(len(t0s_in_data) / 3), figsize=[27, len(t0s_in_data) * len(ys) / 4], sharey=True, ) else: fig, ax = plt.subplots(ncols=1, nrows=1, figsize=[18, 13], sharey=True) y_detrend = [] if not depth: depth = 0.01 if len(t0s_in_data) == 1: t0 = t0s_in_data[0] detrend_offset = 0 for detrend_index in range(0, len(ys)): y_detrend = ys[detrend_index] x = xs ax.plot( x, y_detrend + detrend_offset, "o", color=colors[detrend_index], alpha=0.63, markersize=9, ) ax.text( t0 - (period * window) + 0.01, detrend_offset + 0.013, detrend_labels[detrend_index], color=colors[detrend_index], fontsize=13, ) detrend_offset += depth [ ax.axvline(transit[0], linewidth=1.8, color="k", alpha=0.79, ls="--") for transit in transit_windows ] [ ax.axvline(transit[1], linewidth=1.8, color="k", alpha=0.79, ls="--") for transit in transit_windows ] ax.set_xlabel("time [KBJD]", fontsize=27) ax.set_ylabel("intensity", fontsize=27) ax.set_xlim(t0 - (period * window), t0 + (period * window)) ax.set_ylim(-1.2 * depth, depth * len(ys)) ax.tick_params(axis="x", rotation=45) ax.ticklabel_format(useOffset=False) else: column = 0 row = 0 for ii in range(0, len(t0s_in_data)): ax_ii = ax[row][column] t0 = t0s_in_data[ii] detrend_offset = 0 for detrend_index in range(0, len(ys)): y_detrend = ys[detrend_index] x = xs ax_ii.plot( x, y_detrend + detrend_offset, "o", color=colors[detrend_index], alpha=0.63, ) ax_ii.text( t0 - (period * window) + 0.18, detrend_offset + 0.0018, detrend_labels[detrend_index], color=colors[detrend_index], fontsize=18, ) detrend_offset += depth ax_ii.axvline( transit_windows[ii][0], linewidth=1.8, color="k", alpha=0.79, ls="--" ) ax_ii.axvline( transit_windows[ii][1], linewidth=1.8, color="k", alpha=0.79, ls="--" ) ax_ii.set_xlabel("time [KBJD]", fontsize=18) ax_ii.set_ylabel("intensity", fontsize=18) ax_ii.set_xlim(t0 - (period * window), t0 + (period * window)) ax_ii.set_ylim(-1.2 * depth, depth * len(ys)) ax_ii.tick_params(axis="x", rotation=45) if column == 2: column = 0 row += 1 else: column += 1 if title: fig.suptitle(title, fontsize=36, y=1.01) # fig.delaxes(ax[1][4]) fig.tight_layout() if figname: fig.savefig(figname, bbox_inches="tight") return None
[docs] def plot_phase_fold_lc(time, lc, period, t0s, xlim, figname): plt.close("all") fig = plt.figure(figsize=[18, 6]) x_fold = (time - t0s[0] + 0.5 * period) % period - 0.5 * period plt.scatter(x_fold, lc, c=time, s=10) plt.xlabel("time since transit [days]") plt.ylabel("intensity [ppm]") plt.colorbar(label="time [days]") _ = plt.xlim(0.0 - (period / xlim), 0.0 + (period / xlim)) if figname: fig.savefig(figname) return None
[docs] def plot_outliers( time, flux, time_out, flux_out, moving_median, kepler_quarters, figname, object_id ): """ input: ------- time = array of time values flux = array of flux values time_out = array of time values without outliers flux_out = array of flux values without outliers """ plt.close("all") outlier_times = [] outlier_fluxes = [] n_outliers = len(flux) - len(flux_out) outliers_count = 0 for ii in range(0, len(time)): if time[ii] not in time_out: outliers_count += 1 outlier_times.append(time[ii]) outlier_fluxes.append(flux[ii]) if outliers_count != n_outliers: print("ERROR, didn't find all outliers") fig, ax = plt.subplots(1, 1, figsize=[18, 9]) ax.plot(time_out, flux_out, "o", color="grey", alpha=0.7) ax.plot(outlier_times, outlier_fluxes, "o", color="red", alpha=1.0) # ax.plot(time, moving_median, '.', color = 'k') [ax.axvline(_x, linewidth=1, color="k", ls="--") for _x in kepler_quarters] ax.set_xlabel("time [days]", fontsize=27) ax.set_ylabel("intensity", fontsize=27) ax.set_title(object_id, fontsize=36) fig.tight_layout() fig.savefig(figname) return None
[docs] def plot_split_data(x_split, y_split, t0s, figname, object_id): plt.close("all") fig, ax = plt.subplots(nrows=len(x_split), figsize=[18, 3 * len(x_split)]) if len(x_split) > 1: for ii in range(0, len(x_split)): xmin = np.min(x_split[ii]) xmax = np.max(x_split[ii]) ax_ii = ax[ii] ax_ii.plot(x_split[ii], y_split[ii], "o", color="grey", alpha=0.7) [ax_ii.axvline(_x, linewidth=1, color="k", ls="--") for _x in t0s] ax_ii.text( xmin + (xmax - xmin) * 0.05, 0.7 * np.max(y_split[ii]), "quarter " + str(ii), fontsize=27, ) ax_ii.set_xlim(xmin, xmax) ax_ii.set_ylabel("intensity", fontsize=27) ax[len(x_split) - 1].set_xlabel("time [days]", fontsize=27) else: xmin = np.min(x_split[0]) xmax = np.max(x_split[0]) ax_ii = ax ax_ii.plot(x_split[0], y_split[0], "o", color="grey", alpha=0.7) [ax_ii.axvline(_x, linewidth=1, color="k", ls="--") for _x in t0s] ax_ii.text(xmin + (xmax - xmin) * 0.05, 0, "quarter " + str(0), fontsize=27) ax_ii.set_xlim(xmin, xmax) ax_ii.set_xlabel("time [days]", fontsize=27) ax_ii.set_ylabel("intensity", fontsize=27) fig.suptitle(object_id, fontsize=36) fig.tight_layout() fig.savefig(figname) return None
[docs] def plot_individual_outliers( time, flux, time_out, flux_out, t0s, period, window, depth, figname ): """ input: ------- time = array of time values flux = array of flux values time_out = array of time values without outliers flux_out = array of flux values without outliers period = planet period to define plotting limit window = what fraction of the period to plot on either side of transit (ie. window=1/2 means 1/2 period on either side) t0s = midtransits in data """ plt.close("all") outlier_times = [] outlier_fluxes = [] n_outliers = len(flux) - len(flux_out) outliers_count = 0 for ii in range(0, len(time)): if time[ii] not in time_out: outliers_count += 1 outlier_times.append(time[ii]) outlier_fluxes.append(flux[ii]) if outliers_count != n_outliers: print("ERROR, didn't find all outliers") nplots = 0 epochs_with_outliers = [] for ii in range(0, len(t0s)): outlier_in_this_epoch = False t0 = t0s[ii] xmin, xmax = t0 - (period * window), t0 + (period * window) for outlier_time in outlier_times: if outlier_time > xmin and outlier_time < xmax: if not outlier_in_this_epoch: nplots += 1 outlier_in_this_epoch = True epochs_with_outliers.append(ii) if nplots > 0: fig, ax = plt.subplots(nrows=nplots, figsize=[18, 6 * nplots]) else: return None print(figname) # CHANGED TO FIX NOT PERMANENT MARCH 16,2024 plot_num = -1 for ii in range(0, len(t0s)): if ii in epochs_with_outliers: plot_num += 1 t0 = t0s[ii] xmin, xmax = t0 - (period * window), t0 + (period * window) if nplots > 1: ax[plot_num].plot(time_out, flux_out, "o", color="grey", alpha=0.7) ax[plot_num].plot( outlier_times, outlier_fluxes, "o", color="red", alpha=1.0 ) ax[plot_num].text( xmin + (xmax - xmin) * 0.05, -depth * 0.2, "epoch " + str(ii + 1), fontsize=27, ) ax[plot_num].set_xlim(xmin, xmax) ax[plot_num].set_xlabel("time [days]") ax[plot_num].set_ylabel("intensity") else: ax.plot(time_out, flux_out, "o", color="grey", alpha=0.7) ax.plot(outlier_times, outlier_fluxes, "o", color="red", alpha=1.0) ax.text( xmin + (xmax - xmin) * 0.05, -depth * 0.2, "epoch " + str(ii + 1), fontsize=27, ) ax.set_xlim(xmin, xmax) ax.set_xlabel("time [days]") ax.set_ylabel("intensity") fig.tight_layout() fig.savefig(figname) return None