-
Notifications
You must be signed in to change notification settings - Fork 689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding type hints for cleanlab/filter #598
base: master
Are you sure you want to change the base?
Changes from 1 commit
5d2d13b
2ed3f89
58cde00
ce7d561
0f10982
aa549e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -221,7 +221,7 @@ class 0, 1, ..., K-1. | |
"confident_learning", | ||
"predicted_neq_given", | ||
] # TODO: change default to confident_learning ? | ||
allow_one_class = False | ||
allow_one_class:bool = False | ||
if isinstance(labels, np.ndarray) or all(isinstance(lab, int) for lab in labels): | ||
if set(labels) == {0}: # occurs with missing classes in multi-label settings | ||
allow_one_class = True | ||
|
@@ -276,13 +276,13 @@ class 0, 1, ..., K-1. | |
) | ||
|
||
# Else this is standard multi-class classification | ||
K = get_num_classes( | ||
K:int = get_num_classes( | ||
labels=labels, pred_probs=pred_probs, label_matrix=confident_joint, multi_label=multi_label | ||
) | ||
# Number of examples in each class of labels | ||
label_counts = value_counts_fill_missing_classes(labels, K, multi_label=multi_label) | ||
label_counts:npt.NDArray[np.int_] = value_counts_fill_missing_classes(labels, K, multi_label=multi_label) | ||
# Boolean set to true if dataset is large | ||
big_dataset = K * len(labels) > 1e8 | ||
big_dataset:bool = K * len(labels) > 1e8 | ||
# Ensure labels are of type np.ndarray() | ||
labels = np.asarray(labels) | ||
if confident_joint is None or filter_by == "confident_learning": | ||
|
@@ -298,16 +298,16 @@ class 0, 1, ..., K-1. | |
# Create `prune_count_matrix` with the number of examples to remove in each class and | ||
# leave at least min_examples_per_class examples per class. | ||
# `prune_count_matrix` is transposed relative to the confident_joint. | ||
prune_count_matrix = _keep_at_least_n_per_class( | ||
prune_count_matrix:npt.NDArray[np.int_] = _keep_at_least_n_per_class( | ||
prune_count_matrix=confident_joint.T, | ||
n=min_examples_per_class, | ||
frac_noise=frac_noise, | ||
) | ||
|
||
if num_to_remove_per_class is not None: | ||
# Estimate joint probability distribution over label issues | ||
psy = prune_count_matrix / np.sum(prune_count_matrix, axis=1) | ||
noise_per_s = psy.sum(axis=1) - psy.diagonal() | ||
psy:npt.NDArray["np.floating[T]"] = prune_count_matrix / np.sum(prune_count_matrix, axis=1) | ||
noise_per_s:npt.NDArray["np.floating[T]"] = psy.sum(axis=1) - psy.diagonal() | ||
# Calibrate labels.t. noise rates sum to num_to_remove_per_class | ||
tmp = (psy.T * num_to_remove_per_class / noise_per_s).T | ||
np.fill_diagonal(tmp, label_counts - num_to_remove_per_class) | ||
|
@@ -427,18 +427,19 @@ class 0, 1, ..., K-1. | |
|
||
|
||
def _find_label_issues_multilabel( | ||
labels: list, | ||
pred_probs: np.ndarray, | ||
labels: LabelLike, | ||
pred_probs: npt.NDArray["np.floating[T]"], | ||
return_indices_ranked_by: Optional[str] = None, | ||
rank_by_kwargs={}, | ||
rank_by_kwargs:Dict={}, | ||
filter_by: str = "prune_by_noise_rate", | ||
frac_noise: float = 1.0, | ||
num_to_remove_per_class: Optional[int] = None, | ||
min_examples_per_class=1, | ||
min_examples_per_class: int=1, | ||
confident_joint: Optional[np.ndarray] = None, | ||
n_jobs: Optional[int] = None, | ||
verbose: bool = False, | ||
) -> np.ndarray: | ||
) -> npt.NDArray[Union[np.bool_, np.int_]]: | ||
#TODO: add docstring in the same format as other functions with args and returns | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can I add to-dos wherever the doc doesn't have the same format as the other functions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you referring to the "shape"-format? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The format of doc itself. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ohh right, now I see what you mean! Thank you for pointing this out! That's not related to this PR, so you shouldn't add this particular to-do in this PR. We want to address these kinds of docstrings in a separate PR. |
||
""" | ||
Finds label issues in multi-label classification data where each example can belong to more than one class. | ||
This is done via a one-vs-rest reduction for each class and the results are subsequently aggregated across all classes. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The prune_count_matrix changes the data type (from an array of int to float). Should I declare it as a float? Although the function here returns int explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have to take a better look later today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@elisno did you have a chance to look at this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see where the datatype should change in this function. I think this annotation is fine for now.