Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch matrix multiply leads to vulkan error on WGPU #1865

Open
jungerm2 opened this issue Jun 7, 2024 · 2 comments
Open

Batch matrix multiply leads to vulkan error on WGPU #1865

jungerm2 opened this issue Jun 7, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@jungerm2
Copy link

jungerm2 commented Jun 7, 2024

I wasn't sure if batch matmul was supported, as this seems to be documented nowhere except in pytorch's documentation. It seems to work fine with small tensors but breaks down past a certain size:

type B = burn::backend::wgpu::JitBackend<WgpuRuntime<AutoGraphicsApi, f32, i32>>;
let a: Tensor<B, 4> = Tensor::random([500, 500, 4, 5], Distribution::Normal(-1.0, 1.0), &Default::default());
let b: Tensor<B, 4> = Tensor::random([500, 500, 5, 6], Distribution::Normal(-1.0, 1.0), &Default::default());
let out = a.matmul(b);
println!("{:?}", out);

I'd expect an output tensor of shape [500, 500, 4, 6], but instead I get the following error:

wgpu error: Validation Error

Caused by:
    In a ComputePass
      note: encoder = `<CommandBuffer-(1, 1, Vulkan)>`
    In a dispatch command, indirect:false
      note: compute pipeline = `<ComputePipeline-(4, 1, Vulkan)>`
    Each current dispatch group size dimension ([1, 1, 250000]) must be less or equal to 65535

So it seems there's a maximal dimension of 65535 for bmm. I would expect that this backend-specific limitation be abstracted away, i.e the backend should likely batch the bmm and recombine them automatically. Is there a current workaround for this?

I'm using burn 0.13.2 with vulkan version 1.3.283 on fedora 40.

@laggui laggui added the bug Something isn't working label Jun 7, 2024
@jungerm2
Copy link
Author

jungerm2 commented Jun 7, 2024

For the time being, this seems to work, but it's much slower than I'd like:

/// Perform batch matrix multiplication by splitting first dimension
/// if larger than 65535 and recombining results
pub fn bmm<B: Backend>(a: Tensor<B, 3>, b: Tensor<B, 3>) -> Tensor<B, 3> {
    let batch_size = 65535;
    let [n1, i, j1] = a.dims();
    let [n2, j2, k] = b.dims();
    assert_eq!(n1, n2); 
    assert_eq!(j1, j2);

    if n1 <= batch_size {
        return a.matmul(b); 
    }
    
    let ranges: Vec<_> = (0..(n1 as u32).div_ceil(batch_size as u32))
        .map(|i| (batch_size * i as usize)..(batch_size * (i + 1) as usize).min(n1))
        .collect();
    let result_parts: Vec<_> = ranges.into_iter().map(|r| {
        let a_part = a.clone().slice([r.clone(), 0..i, 0..j1]);
        let b_part = b.clone().slice([r.clone(), 0..j2, 0..k]);
        a_part.matmul(b_part)
    }).collect();
    Tensor::cat(result_parts, 0)
}

It could be speed up by into_par_iter from rayon, but to my great surprise, FloatTensorPrimitive is Send but not Sync...

EDIT: Fixed some indexing in the above code, now it actually works.

@nathanielsimard
Copy link
Member

Well, it's a limitation in the launching paradigm used to compute matrix multiplication. Maybe we could use another dimension to handle the batch part.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants