Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Any interesting results? #1

Open
rom1504 opened this issue Mar 31, 2022 · 79 comments
Open

Any interesting results? #1

rom1504 opened this issue Mar 31, 2022 · 79 comments

Comments

@rom1504
Copy link

rom1504 commented Mar 31, 2022

Hey!
Cool repo. I like all the knn+lm methods
Did you do some runs yet? Anything interesting to report?

@lucidrains
Copy link
Owner

👋 hello Romain! no not yet, i still need to build out the modular forgetting system

didn't you start your new job?

@rom1504
Copy link
Author

rom1504 commented Mar 31, 2022

Ok great, I'll follow up on the progress :)

Indeed I started the new job, pretty interesting!

@igor0
Copy link

igor0 commented Apr 1, 2022

I hope you don't mind me stalking this project, but I tried this out on enwik8 (https://github.com/igor0/memorizing-transformers-pytorch/commit/d302feee0c3d9655a92c392850c4ec5d86bff77c). I basically just ported the enwik8 training loop from another one of @lucidrains's projects.

The initial finding is that with KNN memories, the training loop is pretty slow, so often (but not always) I'll sample 0% GPU usage. Disabling the KNN memories makes the training loop go much faster (>10x compared to with KNN). So the KNN code may need some optimization, but I don't understand it well enough yet to suggest something constructive.

Edit: Ah, that was with knn_use_gpu=False, I missed that. With knn_use_gpu=True, I seem to get a hang. On the GPU, it's faiss.index_cpu_to_all_gpus(self.index) that's hanging for me, endlessly chewing up CPU cycles. Just FYI.

@lucidrains
Copy link
Owner

@igor0 ah thanks for trying it! i'll hook up the enwik8 training myself next week and start profiling and see what's going on :) still want to polish up the pluggable memory expiring strategy (account for memory creation time as well as last retrieved time)

@igor0
Copy link

igor0 commented Apr 2, 2022

I ended up having two issues with KNN on the GPU. Here are the findings so far.

1. the wheel package faiss-gpu hangs for me on A100 and A10
With the faiss-gpu package installed by pip, I always get a hang in index_cpu_to_all_gpus(). I opened an issue here: kyamagu/faiss-wheels#54. I would guess that the faiss-gpu wheel isn't compatible with CUDA 11, so the A100/A10 GPUs don't work.

Using conda rather than pip to install faiss-gpu seems to work for me.

2. " remove_ids not implemented for this type of index"
As far as I can tell, remove_ids is not supported for any GPU indexes. One possible solution may be to simulate a sliding window with two GPU indexes, so that we always completely clear an index, instead of removing entries one-by-one. The fancier expiry will get more complicated and will require some type of manual compaction, at least if you want to run faiss on the GPU.

@rom1504
Copy link
Author

rom1504 commented Apr 3, 2022

there is little benefit to using faiss gpu

however if knn operations are slow here, it's likely because a flat index is used

@igor0
Copy link

igor0 commented Apr 3, 2022

What type of index to use then? One problem is that we don't know the distribution of the keys upfront, and the clustering approaches require that. Furthermore, the distribution of keys changes over time. So, you could keep recomputing the clusters periodically. I'm sure that's doable, but another thing to sort out and tune.

IMO a flat index is a reasonable place to start. And a flat index on a GPU would perform much better than a flat index on a CPU.

If faiss doesn't let us implement a flat index with the ability to replace entries, then we could implement our own sliding window mechanism, or just avoid faiss for now and simply implement the memory directly as a PyTorch tensor. That could be one straightforward solution.

@lucidrains
Copy link
Owner

Hmm, yeah, maybe this won't be smooth sailing

There is also another library out there called TorchPQ that may fit the bill of running on GPU and have removing of ids. But it is relatively young library still, so prob not without a few rough edges. I'll take a closer look next week, thanks for prematurely trying this out!

@lucidrains
Copy link
Owner

@igor0
Copy link

igor0 commented Apr 3, 2022

FlatContainer in TorchPQ looks promising as a potential flat GPU index (to avoid the challenges with clustering): https://github.com/DeMoriarty/TorchPQ/blob/main/torchpq/container/FlatContainer.py

It seems like FlatContainer::set_data_by_address() can arbitrarily overwrite records in the flat container. That would be more efficient than FlatContainer::remove() because remove() needs to copy a lot of data around. Not sure how much that will matter in the end, but always good to avoid copying when possible.

@rom1504
Copy link
Author

rom1504 commented Apr 3, 2022

Could you describe how often these operations are done within memorizing transformers:

  • Adding an embedding (and how many at once)
  • searching
  • removing an embedding

@rom1504
Copy link
Author

rom1504 commented Apr 3, 2022

An example of index that works without training (although it's not obvious that's a good property at this point) is IndexHNSW

What I meant by "there is little benefit in using faiss GPU" is that faiss indices are usually very fast (search is done in less than 1ms) on CPU, and it's not faster on GPU.
The only time it's better to use GPU is if you need to query with a huge batch of embeddings (let's say 1M)

But the choice of index should be done based on how often you need to search/add/train/remove and how many vectors you have. So if you give more information on that, i can advise

@lucidrains
Copy link
Owner

lucidrains commented Apr 3, 2022

@rom1504 thanks for offering your expertise! so basically this paper is adding embeddings at a rate of 512 tokens per training step. To compound on the problem, they are doing separate indexed memories per batch, which is why I have to instantiate the number of faiss indices equal to the batch size. Searching is done also every training step (after the first), with a top k of 32, and removal of embeddings starts after it hits some capacity limit (in the paper, they had 2048, and then scales the memory size up to 16000) 2048 would mean the removing of ids start on the 5th step. So basically high rates of adding, removing, searching.

The author told me what they were doing within google is running each batch on 1 TPU core, and thus able to assign it its own index.

Flat would be fine, but it also negates the paper's main selling point, which is that fetching from approximate knn memories should benefit attention net greatly. Hopefully it isn't the case that it "does not work in practice" due to engineering obstacles

@lucidrains
Copy link
Owner

lucidrains commented Apr 3, 2022

even in the worst case, I think the lessons from this paper can be carried away to some other architecture (say if one were to generalize https://github.com/lucidrains/HTM-pytorch) 1. storing l2normed - key / values (cosine sim attention) as memories for stability 2. memories need not be differentiable 3. approximate knn is fine 4. one only needs one or two layers of long term memory fetching at most (placed at the middle of the attention net)

@rom1504
Copy link
Author

rom1504 commented Apr 3, 2022

ok interesting
Removing from an index is usually slow, so I would not remove.
Instead I would replace the remove operation by adding removed indices to a mask. (and when doing the search you search with an higher K value and apply the mask, higher K affects minimally the search speed)

And maybe you rebuild the index from scratch every 1000 steps to save on memory if needed.

about add/search, I would start by trying simply using IndexHNSWFlat, I believe it will work well enough.
It's a little slow at adding (maybe 10ms for a batch of 512), but search is basically instant (0.1ms)

import faiss
index = faiss.IndexHNSWFlat(dimension, 15, faiss.METRIC_INNER_PRODUCT)
index.add(faiss.rand((512, dimension)))

@lucidrains
Copy link
Owner

@rom1504 thanks for the suggestion 👍

on trying it out, seeing add_with_ids not implemented for this type of index, which would make custom memory expiration strategies unapproachable

@igor0
Copy link

igor0 commented Apr 3, 2022

The problem is that for each training sample, we need to:

  • search seq_len entries
  • remove seq_len entries
  • add seq_len entries

So, adds : removals : searches are 1 : 1 : 1, or alternatively you can think of it as searches : replacements being 1 : 1. So, we are doing the removals in order to create space for the add. Masking out the elements doesn't really solve the problem for us because it doesn't open up new space.

One solution is to have two indexes: current and previous. We add to current, and once current fills up, clear previous, current becomes previous, and previous becomes current. So basically, we are only adding. In this scenario, current and previous don't need to support any fancy operations beyond add() and clear(), so they can probably be either flat or HNSW.

Another solution is to use a single flat index that supports replacement: i.e., we can efficiently replace some entries with other entries. Faiss doesn't seem to support this, but you can either implement something from scratch (just represent the memory as a tensor), or use the other library that @lucidrains mentioned.

@lucidrains
Copy link
Owner

lucidrains commented Apr 3, 2022

@igor0 yea, i feel like if we were to go with flat for faiss, there would be no benefit. the whole point is to sell the approach of using approximate knn for long term, non differentiable memory - at least, that was what excited me about the paper initially

@lucidrains
Copy link
Owner

lucidrains commented Apr 3, 2022

maybe it would be best to forget about faiss and scann, and just try to roll something with deepmind's HTM (although i think more thought needs to be put into how to generalize HTM to more than just a depth of 1 hierarchy) - or the alternative is to just forget about this repository and focus on https://github.com/lucidrains/routing-transformer and make sure it supports recurrence and that the routing attention can act on a set of non differentiable memories

@rom1504
Copy link
Author

rom1504 commented Apr 3, 2022

on trying it out, seeing add_with_ids not implemented for this type of index, which would make custom memory expiration strategies unapproachable

I don't understand why add_with_ids is needed. custom ids don't do anything particularly interesting. You can either use consecutive ids, or maintain a consecutive/custom ids mapping (as a python dict, or as a numpy array)

you could decide to use faiss.IndexIDMap if you really want add with ids https://github.com/facebookresearch/faiss/wiki/Pre--and-post-processing#the-indexidmap

@rom1504
Copy link
Author

rom1504 commented Apr 3, 2022

Masking out the elements doesn't really solve the problem for us because it doesn't open up new space

why is opening new spaces needed?
The number of embedding you add at every iteration is pretty small, so the memory use will be limited until you do a few thousands steps. At this point you can rebuild the index

@lucidrains
Copy link
Owner

on trying it out, seeing add_with_ids not implemented for this type of index, which would make custom memory expiration strategies unapproachable

I don't understand why add_with_ids is needed. custom ids don't do anything particularly interesting. You can either use consecutive ids, or maintain a consecutive/custom ids mapping (as a python dict, or as a numpy array)

you could decide to use faiss.IndexIDMap if you really want add with ids https://github.com/facebookresearch/faiss/wiki/Pre--and-post-processing#the-indexidmap

ohh ok, maybe it could still work then, since faiss can support a ridiculous number of vectors - realistically, each document isn't going to exceed 32k tokens, before the next document comes along and the index needs to be cleared and retrained

i guess the other issue is frequent retraining of indices, since each batch will be maintaining its own index, and everytime a new document comes along, it needs to be cleared and retrained

screenshot below for clarity, just imagine the batch dimension being around 32 - 64

Screenshot from 2022-04-03 13-09-52

@igor0
Copy link

igor0 commented Apr 3, 2022

The number of embedding you add at every iteration is pretty small, so the memory use will be limited until you do a few thousands steps. At this point you can rebuild the index

OK, you can add elements to the index until it reaches 2 x intended_size and then compact it down to intended_size. During the compaction, you can choose whatever criteria for eviction, and simply create a new index. That's an alternative to having two indices (previous and current).

i guess the other issue is frequent retraining of indices

I don't think HNSW indices need to be trained, so training wouldn't be an issue. That's also why we shouldn't use clustering-based indices, at least at training time. (The constraints are a bit different at inference time, but let's focus on training time for now.)

@igor0
Copy link

igor0 commented Apr 3, 2022

realistically, each document isn't going to exceed 32k tokens, before the next document comes along and the index needs to be cleared and retrained

If we don't need to support very large documents, then all the smart forgetting work becomes unnecessary. We can just clear the memory at the end of each document, so no need to forget individual entries at all. Maybe that's good enough to get past enwik8?

@rom1504
Copy link
Author

rom1504 commented Apr 3, 2022

indeed hnsw requires no training, that's why I was suggesting it!

@rom1504
Copy link
Author

rom1504 commented Apr 3, 2022

if your total number of vector is below 1M, hnsw is quite fine
if your dimension is 1024, 1M vectors mean 4GB in ram which is quite reasonable

beyond 10M vectors, you'll have to think a bit more, but with some smart eviction and retraining only every N (like 1000) steps, it should be ok

@lucidrains
Copy link
Owner

indeed hnsw requires no training, that's why I was suggesting it!

TIL!

@lucidrains
Copy link
Owner

realistically, each document isn't going to exceed 32k tokens, before the next document comes along and the index needs to be cleared and retrained

If we don't need to support very large documents, then all the smart forgetting work becomes unnecessary. We can just clear the memory at the end of each document, so no need to forget individual entries at all. Maybe that's good enough to get past enwik8?

yea that's true, but it would be nice if the work is extended (RL for example)

@rom1504
Copy link
Author

rom1504 commented Apr 3, 2022

just to give an idea of how much faiss can scale, https://rom1504.github.io/clip-retrieval/ 's backend is currently holding an index of 5B embeddings of dimension 768. It uses 200MB of ram for the index thanks to faiss's memmapping feature.
The index is 800GB. The search time is 400ms because it's big and on disk.
(if using sharding and in-memory, the search time could be < 1ms)

so for millions of embeddings, everything should be fine

@lucidrains
Copy link
Owner

ok, let me meditate on this for an hour or two (before reckless execution), since it would require a big refactor of the KNN Memory class

thank you both, and hope you are having a great weekend :)

@rom1504
Copy link
Author

rom1504 commented Apr 11, 2022

You are giving batches to index.search and index.add right?
Faiss has good built-in support for parallism (implemented in c++)

@rom1504
Copy link
Author

rom1504 commented Apr 11, 2022

Thousand queries should take like 10ms with hnsw

@igor0
Copy link

igor0 commented Apr 11, 2022

OK, that's fair, it should be batched. Different elements will have to take different paths in the HNSW, but at least you have the loop in C++ and not in Python.

Nevertheless, HNSW is still the bottleneck. I guess it must be the add() that's the bottleneck, not the search. Because we also have to add thousands of entries into the HNSW graph for each training iteration. And add in HNSW is a lot slower than search.

@lucidrains
Copy link
Owner

it was fast enough for me to complete one experiment, and the results don't look that great tbh

i'll keep playing around with it this week. i just need to see some spark before i dedicate more engineering effort to this idea

@lucidrains
Copy link
Owner

perhaps https://github.com/lucidrains/HTM-pytorch would be a better bet, with the memories kept on CPU, and then the compressed memories kept in GPU (and i can reach for https://github.com/lucidrains/memory-efficient-attention-pytorch if need be), but that memory transformer variant rightfully belongs in another repository, as a Deepmind contribution

@lucidrains
Copy link
Owner

on further reflection, i don't even think the attention layer for the long term memories should mix local with distant memories (using all the gating logic)

one should simply place a long term memory module dab in the middle of the network, cross attend efficiently once

@igor0
Copy link

igor0 commented Apr 11, 2022

A few comments:

  • Yeah, the sigmoid gating looks pretty wonky to me. Probably takes lots of tweaking to get it to behave reasonably. Assuming 8 attention heads and 1 layer, the gate will literally be parameterized using 8 floats. Unless you have a very high learning rate, the gate parameters will move really slowly, so it will take forever to converge. So maybe the gate needs scaling by some magical factor?
  • Leaving that aside, summing up the local-memory values with distant-memory values looks ... not optimal? It seems like there is a number of ways that would seem to be much more robust: e.g., use dedicated attention heads for distant memories, or dedicated layer for distant access, or linearly-transform the keys associated with distant memories (to compensate for the lack of positional encoding) and concatenate with local ones.

Regarding your implementation:

  • I don't completely understand why you need the null_k and null_v parameters. Since PyTorch supports dynamic execution flow, can't you simply have an if-branch... if there are no memories, just use the local memories?

On my end, I am experimenting with these ideas, but taking a different path than you. Instead of trying to implement the whole paper, I'm starting from a pre-trained model and trying to implement trivial memorization in the smallest incremental steps I can think of. Hopefully at least one of us gets to some kind of an interesting results.

@lucidrains
Copy link
Owner

@igor0 ahh yes, i do like the idea of having separate heads for local vs distant

re: null key / values, the reason is because in the paper, one can have documents spanning different lengths across different batch rows

161446440-57ffa2e6-3ee6-44f1-815c-994a6ccf524c

it is just cleaner solution than to unbind the batch dimension and do all the conditional logic. the current way it is built is still not great, since ideally the number of memories present is also accounted for in the gating (so the network attends locally when no distant memories are present for a certain batch row)

@lucidrains
Copy link
Owner

@igor0 do keep us posted about results from your approach :)

@lucidrains
Copy link
Owner

lucidrains commented Apr 11, 2022

@igor0 i switched it to your splitting of attention heads idea and it looks a lot better now!! https://wandb.ai/lucidrains/memorizing-transformers/reports/memorizing-transformers--VmlldzoxODE4ODc1?accessToken=2ae4y9l100i3kfj3oa0udcsct0r495evo661lkk71w1mve1itlmphn20eiix1aul 🎉 🎉 thank you!

@lucidrains
Copy link
Owner

lucidrains commented Apr 11, 2022

@rom1504 also thank you for all your help with faiss and the hnsw suggestion!

@lucidrains
Copy link
Owner

lucidrains commented Apr 12, 2022

Screenshot from 2022-04-11 20-37-46

the effects are more pronounced when i turn off XL memories - will investigate for bugs in the XL memories code later this week

@igor0
Copy link

igor0 commented Apr 12, 2022

Interesting. I wonder whether the baseline got an unlucky initialization and so it had a poor start. I'm not sure I'd expect memories to be of that much help early on during the training. If the benefits of memorization show up later in the training, that may make the approach harder to evaluate.

@igor0
Copy link

igor0 commented Apr 14, 2022

On further thought, I'm not sure how exactly to make a dedicated distant attention head work. Since the memory access is non-differentiable, the backward pass won't reach the K and V transforms. So, the K and V transforms for the distant heads won't be trained. So, at least the K/V transform must be tied across the local-remote modes. Did you solve that in some way?

@lucidrains
Copy link
Owner

@igor0 i think that's the appeal of this paper. it seems the claim is that even with non differentiable memory, in addition to approximate nearest neighbors, the network would still perform well. i'm fairly confident the non-differentiable memories would do fine, as it lines up with transformer-xl. XL memories are also non-differentiable (they are detached for the next iteration)

however, one could augment the network with another projection of the non-differentiable memory key / values, perhaps conditioned by their age (which we can keep track of)

@igor0
Copy link

igor0 commented Apr 15, 2022

Well in the paper, each KV projection is used for both local and distant attention. So, even though the backward pass for distant attention doesn't reach the KV projection, that is still OK because the backward passes from local attention seem to be sufficient.

But, if you have a distant-only head, there will literally be no backward passes reaching the KV projections. So, their weights will forever keep their initial random values. That can't possibly work.

Anyways, it's well possible that I'm missing something, but that's why it seems to me that distant-only attention heads won't work trivially. Using a hybrid local-distant head seems to be one solution.

@lucidrains
Copy link
Owner

ahhh got it, yes, we can try the hybrid approach too (which i assume is just concatting the fetched distant key / values to the local one, prior to softmax)

@lucidrains
Copy link
Owner

ultimately, i think we should just let empiricism speak. hard to know which option is best without running the experiments

@lucidrains
Copy link
Owner

I'll have some free time Sunday and will run some experiments comparing the two ways

@lucidrains
Copy link
Owner

@igor0 Hey Igor, I finally got some time to attend to memorizing transformers. Mainly I tried out the hybrid attention as you suggested and it is working the best! I've checked it in for 0.3.0

@lucidrains
Copy link
Owner

the reprojection of the memories did very little, possibly a little worse, just to share a negative result

@lucidrains
Copy link
Owner

lucidrains commented Apr 23, 2022

Screen Shot 2022-04-23 at 5 10 24 PM

finally seeing liftoff with 3 modifications to the KNN attention module

  1. use cosine-sim attention with low initial temperature
  2. hybrid attention (thanks for the suggestion 🙏 )
  3. use separate relative positional bias (as opposed to regular attention layers) - this should also allow the network to turn off attending locally later into training, if need be

@gurvindersingh
Copy link

I can suggest give a try to dynamic position bias instead of T5 rel pos, as later is usually slow and DPB can give similar performance for better speed.

@lucidrains
Copy link
Owner

@gurvindersingh Oh hey Gurvinder! How is it faster than T5 rel pos? The way I see it, T5 rel pos is a zero-layered DPB (usually DPB is a 2-3 layered MLP)

@igor0
Copy link

igor0 commented Apr 25, 2022

Very cool!

Yeah, a full cosine-similarity attention makes sense. Normalizing key and values like in the paper sounds... odd. I'm sure the researchers behind the paper had their rationale.

On my end, I've been trying to add memorization to an existing pretrained model. The positive result is that if I add memory to one or more layers in a pretrained model as hybrid attention, I get a modest but measurable performance benefit. I've been trying to take such model (with one KNN layer) and train it (e.g., with attentions unfrozen, or everything unfrozen) to better adapt the model to exchange info in the KNN layer.

I haven't had much success with the retraining so far. The retraining tends to be unstable and quickly diverge. I've tried a number of the things you are trying here (project the memorized keys/values, add a trained bias to memorized keys, try it with positional biases for memory or without, etc). I can normalize keys/queries or keys/values in the KNN layer, but then that degrades the pretrained model substantially. It seems to stabilize the retraining somewhat, but then I lose some of the benefit of starting from a pretrained model.

The only thing I found that sort of worked was that training a bias for memorized keys gives a modest but measurable benefit. But I haven't been able to stabilize the training sufficiently to get really compelling results. Still working on it and trying things, though! A couple of ideas I'd still like to try include suppressing the KNN attention early in the training (e.g, by using a trained penalty coefficient applied to the distant attention scores before the distant+local softmax) or perhaps adding a dedicated layer just for distant attention.

@lucidrains
Copy link
Owner

@igor0 ahh ok, yes, i did have to do a lot of trial and error before i started to see something with the formula outlined above

i'll keep chipping away at it when i find some time

@gurvindersingh
Copy link

@lucidrains I haven't looked into the code part but when I was testing T5 rel pos embedding and DPB, with DPB model was taking less time per step. Hence my above statement :)

@lucidrains
Copy link
Owner

@gurvindersingh ohh interesting! i'll have to revisit the T5 rel pos bias class and make sure i'm caching some intermediates correctly

@kevinpl07
Copy link

I'm gonna piggyback this thread, because it is quite active.

A memory_list is only needed when I have multiple Layers that contain a KNNAttentionblock right?
So if I have only one layer, I can just use one instance of KNNMemory and don't need to organize them in a list.

Thanks in advance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants