2
2
import torch as t
3
3
from tqdm import tqdm
4
4
from numpy import ndindex
5
+ from loading_utils import Submodule
5
6
from activation_utils import SparseAct
6
- from dataclasses import dataclass
7
7
from nnsight .envoy import Envoy
8
8
from dictionary_learning .dictionary import Dictionary
9
9
from typing import Callable
10
10
11
- DEBUGGING = False
12
-
13
- if DEBUGGING :
14
- tracer_kwargs = {'validate' : True , 'scan' : True }
15
- else :
16
- tracer_kwargs = {'validate' : False , 'scan' : False }
17
-
18
11
EffectOut = namedtuple ('EffectOut' , ['effects' , 'deltas' , 'grads' , 'total_effect' ])
19
12
20
- @dataclass (frozen = True )
21
- class Submodule :
22
- name : str
23
- submodule : Envoy
24
- use_input : bool = False
25
- is_tuple : bool = False
26
-
27
- def __hash__ (self ):
28
- return hash (self .name )
29
-
30
- def get_activation (self ):
31
- if self .use_input :
32
- return self .submodule .input # TODO make sure I didn't break for pythia
33
- else :
34
- if self .is_tuple :
35
- return self .submodule .output [0 ]
36
- else :
37
- return self .submodule .output
38
-
39
- def set_activation (self , x ):
40
- if self .use_input :
41
- self .submodule .input [:] = x
42
- else :
43
- if self .is_tuple :
44
- self .submodule .output [0 ][:] = x
45
- else :
46
- self .submodule .output = x
47
-
48
- def stop_grad (self ):
49
- if self .use_input :
50
- self .submodule .input .grad = t .zeros_like (self .submodule .input )
51
- else :
52
- if self .is_tuple :
53
- self .submodule .output [0 ].grad = t .zeros_like (self .submodule .output [0 ])
54
- else :
55
- self .submodule .output .grad = t .zeros_like (self .submodule .output )
56
-
57
13
58
14
def _pe_attrib (
59
15
clean ,
@@ -66,7 +22,7 @@ def _pe_attrib(
66
22
):
67
23
hidden_states_clean = {}
68
24
grads = {}
69
- with model .trace (clean , ** tracer_kwargs ):
25
+ with model .trace (clean ):
70
26
for submodule in submodules :
71
27
dictionary = dictionaries [submodule ]
72
28
x = submodule .get_activation ()
@@ -90,7 +46,7 @@ def _pe_attrib(
90
46
total_effect = None
91
47
else :
92
48
hidden_states_patch = {}
93
- with model . trace ( patch , ** tracer_kwargs ), t . inference_mode ( ):
49
+ with t . no_grad ( ), model . trace ( patch ):
94
50
for submodule in submodules :
95
51
dictionary = dictionaries [submodule ]
96
52
x = submodule .get_activation ()
@@ -126,7 +82,7 @@ def _pe_ig(
126
82
metric_kwargs = dict (),
127
83
):
128
84
hidden_states_clean = {}
129
- with t .no_grad (), model .trace (clean , ** tracer_kwargs ):
85
+ with t .no_grad (), model .trace (clean ):
130
86
for submodule in submodules :
131
87
dictionary = dictionaries [submodule ]
132
88
x = submodule .get_activation ()
@@ -144,7 +100,7 @@ def _pe_ig(
144
100
total_effect = None
145
101
else :
146
102
hidden_states_patch = {}
147
- with t .no_grad (), model .trace (patch , ** tracer_kwargs ):
103
+ with t .no_grad (), model .trace (patch ):
148
104
for submodule in submodules :
149
105
dictionary = dictionaries [submodule ]
150
106
x = submodule .get_activation ()
@@ -163,7 +119,7 @@ def _pe_ig(
163
119
dictionary = dictionaries [submodule ]
164
120
clean_state = hidden_states_clean [submodule ]
165
121
patch_state = hidden_states_patch [submodule ]
166
- with model .trace (** tracer_kwargs ) as tracer :
122
+ with model .trace () as tracer :
167
123
metrics = []
168
124
fs = []
169
125
for step in range (steps ):
@@ -172,7 +128,7 @@ def _pe_ig(
172
128
f .act .requires_grad_ ().retain_grad ()
173
129
f .res .requires_grad_ ().retain_grad ()
174
130
fs .append (f )
175
- with tracer .invoke (clean , scan = tracer_kwargs [ 'scan' ] ):
131
+ with tracer .invoke (clean ):
176
132
submodule .set_activation (dictionary .decode (f .act ) + f .res )
177
133
metrics .append (metric_fn (model , ** metric_kwargs ))
178
134
metric = sum ([m for m in metrics ])
@@ -200,7 +156,7 @@ def _pe_exact(
200
156
metric_fn ,
201
157
):
202
158
hidden_states_clean = {}
203
- with model . trace ( clean , ** tracer_kwargs ), t . inference_mode ( ):
159
+ with t . no_grad ( ), model . trace ( clean ):
204
160
for submodule in submodules :
205
161
dictionary = dictionaries [submodule ]
206
162
x = submodule .get_activation ()
@@ -218,7 +174,7 @@ def _pe_exact(
218
174
total_effect = None
219
175
else :
220
176
hidden_states_patch = {}
221
- with model . trace ( patch , ** tracer_kwargs ), t . inference_mode ( ):
177
+ with t . no_grad ( ), model . trace ( patch ):
222
178
for submodule in submodules :
223
179
dictionary = dictionaries [submodule ]
224
180
x = submodule .get_activation ()
@@ -241,24 +197,22 @@ def _pe_exact(
241
197
# iterate over positions and features for which clean and patch differ
242
198
idxs = t .nonzero (patch_state .act - clean_state .act )
243
199
for idx in tqdm (idxs ):
244
- with t .inference_mode ():
245
- with model .trace (clean , ** tracer_kwargs ):
246
- f = clean_state .act .clone ()
247
- f [tuple (idx )] = patch_state .act [tuple (idx )]
248
- x_hat = dictionary .decode (f )
249
- submodule .set_activation (x_hat + clean_state .res )
250
- metric = metric_fn (model ).save ()
251
- effect .act [tuple (idx )] = (metric .value - metric_clean .value ).sum ()
200
+ with t .no_grad (), model .trace (clean ):
201
+ f = clean_state .act .clone ()
202
+ f [tuple (idx )] = patch_state .act [tuple (idx )]
203
+ x_hat = dictionary .decode (f )
204
+ submodule .set_activation (x_hat + clean_state .res )
205
+ metric = metric_fn (model ).save ()
206
+ effect .act [tuple (idx )] = (metric .value - metric_clean .value ).sum ()
252
207
253
208
for idx in list (ndindex (effect .resc .shape )): # type: ignore
254
- with t .inference_mode ():
255
- with model .trace (clean , ** tracer_kwargs ):
256
- res = clean_state .res .clone ()
257
- res [tuple (idx )] = patch_state .res [tuple (idx )] # type: ignore
258
- x_hat = dictionary .decode (clean_state .act )
259
- submodule .set_activation (x_hat + res )
260
- metric = metric_fn (model ).save ()
261
- effect .resc [tuple (idx )] = (metric .value - metric_clean .value ).sum () # type: ignore
209
+ with t .no_grad (), model .trace (clean ):
210
+ res = clean_state .res .clone ()
211
+ res [tuple (idx )] = patch_state .res [tuple (idx )] # type: ignore
212
+ x_hat = dictionary .decode (clean_state .act )
213
+ submodule .set_activation (x_hat + res )
214
+ metric = metric_fn (model ).save ()
215
+ effect .resc [tuple (idx )] = (metric .value - metric_clean .value ).sum () # type: ignore
262
216
263
217
effects [submodule ] = effect
264
218
deltas [submodule ] = patch_state - clean_state
@@ -311,8 +265,7 @@ def jvp(
311
265
vjv_values = {}
312
266
313
267
downstream_feature_idxs = downstream_features .to_tensor ().nonzero ()
314
- with model .trace (input , ** tracer_kwargs ):
315
-
268
+ with model .trace (input ):
316
269
# forward pass modifications
317
270
x = upstream_submod .get_activation ()
318
271
x_hat , f = upstream_dict .hacked_forward_for_sfc (x ) # hacking around an nnsight bug
0 commit comments