Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The question about function create_compressed_model():RuntimeError: CUDA error: device-side assert triggered #2688

Open
zbnlala opened this issue May 21, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@zbnlala
Copy link

zbnlala commented May 21, 2024

🐛 Describe the bug

Hi, I want to ask about the quantization of the non-neural network parts in the neural network.
I want to realize the quantization of the PointNet-based network to deal with classification tasks

The other part of the network are just some conv1d and relu function. Then I modified the train function with
changed dataset and evaluation method. However the error occurred . Can you help me fix this problem?
This is the error information:

Traceback (most recent call last):
  File "/home/lalala/examples/torch/classification/quant_train_cls.py", line 874, in <module>
    main(sys.argv[1:])
  File "/home/lalala/examples/torch/classification/quant_train_cls.py", line 151, in main
    start_worker(main_worker, config)
  File "/home/lalala/examples/torch/common/execution.py", line 114, in start_worker
    mp.spawn(main_worker, nprocs=config.ngpus_per_node, args=(config,))
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 241, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
    while not context.join():
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 158, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 68, in _wrap
    fn(i, *args)
  File "/home/lalala/examples/torch/classification/quant_train_cls.py", line 267, in main_worker
    compression_ctrl, model = create_compressed_model(
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/telemetry/decorator.py", line 72, in wrapped
    retval = fn(*args, **kwargs)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/model_creation.py", line 134, in create_compressed_model
    compressed_model = builder.apply_to(nncf_network)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/compression_method_api.py", line 124, in apply_to
    transformed_model = transformer.transform(transformation_layout)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/model_transformer.py", line 78, in transform
    model.nncf.rebuild_graph()
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/nncf_network.py", line 555, in rebuild_graph
    compressed_traced_graph = builder.build_dynamic_graph(
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/graph/graph_builder.py", line 53, in build_dynamic_graph
    return tracer.trace_graph(model, context_to_use, as_eval, trace_parameters)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/dynamic_graph/graph_tracer.py", line 53, in trace_graph
    self.custom_forward_fn(model)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/dynamic_graph/graph_tracer.py", line 96, in default_dummy_forward_fn
    retval = model(*args, **kwargs)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/nncf_network.py", line 1004, in __call__
    return ORIGINAL_CALL(self, *args, **kwargs)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/nncf_network.py", line 1036, in forward
    retval = wrap_module_call(self.nncf._original_unbound_forward)(self, *args, **kwargs)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 154, in wrapped
    retval = module_call(self, *args, **kwargs)
  File "/home/lalala/examples/torch/classification/models/pointnet2_cls.py", line 39, in forward
    l1_xyz, l1_points = self.sa1(xyz, norm)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 154, in wrapped
    retval = module_call(self, *args, **kwargs)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lalala/examples/torch/classification/models/pointnet2_utils.py", line 197, in forward
    new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
  File "/home/lalala/examples/torch/classification/models/pointnet2_utils.py", line 131, in sample_and_group
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
  File "/home/lalala/examples/torch/classification/models/pointnet2_utils.py", line 59, in index_points
    new_points = points[batch_indices, idx, :]
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 98, in wrapped
    result = _execute_op(op_address, operator_info, operator, ctx, *args, **kwargs)
  File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 179, in _execute_op
    result = operator(*args, **kwargs)
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Environment

nccf==2.10.0
pytorch==2.2.2
python==3.9

Minimal Reproducible Example

The non-neural network parts of the network is as follow:

def square_distance(src, dst):
    Calculate Euclid distance between each two points.
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]

    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dis

 
def index_points(points, idx):
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]

    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

def farthest_point_sample(xyz, npoint):

    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]

    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10

    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids


def query_ball_point(radius, nsample, xyz, new_xyz):

    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]

    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx


def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]

    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
    new_xyz = index_points(xyz, fps_idx) # [B, npoint, 3]
    idx = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)


    if points is not None:
        grouped_points = index_points(points, idx)
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
    else:
        new_points = grouped_xyz_norm
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx
    else:
        return new_xyz, new_points


def sample_and_group_all(xyz, points):

    Input:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, 1, 3]
        new_points: sampled points data, [B, 1, N, 3+D]

    device = xyz.device
    B, N, C = xyz.shape
    new_xyz = torch.zeros(B, 1, C).to(device)
    grouped_xyz = xyz.view(B, 1, N, C)
    if points is not None:
        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points


class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.group_all = group_all

    def forward(self, xyz, points):

        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]

        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points)
        else:
            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        # print(new_points.shape)

        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points)))

        new_points = torch.max(new_points, 2)[0] # max at dimension of nsample 
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points
@zbnlala zbnlala added the bug Something isn't working label May 21, 2024
@AlexanderDokuchaev
Copy link
Collaborator

HI @zbnlala
Could you please provide a full script to reproduce the issue?

@zbnlala
Copy link
Author

zbnlala commented May 22, 2024

HI @AlexanderDokuchaev. Sure!
First, my code is modified from examples/torch/classification and put the train.py into this dir.
Note that, for train.py you just need to focus on the function main_worker().
Moreover, the modified dataset code is ModelNetDataLoader.py placed in examples/torch/classification/data_utils(the dir is created by myself). From now the code must run on the modelnet dataset, I do not know how to run create_compressed_model() without the dataloader. Therefore how to get dataset is here.

the model architecture is described in pointnet2_cls.py and pointnet2_utils.py placed in
examples/torch/classification/models.
And then, the json is pointnet_v2_classification_int8.json placed in examples/torch/classification/configs/quantization.

Finally , run the script will cause the bug
CUDA_VISIBLE_DEVICES="0,1" NCCL_P2P_DISABLE=1 python train.py -m test --config configs/quantization/pointnet_v2_classification_int8.json --log-dir=../../results/quantization/pointnet_v2_int8/ --multiprocessing-distributed

However, if I run script with extra --cpu-only, there will no bug.(CUDA version 12.3)
train.py

# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#      http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# # limitations under the License.
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import os.path as osp
import sys
import time
import warnings
from copy import deepcopy
from functools import partial
from pathlib import Path
from shutil import copyfile
from typing import Any
from tqdm import tqdm
import numpy as np

import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torch import nn
from torch.backends import cudnn
from torch.cuda.amp.autocast_mode import autocast
from torch.nn.modules.loss import _Loss
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import datasets
from torchvision import models
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
from torchvision.models import InceptionOutputs
from examples.common.paths import configure_paths
from examples.common.sample_config import SampleConfig
from examples.common.sample_config import create_sample_config
from examples.torch.common.argparser import get_common_argument_parser
from examples.torch.common.argparser import parse_args
from examples.torch.common.example_logger import logger
from examples.torch.common.execution import ExecutionMode
from examples.torch.common.execution import get_execution_mode
from examples.torch.common.execution import prepare_model_for_execution
from examples.torch.common.execution import set_seed
from examples.torch.common.execution import start_worker
from examples.torch.common.export import export_model
from examples.torch.common.model_loader import COMPRESSION_STATE_ATTR
from examples.torch.common.model_loader import MODEL_STATE_ATTR
from examples.torch.common.model_loader import extract_model_and_compression_states
from examples.torch.common.model_loader import load_model
from examples.torch.common.model_loader import load_resuming_checkpoint
from examples.torch.common.optimizer import get_parameter_groups
from examples.torch.common.optimizer import make_optimizer
from examples.torch.common.utils import MockDataset
from examples.torch.common.utils import NullContextManager
from examples.torch.common.utils import configure_device
from examples.torch.common.utils import configure_logging
from examples.torch.common.utils import create_code_snapshot
from examples.torch.common.utils import get_run_name
from examples.torch.common.utils import is_pretrained_model_requested
from examples.torch.common.utils import is_staged_quantization
from examples.torch.common.utils import make_additional_checkpoints
from examples.torch.common.utils import print_args
from examples.torch.common.utils import write_metrics
from nncf.api.compression import CompressionStage
from nncf.common.accuracy_aware_training import create_accuracy_aware_training_loop
from nncf.common.utils.tensorboard import prepare_for_tensorboard
from nncf.config.utils import is_accuracy_aware_training
from nncf.torch import create_compressed_model
from nncf.torch.checkpoint_loading import load_state
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.initialization import default_criterion_fn
from nncf.torch.initialization import register_default_init_args
from nncf.torch.structures import ExecutionParameters
from nncf.torch.utils import is_main_process
from nncf.torch.utils import safe_thread_call

from examples.torch.classification.models.pointnet2_cls import pointnetv2_cls
from examples.torch.common import restricted_pickle_module

from data_utils.ModelNetDataLoader import ModelNetDataLoader

model_names = sorted(
    name
    for name, val in models.__dict__.items()
    if name.islower() and not name.startswith("__") and callable(val)
)


def get_argument_parser():
    parser = get_common_argument_parser()
    parser.add_argument(
        "--dataset",
        help="Dataset to use.",
        choices=["imagenet", "cifar100", "cifar10"],
        default=None,
    )
    parser.add_argument(
        "--local_rank",
        default=None,
    )
    parser.add_argument(
        "--test-every-n-epochs",
        default=1,
        type=int,
        help="Enables running validation every given number of epochs",
    )
    parser.add_argument(
        "--mixed-precision",
        dest="mixed_precision",
        help="Enables torch.cuda.amp autocasting during training and validation steps",
        action="store_true",
    )
    return parser


def main(argv):
    parser = get_argument_parser()
    args = parse_args(parser, argv)
    config = create_sample_config(args, parser)

    if config.dist_url == "env://":
        config.update_from_env()

    configure_paths(config, get_run_name(config))
    copyfile(args.config, osp.join(config.log_dir, "config.json"))
    source_root = Path(__file__).absolute().parents[2]  # nncf root
    create_code_snapshot(source_root, osp.join(config.log_dir, "snapshot.tar.gz"))

    if config.seed is not None:
        warnings.warn(
            "You have chosen to seed training. "
            "This will turn on the CUDNN deterministic setting, "
            "which can slow down your training considerably! "
            "You may see unexpected behavior when restarting "
            "from checkpoints."
        )

    config.execution_mode = get_execution_mode(config)

    if config.metrics_dump is not None:
        write_metrics(0, config.metrics_dump)

    if not is_staged_quantization(config):
        start_worker(main_worker, config)
    else:
        from examples.torch.classification.staged_quantization_worker import (
            staged_quantization_main_worker,
        )

        start_worker(staged_quantization_main_worker, config)


def inception_criterion_fn(
    model_outputs: Any, target: Any, criterion: _Loss
) -> torch.Tensor:
    # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
    output, aux_outputs = model_outputs
    loss1 = criterion(output, target)
    loss2 = criterion(aux_outputs, target)
    return loss1 + 0.4 * loss2


def main_worker(current_gpu, config: SampleConfig):
    configure_device(current_gpu, config)
    if is_main_process():
        configure_logging(logger, config)
        print_args(config)
    else:
        config.tb = None

    set_seed(config)

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(config.device)

    model_name = config["model"]
    train_criterion_fn = (
        inception_criterion_fn if "inception" in model_name else default_criterion_fn
    )

    train_loader = train_sampler = val_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path
    nncf_config = config.nncf_config
    pretrained = is_pretrained_model_requested(config)
    is_export_only = "export" in config.mode and (
        "train" not in config.mode and "test" not in config.mode
    )

    if is_export_only:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        # Data loading code
        # train_dataset, val_dataset = create_datasets(config)
        train_dataset, val_dataset = pointnet_dataset(config)
        train_loader, train_sampler, val_loader, init_loader = create_data_loaders(
            config, train_dataset, val_dataset
        )

        def train_steps_fn(loader, model, optimizer, compression_ctrl, train_steps):
            train_epoch(
                loader,
                model,
                criterion,
                train_criterion_fn,
                optimizer,
                compression_ctrl,
                0,
                config,
                train_iters=train_steps,
                log_training_info=False,
            )

        def validate_model_fn(model, eval_loader):
            instance_acc,class_acc  = validate(
                eval_loader, model, criterion, config
            )
            return instance_acc,class_acc

        def model_eval_fn(model):
            acc,_ = validate(val_loader, model, criterion, config)
            return acc

        execution_params = ExecutionParameters(config.cpu_only, config.current_gpu)

        nncf_config = register_default_init_args(
            nncf_config,
            init_loader,
            criterion=criterion,
            criterion_fn=train_criterion_fn,
            train_steps_fn=train_steps_fn,
            validate_fn=lambda *x: validate_model_fn(*x)[::2],
            autoq_eval_fn=lambda *x: validate_model_fn(*x)[1],
            val_loader=val_loader,
            model_eval_fn=model_eval_fn,
            device=config.device,
            execution_parameters=execution_params,
        )

    # create model
    num_classes=config.get("num_classes", 1000)
    model_params=config.get("model_params")
    weights_path=config.get("weights")
    load_model_fn = partial(pointnetv2_cls, num_class=num_classes,pretrained=pretrained,load_path=weights_path)
    model = safe_thread_call(load_model_fn)
    if not pretrained and weights_path is not None:
        # Check if provided path is a url and download the checkpoint if yes
        sd = torch.load(weights_path, map_location="cpu", pickle_module=restricted_pickle_module)
       
        sd=sd["model_state_dict"]
        if MODEL_STATE_ATTR in sd:
            sd = sd[MODEL_STATE_ATTR]
        load_state(model, sd, is_resume=False)

    model.to(config.device)

    if "train" in config.mode and is_accuracy_aware_training(config):
        uncompressed_model_accuracy = model_eval_fn(model)

    resuming_checkpoint = None
    if resuming_checkpoint_path is not None:
        resuming_checkpoint = load_resuming_checkpoint(resuming_checkpoint_path)
    model_state_dict, compression_state = extract_model_and_compression_states(
        resuming_checkpoint
    )
    compression_ctrl, model = create_compressed_model(
        model, nncf_config, compression_state
    )


def train(
    config,
    compression_ctrl,
    model,
    criterion,
    criterion_fn,
    lr_scheduler,
    model_name,
    optimizer,
    train_loader,
    train_sampler,
    val_loader,
    best_acc1=0,
):
    best_compression_stage = CompressionStage.UNCOMPRESSED

    for epoch in range(config.start_epoch, config.epochs):
        # update compression scheduler state at the begin of the epoch
        compression_ctrl.scheduler.epoch_step()

        if config.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train_epoch(
            train_loader,
            model,
            criterion,
            criterion_fn,
            optimizer,
            compression_ctrl,
            epoch,
            config,
        )

        # Learning rate scheduling should be applied after optimizer’s update
        lr_scheduler.step(
            epoch if not isinstance(lr_scheduler, ReduceLROnPlateau) else best_acc1
        )

        # compute compression algo statistics
        statistics = compression_ctrl.statistics()

        acc1 = best_acc1

        best_instance_acc = 0.0
        best_class_acc = 0.0
        if epoch % config.test_every_n_epochs == 0:
            # evaluate on validation set
            instance_acc, class_acc = validate(val_loader, model, criterion, config, epoch=epoch)
        if (instance_acc >= best_instance_acc):
                best_instance_acc = instance_acc
                best_epoch = epoch + 1

        if (class_acc >= best_class_acc):
            best_class_acc = class_acc
        logger.info('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
        logger.info('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))
        compression_stage = compression_ctrl.compression_stage()
        # remember best acc@1, considering compression stage. If current acc@1 less then the best acc@1, checkpoint
        # still can be best if current compression stage is larger than the best one. Compression stages in ascending
        # order: UNCOMPRESSED, PARTIALLY_COMPRESSED, FULLY_COMPRESSED.
        is_best_by_accuracy = (
            acc1 > best_acc1 and compression_stage == best_compression_stage
        )
        is_best = is_best_by_accuracy or compression_stage > best_compression_stage
        if is_best:
            best_acc1 = acc1
        best_compression_stage = max(compression_stage, best_compression_stage)
        if is_main_process():
            logger.info(statistics.to_str())

            if config.metrics_dump is not None:
                acc = best_acc1 / 100
                write_metrics(acc, config.metrics_dump)

            checkpoint_path = osp.join(
                config.checkpoint_save_dir, get_run_name(config) + "_last.pth"
            )
            checkpoint = {
                "epoch": epoch + 1,
                "arch": model_name,
                MODEL_STATE_ATTR: model.state_dict(),
                COMPRESSION_STATE_ATTR: compression_ctrl.get_compression_state(),
                "best_acc1": best_acc1,
                "acc1": acc1,
                "optimizer": optimizer.state_dict(),
            }

            torch.save(checkpoint, checkpoint_path)
            make_additional_checkpoints(checkpoint_path, is_best, epoch + 1, config)

            for key, value in prepare_for_tensorboard(statistics).items():
                config.tb.add_scalar(
                    "compression/statistics/{0}".format(key),
                    value,
                    len(train_loader) * epoch,
                )


def get_dataset(dataset_config, config, transform, is_train):
    if dataset_config == "imagenet":
        prefix = "train" if is_train else "val"
        return datasets.ImageFolder(osp.join(config.dataset_dir, prefix), transform)
    # For testing purposes
    num_images = config.get("num_mock_images", 1000)
    if dataset_config == "mock_32x32":
        return MockDataset(
            img_size=(32, 32), transform=transform, num_images=num_images
        )
    if dataset_config == "mock_299x299":
        return MockDataset(
            img_size=(299, 299), transform=transform, num_images=num_images
        )
    return create_cifar(config, dataset_config, is_train, transform)


def create_cifar(config, dataset_config, is_train, transform):
    create_cifar_fn = None
    if dataset_config in ["cifar100", "cifar100_224x224"]:
        create_cifar_fn = partial(
            CIFAR100, config.dataset_dir, train=is_train, transform=transform
        )
    if dataset_config == "cifar10":
        create_cifar_fn = partial(
            CIFAR10, config.dataset_dir, train=is_train, transform=transform
        )
    if create_cifar_fn:
        return safe_thread_call(
            partial(create_cifar_fn, download=True),
            partial(create_cifar_fn, download=False),
        )
    return None


def create_datasets(config):
    dataset_config = config.dataset if config.dataset is not None else "imagenet"
    dataset_config = dataset_config.lower()
    assert dataset_config in [
        "imagenet",
        "cifar100",
        "cifar10",
        "cifar100_224x224",
        "mock_32x32",
        "mock_299x299",
    ], "Unknown dataset option"

    if dataset_config == "imagenet":
        normalize = transforms.Normalize(
            mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
        )
    elif dataset_config in ["cifar100", "cifar100_224x224"]:
        normalize = transforms.Normalize(
            mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2761)
        )
    elif dataset_config == "cifar10":
        normalize = transforms.Normalize(
            mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)
        )
    elif dataset_config in ["mock_32x32", "mock_299x299"]:
        normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))

    input_info = FillerInputInfo.from_nncf_config(config)
    image_size = input_info.elements[0].shape[-1]
    size = int(image_size / 0.875)
    if dataset_config in ["cifar10", "cifar100_224x224", "cifar100"]:
        list_val_transforms = [transforms.ToTensor(), normalize]
        if dataset_config == "cifar100_224x224":
            list_val_transforms.insert(0, transforms.Resize(image_size))
        val_transform = transforms.Compose(list_val_transforms)

        list_train_transforms = [
            transforms.RandomCrop(image_size, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]
        if dataset_config == "cifar100_224x224":
            list_train_transforms.insert(0, transforms.Resize(image_size))
        train_transforms = transforms.Compose(list_train_transforms)
    elif dataset_config in ["mock_32x32", "mock_299x299"]:
        val_transform = transforms.Compose(
            [
                transforms.Resize(size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                normalize,
            ]
        )
        train_transforms = transforms.Compose(
            [
                transforms.Resize(size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                normalize,
            ]
        )
    else:
        val_transform = transforms.Compose(
            [
                transforms.Resize(size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                normalize,
            ]
        )
        train_transforms = transforms.Compose(
            [
                transforms.RandomResizedCrop(image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        )

    val_dataset = get_dataset(dataset_config, config, val_transform, is_train=False)
    train_dataset = get_dataset(dataset_config, config, train_transforms, is_train=True)

    return train_dataset, val_dataset

def pointnet_dataset(config):
    data_path=config.get("dataset_dir")
    num_point =  config.get("num_point")
    use_uniform_sample =  config.get("use_uniform_sample")
    use_normals =  config.get("use_normals")
    num_category =  config.get("num_classes")
    process_data=config.get("process_data")
    train_dataset = ModelNetDataLoader(root=data_path, args=None, split='train', 
                                       process_data=process_data,num_point=num_point,
                                       use_uniform_sample=use_uniform_sample,
                                       use_normals=use_normals,num_category=num_category)
    val_dataset = ModelNetDataLoader(root=data_path, args=None, split='test', 
                                     process_data=process_data,num_point=num_point,
                                     use_uniform_sample=use_uniform_sample,
                                     use_normals=use_normals,num_category=num_category)
    return train_dataset, val_dataset

def create_data_loaders(config, train_dataset, val_dataset):
    pin_memory = config.execution_mode != ExecutionMode.CPU_ONLY
    # When using a single GPU per process and per
    # DistributedDataParallel, we need to divide the batch size
    # ourselves based on the total number of GPUs we have
    batch_size = int(config.batch_size)
    workers = int(config.workers)
    batch_size_val = (
        int(config.batch_size_val)
        if config.batch_size_val is not None
        else int(config.batch_size)
    )
    if config.execution_mode == ExecutionMode.MULTIPROCESSING_DISTRIBUTED:
        batch_size //= config.ngpus_per_node
        batch_size_val //= config.ngpus_per_node
        workers //= config.ngpus_per_node

    val_sampler = torch.utils.data.SequentialSampler(val_dataset)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size_val,
        shuffle=False,
        num_workers=workers,
        pin_memory=pin_memory,
        sampler=val_sampler,
        drop_last=False,
    )

    train_sampler = None
    if config.distributed:
        sampler_seed = 0 if config.seed is None else config.seed
        dist_sampler_shuffle = config.seed is None
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, seed=sampler_seed, shuffle=dist_sampler_shuffle
        )

    train_shuffle = train_sampler is None and config.seed is None

    def create_train_data_loader(batch_size_):
        return torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size_,
            shuffle=train_shuffle,
            num_workers=workers,
            pin_memory=pin_memory,
            sampler=train_sampler,
            drop_last=True,
        )

    train_loader = create_train_data_loader(batch_size)

    if config.batch_size_init:
        init_loader = create_train_data_loader(config.batch_size_init)
    else:
        init_loader = deepcopy(train_loader)
    return train_loader, train_sampler, val_loader, init_loader


def train_epoch(
    train_loader,
    model,
    criterion,
    criterion_fn,
    optimizer,
    compression_ctrl,
    epoch,
    config,
    train_iters=None,
    log_training_info=True,
):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    compression_losses = AverageMeter()
    criterion_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    if train_iters is None:
        train_iters = len(train_loader)

    compression_scheduler = compression_ctrl.scheduler
    casting = autocast if config.mixed_precision else NullContextManager

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input_, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        compression_scheduler.step()

        input_ = input_.to(config.device)
        target = target.to(config.device)

        # compute output
        with casting():
            output = model(input_)
            criterion_loss = criterion_fn(output, target, criterion)

            # compute compression loss
            compression_loss = compression_ctrl.loss()
            loss = criterion_loss + compression_loss

        if isinstance(output, InceptionOutputs):
            output = output.logits
        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input_.size(0))
        comp_loss_val = (
            compression_loss.item()
            if isinstance(compression_loss, torch.Tensor)
            else compression_loss
        )
        compression_losses.update(comp_loss_val, input_.size(0))
        criterion_losses.update(criterion_loss.item(), input_.size(0))
        top1.update(acc1, input_.size(0))
        top5.update(acc5, input_.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config.print_freq == 0 and log_training_info:
            logger.info(
                "{rank}: "
                "Epoch: [{0}][{1}/{2}] "
                "Lr: {3:.3} "
                "Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) "
                "Data: {data_time.val:.3f} ({data_time.avg:.3f}) "
                "CE_loss: {ce_loss.val:.4f} ({ce_loss.avg:.4f}) "
                "CR_loss: {cr_loss.val:.4f} ({cr_loss.avg:.4f}) "
                "Loss: {loss.val:.4f} ({loss.avg:.4f}) "
                "Acc@1: {top1.val:.3f} ({top1.avg:.3f}) "
                "Acc@5: {top5.val:.3f} ({top5.avg:.3f})".format(
                    epoch,
                    i,
                    len(train_loader),
                    get_lr(optimizer),
                    batch_time=batch_time,
                    data_time=data_time,
                    ce_loss=criterion_losses,
                    cr_loss=compression_losses,
                    loss=losses,
                    top1=top1,
                    top5=top5,
                    rank=(
                        "{}:".format(config.rank)
                        if config.multiprocessing_distributed
                        else ""
                    ),
                )
            )

        if is_main_process() and log_training_info:
            global_step = train_iters * epoch
            config.tb.add_scalar(
                "train/learning_rate", get_lr(optimizer), i + global_step
            )
            config.tb.add_scalar(
                "train/criterion_loss", criterion_losses.val, i + global_step
            )
            config.tb.add_scalar(
                "train/compression_loss", compression_losses.val, i + global_step
            )
            config.tb.add_scalar("train/loss", losses.val, i + global_step)
            config.tb.add_scalar("train/top1", top1.val, i + global_step)
            config.tb.add_scalar("train/top5", top5.val, i + global_step)

            statistics = compression_ctrl.statistics(quickly_collected_only=True)
            for stat_name, stat_value in prepare_for_tensorboard(statistics).items():
                config.tb.add_scalar(
                    "train/statistics/{}".format(stat_name), stat_value, i + global_step
                )

        if i >= train_iters:
            break


def validate(val_loader, model, criterion, config, epoch=0, log_validation_info=True):
    mean_correct = []
    num_classes=config.get("num_classes")
    class_acc = np.zeros((num_classes, 3))
    # switch to evaluate mode
    model.eval()

    casting = autocast if config.mixed_precision else NullContextManager
    with torch.no_grad():
        
        for j, (points, target) in tqdm(enumerate(val_loader),total=len(val_loader)):
            points, target = points.to(config.device), target.to(config.device)
            with casting():
                pred, _ = model(points)
            # print(pred)
            pred_choice = pred.data.max(1)[1]

            for cat in np.unique(target.cpu()):
                classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
                class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
                class_acc[cat, 1] += 1

            correct = pred_choice.eq(target.long().data).cpu().sum()
            mean_correct.append(correct.item() / float(points.size()[0]))
    class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
    class_acc = np.mean(class_acc[:, 2])
    instance_acc = np.mean(mean_correct)
    return instance_acc,class_acc


class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.val = None
        self.avg = None
        self.sum = None
        self.count = None
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).sum(0, keepdim=True)
            res.append(correct_k.float().mul_(100.0 / batch_size).item())
        return res


def get_lr(optimizer):
    return optimizer.param_groups[0]["lr"]


if __name__ == "__main__":
    main(sys.argv[1:])

ModelNetDataLoader.py

import os
import numpy as np
import warnings
import pickle

from tqdm import tqdm
from torch.utils.data import Dataset

warnings.filterwarnings('ignore')


def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc


def farthest_point_sample(point, npoint):
    """
    Input:
        xyz: pointcloud data, [N, D]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [npoint, D]
    """
    N, D = point.shape
    xyz = point[:,:3]
    centroids = np.zeros((npoint,))
    distance = np.ones((N,)) * 1e10
    farthest = np.random.randint(0, N)
    for i in range(npoint):
        centroids[i] = farthest
        centroid = xyz[farthest, :]
        dist = np.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = np.argmax(distance, -1)
    point = point[centroids.astype(np.int32)]
    return point


class ModelNetDataLoader(Dataset):
    def __init__(self, root, args, split='train', process_data=False,num_point=1024,use_uniform_sample=None,use_normals=False,num_category=40):
        self.root = root
        if args is not None:
            self.npoints = args.num_point
            self.uniform = args.use_uniform_sample
            self.use_normals = args.use_normals
            self.num_category = args.num_category
        else:
            self.npoints = num_point
            self.uniform = use_uniform_sample
            self.use_normals = use_normals
            self.num_category = num_category
        self.process_data = process_data

        if self.num_category == 10:
            self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')
        else:
            self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')

        self.cat = [line.rstrip() for line in open(self.catfile)]
        self.classes = dict(zip(self.cat, range(len(self.cat))))

        shape_ids = {}
        if self.num_category == 10:
            shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
            shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
        else:
            shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
            shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]

        assert (split == 'train' or split == 'test')
        shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
        self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
                         in range(len(shape_ids[split]))]
        print('The size of %s data is %d' % (split, len(self.datapath)))

        if self.uniform:
            self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints))
        else:
            self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))

        if self.process_data:
            if not os.path.exists(self.save_path):
                print('Processing data %s (only running in the first time)...' % self.save_path)
                self.list_of_points = [None] * len(self.datapath)
                self.list_of_labels = [None] * len(self.datapath)

                for index in tqdm(range(len(self.datapath)), total=len(self.datapath)):
                    fn = self.datapath[index]
                    cls = self.classes[self.datapath[index][0]]
                    cls = np.array([cls]).astype(np.int32)
                    point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)

                    if self.uniform:
                        point_set = farthest_point_sample(point_set, self.npoints)
                    else:
                        point_set = point_set[0:self.npoints, :]

                    self.list_of_points[index] = point_set
                    self.list_of_labels[index] = cls

                with open(self.save_path, 'wb') as f:
                    pickle.dump([self.list_of_points, self.list_of_labels], f)
            else:
                print('Load processed data from %s...' % self.save_path)
                with open(self.save_path, 'rb') as f:
                    self.list_of_points, self.list_of_labels = pickle.load(f)

    def __len__(self):
        return len(self.datapath)

    def _get_item(self, index):
        if self.process_data:
            point_set, label = self.list_of_points[index], self.list_of_labels[index]
        else:
            fn = self.datapath[index]
            cls = self.classes[self.datapath[index][0]]
            label = np.array([cls]).astype(np.int32)
            point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)

            if self.uniform:
                point_set = farthest_point_sample(point_set, self.npoints)
            else:
                point_set = point_set[0:self.npoints, :]
                
        point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
        if not self.use_normals:
            point_set = point_set[:, 0:3]

        return point_set, label[0]

    def __getitem__(self, index):
        return self._get_item(index)


if __name__ == '__main__':
    import torch

    data = ModelNetDataLoader('/data/modelnet40_normal_resampled/', split='train')
    DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)
    for point, label in DataLoader:
        print(point.shape)
        print(label.shape)

pointnet2_utils.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np

def timeit(tag, t):
    print("{}: {}s".format(tag, time() - t))
    return time()

def pc_normalize(pc):
    l = pc.shape[0]
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc

def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.

    src^T * dst = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist


def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

#1
def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    
    # Modified: take the farthest point to (0, 0, 0) as the first sample rather than random sampling
    # farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    farthest = torch.argmax(torch.square(xyz).sum(dim=2), dim=1)    

    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        cent = xyz[batch_indices, farthest, :]
        if cent.shape[-1]!=3:
            print("cent.shape",cent.shape)
            print("xyz.shape",xyz.shape)
        cent = cent.view(B, 1, 3)
        dist = torch.sum((xyz - cent) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

#2
def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx

#3
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
    new_xyz = index_points(xyz, fps_idx) # [B, npoint, 3]
    idx = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)


    if points is not None:
        grouped_points = index_points(points, idx)
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
    else:
        new_points = grouped_xyz_norm
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx
    else:
        return new_xyz, new_points

# simply concatnate the xyz coordinate and point feature
def sample_and_group_all(xyz, points):
    """
    Input:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, 1, 3]
        new_points: sampled points data, [B, 1, N, 3+D]
    """
    device = xyz.device
    B, N, C = xyz.shape
    new_xyz = torch.zeros(B, 1, C).to(device)
    grouped_xyz = xyz.view(B, 1, N, C)
    if points is not None:
        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points

#4
class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.group_all = group_all

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points)
        else:
            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        # print(new_points.shape)

        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points)))

        new_points = torch.max(new_points, 2)[0] # max at dimension of nsample 
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points


class PointNetSetAbstractionMsg(nn.Module):
    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
        super(PointNetSetAbstractionMsg, self).__init__()
        self.npoint = npoint
        self.radius_list = radius_list
        self.nsample_list = nsample_list
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        for i in range(len(mlp_list)):
            convs = nn.ModuleList()
            bns = nn.ModuleList()
            last_channel = in_channel + 3
            for out_channel in mlp_list[i]:
                convs.append(nn.Conv2d(last_channel, out_channel, 1))
                bns.append(nn.BatchNorm2d(out_channel))
                last_channel = out_channel
            self.conv_blocks.append(convs)
            self.bn_blocks.append(bns)

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        B, N, C = xyz.shape
        S = self.npoint
        new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
        new_points_list = []
        for i, radius in enumerate(self.radius_list):
            K = self.nsample_list[i]
            group_idx = query_ball_point(radius, K, xyz, new_xyz)
            grouped_xyz = index_points(xyz, group_idx)
            grouped_xyz -= new_xyz.view(B, S, 1, C)
            if points is not None:
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
            else:
                grouped_points = grouped_xyz

            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]
            for j in range(len(self.conv_blocks[i])):
                conv = self.conv_blocks[i][j]
                bn = self.bn_blocks[i][j]
                grouped_points =  F.relu(bn(conv(grouped_points)))
            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]
            new_points_list.append(new_points)

        new_xyz = new_xyz.permute(0, 2, 1)
        new_points_concat = torch.cat(new_points_list, dim=1)
        return new_xyz, new_points_concat


class PointNetFeaturePropagation(nn.Module):
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel

    def forward(self, xyz1, xyz2, points1, points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            points1: input points data, [B, D, N]
            points2: input points data, [B, D, S]
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        xyz1 = xyz1.permute(0, 2, 1)
        xyz2 = xyz2.permute(0, 2, 1)

        points2 = points2.permute(0, 2, 1)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            dists = square_distance(xyz1, xyz2)
            dists, idx = dists.sort(dim=-1)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]

            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)

        if points1 is not None:
            points1 = points1.permute(0, 2, 1)
            new_points = torch.cat([points1, interpolated_points], dim=-1)
        else:
            new_points = interpolated_points

        new_points = new_points.permute(0, 2, 1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        return new_points

pointnet2_cls.py

import torch.nn as nn
import torch
import torch.nn.functional as F
from .pointnet2_utils import PointNetSetAbstraction
import logging

class get_model(nn.Module):
    def __init__(self,num_class,normal_channel=True,pretrained=False,load_path=None):
        super(get_model, self).__init__()
        in_channel = 6 if normal_channel else 3
        self.normal_channel = normal_channel
        self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=64, in_channel=in_channel, mlp=[64, 64, 128], group_all=False)
        self.sa2 = PointNetSetAbstraction(npoint=256, radius=0.4, nsample=64, in_channel=128+3, mlp=[128, 128, 256], group_all=False)
        self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256+3, mlp=[256, 512, 1024], group_all=True)
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.4)
        self.fc3 = nn.Linear(256, num_class)

        if pretrained:
            if load_path is not None:
                self.load_state_dict(torch.load(load_path)['model_state_dict'])
                logging.info("=> done loading ")
            else:
                logging.info("=> no provided checkpoint path ")
            
    def forward(self, xyz):
        
        B, _, _ = xyz.shape
        xyz = xyz.transpose(2, 1)
        print(xyz.shape)
        if self.normal_channel:
            norm = xyz[:, 3:, :]
            xyz = xyz[:, :3, :]
        else:
            norm = None
            
        l1_xyz, l1_points = self.sa1(xyz, norm)
        # print(l1_xyz.shape, l1_points.shape)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        # print(l2_xyz.shape, l2_points.shape)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        # print(l3_xyz.shape, l3_points.shape)
        x = l3_points.view(B, 1024)
        # print(x)
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop2(F.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        # x = F.log_softmax(x, -1)


        return x, l3_points

def pointnetv2_cls(num_class,normal_channel=False,pretrained=False,load_path=None):
    return get_model(num_class,normal_channel=normal_channel,pretrained=pretrained,load_path=load_path)

class get_cls_loss(nn.Module):
    def __init__(self):
        super(get_cls_loss, self).__init__()

    def forward(self, pred, target, trans_feat):
        total_loss = F.nll_loss(pred, target)

        return total_loss

pointnet_v2_classification_int8.json

{
    "model": "pointnetv2",
    "pretrained": false,
    "dataset_dir":"/data1/modelnet/modelnet40_normal_resampled",
    "input_info": {
      "sample_size": [1, 1024, 3]
    },
    "num_classes": 40,
    "batch_size" : 64,
    "epochs": 80,
    "optimizer": {
        "type": "Adam",
        "base_lr": 0.00001,
        "schedule_type": "multistep",
        "steps": [
            5
        ]
    },
    "compression": {
        "algorithm": "quantization",
        "initializer": {
            "range": {
                "num_init_samples": 2560
            }
        }
    }
}

@AlexanderDokuchaev
Copy link
Collaborator

@zbnlala unfortunately this issue is not reproduced.

Could you check that issue is not reproduced without nncf?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants