Skip to content

declearn.model.haiku.utils.select_device

Select a backing device to use based on inputs and availability.

Parameters:

Name Type Description Default
gpu bool

Whether to select a GPU device rather than the CPU one.

required
idx Optional[int]

Optional pre-selected device index. Only used when gpu=True. If idx is None or exceeds the number of available GPU devices, use the first available one.

None

Warns:

Type Description
RuntimeWarning:

If gpu=True but no GPU is available. If idx exceeds the number of available GPU devices.

Returns:

Name Type Description
device jaxlib.xla_extension.Device

Selected device.

Source code in declearn/model/haiku/utils/_gpu.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def select_device(
    gpu: bool,
    idx: Optional[int] = None,
) -> jax.Device:
    """Select a backing device to use based on inputs and availability.

    Parameters
    ----------
    gpu: bool
        Whether to select a GPU device rather than the CPU one.
    idx: int or None, default=None
        Optional pre-selected device index. Only used when `gpu=True`.
        If `idx is None` or exceeds the number of available GPU devices,
        use the first available one.

    Warns
    -----
    RuntimeWarning:
        If `gpu=True` but no GPU is available.
        If `idx` exceeds the number of available GPU devices.

    Returns
    -------
    device: jaxlib.xla_extension.Device
        Selected device.
    """
    idx = 0 if idx is None else idx
    device_type = "gpu" if gpu else "cpu"
    # List devices, handling errors related to the lack of GPU (or CPU error).
    try:
        devices = jax.devices(device_type)
    except RuntimeError as exc:
        # Warn about the lack of GPU (support?), and fall back to CPU.
        if gpu:
            warnings.warn(
                "Cannot use a GPU device: either CUDA is unavailable "
                f"or no GPU is visible to jax: raised {repr(exc)}.",
                RuntimeWarning,
            )
            return select_device(gpu=False, idx=0)
        # Case when no CPU is found: this should never be reached.
        raise RuntimeError(  # pragma: no cover
            "Failed to have jax select a CPU device."
        ) from exc
    # similar code to tensorflow util; pylint: disable=duplicate-code
    # Case when the desired device index is invalid: select another one.
    if idx >= len(devices):
        warnings.warn(
            f"Cannot use {device_type} device n°{idx}: index is out-of-range."
            f"\nUsing {device_type} device n°0 instead.",
            RuntimeWarning,
        )
        idx = 0
    # Return the selected device.
    return devices[idx]