Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add anomaly handler #1780

Merged
merged 5 commits into from
Jun 17, 2024
Merged

Add anomaly handler #1780

merged 5 commits into from
Jun 17, 2024

Conversation

lzhangzz
Copy link
Collaborator

@lzhangzz lzhangzz commented Jun 14, 2024

Detecting and suppressing NaN/INF for debugging and robustness

  1. Detect, suppress and report INF/NaN in the tensors
  2. Fix invalid logits and report errors

USAGE

opt-in by setting environment variable TM_ANOMALY_HANDLER=args...

ARGS

  • level - default: 0
    • 0 - off
    • 1 - embedding/lm_head/logits
    • 2 - plus rmsnorm/residual/ffn_block/attn_block
    • 3 - plus all other kernel outputs
  • nan - value used to replace NaNs, default: NaN
  • inf - value used to replace INF, default: INF
  • fallback - fallback token when there are INFs or NaNs in the logits, default: eos_id

For example TM_ANOMALY_HANDLER=level=3,nan=0,inf=0 will

  1. Flush all NaN/INF to 0 for all kernel outputs and count the numer of anomalies
  2. When NaN/INF detected in the logits, all logits for the sample will be set to 0 with the exception that the fallback token will be set to MAX_HALF, an error will be set for the request
  3. Summary of detected anomalies will be logged at WARNING level after each iteration

NOTE

  • Level 1 is enough to suppress crashes caused by NaN/INF but cannot save the corrupted token
  • Level 2/3 with proper NaN/INF replacement may suppress sporadic INFs and allow the generation to continue smoothly
  • Level 2/3 hurts performance as the launched kernels are doubled and the kernel for handling anomalies is not optimized
  • Try level 3 first then pick the suitable level based on the printed summary

auto x = static_cast<float>(data[i]);
if (isinf(x)) {
++inf_count;
data[i] = x > 0.f ? pinf_val : ninf_val;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering what pinf_val, ninf_val and nan_val are appropriate.

@lvhan028 lvhan028 merged commit 5cbefe2 into InternLM:main Jun 17, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants