GitXplorerGitXplorer
s

attn_saes

public
1 stars
0 forks
0 issues

Commits

List of commits on branch main.
Unverified
76504d82231d40bc56bcc3aa8df433716e6985ed

updated readme

sshehper committed 5 months ago
Unverified
f167e01a79e30d6b80eeee858bae926013913b94

cleaning up repo; removing all unnecessary files

sshehper committed 5 months ago
Unverified
b5c07791106ecc05d04d6694c26af03fbff53c6b

visualizing attention patterns for different layers

sshehper committed 5 months ago
Unverified
4f17ddb231a6cfe7d369130d053de570d61e303e

added code to scale activations of each head separately

sshehper committed 6 months ago
Unverified
7416ffa289b823b90ec714b1191feb09bddf269a

l0 vs ce score for each head

sshehper committed 6 months ago
Unverified
432d01fe170e8d171d7948da503f29ea05ab80ab

added DFA by src in dashboards for better interpretability analysis

sshehper committed 6 months ago

README

The README file for this repository.

SAEs on Individual Attention Head Outputs

SAEs seem to work well on individual attention head outputs. I trained SAEs on concatenated outputs with block-diagonal encoder and decoder weight matrices. The code to train these models is in this branch of my fork of SAELens.

I trained SAEs on attention layer outputs of gelu-2l layer-1 and gpt2-small layers 4 and 5. In all cases, SAEs learned interpreable features.

Training block-diagonal SAEs have many advantages.

  • With the same expansion factor, the SAEs contain $\sim 1/{n_{\text{heads}}}$ times the number of trainable parameters as a fully dense SAEs. This leads to lower costs of training and inference.
  • Each feature is by-design attributed to a single head.
  • When given a task / prompt, we hope to be able to tell the role that each head plays by looking at the single-head features that activate in that task / prompt.
  • We can check with more certainty the claims of whether a head is "moonosemantic". If we find a feature that does not have the same interpretation as all of the other features in a head, the head will be considered polysemantic.

I include W&B logs and feature dashboards of the two gpt2-small SAEs.

A dashboard file contains feature dashboards of the first 10 features from each head. To look through them all, click on the top-left corner to change the number. Features 0-9 belong to head 0, 2048-2057 belong to head 1, and so on.

Reproduction

pip install the branch of SAELens mentioned above. Then do

python -u train_gpt_block_diag.py --l1_coefficients=[2] --hook_layer=4 --expansion_factor=32 --total_training_steps=50000
python -u train_gpt_block_diag.py --l1_coefficients=[3] --hook_layer=5 --expansion_factor=32 --total_training_steps=100000

Next, use gpt_block_diag_dashboards.ipynb in this repo to load and analyze dashboards.

TODO

  • The two SAEs have higher L0-norm values (460 and 278) than what is usually reported. It is possible that they are undertrained as the overall loss and the L0-norm were still decreasing when the training stopped. It is also possible that with the block-diagonal architecture, L0-norm per head (~38 and 22 respectively) and not the full L0-norm that needs to be small.

One possible explanation is that features are redundant across different heads. For example, if two different heads activate the same induction feature in a context, both will activate in my SAEs. In Kissane et al's SAE, there will perhaps be only one feature, which is a concatenation of multiple features from my SAEs (plus some arbitrary directions from other heads.)

  • Include "direct feature attribution by source position" in the feature dashboards, following Kissane et al. This will be needed to fully interpret features in attention head outputs.

  • Kissane et al (paper) (code) seemed to have gotten a better L0-MSE tradeoff than me. They followed Bricken et al while I followed April update. Then why could I not get a better tradeoff? (Am I comparing apples to oranges because of the difference in number of parameters?) They also say in their paper that they targeted L0 of 20 and 80% CE Loss Score. Should I consider sticking to this rule? As it is, my LO values are much higher.

  • I did not take care of good initialization of weights. Taking care of this could easily affect the L0-MSE tradeoff.

  • Speedup training using torch.compile, torch.amp.autocast, etc.