Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-Class Span Classification Support #1034

Open
wants to merge 56 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
0730c25
added dataset summary functions
Steven-Yiran Sep 3, 2023
d1500bc
Added docstring and tests
Steven-Yiran Sep 4, 2023
f5abd21
Added auxiliary_inputs support and related tests
Steven-Yiran Sep 6, 2023
c44b3e9
modified return object type
Steven-Yiran Sep 8, 2023
f6d0493
type annotation
Steven-Yiran Sep 9, 2023
daec74b
Update dataset.py
Steven-Yiran Sep 9, 2023
255619e
changed type annotation
Steven-Yiran Sep 9, 2023
c401d87
Merge branch 'cleanlab:master' into master
Steven-Yiran Sep 9, 2023
de1597e
Update cleanlab/object_detection/dataset.py
Steven-Yiran Sep 19, 2023
730eb57
Update cleanlab/object_detection/dataset.py
Steven-Yiran Sep 19, 2023
74195cd
Apply suggestions from code review
Steven-Yiran Sep 19, 2023
e0ef1bf
added class_names as optional argument
Steven-Yiran Sep 20, 2023
ae6fa1b
added format editions
Steven-Yiran Sep 20, 2023
3d8f96a
Plots and class accuracy
Steven-Yiran Sep 27, 2023
b99641a
edited styling
Steven-Yiran Sep 27, 2023
2aefda5
modified type
Steven-Yiran Sep 27, 2023
9063235
added type annotaion
Steven-Yiran Sep 27, 2023
4f77c8d
Updated plot title text
Steven-Yiran Sep 27, 2023
d8b111c
updated function behavior and style
Steven-Yiran Oct 8, 2023
f98f67e
Update summary.py
Steven-Yiran Oct 8, 2023
0a472ae
Merge branch 'master' into steven-yiran-od-metrics
ulya-tkch Oct 13, 2023
6a4b14b
Removed metrics functions for new PR
ulya-tkch Oct 16, 2023
d3521bd
Remove unused function and import
ulya-tkch Oct 16, 2023
f8ba31b
black
ulya-tkch Oct 16, 2023
a369032
losen mypy typing
ulya-tkch Oct 16, 2023
9d487ed
added get_class_metrics function for review
Steven-Yiran Oct 24, 2023
d408c1a
black
Steven-Yiran Oct 24, 2023
fcc6c38
Merge branch 'master' of https://github.com/Steven-Yiran/cleanlab
Steven-Yiran Oct 24, 2023
d8f594f
changing code to use existing functions
aditya1503 Nov 22, 2023
4684972
Merge branch 'master' into master
aditya1503 Nov 22, 2023
070f62f
add test cases
aditya1503 Nov 22, 2023
7c38466
add test cases
aditya1503 Nov 22, 2023
a1b58b2
linting for previous python versions
aditya1503 Nov 22, 2023
6ad6ef9
pool procs fix
aditya1503 Nov 22, 2023
851831e
public function
aditya1503 Nov 22, 2023
840bef2
Docstring and function naming
Steven-Yiran Nov 28, 2023
c5bca68
Update cleanlab/object_detection/summary.py
aditya1503 Dec 13, 2023
940467d
request changes
aditya1503 Dec 13, 2023
8fb478e
linting
aditya1503 Dec 13, 2023
c454346
add docstring
aditya1503 Dec 13, 2023
3df4ef3
type hinting
aditya1503 Dec 13, 2023
bb9a567
fix linting
aditya1503 Dec 13, 2023
e3d67b2
code coverage
aditya1503 Dec 13, 2023
43509d8
Merge pull request #2 from cleanlab/master
Steven-Yiran Dec 13, 2023
0165de5
change class_names to dict
aditya1503 Dec 15, 2023
e3aa7cd
Merge branch 'master' of github.com:Steven-Yiran/cleanlab
aditya1503 Dec 15, 2023
cf5db58
linting
aditya1503 Dec 15, 2023
1995207
linting;
aditya1503 Dec 15, 2023
1fd06f0
add average over IOU comment
aditya1503 Dec 18, 2023
2dee2d7
Update cleanlab/object_detection/summary.py
Steven-Yiran Dec 18, 2023
5c76b77
Update cleanlab/object_detection/summary.py
Steven-Yiran Dec 18, 2023
f3e094c
lint
Steven-Yiran Dec 18, 2023
60a0f3d
Merge branch 'master' of https://github.com/Steven-Yiran/cleanlab
Steven-Yiran Feb 28, 2024
dcef177
add support for multi-class span classification
Steven-Yiran Mar 2, 2024
187ae60
Update test_span_classification.py
Steven-Yiran Mar 3, 2024
dbd9b06
type fixes
Steven-Yiran Mar 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 257 additions & 33 deletions cleanlab/experimental/span_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@
"""

import numpy as np
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Dict, Union

from cleanlab.token_classification.filter import find_label_issues as find_label_issues_token
from cleanlab.token_classification.summary import display_issues as display_issues_token
from cleanlab.token_classification.rank import (
get_label_quality_scores as get_label_quality_scores_token,
)
from cleanlab.internal.util import get_num_classes
from cleanlab.internal.token_classification_utils import color_sentence, get_sentence


def find_label_issues(
labels: list,
pred_probs: list,
):
**kwargs,
) -> Union[Dict[int, List[Tuple[int, int]]], List[Tuple[int, int]]]:
"""Identifies tokens with label issues in a span classification dataset.

Tokens identified with issues will be ranked by their individual label quality score.
Expand All @@ -27,28 +30,39 @@ def find_label_issues(
Parameters
----------
labels:
Nested list of given labels for all tokens.
Refer to documentation for this argument in :py:func:`token_classification.filter.find_label_issues <cleanlab.token_classification.filter.find_label_issues>` for further details.
For single class span classification dataset, `labels` is a nested list of given labels for all tokens, such that `labels[i]` is a list of labels, one for each token in the `i`-th sentence.

For multi-class span classification dataset, `labels` must be a nested list of lists, such that `labels[i]` is a list of lists, one for each token in the `i`-th sentence.
`labels[i][j]` is a list of integers, each representing a span class label for the `j`-th token in the `i`-th sentence.

Note: Currently, only a single span class is supported.
For a dataset with K classes, each label must be in 0, 1, ..., K-1.

pred_probs:
An array of shape ``(T, K)`` of model-predicted class probabilities.
Refer to documentation for this argument in :py:func:`token_classification.filter.find_label_issues <cleanlab.token_classification.filter.find_label_issues>` for further details.
List of np arrays, such that `pred_probs[i]` has shape ``(T, K)`` if the `i`-th sentence contains T tokens.

Each row of `pred_probs[i]` corresponds to a token `t` in the `i`-th sentence,
and contains model-predicted probabilities that `t` belongs to each of the K possible span classes.

Columns of each `pred_probs[i]` should be ordered such that the probabilities correspond to class 0, 1, ..., K-1.

See documentation for :py:func:`token_classification.filter.find_label_issues <cleanlab.token_classification.filter.find_label_issues>` for optional parameters description.

Returns
-------
issues:
List of label issues identified by cleanlab, such that each element is a tuple ``(i, j)``, which
issues: list or dict
For single span class, the return type is a list of label issues identified by cleanlab, such that each element is a tuple ``(i, j)``, which
indicates that the `j`-th token of the `i`-th sentence has a label issue.

For multiple span classes, the return type is a dictionary with span class as keys, and the value is a list of label issues for that class.

These tuples are ordered in `issues` list based on the likelihood that the corresponding token is mislabeled.

Use :py:func:`experimental.span_classification.get_label_quality_scores <cleanlab.experimental.span_classification.get_label_quality_scores>`
to view these issues within the original sentences.

Examples
--------
For a binary span classification task:
>>> import numpy as np
>>> from cleanlab.experimental.span_classification import find_label_issues
>>> labels = [[0, 0, 1, 1], [1, 1, 0]]
Expand All @@ -57,50 +71,260 @@ def find_label_issues(
... np.array([0.1, 0.1, 0.9]),
... ]
>>> find_label_issues(labels, pred_probs)
[(0, 3)]
For a multi-class span classification task:
>>> labels = [
... [[0], [1, 2], [1, 3], [0]],
... [[1], [2, 3], [3]],
... ]
>>> pred_probs = [
... np.array([[0.9, 0.2, 0.3], [0.9, 0.9, 0.2], [0.9, 0.1, 0.7], [0.1, 0.1, 0.1]]),
... np.array([[0.1, 0.9, 0.1], [0.1, 0.9, 0.9], [0.1, 0.9, 0.9]]),
... ]
>>> find_label_issues(labels, pred_probs)
{1: [(0, 0), (1, 0)], 2: [(1, 0), (1, 2)], 3: []}
"""
pred_probs_token = _get_pred_prob_token(pred_probs)
return find_label_issues_token(labels, pred_probs_token)
if not labels or not pred_probs:
raise ValueError("labels/pred_probs cannot be empty.")

labels_flat = [l for sentence in labels for l in sentence]
num_span_class = get_num_classes(labels_flat)

if num_span_class <= 2:
pred_probs_token = [np.stack([1 - probs, probs], axis=1) for probs in pred_probs]
return find_label_issues_token(labels, pred_probs_token, **kwargs)

# type checks for multi-class span classification
if not isinstance(labels_flat[0], list):
raise ValueError("labels must be a nested list of lists, one for each sentence.")

cls_label_issues = {}
# iterate over each span class, excluding the 'O' class
for cl in range(1, num_span_class):
cls_labels = [[1 if cl in label else 0 for label in sentence] for sentence in labels]
cls_pred_probs = [np.array([pred[cl - 1] for pred in sentence]) for sentence in pred_probs]
pred_probs_token = [np.stack([1 - probs, probs], axis=1) for probs in cls_pred_probs]

cls_label_issues[cl] = find_label_issues_token(cls_labels, pred_probs_token, **kwargs)

return cls_label_issues


def display_issues(
issues: list,
issues: Union[dict, list],
tokens: List[List[str]],
*,
labels: Optional[list] = None,
pred_probs: Optional[list] = None,
exclude: List[Tuple[int, int]] = [],
class_names: Optional[List[str]] = None,
top: int = 20,
top: int = 20, # number of issues to display per class
threshold: float = 0.5,
) -> None:
"""
See documentation of :py:meth:`token_classification.summary.display_issues<cleanlab.token_classification.summary.display_issues>` for description.
Display span classification label issues, showing sentence with problematic tokens highlighted. Can also display auxiliary information
such as labels and predicted probabilities when available.

This method is useful for visualizing the label issues identified in each span class.

Parameters
----------
issues:
For single span class, the input is a list of tuples ``(i, j)`` representing a label issue for the `j`-th token of the `i`-th sentence.
For multiple span classes, the input is a dictionary with span class as keys, and the value is a list of label issues for that class.

tokens:
Nested list such that `tokens[i]` is a list of tokens (strings/words) that comprise the `i`-th sentence.

labels:
For single class span classification dataset, `labels` is a nested list of given labels for all tokens, such that `labels[i]` is a list of labels, one for each token in the `i`-th sentence.

For multi-class span classification dataset, `labels` must be a nested list of lists, such that `labels[i]` is a list of lists, one for each token in the `i`-th sentence.
`labels[i][j]` is a list of integers, each representing a span class label for the `j`-th token in the `i`-th sentence.

For a dataset with K classes, each label must be in 0, 1, ..., K-1.

pred_probs:
List of np arrays, such that `pred_probs[i]` has shape ``(T, K)`` if the `i`-th sentence contains T tokens.

Each row of `pred_probs[i]` corresponds to a token `t` in the `i`-th sentence,
and contains model-predicted probabilities that `t` belongs to each of the K possible span classes.

Columns of each `pred_probs[i]` should be ordered such that the probabilities correspond to class 0, 1, ..., K-1.

exclude:
List of tuples ``(cl, pred_res)`` such that the issue of the tokens will be excluded from display if the token is predicted as `pred_res` for span class `cl`.
`pred_res` is 1 if the token is predicted as inside span, and 0 if outside span.

class_names:
Optional length K list of names of each class, such that `class_names[i]` is the string name of the class corresponding to `labels` with value `i`.

If `class_names` is provided, display these string names for predicted and given labels, otherwise display the integer index of classes.

top: int, default=20
Maximum number of issues to be printed.

threshold: float, default=0.5
Threshold value to exclude tokens from display based on their predicted probabilities. This is only used when `exclude` is provided.
"""
display_issues_token(
issues,
tokens,
labels=labels,
pred_probs=pred_probs,
exclude=exclude,
class_names=class_names,
top=top,
)
if not issues or not tokens:
raise ValueError("issues/tokens cannot be empty.")

if isinstance(issues, list):
# single span class
display_issues_token(
issues,
tokens,
labels=labels,
pred_probs=pred_probs,
exclude=exclude,
class_names=class_names,
top=top,
)
return

for cl, cl_issues in issues.items():
# sentence level issues
if cl_issues and not isinstance(cl_issues[0], tuple):
display_issues_token(
cl_issues,
tokens,
labels=labels,
pred_probs=pred_probs,
exclude=exclude,
class_names=class_names,
top=top,
)
continue

shown = min(top, len(cl_issues))
for issue in cl_issues:
i, j = issue
sentence = get_sentence(tokens[i])
word = tokens[i][j]

if exclude and pred_probs:
# check if the token is excluded via threshold 0.5
pred_res = 1 if pred_probs[i][j][cl - 1] > threshold else 0
if (cl, pred_res) in exclude:
continue

shown -= 1
# build issue message for display
issue_message = ""
if class_names:
issue_message += f"Span Class: {class_names[cl]}\n"
else:
issue_message += f"Span Class: {cl}\n"
issue_message += f"Sentence index: {i}, Token index: {j}\n"
issue_message += f"Token: {word}\n"
if labels or pred_probs:
issue_message += "According to provided labels/pred_probs, "
if labels:
label_str = "inside" if cl in labels[i][j] else "outside"
issue_message += f"token marked as {label_str} span "
if pred_probs:
if labels:
issue_message += "but "
else:
issue_message += "token "
probs = pred_probs[i][j][cl - 1] # kth class is at index k-1
issue_message += f"predicted inside span with probability: {probs}"
issue_message += "\n----"
print(issue_message)
print(color_sentence(sentence, word))

if shown == 0:
break
print("\n")


def get_label_quality_scores(
labels: list,
pred_probs: list,
**kwargs,
) -> Tuple[np.ndarray, list]:
) -> Union[Tuple[np.ndarray, list], Tuple[dict, dict]]:
"""
See documentation of :py:meth:`token_classification.rank.get_label_quality_scores<cleanlab.token_classification.rank.get_label_quality_scores>` for description.
Compute label quality scores for labels in each sentence, and for individual tokens in each sentence.

Each score is between 0 and 1.

Lower scores indicate token labels that are less likely to be correct, or sentences that are more likely to contain a mislabeled token.

Parameters
----------
labels:
For single class span classification dataset, `labels` is a nested list of given labels for all tokens, such that `labels[i]` is a list of labels, one for each token in the `i`-th sentence.

For multi-class span classification dataset, `labels` must be a nested list of lists, such that `labels[i]` is a list of lists, one for each token in the `i`-th sentence.
`labels[i][j]` is a list of integers, each representing a span class label for the `j`-th token in the `i`-th sentence.

For a dataset with K classes, each label must be in 0, 1, ..., K-1.

pred_probs:
List of np arrays, such that `pred_probs[i]` has shape ``(T, K)`` if the `i`-th sentence contains T tokens.

Each row of `pred_probs[i]` corresponds to a token `t` in the `i`-th sentence,
and contains model-predicted probabilities that `t` belongs to each of the K possible span classes.

Columns of each `pred_probs[i]` should be ordered such that the probabilities correspond to class 0, 1, ..., K-1.

See documentation of :py:meth:`token_classification.rank.get_label_quality_scores<cleanlab.token_classification.rank.get_label_quality_scores>` for optional parameters description.

Returns
-------
sentence_scores:
A dictionary with span class as keys, and the value is an array of shape ``(N,)`` where `N` is the number of sentences.

Each element of the array is a score between 0 and 1 indicating the overall label quality of the sentence.

token_scores:
A dictionary with span class as keys, and the value is a list of ``pd.Series``, such that the i-th element of the list
contains the label quality scores for individual tokens in the `i`-th sentence.

If `tokens` strings were provided, they are used as index for each ``Series``.

Examples
--------
For a multi span classification task:
>>> import numpy as np
>>> from cleanlab.experimental.span_classification import get_label_quality_scores
>>> labels = [
... [[0], [1, 2], [1, 3], [0]],
... [[1], [2, 3], [3]],
... ]
>>> pred_probs = [
... np.array([[0.9, 0.2, 0.3], [0.9, 0.9, 0.2], [0.9, 0.1, 0.7], [0.1, 0.1, 0.1]]),
... np.array([[0.1, 0.9, 0.1], [0.1, 0.9, 0.9], [0.1, 0.9, 0.9]]),
... ]
>>> sentence_scores, token_scores = get_label_quality_scores(labels, pred_probs)
"""
pred_probs_token = _get_pred_prob_token(pred_probs)
return get_label_quality_scores_token(labels, pred_probs_token, **kwargs)
if not labels or not pred_probs:
raise ValueError("labels/pred_probs cannot be empty.")

labels_flat = [l for sentence in labels for l in sentence]
num_span_class = get_num_classes(labels_flat)

if num_span_class <= 2:
pred_probs_token = [np.stack([1 - probs, probs], axis=1) for probs in pred_probs]
return get_label_quality_scores_token(labels, pred_probs_token, **kwargs)

# type checks for multi-class span classification
if not isinstance(labels_flat[0], list):
raise ValueError("labels must be a nested list of lists, one for each sentence.")
if not isinstance(pred_probs[0][0], np.ndarray) and not isinstance(pred_probs[0][0], list):
raise ValueError("pred_probs must be a list of np arrays, one for each sentence.")

sentence_scores = {}
label_scores = {}
# iterate over each span class, excluding the 'O' class
for cl in range(1, num_span_class):
cls_labels = [[1 if cl in label else 0 for label in sentence] for sentence in labels]
cls_pred_probs = [np.array([pred[cl - 1] for pred in sentence]) for sentence in pred_probs]
pred_probs_token = [np.stack([1 - probs, probs], axis=1) for probs in cls_pred_probs]

sentence_scores[cl], label_scores[cl] = get_label_quality_scores_token(
cls_labels, pred_probs_token, **kwargs
)

def _get_pred_prob_token(pred_probs: list) -> list:
"""Converts pred_probs for span classification to pred_probs for token classification."""
pred_probs_token = []
for probs in pred_probs:
pred_probs_token.append(np.stack([1 - probs, probs], axis=1))
return pred_probs_token
return sentence_scores, label_scores
Loading
Loading