-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Investigate possible memory leak with sample packing #1097
Comments
To process 50k samples (~ 2.2% of the entire dataset) consumes 41gb of memory. To process 100k samples (~ 4.3% of the entire dataset) consumes 81 gb of memory. This seems like the memory is scaling linearly, based on this looks like the entire process would consume ~1886 gb of memory. |
Using the following code, I'm calculating the size of the packs directly:
Here's how it plays out for every pack:
The mask is HUGGGEEE - we need to change how we calculate this. |
See comment soln seems to be to calculate mask in each batch: https://github.com/pytorch/torchtune/pull/1083/files#diff-d069dcb7c967a8f3aaaca697349854143ff7b0631f885a6e7bd62dda19f6eee1R169 We could also technically do this for
|
Thought I might chip in if it's helpful. I have a little helper function for doing this for each batch in my PPO fork: def get_causal_mask(
tokens: torch.Tensor,
padding_mask: torch.Tensor,
*,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Generates a causal attention mask for the given tokens and padding mask suitable for
consumtion by :func:`~torch.nn.functional.scaled_dot_product_attention~` where the
mask is added to the attention score.
HF uses a similar implementation internally, see
https://github.com/huggingface/transformers/blob/a564d10afe1a78c31934f0492422700f61a0ffc0/src/transformers/models/mistral/modeling_mistral.py#L1096
Args:
tokens (torch.Tensor): tensor of token IDs with shape [bsz x seq_length]
padding_mask (torch.Tensor): tensor of padding ID masks with shape [bsz x seq_length]
dtype (torch.dtype): dtype to infer fill value for masking
Returns:
torch.Tensor: Casual mask with shape [bsz x seq_length x seq_length]
"""
fill_value = torch.finfo(dtype).min
mask = torch.triu(
torch.full((tokens.shape[-1], tokens.shape[-1]), fill_value), diagonal=1
).to(tokens.device, dtype=dtype)
mask = mask.masked_fill(
padding_mask.unsqueeze(1).expand(-1, tokens.shape[1], tokens.shape[1]),
fill_value,
return mask
) This is inspired by how HF does things to keep memory usage low: only creating # create position IDs and causal masks for the current trajectory
padding_masks = query_responses == self._tokenizer.pad_id
# we only need custom causal masks for sequences with left-padding
if padding_masks.any():
masks = ppo_utils.get_causal_mask(
query_responses,
padding_mask=padding_masks,
dtype=self._dtype,
)
position_ids = (~padding_masks).cumsum(-1) - (~padding_masks).long()
position_ids = position_ids.to(device=self._device, dtype=torch.int)
else:
# defer SDPA to handle causal masks
masks, position_ids = None, None |
after discussing on disc this needs to be updated so atnn masks are bool, not floats, but the idea is the same. |
Following the gist below, a dataset of size ~3 GB with 2M entries takes up more than 900 GB of CPU memory while doing offline sample packing.
https://gist.github.com/joecummings/642eaa9ce539ad93360ee3f999dbcfa3
Need to get to the bottom of this.
The text was updated successfully, but these errors were encountered: