Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build_cls_mask() in CoCa TextTransfotmer #549

Open
yiren-jian opened this issue Jun 24, 2023 · 2 comments
Open

build_cls_mask() in CoCa TextTransfotmer #549

yiren-jian opened this issue Jun 24, 2023 · 2 comments

Comments

@yiren-jian
Copy link

TL, DR: current implementation of build_cls_mask() produces cls_mask for [CLS] being as the first token. But in CoCa, [CLS] is the end token.

In Issue 312, build_cls_mask() was introduced by @gpucce in TextTransformer in CoCa to "preventing the CLS token at the end of the sequence from attending to padded tokens".

# https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py#L587
def build_cls_mask(self, text, cast_dtype: torch.dtype):
        cls_mask = (text != self.pad_id).unsqueeze(1)
        cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
        additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
        additive_mask.fill_(0)
        additive_mask.masked_fill_(~cls_mask, float("-inf"))
        additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)
        return additive_mask

Taking text = torch.tensor([[1,2,3,4,0,0,0]]) as an example,

import torch
import torch.nn.functional as F

text = torch.tensor([[1,2,3,4,0,0,0]])  ### batch size 1, sequence 4 with 3 padding (pad_id=0)

pad_id = 0
cls_mask = (text != pad_id).unsqueeze(1)
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0)
additive_mask = torch.empty(cls_mask.shape)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
print(additive_mask)

This output

tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., -inf, -inf, -inf]]])

In @lucidrains implementation

# https://github.com/lucidrains/CoCa-pytorch/blob/main/coca_pytorch/coca_pytorch.py#L384-L385
cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)

taking the same text as the example

import einops
import torch
import torch.nn.functional as F
from einops import rearrange, repeat

text = torch.tensor([[1,2,3,4,0,0,0]])

pad_id = 0
seq = text.shape[1]

cls_mask = rearrange(text!=pad_id, 'b j -> b 1 j')
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)
print(attn_mask)

it produces (which I believe should be the desired outcome)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False, False, False,  True]]])

Since [CLS] token is appended at the end of a sequence,

# https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py#L607
x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1)

I feel that the current implementation in open_clip is wrong? Do I miss anything?

@Mypathissional
Copy link

Yes, I feel that the current implementation is also wrong, can someone give update on this?

@gpucce
Copy link
Contributor

gpucce commented Sep 20, 2023

Hi, There is a PR #551 to fix this but I think nobody has time to review it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants