Skip to content

declearn.model.torch.utils.select_device

Select a backing device to use based on inputs and availability.

Parameters:

Name Type Description Default
gpu bool

Whether to select a GPU device rather than the CPU one.

required
idx Optional[int]

Optional pre-selected GPU device index. Only used when gpu=True. If idx is None or exceeds the number of available GPU devices, use torch.cuda.current_device().

None

Warns:

Type Description
RuntimeWarning

If gpu=True but no GPU is available. If idx exceeds the number of available GPU devices.

Returns:

Name Type Description
device torch.device

Selected torch device, with type "cpu" or "cuda".

Source code in declearn/model/torch/utils/_gpu.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def select_device(
    gpu: bool,
    idx: Optional[int] = None,
) -> torch.device:
    """Select a backing device to use based on inputs and availability.

    Parameters
    ----------
    gpu: bool
        Whether to select a GPU device rather than the CPU one.
    idx: int or None, default=None
        Optional pre-selected GPU device index. Only used when `gpu=True`.
        If `idx is None` or exceeds the number of available GPU devices,
        use `torch.cuda.current_device()`.

    Warns
    -----
    RuntimeWarning
        If `gpu=True` but no GPU is available.
        If `idx` exceeds the number of available GPU devices.

    Returns
    -------
    device: torch.device
        Selected torch device, with type "cpu" or "cuda".
    """
    # Case when instructed to use the CPU device.
    if not gpu:
        return torch.device("cpu")
    # Case when no GPU is available: warn and use the CPU instead.
    if gpu and not torch.cuda.is_available():
        warnings.warn(
            "Cannot use a GPU device: either CUDA is unavailable "
            "or no GPU is visible to torch."
        )
        return torch.device("cpu")
    # Case when the desired GPU is invalid: select another one.
    if (idx or 0) >= torch.cuda.device_count():
        warnings.warn(
            f"Cannot use GPU device n°{idx}: index is out-of-range.\n"
            f"Using GPU device n°{torch.cuda.current_device()} instead.",
            RuntimeWarning,
        )
        idx = None
    # Return the selected or auto-selected GPU device index.
    if idx is None:
        idx = torch.cuda.current_device()
    return torch.device("cuda", index=idx)  # pylint: disable=no-member