|
| 1 | +// Base-pair scoring for Carbon/HybridDNA models: for a DNA sequence wrapped in |
| 2 | +// <dna>, report P(base | preceding context) at each position, marginalized from |
| 3 | +// the k-mer distribution exactly like the sampler. Mirrors the Python |
| 4 | +// HybridDNATokenizer model's score_sequence(). |
| 5 | + |
| 6 | +#include "arg.h" |
| 7 | +#include "common.h" |
| 8 | +#include "log.h" |
| 9 | +#include "llama.h" |
| 10 | + |
| 11 | +#include <array> |
| 12 | +#include <cstring> |
| 13 | +#include <cmath> |
| 14 | +#include <cstdio> |
| 15 | +#include <string> |
| 16 | +#include <vector> |
| 17 | + |
| 18 | +static int dna_base_index(char c) { |
| 19 | + switch (c) { |
| 20 | + case 'A': return 0; |
| 21 | + case 'T': return 1; |
| 22 | + case 'C': return 2; |
| 23 | + case 'G': return 3; |
| 24 | + default: return -1; |
| 25 | + } |
| 26 | +} |
| 27 | + |
| 28 | +int main(int argc, char ** argv) { |
| 29 | + common_params params; |
| 30 | + params.prompt = "GGGCTATAAAGGCCATCGATCGATCGATCGATCGATCGATCG"; |
| 31 | + params.n_ctx = 0; |
| 32 | + |
| 33 | + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { |
| 34 | + return 1; |
| 35 | + } |
| 36 | + |
| 37 | + // the prompt is the raw DNA sequence (ACGT), without the <dna> tag |
| 38 | + std::string seq; |
| 39 | + for (char c : params.prompt) { |
| 40 | + seq += (c >= 'a' && c <= 'z') ? char(c - 32) : c; |
| 41 | + } |
| 42 | + |
| 43 | + common_init(); |
| 44 | + llama_backend_init(); |
| 45 | + llama_numa_init(params.numa); |
| 46 | + |
| 47 | + auto llama_init = common_init_from_params(params); |
| 48 | + llama_model * model = llama_init->model(); |
| 49 | + llama_context * ctx = llama_init->context(); |
| 50 | + if (!model || !ctx) { |
| 51 | + LOG_ERR("%s: failed to load model\n", __func__); |
| 52 | + return 1; |
| 53 | + } |
| 54 | + const llama_vocab * vocab = llama_model_get_vocab(model); |
| 55 | + |
| 56 | + // locate the k-mer block: tokenize "<dna>AAAA...A" to recover the first |
| 57 | + // k-mer id and k (its text length), then n_kmers = 4^k |
| 58 | + const llama_token first = [&] { |
| 59 | + const auto t = common_tokenize(vocab, "<dna>AAAAAAAAAAAA", false, true); |
| 60 | + return t.size() >= 2 ? t[1] : LLAMA_TOKEN_NULL; |
| 61 | + }(); |
| 62 | + if (first == LLAMA_TOKEN_NULL) { |
| 63 | + LOG_ERR("%s: could not locate DNA k-mer block (not a HybridDNA vocab?)\n", __func__); |
| 64 | + return 1; |
| 65 | + } |
| 66 | + const int32_t k = (int32_t) std::strlen(llama_vocab_get_text(vocab, first)); |
| 67 | + std::vector<int32_t> pow4(k); |
| 68 | + int32_t n_kmers = 1; |
| 69 | + for (int32_t i = k - 1; i >= 0; --i) { |
| 70 | + pow4[i] = n_kmers; |
| 71 | + n_kmers *= 4; |
| 72 | + } |
| 73 | + |
| 74 | + // right-pad the sequence to a multiple of k with 'A' (training convention) |
| 75 | + std::string padded = seq; |
| 76 | + if (padded.size() % k != 0) { |
| 77 | + padded.append(k - padded.size() % k, 'A'); |
| 78 | + } |
| 79 | + const int32_t n_tok = (int32_t) padded.size() / k; |
| 80 | + |
| 81 | + // tokenize "<dna>" + padded sequence |
| 82 | + const std::vector<llama_token> tokens = common_tokenize(vocab, "<dna>" + padded, false, true); |
| 83 | + |
| 84 | + // decode, requesting logits at every position that predicts a k-mer |
| 85 | + // (position t predicts the token at t+1; position 0 is <dna> -> k-mer 0) |
| 86 | + llama_batch batch = llama_batch_init((int32_t) tokens.size(), 0, 1); |
| 87 | + for (int32_t i = 0; i < (int32_t) tokens.size(); ++i) { |
| 88 | + common_batch_add(batch, tokens[i], i, {0}, i < n_tok); |
| 89 | + } |
| 90 | + if (llama_decode(ctx, batch) != 0) { |
| 91 | + LOG_ERR("%s: llama_decode() failed\n", __func__); |
| 92 | + return 1; |
| 93 | + } |
| 94 | + |
| 95 | + const int32_t n_vocab = llama_vocab_n_tokens(vocab); |
| 96 | + |
| 97 | + printf("# pos base P(base|context)\n"); |
| 98 | + for (int32_t t = 0; t < n_tok; ++t) { |
| 99 | + const float * logits = llama_get_logits_ith(ctx, t); |
| 100 | + |
| 101 | + // softmax over the k-mer block (shifted by its running max) |
| 102 | + float maxl = -INFINITY; |
| 103 | + for (int32_t f = 0; f < n_kmers && first + f < n_vocab; ++f) { |
| 104 | + maxl = std::max(maxl, logits[first + f]); |
| 105 | + } |
| 106 | + std::vector<std::array<float, 4>> bp(k, {0.0f, 0.0f, 0.0f, 0.0f}); |
| 107 | + float sum = 0.0f; |
| 108 | + for (int32_t f = 0; f < n_kmers && first + f < n_vocab; ++f) { |
| 109 | + const float w = expf(logits[first + f] - maxl); |
| 110 | + sum += w; |
| 111 | + for (int32_t pos = 0; pos < k; ++pos) { |
| 112 | + bp[pos][(f / pow4[pos]) % 4] += w; |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + for (int32_t pos = 0; pos < k; ++pos) { |
| 117 | + const int32_t gi = t * k + pos; |
| 118 | + if (gi >= (int32_t) seq.size()) { |
| 119 | + break; |
| 120 | + } |
| 121 | + const int b = dna_base_index(seq[gi]); |
| 122 | + const float p = (b < 0 || sum == 0.0f) ? 0.0f : bp[pos][b] / sum; |
| 123 | + printf("%4d %c %.4f\n", gi, seq[gi], p); |
| 124 | + } |
| 125 | + } |
| 126 | + |
| 127 | + llama_batch_free(batch); |
| 128 | + llama_backend_free(); |
| 129 | + return 0; |
| 130 | +} |
0 commit comments