|
| 1 | +import argparse |
| 2 | +import math |
| 3 | +import numpy as np |
| 4 | +import scipy as sp |
| 5 | +import scipy.linalg |
| 6 | +import scipy.sparse |
| 7 | +import scipy.sparse.linalg |
| 8 | +import sys |
| 9 | +import tqdm |
| 10 | +import time |
| 11 | +from schwinger.schwinger import * |
| 12 | + |
| 13 | +def hmc_update(cfg, action, tau, n_leap, verbose=True): |
| 14 | + old_cfg = np.copy(cfg) |
| 15 | + old_S = action.init_traj(old_cfg) |
| 16 | + old_pi = sample_pi(cfg.shape) |
| 17 | + old_K = np.sum(old_pi*old_pi) / 2 |
| 18 | + old_H = old_S + old_K |
| 19 | + |
| 20 | + cfg = np.copy(cfg) |
| 21 | + new_pi = np.copy(old_pi) |
| 22 | + leapfrog_update(cfg, new_pi, action, tau, n_leap, verbose=verbose) |
| 23 | + |
| 24 | + new_S = action.compute_action(cfg) |
| 25 | + new_K = np.sum(new_pi*new_pi) / 2 |
| 26 | + new_H = new_S + new_K |
| 27 | + |
| 28 | + delta_H = new_H - old_H |
| 29 | + if verbose: |
| 30 | + print("Delta H = {:.5g} - {:.5g} = {:.5g}".format(new_H, old_H, delta_H)) |
| 31 | + print("Delta S = {:.5g} - {:.5g} = {:.5g}".format(new_S, old_S, new_S - old_S)) |
| 32 | + print("Delta K = {:.5g} - {:.5g} = {:.5g}".format(new_K, old_K, new_K - old_K)) |
| 33 | + |
| 34 | + # metropolis step |
| 35 | + acc = 0 |
| 36 | + if np.random.random() < np.exp(-delta_H): |
| 37 | + acc = 1 |
| 38 | + S = new_S |
| 39 | + else: |
| 40 | + cfg = old_cfg |
| 41 | + S = old_S |
| 42 | + if verbose: |
| 43 | + print("Acc {:.5g} (changed {})".format(min(1.0, np.exp(-delta_H)), acc)) |
| 44 | + |
| 45 | + return cfg, S, acc |
| 46 | + |
| 47 | +def run_hmc(L, n_step, n_skip, n_therm, tau, n_leap, action, cfg, *, topo_hop_freq=0): |
| 48 | + Nd = len(L) |
| 49 | + V = np.prod(L) |
| 50 | + shape = tuple([Nd] + list(L)) |
| 51 | + pure_gauge_action = PureGaugeAction(beta=action.beta) |
| 52 | + |
| 53 | + # MC updates |
| 54 | + total_acc = 0 |
| 55 | + hop_acc = 0 |
| 56 | + hop_props = 0 |
| 57 | + cfgs = [] |
| 58 | + plaqs = [] |
| 59 | + topos = [] |
| 60 | + with tqdm.tqdm(total = n_therm + n_step, postfix='Acc: ???, Q: ???') as t: |
| 61 | + for i in tqdm.tqdm(range(-n_therm, n_step)): |
| 62 | + print("MC step {} / {}".format(i+1, n_step)) |
| 63 | + cfg, S, acc = hmc_update(cfg, action, tau, n_leap) |
| 64 | + if i >= 0: total_acc += acc |
| 65 | + if topo_hop_freq > 0 and i % topo_hop_freq == 0: |
| 66 | + assert Nd == 2 |
| 67 | + assert L[0] == L[1] |
| 68 | + Nf = 2 |
| 69 | + hop_props += 1 |
| 70 | + dQ = np.random.randint(-1, 2) |
| 71 | + print('Topo hop proposal dQ', dQ) |
| 72 | + dU = make_topo_cfg(L[0], dQ) |
| 73 | + new_cfg = cfg * dU # NOTE: This only works because Abelian |
| 74 | + # new_S = action.compute_action(new_cfg) |
| 75 | + old_S_marginal = pure_gauge_action.compute_action(cfg) |
| 76 | + new_S_marginal = pure_gauge_action.compute_action(new_cfg) |
| 77 | + _, old_logdetD = np.linalg.slogdet(dirac_op(cfg, kappa=action.kappa).toarray()) |
| 78 | + _, new_logdetD = np.linalg.slogdet(dirac_op(new_cfg, kappa=action.kappa).toarray()) |
| 79 | + old_S_marginal -= Nf * old_logdetD |
| 80 | + new_S_marginal -= Nf * new_logdetD |
| 81 | + print('delta S =', new_S_marginal - old_S_marginal) |
| 82 | + print('delta (-log det D) =', Nf*(old_logdetD - new_logdetD)) |
| 83 | + print('old logdetD =', np.real(old_logdetD), 'new logdetD =', np.real(new_logdetD)) |
| 84 | + if np.random.random() < np.exp(-new_S_marginal + old_S_marginal): |
| 85 | + print('accepted') |
| 86 | + cfg = new_cfg |
| 87 | + hop_acc += 1 |
| 88 | + else: |
| 89 | + print('rejected') |
| 90 | + |
| 91 | + # avg plaq |
| 92 | + plaq = np.sum(np.real(ensemble_plaqs(cfg))) / V |
| 93 | + print("Average plaq = {:.6g}".format(plaq)) |
| 94 | + # topo Q |
| 95 | + topo = np.sum(compute_topo(cfg)) |
| 96 | + Q = int(round(topo)) |
| 97 | + print("Topo = {:d}".format(Q)) |
| 98 | + |
| 99 | + # save cfg |
| 100 | + if i >= 0 and i % n_skip == 0: |
| 101 | + print("Saving cfg!") |
| 102 | + cfgs.append(cfg) |
| 103 | + plaqs.append(plaq) |
| 104 | + topos.append(topo) |
| 105 | + t.postfix = 'Acc: {:.3f}, Q: {:d}'.format(total_acc / (i+1), Q) |
| 106 | + t.update() |
| 107 | + |
| 108 | + print("MC finished.") |
| 109 | + print("Total acc {:.4f}".format(total_acc / n_step)) |
| 110 | + if topo_hop_freq > 0: |
| 111 | + print("Total hop acc {:.4f}".format(hop_acc / hop_props)) |
| 112 | + return cfgs, plaqs, topos |
| 113 | + |
| 114 | + |
| 115 | +if __name__ == "__main__": |
| 116 | + parser = argparse.ArgumentParser(description='Run HMC for Schwinger') |
| 117 | + # general params |
| 118 | + parser.add_argument('--seed', type=int) |
| 119 | + parser.add_argument('--Lx', type=int, required=True) |
| 120 | + parser.add_argument('--Lt', type=int, required=True) |
| 121 | + parser.add_argument('--tag', type=str, default="") |
| 122 | + parser.add_argument('--Ncfg', type=int, required=True) |
| 123 | + parser.add_argument('--n_skip', type=int, required=True) |
| 124 | + parser.add_argument('--n_therm', type=int, required=True) |
| 125 | + parser.add_argument('--tau', type=float, default=0.5) |
| 126 | + parser.add_argument('--n_leap', type=int, default=20) |
| 127 | + parser.add_argument('--init_cfg', type=str) |
| 128 | + parser.add_argument('--topo_hop_freq', type=int, default=0) |
| 129 | + # action params |
| 130 | + parser.add_argument('--type', type=str, required=True) |
| 131 | + parser.add_argument('--beta', type=float) |
| 132 | + parser.add_argument('--kappa', type=float) |
| 133 | + parser.add_argument('--reweight_dt', type=int) |
| 134 | + parser.add_argument('--conn_weight', type=float, default=1.0) |
| 135 | + parser.add_argument('--disc_weight', type=float, default=0.0) |
| 136 | + parser.add_argument('--xspace', type=int, default=4) |
| 137 | + parser.add_argument('--eps_reg', type=float) |
| 138 | + parser.add_argument('--polya_x', type=int) |
| 139 | + parser.add_argument('--theta_i', type=float) |
| 140 | + parser.add_argument('--delta', type=float) |
| 141 | + parser.add_argument('--compute_dirac', type=str, default="", |
| 142 | + help="Which Dirac op to compute on cfgs, if any") |
| 143 | + args = parser.parse_args() |
| 144 | + print("args = {}".format(args)) |
| 145 | + |
| 146 | + start = time.time() |
| 147 | + |
| 148 | + # handle params |
| 149 | + if len(args.tag) > 0: |
| 150 | + args.tag = "_" + args.tag |
| 151 | + if args.seed is None: |
| 152 | + args.seed = np.random.randint(np.iinfo('uint32').max) |
| 153 | + print("Generated random seed = {}".format(args.seed)) |
| 154 | + np.random.seed(args.seed) |
| 155 | + print("Using seed = {}.".format(args.seed)) |
| 156 | + L = [args.Lx, args.Lt] |
| 157 | + Nd = len(L) |
| 158 | + shape = tuple([Nd] + list(L)) |
| 159 | + if args.init_cfg is None: |
| 160 | + print('Generating warm init cfg.') |
| 161 | + init_cfg_A = 0.4*np.random.normal(size=shape) |
| 162 | + cfg = np.exp(1j * init_cfg_A) |
| 163 | + elif args.init_cfg == 'hot': |
| 164 | + print('Generating hot init cfg.') |
| 165 | + init_cfg_A = 2*np.pi*np.random.random(size=shape) |
| 166 | + cfg = np.exp(1j * init_cfg_A) |
| 167 | + else: |
| 168 | + print('Loading init cfg from {}.'.format(args.init_cfg)) |
| 169 | + cfg = np.load(args.init_cfg) |
| 170 | + cfg = cfg.reshape(shape) |
| 171 | + tot_steps = args.Ncfg * args.n_skip |
| 172 | + if args.type == "pure_gauge": |
| 173 | + assert(args.beta is not None) |
| 174 | + action = PureGaugeAction(args.beta) |
| 175 | + elif args.type == "two_flavor": |
| 176 | + assert(args.beta is not None) |
| 177 | + assert(args.kappa is not None) |
| 178 | + action = TwoFlavorAction(args.beta, args.kappa) |
| 179 | + elif args.type == "exact_1flav_staggered": |
| 180 | + assert(args.beta is not None) |
| 181 | + assert(args.kappa is not None) |
| 182 | + m0 = m0_from_kappa(args.kappa, Nd) |
| 183 | + action = ExactStaggeredAction(args.beta, m0, Nf=1) |
| 184 | + elif args.type == "exact_2flav_staggered": |
| 185 | + assert(args.beta is not None) |
| 186 | + assert(args.kappa is not None) |
| 187 | + m0 = m0_from_kappa(args.kappa, Nd) |
| 188 | + action = ExactStaggeredAction(args.beta, m0, Nf=2) |
| 189 | + elif args.type == "exact_2flav_wilson": |
| 190 | + assert(args.beta is not None) |
| 191 | + assert(args.kappa is not None) |
| 192 | + action = ExactWilsonAction(args.beta, args.kappa, Nf=2) |
| 193 | + else: |
| 194 | + print("Unknown action type {}".format(args.type)) |
| 195 | + sys.exit(1) |
| 196 | + |
| 197 | + # do the thing! |
| 198 | + cfgs, plaqs, topos = run_hmc(L, tot_steps, args.n_skip, args.n_therm, |
| 199 | + args.tau, args.n_leap, action, cfg, |
| 200 | + topo_hop_freq=args.topo_hop_freq) |
| 201 | + |
| 202 | + # write stuff out |
| 203 | + prefix = 'u1_{:s}_N{:d}_skip{:d}_therm{:d}_{:d}_{:d}{:s}'.format( |
| 204 | + action.make_tag(), args.Ncfg, args.n_skip, args.n_therm, |
| 205 | + args.Lx, args.Lt, args.tag) |
| 206 | + fname = prefix + '.npy' |
| 207 | + np.save(fname, cfgs) |
| 208 | + print("Wrote ensemble to {}".format(fname)) |
| 209 | + fname = prefix + '.plaq.npy' |
| 210 | + np.save(fname, plaqs) |
| 211 | + print("Wrote plaqs to {}".format(fname)) |
| 212 | + fname = prefix + '.topo.npy' |
| 213 | + np.save(fname, topos) |
| 214 | + print("Wrote topos to {}".format(fname)) |
| 215 | + |
| 216 | + # Compute meson correlation functions |
| 217 | + if args.compute_dirac == "": |
| 218 | + print("Skipping propagator calcs") |
| 219 | + else: |
| 220 | + if args.compute_dirac == "wilson": |
| 221 | + assert args.kappa is not None, "kappa required" |
| 222 | + make_D = lambda cfg: dirac_op(cfg, kappa=args.kappa).toarray() |
| 223 | + Cts = [] |
| 224 | + for cfg in cfgs: |
| 225 | + D = make_D(cfg) |
| 226 | + Dinv = np.linalg.inv(D) |
| 227 | + Dinv = np.reshape(Dinv, L + [NS] + L + [NS]) |
| 228 | + Lx, Lt = L |
| 229 | + C_src_avg = 0 |
| 230 | + for x in range(Lx): |
| 231 | + for t in range(Lt): |
| 232 | + prop = np.roll(Dinv[:,:,:,x,t,:], (-x,-t), axis=(0,1)) |
| 233 | + prop_dag = np.einsum( |
| 234 | + 'ab,xybc,cd->xyad', pauli(3), |
| 235 | + np.swapaxes(np.conj(prop), axis1=-1, axis2=-2), pauli(3)) |
| 236 | + meson_corr = lambda p1, p2: np.einsum( |
| 237 | + 'xyab,bc,xycd,da -> xy', prop, p1, prop_dag, p2) |
| 238 | + assert np.allclose( |
| 239 | + meson_corr(pauli(3), pauli(3)), |
| 240 | + np.sum(np.abs(prop**2), axis=(2,3))) |
| 241 | + pion_corr = meson_corr(pauli(3), pauli(3)) |
| 242 | + C_src_avg = C_src_avg + pion_corr / (Lx*Lt) |
| 243 | + Ct = np.mean(C_src_avg, axis=0) |
| 244 | + Cts.append(Ct) |
| 245 | + elif args.compute_dirac == "staggered": |
| 246 | + assert args.kappa is not None, "kappa required" |
| 247 | + m0 = m0_from_kappa(args.kappa, Nd) |
| 248 | + make_D = lambda cfg: make_op_matrix(L, lambda psi: apply_staggered_D(psi, U=cfg, m0=m0)) |
| 249 | + raise NotImplementedError('staggered not implemented') |
| 250 | + else: |
| 251 | + raise RuntimeError(f"Dirac op type {args.compute_dirac} not supported") |
| 252 | + |
| 253 | + fname = prefix + '.meson_Ct.npy' |
| 254 | + np.save(fname, np.array(Cts)) |
| 255 | + print("Wrote Cts to {}".format(fname)) |
| 256 | + print("TIME ensemble gen {:.2f}s".format(time.time()-start)) |
0 commit comments