Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add video support #430

Draft
wants to merge 28 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3ec0ed8
Add video support
iejMac Feb 15, 2023
0403d1d
Merge branch 'main' into vivit
iejMac Feb 16, 2023
1cd33e4
data loading: correct shapes in training loop (crappy code)
iejMac Feb 18, 2023
107292f
Merge branch 'vivit' of https://github.com/iejMac/open_clip into vivit
iejMac Feb 18, 2023
b995750
update model progress
iejMac Feb 18, 2023
be04c06
rename file + create_model loads something
iejMac Feb 19, 2023
0ad7168
update
iejMac Feb 19, 2023
f9dfd02
update
iejMac Feb 19, 2023
df1c698
embeddings get to loss, time to implement video encoding
iejMac Feb 19, 2023
7643ce3
update, set num_samples
iejMac Feb 19, 2023
cb12acd
more filling in
iejMac Feb 19, 2023
67a2d33
slightly improved preprocessing
iejMac Feb 19, 2023
f0615f9
update
iejMac Feb 19, 2023
7bd848e
update weird lag
iejMac Feb 19, 2023
80f41a0
simpler dataloader same results
iejMac Feb 20, 2023
f5af600
properly normalize frames
iejMac Feb 21, 2023
982ee88
update no temporal
iejMac Feb 21, 2023
3f4fde7
filter no mp4 samples
iejMac Feb 26, 2023
7176d95
update
iejMac Feb 26, 2023
3f64b62
adding projection removes weird const loss bug but training doesn't g…
iejMac Feb 26, 2023
8d3cc48
some updates
iejMac Feb 27, 2023
4230428
save changes
iejMac Mar 11, 2023
9413c5e
update dataloader to use video2dataset
iejMac Mar 26, 2023
5c65a52
update
iejMac Mar 27, 2023
f2fa5bd
repeat is bad
iejMac Mar 28, 2023
3125171
enable loading CLIP weights to spatial and text encoders
iejMac Apr 4, 2023
b093a4c
update
iejMac Apr 15, 2023
9db7425
update
Jun 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 59 additions & 18 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype
from .coca_model import CoCa
from .video_model import VideoCLIP # TODO: change once full model is implemented
from .loss import ClipLoss, DistillClipLoss, CoCaLoss
from .openai import load_openai_model
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform, AugmentationCfg
from .transform import image_transform, video_transform, AugmentationCfg
from .tokenizer import HFTokenizer, tokenize


Expand Down Expand Up @@ -100,7 +102,18 @@ def load_checkpoint(model, checkpoint_path, strict=True):
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
resize_pos_embed(state_dict, model)
incompatible_keys = model.load_state_dict(state_dict, strict=strict)

incompatible_keys = []
# TODO: find better way of doing this
if isinstance(model, VideoCLIP):
text_state_dict = dict([(k[len("text."):], v) for (k, v) in state_dict.items() if k.startswith("text")])
visual_state_dict = dict([(k[len("visual."):], v) for (k, v) in state_dict.items() if k.startswith("visual")])

incompatible_keys += model.text.load_state_dict(text_state_dict, strict=strict)
incompatible_keys += model.visual.spatial.load_state_dict(visual_state_dict, strict=strict)
else:
incompatible_keys = model.load_state_dict(state_dict, strict=strict)

return incompatible_keys


Expand Down Expand Up @@ -191,11 +204,20 @@ def create_model(
else:
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
if "ViViT" in model_name: # TODO better way of detecting video configs
model = VideoCLIP(**model_cfg)
else:
model = CLIP(**model_cfg, cast_dtype=cast_dtype)

pretrained_loaded = False
if pretrained:
checkpoint_path = ''

# TODO: not sure how to initialize components nicely
# idea for now: model_name:pretrained
if ":" in pretrained:
model_name, pretrained = pretrained.split(":")

pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
if pretrained_cfg:
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
Expand Down Expand Up @@ -305,21 +327,40 @@ def create_model_and_transforms(
output_dict=output_dict,
)

image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std,
aug_cfg=aug_cfg,
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std,
)
# TODO: better way of getting modality specific transforms
if "ViViT" in model_name:
preprocess_train = video_transform(
frame_size=model.visual.spatial.image_size,
n_frames=model.visual.context_length,
take_every_nth=5,
is_train=False, # TODO: figre out if frame augmentations make sense
frame_mean=None,
frame_std=None,
)
preprocess_val = video_transform(
frame_size=model.visual.spatial.image_size,
n_frames=model.visual.context_length,
take_every_nth=5,
is_train=False,
frame_mean=None,
frame_std=None,
)
else:
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
image_std = image_std or getattr(model.visual, 'image_std', None)
preprocess_train = image_transform(
model.visual.image_size,
is_train=True,
mean=image_mean,
std=image_std,
aug_cfg=aug_cfg,
)
preprocess_val = image_transform(
model.visual.image_size,
is_train=False,
mean=image_mean,
std=image_std,
)

return model, preprocess_train, preprocess_val

Expand Down
24 changes: 24 additions & 0 deletions src/open_clip/model_configs/ViViT-B-32.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
},
"temporal_cfg": {
"context_length": 32,
"width": 512,
"heads": 8,
"layers": 12,
"mlp_ratio": 4,
"pooler_type": "cls_pooler"
}
}
24 changes: 24 additions & 0 deletions src/open_clip/model_configs/ViViT-B-32_short.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"embed_dim": 512,
"vision_cfg": {
"image_size": 224,
"layers": 12,
"width": 768,
"patch_size": 32
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 512,
"heads": 8,
"layers": 12
},
"temporal_cfg": {
"context_length": 8,
"width": 512,
"heads": 8,
"layers": 12,
"mlp_ratio": 4,
"pooler_type": "cls_pooler"
}
}
24 changes: 24 additions & 0 deletions src/open_clip/model_configs/ViViT-L-14_short.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"embed_dim": 768,
"vision_cfg": {
"image_size": 224,
"layers": 24,
"width": 1024,
"patch_size": 14
},
"text_cfg": {
"context_length": 77,
"vocab_size": 49408,
"width": 768,
"heads": 12,
"layers": 12
},
"temporal_cfg": {
"context_length": 8,
"width": 768,
"heads": 12,
"layers": 12,
"mlp_ratio": 4,
"pooler_type": "cls_pooler"
}
}
69 changes: 68 additions & 1 deletion src/open_clip/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torchvision.transforms.functional as F

from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
CenterCrop
CenterCrop, ToPILImage

from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD

Expand Down Expand Up @@ -131,3 +131,70 @@ def image_transform(
normalize,
])
return Compose(transforms)


# TODO: needs improvmenet
def video_transform(
frame_size: int,
n_frames: int,
take_every_nth: int,
is_train: bool,
frame_mean: Optional[Tuple[float, ...]] = None,
frame_std: Optional[Tuple[float, ...]] = None,
):

frame_mean = frame_mean or OPENAI_DATASET_MEAN
if not isinstance(frame_mean, (list, tuple)):
frame_mean = (frame_mean,) * 3

frame_std = frame_std or OPENAI_DATASET_STD
if not isinstance(frame_std, (list, tuple)):
frame_std = (frame_std,) * 3

normalize = Normalize(mean=frame_mean, std=frame_std)

if is_train:
transforms = [
ToPILImage(),
RandomResizedCrop(
frame_size,
scale=(0.9, 0.1),
interpolation=InterpolationMode.BICUBIC,
),
_convert_to_rgb,
ToTensor(),
normalize,
]
else:
transforms = [
ToPILImage(),
Resize(frame_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(frame_size),
_convert_to_rgb,
ToTensor(),
normalize,
]

frame_transform = Compose(transforms)
def apply_frame_transform(sample):
video, audio, video_meta = sample
video = video.permute(0, 3, 1, 2)

video = video[::take_every_nth]
video = video[:n_frames] # TODO: maybe make this middle n frames

# TODO: maybe padding isn't the way to go
# TODO: also F.pad is acting up for some reason
# isn't letting me input a len 8 tuple for 4d tnesor???
# video = F.pad(video, tuple([0, 0]*len(video.shape[-3:]) + [0, n_frames - video.shape[0]]))

if video.shape[0] < n_frames:
padded_video = torch.zeros(n_frames, *video.shape[1:])
padded_video[:video.shape[0]] = video
video = padded_video

# TODO: this .float() is weird, look how this is done in other places
return torch.cat([frame_transform(frame.float())[None, ...] for frame in video])


return apply_frame_transform
2 changes: 1 addition & 1 deletion src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def forward(self, x: torch.Tensor):

if self.output_tokens:
return pooled, tokens

return pooled


Expand Down
Loading