Skip to content

Commit 8ec9a95

Browse files
committed
fix(mtp): distinguish explicit --mtp-source none from default
Per cubic P1 on PR #237: --mtp-source=none was silently overridden by the legacy-flag inference because MtpSource::None served both as the default and as the explicit 'none' value. Adds an internal MtpSource::Unset sentinel as the new default. Legacy-flag inference only fires when the field is still Unset. After inference (or if no legacy flag matched), Unset is resolved to None before any backend code sees it. User-visible CLI surface unchanged: --mtp-source still accepts exactly {none, native, external, auto}. Unset is internal-only and never escapes arg parsing. Defensive assert in create_backend() enforces this.
1 parent c2e7a53 commit 8ec9a95

17 files changed

Lines changed: 134 additions & 120 deletions

dflash/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,10 @@ if(DFLASH27B_TESTS)
528528
endif()
529529
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_common_mtp_orchestrator.cpp")
530530
add_executable(test_common_mtp_orchestrator test/test_common_mtp_orchestrator.cpp)
531-
target_include_directories(test_common_mtp_orchestrator PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS})
532-
target_link_libraries(test_common_mtp_orchestrator PRIVATE dflash27b)
531+
target_include_directories(test_common_mtp_orchestrator PRIVATE
532+
${DFLASH27B_SRC_INCLUDE_DIRS}
533+
${CMAKE_CURRENT_SOURCE_DIR}/deps/llama.cpp/ggml/include)
534+
target_link_libraries(test_common_mtp_orchestrator PRIVATE dflash_common ggml-base)
533535
endif()
534536
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_draft_vs_reference.cpp")
535537
add_executable(test_draft_vs_reference test/test_draft_vs_reference.cpp)

dflash/src/common/backend_factory.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "gguf.h"
1212

13+
#include <cassert>
1314
#include <cstdio>
1415

1516
namespace dflash::common {
@@ -54,6 +55,10 @@ std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args) {
5455

5556
std::fprintf(stderr, "[backend_factory] detected arch=%s\n", arch.c_str());
5657

58+
// Unset must have been resolved to None by arg parsing before reaching here.
59+
assert(args.mtp_source != MtpSource::Unset &&
60+
"MtpSource::Unset must be resolved by arg parsing before reaching the backend factory");
61+
5762
// Resolve MtpSource::Auto before constructing the backend.
5863
MtpSource resolved_source = args.mtp_source;
5964
if (resolved_source == MtpSource::Auto) {
@@ -99,7 +104,8 @@ std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args) {
99104
cfg.mtp_gguf_path = args.mtp_gguf_path;
100105
break;
101106
case MtpSource::None:
102-
case MtpSource::Auto: // Auto is fully resolved above; this arm is unreachable.
107+
case MtpSource::Auto: // fully resolved above; arm is unreachable.
108+
case MtpSource::Unset: // guarded by assert above; arm is unreachable.
103109
default:
104110
cfg.mtp_gguf_path = nullptr;
105111
break;

dflash/src/common/backend_factory.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ namespace dflash::common {
2121
// ─── MTP source selection ────────────────────────────────────────────────
2222
// Replaces the old free-form mtp_draft_source string (@howard0su #237, line 59).
2323
enum class MtpSource {
24+
Unset, // internal sentinel: --mtp-source not provided (never escapes arg parsing)
2425
None, // no MTP speculator
2526
Native, // MTP heads co-located in the target GGUF (e.g. unsloth single-file)
2627
ExternalDrafter, // separate MTP-head GGUF supplied via mtp_gguf_path
@@ -63,13 +64,15 @@ struct BackendArgs {
6364

6465
// MTP (Multi-Token Prediction) speculator — mutually exclusive with --draft.
6566
// mtp_source drives which loading path is taken:
67+
// Unset → internal default; --mtp-source not provided; resolved to None after
68+
// legacy-flag inference (never reaches the backend factory as Unset).
6669
// None → MTP disabled; mtp_gguf_path ignored.
6770
// Native → MTP heads embedded in model_path GGUF (single-file, e.g. unsloth).
6871
// mtp_gguf_path is left nullptr; the factory sets it to model_path.
6972
// ExternalDrafter→ Separate MTP-head GGUF at mtp_gguf_path (required).
7073
// Auto → factory calls gguf_contains_mtp_tensors(model_path): if true,
7174
// resolves to Native; otherwise resolves to None.
72-
MtpSource mtp_source = MtpSource::None;
75+
MtpSource mtp_source = MtpSource::Unset;
7376
const char * mtp_gguf_path = nullptr; // required only for ExternalDrafter
7477
int mtp_gamma = 0; // 0 = MTP loaded but not active; >0 = chain depth
7578
bool mtp_use_topk = false; // false = chain (default), true = mtp_topk strategy

dflash/src/common/mtp_chain_runner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <cstdio>
1010
#include <cstring>
1111

12-
namespace dflash27b::mtp {
12+
namespace dflash::common::mtp {
1313

1414
MtpChainRunner::MtpChainRunner(IMtpModule & mtp,
1515
DFlashTarget & target,
@@ -294,4 +294,4 @@ GenerateResult MtpChainRunner::run(const GenerateRequest & req,
294294
return result;
295295
}
296296

297-
} // namespace dflash27b::mtp
297+
} // namespace dflash::common::mtp

dflash/src/common/mtp_chain_runner.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
#include <vector>
1919

20-
namespace dflash27b {
20+
namespace dflash::common {
2121

2222
struct DFlashTarget; // forward — see common/dflash_target.h
2323

@@ -84,4 +84,4 @@ class MtpChainRunner {
8484
};
8585

8686
} // namespace mtp
87-
} // namespace dflash27b
87+
} // namespace dflash::common

dflash/src/common/mtp_interface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include <cstdint>
2323
#include <vector>
2424

25-
namespace dflash27b {
25+
namespace dflash::common {
2626

2727
struct DFlashTarget; // forward — see common/dflash_target.h
2828

@@ -212,4 +212,4 @@ struct INativeMtp : IMtpModule {
212212
};
213213

214214
} // namespace mtp
215-
} // namespace dflash27b
215+
} // namespace dflash::common

dflash/src/common/mtp_orchestrator.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include <cstring>
88
#include <vector>
99

10-
namespace dflash27b {
10+
namespace dflash::common {
1111
namespace mtp {
1212

1313
namespace {
@@ -39,7 +39,7 @@ GenerateResult warm_and_decode(ModelBackend * backend,
3939
return result;
4040
}
4141

42-
dflash27b::mtp::IMtpModule * module = backend->mtp();
42+
dflash::common::mtp::IMtpModule * module = backend->mtp();
4343
DFlashTarget * target = backend->dflash_target();
4444
if (!module || !target) {
4545
result.error = "warm_and_decode: backend missing mtp() or dflash_target()";
@@ -102,10 +102,10 @@ GenerateResult warm_and_decode(ModelBackend * backend,
102102
if (target->last_hidden() != nullptr) {
103103
module->set_initial_hidden(target->last_hidden(), hidden);
104104
}
105-
if (module->flavor() == dflash27b::mtp::MtpFlavor::NativeHeads
105+
if (module->flavor() == dflash::common::mtp::MtpFlavor::NativeHeads
106106
&& !all_prefill_hidden.empty()) {
107107
// flavor() guarantees the concrete type; static_cast is safe.
108-
auto * native = static_cast<dflash27b::mtp::INativeMtp *>(module);
108+
auto * native = static_cast<dflash::common::mtp::INativeMtp *>(module);
109109
if (native && !native->warm_head_kv(req.prompt.data(), prompt_len,
110110
last_tok, all_prefill_hidden.data())) {
111111
result.error = "warm_and_decode: warm_head_kv failed";
@@ -142,7 +142,7 @@ GenerateResult warm_and_decode(ModelBackend * backend,
142142
}
143143

144144
auto t_decode0 = std::chrono::steady_clock::now();
145-
dflash27b::mtp::MtpChainRunner runner(*module, *target, sampler);
145+
dflash::common::mtp::MtpChainRunner runner(*module, *target, sampler);
146146
GenerateResult inner_res = runner.run(inner, io, last_tok, prompt_len, gamma);
147147
result.decode_s = std::chrono::duration<double>(
148148
std::chrono::steady_clock::now() - t_decode0).count();
@@ -169,4 +169,4 @@ GenerateResult warm_and_decode(ModelBackend * backend,
169169
}
170170

171171
} // namespace mtp
172-
} // namespace dflash27b
172+
} // namespace dflash::common

dflash/src/common/mtp_orchestrator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include "model_backend.h"
1111
#include "mtp_interface.h"
1212

13-
namespace dflash27b {
13+
namespace dflash::common {
1414

1515
class DFlashTarget;
1616

@@ -41,4 +41,4 @@ GenerateResult warm_and_decode(ModelBackend * backend,
4141
const DaemonIO & io);
4242

4343
} // namespace mtp
44-
} // namespace dflash27b
44+
} // namespace dflash::common

dflash/src/qwen35/qwen35_daemon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ int run_qwen35_daemon(const Qwen35DaemonArgs & args) {
4343
cfg.mtp_gguf_path = args.mtp_gguf_path;
4444
break;
4545
case MtpSource::Auto:
46-
cfg.mtp_gguf_path = dflash27b::gguf_contains_mtp_tensors(args.target_path)
46+
cfg.mtp_gguf_path = dflash::common::gguf_contains_mtp_tensors(args.target_path)
4747
? args.target_path
4848
: nullptr;
4949
break;

dflash/src/qwen35/qwen35_mtp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
#include <utility>
5050
#include <vector>
5151

52-
namespace dflash27b::mtp {
52+
namespace dflash::common::mtp {
5353

5454
#ifdef DFLASH_MTP_PROFILE
5555
// Per-iter profiler for the step_chain_gpu_ loop. Enabled by -DDFLASH_MTP_PROFILE=1.
@@ -2096,4 +2096,4 @@ bool Qwen35MtpModule::step_chain(int32_t current_token,
20962096
return true;
20972097
}
20982098

2099-
} // namespace dflash27b::mtp
2099+
} // namespace dflash::common::mtp

0 commit comments

Comments
 (0)