Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Outdated ncclResult code #128756

Closed
myungjin opened this issue Jun 14, 2024 · 2 comments
Closed

Outdated ncclResult code #128756

myungjin opened this issue Jun 14, 2024 · 2 comments
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@myungjin
Copy link
Contributor

myungjin commented Jun 14, 2024

馃悰 Describe the bug

NCCL introduced a new ncclResult code (ncclRemoteError). see here and here.

This new code is not included in the https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.h#L40-L49.
This throws a runtime error std::runtime_error("Unconvertible NCCL type") from https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.cpp#L82.

This makes it impossible to handle an exception gracefully when a nccl worker in a remote machine fails.
The following three code blocks need to be updated correctly.
https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.h#L40-L49
https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.cpp#L36-L59
https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.cpp#L61-L84

The proposed changes are as follows:
In https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.h,

enum class ncclResult {
  Success = 0,
  UnhandledCudaError = 1,
  SystemError = 2,
  InternalError = 3,
  InvalidArgument = 4,
  InvalidUsage = 5,
  RemoteError = 6,
  InProgress = 7,
  NumResults = 8
};

In https://github.com/pytorch/pytorch/blob/main/torch/csrc/cuda/nccl.cpp

ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
  switch (var) {
    case torch::cuda::nccl::ncclResult::Success:
      return ncclResult_t::ncclSuccess;
    case torch::cuda::nccl::ncclResult::UnhandledCudaError:
      return ncclResult_t::ncclUnhandledCudaError;
    case torch::cuda::nccl::ncclResult::SystemError:
      return ncclResult_t::ncclSystemError;
    case torch::cuda::nccl::ncclResult::InternalError:
      return ncclResult_t::ncclInternalError;
    case torch::cuda::nccl::ncclResult::InvalidArgument:
      return ncclResult_t::ncclInvalidArgument;
    case torch::cuda::nccl::ncclResult::InvalidUsage:
      return ncclResult_t::ncclInvalidUsage;
    case torch::cuda::nccl::ncclResult::RemoteError:
      return ncclResult_t::ncclRemoteError;
#ifdef NCCL_HAS_COMM_NONBLOCKING
    case torch::cuda::nccl::ncclResult::InProgress:
      return ncclResult_t::ncclInProgress;
#endif
    case torch::cuda::nccl::ncclResult::NumResults:
      return ncclResult_t::ncclNumResults;
    default:
      throw std::runtime_error("Unconvertible NCCL type");
  }
}

torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
  switch (var) {
    case ncclSuccess:
      return torch::cuda::nccl::ncclResult::Success;
    case ncclUnhandledCudaError:
      return torch::cuda::nccl::ncclResult::UnhandledCudaError;
    case ncclSystemError:
      return torch::cuda::nccl::ncclResult::SystemError;
    case ncclInternalError:
      return torch::cuda::nccl::ncclResult::InternalError;
    case ncclInvalidArgument:
      return torch::cuda::nccl::ncclResult::InvalidArgument;
    case ncclInvalidUsage:
      return torch::cuda::nccl::ncclResult::InvalidUsage;
    case ncclRemoteError:
      return torch::cuda::nccl::ncclResult::RemoteError;
#ifdef NCCL_HAS_COMM_NONBLOCKING
    case ncclInProgress:
      return torch::cuda::nccl::ncclResult::InProgress;
#endif
    case ncclNumResults:
      return torch::cuda::nccl::ncclResult::NumResults;
    default:
      throw std::runtime_error("Unconvertible NCCL type");
  }
}

Versions

PyTorch version: 2.2.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Amazon Linux 2 (x86_64)
GCC version: (conda-forge gcc 12.3.0-7) 12.3.0
Clang version: Could not collect
CMake version: version 3.29.4
Libc version: glibc-2.26

Python version: 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.4.247-162.350.amzn2.x86_64-x86_64-with-glibc2.26
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Tesla V100-SXM2-16GB
GPU 1: Tesla V100-SXM2-16GB
GPU 2: Tesla V100-SXM2-16GB
GPU 3: Tesla V100-SXM2-16GB

Nvidia driver version: 535.161.07
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 32
On-line CPU(s) list: 0-31
Thread(s) per core: 2
Core(s) per socket: 16
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 79
Model name: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
Stepping: 1
CPU MHz: 2699.958
CPU max MHz: 3000.0000
CPU min MHz: 1200.0000
BogoMIPS: 4600.01
Hypervisor vendor: Xen
Virtualization type: full
L1d cache: 32K
L1i cache: 32K
L2 cache: 256K
L3 cache: 46080K
NUMA node0 CPU(s): 0-31
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx xsaveopt

Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] mypy==1.7.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.22.4
[pip3] optree==0.9.1
[pip3] torch==2.2.1
[pip3] torchvision==0.17.1
[pip3] triton==2.2.0
[conda] libmagma 2.7.2 h173bb3b_2 conda-forge
[conda] libmagma_sparse 2.7.2 h173bb3b_3 conda-forge
[conda] magma 2.7.2 h51420fd_3 conda-forge
[conda] mkl 2024.1.0 ha957f24_693 conda-forge
[conda] mkl-include 2024.1.0 ha957f24_693 conda-forge
[conda] numpy 1.22.4 pypi_0 pypi
[conda] optree 0.9.1 pypi_0 pypi
[conda] torch 2.2.1 pypi_0 pypi
[conda] torchvision 0.17.1 pypi_0 pypi
[conda] triton 2.2.0 pypi_0 pypi

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

@myungjin myungjin changed the title oudated ncclResult code Outdated ncclResult code Jun 14, 2024
@ezyang ezyang added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 15, 2024
@ezyang
Copy link
Contributor

ezyang commented Jun 15, 2024

send us a patch please!

myungjin added a commit to myungjin/pytorch that referenced this issue Jun 15, 2024
The nccl result codes are outdated. This PR fixes pytorch#128756.
@myungjin
Copy link
Contributor Author

@ezyang Created a PR

@weifengpy weifengpy added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants