Skip to content

b_dec is incorrectly added to topk aux loss #132

@chanind

Description

@chanind

Looking at sparse_coder.py#L247, the code calls decode() on the auxk_acts to get e_hat, but calling decode() also adds in the decoder bias b_dec. This seems like a mistake, since e, the residual, is the difference between sae_out and y, and this should already take care of b_dec. This aux loss will then pull dead latents towards e - b_dec rather than just towards e as is likely intended, and is likely also causing an unintended gradient on b_dec as well.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions