Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update knn shapely score computation #1142

Merged
merged 11 commits into from
Jun 19, 2024
53 changes: 40 additions & 13 deletions cleanlab/data_valuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,43 @@
from cleanlab.internal.neighbor.knn_graph import create_knn_graph_and_index


def _knn_shapley_score(knn_graph: csr_matrix, labels: np.ndarray, k: int) -> np.ndarray:
"""Compute the Shapley values of data points based on a knn graph."""
N = labels.shape[0]
def _knn_shapley_score(neighbor_indices: np.ndarray, y: np.ndarray, k: int) -> np.ndarray:
"""Compute the Data Shapley values of data points using neighbor indices in a K-Nearest Neighbors (KNN) graph.

This function leverages equations (18) and (19) from the paper available at https://arxiv.org/abs/1908.08619
for computational efficiency.

Parameters
----------
neighbor_indices :
A 2D array where each row contains the indices of the k-nearest neighbors for each data point.
y :
A 1D array of target values corresponding to the data points.
k :
The number of nearest neighbors to consider for each data point.

Notes
-----
- The training set is used as its own test set for the KNN-Shapley value computation, meaning y_test is the same as y_train.
- `neighbor_indices` are assumed to be pre-sorted by distance, with the nearest neighbors appearing first, and with at least `k` neighbors.
- Unlike the referenced paper, this implementation does not account for an upper error bound epsilon.
Consequently, K* is treated as equal to K instead of K* = max(K, 1/epsilon).
- This simplification implies that the term min(K, j + 1) will always be j + 1, which is offset by the
corresponding denominator term in the inner loop.
- Dividing by K in the end achieves the same result as dividing by K* in the paper.
- The pre-allocated `scores` array incorporates equation (18) for j = k - 1, ensuring efficient computation.
"""
N = y.shape[0]
scores = np.zeros((N, N))
dist = knn_graph.indices.reshape(N, -1)

for y, s, dist_i in zip(labels, scores, dist):
idx = dist_i[::-1]
ans = labels[idx]
s[idx[k - 1]] = float(ans[k - 1] == y)
ans_matches = (ans == y).flatten()
for y_alpha, s_alpha, idx in zip(y, scores, neighbor_indices):
y_neighbors = y[idx]
ans_matches = (y_neighbors == y_alpha).flatten()
for j in range(k - 2, -1, -1):
s[idx[j]] = s[idx[j + 1]] + float(int(ans_matches[j]) - int(ans_matches[j + 1]))
return 0.5 * (np.mean(scores / k, axis=0) + 1)
s_alpha[idx[j]] = s_alpha[idx[j + 1]] + float(
int(ans_matches[j]) - int(ans_matches[j + 1])
)
return np.mean(scores / k, axis=0)


def data_shapley_knn(
Expand Down Expand Up @@ -91,7 +114,7 @@ def data_shapley_knn(
An array of transformed Data Shapley values for each data point, calibrated to indicate their relative importance.
These scores have been adjusted to fall within 0 to 1.
Values closer to 1 indicate data points that are highly influential and positively contribute to a trained ML model's performance.
Conversely, scores below 0.5 indicate data points estimated to negatively impact model performance.
Conversely, scores below 0.5 indicate data points estimated to negatively impact model performance.

Raises
------
Expand All @@ -113,4 +136,8 @@ def data_shapley_knn(
# Use provided knn_graph or compute it from features
if knn_graph is None:
knn_graph, _ = create_knn_graph_and_index(features, n_neighbors=k, metric=metric)
return _knn_shapley_score(knn_graph, labels, k)

num_examples = labels.shape[0]
distances = knn_graph.indices.reshape(num_examples, -1)
scores = _knn_shapley_score(neighbor_indices=distances, y=labels, k=k)
return 0.5 * (scores + 1)
jwmueller marked this conversation as resolved.
Show resolved Hide resolved
29 changes: 16 additions & 13 deletions docs/source/tutorials/datalab/workflows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,11 @@
"metadata": {},
"source": [
"### 4. (Optional) Visualize Data Valuation Scores\n",
"Finally, we will visualize the data valuation scores using a histogram to understand the distribution of scores across different labels."
"Finally, we will visualize the data valuation scores using a strip plot to understand the distribution across different labels.\n",
"\n",
"A score below 0.5 indicates a negative contribution to the model's training performance, while a score above 0.5 indicates a positive contribution.\n",
"\n",
"By examining the scores across different labels, we can identify whether positive or negative contributions are disproportionately concentrated in a single class. This can help us detect potential biases in the training data."
elisno marked this conversation as resolved.
Show resolved Hide resolved
]
},
{
Expand All @@ -337,29 +341,28 @@
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Prepare the data for plotting a histogram\n",
"# Prepare the data for plotting\n",
"plot_data = (\n",
" data_valuation_issues\n",
" # Optionally, add a 'given_label' column to distinguish between labels in the histogram\n",
" .join(pd.DataFrame({\"given_label\": df_text[\"Label\"]}))\n",
")\n",
"\n",
"# Plot histograms of data valuation scores for each label\n",
"sns.histplot(\n",
"# Plot strip plots of data valuation scores for each label\n",
"sns.stripplot(\n",
jwmueller marked this conversation as resolved.
Show resolved Hide resolved
" data=plot_data,\n",
" hue=\"given_label\", # Comment out if no labels should be used in the visualization\n",
" x=\"data_valuation_score\",\n",
" bins=15,\n",
" element=\"step\",\n",
" multiple=\"stack\", # Stack histograms for different labels\n",
" hue=\"given_label\", # Comment out if no labels should be used in the visualization\n",
" dodge=True,\n",
" jitter=0.3,\n",
" alpha=0.5,\n",
")\n",
"\n",
"# Set y-axis to a logarithmic scale for better visualization of wide-ranging counts\n",
"plt.yscale(\"log\")\n",
"plt.yscale(\"log\")\n",
"plt.title(\"Data Valuation Scores by Label\")\n",
"plt.axvline(lab.info[\"data_valuation\"][\"threshold\"], color=\"red\", linestyle=\"--\", label=\"Issue Threshold\")\n",
"\n",
"plt.title(\"Strip plot of Data Valuation Scores by Label\")\n",
"plt.xlabel(\"Data Valuation Score\")\n",
"plt.ylabel(\"Count (log scale)\")\n",
"plt.legend()\n",
"plt.show()"
]
},
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name = "cleanlab"
# requirements files see:
# https://packaging.python.org/en/latest/discussions/install-requires-vs-requirements/
dependencies = [
"numpy>=1.22.0",
"numpy~=1.22",
jwmueller marked this conversation as resolved.
Show resolved Hide resolved
"scikit-learn>=1.1",
"tqdm>=4.53.0",
"pandas>=1.4.0",
Expand Down
Loading
Loading