declearn.model.torch.utils.AutoDeviceModule
Bases: torch.nn.Module
Wrapper for a torch.nn.Module
, automating device-management.
This torch.nn.Module
subclass enables wrapping another one, and
provides:
- a
device
attribute (and instantiation parameter) indicating where the wrapped module is placed - automatic placement of input tensors on that device as part of
forward
calls to the module - a
set_device
method to change the device and move the wrapped module to it
This aims at internalizing device-management boilerplate code.
The wrapped module is assigned to the module
attribute and thus
can be accessed directly.
Source code in declearn/model/torch/utils/_gpu.py
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
|
__init__(module, device)
Wrap a torch Module into an AutoDeviceModule.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module |
torch.nn.Module
|
Torch module that needs wrapping. |
required |
device |
torch.device
|
Torch device where to place the wrapped module and computations. |
required |
Source code in declearn/model/torch/utils/_gpu.py
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
|
forward(*inputs)
Run the forward computation, automating device-placement of inputs.
Please refer to self.module.forward
for details on the wrapped
module's forward specifications.
Source code in declearn/model/torch/utils/_gpu.py
118 119 120 121 122 123 124 125 126 127 128 |
|
set_device(device)
Move the wrapped module to a pre-selected torch device.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device |
torch.device
|
Torch device where to place the wrapped module and computations. |
required |
Source code in declearn/model/torch/utils/_gpu.py
130 131 132 133 134 135 136 137 138 139 140 141 142 |
|