Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#860] Adding Spurious Correlation feature #1140

Merged
merged 10 commits into from
Jun 27, 2024
Merged
54 changes: 54 additions & 0 deletions cleanlab/datalab/datalab.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from cleanlab.datalab.internal.serialize import _Serializer
from cleanlab.datalab.internal.task import Task
from cleanlab.datalab.internal.spurious_correlation import SpuriousCorrelations

if TYPE_CHECKING: # pragma: no cover
import numpy.typing as npt
Expand Down Expand Up @@ -635,3 +636,56 @@ def load(path: str, data: Optional[Dataset] = None) -> "Datalab":
load_message = f"Datalab loaded from folder: {path}"
print(load_message)
return datalab

def _spurious_correlation(self) -> pd.DataFrame:
"""
Assess potential spurious correlations in issue severity scores.

This method calculates scores indicating the likelihood of spurious correlations
for various issue severity scores in the dataset, as estimated by the `find_issues()` method.
Currently, it focuses on severity scores related to image attributes.
If `find_issues()` has not been called, it raises a ValueError.

Returns
-------
`correlations_df` : pandas.DataFrame
A DataFrame containing the calculated correlations for each property, excluding 'class_imbalance_score'.
The DataFrame includes:
- 'property' : str
The name of the property.
- 'score' : float
The spurious correlation score (between 0 and 1) for the property,
where a low score indicates a higher likelihood of spurious correlation,
and a high score indicates a lower likelihood.

Raises
------
ValueError
If the issues have not been identified (i.e., `find_issues()` has not been called).

Notes
-----
This method currently focuses on image-related severity scores, with potential for future expansions.
"""
try:
issues = self.get_issues()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a validation step that ensures that the issues dataframe has all the relevant (image-specific) scores.
If it doesn't an error with a helpful message should be raised.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a validation step here to cjeck all vision/image issues are present in the correlations dataframe

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, the issues dataframe should be validated, not the correlations_df.

except ValueError:
raise ValueError(
"Please call find_issues() before proceeding with finding Spurious Correlations"
)

if not all(
default_cleanvision_issue + "_score" in issues.columns.tolist()
for default_cleanvision_issue in DEFAULT_CLEANVISION_ISSUES.keys()
):
raise ValueError("All vision issue scores are not computed by get_issues() method")

cleanvision_issues_columns = [
default_cleanvision_issue + "_score"
for default_cleanvision_issue in DEFAULT_CLEANVISION_ISSUES.keys()
]
issues_score_data = issues[cleanvision_issues_columns]
property_correlations = SpuriousCorrelations(data=issues_score_data, labels=self.labels)
correlations_df = property_correlations.calculate_correlations()

return correlations_df
113 changes: 113 additions & 0 deletions cleanlab/datalab/internal/spurious_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import warnings

import numpy as np
import pandas as pd
from sklearn.model_selection import cross_val_score
from sklearn.naive_bayes import GaussianNB

warnings.filterwarnings("ignore")


@dataclass
class SpuriousCorrelations:
data: pd.DataFrame
labels: Union[np.ndarray, list]
properties_of_interest: Optional[List[str]] = None

def __post_init__(self):
# Must have same number of rows
if not len(self.data) == len(self.labels):
raise ValueError(
"The number of rows in the data dataframe must be the same as the number of labels."
)

# Set default properties_of_interest if not provided
if self.properties_of_interest is None:
self.properties_of_interest = self.data.columns.tolist()

if not all(isinstance(p, str) for p in self.properties_of_interest):
raise TypeError("properties_of_interest must be a list of strings.")

def calculate_correlations(self) -> pd.DataFrame:
"""Calculates the spurious correlation scores for each property of interest found in the dataset."""
baseline_accuracy = self._get_baseline()
assert (
self.properties_of_interest is not None
), "properties_of_interest must be set, but is None."
property_scores = {
str(property_of_interest): self.calculate_spurious_correlation(
property_of_interest, baseline_accuracy
)
for property_of_interest in self.properties_of_interest
}
data_score = pd.DataFrame(list(property_scores.items()), columns=["property", "score"])
return data_score

def _get_baseline(self) -> float:
"""Calculates the baseline accuracy of the dataset. The baseline model is predicting the most common label."""
baseline_accuracy = np.bincount(self.labels).argmax() / len(self.labels)
return float(baseline_accuracy)

def calculate_spurious_correlation(
self, property_of_interest, baseline_accuracy: float
) -> float:
"""Scores the dataset based on a given property of interest.

Parameters
----------
property_of_interest :
The property of interest to score the dataset on.

baseline_accuracy :
The accuracy of the baseline model.

Returns
-------
score :
A correlation score of the dataset's labels to the property of interest.
"""
X = self.data[property_of_interest].values.reshape(-1, 1)
y = self.labels
mean_accuracy = _train_and_eval(X, y)
return relative_room_for_improvement(baseline_accuracy, float(mean_accuracy))


def _train_and_eval(X, y, cv=5) -> float:
classifier = GaussianNB() # TODO: Make this a parameter
cv_accuracies = cross_val_score(classifier, X, y, cv=cv, scoring="accuracy")
mean_accuracy = float(np.mean(cv_accuracies))
return mean_accuracy


def relative_room_for_improvement(
baseline_accuracy: float, mean_accuracy: float, eps: float = 1e-8
) -> float:
"""
Calculate the relative room for improvement given a baseline and trial accuracy.

This function computes the ratio of the difference between perfect accuracy (1.0)
and the trial accuracy to the difference between perfect accuracy and the baseline accuracy.
If the baseline accuracy is perfect (i.e., 1.0), an epsilon value is added to the denominator
to avoid division by zero.

Parameters
----------
baseline_accuracy :
The accuracy of the baseline model. Must be between 0 and 1.
mean_accuracy :
The accuracy of the trial model being compared. Must be between 0 and 1.
eps :
A small constant to avoid division by zero when baseline accuracy is 1. Defaults to 1e-8.

Returns
-------
score :
The relative room for improvement, bounded between 0 and 1.
"""
numerator = 1 - mean_accuracy
denominator = 1 - baseline_accuracy
if baseline_accuracy == 1:
denominator += eps

Check warning on line 112 in cleanlab/datalab/internal/spurious_correlation.py

View check run for this annotation

Codecov / codecov/patch

cleanlab/datalab/internal/spurious_correlation.py#L112

Added line #L112 was not covered by tests
return min(1, numerator / denominator)
205 changes: 205 additions & 0 deletions docs/source/tutorials/datalab/workflows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,211 @@
"assert all(class_imbalance_issues.query(\"is_class_imbalance_issue\")[\"class_imbalance_score\"] == 0.02), \"Class imbalance issue scores are not as expected\"\n",
"assert all(class_imbalance_issues.query(\"not is_class_imbalance_issue\")[\"class_imbalance_score\"] == 1.0), \"Class imbalance issue scores are not as expected\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Find Spurious Correlation between Vision Dataset features and class labels\n",
"\n",
"In this section, we demonstrate how to identify spurious correlations in a vision dataset using the `cleanlab` library. Spurious correlations are unintended associations in the data that do not reflect the true underlying relationships, potentially leading to misleading model predictions and poor generalization.\n",
"\n",
"We will utilize the `Datalab` class from cleanlab with the `image_key` attribute to pinpoint vision-specific issues such as `dark_score`, `blurry_score`, `odd_aspect_ratio_score`, and more in the dataset. By analyzing these correlations, we can understand their impact on model performance and take steps to enhance the robustness and reliability of our machine learning models."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Load the dataset\n",
"\n",
"We will demonstrate this workflow using the CIFAR-10 dataset by selecting 100 images from two random classes. To illustrate the impact of spurious correlations between image features and class labels, we will showcase how altering all images of a class, such as darkening them, significantly reduces the `dark_score`. This demonstrates the strong correlation detection of darkness within the dataset.\n",
"\n",
"Similarly, we can observe significant reductions in `blurry_score` and `odd_aspect_ratio_score` when one of the classes contains images with corresponding characteristics such as blurriness or an unusual aspect ratio between width and height."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from cleanlab import Datalab\n",
"from torchvision.datasets import CIFAR10\n",
"from datasets import Dataset\n",
"import io\n",
"from PIL import Image, ImageEnhance\n",
"import random\n",
"import numpy as np\n",
"from IPython.display import display, Markdown\n",
"\n",
"# Download the CIFAR-10 test dataset\n",
"data = CIFAR10(root='./data', train=False, download=True)\n",
"\n",
"# Set seed for reproducibility\n",
"np.random.seed(0)\n",
"random.seed(0)\n",
"\n",
"# Randomly select two classes\n",
"classes = list(range(len(data.classes)))\n",
"selected_classes = random.sample(classes, 2)\n",
"\n",
"# Function to convert PIL object to PNG image to be passed to the Datalab object\n",
"def convert_to_png_image(image):\n",
" buffer = io.BytesIO()\n",
" image.save(buffer, format='PNG')\n",
" buffer.seek(0)\n",
" return Image.open(buffer)\n",
"\n",
"# Generating 100 ('max_num_images') images from each of the two chosen classes\n",
"max_num_images = 100\n",
"list_images, list_labels = [], []\n",
"num_images = {selected_classes[0]: 0, selected_classes[1]: 0}\n",
"\n",
"for img, label in data:\n",
" if num_images[selected_classes[0]] == max_num_images and num_images[selected_classes[1]] == max_num_images:\n",
" break\n",
" if label in selected_classes:\n",
" if num_images[label] == max_num_images:\n",
" continue\n",
" list_images.append(convert_to_png_image(img))\n",
" list_labels.append(label)\n",
" num_images[label] += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Creating `Dataset` object to be passed to the `Datalab` object to find vision-related issues"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a datasets.Dataset object from list of images and their corresponding labels\n",
"dataset_dict = {'image': list_images, 'label': list_labels}\n",
"dataset = Dataset.from_dict(dataset_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. (Optional) Creating a transformed dataset using `ImageEnhance` to induce darkness"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Function to reduce brightness to 30%\n",
"def apply_dark(image):\n",
" \"\"\"Decreases brightness of the image.\"\"\"\n",
" enhancer = ImageEnhance.Brightness(image)\n",
" return enhancer.enhance(0.3)\n",
"\n",
"# Applying the darkness filter to one of the classes\n",
"transformed_list_images = [\n",
" apply_dark(img) if label == selected_classes[0] else img\n",
" for label, img in zip(list_labels, list_images)\n",
"]\n",
"\n",
"# Creating datasets.Dataset object from the transformed dataset\n",
"transformed_dataset_dict = {'image': transformed_list_images, 'label': list_labels}\n",
"transformed_dataset = Dataset.from_dict(transformed_dataset_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. (Optional) Visualizing Images in the dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def plot_images(dataset_dict):\n",
" \"\"\"Plots the first 15 images from the dataset dictionary.\"\"\"\n",
" images = dataset_dict['image']\n",
" labels = dataset_dict['label']\n",
" \n",
" # Define the number of images to plot\n",
" num_images_to_plot = 15\n",
" num_cols = 5 # Number of columns in the plot grid\n",
" num_rows = (num_images_to_plot + num_cols - 1) // num_cols # Calculate rows needed\n",
" \n",
" # Create a figure\n",
" fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 6))\n",
" axes = axes.flatten()\n",
" \n",
" # Plot each image\n",
" for i in range(num_images_to_plot):\n",
" img = images[i]\n",
" label = labels[i]\n",
" axes[i].imshow(img)\n",
" axes[i].set_title(f'Label: {label}')\n",
" axes[i].axis('off')\n",
" \n",
" # Hide any remaining empty subplots\n",
" for i in range(num_images_to_plot, len(axes)):\n",
" axes[i].axis('off')\n",
" \n",
" # Show the plot\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"plot_images(dataset_dict)\n",
"plot_images(transformed_dataset_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5. Finding image-specific property scores"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Function to find image-specific property scores given the dataset object\n",
"def get_property_scores(dataset):\n",
" lab = Datalab(data=dataset, label_name=\"label\", image_key=\"image\")\n",
" lab.find_issues()\n",
" return lab._spurious_correlation()\n",
"\n",
"# Finds specific property score in the dataframe containing property scores \n",
"def get_specific_property_score(property_scores_df, property_name):\n",
" return property_scores_df[property_scores_df['property'] == property_name]['score'].iloc[0]\n",
"\n",
"# Finding scores in original and transformed dataset\n",
"standard_property_scores = get_property_scores(dataset)\n",
"transformed_property_scores = get_property_scores(transformed_dataset)\n",
"\n",
"# Displaying the scores dataframe\n",
"display(Markdown(\"### Vision-specific property scores in the original dataset\"))\n",
"display(standard_property_scores)\n",
"display(Markdown(\"### Vision-specific property scores in the transformed dataset\"))\n",
"display(transformed_property_scores)\n",
"\n",
"# Smaller 'dark_score' value for modified dataframe shows strong correlation with the class labels in the transformed dataset\n",
"assert get_specific_property_score(standard_property_scores, 'dark_score') > get_specific_property_score(transformed_property_scores, 'dark_score')"
]
}
],
"metadata": {
Expand Down
Loading
Loading