Skip to content

Commit 9e9e85a

Browse files
committed
sum aggregation works now; refactoring some stuff
1 parent b541703 commit 9e9e85a

File tree

4 files changed

+667
-224
lines changed

4 files changed

+667
-224
lines changed

attribution.py

+24-71
Original file line numberDiff line numberDiff line change
@@ -2,58 +2,14 @@
22
import torch as t
33
from tqdm import tqdm
44
from numpy import ndindex
5+
from loading_utils import Submodule
56
from activation_utils import SparseAct
6-
from dataclasses import dataclass
77
from nnsight.envoy import Envoy
88
from dictionary_learning.dictionary import Dictionary
99
from typing import Callable
1010

11-
DEBUGGING = False
12-
13-
if DEBUGGING:
14-
tracer_kwargs = {'validate' : True, 'scan' : True}
15-
else:
16-
tracer_kwargs = {'validate' : False, 'scan' : False}
17-
1811
EffectOut = namedtuple('EffectOut', ['effects', 'deltas', 'grads', 'total_effect'])
1912

20-
@dataclass(frozen=True)
21-
class Submodule:
22-
name: str
23-
submodule: Envoy
24-
use_input: bool = False
25-
is_tuple: bool = False
26-
27-
def __hash__(self):
28-
return hash(self.name)
29-
30-
def get_activation(self):
31-
if self.use_input:
32-
return self.submodule.input # TODO make sure I didn't break for pythia
33-
else:
34-
if self.is_tuple:
35-
return self.submodule.output[0]
36-
else:
37-
return self.submodule.output
38-
39-
def set_activation(self, x):
40-
if self.use_input:
41-
self.submodule.input[:] = x
42-
else:
43-
if self.is_tuple:
44-
self.submodule.output[0][:] = x
45-
else:
46-
self.submodule.output = x
47-
48-
def stop_grad(self):
49-
if self.use_input:
50-
self.submodule.input.grad = t.zeros_like(self.submodule.input)
51-
else:
52-
if self.is_tuple:
53-
self.submodule.output[0].grad = t.zeros_like(self.submodule.output[0])
54-
else:
55-
self.submodule.output.grad = t.zeros_like(self.submodule.output)
56-
5713

5814
def _pe_attrib(
5915
clean,
@@ -66,7 +22,7 @@ def _pe_attrib(
6622
):
6723
hidden_states_clean = {}
6824
grads = {}
69-
with model.trace(clean, **tracer_kwargs):
25+
with model.trace(clean):
7026
for submodule in submodules:
7127
dictionary = dictionaries[submodule]
7228
x = submodule.get_activation()
@@ -90,7 +46,7 @@ def _pe_attrib(
9046
total_effect = None
9147
else:
9248
hidden_states_patch = {}
93-
with model.trace(patch, **tracer_kwargs), t.inference_mode():
49+
with t.no_grad(), model.trace(patch):
9450
for submodule in submodules:
9551
dictionary = dictionaries[submodule]
9652
x = submodule.get_activation()
@@ -126,7 +82,7 @@ def _pe_ig(
12682
metric_kwargs=dict(),
12783
):
12884
hidden_states_clean = {}
129-
with t.no_grad(), model.trace(clean, **tracer_kwargs):
85+
with t.no_grad(), model.trace(clean):
13086
for submodule in submodules:
13187
dictionary = dictionaries[submodule]
13288
x = submodule.get_activation()
@@ -144,7 +100,7 @@ def _pe_ig(
144100
total_effect = None
145101
else:
146102
hidden_states_patch = {}
147-
with t.no_grad(), model.trace(patch, **tracer_kwargs):
103+
with t.no_grad(), model.trace(patch):
148104
for submodule in submodules:
149105
dictionary = dictionaries[submodule]
150106
x = submodule.get_activation()
@@ -163,7 +119,7 @@ def _pe_ig(
163119
dictionary = dictionaries[submodule]
164120
clean_state = hidden_states_clean[submodule]
165121
patch_state = hidden_states_patch[submodule]
166-
with model.trace(**tracer_kwargs) as tracer:
122+
with model.trace() as tracer:
167123
metrics = []
168124
fs = []
169125
for step in range(steps):
@@ -172,7 +128,7 @@ def _pe_ig(
172128
f.act.requires_grad_().retain_grad()
173129
f.res.requires_grad_().retain_grad()
174130
fs.append(f)
175-
with tracer.invoke(clean, scan=tracer_kwargs['scan']):
131+
with tracer.invoke(clean):
176132
submodule.set_activation(dictionary.decode(f.act) + f.res)
177133
metrics.append(metric_fn(model, **metric_kwargs))
178134
metric = sum([m for m in metrics])
@@ -200,7 +156,7 @@ def _pe_exact(
200156
metric_fn,
201157
):
202158
hidden_states_clean = {}
203-
with model.trace(clean, **tracer_kwargs), t.inference_mode():
159+
with t.no_grad(), model.trace(clean):
204160
for submodule in submodules:
205161
dictionary = dictionaries[submodule]
206162
x = submodule.get_activation()
@@ -218,7 +174,7 @@ def _pe_exact(
218174
total_effect = None
219175
else:
220176
hidden_states_patch = {}
221-
with model.trace(patch, **tracer_kwargs), t.inference_mode():
177+
with t.no_grad(), model.trace(patch):
222178
for submodule in submodules:
223179
dictionary = dictionaries[submodule]
224180
x = submodule.get_activation()
@@ -241,24 +197,22 @@ def _pe_exact(
241197
# iterate over positions and features for which clean and patch differ
242198
idxs = t.nonzero(patch_state.act - clean_state.act)
243199
for idx in tqdm(idxs):
244-
with t.inference_mode():
245-
with model.trace(clean, **tracer_kwargs):
246-
f = clean_state.act.clone()
247-
f[tuple(idx)] = patch_state.act[tuple(idx)]
248-
x_hat = dictionary.decode(f)
249-
submodule.set_activation(x_hat + clean_state.res)
250-
metric = metric_fn(model).save()
251-
effect.act[tuple(idx)] = (metric.value - metric_clean.value).sum()
200+
with t.no_grad(), model.trace(clean):
201+
f = clean_state.act.clone()
202+
f[tuple(idx)] = patch_state.act[tuple(idx)]
203+
x_hat = dictionary.decode(f)
204+
submodule.set_activation(x_hat + clean_state.res)
205+
metric = metric_fn(model).save()
206+
effect.act[tuple(idx)] = (metric.value - metric_clean.value).sum()
252207

253208
for idx in list(ndindex(effect.resc.shape)): # type: ignore
254-
with t.inference_mode():
255-
with model.trace(clean, **tracer_kwargs):
256-
res = clean_state.res.clone()
257-
res[tuple(idx)] = patch_state.res[tuple(idx)] # type: ignore
258-
x_hat = dictionary.decode(clean_state.act)
259-
submodule.set_activation(x_hat + res)
260-
metric = metric_fn(model).save()
261-
effect.resc[tuple(idx)] = (metric.value - metric_clean.value).sum() # type: ignore
209+
with t.no_grad(), model.trace(clean):
210+
res = clean_state.res.clone()
211+
res[tuple(idx)] = patch_state.res[tuple(idx)] # type: ignore
212+
x_hat = dictionary.decode(clean_state.act)
213+
submodule.set_activation(x_hat + res)
214+
metric = metric_fn(model).save()
215+
effect.resc[tuple(idx)] = (metric.value - metric_clean.value).sum() # type: ignore
262216

263217
effects[submodule] = effect
264218
deltas[submodule] = patch_state - clean_state
@@ -311,8 +265,7 @@ def jvp(
311265
vjv_values = {}
312266

313267
downstream_feature_idxs = downstream_features.to_tensor().nonzero()
314-
with model.trace(input, **tracer_kwargs):
315-
268+
with model.trace(input):
316269
# forward pass modifications
317270
x = upstream_submod.get_activation()
318271
x_hat, f = upstream_dict.hacked_forward_for_sfc(x) # hacking around an nnsight bug

circuit.py

+23-33
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,11 @@
1111
from attribution import patching_effect, jvp
1212
from circuit_plotting import plot_circuit, plot_circuit_posaligned
1313
from dictionary_learning import AutoEncoder
14-
from loading_utils import load_examples, load_examples_nopair
14+
from data_loading_utils import load_examples, load_examples_nopair
1515
from dictionary_loading_utils import load_saes_and_submodules
1616
from nnsight import LanguageModel
1717
from coo_utils import sparse_reshape
1818

19-
DEBUGGING = True
20-
21-
if DEBUGGING:
22-
tracer_kwargs = {"validate": True, "scan": True}
23-
else:
24-
tracer_kwargs = {"validate": False, "scan": False}
25-
26-
2719
def get_circuit(
2820
clean,
2921
patch,
@@ -35,11 +27,10 @@ def get_circuit(
3527
dictionaries,
3628
metric_fn,
3729
metric_kwargs=dict(),
38-
aggregation="sum", # or 'none' for not aggregating across sequence position
30+
aggregation="sum", # or "none" for not aggregating across sequence position
3931
nodes_only=False,
4032
parallel_attn=False,
4133
node_threshold=0.1,
42-
edge_threshold=0.01,
4334
):
4435
all_submods = ([embed] if embed is not None else []) + [
4536
submod for layer_submods in zip(attns, mlps, resids) for submod in layer_submods
@@ -109,7 +100,7 @@ def N(upstream, downstream, midstream=[]):
109100
edges[f"mlp_{layer}"][f"resid_{layer}"] = MR_effect
110101
edges[f"attn_{layer}"][f"resid_{layer}"] = AR_effect
111102

112-
if parallel_attn:
103+
if not parallel_attn:
113104
AM_effect = N(attn, mlp)
114105
edges[f"attn_{layer}"][f"mlp_{layer}"] = AM_effect
115106

@@ -160,13 +151,13 @@ def N(upstream, downstream, midstream=[]):
160151

161152
# aggregate across batch dimension
162153
for child in edges:
163-
bc, fc = nodes[child].act.shape
154+
bc, _ = nodes[child].act.shape
164155
for parent in edges[child]:
165156
weight_matrix = edges[child][parent]
166157
if parent == "y":
167158
weight_matrix = weight_matrix.sum(dim=0) / bc
168159
else:
169-
bp, fp = nodes[parent].act.shape
160+
bp, _ = nodes[parent].act.shape
170161
assert bp == bc
171162
weight_matrix = weight_matrix.sum(dim=(0, 2)) / bc
172163
edges[child][parent] = weight_matrix
@@ -465,8 +456,8 @@ def metric_fn(model):
465456
"google/gemma-2-2b": 26,
466457
}[args.model]
467458
parallel_attn = {
468-
"EleutherAI/pythia-70m-deduped": False,
469-
"google/gemma-2-2b": True,
459+
"EleutherAI/pythia-70m-deduped": True,
460+
"google/gemma-2-2b": False,
470461
}[args.model]
471462
include_embed = {
472463
"EleutherAI/pythia-70m-deduped": True,
@@ -478,7 +469,7 @@ def metric_fn(model):
478469
}[args.model]
479470

480471
if args.model == "EleutherAI/pythia-70m-deduped":
481-
model = LanguageModel(args.model, device_map=device, dispatch=True)
472+
model = LanguageModel(args.model, device_map=device, dispatch=True, torch_dtype=dtype)
482473
elif args.model == "google/gemma-2-2b":
483474
model = LanguageModel(
484475
args.model,
@@ -498,17 +489,17 @@ def metric_fn(model):
498489
)
499490
else:
500491
data_path = f"data/{args.dataset}.json"
501-
if args.aggregation == "sum":
502-
raise NotImplementedError(
503-
"Sum aggregation is not yet implemented for new data loading."
504-
)
505-
examples = load_examples(
506-
data_path, args.num_examples, model, pad_to_length=args.example_length
507-
)
508-
else:
509-
examples = load_examples(
510-
data_path, args.num_examples, model, use_min_length_only=True
511-
)
492+
# if args.aggregation == "sum":
493+
# raise NotImplementedError(
494+
# "Sum aggregation is not yet implemented for new data loading."
495+
# )
496+
# examples = load_examples(
497+
# data_path, args.num_examples, model, pad_to_length=args.example_length
498+
# )
499+
# else:
500+
examples = load_examples(
501+
data_path, args.num_examples, model, use_min_length_only=True
502+
)
512503

513504
num_examples = min([args.num_examples, len(examples)])
514505
if num_examples < args.num_examples: # warn the user
@@ -616,7 +607,6 @@ def metric_fn(model):
616607
nodes_only=args.nodes_only,
617608
aggregation=args.aggregation,
618609
node_threshold=args.node_threshold,
619-
edge_threshold=args.edge_threshold,
620610
parallel_attn=parallel_attn,
621611
)
622612

@@ -688,11 +678,9 @@ def metric_fn(model):
688678
annotations=annotations,
689679
save_dir=f"{args.plot_dir}/{save_base}_node{args.node_threshold}_edge{args.edge_threshold}",
690680
gemma_mode=(args.model == "google/gemma-2-2b"),
681+
parallel_attn=parallel_attn,
691682
)
692683
else:
693-
raise NotImplementedError(
694-
"Sum aggregation is not yet implemented for new data loading."
695-
)
696684
plot_circuit(
697685
nodes,
698686
edges,
@@ -701,5 +689,7 @@ def metric_fn(model):
701689
edge_threshold=args.edge_threshold,
702690
pen_thickness=args.pen_thickness,
703691
annotations=annotations,
704-
save_dir=f"{args.plot_dir}/{save_basename}_dict{args.dict_id}_node{args.node_threshold}_edge{args.edge_threshold}_n{num_examples}_agg{args.aggregation}",
692+
save_dir=f"{args.plot_dir}/{save_base}_node{args.node_threshold}_edge{args.edge_threshold}_n{num_examples}_agg{args.aggregation}",
693+
gemma_mode=(args.model == "google/gemma-2-2b"),
694+
parallel_attn=parallel_attn,
705695
)

0 commit comments

Comments
 (0)