Skip to content

declearn.model.torch.utils.AutoDeviceModule

Bases: torch.nn.Module

Wrapper for a torch.nn.Module, automating device-management.

This torch.nn.Module subclass enables wrapping another one, and provides:

  • a device attribute (and instantiation parameter) indicating where the wrapped module is placed
  • automatic placement of input tensors on that device as part of forward calls to the module
  • a set_device method to change the device and move the wrapped module to it

This aims at internalizing device-management boilerplate code. The wrapped module is assigned to the module attribute and thus can be accessed directly.

Source code in declearn/model/torch/utils/_gpu.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class AutoDeviceModule(torch.nn.Module):
    """Wrapper for a `torch.nn.Module`, automating device-management.

    This `torch.nn.Module` subclass enables wrapping another one, and
    provides:

    - a `device` attribute (and instantiation parameter) indicating
      where the wrapped module is placed
    - automatic placement of input tensors on that device as part of
      `forward` calls to the module
    - a `set_device` method to change the device and move the wrapped
      module to it

    This aims at internalizing device-management boilerplate code.
    The wrapped module is assigned to the `module` attribute and thus
    can be accessed directly.
    """

    def __init__(
        self,
        module: torch.nn.Module,
        device: torch.device,  # pylint: disable=no-member
    ) -> None:
        """Wrap a torch Module into an AutoDeviceModule.

        Parameters
        ----------
        module: torch.nn.Module
            Torch module that needs wrapping.
        device: torch.device
            Torch device where to place the wrapped module and computations.
        """
        super().__init__()
        self.device = device
        self.module = module.to(self.device)

    def forward(self, *inputs: Any) -> torch.Tensor:
        """Run the forward computation, automating device-placement of inputs.

        Please refer to `self.module.forward` for details on the wrapped
        module's forward specifications.
        """
        inputs = tuple(
            x.to(self.device) if isinstance(x, torch.Tensor) else x
            for x in inputs
        )
        return self.module(*inputs)

    def set_device(
        self,
        device: torch.device,  # pylint: disable=no-member
    ) -> None:
        """Move the wrapped module to a pre-selected torch device.

        Parameters
        ----------
        device: torch.device
            Torch device where to place the wrapped module and computations.
        """
        self.device = device
        self.module.to(device)

__init__(module, device)

Wrap a torch Module into an AutoDeviceModule.

Parameters:

Name Type Description Default
module torch.nn.Module

Torch module that needs wrapping.

required
device torch.device

Torch device where to place the wrapped module and computations.

required
Source code in declearn/model/torch/utils/_gpu.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def __init__(
    self,
    module: torch.nn.Module,
    device: torch.device,  # pylint: disable=no-member
) -> None:
    """Wrap a torch Module into an AutoDeviceModule.

    Parameters
    ----------
    module: torch.nn.Module
        Torch module that needs wrapping.
    device: torch.device
        Torch device where to place the wrapped module and computations.
    """
    super().__init__()
    self.device = device
    self.module = module.to(self.device)

forward(*inputs)

Run the forward computation, automating device-placement of inputs.

Please refer to self.module.forward for details on the wrapped module's forward specifications.

Source code in declearn/model/torch/utils/_gpu.py
118
119
120
121
122
123
124
125
126
127
128
def forward(self, *inputs: Any) -> torch.Tensor:
    """Run the forward computation, automating device-placement of inputs.

    Please refer to `self.module.forward` for details on the wrapped
    module's forward specifications.
    """
    inputs = tuple(
        x.to(self.device) if isinstance(x, torch.Tensor) else x
        for x in inputs
    )
    return self.module(*inputs)

set_device(device)

Move the wrapped module to a pre-selected torch device.

Parameters:

Name Type Description Default
device torch.device

Torch device where to place the wrapped module and computations.

required
Source code in declearn/model/torch/utils/_gpu.py
130
131
132
133
134
135
136
137
138
139
140
141
142
def set_device(
    self,
    device: torch.device,  # pylint: disable=no-member
) -> None:
    """Move the wrapped module to a pre-selected torch device.

    Parameters
    ----------
    device: torch.device
        Torch device where to place the wrapped module and computations.
    """
    self.device = device
    self.module.to(device)