Skip to content

Commit cfd200e

Browse files
committed
examples : add dna-bp-score for HybridDNA base-pair scoring
Signed-off-by: Kashif Rasul <kashif.rasul@gmail.com>
1 parent d739eab commit cfd200e

3 files changed

Lines changed: 136 additions & 0 deletions

File tree

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ if (EMSCRIPTEN)
1616
else()
1717
add_subdirectory(batched)
1818
add_subdirectory(debug)
19+
add_subdirectory(dna-bp-score)
1920
add_subdirectory(embedding)
2021
add_subdirectory(eval-callback)
2122

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-dna-bp-score)
2+
add_executable(${TARGET} dna-bp-score.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE llama-common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// Base-pair scoring for Carbon/HybridDNA models: for a DNA sequence wrapped in
2+
// <dna>, report P(base | preceding context) at each position, marginalized from
3+
// the k-mer distribution exactly like the sampler. Mirrors the Python
4+
// HybridDNATokenizer model's score_sequence().
5+
6+
#include "arg.h"
7+
#include "common.h"
8+
#include "log.h"
9+
#include "llama.h"
10+
11+
#include <array>
12+
#include <cstring>
13+
#include <cmath>
14+
#include <cstdio>
15+
#include <string>
16+
#include <vector>
17+
18+
static int dna_base_index(char c) {
19+
switch (c) {
20+
case 'A': return 0;
21+
case 'T': return 1;
22+
case 'C': return 2;
23+
case 'G': return 3;
24+
default: return -1;
25+
}
26+
}
27+
28+
int main(int argc, char ** argv) {
29+
common_params params;
30+
params.prompt = "GGGCTATAAAGGCCATCGATCGATCGATCGATCGATCGATCG";
31+
params.n_ctx = 0;
32+
33+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
34+
return 1;
35+
}
36+
37+
// the prompt is the raw DNA sequence (ACGT), without the <dna> tag
38+
std::string seq;
39+
for (char c : params.prompt) {
40+
seq += (c >= 'a' && c <= 'z') ? char(c - 32) : c;
41+
}
42+
43+
common_init();
44+
llama_backend_init();
45+
llama_numa_init(params.numa);
46+
47+
auto llama_init = common_init_from_params(params);
48+
llama_model * model = llama_init->model();
49+
llama_context * ctx = llama_init->context();
50+
if (!model || !ctx) {
51+
LOG_ERR("%s: failed to load model\n", __func__);
52+
return 1;
53+
}
54+
const llama_vocab * vocab = llama_model_get_vocab(model);
55+
56+
// locate the k-mer block: tokenize "<dna>AAAA...A" to recover the first
57+
// k-mer id and k (its text length), then n_kmers = 4^k
58+
const llama_token first = [&] {
59+
const auto t = common_tokenize(vocab, "<dna>AAAAAAAAAAAA", false, true);
60+
return t.size() >= 2 ? t[1] : LLAMA_TOKEN_NULL;
61+
}();
62+
if (first == LLAMA_TOKEN_NULL) {
63+
LOG_ERR("%s: could not locate DNA k-mer block (not a HybridDNA vocab?)\n", __func__);
64+
return 1;
65+
}
66+
const int32_t k = (int32_t) std::strlen(llama_vocab_get_text(vocab, first));
67+
std::vector<int32_t> pow4(k);
68+
int32_t n_kmers = 1;
69+
for (int32_t i = k - 1; i >= 0; --i) {
70+
pow4[i] = n_kmers;
71+
n_kmers *= 4;
72+
}
73+
74+
// right-pad the sequence to a multiple of k with 'A' (training convention)
75+
std::string padded = seq;
76+
if (padded.size() % k != 0) {
77+
padded.append(k - padded.size() % k, 'A');
78+
}
79+
const int32_t n_tok = (int32_t) padded.size() / k;
80+
81+
// tokenize "<dna>" + padded sequence
82+
const std::vector<llama_token> tokens = common_tokenize(vocab, "<dna>" + padded, false, true);
83+
84+
// decode, requesting logits at every position that predicts a k-mer
85+
// (position t predicts the token at t+1; position 0 is <dna> -> k-mer 0)
86+
llama_batch batch = llama_batch_init((int32_t) tokens.size(), 0, 1);
87+
for (int32_t i = 0; i < (int32_t) tokens.size(); ++i) {
88+
common_batch_add(batch, tokens[i], i, {0}, i < n_tok);
89+
}
90+
if (llama_decode(ctx, batch) != 0) {
91+
LOG_ERR("%s: llama_decode() failed\n", __func__);
92+
return 1;
93+
}
94+
95+
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
96+
97+
printf("# pos base P(base|context)\n");
98+
for (int32_t t = 0; t < n_tok; ++t) {
99+
const float * logits = llama_get_logits_ith(ctx, t);
100+
101+
// softmax over the k-mer block (shifted by its running max)
102+
float maxl = -INFINITY;
103+
for (int32_t f = 0; f < n_kmers && first + f < n_vocab; ++f) {
104+
maxl = std::max(maxl, logits[first + f]);
105+
}
106+
std::vector<std::array<float, 4>> bp(k, {0.0f, 0.0f, 0.0f, 0.0f});
107+
float sum = 0.0f;
108+
for (int32_t f = 0; f < n_kmers && first + f < n_vocab; ++f) {
109+
const float w = expf(logits[first + f] - maxl);
110+
sum += w;
111+
for (int32_t pos = 0; pos < k; ++pos) {
112+
bp[pos][(f / pow4[pos]) % 4] += w;
113+
}
114+
}
115+
116+
for (int32_t pos = 0; pos < k; ++pos) {
117+
const int32_t gi = t * k + pos;
118+
if (gi >= (int32_t) seq.size()) {
119+
break;
120+
}
121+
const int b = dna_base_index(seq[gi]);
122+
const float p = (b < 0 || sum == 0.0f) ? 0.0f : bp[pos][b] / sum;
123+
printf("%4d %c %.4f\n", gi, seq[gi], p);
124+
}
125+
}
126+
127+
llama_batch_free(batch);
128+
llama_backend_free();
129+
return 0;
130+
}

0 commit comments

Comments
 (0)