-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Any interesting results? #1
Comments
👋 hello Romain! no not yet, i still need to build out the modular forgetting system didn't you start your new job? |
Ok great, I'll follow up on the progress :) Indeed I started the new job, pretty interesting! |
I hope you don't mind me stalking this project, but I tried this out on enwik8 (https://github.com/igor0/memorizing-transformers-pytorch/commit/d302feee0c3d9655a92c392850c4ec5d86bff77c). I basically just ported the enwik8 training loop from another one of @lucidrains's projects. The initial finding is that with KNN memories, the training loop is pretty slow, so often (but not always) I'll sample 0% GPU usage. Disabling the KNN memories makes the training loop go much faster (>10x compared to with KNN). So the KNN code may need some optimization, but I don't understand it well enough yet to suggest something constructive. Edit: Ah, that was with knn_use_gpu=False, I missed that. With knn_use_gpu=True, I seem to get a hang. On the GPU, it's faiss.index_cpu_to_all_gpus(self.index) that's hanging for me, endlessly chewing up CPU cycles. Just FYI. |
@igor0 ah thanks for trying it! i'll hook up the enwik8 training myself next week and start profiling and see what's going on :) still want to polish up the pluggable memory expiring strategy (account for memory creation time as well as last retrieved time) |
I ended up having two issues with KNN on the GPU. Here are the findings so far. 1. the wheel package faiss-gpu hangs for me on A100 and A10 Using conda rather than pip to install faiss-gpu seems to work for me. 2. " remove_ids not implemented for this type of index" |
there is little benefit to using faiss gpu however if knn operations are slow here, it's likely because a flat index is used |
What type of index to use then? One problem is that we don't know the distribution of the keys upfront, and the clustering approaches require that. Furthermore, the distribution of keys changes over time. So, you could keep recomputing the clusters periodically. I'm sure that's doable, but another thing to sort out and tune. IMO a flat index is a reasonable place to start. And a flat index on a GPU would perform much better than a flat index on a CPU. If faiss doesn't let us implement a flat index with the ability to replace entries, then we could implement our own sliding window mechanism, or just avoid faiss for now and simply implement the memory directly as a PyTorch tensor. That could be one straightforward solution. |
Hmm, yeah, maybe this won't be smooth sailing There is also another library out there called TorchPQ that may fit the bill of running on GPU and have removing of ids. But it is relatively young library still, so prob not without a few rough edges. I'll take a closer look next week, thanks for prematurely trying this out! |
FlatContainer in TorchPQ looks promising as a potential flat GPU index (to avoid the challenges with clustering): https://github.com/DeMoriarty/TorchPQ/blob/main/torchpq/container/FlatContainer.py It seems like |
Could you describe how often these operations are done within memorizing transformers:
|
An example of index that works without training (although it's not obvious that's a good property at this point) is IndexHNSW What I meant by "there is little benefit in using faiss GPU" is that faiss indices are usually very fast (search is done in less than 1ms) on CPU, and it's not faster on GPU. But the choice of index should be done based on how often you need to search/add/train/remove and how many vectors you have. So if you give more information on that, i can advise |
@rom1504 thanks for offering your expertise! so basically this paper is adding embeddings at a rate of 512 tokens per training step. To compound on the problem, they are doing separate indexed memories per batch, which is why I have to instantiate the number of faiss indices equal to the batch size. Searching is done also every training step (after the first), with a top k of 32, and removal of embeddings starts after it hits some capacity limit (in the paper, they had 2048, and then scales the memory size up to 16000) 2048 would mean the removing of ids start on the 5th step. So basically high rates of adding, removing, searching. The author told me what they were doing within google is running each batch on 1 TPU core, and thus able to assign it its own index. Flat would be fine, but it also negates the paper's main selling point, which is that fetching from approximate knn memories should benefit attention net greatly. Hopefully it isn't the case that it "does not work in practice" due to engineering obstacles |
even in the worst case, I think the lessons from this paper can be carried away to some other architecture (say if one were to generalize https://github.com/lucidrains/HTM-pytorch) 1. storing l2normed - key / values (cosine sim attention) as memories for stability 2. memories need not be differentiable 3. approximate knn is fine 4. one only needs one or two layers of long term memory fetching at most (placed at the middle of the attention net) |
ok interesting And maybe you rebuild the index from scratch every 1000 steps to save on memory if needed. about add/search, I would start by trying simply using IndexHNSWFlat, I believe it will work well enough. import faiss
index = faiss.IndexHNSWFlat(dimension, 15, faiss.METRIC_INNER_PRODUCT)
index.add(faiss.rand((512, dimension))) |
@rom1504 thanks for the suggestion 👍 on trying it out, seeing |
The problem is that for each training sample, we need to:
So, adds : removals : searches are 1 : 1 : 1, or alternatively you can think of it as searches : replacements being 1 : 1. So, we are doing the removals in order to create space for the add. Masking out the elements doesn't really solve the problem for us because it doesn't open up new space. One solution is to have two indexes: Another solution is to use a single flat index that supports replacement: i.e., we can efficiently replace some entries with other entries. Faiss doesn't seem to support this, but you can either implement something from scratch (just represent the memory as a tensor), or use the other library that @lucidrains mentioned. |
@igor0 yea, i feel like if we were to go with flat for faiss, there would be no benefit. the whole point is to sell the approach of using approximate knn for long term, non differentiable memory - at least, that was what excited me about the paper initially |
maybe it would be best to forget about faiss and scann, and just try to roll something with deepmind's HTM (although i think more thought needs to be put into how to generalize HTM to more than just a depth of 1 hierarchy) - or the alternative is to just forget about this repository and focus on https://github.com/lucidrains/routing-transformer and make sure it supports recurrence and that the routing attention can act on a set of non differentiable memories |
I don't understand why add_with_ids is needed. custom ids don't do anything particularly interesting. You can either use consecutive ids, or maintain a consecutive/custom ids mapping (as a python dict, or as a numpy array) you could decide to use faiss.IndexIDMap if you really want add with ids https://github.com/facebookresearch/faiss/wiki/Pre--and-post-processing#the-indexidmap |
why is opening new spaces needed? |
ohh ok, maybe it could still work then, since faiss can support a ridiculous number of vectors - realistically, each document isn't going to exceed 32k tokens, before the next document comes along and the index needs to be cleared and retrained i guess the other issue is frequent retraining of indices, since each batch will be maintaining its own index, and everytime a new document comes along, it needs to be cleared and retrained screenshot below for clarity, just imagine the batch dimension being around 32 - 64 |
OK, you can add elements to the index until it reaches
I don't think HNSW indices need to be trained, so training wouldn't be an issue. That's also why we shouldn't use clustering-based indices, at least at training time. (The constraints are a bit different at inference time, but let's focus on training time for now.) |
If we don't need to support very large documents, then all the smart forgetting work becomes unnecessary. We can just clear the memory at the end of each document, so no need to forget individual entries at all. Maybe that's good enough to get past enwik8? |
indeed hnsw requires no training, that's why I was suggesting it! |
if your total number of vector is below 1M, hnsw is quite fine beyond 10M vectors, you'll have to think a bit more, but with some smart eviction and retraining only every N (like 1000) steps, it should be ok |
TIL! |
yea that's true, but it would be nice if the work is extended (RL for example) |
just to give an idea of how much faiss can scale, https://rom1504.github.io/clip-retrieval/ 's backend is currently holding an index of 5B embeddings of dimension 768. It uses 200MB of ram for the index thanks to faiss's memmapping feature. so for millions of embeddings, everything should be fine |
ok, let me meditate on this for an hour or two (before reckless execution), since it would require a big refactor of the KNN Memory class thank you both, and hope you are having a great weekend :) |
You are giving batches to index.search and index.add right? |
Thousand queries should take like 10ms with hnsw |
OK, that's fair, it should be batched. Different elements will have to take different paths in the HNSW, but at least you have the loop in C++ and not in Python. Nevertheless, HNSW is still the bottleneck. I guess it must be the add() that's the bottleneck, not the search. Because we also have to add thousands of entries into the HNSW graph for each training iteration. And add in HNSW is a lot slower than search. |
it was fast enough for me to complete one experiment, and the results don't look that great tbh i'll keep playing around with it this week. i just need to see some spark before i dedicate more engineering effort to this idea |
perhaps https://github.com/lucidrains/HTM-pytorch would be a better bet, with the memories kept on CPU, and then the compressed memories kept in GPU (and i can reach for https://github.com/lucidrains/memory-efficient-attention-pytorch if need be), but that memory transformer variant rightfully belongs in another repository, as a Deepmind contribution |
on further reflection, i don't even think the attention layer for the long term memories should mix local with distant memories (using all the gating logic) one should simply place a long term memory module dab in the middle of the network, cross attend efficiently once |
A few comments:
Regarding your implementation:
On my end, I am experimenting with these ideas, but taking a different path than you. Instead of trying to implement the whole paper, I'm starting from a pre-trained model and trying to implement trivial memorization in the smallest incremental steps I can think of. Hopefully at least one of us gets to some kind of an interesting results. |
@igor0 ahh yes, i do like the idea of having separate heads for local vs distant re: null key / values, the reason is because in the paper, one can have documents spanning different lengths across different batch rows it is just cleaner solution than to unbind the batch dimension and do all the conditional logic. the current way it is built is still not great, since ideally the number of memories present is also accounted for in the gating (so the network attends locally when no distant memories are present for a certain batch row) |
@igor0 do keep us posted about results from your approach :) |
@igor0 i switched it to your splitting of attention heads idea and it looks a lot better now!! https://wandb.ai/lucidrains/memorizing-transformers/reports/memorizing-transformers--VmlldzoxODE4ODc1?accessToken=2ae4y9l100i3kfj3oa0udcsct0r495evo661lkk71w1mve1itlmphn20eiix1aul 🎉 🎉 thank you! |
@rom1504 also thank you for all your help with faiss and the hnsw suggestion! |
Interesting. I wonder whether the baseline got an unlucky initialization and so it had a poor start. I'm not sure I'd expect memories to be of that much help early on during the training. If the benefits of memorization show up later in the training, that may make the approach harder to evaluate. |
On further thought, I'm not sure how exactly to make a dedicated distant attention head work. Since the memory access is non-differentiable, the backward pass won't reach the K and V transforms. So, the K and V transforms for the distant heads won't be trained. So, at least the K/V transform must be tied across the local-remote modes. Did you solve that in some way? |
@igor0 i think that's the appeal of this paper. it seems the claim is that even with non differentiable memory, in addition to approximate nearest neighbors, the network would still perform well. i'm fairly confident the non-differentiable memories would do fine, as it lines up with transformer-xl. XL memories are also non-differentiable (they are detached for the next iteration) however, one could augment the network with another projection of the non-differentiable memory key / values, perhaps conditioned by their age (which we can keep track of) |
Well in the paper, each KV projection is used for both local and distant attention. So, even though the backward pass for distant attention doesn't reach the KV projection, that is still OK because the backward passes from local attention seem to be sufficient. But, if you have a distant-only head, there will literally be no backward passes reaching the KV projections. So, their weights will forever keep their initial random values. That can't possibly work. Anyways, it's well possible that I'm missing something, but that's why it seems to me that distant-only attention heads won't work trivially. Using a hybrid local-distant head seems to be one solution. |
ahhh got it, yes, we can try the hybrid approach too (which i assume is just concatting the fetched distant key / values to the local one, prior to softmax) |
ultimately, i think we should just let empiricism speak. hard to know which option is best without running the experiments |
I'll have some free time Sunday and will run some experiments comparing the two ways |
@igor0 Hey Igor, I finally got some time to attend to memorizing transformers. Mainly I tried out the hybrid attention as you suggested and it is working the best! I've checked it in for 0.3.0 |
the reprojection of the memories did very little, possibly a little worse, just to share a negative result |
finally seeing liftoff with 3 modifications to the KNN attention module
|
I can suggest give a try to dynamic position bias instead of T5 rel pos, as later is usually slow and DPB can give similar performance for better speed. |
@gurvindersingh Oh hey Gurvinder! How is it faster than T5 rel pos? The way I see it, T5 rel pos is a zero-layered DPB (usually DPB is a 2-3 layered MLP) |
Very cool! Yeah, a full cosine-similarity attention makes sense. Normalizing key and values like in the paper sounds... odd. I'm sure the researchers behind the paper had their rationale. On my end, I've been trying to add memorization to an existing pretrained model. The positive result is that if I add memory to one or more layers in a pretrained model as hybrid attention, I get a modest but measurable performance benefit. I've been trying to take such model (with one KNN layer) and train it (e.g., with attentions unfrozen, or everything unfrozen) to better adapt the model to exchange info in the KNN layer. I haven't had much success with the retraining so far. The retraining tends to be unstable and quickly diverge. I've tried a number of the things you are trying here (project the memorized keys/values, add a trained bias to memorized keys, try it with positional biases for memory or without, etc). I can normalize keys/queries or keys/values in the KNN layer, but then that degrades the pretrained model substantially. It seems to stabilize the retraining somewhat, but then I lose some of the benefit of starting from a pretrained model. The only thing I found that sort of worked was that training a bias for memorized keys gives a modest but measurable benefit. But I haven't been able to stabilize the training sufficiently to get really compelling results. Still working on it and trying things, though! A couple of ideas I'd still like to try include suppressing the KNN attention early in the training (e.g, by using a trained penalty coefficient applied to the distant attention scores before the distant+local softmax) or perhaps adding a dedicated layer just for distant attention. |
@igor0 ahh ok, yes, i did have to do a lot of trial and error before i started to see something with the formula outlined above i'll keep chipping away at it when i find some time |
@lucidrains I haven't looked into the code part but when I was testing T5 rel pos embedding and DPB, with DPB model was taking less time per step. Hence my above statement :) |
@gurvindersingh ohh interesting! i'll have to revisit the T5 rel pos bias class and make sure i'm caching some intermediates correctly |
I'm gonna piggyback this thread, because it is quite active. A memory_list is only needed when I have multiple Layers that contain a KNNAttentionblock right? Thanks in advance! |
Hey!
Cool repo. I like all the knn+lm methods
Did you do some runs yet? Anything interesting to report?
The text was updated successfully, but these errors were encountered: