Skip to content

Latest commit

 

History

History

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 
 
 

Keras MLP


Usage

  • Basic usage
    from keras_cv_attention_models import mlp_family
    # Will download and load `imagenet` pretrained weights.
    # Model weight is loaded with `by_name=True, skip_mismatch=True`.
    mm = mlp_family.MLPMixerB16(num_classes=1000, pretrained="imagenet")
    
    # Run prediction
    import tensorflow as tf
    from tensorflow import keras
    from skimage.data import chelsea # Chelsea the cat
    imm = keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='tf') # model="tf" or "torch"
    pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy()
    print(keras.applications.imagenet_utils.decode_predictions(pred)[0])
    # [('n02124075', 'Egyptian_cat', 0.9568315), ('n02123045', 'tabby', 0.017994137), ...]
    For "imagenet21k" pre-trained models, actual num_classes is 21843.
  • Exclude model top layers by set num_classes=0.
    from keras_cv_attention_models import mlp_family
    mm = mlp_family.ResMLP_B24(num_classes=0, pretrained="imagenet22k")
    print(mm.output_shape)
    # (None, 784, 768)
    
    mm.save('resmlp_b24_imagenet22k-notop.h5')

MLP mixer

ResMLP

GMLP

WaveMLP

  • PDF 2111.12294 An Image Patch is a Wave: Quantum Inspired Vision MLP

  • Model weights reloaded from Github huawei-noah/wavemlp_pytorch.

  • Models

    Model Params FLOPs Input Top1 Acc Download
    WaveMLP_T 17M 2.47G 224 80.9 wavemlp_t_imagenet.h5
    WaveMLP_S 30M 4.55G 224 82.9 wavemlp_s_imagenet.h5
    WaveMLP_M 44M 7.92G 224 83.3 wavemlp_m_imagenet.h5
    WaveMLP_B 63M 10.26G 224 83.6
  • Dynamic input shape

    from skimage.data import chelsea
    from keras_cv_attention_models import wave_mlp
    mm = wave_mlp.WaveMLP_T(input_shape=(None, None, 3))
    pred = mm(mm.preprocess_input(chelsea(), input_shape=[320, 320, 3]))
    print(mm.decode_predictions(pred)[0])
    # [('n02124075', 'Egyptian_cat', 0.4864809), ('n02123159', 'tiger_cat', 0.14551573), ...]
  • Verification with PyTorch version

    inputs = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
    
    """ PyTorch WaveMLP_T """
    sys.path.append("../CV-Backbones")
    from wavemlp_pytorch.models import wavemlp as torch_wavemlp
    import torch
    torch_model = torch_wavemlp.WaveMLP_T()
    ww = torch.load('WaveMLP_T.pth.tar', map_location=torch.device('cpu'))
    ww = {kk: vv for kk, vv in ww.items() if not kk.endswith("total_ops") and not kk.endswith("total_params")}
    torch_model.load_state_dict(ww)
    _ = torch_model.eval()
    torch_out = torch_model(torch.from_numpy(inputs).permute(0, 3, 1, 2)).detach().numpy()
    
    """ Keras WaveMLP_T """
    from keras_cv_attention_models import wave_mlp
    mm = wave_mlp.WaveMLP_T(pretrained="imagenet", classifier_activation=None)
    keras_out = mm(inputs).numpy()
    
    """ Verification """
    print(f"{np.allclose(torch_out, keras_out, atol=1e-5) = }")
    # np.allclose(torch_out, keras_out, atol=1e-5) = True