Note
Click here to download the full example code
MCMC search v.s. grid searchΒΆ
An example to compare MCMCSearch and GridSearch on the same data.
8 import pyfstat
9 import os
10 import numpy as np
11 import matplotlib.pyplot as plt
12
13 # flip this switch for a more expensive 4D (F0,F1,Alpha,Delta) run
14 # instead of just (F0,F1)
15 # (still only a few minutes on current laptops)
16 sky = False
17
18 outdir = os.path.join(
19 "PyFstat_example_data", "PyFstat_example_simple_mcmc_vs_grid_comparison"
20 )
21 if sky:
22 outdir += "AlphaDelta"
23
24 # parameters for the data set to generate
25 tstart = 1000000000
26 duration = 30 * 86400
27 Tsft = 1800
28 detectors = "H1,L1"
29 sqrtSX = 1e-22
30
31 # parameters for injected signals
32 inj = {
33 "tref": tstart,
34 "F0": 30.0,
35 "F1": -1e-10,
36 "F2": 0,
37 "Alpha": 0.5,
38 "Delta": 1,
39 "h0": 0.05 * sqrtSX,
40 "cosi": 1.0,
41 }
42
43 # latex-formatted plotting labels
44 labels = {
45 "F0": "$f$ [Hz]",
46 "F1": "$\\dot{f}$ [Hz/s]",
47 "2F": "$2\\mathcal{F}$",
48 "Alpha": "$\\alpha$",
49 "Delta": "$\\delta$",
50 }
51 labels["max2F"] = "$\\max\\,$" + labels["2F"]
52
53
54 def plot_grid_vs_samples(grid_res, mcmc_res, xkey, ykey):
55 """local plotting function to avoid code duplication in the 4D case"""
56 plt.plot(grid_res[xkey], grid_res[ykey], ".", label="grid")
57 plt.plot(mcmc_res[xkey], mcmc_res[ykey], ".", label="mcmc")
58 plt.plot(inj[xkey], inj[ykey], "*k", label="injection")
59 grid_maxidx = np.argmax(grid_res["twoF"])
60 mcmc_maxidx = np.argmax(mcmc_res["twoF"])
61 plt.plot(
62 grid_res[xkey][grid_maxidx],
63 grid_res[ykey][grid_maxidx],
64 "+g",
65 label=labels["max2F"] + "(grid)",
66 )
67 plt.plot(
68 mcmc_res[xkey][mcmc_maxidx],
69 mcmc_res[ykey][mcmc_maxidx],
70 "xm",
71 label=labels["max2F"] + "(mcmc)",
72 )
73 plt.xlabel(labels[xkey])
74 plt.ylabel(labels[ykey])
75 plt.legend()
76 plotfilename_base = os.path.join(outdir, "grid_vs_mcmc_{:s}{:s}".format(xkey, ykey))
77 plt.savefig(plotfilename_base + ".png")
78 if xkey == "F0" and ykey == "F1":
79 plt.xlim(zoom[xkey])
80 plt.ylim(zoom[ykey])
81 plt.savefig(plotfilename_base + "_zoom.png")
82 plt.close()
83
84
85 def plot_2F_scatter(res, label, xkey, ykey):
86 """local plotting function to avoid code duplication in the 4D case"""
87 markersize = 3 if label == "grid" else 1
88 sc = plt.scatter(res[xkey], res[ykey], c=res["twoF"], s=markersize)
89 cb = plt.colorbar(sc)
90 plt.xlabel(labels[xkey])
91 plt.ylabel(labels[ykey])
92 cb.set_label(labels["2F"])
93 plt.title(label)
94 plt.plot(inj[xkey], inj[ykey], "*k", label="injection")
95 maxidx = np.argmax(res["twoF"])
96 plt.plot(
97 res[xkey][maxidx],
98 res[ykey][maxidx],
99 "+r",
100 label=labels["max2F"],
101 )
102 plt.legend()
103 plotfilename_base = os.path.join(
104 outdir, "{:s}_{:s}{:s}_2F".format(label, xkey, ykey)
105 )
106 plt.xlim([min(res[xkey]), max(res[xkey])])
107 plt.ylim([min(res[ykey]), max(res[ykey])])
108 plt.savefig(plotfilename_base + ".png")
109 plt.close()
110
111
112 if __name__ == "__main__":
113
114 print("Generating SFTs with injected signal...")
115 writer = pyfstat.Writer(
116 label="simulated_signal",
117 outdir=outdir,
118 tstart=tstart,
119 duration=duration,
120 detectors=detectors,
121 sqrtSX=sqrtSX,
122 Tsft=Tsft,
123 **inj,
124 Band=1, # default band estimation would be too narrow for a wide grid/prior
125 )
126 writer.make_data()
127 print("")
128
129 # set up square search grid with fixed (F0,F1) mismatch
130 # and (optionally) some ad-hoc sky coverage
131 m = 0.001
132 dF0 = np.sqrt(12 * m) / (np.pi * duration)
133 dF1 = np.sqrt(180 * m) / (np.pi * duration ** 2)
134 DeltaF0 = 500 * dF0
135 DeltaF1 = 200 * dF1
136 if sky:
137 # cover less range to keep runtime down
138 DeltaF0 /= 10
139 DeltaF1 /= 10
140 F0s = [inj["F0"] - DeltaF0 / 2.0, inj["F0"] + DeltaF0 / 2.0, dF0]
141 F1s = [inj["F1"] - DeltaF1 / 2.0, inj["F1"] + DeltaF1 / 2.0, dF1]
142 F2s = [inj["F2"]]
143 search_keys = ["F0", "F1"] # only the ones that aren't 0-width
144 if sky:
145 dSky = 0.01 # rather coarse to keep runtime down
146 DeltaSky = 10 * dSky
147 Alphas = [inj["Alpha"] - DeltaSky / 2.0, inj["Alpha"] + DeltaSky / 2.0, dSky]
148 Deltas = [inj["Delta"] - DeltaSky / 2.0, inj["Delta"] + DeltaSky / 2.0, dSky]
149 search_keys += ["Alpha", "Delta"]
150 else:
151 Alphas = [inj["Alpha"]]
152 Deltas = [inj["Delta"]]
153 search_keys_label = "".join(search_keys)
154
155 print("Performing GridSearch...")
156 gridsearch = pyfstat.GridSearch(
157 label="grid_search_" + search_keys_label,
158 outdir=outdir,
159 sftfilepattern=os.path.join(outdir, "*simulated_signal*sft"),
160 F0s=F0s,
161 F1s=F1s,
162 F2s=F2s,
163 Alphas=Alphas,
164 Deltas=Deltas,
165 tref=inj["tref"],
166 )
167 gridsearch.run()
168 gridsearch.print_max_twoF()
169
170 # do some plots of the GridSearch results
171 if not sky: # this plotter can't currently deal with too large result arrays
172 print("Plotting 1D 2F distributions...")
173 for key in search_keys:
174 gridsearch.plot_1D(xkey=key, xlabel=labels[key], ylabel=labels["2F"])
175
176 print("Making GridSearch {:s} corner plot...".format("-".join(search_keys)))
177 vals = [np.unique(gridsearch.data[key]) - inj[key] for key in search_keys]
178 twoF = gridsearch.data["twoF"].reshape([len(kval) for kval in vals])
179 corner_labels = [
180 "$f - f_0$ [Hz]",
181 "$\\dot{f} - \\dot{f}_0$ [Hz/s]",
182 ]
183 if sky:
184 corner_labels.append("$\\alpha - \\alpha_0$")
185 corner_labels.append("$\\delta - \\delta_0$")
186 corner_labels.append(labels["2F"])
187 gridcorner_fig, gridcorner_axes = pyfstat.gridcorner(
188 twoF, vals, projection="log_mean", labels=corner_labels, whspace=0.1, factor=1.8
189 )
190 gridcorner_fig.savefig(os.path.join(outdir, gridsearch.label + "_corner.png"))
191 plt.close(gridcorner_fig)
192 print("")
193
194 print("Performing MCMCSearch...")
195 # set up priors in F0 and F1 (over)covering the grid ranges
196 if sky: # MCMC will still be fast in 4D with wider range than grid
197 DeltaF0 *= 50
198 DeltaF1 *= 50
199 theta_prior = {
200 "F0": {
201 "type": "unif",
202 "lower": inj["F0"] - DeltaF0 / 2.0,
203 "upper": inj["F0"] + DeltaF0 / 2.0,
204 },
205 "F1": {
206 "type": "unif",
207 "lower": inj["F1"] - DeltaF1 / 2.0,
208 "upper": inj["F1"] + DeltaF1 / 2.0,
209 },
210 "F2": inj["F2"],
211 }
212 if sky:
213 # also implicitly covering twice the grid range here
214 theta_prior["Alpha"] = {
215 "type": "unif",
216 "lower": inj["Alpha"] - DeltaSky,
217 "upper": inj["Alpha"] + DeltaSky,
218 }
219 theta_prior["Delta"] = {
220 "type": "unif",
221 "lower": inj["Delta"] - DeltaSky,
222 "upper": inj["Delta"] + DeltaSky,
223 }
224 else:
225 theta_prior["Alpha"] = inj["Alpha"]
226 theta_prior["Delta"] = inj["Delta"]
227 # ptemcee sampler settings - in a real application we might want higher values
228 ntemps = 2
229 log10beta_min = -1
230 nwalkers = 100
231 nsteps = [200, 200] # [burnin,production]
232
233 mcmcsearch = pyfstat.MCMCSearch(
234 label="mcmc_search_" + search_keys_label,
235 outdir=outdir,
236 sftfilepattern=os.path.join(outdir, "*simulated_signal*sft"),
237 theta_prior=theta_prior,
238 tref=inj["tref"],
239 nsteps=nsteps,
240 nwalkers=nwalkers,
241 ntemps=ntemps,
242 log10beta_min=log10beta_min,
243 )
244 # walker plot is generated during main run of the search class
245 mcmcsearch.run(
246 walker_plot_args={"plot_det_stat": True, "injection_parameters": inj}
247 )
248 mcmcsearch.print_summary()
249
250 # call some built-in plotting methods
251 # these can all highlight the injection parameters, too
252 print("Making MCMCSearch {:s} corner plot...".format("-".join(search_keys)))
253 mcmcsearch.plot_corner(truths=inj)
254 print("Making MCMCSearch prior-posterior comparison plot...")
255 mcmcsearch.plot_prior_posterior(injection_parameters=inj)
256 print("")
257
258 # NOTE: everything below here is just custom commandline output and plotting
259 # for this particular example, which uses the PyFstat outputs,
260 # but isn't very instructive if you just want to learn the main usage of the package.
261
262 # some informative command-line output comparing search results and injection
263 # get max of GridSearch, contains twoF and all Doppler parameters in the dict
264 max_dict_grid = gridsearch.get_max_twoF()
265 # same for MCMCSearch, here twoF is separate, and non-sampled parameters are not included either
266 max_dict_mcmc, max_2F_mcmc = mcmcsearch.get_max_twoF()
267 print(
268 "max2F={:.4f} from GridSearch, offsets from injection: {:s}.".format(
269 max_dict_grid["twoF"],
270 ", ".join(
271 [
272 "{:.4e} in {:s}".format(max_dict_grid[key] - inj[key], key)
273 for key in search_keys
274 ]
275 ),
276 )
277 )
278 print(
279 "max2F={:.4f} from MCMCSearch, offsets from injection: {:s}.".format(
280 max_2F_mcmc,
281 ", ".join(
282 [
283 "{:.4e} in {:s}".format(max_dict_mcmc[key] - inj[key], key)
284 for key in search_keys
285 ]
286 ),
287 )
288 )
289 # get additional point and interval estimators
290 stats_dict_mcmc = mcmcsearch.get_summary_stats()
291 print(
292 "mean from MCMCSearch: offset from injection by {:s},"
293 " or in fractions of 2sigma intervals: {:s}.".format(
294 ", ".join(
295 [
296 "{:.4e} in {:s}".format(
297 stats_dict_mcmc[key]["mean"] - inj[key], key
298 )
299 for key in search_keys
300 ]
301 ),
302 ", ".join(
303 [
304 "{:.2f}% in {:s}".format(
305 100
306 * np.abs(stats_dict_mcmc[key]["mean"] - inj[key])
307 / (2 * stats_dict_mcmc[key]["std"]),
308 key,
309 )
310 for key in search_keys
311 ]
312 ),
313 )
314 )
315 print(
316 "median from MCMCSearch: offset from injection by {:s},"
317 " or in fractions of 90% confidence intervals: {:s}.".format(
318 ", ".join(
319 [
320 "{:.4e} in {:s}".format(
321 stats_dict_mcmc[key]["median"] - inj[key], key
322 )
323 for key in search_keys
324 ]
325 ),
326 ", ".join(
327 [
328 "{:.2f}% in {:s}".format(
329 100
330 * np.abs(stats_dict_mcmc[key]["median"] - inj[key])
331 / (
332 stats_dict_mcmc[key]["upper90"]
333 - stats_dict_mcmc[key]["lower90"]
334 ),
335 key,
336 )
337 for key in search_keys
338 ]
339 ),
340 )
341 )
342 print()
343
344 # do additional custom plotting
345 print("Loading grid and MCMC search results for custom comparison plots...")
346 gridfile = os.path.join(outdir, gridsearch.label + "_NA_GridSearch.txt")
347 if not os.path.isfile(gridfile):
348 raise RuntimeError(
349 "Failed to load GridSearch results from file '{:s}',"
350 " something must have gone wrong!".format(gridfile)
351 )
352 grid_res = pyfstat.helper_functions.read_txt_file_with_header(gridfile)
353 mcmc_file = os.path.join(outdir, mcmcsearch.label + "_samples.dat")
354 if not os.path.isfile(mcmc_file):
355 raise RuntimeError(
356 "Failed to load MCMCSearch results from file '{:s}',"
357 " something must have gone wrong!".format(mcmc_file)
358 )
359 mcmc_res = pyfstat.helper_functions.read_txt_file_with_header(mcmc_file)
360
361 zoom = {
362 "F0": [inj["F0"] - 10 * dF0, inj["F0"] + 10 * dF0],
363 "F1": [inj["F1"] - 5 * dF1, inj["F1"] + 5 * dF1],
364 }
365
366 # we'll use the two local plotting functions defined above
367 # to avoid code duplication in the sky case
368 print("Creating MCMC-grid comparison plots...")
369 plot_grid_vs_samples(grid_res, mcmc_res, "F0", "F1")
370 plot_2F_scatter(grid_res, "grid", "F0", "F1")
371 plot_2F_scatter(mcmc_res, "mcmc", "F0", "F1")
372 if sky:
373 plot_grid_vs_samples(grid_res, mcmc_res, "Alpha", "Delta")
374 plot_2F_scatter(grid_res, "grid", "Alpha", "Delta")
375 plot_2F_scatter(mcmc_res, "mcmc", "Alpha", "Delta")
Total running time of the script: ( 0 minutes 0.000 seconds)