Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate possible memory leak with sample packing #1097

Closed
joecummings opened this issue Jun 17, 2024 · 5 comments · Fixed by #1109
Closed

Investigate possible memory leak with sample packing #1097

joecummings opened this issue Jun 17, 2024 · 5 comments · Fixed by #1109
Assignees

Comments

@joecummings
Copy link
Contributor

Following the gist below, a dataset of size ~3 GB with 2M entries takes up more than 900 GB of CPU memory while doing offline sample packing.

https://gist.github.com/joecummings/642eaa9ce539ad93360ee3f999dbcfa3

Need to get to the bottom of this.

@joecummings joecummings self-assigned this Jun 17, 2024
@joecummings joecummings added the bug Something isn't working label Jun 17, 2024
@joecummings
Copy link
Contributor Author

To process 50k samples (~ 2.2% of the entire dataset) consumes 41gb of memory.

To process 100k samples (~ 4.3% of the entire dataset) consumes 81 gb of memory.

This seems like the memory is scaling linearly, based on this looks like the entire process would consume ~1886 gb of memory.

@joecummings
Copy link
Contributor Author

joecummings commented Jun 17, 2024

Using the following code, I'm calculating the size of the packs directly:

packs = [
	{"tensor": ... , "mask": ... , "labels": ... , "input_pos": ... },
	{"tensor": ... , "mask": ... , "labels": ... , "input_pos": ... },
	...
]


def mem_packs(packs):
    total_size = 0
    for pack in packs:
        for elem in pack.values():
			# this adds python object memory (same for every object) + size of tensor elem * num of tensors
            total_size += sys.getsizeof(elem) + elem.element_size() * elem.nelement()
    return f"{round(total_size / (1024**3), 2)}G"

Here's how it plays out for every pack:

tokens: 88 (fixed size of Python object) + 8 (size of torch.int64) * 3072 (max seq len)
labels: 88 + 8  * 3072
mask: 88 + 1 (size of torch.bool) * 3072 * 3072 (max seq len x max seq len)
input_pos: 88 + 8 * 3072
------------------------------
9511264 bytes ~= 9.5 MB

The mask is HUGGGEEE - we need to change how we calculate this.

@joecummings
Copy link
Contributor Author

See comment soln seems to be to calculate mask in each batch: https://github.com/pytorch/torchtune/pull/1083/files#diff-d069dcb7c967a8f3aaaca697349854143ff7b0631f885a6e7bd62dda19f6eee1R169

We could also technically do this for input_pos. This would slow down the processing but save memory:

tokens: 88 (fixed size of Python object) + 8 (size of torch.int64) * 3072 (max seq len)
labels: 88 + 8  * 3072
----------------------
49328 bytes ~= 0.05 MB

@joecummings joecummings removed the bug Something isn't working label Jun 18, 2024
@SalmanMohammadi
Copy link
Contributor

Thought I might chip in if it's helpful. I have a little helper function for doing this for each batch in my PPO fork:

def get_causal_mask(
    tokens: torch.Tensor,
    padding_mask: torch.Tensor,
    *,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """
    Generates a causal attention mask for the given tokens and padding mask suitable for
    consumtion by :func:`~torch.nn.functional.scaled_dot_product_attention~` where the
    mask is added to the attention score.

    HF uses a similar implementation internally, see
    https://github.com/huggingface/transformers/blob/a564d10afe1a78c31934f0492422700f61a0ffc0/src/transformers/models/mistral/modeling_mistral.py#L1096

    Args:
        tokens (torch.Tensor): tensor of token IDs with shape [bsz x seq_length]
        padding_mask (torch.Tensor): tensor of padding ID masks with shape [bsz x seq_length]
        dtype (torch.dtype): dtype to infer fill value for masking
    Returns:
        torch.Tensor: Casual mask with shape [bsz x seq_length x seq_length]
    """
    fill_value = torch.finfo(dtype).min
    mask = torch.triu(
        torch.full((tokens.shape[-1], tokens.shape[-1]), fill_value), diagonal=1
    ).to(tokens.device, dtype=dtype)
    mask = mask.masked_fill(
        padding_mask.unsqueeze(1).expand(-1, tokens.shape[1], tokens.shape[1]),
        fill_value,
    return mask
    )

This is inspired by how HF does things to keep memory usage low: only creating padding_mask during dataset sampling, which is [bsz, seq_len], and reshaping this on-the-fly to a [bsz, seq_len, seq_len] before passing to SDPA. Here's an example in context which also generates position IDs.

                    # create position IDs and causal masks for the current trajectory
                    padding_masks = query_responses == self._tokenizer.pad_id
                    # we only need custom causal masks for sequences with left-padding
                    if padding_masks.any():
                        masks = ppo_utils.get_causal_mask(
                            query_responses,
                            padding_mask=padding_masks,
                            dtype=self._dtype,
                        )
                        position_ids = (~padding_masks).cumsum(-1) - (~padding_masks).long()
                        position_ids = position_ids.to(device=self._device, dtype=torch.int)
                    else:
                        # defer SDPA to handle causal masks
                        masks, position_ids = None, None

@SalmanMohammadi
Copy link
Contributor

SalmanMohammadi commented Jun 18, 2024

after discussing on disc this needs to be updated so atnn masks are bool, not floats, but the idea is the same.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants