Note
Click here to download the full example code
MCMC search v.s. grid searchΒΆ
An example to compare MCMCSearch and GridSearch on the same data.
8 import os
9
10 import matplotlib.pyplot as plt
11 import numpy as np
12
13 import pyfstat
14
15 # flip this switch for a more expensive 4D (F0,F1,Alpha,Delta) run
16 # instead of just (F0,F1)
17 # (still only a few minutes on current laptops)
18 sky = False
19
20 outdir = os.path.join(
21 "PyFstat_example_data", "PyFstat_example_simple_mcmc_vs_grid_comparison"
22 )
23 logger = pyfstat.set_up_logger(label="mcmc_vs_grid", outdir=outdir)
24 if sky:
25 outdir += "AlphaDelta"
26
27 # parameters for the data set to generate
28 tstart = 1000000000
29 duration = 30 * 86400
30 Tsft = 1800
31 detectors = "H1,L1"
32 sqrtSX = 1e-22
33
34 # parameters for injected signals
35 inj = {
36 "tref": tstart,
37 "F0": 30.0,
38 "F1": -1e-10,
39 "F2": 0,
40 "Alpha": 0.5,
41 "Delta": 1,
42 "h0": 0.05 * sqrtSX,
43 "cosi": 1.0,
44 }
45
46 # latex-formatted plotting labels
47 labels = {
48 "F0": "$f$ [Hz]",
49 "F1": "$\\dot{f}$ [Hz/s]",
50 "2F": "$2\\mathcal{F}$",
51 "Alpha": "$\\alpha$",
52 "Delta": "$\\delta$",
53 }
54 labels["max2F"] = "$\\max\\,$" + labels["2F"]
55
56
57 def plot_grid_vs_samples(grid_res, mcmc_res, xkey, ykey):
58 """local plotting function to avoid code duplication in the 4D case"""
59 plt.plot(grid_res[xkey], grid_res[ykey], ".", label="grid")
60 plt.plot(mcmc_res[xkey], mcmc_res[ykey], ".", label="mcmc")
61 plt.plot(inj[xkey], inj[ykey], "*k", label="injection")
62 grid_maxidx = np.argmax(grid_res["twoF"])
63 mcmc_maxidx = np.argmax(mcmc_res["twoF"])
64 plt.plot(
65 grid_res[xkey][grid_maxidx],
66 grid_res[ykey][grid_maxidx],
67 "+g",
68 label=labels["max2F"] + "(grid)",
69 )
70 plt.plot(
71 mcmc_res[xkey][mcmc_maxidx],
72 mcmc_res[ykey][mcmc_maxidx],
73 "xm",
74 label=labels["max2F"] + "(mcmc)",
75 )
76 plt.xlabel(labels[xkey])
77 plt.ylabel(labels[ykey])
78 plt.legend()
79 plotfilename_base = os.path.join(outdir, "grid_vs_mcmc_{:s}{:s}".format(xkey, ykey))
80 plt.savefig(plotfilename_base + ".png")
81 if xkey == "F0" and ykey == "F1":
82 plt.xlim(zoom[xkey])
83 plt.ylim(zoom[ykey])
84 plt.savefig(plotfilename_base + "_zoom.png")
85 plt.close()
86
87
88 def plot_2F_scatter(res, label, xkey, ykey):
89 """local plotting function to avoid code duplication in the 4D case"""
90 markersize = 3 if label == "grid" else 1
91 sc = plt.scatter(res[xkey], res[ykey], c=res["twoF"], s=markersize)
92 cb = plt.colorbar(sc)
93 plt.xlabel(labels[xkey])
94 plt.ylabel(labels[ykey])
95 cb.set_label(labels["2F"])
96 plt.title(label)
97 plt.plot(inj[xkey], inj[ykey], "*k", label="injection")
98 maxidx = np.argmax(res["twoF"])
99 plt.plot(
100 res[xkey][maxidx],
101 res[ykey][maxidx],
102 "+r",
103 label=labels["max2F"],
104 )
105 plt.legend()
106 plotfilename_base = os.path.join(
107 outdir, "{:s}_{:s}{:s}_2F".format(label, xkey, ykey)
108 )
109 plt.xlim([min(res[xkey]), max(res[xkey])])
110 plt.ylim([min(res[ykey]), max(res[ykey])])
111 plt.savefig(plotfilename_base + ".png")
112 plt.close()
113
114
115 if __name__ == "__main__":
116
117 logger.info("Generating SFTs with injected signal...")
118 writer = pyfstat.Writer(
119 label="simulated_signal",
120 outdir=outdir,
121 tstart=tstart,
122 duration=duration,
123 detectors=detectors,
124 sqrtSX=sqrtSX,
125 Tsft=Tsft,
126 **inj,
127 Band=1, # default band estimation would be too narrow for a wide grid/prior
128 )
129 writer.make_data()
130
131 # set up square search grid with fixed (F0,F1) mismatch
132 # and (optionally) some ad-hoc sky coverage
133 m = 0.001
134 dF0 = np.sqrt(12 * m) / (np.pi * duration)
135 dF1 = np.sqrt(180 * m) / (np.pi * duration**2)
136 DeltaF0 = 500 * dF0
137 DeltaF1 = 200 * dF1
138 if sky:
139 # cover less range to keep runtime down
140 DeltaF0 /= 10
141 DeltaF1 /= 10
142 F0s = [inj["F0"] - DeltaF0 / 2.0, inj["F0"] + DeltaF0 / 2.0, dF0]
143 F1s = [inj["F1"] - DeltaF1 / 2.0, inj["F1"] + DeltaF1 / 2.0, dF1]
144 F2s = [inj["F2"]]
145 search_keys = ["F0", "F1"] # only the ones that aren't 0-width
146 if sky:
147 dSky = 0.01 # rather coarse to keep runtime down
148 DeltaSky = 10 * dSky
149 Alphas = [inj["Alpha"] - DeltaSky / 2.0, inj["Alpha"] + DeltaSky / 2.0, dSky]
150 Deltas = [inj["Delta"] - DeltaSky / 2.0, inj["Delta"] + DeltaSky / 2.0, dSky]
151 search_keys += ["Alpha", "Delta"]
152 else:
153 Alphas = [inj["Alpha"]]
154 Deltas = [inj["Delta"]]
155 search_keys_label = "".join(search_keys)
156
157 logger.info("Performing GridSearch...")
158 gridsearch = pyfstat.GridSearch(
159 label="grid_search_" + search_keys_label,
160 outdir=outdir,
161 sftfilepattern=os.path.join(outdir, "*simulated_signal*sft"),
162 F0s=F0s,
163 F1s=F1s,
164 F2s=F2s,
165 Alphas=Alphas,
166 Deltas=Deltas,
167 tref=inj["tref"],
168 )
169 gridsearch.run()
170 gridsearch.print_max_twoF()
171 gridsearch.generate_loudest()
172
173 # do some plots of the GridSearch results
174 if not sky: # this plotter can't currently deal with too large result arrays
175 logger.info("Plotting 1D 2F distributions...")
176 for key in search_keys:
177 gridsearch.plot_1D(xkey=key, xlabel=labels[key], ylabel=labels["2F"])
178
179 logger.info("Making GridSearch {:s} corner plot...".format("-".join(search_keys)))
180 vals = [np.unique(gridsearch.data[key]) - inj[key] for key in search_keys]
181 twoF = gridsearch.data["twoF"].reshape([len(kval) for kval in vals])
182 corner_labels = [
183 "$f - f_0$ [Hz]",
184 "$\\dot{f} - \\dot{f}_0$ [Hz/s]",
185 ]
186 if sky:
187 corner_labels.append("$\\alpha - \\alpha_0$")
188 corner_labels.append("$\\delta - \\delta_0$")
189 corner_labels.append(labels["2F"])
190 gridcorner_fig, gridcorner_axes = pyfstat.gridcorner(
191 twoF, vals, projection="log_mean", labels=corner_labels, whspace=0.1, factor=1.8
192 )
193 gridcorner_fig.savefig(os.path.join(outdir, gridsearch.label + "_corner.png"))
194 plt.close(gridcorner_fig)
195
196 logger.info("Performing MCMCSearch...")
197 # set up priors in F0 and F1 (over)covering the grid ranges
198 if sky: # MCMC will still be fast in 4D with wider range than grid
199 DeltaF0 *= 50
200 DeltaF1 *= 50
201 theta_prior = {
202 "F0": {
203 "type": "unif",
204 "lower": inj["F0"] - DeltaF0 / 2.0,
205 "upper": inj["F0"] + DeltaF0 / 2.0,
206 },
207 "F1": {
208 "type": "unif",
209 "lower": inj["F1"] - DeltaF1 / 2.0,
210 "upper": inj["F1"] + DeltaF1 / 2.0,
211 },
212 "F2": inj["F2"],
213 }
214 if sky:
215 # also implicitly covering twice the grid range here
216 theta_prior["Alpha"] = {
217 "type": "unif",
218 "lower": inj["Alpha"] - DeltaSky,
219 "upper": inj["Alpha"] + DeltaSky,
220 }
221 theta_prior["Delta"] = {
222 "type": "unif",
223 "lower": inj["Delta"] - DeltaSky,
224 "upper": inj["Delta"] + DeltaSky,
225 }
226 else:
227 theta_prior["Alpha"] = inj["Alpha"]
228 theta_prior["Delta"] = inj["Delta"]
229 # ptemcee sampler settings - in a real application we might want higher values
230 ntemps = 2
231 log10beta_min = -1
232 nwalkers = 100
233 nsteps = [200, 200] # [burnin,production]
234
235 mcmcsearch = pyfstat.MCMCSearch(
236 label="mcmc_search_" + search_keys_label,
237 outdir=outdir,
238 sftfilepattern=os.path.join(outdir, "*simulated_signal*sft"),
239 theta_prior=theta_prior,
240 tref=inj["tref"],
241 nsteps=nsteps,
242 nwalkers=nwalkers,
243 ntemps=ntemps,
244 log10beta_min=log10beta_min,
245 )
246 # walker plot is generated during main run of the search class
247 mcmcsearch.run(
248 walker_plot_args={"plot_det_stat": True, "injection_parameters": inj}
249 )
250 mcmcsearch.print_summary()
251
252 # call some built-in plotting methods
253 # these can all highlight the injection parameters, too
254 logger.info("Making MCMCSearch {:s} corner plot...".format("-".join(search_keys)))
255 mcmcsearch.plot_corner(truths=inj)
256 logger.info("Making MCMCSearch prior-posterior comparison plot...")
257 mcmcsearch.plot_prior_posterior(injection_parameters=inj)
258
259 # NOTE: everything below here is just custom commandline output and plotting
260 # for this particular example, which uses the PyFstat outputs,
261 # but isn't very instructive if you just want to learn the main usage of the package.
262
263 # some informative command-line output comparing search results and injection
264 # get max of GridSearch, contains twoF and all Doppler parameters in the dict
265 max_dict_grid = gridsearch.get_max_twoF()
266 # same for MCMCSearch, here twoF is separate, and non-sampled parameters are not included either
267 max_dict_mcmc, max_2F_mcmc = mcmcsearch.get_max_twoF()
268 logger.info(
269 "max2F={:.4f} from GridSearch, offsets from injection: {:s}.".format(
270 max_dict_grid["twoF"],
271 ", ".join(
272 [
273 "{:.4e} in {:s}".format(max_dict_grid[key] - inj[key], key)
274 for key in search_keys
275 ]
276 ),
277 )
278 )
279 logger.info(
280 "max2F={:.4f} from MCMCSearch, offsets from injection: {:s}.".format(
281 max_2F_mcmc,
282 ", ".join(
283 [
284 "{:.4e} in {:s}".format(max_dict_mcmc[key] - inj[key], key)
285 for key in search_keys
286 ]
287 ),
288 )
289 )
290 # get additional point and interval estimators
291 stats_dict_mcmc = mcmcsearch.get_summary_stats()
292 logger.info(
293 "mean from MCMCSearch: offset from injection by {:s},"
294 " or in fractions of 2sigma intervals: {:s}.".format(
295 ", ".join(
296 [
297 "{:.4e} in {:s}".format(
298 stats_dict_mcmc[key]["mean"] - inj[key], key
299 )
300 for key in search_keys
301 ]
302 ),
303 ", ".join(
304 [
305 "{:.2f}% in {:s}".format(
306 100
307 * np.abs(stats_dict_mcmc[key]["mean"] - inj[key])
308 / (2 * stats_dict_mcmc[key]["std"]),
309 key,
310 )
311 for key in search_keys
312 ]
313 ),
314 )
315 )
316 logger.info(
317 "median from MCMCSearch: offset from injection by {:s},"
318 " or in fractions of 90% confidence intervals: {:s}.".format(
319 ", ".join(
320 [
321 "{:.4e} in {:s}".format(
322 stats_dict_mcmc[key]["median"] - inj[key], key
323 )
324 for key in search_keys
325 ]
326 ),
327 ", ".join(
328 [
329 "{:.2f}% in {:s}".format(
330 100
331 * np.abs(stats_dict_mcmc[key]["median"] - inj[key])
332 / (
333 stats_dict_mcmc[key]["upper90"]
334 - stats_dict_mcmc[key]["lower90"]
335 ),
336 key,
337 )
338 for key in search_keys
339 ]
340 ),
341 )
342 )
343
344 # do additional custom plotting
345 logger.info("Loading grid and MCMC search results for custom comparison plots...")
346 gridfile = os.path.join(outdir, gridsearch.label + "_NA_GridSearch.txt")
347 if not os.path.isfile(gridfile):
348 raise RuntimeError(
349 "Failed to load GridSearch results from file '{:s}',"
350 " something must have gone wrong!".format(gridfile)
351 )
352 grid_res = pyfstat.utils.read_txt_file_with_header(gridfile)
353 mcmc_file = os.path.join(outdir, mcmcsearch.label + "_samples.dat")
354 if not os.path.isfile(mcmc_file):
355 raise RuntimeError(
356 "Failed to load MCMCSearch results from file '{:s}',"
357 " something must have gone wrong!".format(mcmc_file)
358 )
359 mcmc_res = pyfstat.utils.read_txt_file_with_header(mcmc_file)
360
361 zoom = {
362 "F0": [inj["F0"] - 10 * dF0, inj["F0"] + 10 * dF0],
363 "F1": [inj["F1"] - 5 * dF1, inj["F1"] + 5 * dF1],
364 }
365
366 # we'll use the two local plotting functions defined above
367 # to avoid code duplication in the sky case
368 logger.info("Creating MCMC-grid comparison plots...")
369 plot_grid_vs_samples(grid_res, mcmc_res, "F0", "F1")
370 plot_2F_scatter(grid_res, "grid", "F0", "F1")
371 plot_2F_scatter(mcmc_res, "mcmc", "F0", "F1")
372 if sky:
373 plot_grid_vs_samples(grid_res, mcmc_res, "Alpha", "Delta")
374 plot_2F_scatter(grid_res, "grid", "Alpha", "Delta")
375 plot_2F_scatter(mcmc_res, "mcmc", "Alpha", "Delta")
Total running time of the script: ( 0 minutes 0.000 seconds)