Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding an spurious_correlation as new issue type #872

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d9773fc
adding an spurious_correlation as new issue type
01PrathamS Oct 13, 2023
f6af4d4
new issue type spurious_correlation changes added
01PrathamS Oct 17, 2023
8312f7e
make necessary changes to spurious_correlation function and add unit …
01PrathamS Oct 25, 2023
48bf33c
Revert "make necessary changes to spurious_correlation function and a…
01PrathamS Oct 25, 2023
178c079
Merge branch 'cleanlab:master' into spurious_correlations
01PrathamS Oct 25, 2023
3a7a3fb
apply black formatter
elisno Oct 31, 2023
2660c33
define helper function that defines the score
elisno Oct 31, 2023
33444fd
Turn SpuriousCorrelations class into a dataclass
elisno Oct 31, 2023
4f3d9fa
Add a test file for spurious correlation functionality
elisno Oct 31, 2023
fcabb2c
update unit tests for calculate_correlations method
elisno Nov 1, 2023
81d4fcf
add assertion that the properties of interests should be in the dataf…
elisno Nov 1, 2023
ad0374a
add docstrings
elisno Nov 1, 2023
3907bb9
add docs page for internal class for handling spurious correlations c…
elisno Nov 1, 2023
b3b06c0
add a working Datalab example
elisno Nov 1, 2023
0fb83b1
refactor Datalab._spurious_correlations
elisno Nov 1, 2023
5d9698b
apply black formatter
elisno Nov 1, 2023
a185ecc
rename score column name
elisno Nov 1, 2023
e6cf509
remove line that suppresses warnings
elisno Nov 1, 2023
0ea9c9a
cast feature array to numpy array, add more type-hints
elisno Nov 1, 2023
861aa93
add type-hints for return values
elisno Nov 1, 2023
fc42d7e
format docstring
elisno Nov 1, 2023
fb5aa35
adjust condition for checking property 2 of the scoring function
elisno Nov 1, 2023
f35b3c1
Merge branch 'cleanlab:master' into spurious_correlations
01PrathamS Nov 12, 2023
2ae6b7b
fix baseline accuracy
elisno Nov 13, 2023
409c5b7
update expected scores based on fixed baseline accuracy
elisno Nov 13, 2023
1ef0397
hello
01PrathamS Dec 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 82 additions & 33 deletions cleanlab/datalab/datalab.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@
import pandas as pd

import cleanlab
from cleanlab.datalab.internal.adapter.imagelab import create_imagelab
from cleanlab.datalab.internal.adapter.imagelab import (
create_imagelab,
)
from cleanlab.datalab.internal.data import Data
from cleanlab.datalab.internal.display import _Displayer
from cleanlab.datalab.internal.helper_factory import (
_DataIssuesBuilder,
data_issues_factory,
issue_finder_factory,
report_factory,
)
from cleanlab.datalab.internal.issue_manager_factory import (
list_default_issue_types as _list_default_issue_types,
list_possible_issue_types as _list_possible_issue_types,
)
from cleanlab.datalab.internal.issue_finder import IssueFinder
from cleanlab.datalab.internal.serialize import _Serializer
from cleanlab.datalab.internal.spurious_correlation import SpuriousCorrelations

if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt
Expand All @@ -49,7 +49,6 @@

DatasetLike = Union[Dataset, pd.DataFrame, Dict[str, Any], List[Dict[str, Any]], str]


__all__ = ["Datalab"]


Expand Down Expand Up @@ -78,10 +77,6 @@ class Datalab:
- path to a local file: Text (.txt), CSV (.csv), JSON (.json)
- or a dataset identifier on the Hugging Face Hub

task : str
The type of machine learning task that the dataset is used for.
By default, this is set to "classification", but you can also set it to "regression" if you are working with a regression dataset.

label_name : str, optional
The name of the label column in the dataset.

Expand All @@ -106,30 +101,20 @@ class Datalab:
def __init__(
self,
data: "DatasetLike",
task: str = "classification",
label_name: Optional[str] = None,
image_key: Optional[str] = None,
verbosity: int = 1,
) -> None:
# Assume continuous values of labels for regression task
# Map labels to integers for classification task
map_labels_to_int = task == "classification" # TODO: handle more generally

self._data = Data(data, label_name, map_to_int=map_labels_to_int)
self._data = Data(data, label_name)
self.data = self._data._data
self.task = task
self._labels = self._data.labels
self._label_map = self._labels.label_map
self.label_name = self._labels.label_name
self._data_hash = self._data._data_hash
self.cleanlab_version = cleanlab.version.__version__
self.verbosity = verbosity
self._imagelab = create_imagelab(dataset=self.data, image_key=image_key)

# Create the builder for DataIssues
builder = _DataIssuesBuilder(self._data)
builder.set_imagelab(self._imagelab).set_task(task)
self.data_issues = builder.build()
self.data_issues = data_issues_factory(self._imagelab)(self._data)

# todo: check displayer methods
def __repr__(self) -> str:
Expand All @@ -145,7 +130,7 @@ def labels(self) -> np.ndarray:

@property
def has_labels(self) -> bool:
"""Whether the dataset has labels, and that they are in a [0, 1, ..., K-1] format."""
"""Whether the dataset has labels."""
return self._labels.is_available

@property
Expand Down Expand Up @@ -306,21 +291,84 @@ def find_issues(
"No issue types were specified so no issues will be found in the dataset. Set `issue_types` as None to consider a default set of issues."
)
return None
issue_finder = issue_finder_factory(self._imagelab)(
datalab=self, task=self.task, verbosity=self.verbosity
)

issue_finder = issue_finder_factory(self._imagelab)(datalab=self, verbosity=self.verbosity)
issue_finder.find_issues(
pred_probs=pred_probs,
features=features,
knn_graph=knn_graph,
issue_types=issue_types,
)

if self.verbosity:
print(
f"\nAudit complete. {self.data_issues.issue_summary['num_issues'].sum()} issues found in the dataset."
)

def _spurious_correlations(self, properties: Optional[List[str]] = None) -> pd.DataFrame:
"""
Identify potential spurious correlations between image properties and their corresponding scores.


Parameters:
-----------
properties : Optional[List[str]]
A list of specific image properties (e.g. 'dark', 'grayscale') to be analyzed.
If None, all available properties from the issue summary will be considered.

Returns:
--------
A DataFrame indicating correlations for each image property.

Note
----
This method is a wrapper around the :py:meth:`SpuriousCorrelations.calculate_correlations <cleanlab.datalab.internal.spurious_correlation.SpuriousCorrelations.calculate_correlations>` method.

It is still a work in progress and may be subject to change in future versions.

See Also
--------
cleanlab.datalab.internal.spurious_correlation.SpuriousCorrelations
"""
# TODO: Update this check when support for more properties is added.
if self._imagelab is None:
raise NotImplementedError("No ImageLab instance found. Please specify properties.")

# TODO: Update this check when support for more properties is added.
if self._imagelab.issue_summary.empty or self.issues.empty:
raise ValueError("No issues found in ImageLab. Please run find_issues() first.")

# Default to all available properties from the issue summary.
if properties is None:
_issue_summary = self._imagelab.issue_summary
properties = _issue_summary["issue_type"].values.tolist()

# Ensure only properties present in both datalab and imagelab are considered.
if self._imagelab:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are we checking here? _issue_summary was obtained using self._imagelab, so is this condition required?

properties = [
p for p in properties if p in self.issue_summary["issue_type"].values.tolist()
]

# Validate the input properties.
valid_properties = self.issue_summary["issue_type"].values.tolist()
for p in properties:
if p not in valid_properties:
raise ValueError(
f"{p} is not a valid property. Available options: {valid_properties}"
)

# Convert score column names to regular column names for easier querying.
score_column_to_column_name = lambda name_score: name_score.split("_score")[0]
score_columns = [c for c in self.issues.columns if c.endswith("_score")]
rename_map = dict(
zip(score_columns, [score_column_to_column_name(c) for c in score_columns])
)

# Filter and rename columns in the issues dataframe.
df = self.issues[score_columns].rename(columns=rename_map)
df = df[[c for c in rename_map.values() if c in properties]]

return SpuriousCorrelations(data=df, labels=self.labels).calculate_correlations()

def report(
self,
*,
Expand Down Expand Up @@ -357,7 +405,6 @@ def report(

reporter = report_factory(self._imagelab)(
data_issues=self.data_issues,
task=self.task,
verbosity=verbosity,
include_description=include_description,
show_summary_score=show_summary_score,
Expand Down Expand Up @@ -504,7 +551,8 @@ def get_info(self, issue_name: Optional[str] = None) -> Dict[str, Any]:
"""
return self.data_issues.get_info(issue_name)

def list_possible_issue_types(self) -> List[str]:
@staticmethod
def list_possible_issue_types() -> List[str]:
"""Returns a list of all registered issue types.

Any issue type that is not in this list cannot be used in the :py:meth:`find_issues` method.
Expand All @@ -517,9 +565,10 @@ def list_possible_issue_types(self) -> List[str]:
--------
:py:class:`REGISTRY <cleanlab.datalab.internal.issue_manager_factory.REGISTRY>` : All available issue types and their corresponding issue managers can be found here.
"""
return _list_possible_issue_types(task=self.task)
return IssueFinder.list_possible_issue_types()

def list_default_issue_types(self) -> List[str]:
@staticmethod
def list_default_issue_types() -> List[str]:
"""Returns a list of the issue types that are run by default
when :py:meth:`find_issues` is called without specifying `issue_types`.

Expand All @@ -531,7 +580,7 @@ def list_default_issue_types(self) -> List[str]:
--------
:py:class:`REGISTRY <cleanlab.datalab.internal.issue_manager_factory.REGISTRY>` : All available issue types and their corresponding issue managers can be found here.
"""
return _list_default_issue_types(task=self.task)
return IssueFinder.list_default_issue_types()

def save(self, path: str, force: bool = False) -> None:
"""Saves this Datalab object to file (all files are in folder at `path/`).
Expand Down
14 changes: 5 additions & 9 deletions cleanlab/datalab/internal/adapter/imagelab.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
IMAGELAB_ISSUES_MAX_PREVALENCE,
)
from cleanlab.datalab.internal.data import Data
from cleanlab.datalab.internal.data_issues import DataIssues, _InfoStrategy
from cleanlab.datalab.internal.data_issues import DataIssues
from cleanlab.datalab.internal.issue_finder import IssueFinder
from cleanlab.datalab.internal.report import Reporter

Expand Down Expand Up @@ -70,8 +70,6 @@ class ImagelabDataIssuesAdapter(DataIssues):
----------
data :
The data object for which the issues are being collected.
strategy :
Strategy used for processing info dictionaries.

Parameters
----------
Expand All @@ -84,8 +82,8 @@ class ImagelabDataIssuesAdapter(DataIssues):
A dictionary that contains information and statistics about the data and each issue type.
"""

def __init__(self, data: Data, strategy: _InfoStrategy) -> None:
super().__init__(data, strategy)
def __init__(self, data: Data) -> None:
super().__init__(data)

def _update_issues_imagelab(self, imagelab: "Imagelab", overlapping_issues: List[str]) -> None:
overwrite_columns = [f"is_{issue_type}_issue" for issue_type in overlapping_issues]
Expand Down Expand Up @@ -145,14 +143,12 @@ def __init__(
self,
data_issues: "DataIssues",
imagelab: "Imagelab",
task: str,
verbosity: int = 1,
include_description: bool = True,
show_summary_score: bool = False,
):
super().__init__(
data_issues=data_issues,
task=task,
verbosity=verbosity,
include_description=include_description,
show_summary_score=show_summary_score,
Expand All @@ -168,8 +164,8 @@ def report(self, num_examples: int) -> None:


class ImagelabIssueFinderAdapter(IssueFinder):
def __init__(self, datalab, task, verbosity):
super().__init__(datalab, task, verbosity)
def __init__(self, datalab, verbosity):
super().__init__(datalab, verbosity)
self.imagelab = self.datalab._imagelab

def _get_imagelab_issue_types(self, issue_types, **kwargs):
Expand Down
39 changes: 5 additions & 34 deletions cleanlab/datalab/internal/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,6 @@ class Data:
label_name : Union[str, List[str]]
Name of the label column in the dataset.

map_to_int : bool
Whether to map the labels to integers, e.g. [0, 1, ..., K-1] where K is the number of classes.
If False, the labels are not mapped to integers, e.g. for regression tasks.

Warnings
--------
Optional dependencies:
Expand All @@ -130,13 +126,11 @@ class Data:
:py:class:`Datalab <cleanlab.datalab.datalab.Datalab>` to work.
"""

def __init__(
self, data: "DatasetLike", label_name: Optional[str] = None, map_to_int: bool = True
) -> None:
def __init__(self, data: "DatasetLike", label_name: Optional[str] = None) -> None:
self._validate_data(data)
self._data = self._load_data(data)
self._data_hash = hash(self._data)
self.labels = Label(data=self._data, label_name=label_name, map_to_int=map_to_int)
self.labels = Label(data=self._data, label_name=label_name)

def _load_data(self, data: "DatasetLike") -> Dataset:
"""Checks the type of dataset and uses the correct loader method and
Expand Down Expand Up @@ -224,31 +218,17 @@ class Label:
"""
Class to represent labels in a dataset.

It stores the labels as a numpy array and maps them to integers if necessary.
If a mapping is not necessary, e.g. for regression tasks, the mapping will be an empty dictionary.

Parameters
----------
data :
A Hugging Face Dataset object.

label_name : str
Name of the label column in the dataset.

map_to_int : bool
Whether to map the labels to integers, e.g. [0, 1, ..., K-1] where K is the number of classes.
If False, the labels are not mapped to integers, e.g. for regression tasks.
"""

def __init__(
self, *, data: Dataset, label_name: Optional[str] = None, map_to_int: bool = True
) -> None:
def __init__(self, *, data: Dataset, label_name: Optional[str] = None) -> None:
self._data = data
self.label_name = label_name
self.labels = labels_to_array([])
self.label_map: Mapping[str, Any] = {}
if label_name is not None:
self.labels, self.label_map = _extract_labels(data, label_name, map_to_int)
self.labels, self.label_map = _extract_labels(data, label_name)
self._validate_labels()

def __len__(self) -> int:
Expand Down Expand Up @@ -293,7 +273,7 @@ def _validate_labels(self) -> None:
assert len(labels) == len(self._data)


def _extract_labels(data: Dataset, label_name: str, map_to_int: bool) -> Tuple[np.ndarray, Mapping]:
def _extract_labels(data: Dataset, label_name: str) -> Tuple[np.ndarray, Mapping]:
"""
Picks out labels from the dataset and formats them to be [0, 1, ..., K-1]
where K is the number of classes. Also returns a mapping from the formatted
Expand All @@ -305,15 +285,9 @@ def _extract_labels(data: Dataset, label_name: str, map_to_int: bool) -> Tuple[n

Parameters
----------
data : datasets.Dataset
A Hugging Face Dataset object.

label_name : str
Name of the column in the dataset that contains the labels.

map_to_int : bool
Whether to map the labels to integers, e.g. [0, 1, ..., K-1] where K is the number of classes.
If False, the labels are not mapped to integers, e.g. for regression tasks.
Returns
-------
formatted_labels : np.ndarray
Expand All @@ -327,9 +301,6 @@ def _extract_labels(data: Dataset, label_name: str, map_to_int: bool) -> Tuple[n
if labels.ndim != 1:
raise ValueError("labels must be 1D numpy array.")

if not map_to_int:
# Don't map labels to integers, e.g. for regression tasks
return labels, {}
label_name_feature = data.features[label_name]
if isinstance(label_name_feature, ClassLabel):
label_map = {label: label_name_feature.str2int(label) for label in label_name_feature.names}
Expand Down
Loading