Skip to content

declearn.model.tensorflow.utils.select_device

Select a backing device to use based on inputs and availability.

Parameters:

Name Type Description Default
gpu bool

Whether to select a GPU device rather than the CPU one.

required
idx Optional[int]

Optional pre-selected device index. Only used when gpu=True. If idx is None or exceeds the number of available GPU devices, use the first available one.

None

Warns:

Type Description
RuntimeWarning

If gpu=True but no GPU is available. If idx exceeds the number of available GPU devices.

Returns:

Name Type Description
device tf.config.LogicalDevice

Selected device, usable as tf.device argument.

Source code in declearn/model/tensorflow/utils/_gpu.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def select_device(
    gpu: bool,
    idx: Optional[int] = None,
) -> tf.config.LogicalDevice:
    """Select a backing device to use based on inputs and availability.

    Parameters
    ----------
    gpu: bool
        Whether to select a GPU device rather than the CPU one.
    idx: int or None, default=None
        Optional pre-selected device index. Only used when `gpu=True`.
        If `idx is None` or exceeds the number of available GPU devices,
        use the first available one.

    Warns
    -----
    RuntimeWarning
        If `gpu=True` but no GPU is available.
        If `idx` exceeds the number of available GPU devices.

    Returns
    -------
    device:
        Selected device, usable as `tf.device` argument.
    """
    idx = 0 if idx is None else idx
    # List available CPU or GPU devices.
    device_type = "GPU" if gpu else "CPU"
    devices = tf.config.list_logical_devices(device_type)
    # Case when no GPU is available: warn and use a CPU instead.
    if gpu and not devices:
        warnings.warn(
            "Cannot use a GPU device: either CUDA is unavailable "
            "or no GPU is visible to tensorflow."
        )
        device_type, idx = "CPU", 0
        devices = tf.config.list_logical_devices("CPU")
    # Case when the desired device index is invalid: select another one.
    if idx >= len(devices):
        warnings.warn(
            f"Cannot use {device_type} device n°{idx}: index is out-of-range."
            f"\nUsing {device_type} device n°0 instead.",
            RuntimeWarning,
        )
        idx = 0
    # Return the selected device.
    return devices[idx]