Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf: common_label_issues #1069

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions cleanlab/segmentation/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
Methods to display images and their label issues in a semantic segmentation dataset, as well as summarize the overall types of issues identified.
"""

from typing import Any, Dict, List, Optional
from typing import List, Optional

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from cleanlab.internal.segmentation_utils import _get_summary_optional_params

Expand Down Expand Up @@ -97,8 +96,8 @@ def display_issues(
correct_ordering = np.argsort(-np.sum(issues, axis=(1, 2)))[:top]

try:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
except:
raise ImportError('try "pip install matplotlib"')
Expand Down Expand Up @@ -227,37 +226,47 @@ def common_label_issues(
where each row contains information about a particular given/predicted label swap.
Rows are ordered by the number of label issues inferred to exhibit this type of label swap.
"""
try:
N, K, H, W = pred_probs.shape
except:
pred_probs_shape = pred_probs.shape
if len(pred_probs_shape) != 4:
raise ValueError("pred_probs must be of shape (N, K, H, W)")
N, K, H, W = pred_probs_shape

assert labels.shape == (N, H, W), "labels must be of shape (N, H, W)"

class_names, exclude, top = _get_summary_optional_params(class_names, exclude, top)
# Find issues by pixel coordinates
issue_coords = np.column_stack(np.where(issues))
issue_coords = np.where(issues)

label_coord_issues = labels[issue_coords]
preds = pred_probs[issue_coords[0], :, issue_coords[1], issue_coords[2]].argmax(axis=1)

mask = ~np.isin(preds, exclude)
unique_labels = np.unique(labels)

if verbose:
from tqdm.auto import tqdm

pbar = tqdm(desc="Labels processed", total=len(unique_labels))
# Count issues per class (given label)
count: Dict[int, Any] = {}
for i, j, k in tqdm(issue_coords):
label = labels[i, j, k]
pred = pred_probs[i, :, j, k].argmax()
if label not in count:
count[label] = np.zeros(K, dtype=int)
if pred not in exclude:
count[label][pred] += 1
count = {label: np.zeros(K, dtype=int) for label in unique_labels}
for label in unique_labels:
label_mask = mask & (label_coord_issues == label)
label_preds, pred_counts = np.unique(preds[label_mask], return_counts=True)
for i, pred in enumerate(label_preds):
count[label][pred] = pred_counts[i]
gogetron marked this conversation as resolved.
Show resolved Hide resolved

if verbose:
pbar.update(1)

# Prepare output DataFrame
if class_names is None:
class_names = [str(i) for i in range(K)]

info = []
for given_label, class_name in enumerate(class_names):
if given_label in count:
for pred_label, num_issues in enumerate(count[given_label]):
if num_issues > 0:
info.append([class_name, class_names[pred_label], num_issues])
for given_label, num_issues_array in count.items():
for pred_label, num_issues in enumerate(num_issues_array):
if num_issues > 0:
info.append([class_names[given_label], class_names[pred_label], num_issues])

info = sorted(info, key=lambda x: x[2], reverse=True)[:top]
issues_df = pd.DataFrame(info, columns=["given_label", "predicted_label", "num_pixel_issues"])
Expand Down
Loading