Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: not the same results on CPU and GPU #24779

Open
3 tasks done
lebence opened this issue May 30, 2024 · 1 comment
Open
3 tasks done

[Bug]: not the same results on CPU and GPU #24779

lebence opened this issue May 30, 2024 · 1 comment
Assignees
Labels
bug Something isn't working category: GPU OpenVINO GPU plugin support_request

Comments

@lebence
Copy link

lebence commented May 30, 2024

OpenVINO Version

2024.1.0

Operating System

Ubuntu 20.04 (LTS)

Device used for inference

GPU

Framework

Keras (TensorFlow 2)

Model used

Custom classifier model

Issue description

After upgrading from version 2023.* to 2024.*, our inference results on CPU and GPU are not the same, in fact they are completely different. So if I use the 2024.*.* version for prediction, the result will not be the same on CPU and GPU. Not the version of OpenVINO which the model was built on counts, rather the version used in the inference. If I optimize the network with version 2023.*.*, but use version 2024.*.* in inference, the error persists. Of course in inference with version 2023.*.* the prediction works fine. I wrote a few lines of code in order to reproduce the result with a dummy network.

Step-by-step reproduction

Firstly create your network with this code snippet. I used Tensorflow 2.10.1.

import os

#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import tensorflow as tf
from tensorflow import keras

img_inputs = keras.Input(shape=(192, 192, 1))

x = tf.keras.layers.Conv2D(32, (3,3), padding="same", activation="relu")(img_inputs)
x = tf.keras.layers.Conv2D(32, (3,3), activation="relu")(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2,2))(x)
x = tf.keras.layers.Dropout(0.30)(x)
x = tf.keras.layers.Conv2D(64, (3,3), padding="same", activation="relu")(x)
x = tf.keras.layers.Conv2D(64, (3,3), activation="relu")(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2,2))(x)
x = tf.keras.layers.Dropout(0.30)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
head_d_1 = tf.keras.layers.Dense(256, activation='relu')(x)
head_d_2 = tf.keras.layers.Dense(128, activation='relu')(head_d_1)
head_out = tf.keras.layers.Dense(18, activation='softmax')(head_d_2)

model = keras.Model(inputs=img_inputs, outputs=head_out)

print(model.summary())

def myprint(s):
    with open('vino_bug_2024_inference/dummy_model_summary.txt','a') as f:
        print(s, file=f)

model.summary(print_fn=myprint)

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer="adam",
    metrics=["accuracy"],
)

model.save('vino_bug_2024_inference/dummy_model')

Then I got this model: dummy_model.zip

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_3 (InputLayer)        [(None, 192, 192, 1)]     0         
                                                                 
 conv2d_8 (Conv2D)           (None, 192, 192, 32)      320       
                                                                 
 conv2d_9 (Conv2D)           (None, 190, 190, 32)      9248      
                                                                 
 max_pooling2d_4 (MaxPooling  (None, 95, 95, 32)       0         
 2D)                                                             
                                                                 
 dropout_4 (Dropout)         (None, 95, 95, 32)        0         
                                                                 
 conv2d_10 (Conv2D)          (None, 95, 95, 64)        18496     
                                                                 
 conv2d_11 (Conv2D)          (None, 93, 93, 64)        36928     
                                                                 
 max_pooling2d_5 (MaxPooling  (None, 46, 46, 64)       0         
 2D)                                                             
                                                                 
 dropout_5 (Dropout)         (None, 46, 46, 64)        0         
                                                                 
 batch_normalization_2 (Batc  (None, 46, 46, 64)       256       
 hNormalization)                                                 
                                                                 
 global_average_pooling2d_2   (None, 64)               0         
 (GlobalAveragePooling2D)                                        
                                                                 
 dense_6 (Dense)             (None, 256)               16640     
                                                                 
 dense_7 (Dense)             (None, 128)               32896     
                                                                 
 dense_8 (Dense)             (None, 18)                2322      
                                                                 
=================================================================
Total params: 117,106
Trainable params: 116,978
Non-trainable params: 128
_________________________________________________________________

I used your model optimizer. In this case, I used 2024.1.0 for optimization, but a network optimized with 2023.0.0 or 2023.3.0 also caused errors.

mo --saved_model_dir vino_bug_2024_inference/dummy_model/ --input_shape [?,192,192,1] --layout NHWC --output_dir vino_bug_2024_inference/vino_model --use_legacy_frontend

The vino model: vino_model.zip

I used raw_image.txt to predict on both CPU and GPU with this script. (You need to modify device at compile_model)

import os
import numpy as np
from typing import List
from openvino.runtime import Core
import xml.etree.ElementTree as ET

def _parse_xml(model_path: str) -> ET.Element:
        tree = ET.parse(model_path)
        return tree.getroot()

def _get_input_shape_from_xml(xml_root: ET.Element) -> List[int]:
    data = xml_root.find(".//layer/data")
    input_shape = data.get("shape") if data is not None else ""
    return input_shape    

model_path = os.path.join("vino_bug_2024_inference/vino_model","saved_model.xml")
ov_core = Core()
print("Available devices:",ov_core.available_devices)
model = ov_core.read_model(model_path)
root_vino = _parse_xml(model_path)

input_shape = _get_input_shape_from_xml(root_vino)
input_layer_name = next(iter(model.inputs))
model.reshape({input_layer_name: input_shape})
compiled_model = ov_core.compile_model(model, "CPU") # <-- modify it here

raw_image = "vino_bug_2024_inference/raw_image.txt"
input_data = np.loadtxt(raw_image, comments="#", delimiter=",", unpack=False)
input_data = np.expand_dims(input_data, axis=2)
input_data = np.expand_dims(input_data, axis=0)
print("Input shape: ",input_data.shape)

input_data = input_data
infer_request = compiled_model.create_infer_request()
infer_request.infer({input_layer_name.get_any_name(): input_data})
output = infer_request.get_output_tensor(0)
print("Output: \n",output.data[0])

np.savetxt("vino_bug_2024_inference/full_model_result_cpu.txt", output.data[0], delimiter=",", fmt="%s")

And got these results

The results are completely different and it is not an epsilon difference. We were curious which layer was causing the difference, so we started to pick layers from the network (working backwards from the very end) continuously. This continued until we got to the very first convolution layer, for which we ran the model optimizer and then the prediction, and there was still a huge difference between the CPU and GPU results. (We added a Flatten layer for easier understanding of the output)

import os
import tensorflow as tf
from keras.layers import Flatten
from keras.models import Model
print(tf.version.VERSION)

model = tf.keras.models.load_model('vino_bug_2024_inference/dummy_model', compile=False)
model.summary()

input_model = Model(inputs=model.input, outputs=model.get_layer("conv2d_8").output) #the name of my first conv layer
input_model_output = input_model.output

# Add a Flatten layer
flatten_output = Flatten()(input_model_output)

# Create a new model with the Flatten layer as its output
input_conv_flatten_model = Model(inputs=input_model.input, outputs=flatten_output)
input_conv_flatten_model.summary()

def myprint(s):
    with open('vino_bug_2024_inference/input_conv_flatten_model_summary.txt','a') as f:
        print(s, file=f)

input_conv_flatten_model.summary(print_fn=myprint)


input_conv_flatten_model.save("vino_bug_2024_inference/input_conv_flatten_model")
Model: "model_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_3 (InputLayer)        [(None, 192, 192, 1)]     0         
                                                                 
 conv2d_8 (Conv2D)           (None, 192, 192, 32)      320       
                                                                 
 flatten_5 (Flatten)         (None, 1179648)           0         
                                                                 
=================================================================
Total params: 320
Trainable params: 320
Non-trainable params: 0
_________________________________________________________________

Keras savedmodel: input_conv_flatten_model.zip
Vino model from this savedmodel (using the script above): vino_model_input_conv_flatten.zip

Predicting again... (You need to modify device at compile_model)

import os
import numpy as np
from typing import List
from openvino.runtime import Core
import xml.etree.ElementTree as ET

def _parse_xml(model_path: str) -> ET.Element:
        tree = ET.parse(model_path)
        return tree.getroot()

def _get_input_shape_from_xml(xml_root: ET.Element) -> List[int]:
    data = xml_root.find(".//layer/data")
    input_shape = data.get("shape") if data is not None else ""
    return input_shape    

model_path = os.path.join("vino_bug_2024_inference/vino_model_input_conv_flatten","saved_model.xml")
ov_core = Core()
print("Available devices:",ov_core.available_devices)
model = ov_core.read_model(model_path)
root_vino = _parse_xml(model_path)

input_shape = _get_input_shape_from_xml(root_vino)
input_layer_name = next(iter(model.inputs))
model.reshape({input_layer_name: input_shape})
compiled_model = ov_core.compile_model(model, "CPU") #<- modify here it to GPU

raw_image = "vino_bug_2024_inference/raw_image.txt"
input_data = np.loadtxt(raw_image, comments="#", delimiter=",", unpack=False)
input_data = np.expand_dims(input_data, axis=2)
input_data = np.expand_dims(input_data, axis=0)
print("Input shape: ",input_data.shape)

input_data = input_data
infer_request = compiled_model.create_infer_request()
infer_request.infer({input_layer_name.get_any_name(): input_data})
output = infer_request.get_output_tensor(0)
print("Output: \n",output.data[0])

np.savetxt("vino_bug_2024_inference/input_conv_flatten_model_result_rawcpu.txt", output.data[0], delimiter=",", fmt="%s")

And I got there results on CPU and GPU.

So in summary, it came out that a single convolution layer on CPU and GPU gives completely different results. In the example I gave a classification problem, but the difference between CPU and GPU results is even more noticeable when talking about e.g. detection.

We tested what happens if we use 2023.0.0 or 2023.3.0 for inference prediction. In this case we looked at the difference between CPU and GPU and found only epsilon difference.

Relevant log output

No response

Issue submission checklist

  • I'm reporting an issue. It's not a question.
  • I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
  • There is reproducer code and related data files such as images, videos, models, etc.
@lebence lebence added bug Something isn't working support_request labels May 30, 2024
@andrei-kochin andrei-kochin added the category: GPU OpenVINO GPU plugin label May 31, 2024
@lebence
Copy link
Author

lebence commented Jun 18, 2024

We checked the newest release (2024.2.0) and the problem still exists. Were you able to reproduce the bug? @p-durandin

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working category: GPU OpenVINO GPU plugin support_request
Projects
None yet
Development

No branches or pull requests

3 participants