Skip to content

Commit ea88f89

Browse files
authored
Add support for OpenAI privacy filter model (#1658)
* Update test_pipelines_token_classification.js * Update token-classification pipeline output type tests * Update types.test.js * use tsconfig.json for type tests * Add support for openai_privacy_filter * add support for aggregation strategy * fix docs build
1 parent b93766d commit ea88f89

8 files changed

Lines changed: 462 additions & 51 deletions

File tree

packages/transformers/docs/plugins/preprocess.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ function transformType(expr) {
7474
.replace(/keyof\s+typeof\s+\w+/g, "string") // keyof typeof X -> string
7575
.replace(/typeof\s+\w+/g, "Object") // typeof X -> Object
7676
.replace(/\binfer\s+\w+/g, "any") // infer K -> any
77+
.replace(/\breadonly\s+/g, "") // readonly T -> T
7778
.replace(/\(\s*\w[\w<>, ]*\s+extends\s+\w+\s*\?[^)]*\)/g, "any") // (X extends Y ? A : B) -> any
7879
.replace(/(?<!\w)\(\s*(\w+)\s*\)/g, "$1") // (any) -> any (unwrap parens around simple types, not after words like "function")
7980
.replace(/(\w+)\?\s*:/g, "$1:") // x?: T -> x: T

packages/transformers/src/models/models.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ export * from './olmo/modeling_olmo.js';
121121
export * from './olmo2/modeling_olmo2.js';
122122
export * from './olmo3/modeling_olmo3.js';
123123
export * from './olmo_hybrid/modeling_olmo_hybrid.js';
124+
export * from './openai_privacy_filter/modeling_openai_privacy_filter.js';
124125
export * from './openelm/modeling_openelm.js';
125126
export * from './opt/modeling_opt.js';
126127
export * from './owlv2/modeling_owlv2.js';
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import { PreTrainedModel } from '../modeling_utils.js';
2+
import { SequenceClassifierOutput } from '../modeling_outputs.js';
3+
4+
export class OpenAIPrivacyFilterPreTrainedModel extends PreTrainedModel {}
5+
export class OpenAIPrivacyFilterModel extends OpenAIPrivacyFilterPreTrainedModel {}
6+
7+
export class OpenAIPrivacyFilterForTokenClassification extends OpenAIPrivacyFilterPreTrainedModel {
8+
/**
9+
* Calls the model on new inputs.
10+
*
11+
* @param {Object} model_inputs The inputs to the model.
12+
* @returns {Promise<SequenceClassifierOutput>} An object containing the model's output logits for sequence classification.
13+
*/
14+
async _call(model_inputs) {
15+
return new SequenceClassifierOutput(await super._call(model_inputs));
16+
}
17+
}

packages/transformers/src/models/registry.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
9797
['mgp-str', 'MgpstrForSceneTextRecognition'],
9898

9999
['style_text_to_speech_2', 'StyleTextToSpeech2Model'],
100+
['openai_privacy_filter', 'OpenAIPrivacyFilterModel'],
100101
]);
101102

102103
const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
@@ -238,6 +239,7 @@ const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([
238239
['roberta', 'RobertaForTokenClassification'],
239240
['xlm', 'XLMForTokenClassification'],
240241
['xlm-roberta', 'XLMRobertaForTokenClassification'],
242+
['openai_privacy_filter', 'OpenAIPrivacyFilterForTokenClassification'],
241243
]);
242244

243245
export const MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = new Map([

packages/transformers/src/pipelines/token-classification.js

Lines changed: 132 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,58 @@ import { max, softmax } from '../utils/maths.js';
88
*/
99

1010
/**
11-
* @typedef {Object} TokenClassificationSingle
12-
* @property {string} word The token/word classified. This is obtained by decoding the selected tokens.
13-
* @property {number} score The corresponding probability for `entity`.
14-
* @property {string} entity The entity predicted for that token/word.
15-
* @property {number} index The index of the corresponding token in the sentence.
16-
* @property {number} [start] The index of the start of the corresponding entity in the sentence.
17-
* @property {number} [end] The index of the end of the corresponding entity in the sentence.
18-
* @typedef {TokenClassificationSingle[]} TokenClassificationOutput
19-
*
20-
* @typedef {Object} TokenClassificationPipelineOptions Parameters specific to token classification pipelines.
11+
* Strategy for fusing tokens based on the model prediction.
12+
* - `"none"`: Return raw per-token predictions.
13+
* - `"simple"`: Group entities using BIO / BIOES tags (see pipeline docs for details).
14+
* @typedef {"none" | "simple"} AggregationStrategy
15+
*/
16+
17+
/**
18+
* @typedef {Object} TokenClassificationPipelineOptions
2119
* @property {string[]} [ignore_labels] A list of labels to ignore.
20+
* @property {AggregationStrategy} [aggregation_strategy="none"] Token-fusion strategy.
21+
* When set to anything other than `"none"`, results use `entity_group` instead of `entity`.
22+
*/
23+
24+
/**
25+
* Single element of a token-classification result, parameterised by the options type `O` so that
26+
* `entity` vs. `entity_group` is known statically based on `aggregation_strategy`.
2227
*
23-
* @typedef {TextPipelineConstructorArgs & TokenClassificationPipelineCallback & Disposable} TokenClassificationPipelineType
28+
* - Grouped (present when `O["aggregation_strategy"]` is `"simple"`):
29+
* `{ word, score, entity_group }`
30+
* - Raw (the default — when `aggregation_strategy` is missing, `"none"`, or `undefined`):
31+
* `{ word, score, entity, index }`
32+
* - Both variants also carry optional `start` / `end` character offsets.
33+
*
34+
* When `O` is the untyped `TokenClassificationPipelineOptions`, the element is the union of both shapes,
35+
* narrowable via `if ("entity_group" in item)` / `if (item.entity !== undefined)`.
36+
*
37+
* @template {TokenClassificationPipelineOptions | undefined} [O=TokenClassificationPipelineOptions]
38+
* @typedef {_PickElement<O>[]} TokenClassificationOutput
39+
*/
40+
41+
/**
42+
* @template {TokenClassificationPipelineOptions | undefined} O
43+
* @typedef {O extends undefined
44+
* ? _Raw
45+
* : "aggregation_strategy" extends keyof O
46+
* ? (O extends { aggregation_strategy?: infer A }
47+
* ? ([A] extends ["simple"] ? _Grouped
48+
* : [A] extends ["none" | undefined] ? _Raw
49+
* : _Raw | _Grouped)
50+
* : _Raw)
51+
* : _Raw} _PickElement
2452
*/
2553

2654
/**
27-
* @template T
28-
* @typedef {T extends string[] ? TokenClassificationOutput[] : TokenClassificationOutput} TokenClassificationPipelineResult
55+
* @typedef {{ word: string, score: number, entity: string, index: number, entity_group?: undefined, start?: number, end?: number }} _Raw
56+
* @typedef {{ word: string, score: number, entity_group: string, entity?: undefined, index?: undefined, start?: number, end?: number }} _Grouped
2957
*/
3058

3159
/**
32-
* @typedef {<T extends string | string[]>(texts: T, options?: TokenClassificationPipelineOptions) => Promise<TokenClassificationPipelineResult<T>>} TokenClassificationPipelineCallback
60+
* @typedef {TextPipelineConstructorArgs & TokenClassificationPipelineCallback & Disposable} TokenClassificationPipelineType
61+
*
62+
* @typedef {<Q extends string | string[], const O extends TokenClassificationPipelineOptions = {}>(texts: Q, options?: O) => Promise<Q extends string[] ? TokenClassificationOutput<O>[] : TokenClassificationOutput<O>>} TokenClassificationPipelineCallback
3363
*/
3464

3565
/**
@@ -64,11 +94,29 @@ import { max, softmax } from '../utils/maths.js';
6494
* // { entity: 'I-LOC', score: 0.9975294470787048, index: 8, word: 'America' }
6595
* // ]
6696
* ```
97+
*
98+
* **Example:** Group adjacent BIO/BIOES tokens into entity spans using `aggregation_strategy: "simple"`.
99+
* ```javascript
100+
* import { pipeline } from '@huggingface/transformers';
101+
*
102+
* const classifier = await pipeline('token-classification', 'Xenova/bert-base-NER');
103+
* const output = await classifier('My name is Sarah and I live in London', { aggregation_strategy: 'simple' });
104+
* // [
105+
* // { entity_group: 'PER', score: 0.9985477924346924, word: 'Sarah' },
106+
* // { entity_group: 'LOC', score: 0.999621570110321, word: 'London' }
107+
* // ]
108+
* ```
67109
*/
68110
export class TokenClassificationPipeline
69111
extends /** @type {new (options: TextPipelineConstructorArgs) => TokenClassificationPipelineType} */ (Pipeline)
70112
{
71-
async _call(texts, { ignore_labels = ['O'] } = {}) {
113+
async _call(texts, { ignore_labels = ['O'], aggregation_strategy = 'none' } = {}) {
114+
if (aggregation_strategy !== 'none' && aggregation_strategy !== 'simple') {
115+
throw new Error(
116+
`Invalid aggregation_strategy: "${aggregation_strategy}". Must be one of "none" or "simple".`,
117+
);
118+
}
119+
72120
const isBatched = Array.isArray(texts);
73121

74122
// Run tokenization
@@ -86,43 +134,94 @@ export class TokenClassificationPipeline
86134

87135
const toReturn = [];
88136
for (let i = 0; i < logits.dims[0]; ++i) {
89-
const ids = model_inputs.input_ids[i];
137+
const ids = model_inputs.input_ids[i].tolist();
90138
const batch = logits[i];
91139

92-
// List of tokens that aren't ignored
93140
const tokens = [];
94141
for (let j = 0; j < batch.dims[0]; ++j) {
95142
const tokenData = batch[j];
96143
const topScoreIndex = max(tokenData.data)[1];
97144

98145
const entity = id2label ? id2label[topScoreIndex] : `LABEL_${topScoreIndex}`;
99-
if (ignore_labels.includes(entity)) {
100-
// We predicted a token that should be ignored. So, we skip it.
101-
continue;
102-
}
146+
if (ignore_labels.includes(entity)) continue;
103147

104148
// TODO add option to keep special tokens?
105-
const word = this.tokenizer.decode([ids[j].item()], { skip_special_tokens: true });
106-
if (word === '') {
107-
// Was a special token. So, we skip it.
108-
continue;
109-
}
149+
const word = this.tokenizer.decode([ids[j]], { skip_special_tokens: true });
150+
if (word === '') continue; // Was a special token.
110151

111152
const scores = softmax(tokenData.data);
112-
113153
tokens.push({
114-
entity: entity,
154+
entity,
115155
score: scores[topScoreIndex],
116156
index: j,
117-
word: word,
118-
157+
word,
119158
// TODO: Add support for start and end
120-
// start: null,
121-
// end: null,
122159
});
123160
}
124-
toReturn.push(tokens);
161+
162+
toReturn.push(aggregation_strategy === 'simple' ? groupEntities(tokens, ids, this.tokenizer) : tokens);
125163
}
126164
return isBatched ? toReturn : toReturn[0];
127165
}
128166
}
167+
168+
/**
169+
* Split a raw entity label into its BIOES prefix and tag.
170+
*
171+
* @param {string} entity
172+
* @returns {readonly [prefix: 'B'|'I'|'E'|'S', tag: string]}
173+
*/
174+
function getTag(entity) {
175+
const p = entity[0];
176+
return entity[1] === '-' && (p === 'B' || p === 'I' || p === 'E' || p === 'S')
177+
? [p, entity.slice(2)]
178+
: ['I', entity];
179+
}
180+
181+
/**
182+
* Group raw per-token predictions into entity spans using the SIMPLE strategy.
183+
*
184+
* The only "continue" predicate is: a non-`B`/non-`S` token whose tag matches
185+
* the open group's tag, when that group hasn't been closed by an `E` / `S`.
186+
* Everything else starts a fresh group.
187+
*
188+
* @param {_Raw[]} tokens
189+
* @param {number[]} ids Full input_ids for the sequence (indexed by `token.index`), used to re-decode
190+
* each group so the joined `word` matches what the tokenizer would produce.
191+
* @param {any} tokenizer
192+
* @returns {_Grouped[]}
193+
*/
194+
function groupEntities(tokens, ids, tokenizer) {
195+
/** @type {{ tag: string, start: number, end: number }[]} */
196+
const groups = []; // each entry is a [start, end) slice into `tokens`, plus the shared tag
197+
let openTag = null; // null = no open group
198+
199+
for (let i = 0; i < tokens.length; ++i) {
200+
const [prefix, tag] = getTag(tokens[i].entity);
201+
const extend = openTag === tag && prefix !== 'B' && prefix !== 'S';
202+
if (extend) {
203+
groups[groups.length - 1].end = i + 1;
204+
// `E` terminates the group; subsequent `I`/`E`/`S` start fresh.
205+
if (prefix === 'E') openTag = null;
206+
} else {
207+
groups.push({ tag, start: i, end: i + 1 });
208+
// `S` opens and immediately closes; anything else leaves the group open
209+
// (including a leading `E` — best-effort recovery for a malformed sequence).
210+
openTag = prefix === 'S' ? null : tag;
211+
}
212+
}
213+
214+
return groups.map(({ tag, start, end }) => {
215+
let scoreSum = 0;
216+
const groupIds = [];
217+
for (let i = start; i < end; ++i) {
218+
scoreSum += tokens[i].score;
219+
groupIds.push(ids[tokens[i].index]);
220+
}
221+
return {
222+
entity_group: tag,
223+
score: scoreSum / (end - start),
224+
word: tokenizer.decode(groupIds, { skip_special_tokens: true }),
225+
};
226+
});
227+
}

0 commit comments

Comments
 (0)