Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributions from space bounds #1875

Open
will-maclean opened this issue Jun 10, 2024 · 0 comments
Open

Distributions from space bounds #1875

will-maclean opened this issue Jun 10, 2024 · 0 comments

Comments

@will-maclean
Copy link
Contributor

Feature description

Create a distribution from either a scalar bound or a tensor bound. This adds support for much more complex distributions.

Feature motivation

In python/numpy, we can do the following:

import numpy as np

# Define 2D low and high arrays
lows = np.array([[1, 2, 3], [4, 5, 6]])
highs = np.array([[7, 8, 9], [10, 11, 12]])

# Ensure that the low and high arrays have the same shape
assert lows.shape == highs.shape, "Lows and highs must have the same shape"

# Generate the random 2D array
random_array = np.random.uniform(lows, highs)

This allows us to create samples where each element is sample from a differently-parametrised distribution. In burn currently this is not possible - we can only sample from a scalar bound:

// possible
let shape = Shape::new([2, 3]);
let low = 0.0;
let high = 1.0;
Tensor::random(shape, Distribution::Uniform(low, high), &Default::default());

// not possible
let shape = Shape::new([2, 3]);
let low = Tensor::from_data([0.0, 0.5, 1.0], &Default::default());
let high = Tensor::from_data([1.0, 0.7, 1.01], &Default::default());
Tensor::random(shape, Distribution::Uniform(low, high), &Default::default());

This is because the Distribution enum only support f64 bounds:

// crates/burn-tensor/src/tensor/data.rs

// Distribution for random value of a tensor.
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum Distribution {
    /// Uniform distribution from 0 (inclusive) to 1 (exclusive).
    Default,

    /// Bernoulli distribution with the given probability.
    Bernoulli(f64),

    /// Uniform distribution. The range is inclusive.
    Uniform(f64, f64),

    /// Normal distribution with the given mean and standard deviation.
    Normal(f64, f64),
}

Suggest a Solution

I think it would be very useful to define a Bounds enum, that allows for the definition of a bound in either the scalar or tensor form. Something like the following would then be possible:

pub enum SpaceBound<B: Backend, const D: usize>{
  Scalar(f64),
  Space(Tensor<B, D>)
}

/// Distribution for random value of a tensor.
#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum Distribution<B: Backend, const D: usize> {
    /// Uniform distribution from 0 (inclusive) to 1 (exclusive).
    Default,

    /// Bernoulli distribution with the given probability.
    Bernoulli(SpaceBound<B, D>),

    /// Uniform distribution. The range is inclusive.
    Uniform(SpaceBound<B, D>, SpaceBound<B, D>),

    /// Normal distribution with the given mean and standard deviation.
    Normal(SpaceBound<B, D>, SpaceBound<B, D>),
}

When sampling, we could then match on the SpaceBound enum. Very roughly, something like:

// somewhere in the sampling code
match distribution {
  Uniform(low, high) => match (low, high) {
    (Scalar(low), Scalar(high)) => // same as current implementation
    (Space(low), Space(high)) => // new implementation
    (Scalar(low), Space(high)) => // new implementation
    (Space(low), Scalar(high)) => // new implementation
  }
  _ => {}
}

Instead of matching on the different combinations, we could also just try broadcasting the bounds up to the specified shape.

Either way, note the flexibility of being able to have different combinations of scalars and spaces. I also know that I am going to need something describing space boundaries in my own project, so having this implemented as a concept in burn would mean I don't need to jump between my own types and burn's types as much.

Keen to hear people's thoughts. I'm happy to have a go at implementing this myself if that's useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant