-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathindex.js
More file actions
147 lines (128 loc) · 4.33 KB
/
Copy pathindex.js
File metadata and controls
147 lines (128 loc) · 4.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
const { pipeline } = require('@huggingface/transformers');
const { HierarchicalNSW } = require('hnswlib-node');
class BiMap {
constructor(entries = []) {
this._forward = new Map(entries);
this._backward = new Map(entries.map(([k, v]) => [v, k]));
}
set(key, value) {
// 删除已有的正向/反向映射,保证一对一
if (this._forward.has(key)) {
const oldVal = this._forward.get(key);
this._backward.delete(oldVal);
}
if (this._backward.has(value)) {
const oldKey = this._backward.get(value);
this._forward.delete(oldKey);
}
this._forward.set(key, value);
this._backward.set(value, key);
}
get(key) { return this._forward.get(key); }
getKey(value) { return this._backward.get(value); }
has(key) { return this._forward.has(key); }
hasValue(value) { return this._backward.has(value); }
delete(key) {
if (!this._forward.has(key)) return false;
const value = this._forward.get(key);
this._forward.delete(key);
this._backward.delete(value);
return true;
}
get length() { return this._forward.size; }
}
const numDimensions = 384; // the length of data point vector that will be indexed.
const maxElements = 1024; // the maximum number of data points.
const modelName = 'Xenova/all-MiniLM-L6-v2';
// declaring and intializing index.
const index = new HierarchicalNSW('l2', numDimensions);
index.initIndex(maxElements);
let extractor;
const labelMapping = new BiMap();
function formatBytes(bytes) {
if (!Number.isFinite(bytes)) return '';
const units = ['B', 'KB', 'MB', 'GB'];
let value = bytes;
let unit = units.shift();
while (value >= 1024 && units.length) {
value /= 1024;
unit = units.shift();
}
return `${value.toFixed(value >= 10 || unit === 'B' ? 0 : 1)} ${unit}`;
}
function createProgressLogger(log) {
const lastLogged = new Map();
let activeDownload;
return data => {
if (!data) return;
const key = data.file
? `${data.name || modelName}/${data.file}`
: activeDownload;
if (data.status === 'initiate') {
log.info(`Loading embedding model file: ${key}`);
return;
}
if (data.status === 'download') {
activeDownload = key;
return;
}
if (data.status === 'done') {
if (activeDownload === key) activeDownload = null;
return;
}
if (!key) return;
if (typeof data.progress !== 'number') return;
if (data.status && data.status !== 'progress') return;
const progress = Math.floor(data.progress);
const previous = lastLogged.get(key) || 0;
if (progress < 100 && progress - previous < 5) return;
lastLogged.set(key, progress);
const width = 20;
const filled = Math.round((Math.min(progress, 100) / 100) * width);
const bar = `${'#'.repeat(filled)}${'-'.repeat(width - filled)}`;
const size = data.total
? ` (${formatBytes(data.loaded)} / ${formatBytes(data.total)})`
: '';
log.info(`Downloading ${key} [${bar}] ${progress}%${size}`);
};
}
hexo.extend.filter.register('after_init', async function() {
const log = this.log || hexo.log;
log.info(`Loading embedding model: ${modelName}`);
extractor = await pipeline('feature-extraction', modelName, {
progress_callback: createProgressLogger(log)
});
log.info(`Embedding model ready: ${modelName}`);
});
hexo.extend.filter.register('before_post_render', async function(data) {
const embeddings = await extractor([data._content], { pooling: 'mean', normalize: true });
data.embedding_vector = embeddings.tolist()[0];
// Create a new id if data.path doesn't exist in labelMapping
// Else use the existing id
let id;
if (!labelMapping.hasValue(data.path)) {
id = labelMapping.length;
labelMapping.set(id, data.path);
} else {
id = labelMapping.getKey(data.path);
}
index.addPoint(data.embedding_vector, id);
return data;
});
hexo.extend.helper.register('related_posts', function(post) {
const result = [];
if (!post.embedding_vector) {
post.related_posts = result;
return result;
}
const numNeighbors = 5;
const query = post.embedding_vector;
const { neighbors } = index.searchKnn(query, numNeighbors);
// Skip the first result as it is the query itself
for (let i = 1; i < neighbors.length; i++) {
const neighbor = neighbors[i];
result.push(labelMapping.get(neighbor));
}
post.related_posts = result;
return result;
});