🚀 Feature
I think the addition of a SequenceAccuracy metric (or an enhancement to the existing Accuracy metric) to natively support sequence-to-sequence outputs common in NLP and Transformer-based models will be good.
Motivation:
When working with sequence models, the model output is typically of shape (batch, sequence_length, num_classes) and the target is (batch, sequence_length). Currently, the standard ignite.metrics.Accuracy expects inputs to be already flattened or simple 1D/2D tensors.
I have to write a custom output_transform for every single project to flatten the tensors and create masking.
Current work around:
def get_metrics(loss_fn, pad_idx=0):
def accuracy_transform(output):
preds, y = output
preds_flat = preds.reshape(-1, preds.size(-1))
y_flat = y.reshape(-1)
mask = y_flat != pad_idx
preds_masked = preds_flat[mask]
y_masked = y_flat[mask]
return preds_masked, y_masked
return {
"accuracy": metrics.Accuracy(output_transform=accuracy_transform),
"loss": metrics.Loss(loss_fn),
}
As Accuracy is one of the most common nlp metric I think offering a more plug and play option will be godd.
We can create a new metric SequenceAccuracy which can be a wrapper around Accuracy and perform masking, alternatively we can modify the Accuracy itself to support sequences and masking.
🚀 Feature
I think the addition of a
SequenceAccuracymetric (or an enhancement to the existing Accuracy metric) to natively support sequence-to-sequence outputs common in NLP and Transformer-based models will be good.Motivation:
When working with sequence models, the model output is typically of shape
(batch, sequence_length, num_classes)and the target is(batch, sequence_length). Currently, the standard ignite.metrics.Accuracy expects inputs to be already flattened or simple 1D/2D tensors.I have to write a custom output_transform for every single project to flatten the tensors and create masking.
Current work around:
As
Accuracyis one of the most common nlp metric I think offering a more plug and play option will be godd.We can create a new metric
SequenceAccuracywhich can be a wrapper aroundAccuracyand perform masking, alternatively we can modify theAccuracyitself to support sequences and masking.