|
| 1 | +#!/bin/env python |
| 2 | + |
| 3 | +import os |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import h5py |
| 7 | + |
| 8 | +import virgo.mpi.parallel_sort as psort |
| 9 | +import virgo.mpi.parallel_hdf5 as phdf5 |
| 10 | + |
| 11 | +import lustre |
| 12 | +import read_hbtplus |
| 13 | + |
| 14 | +import unyt |
| 15 | +import swift_cells |
| 16 | + |
| 17 | +from mpi4py import MPI |
| 18 | + |
| 19 | +comm = MPI.COMM_WORLD |
| 20 | +comm_rank = comm.Get_rank() |
| 21 | +comm_size = comm.Get_size() |
| 22 | + |
| 23 | +# Maximum number of particle types |
| 24 | +NTYPEMAX = 7 |
| 25 | + |
| 26 | + |
| 27 | +def message(s): |
| 28 | + if comm_rank == 0: |
| 29 | + print(s) |
| 30 | + |
| 31 | + |
| 32 | +def exchange_array(arr, dest, comm): |
| 33 | + """ |
| 34 | + Carry out an alltoallv on the supplied array, given the MPI rank |
| 35 | + to send each element to. |
| 36 | + """ |
| 37 | + order = np.argsort(dest) |
| 38 | + sendbuf = arr[order] |
| 39 | + send_count = np.bincount(dest, minlength=comm_size) |
| 40 | + send_offset = np.cumsum(send_count) - send_count |
| 41 | + recv_count = np.zeros_like(send_count) |
| 42 | + comm.Alltoall(send_count, recv_count) |
| 43 | + recv_offset = np.cumsum(recv_count) - recv_count |
| 44 | + recvbuf = np.ndarray(recv_count.sum(), dtype=arr.dtype) |
| 45 | + psort.my_alltoallv( |
| 46 | + sendbuf, send_count, send_offset, recvbuf, recv_count, recv_offset, comm=comm |
| 47 | + ) |
| 48 | + return recvbuf |
| 49 | + |
| 50 | + |
| 51 | +# Define a placeholder unit system, since everything we do is dimensionless. |
| 52 | +# However, it would be better to load this from a snapshot. |
| 53 | +def define_unit_system(): |
| 54 | + |
| 55 | + # Create a registry using this base unit system |
| 56 | + reg = unyt.unit_registry.UnitRegistry() |
| 57 | + |
| 58 | + # Add some units which might be useful for dealing with input halo catalogues |
| 59 | + unyt.define_unit("swift_mpc", 1.0 * unyt.cm, registry=reg) |
| 60 | + unyt.define_unit("swift_msun", 1.0 * unyt.g, registry=reg) |
| 61 | + unyt.define_unit("h", 1.0 * unyt.Hz, registry=reg) |
| 62 | + |
| 63 | + return reg |
| 64 | + |
| 65 | + |
| 66 | +def find_matching_halos( |
| 67 | + base_name1, |
| 68 | + base_name2, |
| 69 | + max_nr_particles, |
| 70 | + min_particle_id, |
| 71 | + max_particle_id, |
| 72 | + field_only, |
| 73 | +): |
| 74 | + |
| 75 | + # We only care about dimensionless quantities here, so |
| 76 | + # define a placeholder unit system |
| 77 | + registry = define_unit_system() |
| 78 | + a_unit = unyt.Unit("cm", registry=registry) ** 0 |
| 79 | + boxsize = None |
| 80 | + |
| 81 | + # Load halo data from the second catalogue |
| 82 | + keep_orphans = True |
| 83 | + halo_data2 = read_hbtplus.read_hbtplus_catalogue( |
| 84 | + comm, base_name2, a_unit, registry, boxsize, keep_orphans |
| 85 | + ) |
| 86 | + |
| 87 | + # Catalogue 2 properties |
| 88 | + track_ids2 = halo_data2["TrackId"] |
| 89 | + host_ids2 = halo_data2["HostHaloId"] |
| 90 | + is_central2 = halo_data2["is_central"] |
| 91 | + |
| 92 | + # Find the track ID of the central halo that has the same hostHaloID |
| 93 | + central_index2 = psort.parallel_match( |
| 94 | + host_ids2, host_ids2[is_central2 == 1], comm=comm |
| 95 | + ) |
| 96 | + host_track_ids2 = psort.fetch_elements( |
| 97 | + track_ids2[is_central2 == 1], central_index2, comm=comm |
| 98 | + ) |
| 99 | + |
| 100 | + # Find the index of that central halo (which may differ from the track ID) |
| 101 | + host_index2 = psort.parallel_match(host_track_ids2, track_ids2, comm=comm) |
| 102 | + |
| 103 | + # Free the other halo data |
| 104 | + del halo_data2 |
| 105 | + |
| 106 | + # Find group membership for particles in the first catalogue: |
| 107 | + total_nr_halos1, cat1_ids, cat1_grnr_in_cat1, rank_bound1 = read_hbtplus.read_hbtplus_groupnr( |
| 108 | + base_name1 |
| 109 | + ) |
| 110 | + |
| 111 | + # Find group membership for particles in the second catalogue |
| 112 | + total_nr_halos2, cat2_ids, cat2_grnr_in_cat2, rank_bound2 = read_hbtplus.read_hbtplus_groupnr( |
| 113 | + base_name2 |
| 114 | + ) |
| 115 | + |
| 116 | + # Decide range of halos in cat1 which we'll store on each rank: |
| 117 | + # This is used to partition the result between MPI ranks. |
| 118 | + nr_cat1_tot = total_nr_halos1 |
| 119 | + nr_cat1_per_rank = nr_cat1_tot // comm_size |
| 120 | + if comm_rank < comm_size - 1: |
| 121 | + nr_cat1_local = nr_cat1_per_rank |
| 122 | + else: |
| 123 | + nr_cat1_local = nr_cat1_tot - (comm_size - 1) * nr_cat1_per_rank |
| 124 | + |
| 125 | + # Clear group membership for particles with invalid IDs |
| 126 | + if min_particle_id != None: |
| 127 | + discard = cat1_ids < min_particle_id |
| 128 | + cat1_grnr_in_cat1[discard] = -1 |
| 129 | + |
| 130 | + if max_particle_id != None: |
| 131 | + discard = cat1_ids >= max_particle_id |
| 132 | + cat1_grnr_in_cat1[discard] = -1 |
| 133 | + |
| 134 | + # If we're only matching to field halos, then any particles in the second catalogue which |
| 135 | + # belong to a halo with hostHaloID != -1 need to have their group membership reset to their |
| 136 | + # host halo. |
| 137 | + if field_only: |
| 138 | + # Find particles in halos in cat2 |
| 139 | + in_halo = cat2_grnr_in_cat2 >= 0 |
| 140 | + # Fetch host halo array index for each particle in cat2, or -1 if not in a halo |
| 141 | + particle_host_index = -np.ones_like(cat2_grnr_in_cat2) |
| 142 | + particle_host_index[in_halo] = psort.fetch_elements( |
| 143 | + host_index2, cat2_grnr_in_cat2[in_halo], comm=comm |
| 144 | + ) |
| 145 | + # Where a particle's halo has a host halo, set its group membership to be the host halo |
| 146 | + have_host = particle_host_index >= 0 |
| 147 | + cat2_grnr_in_cat2[have_host] = particle_host_index[have_host] |
| 148 | + |
| 149 | + # Discard particles which are in no halo from each catalogue |
| 150 | + in_group = cat1_grnr_in_cat1 >= 0 |
| 151 | + cat1_ids = cat1_ids[in_group] |
| 152 | + cat1_grnr_in_cat1 = cat1_grnr_in_cat1[in_group] |
| 153 | + in_group = cat2_grnr_in_cat2 >= 0 |
| 154 | + cat2_ids = cat2_ids[in_group] |
| 155 | + cat2_grnr_in_cat2 = cat2_grnr_in_cat2[in_group] |
| 156 | + |
| 157 | + # Now we need to identify the first max_nr_particles remaining particles for each |
| 158 | + # halo in catalogue 1. First, find the ranking of each particle within the part of |
| 159 | + # its group which is stored on this MPI rank. First particle in a group has rank 0. |
| 160 | + unique_grnr, unique_index, unique_count = np.unique( |
| 161 | + cat1_grnr_in_cat1, return_index=True, return_counts=True |
| 162 | + ) |
| 163 | + cat1_rank_in_group = -np.ones_like(cat1_grnr_in_cat1) |
| 164 | + for ui, uc in zip(unique_index, unique_count): |
| 165 | + cat1_rank_in_group[ui : ui + uc] = np.arange(uc, dtype=int) |
| 166 | + assert np.all(cat1_rank_in_group >= 0) |
| 167 | + |
| 168 | + # Then for the first group on each rank we'll need to add the total number of particles in |
| 169 | + # the same group on all lower numbered ranks. Since the particles are sorted by group this |
| 170 | + # can only ever be the last group on each lower numbered rank. |
| 171 | + if len(unique_grnr) > 0: |
| 172 | + # This rank has at least one particle in a group. Store indexes of first and last groups |
| 173 | + # and the number of particles from the last group which are stored on this rank. |
| 174 | + assert unique_index[0] == 0 |
| 175 | + first_grnr = unique_grnr[0] |
| 176 | + last_grnr = unique_grnr[-1] |
| 177 | + last_grnr_count = unique_count[-1] |
| 178 | + else: |
| 179 | + # This rank has no particles in groups |
| 180 | + first_grnr = -1 |
| 181 | + last_grnr = -1 |
| 182 | + last_grnr_count = 0 |
| 183 | + all_last_grnr = comm.allgather(last_grnr) |
| 184 | + all_last_grnr_count = comm.allgather(last_grnr_count) |
| 185 | + # Loop over lower numbered ranks |
| 186 | + for rank_nr in range(comm_rank): |
| 187 | + if first_grnr >= 0 and all_last_grnr[rank_nr] == first_grnr: |
| 188 | + cat1_rank_in_group[: unique_count[0]] += all_last_grnr_count[rank_nr] |
| 189 | + |
| 190 | + # Only keep the first max_nr_particles remaining particles in each group in catalogue 1 |
| 191 | + keep = cat1_rank_in_group < max_nr_particles |
| 192 | + cat1_ids = cat1_ids[keep] |
| 193 | + cat1_grnr_in_cat1 = cat1_grnr_in_cat1[keep] |
| 194 | + |
| 195 | + # For each particle ID in catalogue 1, try to find the same particle ID in catalogue 2 |
| 196 | + ptr = psort.parallel_match(cat1_ids, cat2_ids, comm=comm) |
| 197 | + matched = ptr >= 0 |
| 198 | + |
| 199 | + # For each particle ID in catalogue 1, fetch the group membership of the matching ID in catalogue 2 |
| 200 | + cat1_grnr_in_cat2 = -np.ones_like(cat1_grnr_in_cat1) |
| 201 | + cat1_grnr_in_cat2[matched] = psort.fetch_elements(cat2_grnr_in_cat2, ptr[matched]) |
| 202 | + |
| 203 | + # Discard unmatched particles |
| 204 | + cat1_grnr_in_cat1 = cat1_grnr_in_cat1[matched] |
| 205 | + cat1_grnr_in_cat2 = cat1_grnr_in_cat2[matched] |
| 206 | + |
| 207 | + # Get sorted, unique (grnr1, grnr2) combinations and counts of how many instances of each we have |
| 208 | + assert np.all(cat1_grnr_in_cat1 < 2 ** 32) |
| 209 | + assert np.all(cat1_grnr_in_cat1 >= 0) |
| 210 | + assert np.all(cat1_grnr_in_cat2 < 2 ** 32) |
| 211 | + assert np.all(cat1_grnr_in_cat2 >= 0) |
| 212 | + sort_key = (cat1_grnr_in_cat1.astype(np.uint64) << 32) + cat1_grnr_in_cat2.astype( |
| 213 | + np.uint64 |
| 214 | + ) |
| 215 | + unique_value, cat1_count = psort.parallel_unique( |
| 216 | + sort_key, comm=comm, return_counts=True, repartition_output=True |
| 217 | + ) |
| 218 | + cat1_grnr_in_cat1 = (unique_value >> 32).astype( |
| 219 | + int |
| 220 | + ) # Cast to int because mixing signed and unsigned causes numpy to cast to float! |
| 221 | + cat1_grnr_in_cat2 = (unique_value % (1 << 32)).astype(int) |
| 222 | + |
| 223 | + # Send each (grnr1, grnr2, count) combination to the rank which will store the result for that halo |
| 224 | + if nr_cat1_per_rank > 0: |
| 225 | + dest = (cat1_grnr_in_cat1 // nr_cat1_per_rank).astype(int) |
| 226 | + dest[dest > comm_size - 1] = comm_size - 1 |
| 227 | + else: |
| 228 | + dest = np.empty_like(cat1_grnr_in_cat1, dtype=int) |
| 229 | + dest[:] = comm_size - 1 |
| 230 | + recv_grnr_in_cat1 = exchange_array(cat1_grnr_in_cat1, dest, comm) |
| 231 | + recv_grnr_in_cat2 = exchange_array(cat1_grnr_in_cat2, dest, comm) |
| 232 | + recv_count = exchange_array(cat1_count, dest, comm) |
| 233 | + |
| 234 | + # Allocate output arrays: |
| 235 | + # Each rank has nr_cat1_per_rank halos with any extras on the last rank |
| 236 | + first_in_cat1 = comm_rank * nr_cat1_per_rank |
| 237 | + result_grnr_in_cat2 = -np.ones( |
| 238 | + nr_cat1_local, dtype=int |
| 239 | + ) # For each halo in cat1, will store index of match in cat2 |
| 240 | + result_count = np.zeros( |
| 241 | + nr_cat1_local, dtype=int |
| 242 | + ) # Will store number of matching particles |
| 243 | + |
| 244 | + # Update output arrays using the received data. |
| 245 | + for recv_nr in range(len(recv_grnr_in_cat1)): |
| 246 | + # Compute local array index of halo to update |
| 247 | + local_halo_nr = recv_grnr_in_cat1[recv_nr] - first_in_cat1 |
| 248 | + assert local_halo_nr >= 0 |
| 249 | + assert local_halo_nr < nr_cat1_local |
| 250 | + # Check if the received count is higher than the highest so far |
| 251 | + if recv_count[recv_nr] > result_count[local_halo_nr]: |
| 252 | + # This received combination has the highest count so far |
| 253 | + result_grnr_in_cat2[local_halo_nr] = recv_grnr_in_cat2[recv_nr] |
| 254 | + result_count[local_halo_nr] = recv_count[recv_nr] |
| 255 | + elif recv_count[recv_nr] == result_count[local_halo_nr]: |
| 256 | + # In the event of a tie, go for the lowest group number for reproducibility |
| 257 | + if recv_grnr_in_cat2[recv_nr] < result_grnr_in_cat2[local_halo_nr]: |
| 258 | + result_grnr_in_cat2[local_halo_nr] = recv_grnr_in_cat2[recv_nr] |
| 259 | + result_count[local_halo_nr] = recv_count[recv_nr] |
| 260 | + |
| 261 | + return result_grnr_in_cat2, result_count |
| 262 | + |
| 263 | + |
| 264 | +def consistent_match(match_index_12, match_index_21): |
| 265 | + """ |
| 266 | + For each halo in catalogue 1, determine if its match in catalogue 2 |
| 267 | + points back at it. |
| 268 | +
|
| 269 | + match_index_12 has one entry for each halo in catalogue 1 and |
| 270 | + specifies the matching halo in catalogue 2 (or -1 for not match) |
| 271 | +
|
| 272 | + match_index_21 has one entry for each halo in catalogue 2 and |
| 273 | + specifies the matching halo in catalogue 1 (or -1 for not match) |
| 274 | +
|
| 275 | + Returns an array with 1 for a match and 0 otherwise. |
| 276 | + """ |
| 277 | + |
| 278 | + # Find the global array indexes of halos stored on this rank |
| 279 | + nr_local_halos = len(match_index_12) |
| 280 | + local_halo_offset = comm.scan(nr_local_halos) - nr_local_halos |
| 281 | + local_halo_index = np.arange( |
| 282 | + local_halo_offset, local_halo_offset + nr_local_halos, dtype=int |
| 283 | + ) |
| 284 | + |
| 285 | + # For each halo, find the halo that its match in the other catalogue was matched with |
| 286 | + match_back = -np.ones(nr_local_halos, dtype=int) |
| 287 | + has_match = match_index_12 >= 0 |
| 288 | + match_back[has_match] = psort.fetch_elements( |
| 289 | + match_index_21, match_index_12[has_match], comm=comm |
| 290 | + ) |
| 291 | + |
| 292 | + # If we retrieved our own halo index, we have a match |
| 293 | + return np.where(match_back == local_halo_index, 1, 0) |
| 294 | + |
| 295 | + |
| 296 | +def get_match_hbt_halos_args(comm): |
| 297 | + """ |
| 298 | + Process command line arguments for halo matching program. |
| 299 | +
|
| 300 | + Returns a dict with the argument values, or None on failure. |
| 301 | + """ |
| 302 | + |
| 303 | + from virgo.mpi.util import MPIArgumentParser |
| 304 | + |
| 305 | + parser = MPIArgumentParser( |
| 306 | + comm, description="Find matching halos between snapshots" |
| 307 | + ) |
| 308 | + parser.add_argument("hbt_basename1", help="Base name of the first set of HBT files") |
| 309 | + parser.add_argument( |
| 310 | + "hbt_basename2", help="Base name of the second set of HBT files" |
| 311 | + ) |
| 312 | + parser.add_argument( |
| 313 | + "nr_particles", |
| 314 | + metavar="N", |
| 315 | + type=int, |
| 316 | + help="Number of most bound particles to use.", |
| 317 | + ) |
| 318 | + parser.add_argument("output_file", help="Output file name") |
| 319 | + parser.add_argument( |
| 320 | + "--min-particle-id", |
| 321 | + nargs="*", |
| 322 | + type=int, |
| 323 | + help="Only use particle with ID >= this", |
| 324 | + ) |
| 325 | + parser.add_argument( |
| 326 | + "--max-particle-id", |
| 327 | + nargs="*", |
| 328 | + type=int, |
| 329 | + help="Only use particle with ID < this", |
| 330 | + ) |
| 331 | + parser.add_argument( |
| 332 | + "--to-field-halos-only", action="store_true", help="Only match to field halos" |
| 333 | + ) |
| 334 | + args = parser.parse_args() |
| 335 | + |
| 336 | + return args |
| 337 | + |
| 338 | + |
| 339 | +if __name__ == "__main__": |
| 340 | + |
| 341 | + # Read command line parameters |
| 342 | + args = get_match_hbt_halos_args(comm) |
| 343 | + |
| 344 | + # Ensure output dir exists |
| 345 | + if comm_rank == 0: |
| 346 | + lustre.ensure_output_dir(args.output_file) |
| 347 | + comm.barrier() |
| 348 | + |
| 349 | + # For each halo in output 1, find the matching halo in output 2 |
| 350 | + message("Matching from first catalogue to second") |
| 351 | + match_index_12, count_12 = find_matching_halos( |
| 352 | + args.hbt_basename1, |
| 353 | + args.hbt_basename2, |
| 354 | + args.nr_particles, |
| 355 | + args.min_particle_id, |
| 356 | + args.max_particle_id, |
| 357 | + args.to_field_halos_only, |
| 358 | + ) |
| 359 | + total_nr_halos = comm.allreduce(len(match_index_12)) |
| 360 | + total_nr_matched = comm.allreduce(np.sum(match_index_12 >= 0)) |
| 361 | + message(f" Matched {total_nr_matched} of {total_nr_halos} halos") |
| 362 | + |
| 363 | + # For each halo in output 2, find the matching halo in output 1 |
| 364 | + message("Matching from second catalogue to first") |
| 365 | + match_index_21, count_21 = find_matching_halos( |
| 366 | + args.hbt_basename2, |
| 367 | + args.hbt_basename1, |
| 368 | + args.nr_particles, |
| 369 | + args.min_particle_id, |
| 370 | + args.max_particle_id, |
| 371 | + args.to_field_halos_only, |
| 372 | + ) |
| 373 | + total_nr_halos = comm.allreduce(len(match_index_21)) |
| 374 | + total_nr_matched = comm.allreduce(np.sum(match_index_21 >= 0)) |
| 375 | + message(f" Matched {total_nr_matched} of {total_nr_halos} halos") |
| 376 | + |
| 377 | + # Check for consistent matches in both directions |
| 378 | + message("Checking for consistent matches") |
| 379 | + consistent_12 = consistent_match(match_index_12, match_index_21) |
| 380 | + consistent_21 = consistent_match(match_index_21, match_index_12) |
| 381 | + |
| 382 | + # Write the output |
| 383 | + def write_output_field(name, data, description): |
| 384 | + dataset = phdf5.collective_write(outfile, name, data, comm) |
| 385 | + dataset.attrs["Description"] = description |
| 386 | + |
| 387 | + message("Writing output") |
| 388 | + with h5py.File(args.output_file, "w", driver="mpio", comm=comm) as outfile: |
| 389 | + # Write input parameters |
| 390 | + params = outfile.create_group("Parameters") |
| 391 | + for name, value in vars(args).items(): |
| 392 | + if value is not None: |
| 393 | + params.attrs[name] = value |
| 394 | + # Matching from first catalogue to second |
| 395 | + write_output_field( |
| 396 | + "MatchIndex1to2", |
| 397 | + match_index_12, |
| 398 | + "For each halo in the first catalogue, index of the matching halo in the second", |
| 399 | + ) |
| 400 | + write_output_field( |
| 401 | + "MatchCount1to2", |
| 402 | + count_12, |
| 403 | + f"How many of the {args.nr_particles} most bound particles from the halo in the first catalogue are in the matched halo in the second", |
| 404 | + ) |
| 405 | + write_output_field( |
| 406 | + "Consistent1to2", |
| 407 | + consistent_12, |
| 408 | + "Whether the match from first to second catalogue is consistent with second to first (1) or not (0)", |
| 409 | + ) |
| 410 | + # Matching from second catalogue to first |
| 411 | + write_output_field( |
| 412 | + "MatchIndex2to1", |
| 413 | + match_index_21, |
| 414 | + "For each halo in the second catalogue, index of the matching halo in the first", |
| 415 | + ) |
| 416 | + write_output_field( |
| 417 | + "MatchCount2to1", |
| 418 | + count_21, |
| 419 | + f"How many of the {args.nr_particles} most bound particles from the halo in the second catalogue are in the matched halo in the first", |
| 420 | + ) |
| 421 | + write_output_field( |
| 422 | + "Consistent2to1", |
| 423 | + consistent_21, |
| 424 | + "Whether the match from second to first catalogue is consistent with first to second (1) or not (0)", |
| 425 | + ) |
| 426 | + comm.barrier() |
| 427 | + message("Done.") |
0 commit comments