@@ -8,28 +8,58 @@ import { max, softmax } from '../utils/maths.js';
88 */
99
1010/**
11- * @typedef {Object } TokenClassificationSingle
12- * @property {string } word The token/word classified. This is obtained by decoding the selected tokens.
13- * @property {number } score The corresponding probability for `entity`.
14- * @property {string } entity The entity predicted for that token/word.
15- * @property {number } index The index of the corresponding token in the sentence.
16- * @property {number } [start] The index of the start of the corresponding entity in the sentence.
17- * @property {number } [end] The index of the end of the corresponding entity in the sentence.
18- * @typedef {TokenClassificationSingle[] } TokenClassificationOutput
19- *
20- * @typedef {Object } TokenClassificationPipelineOptions Parameters specific to token classification pipelines.
11+ * Strategy for fusing tokens based on the model prediction.
12+ * - `"none"`: Return raw per-token predictions.
13+ * - `"simple"`: Group entities using BIO / BIOES tags (see pipeline docs for details).
14+ * @typedef {"none" | "simple" } AggregationStrategy
15+ */
16+
17+ /**
18+ * @typedef {Object } TokenClassificationPipelineOptions
2119 * @property {string[] } [ignore_labels] A list of labels to ignore.
20+ * @property {AggregationStrategy } [aggregation_strategy="none"] Token-fusion strategy.
21+ * When set to anything other than `"none"`, results use `entity_group` instead of `entity`.
22+ */
23+
24+ /**
25+ * Single element of a token-classification result, parameterised by the options type `O` so that
26+ * `entity` vs. `entity_group` is known statically based on `aggregation_strategy`.
2227 *
23- * @typedef {TextPipelineConstructorArgs & TokenClassificationPipelineCallback & Disposable } TokenClassificationPipelineType
28+ * - Grouped (present when `O["aggregation_strategy"]` is `"simple"`):
29+ * `{ word, score, entity_group }`
30+ * - Raw (the default — when `aggregation_strategy` is missing, `"none"`, or `undefined`):
31+ * `{ word, score, entity, index }`
32+ * - Both variants also carry optional `start` / `end` character offsets.
33+ *
34+ * When `O` is the untyped `TokenClassificationPipelineOptions`, the element is the union of both shapes,
35+ * narrowable via `if ("entity_group" in item)` / `if (item.entity !== undefined)`.
36+ *
37+ * @template {TokenClassificationPipelineOptions | undefined} [O=TokenClassificationPipelineOptions]
38+ * @typedef {_PickElement<O>[] } TokenClassificationOutput
39+ */
40+
41+ /**
42+ * @template {TokenClassificationPipelineOptions | undefined} O
43+ * @typedef {O extends undefined
44+ * ? _Raw
45+ * : "aggregation_strategy" extends keyof O
46+ * ? (O extends { aggregation_strategy?: infer A }
47+ * ? ([A] extends ["simple"] ? _Grouped
48+ * : [A] extends ["none" | undefined] ? _Raw
49+ * : _Raw | _Grouped)
50+ * : _Raw)
51+ * : _Raw} _PickElement
2452 */
2553
2654/**
27- * @template T
28- * @typedef {T extends string[] ? TokenClassificationOutput[] : TokenClassificationOutput } TokenClassificationPipelineResult
55+ * @typedef { { word: string, score: number, entity: string, index: number, entity_group?: undefined, start?: number, end?: number } } _Raw
56+ * @typedef {{ word: string, score: number, entity_group: string, entity?: undefined, index?: undefined, start?: number, end?: number } } _Grouped
2957 */
3058
3159/**
32- * @typedef {<T extends string | string[]>(texts: T, options?: TokenClassificationPipelineOptions) => Promise<TokenClassificationPipelineResult<T>> } TokenClassificationPipelineCallback
60+ * @typedef {TextPipelineConstructorArgs & TokenClassificationPipelineCallback & Disposable } TokenClassificationPipelineType
61+ *
62+ * @typedef {<Q extends string | string[], const O extends TokenClassificationPipelineOptions = {}>(texts: Q, options?: O) => Promise<Q extends string[] ? TokenClassificationOutput<O>[] : TokenClassificationOutput<O>> } TokenClassificationPipelineCallback
3363 */
3464
3565/**
@@ -64,11 +94,29 @@ import { max, softmax } from '../utils/maths.js';
6494 * // { entity: 'I-LOC', score: 0.9975294470787048, index: 8, word: 'America' }
6595 * // ]
6696 * ```
97+ *
98+ * **Example:** Group adjacent BIO/BIOES tokens into entity spans using `aggregation_strategy: "simple"`.
99+ * ```javascript
100+ * import { pipeline } from '@huggingface/transformers';
101+ *
102+ * const classifier = await pipeline('token-classification', 'Xenova/bert-base-NER');
103+ * const output = await classifier('My name is Sarah and I live in London', { aggregation_strategy: 'simple' });
104+ * // [
105+ * // { entity_group: 'PER', score: 0.9985477924346924, word: 'Sarah' },
106+ * // { entity_group: 'LOC', score: 0.999621570110321, word: 'London' }
107+ * // ]
108+ * ```
67109 */
68110export class TokenClassificationPipeline
69111 extends /** @type {new (options: TextPipelineConstructorArgs) => TokenClassificationPipelineType } */ ( Pipeline )
70112{
71- async _call ( texts , { ignore_labels = [ 'O' ] } = { } ) {
113+ async _call ( texts , { ignore_labels = [ 'O' ] , aggregation_strategy = 'none' } = { } ) {
114+ if ( aggregation_strategy !== 'none' && aggregation_strategy !== 'simple' ) {
115+ throw new Error (
116+ `Invalid aggregation_strategy: "${ aggregation_strategy } ". Must be one of "none" or "simple".` ,
117+ ) ;
118+ }
119+
72120 const isBatched = Array . isArray ( texts ) ;
73121
74122 // Run tokenization
@@ -86,43 +134,94 @@ export class TokenClassificationPipeline
86134
87135 const toReturn = [ ] ;
88136 for ( let i = 0 ; i < logits . dims [ 0 ] ; ++ i ) {
89- const ids = model_inputs . input_ids [ i ] ;
137+ const ids = model_inputs . input_ids [ i ] . tolist ( ) ;
90138 const batch = logits [ i ] ;
91139
92- // List of tokens that aren't ignored
93140 const tokens = [ ] ;
94141 for ( let j = 0 ; j < batch . dims [ 0 ] ; ++ j ) {
95142 const tokenData = batch [ j ] ;
96143 const topScoreIndex = max ( tokenData . data ) [ 1 ] ;
97144
98145 const entity = id2label ? id2label [ topScoreIndex ] : `LABEL_${ topScoreIndex } ` ;
99- if ( ignore_labels . includes ( entity ) ) {
100- // We predicted a token that should be ignored. So, we skip it.
101- continue ;
102- }
146+ if ( ignore_labels . includes ( entity ) ) continue ;
103147
104148 // TODO add option to keep special tokens?
105- const word = this . tokenizer . decode ( [ ids [ j ] . item ( ) ] , { skip_special_tokens : true } ) ;
106- if ( word === '' ) {
107- // Was a special token. So, we skip it.
108- continue ;
109- }
149+ const word = this . tokenizer . decode ( [ ids [ j ] ] , { skip_special_tokens : true } ) ;
150+ if ( word === '' ) continue ; // Was a special token.
110151
111152 const scores = softmax ( tokenData . data ) ;
112-
113153 tokens . push ( {
114- entity : entity ,
154+ entity,
115155 score : scores [ topScoreIndex ] ,
116156 index : j ,
117- word : word ,
118-
157+ word,
119158 // TODO: Add support for start and end
120- // start: null,
121- // end: null,
122159 } ) ;
123160 }
124- toReturn . push ( tokens ) ;
161+
162+ toReturn . push ( aggregation_strategy === 'simple' ? groupEntities ( tokens , ids , this . tokenizer ) : tokens ) ;
125163 }
126164 return isBatched ? toReturn : toReturn [ 0 ] ;
127165 }
128166}
167+
168+ /**
169+ * Split a raw entity label into its BIOES prefix and tag.
170+ *
171+ * @param {string } entity
172+ * @returns {readonly [prefix: 'B'|'I'|'E'|'S', tag: string] }
173+ */
174+ function getTag ( entity ) {
175+ const p = entity [ 0 ] ;
176+ return entity [ 1 ] === '-' && ( p === 'B' || p === 'I' || p === 'E' || p === 'S' )
177+ ? [ p , entity . slice ( 2 ) ]
178+ : [ 'I' , entity ] ;
179+ }
180+
181+ /**
182+ * Group raw per-token predictions into entity spans using the SIMPLE strategy.
183+ *
184+ * The only "continue" predicate is: a non-`B`/non-`S` token whose tag matches
185+ * the open group's tag, when that group hasn't been closed by an `E` / `S`.
186+ * Everything else starts a fresh group.
187+ *
188+ * @param {_Raw[] } tokens
189+ * @param {number[] } ids Full input_ids for the sequence (indexed by `token.index`), used to re-decode
190+ * each group so the joined `word` matches what the tokenizer would produce.
191+ * @param {any } tokenizer
192+ * @returns {_Grouped[] }
193+ */
194+ function groupEntities ( tokens , ids , tokenizer ) {
195+ /** @type {{ tag: string, start: number, end: number }[] } */
196+ const groups = [ ] ; // each entry is a [start, end) slice into `tokens`, plus the shared tag
197+ let openTag = null ; // null = no open group
198+
199+ for ( let i = 0 ; i < tokens . length ; ++ i ) {
200+ const [ prefix , tag ] = getTag ( tokens [ i ] . entity ) ;
201+ const extend = openTag === tag && prefix !== 'B' && prefix !== 'S' ;
202+ if ( extend ) {
203+ groups [ groups . length - 1 ] . end = i + 1 ;
204+ // `E` terminates the group; subsequent `I`/`E`/`S` start fresh.
205+ if ( prefix === 'E' ) openTag = null ;
206+ } else {
207+ groups . push ( { tag, start : i , end : i + 1 } ) ;
208+ // `S` opens and immediately closes; anything else leaves the group open
209+ // (including a leading `E` — best-effort recovery for a malformed sequence).
210+ openTag = prefix === 'S' ? null : tag ;
211+ }
212+ }
213+
214+ return groups . map ( ( { tag, start, end } ) => {
215+ let scoreSum = 0 ;
216+ const groupIds = [ ] ;
217+ for ( let i = start ; i < end ; ++ i ) {
218+ scoreSum += tokens [ i ] . score ;
219+ groupIds . push ( ids [ tokens [ i ] . index ] ) ;
220+ }
221+ return {
222+ entity_group : tag ,
223+ score : scoreSum / ( end - start ) ,
224+ word : tokenizer . decode ( groupIds , { skip_special_tokens : true } ) ,
225+ } ;
226+ } ) ;
227+ }
0 commit comments