Skip to content

Commit 2271f97

Browse files
authoredDec 19, 2024
Add HBT matching script (#127)
1 parent 3ea0208 commit 2271f97

File tree

3 files changed

+438
-34
lines changed

3 files changed

+438
-34
lines changed
 

‎README.md

+5-31
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,9 @@ pip install snakeviz --user
194194
snakeviz -b "firefox -no-remote %s" ./profile.0.dat
195195
```
196196

197-
## Matching halos between VR outputs
197+
## Matching halos between outputs
198198

199-
Note that this requires the latest version of https://github.com/jchelly/VirgoDC
200-
201-
This repository also contains a program to find halos which contain the same
199+
This repository also contains a script to find halos which contain the same
202200
particle IDs between two outputs. It can be used to find the same halos between
203201
different snapshots or between hydro and dark matter only simulations.
204202

@@ -207,36 +205,12 @@ determine which halo in the second output contains the largest number of these
207205
IDs. This matching process is then repeated in the opposite direction and we
208206
check for cases were we have consistent matches in both directions.
209207

210-
### Running the program
211-
212-
It can be run as follows:
213-
```
214-
vr_basename1="./vr/catalogue_0012/vr_catalogue_0012"
215-
vr_basename2="./vr/catalogue_0013/vr_catalogue_0013"
216-
217-
outfile="halo_matching_0012_to_0013.hdf5"
218-
nr_particles=10
219-
220-
mpirun python3 -u -m mpi4py \
221-
./match_vr_halos.py ${vr_basename1} ${vr_basename2} \
222-
${nr_particles} ${outfile} --use-types 0 1 2 3 4 5
223-
```
224-
Here `nr_particles` is the number of most bound particles to use for matching.
225-
226-
### Matching using only specified particle types
227-
228-
The `--use-types` flag specifies which particle types to use for matching using
229-
the type numbering scheme from Swift. Only the specified types are included in
230-
the most bound particles used to match halos between snapshots. For example,
231-
`--use-types 1` will cause the code to track the `nr_particles` most bound dark
232-
matter particles from each halo.
233-
234208
### Matching to field halos only
235209

236-
The `--to-field-halos-only` flag can be used to match field halos (those with
237-
hostHaloID=-1 in the VR output) between outputs. If it is set we follow the
210+
The `--to-field-halos-only` flag can be used to match central halos
211+
between outputs. If it is set we follow the
238212
first `nr_particles` most bound particles from each halo as usual, but when
239-
locating them in the other output any particles in halos with hostHaloID>=0
213+
locating them in the other output any particles in satellite subhalos
240214
are treated as belonging to the host halo.
241215

242216
In this mode field halos in one catalogue will only ever be matched to field

‎match_hbt_halos.py

+427
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,427 @@
1+
#!/bin/env python
2+
3+
import os
4+
5+
import numpy as np
6+
import h5py
7+
8+
import virgo.mpi.parallel_sort as psort
9+
import virgo.mpi.parallel_hdf5 as phdf5
10+
11+
import lustre
12+
import read_hbtplus
13+
14+
import unyt
15+
import swift_cells
16+
17+
from mpi4py import MPI
18+
19+
comm = MPI.COMM_WORLD
20+
comm_rank = comm.Get_rank()
21+
comm_size = comm.Get_size()
22+
23+
# Maximum number of particle types
24+
NTYPEMAX = 7
25+
26+
27+
def message(s):
28+
if comm_rank == 0:
29+
print(s)
30+
31+
32+
def exchange_array(arr, dest, comm):
33+
"""
34+
Carry out an alltoallv on the supplied array, given the MPI rank
35+
to send each element to.
36+
"""
37+
order = np.argsort(dest)
38+
sendbuf = arr[order]
39+
send_count = np.bincount(dest, minlength=comm_size)
40+
send_offset = np.cumsum(send_count) - send_count
41+
recv_count = np.zeros_like(send_count)
42+
comm.Alltoall(send_count, recv_count)
43+
recv_offset = np.cumsum(recv_count) - recv_count
44+
recvbuf = np.ndarray(recv_count.sum(), dtype=arr.dtype)
45+
psort.my_alltoallv(
46+
sendbuf, send_count, send_offset, recvbuf, recv_count, recv_offset, comm=comm
47+
)
48+
return recvbuf
49+
50+
51+
# Define a placeholder unit system, since everything we do is dimensionless.
52+
# However, it would be better to load this from a snapshot.
53+
def define_unit_system():
54+
55+
# Create a registry using this base unit system
56+
reg = unyt.unit_registry.UnitRegistry()
57+
58+
# Add some units which might be useful for dealing with input halo catalogues
59+
unyt.define_unit("swift_mpc", 1.0 * unyt.cm, registry=reg)
60+
unyt.define_unit("swift_msun", 1.0 * unyt.g, registry=reg)
61+
unyt.define_unit("h", 1.0 * unyt.Hz, registry=reg)
62+
63+
return reg
64+
65+
66+
def find_matching_halos(
67+
base_name1,
68+
base_name2,
69+
max_nr_particles,
70+
min_particle_id,
71+
max_particle_id,
72+
field_only,
73+
):
74+
75+
# We only care about dimensionless quantities here, so
76+
# define a placeholder unit system
77+
registry = define_unit_system()
78+
a_unit = unyt.Unit("cm", registry=registry) ** 0
79+
boxsize = None
80+
81+
# Load halo data from the second catalogue
82+
keep_orphans = True
83+
halo_data2 = read_hbtplus.read_hbtplus_catalogue(
84+
comm, base_name2, a_unit, registry, boxsize, keep_orphans
85+
)
86+
87+
# Catalogue 2 properties
88+
track_ids2 = halo_data2["TrackId"]
89+
host_ids2 = halo_data2["HostHaloId"]
90+
is_central2 = halo_data2["is_central"]
91+
92+
# Find the track ID of the central halo that has the same hostHaloID
93+
central_index2 = psort.parallel_match(
94+
host_ids2, host_ids2[is_central2 == 1], comm=comm
95+
)
96+
host_track_ids2 = psort.fetch_elements(
97+
track_ids2[is_central2 == 1], central_index2, comm=comm
98+
)
99+
100+
# Find the index of that central halo (which may differ from the track ID)
101+
host_index2 = psort.parallel_match(host_track_ids2, track_ids2, comm=comm)
102+
103+
# Free the other halo data
104+
del halo_data2
105+
106+
# Find group membership for particles in the first catalogue:
107+
total_nr_halos1, cat1_ids, cat1_grnr_in_cat1, rank_bound1 = read_hbtplus.read_hbtplus_groupnr(
108+
base_name1
109+
)
110+
111+
# Find group membership for particles in the second catalogue
112+
total_nr_halos2, cat2_ids, cat2_grnr_in_cat2, rank_bound2 = read_hbtplus.read_hbtplus_groupnr(
113+
base_name2
114+
)
115+
116+
# Decide range of halos in cat1 which we'll store on each rank:
117+
# This is used to partition the result between MPI ranks.
118+
nr_cat1_tot = total_nr_halos1
119+
nr_cat1_per_rank = nr_cat1_tot // comm_size
120+
if comm_rank < comm_size - 1:
121+
nr_cat1_local = nr_cat1_per_rank
122+
else:
123+
nr_cat1_local = nr_cat1_tot - (comm_size - 1) * nr_cat1_per_rank
124+
125+
# Clear group membership for particles with invalid IDs
126+
if min_particle_id != None:
127+
discard = cat1_ids < min_particle_id
128+
cat1_grnr_in_cat1[discard] = -1
129+
130+
if max_particle_id != None:
131+
discard = cat1_ids >= max_particle_id
132+
cat1_grnr_in_cat1[discard] = -1
133+
134+
# If we're only matching to field halos, then any particles in the second catalogue which
135+
# belong to a halo with hostHaloID != -1 need to have their group membership reset to their
136+
# host halo.
137+
if field_only:
138+
# Find particles in halos in cat2
139+
in_halo = cat2_grnr_in_cat2 >= 0
140+
# Fetch host halo array index for each particle in cat2, or -1 if not in a halo
141+
particle_host_index = -np.ones_like(cat2_grnr_in_cat2)
142+
particle_host_index[in_halo] = psort.fetch_elements(
143+
host_index2, cat2_grnr_in_cat2[in_halo], comm=comm
144+
)
145+
# Where a particle's halo has a host halo, set its group membership to be the host halo
146+
have_host = particle_host_index >= 0
147+
cat2_grnr_in_cat2[have_host] = particle_host_index[have_host]
148+
149+
# Discard particles which are in no halo from each catalogue
150+
in_group = cat1_grnr_in_cat1 >= 0
151+
cat1_ids = cat1_ids[in_group]
152+
cat1_grnr_in_cat1 = cat1_grnr_in_cat1[in_group]
153+
in_group = cat2_grnr_in_cat2 >= 0
154+
cat2_ids = cat2_ids[in_group]
155+
cat2_grnr_in_cat2 = cat2_grnr_in_cat2[in_group]
156+
157+
# Now we need to identify the first max_nr_particles remaining particles for each
158+
# halo in catalogue 1. First, find the ranking of each particle within the part of
159+
# its group which is stored on this MPI rank. First particle in a group has rank 0.
160+
unique_grnr, unique_index, unique_count = np.unique(
161+
cat1_grnr_in_cat1, return_index=True, return_counts=True
162+
)
163+
cat1_rank_in_group = -np.ones_like(cat1_grnr_in_cat1)
164+
for ui, uc in zip(unique_index, unique_count):
165+
cat1_rank_in_group[ui : ui + uc] = np.arange(uc, dtype=int)
166+
assert np.all(cat1_rank_in_group >= 0)
167+
168+
# Then for the first group on each rank we'll need to add the total number of particles in
169+
# the same group on all lower numbered ranks. Since the particles are sorted by group this
170+
# can only ever be the last group on each lower numbered rank.
171+
if len(unique_grnr) > 0:
172+
# This rank has at least one particle in a group. Store indexes of first and last groups
173+
# and the number of particles from the last group which are stored on this rank.
174+
assert unique_index[0] == 0
175+
first_grnr = unique_grnr[0]
176+
last_grnr = unique_grnr[-1]
177+
last_grnr_count = unique_count[-1]
178+
else:
179+
# This rank has no particles in groups
180+
first_grnr = -1
181+
last_grnr = -1
182+
last_grnr_count = 0
183+
all_last_grnr = comm.allgather(last_grnr)
184+
all_last_grnr_count = comm.allgather(last_grnr_count)
185+
# Loop over lower numbered ranks
186+
for rank_nr in range(comm_rank):
187+
if first_grnr >= 0 and all_last_grnr[rank_nr] == first_grnr:
188+
cat1_rank_in_group[: unique_count[0]] += all_last_grnr_count[rank_nr]
189+
190+
# Only keep the first max_nr_particles remaining particles in each group in catalogue 1
191+
keep = cat1_rank_in_group < max_nr_particles
192+
cat1_ids = cat1_ids[keep]
193+
cat1_grnr_in_cat1 = cat1_grnr_in_cat1[keep]
194+
195+
# For each particle ID in catalogue 1, try to find the same particle ID in catalogue 2
196+
ptr = psort.parallel_match(cat1_ids, cat2_ids, comm=comm)
197+
matched = ptr >= 0
198+
199+
# For each particle ID in catalogue 1, fetch the group membership of the matching ID in catalogue 2
200+
cat1_grnr_in_cat2 = -np.ones_like(cat1_grnr_in_cat1)
201+
cat1_grnr_in_cat2[matched] = psort.fetch_elements(cat2_grnr_in_cat2, ptr[matched])
202+
203+
# Discard unmatched particles
204+
cat1_grnr_in_cat1 = cat1_grnr_in_cat1[matched]
205+
cat1_grnr_in_cat2 = cat1_grnr_in_cat2[matched]
206+
207+
# Get sorted, unique (grnr1, grnr2) combinations and counts of how many instances of each we have
208+
assert np.all(cat1_grnr_in_cat1 < 2 ** 32)
209+
assert np.all(cat1_grnr_in_cat1 >= 0)
210+
assert np.all(cat1_grnr_in_cat2 < 2 ** 32)
211+
assert np.all(cat1_grnr_in_cat2 >= 0)
212+
sort_key = (cat1_grnr_in_cat1.astype(np.uint64) << 32) + cat1_grnr_in_cat2.astype(
213+
np.uint64
214+
)
215+
unique_value, cat1_count = psort.parallel_unique(
216+
sort_key, comm=comm, return_counts=True, repartition_output=True
217+
)
218+
cat1_grnr_in_cat1 = (unique_value >> 32).astype(
219+
int
220+
) # Cast to int because mixing signed and unsigned causes numpy to cast to float!
221+
cat1_grnr_in_cat2 = (unique_value % (1 << 32)).astype(int)
222+
223+
# Send each (grnr1, grnr2, count) combination to the rank which will store the result for that halo
224+
if nr_cat1_per_rank > 0:
225+
dest = (cat1_grnr_in_cat1 // nr_cat1_per_rank).astype(int)
226+
dest[dest > comm_size - 1] = comm_size - 1
227+
else:
228+
dest = np.empty_like(cat1_grnr_in_cat1, dtype=int)
229+
dest[:] = comm_size - 1
230+
recv_grnr_in_cat1 = exchange_array(cat1_grnr_in_cat1, dest, comm)
231+
recv_grnr_in_cat2 = exchange_array(cat1_grnr_in_cat2, dest, comm)
232+
recv_count = exchange_array(cat1_count, dest, comm)
233+
234+
# Allocate output arrays:
235+
# Each rank has nr_cat1_per_rank halos with any extras on the last rank
236+
first_in_cat1 = comm_rank * nr_cat1_per_rank
237+
result_grnr_in_cat2 = -np.ones(
238+
nr_cat1_local, dtype=int
239+
) # For each halo in cat1, will store index of match in cat2
240+
result_count = np.zeros(
241+
nr_cat1_local, dtype=int
242+
) # Will store number of matching particles
243+
244+
# Update output arrays using the received data.
245+
for recv_nr in range(len(recv_grnr_in_cat1)):
246+
# Compute local array index of halo to update
247+
local_halo_nr = recv_grnr_in_cat1[recv_nr] - first_in_cat1
248+
assert local_halo_nr >= 0
249+
assert local_halo_nr < nr_cat1_local
250+
# Check if the received count is higher than the highest so far
251+
if recv_count[recv_nr] > result_count[local_halo_nr]:
252+
# This received combination has the highest count so far
253+
result_grnr_in_cat2[local_halo_nr] = recv_grnr_in_cat2[recv_nr]
254+
result_count[local_halo_nr] = recv_count[recv_nr]
255+
elif recv_count[recv_nr] == result_count[local_halo_nr]:
256+
# In the event of a tie, go for the lowest group number for reproducibility
257+
if recv_grnr_in_cat2[recv_nr] < result_grnr_in_cat2[local_halo_nr]:
258+
result_grnr_in_cat2[local_halo_nr] = recv_grnr_in_cat2[recv_nr]
259+
result_count[local_halo_nr] = recv_count[recv_nr]
260+
261+
return result_grnr_in_cat2, result_count
262+
263+
264+
def consistent_match(match_index_12, match_index_21):
265+
"""
266+
For each halo in catalogue 1, determine if its match in catalogue 2
267+
points back at it.
268+
269+
match_index_12 has one entry for each halo in catalogue 1 and
270+
specifies the matching halo in catalogue 2 (or -1 for not match)
271+
272+
match_index_21 has one entry for each halo in catalogue 2 and
273+
specifies the matching halo in catalogue 1 (or -1 for not match)
274+
275+
Returns an array with 1 for a match and 0 otherwise.
276+
"""
277+
278+
# Find the global array indexes of halos stored on this rank
279+
nr_local_halos = len(match_index_12)
280+
local_halo_offset = comm.scan(nr_local_halos) - nr_local_halos
281+
local_halo_index = np.arange(
282+
local_halo_offset, local_halo_offset + nr_local_halos, dtype=int
283+
)
284+
285+
# For each halo, find the halo that its match in the other catalogue was matched with
286+
match_back = -np.ones(nr_local_halos, dtype=int)
287+
has_match = match_index_12 >= 0
288+
match_back[has_match] = psort.fetch_elements(
289+
match_index_21, match_index_12[has_match], comm=comm
290+
)
291+
292+
# If we retrieved our own halo index, we have a match
293+
return np.where(match_back == local_halo_index, 1, 0)
294+
295+
296+
def get_match_hbt_halos_args(comm):
297+
"""
298+
Process command line arguments for halo matching program.
299+
300+
Returns a dict with the argument values, or None on failure.
301+
"""
302+
303+
from virgo.mpi.util import MPIArgumentParser
304+
305+
parser = MPIArgumentParser(
306+
comm, description="Find matching halos between snapshots"
307+
)
308+
parser.add_argument("hbt_basename1", help="Base name of the first set of HBT files")
309+
parser.add_argument(
310+
"hbt_basename2", help="Base name of the second set of HBT files"
311+
)
312+
parser.add_argument(
313+
"nr_particles",
314+
metavar="N",
315+
type=int,
316+
help="Number of most bound particles to use.",
317+
)
318+
parser.add_argument("output_file", help="Output file name")
319+
parser.add_argument(
320+
"--min-particle-id",
321+
nargs="*",
322+
type=int,
323+
help="Only use particle with ID >= this",
324+
)
325+
parser.add_argument(
326+
"--max-particle-id",
327+
nargs="*",
328+
type=int,
329+
help="Only use particle with ID < this",
330+
)
331+
parser.add_argument(
332+
"--to-field-halos-only", action="store_true", help="Only match to field halos"
333+
)
334+
args = parser.parse_args()
335+
336+
return args
337+
338+
339+
if __name__ == "__main__":
340+
341+
# Read command line parameters
342+
args = get_match_hbt_halos_args(comm)
343+
344+
# Ensure output dir exists
345+
if comm_rank == 0:
346+
lustre.ensure_output_dir(args.output_file)
347+
comm.barrier()
348+
349+
# For each halo in output 1, find the matching halo in output 2
350+
message("Matching from first catalogue to second")
351+
match_index_12, count_12 = find_matching_halos(
352+
args.hbt_basename1,
353+
args.hbt_basename2,
354+
args.nr_particles,
355+
args.min_particle_id,
356+
args.max_particle_id,
357+
args.to_field_halos_only,
358+
)
359+
total_nr_halos = comm.allreduce(len(match_index_12))
360+
total_nr_matched = comm.allreduce(np.sum(match_index_12 >= 0))
361+
message(f" Matched {total_nr_matched} of {total_nr_halos} halos")
362+
363+
# For each halo in output 2, find the matching halo in output 1
364+
message("Matching from second catalogue to first")
365+
match_index_21, count_21 = find_matching_halos(
366+
args.hbt_basename2,
367+
args.hbt_basename1,
368+
args.nr_particles,
369+
args.min_particle_id,
370+
args.max_particle_id,
371+
args.to_field_halos_only,
372+
)
373+
total_nr_halos = comm.allreduce(len(match_index_21))
374+
total_nr_matched = comm.allreduce(np.sum(match_index_21 >= 0))
375+
message(f" Matched {total_nr_matched} of {total_nr_halos} halos")
376+
377+
# Check for consistent matches in both directions
378+
message("Checking for consistent matches")
379+
consistent_12 = consistent_match(match_index_12, match_index_21)
380+
consistent_21 = consistent_match(match_index_21, match_index_12)
381+
382+
# Write the output
383+
def write_output_field(name, data, description):
384+
dataset = phdf5.collective_write(outfile, name, data, comm)
385+
dataset.attrs["Description"] = description
386+
387+
message("Writing output")
388+
with h5py.File(args.output_file, "w", driver="mpio", comm=comm) as outfile:
389+
# Write input parameters
390+
params = outfile.create_group("Parameters")
391+
for name, value in vars(args).items():
392+
if value is not None:
393+
params.attrs[name] = value
394+
# Matching from first catalogue to second
395+
write_output_field(
396+
"MatchIndex1to2",
397+
match_index_12,
398+
"For each halo in the first catalogue, index of the matching halo in the second",
399+
)
400+
write_output_field(
401+
"MatchCount1to2",
402+
count_12,
403+
f"How many of the {args.nr_particles} most bound particles from the halo in the first catalogue are in the matched halo in the second",
404+
)
405+
write_output_field(
406+
"Consistent1to2",
407+
consistent_12,
408+
"Whether the match from first to second catalogue is consistent with second to first (1) or not (0)",
409+
)
410+
# Matching from second catalogue to first
411+
write_output_field(
412+
"MatchIndex2to1",
413+
match_index_21,
414+
"For each halo in the second catalogue, index of the matching halo in the first",
415+
)
416+
write_output_field(
417+
"MatchCount2to1",
418+
count_21,
419+
f"How many of the {args.nr_particles} most bound particles from the halo in the second catalogue are in the matched halo in the first",
420+
)
421+
write_output_field(
422+
"Consistent2to1",
423+
consistent_21,
424+
"Whether the match from second to first catalogue is consistent with first to second (1) or not (0)",
425+
)
426+
comm.barrier()
427+
message("Done.")

‎read_hbtplus.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def read_hbtplus_groupnr(basename):
117117
return total_nr_halos, ids_bound, grnr_bound, rank_bound
118118

119119

120-
def read_hbtplus_catalogue(comm, basename, a_unit, registry, boxsize):
120+
def read_hbtplus_catalogue(comm, basename, a_unit, registry, boxsize, keep_orphans=False):
121121
"""
122122
Read in the HBTplus halo catalogue, distributed over communicator comm.
123123
@@ -204,8 +204,11 @@ def read_hbtplus_catalogue(comm, basename, a_unit, registry, boxsize):
204204
subhalo["Nbound"], units=unyt.dimensionless, dtype=int, registry=registry
205205
)
206206

207-
# Only process resolved subhalos (HBTplus also outputs unresolved "orphan" subhalos)
208-
keep = nr_bound_part > 1
207+
# Do we only process resolved subhalos? (HBT also outputs unresolved "orphan" subhalos)
208+
if not keep_orphans:
209+
keep = nr_bound_part > 0
210+
else:
211+
keep = np.ones_like(nr_bound_part, dtype=bool)
209212

210213
# Assign indexes to halos: for each halo we're going to process we store the
211214
# position in the input catalogue.

0 commit comments

Comments
 (0)
Please sign in to comment.