GitXplorerGitXplorer
s

attn_saes

public
1 stars
0 forks
0 issues

Commits

List of commits on branch main.
Unverified
3727390539d16d9de76c4727c768084b88e29192

added dashboards for layer 5

sshehper committed 6 months ago
Unverified
0d8f1740022ac53f3eb8020baf7930a1ec2c0b3f

uploading dashboards for layer 4

sshehper committed 6 months ago
Unverified
b01bdb541c694f568afb08fea08644c25951e402

generate dashboards for block diagonal SAEs of gpt2-small

sshehper committed 6 months ago
Unverified
c25473bb6088ab5411312377c51ae39992f6b056

added README

sshehper committed 6 months ago
Unverified
b638fb46ca0e9785f4a8f9338ccb0105def6f5b2

a copy of joseph bloom's notebook that works with my standard architecture hook_z saes

sshehper committed 6 months ago
Unverified
6e1a6b2e668efa4d8816c3fd4bfffcec5a55e25f

train with block diagonal architecture on gpt attn layers

sshehper committed 6 months ago

README

The README file for this repository.

SAEs on Individual Attention Head Outputs

SAEs seem to work well on individual attention head outputs. I trained SAEs on concatenated outputs with block-diagonal encoder and decoder weight matrices. The code to train these models is in this branch of my fork of SAELens.

I trained SAEs on attention layer outputs of gelu-2l layer-1 and gpt2-small layers 4 and 5. In all cases, SAEs learned interpreable features.

Training block-diagonal SAEs have many advantages.

  • With the same expansion factor, the SAEs contain $\sim 1/{n_{\text{heads}}}$ times the number of trainable parameters as a fully dense SAEs. This leads to lower costs of training and inference.
  • Each feature is by-design attributed to a single head.
  • When given a task / prompt, we hope to be able to tell the role that each head plays by looking at the single-head features that activate in that task / prompt.
  • We can check with more certainty the claims of whether a head is "moonosemantic". If we find a feature that does not have the same interpretation as all of the other features in a head, the head will be considered polysemantic.

I include W&B logs and feature dashboards of the two gpt2-small SAEs.

A dashboard file contains feature dashboards of the first 10 features from each head. To look through them all, click on the top-left corner to change the number. Features 0-9 belong to head 0, 2048-2057 belong to head 1, and so on.

Reproduction

pip install the branch of SAELens mentioned above. Then do

python -u train_gpt_block_diag.py --l1_coefficients=[2] --hook_layer=4 --expansion_factor=32 --total_training_steps=50000
python -u train_gpt_block_diag.py --l1_coefficients=[3] --hook_layer=5 --expansion_factor=32 --total_training_steps=100000

Next, use gpt_block_diag_dashboards.ipynb in this repo to load and analyze dashboards.

TODO

  • The two SAEs have higher L0-norm values (460 and 278) than what is usually reported. It is possible that they are undertrained as the overall loss and the L0-norm were still decreasing when the training stopped. It is also possible that with the block-diagonal architecture, L0-norm per head (~38 and 22 respectively) and not the full L0-norm that needs to be small.

One possible explanation is that features are redundant across different heads. For example, if two different heads activate the same induction feature in a context, both will activate in my SAEs. In Kissane et al's SAE, there will perhaps be only one feature, which is a concatenation of multiple features from my SAEs (plus some arbitrary directions from other heads.)

  • Include "direct feature attribution by source position" in the feature dashboards, following Kissane et al. This will be needed to fully interpret features in attention head outputs.

  • Kissane et al (paper) (code) seemed to have gotten a better L0-MSE tradeoff than me. They followed Bricken et al while I followed April update. Then why could I not get a better tradeoff? (Am I comparing apples to oranges because of the difference in number of parameters?) They also say in their paper that they targeted L0 of 20 and 80% CE Loss Score. Should I consider sticking to this rule? As it is, my LO values are much higher.

  • I did not take care of good initialization of weights. Taking care of this could easily affect the L0-MSE tradeoff.

  • Speedup training using torch.compile, torch.amp.autocast, etc.