Please install the dependencies and download the pre-trained models according to the instructions in README.md.
| Task | Huggingface Dataset Link |
|---|---|
| pure text | qiaojin/PubMedQA |
| text-to-image | jasonhuang23/artwork |
| VQA | rbojja/medical-vqa |
We will use the above datasets for JanusPro fine-tuning. Please download them by:
huggingface-cli download qiaojin/PubMedQA --repo-type dataset --local-dir datasets/PubMedQA
huggingface-cli download jasonhuang23/artwork --repo-type dataset --local-dir datasets/artwork
huggingface-cli download rbojja/medical-vqa --repo-type dataset --local-dir datasets/medical-vqaBefore launching sft training with the scripts under ../scripts/, we need to setup the meta env var YOUR_DATA_PATH and YOUR_DOWNLOADED_JANUS_CKPT_PATH for each script.
After setting up paths as above, you are good to go.
- Multimodal Understanding Task (VQA)
bash scripts/run_sft_vqa.sh- Text Generation Task
bash scripts/run_sft_text.sh # if no manual patching, by default it should be changed into pynativePatching janus/models/modeling_vlm.py: Single task for pure text
# @ L431
-- def construct(
++ # def construct( # just comment the whole function out
# @ L479
-- def construct_graph_single_task(
++ def construct(- Text-to-Image Generation Task (T2I)
bash scripts/run_sft_t2i.shThe default training stage is stage 3, that is, all modules are trainable except for VQ16 for image token decoding. To switch to other stage, you can modify the --stage argument in the training script.
For more detailed arguments, please run python train.py -h.
bash scripts/run_sft_mixed_graph.shWe also implemented a stage-3 SFT for medical data aiming for building a radiology expert model. The datasets can be retrieved from huggingface with from the following repos.
| #Data Samples | HuggingFace Source | |
|---|---|---|
| VQA | 100 | robojja/medical-vqa |
| pure-text | 20 | qiaojin/PubmeQA |
| T2I | 80 | mdwiratathya/ROCO-radiology |
Note
We achieve higher training throughput by enabling graph mode compute. However, to do that we need to predefine a compute graph for the vlm for each of the task out of three in total, as for each task, the vlm takes different types of input arg pairs. This feature is for MindSpore 2.5.0 only. It is no longer supported in Mindspore 2.7.0
To run scripts/run_sft_mixed_graph.sh, simply go into janus/models/modeling_vlm.py, and patch construct_*() into construct() as follows.
# @ L431
-- def construct(
++ # def construct( # just comment the whole function out
# @ L573
-- def construct_graph_mixed_task(
++ def construct(# @ L431
-- def construct(
++ # def construct( # just comment the whole function out
# @ L519
-- def construct_pynative_mixed_task(
++ def construct(Experiments are tested on Ascend Atlas 800T A2 machines with mindspore 2.7.0 pynative mode:
| model | task | # card(s) | image size | max_length | batch size | step time (s/step) |
|---|---|---|---|---|---|---|
| Janus-Pro-1B | T2I | 1 | 384x384 | 1024 | 8 | 0.60 |
| Janus-Pro-1B | VQA | 1 | 384x384 | 1024 | 4 | 0.42 |
| Janus-Pro-1B | Text | 1 | n.a. | 512 | 8 | 0.46 |
| Janus-Pro-7B | T2I | 1 | 384x384 | 1024 | 1 | 0.51 |
| Janus-Pro-7B | VQA | 1 | 384x384 | 1024 | 1 | 0.46 |
| Janus-Pro-7B | Text | 1 | n.a. | 512 | 1 | 0.57 |
For mixed-SFT:
| model | task | ms_mode | # card(s) | image size | max_length | batch size | step time (s/step) |
|---|---|---|---|---|---|---|---|
| Janus-Pro-1B | mixed | pynative | 1 | 384x384 | 1024 | 6 | 2.30 |