Select a backing device to use based on inputs and availability.
Parameters:
Name |
Type |
Description |
Default |
gpu |
bool
|
Whether to select a GPU device rather than the CPU one. |
required
|
idx |
Optional[int]
|
Optional pre-selected GPU device index. Only used when gpu=True .
If idx is None or exceeds the number of available GPU devices,
use torch.cuda.current_device() . |
None
|
Warns:
Type |
Description |
RuntimeWarning
|
If gpu=True but no GPU is available.
If idx exceeds the number of available GPU devices. |
Returns:
Name | Type |
Description |
device |
torch.device
|
Selected torch device, with type "cpu" or "cuda". |
Source code in declearn/model/torch/utils/_gpu.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79 | def select_device(
gpu: bool,
idx: Optional[int] = None,
) -> torch.device: # pylint: disable=no-member
"""Select a backing device to use based on inputs and availability.
Parameters
----------
gpu: bool
Whether to select a GPU device rather than the CPU one.
idx: int or None, default=None
Optional pre-selected GPU device index. Only used when `gpu=True`.
If `idx is None` or exceeds the number of available GPU devices,
use `torch.cuda.current_device()`.
Warns
-----
RuntimeWarning
If `gpu=True` but no GPU is available.
If `idx` exceeds the number of available GPU devices.
Returns
-------
device: torch.device
Selected torch device, with type "cpu" or "cuda".
"""
# Case when instructed to use the CPU device.
if not gpu:
return torch.device("cpu") # pylint: disable=no-member
# Case when no GPU is available: warn and use the CPU instead.
if gpu and not torch.cuda.is_available():
warnings.warn(
"Cannot use a GPU device: either CUDA is unavailable "
"or no GPU is visible to torch."
)
return torch.device("cpu") # pylint: disable=no-member
# Case when the desired GPU is invalid: select another one.
if (idx or 0) >= torch.cuda.device_count():
warnings.warn(
f"Cannot use GPU device n°{idx}: index is out-of-range.\n"
f"Using GPU device n°{torch.cuda.current_device()} instead.",
RuntimeWarning,
)
idx = None
# Return the selected or auto-selected GPU device index.
if idx is None:
idx = torch.cuda.current_device()
return torch.device("cuda", index=idx) # pylint: disable=no-member
|