Skip to content

Commit d9b2d8a

Browse files
authored
Merge pull request #366 from donglihe-hub/beam_width
add beam_width
2 parents 3682408 + adae624 commit d9b2d8a

2 files changed

Lines changed: 7 additions & 1 deletion

File tree

main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ def add_all_arguments(parser):
223223
parser.add_argument(
224224
"--tree_max_depth", type=int, default=10, help="Maximum depth of the tree (default: %(default)s)"
225225
)
226+
parser.add_argument(
227+
"--beam_width",
228+
type=int,
229+
default=10,
230+
help="The width of the beam search (default: %(default)s)",
231+
)
226232
parser.add_argument(
227233
"-h",
228234
"--help",

tests/basic.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ find example_config -name "*.yml" ${filters[@]} -type f -print0 |
119119
echo "Running $config"
120120
stderr=$(python $script --config "$config" --epochs 1 \
121121
--result_dir "$result_dir" --embed_cache_dir data \
122-
--val_size 0.2 --save_k_predictions 2 \
122+
--val_size 0.2 --beam_width 1 \
123123
--cpu 2>&1 > /dev/null)
124124
if [[ $? -ne 0 ]]; then
125125
echo "$stderr" >&2

0 commit comments

Comments
 (0)