diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 71e7568a9..92b1675e9 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -3,6 +3,7 @@ import logging import re import subprocess +import platform from typing import Optional import torch @@ -83,8 +84,13 @@ def get_rocm_gpu_arch() -> str: logger = logging.getLogger(__name__) try: if torch.version.hip: - result = subprocess.run(["rocminfo"], capture_output=True, text=True) - match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) + rocminfo_process_name = "rocminfo" + search_pattern = r"Name:\s+gfx([a-zA-Z\d]+)" + if platform.system() == "Windows": + rocminfo_process_name = "hipinfo" + search_pattern = r"Name:\s*gfx([a-zA-Z\d]+)" + result = subprocess.run([rocminfo_process_name], capture_output=True, text=True) + match = re.search(search_pattern, result.stdout) if match: return "gfx" + match.group(1) else: @@ -107,8 +113,13 @@ def get_rocm_warpsize() -> int: logger = logging.getLogger(__name__) try: if torch.version.hip: - result = subprocess.run(["rocminfo"], capture_output=True, text=True) - match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout) + rocminfo_process_name = "rocminfo" + search_pattern = r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)" + if platform.system() == "Windows": + rocminfo_process_name = "hipinfo" + search_pattern = r"warpSize:\s*(\d+)" + result = subprocess.run([rocminfo_process_name], capture_output=True, text=True) + match = re.search(search_pattern, result.stdout) if match: return int(match.group(1)) else: diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index 4eb446206..06db0298f 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -11,7 +11,9 @@ #include #include #include +#ifndef _WIN32 #include +#endif #include #include