Engineering

Jul 28, 2023

Engineering

Struggling with bitsandbytes issue

  • Jeongseok Kang

    Researcher

Jul 28, 2023

Engineering

Struggling with bitsandbytes issue

  • Jeongseok Kang

    Researcher

It's the age of the Large Language Model (LLM). Announced by OpenAI in November 2022, ChatGPT has become synonymous with modern artificial intelligence since AlphaGo. While many companies and research labs are working on developing their own language models based on ChatGPT, a growing number are open-sourcing them, such as Meta AI's Llama 2, making them more accessible to individuals.

Backend.AI is a popular choice for developing LLMs because of the ease of running large clusters and distributed processing. In fact, we get a lot of feedback and requests from customers, and today I'd like to share how we solved one of them.

On April 4, 2023, we received an issue that certain packages were causing an error when running in a container environment provided by NGC Catalog1(NVIDIA GPU Cloud). The NGC Catalog is a list of containers2 with optimized environments for developing AI/ML, metaverse, and high-performance computing applications, and because it is operated and distributed directly by NVIDIA, it is highly trusted and considered a standard, especially in the CUDA environment. Therefore, an issue in this environment represents a potential risk that many of our customers will face in the future, and we have made it a high priority to address this issue.

Reproducing the issue

We first went through the process of reproducing the issue to determine the exact cause. This particular case was an error in a package called bitsandbytes while running ViperGPT3, developed by Columbia University. ViperGPT has a dependency on bitsandbytes as shown below.

requirements.txt

accelerate==0.18.0
backoff==2.2.1
bitsandbytes==0.38.1
cityscapesscripts==2.2.1
git+https://github.com/openai/CLIP.git
decord==0.6.0
dill==0.3.6
...

I was able to reproduce the issue by simply import bitsandbytes.

Note The execution environment was utilized from the image nvcr.io/nvidia/pytorch:22.05-py3.

$ pip install bitsandbytes # 0.37.1 $ python >> import bitsandbytes ===================================BUG REPORT=================================== Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues ================================================================================ CUDA exception! Error code: OS call failed or operation not supported on this OS CUDA exception! Error code: initialization error CUDA SETUP: CUDA runtime path found: /home/work/data/miniconda3/envs/vipergpt/lib/libcudart.so /home/work/data/miniconda3/envs/vipergpt/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:136: UserWarning: WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library... warn(msg) CUDA SETUP: Detected CUDA version 116 CUDA SETUP: Loading binary /home/work/data/miniconda3/envs/vipergpt/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so... /home/work/data/miniconda3/envs/vipergpt/lib/python3.10/site-packages/bitsandbytes/cextension.py:31: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers and GPU quantization are unavailable. warn("The installed version of bitsandbytes was compiled without GPU support. "

bitsandbytes traverses all CUDA devices installed in the execution environment and checks their compute capability4. In this case, libcuda.so was supposed to check the number of CUDA devices installed in the execution environment using libcuda.so in the following way. When calling cuDeviceGetCount()5, we found that an error occurred: 304 CUDA_ERROR_OPERATING_SYSTEM

def get_compute_capabilities(cuda): """ 1. find libcuda.so library (GPU driver) (/usr/lib) init_device -> init variables -> call function by reference 2. call extern C function to determine CC (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) 3. Check for CUDA errors https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 """ nGpus = ct.c_int() cc_major = ct.c_int() cc_minor = ct.c_int() device = ct.c_int() # highlight-next-line check_cuda_result(cuda, cuda.cuDeviceGetCount(ct.byref(nGpus))) ccs = [] for i in range(nGpus.value): check_cuda_result(cuda, cuda.cuDeviceGet(ct.byref(device), i)) ref_major = ct.byref(cc_major) ref_minor = ct.byref(cc_minor) # 2. call extern C function to determine CC check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)) ccs.append(f"{cc_major.value}.{cc_minor.value}") return ccs

What is bitsandbytes?

Since the introduction of Transformer, language models have shown great performance improvements, and it has become a trend to increase the scale of models by stacking more Transformer blocks. This has led to the requirement of a large amount of GPU resources not only for training models but also for servicing them. For example, to service GPT-3 with 175B parameters, eight 80GB A100 GPUs worth about $15,000 are required. Adding it all up, this means a total cost of $120,000. This can be a huge burden not only for individuals but also for companies and research institutes, and thus research is actively being conducted to lighten the inference model for services.

Source: A Gentle Introduction to 8-bit Matrix Multiplication for transformers at scale using Hugging Face Transformers, Accelerate and bitsandbytes (Hugging Face)

To explain bitsandbytes, it's the open-source release of LLM.int8()6, a work by Tim Dettmers, a PhD candidate at the University of Washington, in collaboration with Facebook AI Research (now Meta AI). It uses a vector-wise quantization method that treats each vector independently when computing matrix products, and mixes 8-bit and 16-bit techniques to minimize losses by representing important vectors in 16-bit to reduce the size of the model while maintaining performance. It has been merged into Hugging Face's Transformer implementation and is used in a variety of models including Llama2, QLoRA, KoAlpaca, KULLM.

Identifying the cause

Now that we've located and reproduced the problem, we need to figure out what's causing it. We looked to see if there were any similar cases, but we couldn't find any. Also, cuInit() was called normally, which makes it even more difficult to pinpoint the cause.

import ctypes

count = ctypes.c_int()

libcuda = ctypes.CDLL("libcuda.so")
libcuda.cuInit(0)  # 0 (CUDA_SUCCESS)
libcuda.cuDeviceGetCount(ctypes.byref(count))  # 304 (CUDA_ERROR_OPERATING_SYSTEM)

libcudart = ctypes.CDLL("libcudart.so")
libcudart.cudaGetDeviceCount(ctypes.byref(count))  # 304 (CUDA_ERROR_OPERATING_SYSTEM)

I filed an issue (TimDettmers/bitsandbytes#264) through GitHub repository as shown below to get advice, and was told to update the package to the latest version and try again. I updated to 0.38.0.post1, which was the latest at the time, and tested again, but the issue was not resolved. I couldn't afford to lose too much time, so I decided to switch gears and remove the offending part.

Source: Greco-Roman Mythology in Comics (Gana Publishing)

Solving a problems

The first approach was to use CUDA-Python7. CUDA-Python is a Low-Level Bindings package officially distributed by NVIDIA. I was familiar with it off the top of my head, as I've found it useful in the past, and decided to install and test it right away.

$ pip install cuda-python
from cuda import cuda
from cuda import cudart

cuda.cuInit(0)  # (<CUresult.CUDA_SUCCESS: 0>,)
cudart.cudaGetDeviceCount()  # (<cudaError_t.cudaSuccess: 0>, 1)

Fortunately, cudart.cudaGetDeviceCount() worked fine and I proceeded to test integrating it into bitsandbytes. However, when I called cuda.cuInit(0) and then called torch.cuda.is_available(), I got an error. This was because I called cudaGetDeviceCount() inside torch.cuda.is_available().

from cuda import cuda, cudart

cuda.cuInit(0)  # <CUresult.CUDA_SUCCESS: 0>,)
cuda.cudaGetDeviceCount()  # (<cudaError_t.cudaSuccess: 0>, 1)

import bitsandbytes

# ...
# /opt/conda/lib/python3.8/site-packages/torch/cuda/__init__.py:82: UserWarning: CUDA initialization: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 304: OS call failed or operation not supported on this OS (Triggered internally at /opt/pytorch/pytorch/c10/cuda/CUDAFunctions.cpp:109.)
#   return torch._C._cuda_getDeviceCount() > 0
# ...

The problem seemed to be back to square one. I took a breath and calmly re-read the error log above. Then something caught my eye.

torch._C._cuda_getDeviceCount() > 0

bitsandbytes was already using PyTorch internally, meaning it had a dependency on PyTorch. To be precise, bitsandbytes had a dependency on lion-pytorch, which had a dependency on PyTorch. And PyTorch already had an interface to CUDA functions, which I decided to use this time.

Fortunately, all of the CUDA functions used by bitsandbytes existed in PyTorch. I made the following changes to the functions that were previously called via libcuda.so and libcudart.so.

libcuda/libcudarttorch
libcuda.cuDeviceGetCount()torch.cuda.device_count()
libcuda.cuDeviceGet()torch.cuda.device()
libcuda.cuDeviceComputeCapability()torch.cuda.get_device_capability()
libcudart.cudaRuntimeGetVersion()torch.version.cuda

After confirming that it worked after the change, I registered a PR in the GitHub repository to apply to the distribution package version. (TimDettmers/bitsandbytes#375).

Recap

On July 14, 2023, about two months after registering the PR, the patch was merged into the main branch and included in version 0.40.1.

I was able to get some feedback from the author, Tim Dettmers, and this short article gives me a sense of his thoughts and philosophy.

This was a great opportunity to learn more about the ecosystem of the LLM. It was the first time in a long time that I was able to enjoy open source activities. The ability to collaborate across spatial constraints and learn from each other's ideas is what makes open source so appealing. Backend.AI has an open source version along with an enterprise version. We will always strive to provide a better user experience and a better developer experience.


1: NVIDIA GPU Cloud

2: The NGC catalog hosts containers for AI/ML, metaverse, and HPC applications and are performance-optimized, tested, and ready to deploy on GPU-powered on-prem, cloud, and edge systems.

3: ViperGPT: Visual Inference via Python Execution for Reasoning, March 14, 2023.

4: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capability

5: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g52b5ce05cb8c5fb6831b2c0ff2887c74

6: LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale, November 10, 2022.

7: https://developer.nvidia.com/cuda-python

We're here for you!

Complete the form and we'll be in touch soon

Contact Us

Headquarter & HPC Lab

8F, 577, Seolleung-ro, Gangnam-gu, Seoul, Republic of Korea

© Lablup Inc. All rights reserved.