Skip to content

Enable extraction of gene embeddings from geneformer (averaging of gene embeddings across all cells) #452

@jstjohn

Description

@jstjohn

A potential design:

  1. add an argparse option for --num-layers-override in infer.py https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py#L235 with a default of None.
  2. Add logic to https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py#L34 where if the override is unset, nothing different happens
  3. If the override is set we need to do two things to make it impact the model:
    1. import this thing: https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py#L105
    2. add override_parent_fields=['num_layers'] + OVERRIDE_BIOBERT_CONFIG_DEFAULTS to the config_class (around here https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/infer_geneformer.py#L116) but only if the user set num_layers_override != None. This communicates to the checkpoint loader to not pull this field out of the trained model config in the checkpoint, and instead use the user supplied option for this field.
    3. also add num_layers=num_layers_override to the config around that point, but again only if the user set this to not None.

What will happen then is the model will be initialized with the user requested num layers rather than the num_layers it was originally trained with. So if you want to remove the last layer and get the inference results from that second to last layer, and you know the model was trained with 6 layers, then you could set --num-layers-override 5 and you would get back a 5 layer model with that last layer left off.

Side note: These steps are generally how you would override any setting in the loaded model. This pattern can be used for fine-tuning as well as inference if you want to change things about the model when you load it. Note that in the fine-tuning case, not here, if you add a new layer you also need to communicate to the checkpoint loader to not look for that new layer in the checkpoint, otherwise you get a confusing looking error about that layer not being found at checkpoint load time.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions