Select a backing device to use based on inputs and availability.
Parameters:
Name |
Type |
Description |
Default |
gpu |
bool
|
Whether to select a GPU device rather than the CPU one. |
required
|
idx |
Optional[int]
|
Optional pre-selected device index. Only used when gpu=True .
If idx is None or exceeds the number of available GPU devices,
use the first available one. |
None
|
Warns:
Type |
Description |
RuntimeWarning:
|
If gpu=True but no GPU is available.
If idx exceeds the number of available GPU devices. |
Returns:
Name | Type |
Description |
device |
jaxlib.xla_extension.Device
|
Selected device. |
Source code in declearn/model/haiku/utils/_gpu.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82 | def select_device(
gpu: bool,
idx: Optional[int] = None,
) -> jax.Device:
"""Select a backing device to use based on inputs and availability.
Parameters
----------
gpu: bool
Whether to select a GPU device rather than the CPU one.
idx: int or None, default=None
Optional pre-selected device index. Only used when `gpu=True`.
If `idx is None` or exceeds the number of available GPU devices,
use the first available one.
Warns
-----
RuntimeWarning:
If `gpu=True` but no GPU is available.
If `idx` exceeds the number of available GPU devices.
Returns
-------
device: jaxlib.xla_extension.Device
Selected device.
"""
idx = 0 if idx is None else idx
device_type = "gpu" if gpu else "cpu"
# List devices, handling errors related to the lack of GPU (or CPU error).
try:
devices = jax.devices(device_type)
except RuntimeError as exc:
# Warn about the lack of GPU (support?), and fall back to CPU.
if gpu:
warnings.warn(
"Cannot use a GPU device: either CUDA is unavailable "
f"or no GPU is visible to jax: raised {repr(exc)}.",
RuntimeWarning,
)
return select_device(gpu=False, idx=0)
# Case when no CPU is found: this should never be reached.
raise RuntimeError( # pragma: no cover
"Failed to have jax select a CPU device."
) from exc
# similar code to tensorflow util; pylint: disable=duplicate-code
# Case when the desired device index is invalid: select another one.
if idx >= len(devices):
warnings.warn(
f"Cannot use {device_type} device n°{idx}: index is out-of-range."
f"\nUsing {device_type} device n°0 instead.",
RuntimeWarning,
)
idx = 0
# Return the selected device.
return devices[idx]
|