|
| 1 | +#!/bin/bash |
| 2 | +# ============================================================================ |
| 3 | +# Common Configuration Script - for attentive_probe_codec evaluation |
| 4 | +# Usage: Source this file in model scripts, then call run_attentive_probe_codec |
| 5 | +# ============================================================================ |
| 6 | + |
| 7 | +# Environment setup |
| 8 | +export PYTHONPATH=../ |
| 9 | + |
| 10 | +# Default dataset list (can be overridden in calling script) |
| 11 | +DEFAULT_DATASETS=( |
| 12 | + "ssv2" |
| 13 | + "diving48" |
| 14 | + "perception_test" |
| 15 | + "epic_verb" |
| 16 | + "epic_noun" |
| 17 | + "hmdb51" |
| 18 | + "k400" |
| 19 | + "charadesego" |
| 20 | +) |
| 21 | + |
| 22 | +# ============================================================================ |
| 23 | +# Get batch size based on dataset name |
| 24 | +# Args: $1 - dataset name |
| 25 | +# ============================================================================ |
| 26 | +get_batch_size() { |
| 27 | + local dataset="$1" |
| 28 | + if [[ "$dataset" == "ssv2" || "$dataset" == "diving48" || "$dataset" == "perception_test" ]]; then |
| 29 | + echo 4 |
| 30 | + elif [[ "$dataset" == "hmdb51" ]]; then |
| 31 | + echo 2 |
| 32 | + else |
| 33 | + echo 16 |
| 34 | + fi |
| 35 | +} |
| 36 | + |
| 37 | +# ============================================================================ |
| 38 | +# Get epochs based on dataset name |
| 39 | +# Args: $1 - dataset name |
| 40 | +# ============================================================================ |
| 41 | +get_epochs() { |
| 42 | + local dataset="$1" |
| 43 | + if [[ "$dataset" == "hmdb51" ]]; then |
| 44 | + echo 30 |
| 45 | + elif [[ "$dataset" == "diving48" ]]; then |
| 46 | + echo 30 |
| 47 | + elif [[ "$dataset" == "perception_test" ]]; then |
| 48 | + echo 30 |
| 49 | + else |
| 50 | + echo 10 |
| 51 | + fi |
| 52 | +} |
| 53 | + |
| 54 | +# ============================================================================ |
| 55 | +# Get codec parameters based on dataset name |
| 56 | +# Args: $1 - dataset name |
| 57 | +# Returns: Sets CODEC_MV_COMPENSATE, CODEC_STATIC_ABS_THRESH, CODEC_STATIC_REL_THRESH |
| 58 | +# ============================================================================ |
| 59 | +get_codec_params() { |
| 60 | + local dataset="$1" |
| 61 | + if [[ "$dataset" == "diving48" || "$dataset" == "perception_test" ]]; then |
| 62 | + # Parameters for diving48 and perception_test |
| 63 | + CODEC_MV_COMPENSATE="similarity" |
| 64 | + CODEC_STATIC_ABS_THRESH="126" |
| 65 | + CODEC_STATIC_REL_THRESH="0.38" |
| 66 | + else |
| 67 | + # Parameters for other datasets |
| 68 | + CODEC_MV_COMPENSATE="similarity" |
| 69 | + CODEC_STATIC_ABS_THRESH="116" |
| 70 | + CODEC_STATIC_REL_THRESH="0.55" |
| 71 | + fi |
| 72 | +} |
| 73 | + |
| 74 | +# ============================================================================ |
| 75 | +# Run attentive_probe_codec evaluation |
| 76 | +# Required variables to set before calling: |
| 77 | +# - MODEL_FAMILY: model family (required) |
| 78 | +# - MODEL_NAME: model name (required) |
| 79 | +# - MODEL_WEIGHT: model weight path (optional, default "NULL") |
| 80 | +# - FRAMES_TOKEN_NUM: token count (optional, default 196) |
| 81 | +# - EMBEDDING_SIZE: embedding dimension (optional, default 768) |
| 82 | +# - INPUT_SIZE: input size (optional, not passed if unset) |
| 83 | +# - NUM_FRAMES: number of frames (optional, not passed if unset) |
| 84 | +# - DATASETS: dataset array (optional, uses DEFAULT_DATASETS if unset/empty) |
| 85 | +# - REPORT_DIR_SUFFIX: report directory suffix (optional, e.g. "_64frames_codec") |
| 86 | +# ============================================================================ |
| 87 | +run_attentive_probe_codec() { |
| 88 | + # Set default values |
| 89 | + MODEL_WEIGHT="${MODEL_WEIGHT:-NULL}" |
| 90 | + FRAMES_TOKEN_NUM="${FRAMES_TOKEN_NUM:-196}" |
| 91 | + EMBEDDING_SIZE="${EMBEDDING_SIZE:-768}" |
| 92 | + REPORT_DIR_SUFFIX="${REPORT_DIR_SUFFIX:-}" |
| 93 | + |
| 94 | + # Use custom datasets or default datasets |
| 95 | + if [[ -z "${DATASETS+x}" ]] || [[ ${#DATASETS[@]} -eq 0 ]]; then |
| 96 | + DATASETS=("${DEFAULT_DATASETS[@]}") |
| 97 | + fi |
| 98 | + |
| 99 | + # Build report directory |
| 100 | + BASE_REPORT_DIR="result_attentive_probe/${MODEL_FAMILY}/${MODEL_NAME}${REPORT_DIR_SUFFIX}" |
| 101 | + |
| 102 | + # Loop through each dataset for testing |
| 103 | + for DATASET in "${DATASETS[@]}"; do |
| 104 | + BATCH_SIZE=$(get_batch_size "$DATASET") |
| 105 | + EPOCHS=$(get_epochs "$DATASET") |
| 106 | + |
| 107 | + # Get codec-specific parameters for this dataset |
| 108 | + get_codec_params "$DATASET" |
| 109 | + |
| 110 | + echo "DATASET=$DATASET, BATCH_SIZE=$BATCH_SIZE" |
| 111 | + echo "Codec params: mv_compensate=${CODEC_MV_COMPENSATE}, static_abs_thresh=${CODEC_STATIC_ABS_THRESH}, static_rel_thresh=${CODEC_STATIC_REL_THRESH}" |
| 112 | + |
| 113 | + echo "========================================================" |
| 114 | + echo "Start testing dataset: ${DATASET}" |
| 115 | + echo "Model: ${MODEL_NAME}" |
| 116 | + echo "Batch Size: ${BATCH_SIZE}" |
| 117 | + echo "Report Dir: ${BASE_REPORT_DIR}/${DATASET}" |
| 118 | + echo "========================================================" |
| 119 | + |
| 120 | + # Build output directory |
| 121 | + SAVE_DIR="${BASE_REPORT_DIR}/${DATASET}" |
| 122 | + mkdir -p "$SAVE_DIR" |
| 123 | + |
| 124 | + # Build extra arguments |
| 125 | + EXTRA_ARGS="" |
| 126 | + if [[ -n "${INPUT_SIZE}" ]]; then |
| 127 | + EXTRA_ARGS="${EXTRA_ARGS} --input_size ${INPUT_SIZE}" |
| 128 | + fi |
| 129 | + if [[ -n "${NUM_FRAMES}" ]]; then |
| 130 | + EXTRA_ARGS="${EXTRA_ARGS} --num_frames ${NUM_FRAMES}" |
| 131 | + fi |
| 132 | + |
| 133 | + torchrun --nproc_per_node 8 --master_port 15555 \ |
| 134 | + attentive_probe_codec.py \ |
| 135 | + --eval_freq 1 \ |
| 136 | + --default_lr_list 0.0001 \ |
| 137 | + --default_epoch "${EPOCHS}" \ |
| 138 | + --batch_size ${BATCH_SIZE} \ |
| 139 | + --default_weight_decay 0 \ |
| 140 | + --dali_py_num_workers 8 \ |
| 141 | + --model_family "${MODEL_FAMILY}" \ |
| 142 | + --model_name "${MODEL_NAME}" \ |
| 143 | + --model_weight "${MODEL_WEIGHT}" \ |
| 144 | + --dataset "${DATASET}" \ |
| 145 | + --save_report "${SAVE_DIR}" \ |
| 146 | + --frames_token_num ${FRAMES_TOKEN_NUM} \ |
| 147 | + --embedding_size ${EMBEDDING_SIZE} \ |
| 148 | + --mv_compensate ${CODEC_MV_COMPENSATE} \ |
| 149 | + --static_abs_thresh ${CODEC_STATIC_ABS_THRESH} \ |
| 150 | + --static_rel_thresh ${CODEC_STATIC_REL_THRESH} \ |
| 151 | + --static_fallback 1 \ |
| 152 | + ${EXTRA_ARGS} |
| 153 | + |
| 154 | + echo "Finished testing ${DATASET}" |
| 155 | + echo "" |
| 156 | + done |
| 157 | +} |
0 commit comments