|
17 | 17 | import threading
|
18 | 18 | import time
|
19 | 19 | from math import floor, log2
|
20 |
| -from typing import Any, Callable, List, Optional, Tuple, Type, Union |
| 20 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union |
21 | 21 | import torch # usort:skip
|
22 | 22 |
|
23 | 23 | # @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
|
@@ -156,6 +156,7 @@ def __init__(
|
156 | 156 | lazy_bulk_init_enabled: bool = False,
|
157 | 157 | backend_type: BackendType = BackendType.SSD,
|
158 | 158 | kv_zch_params: Optional[KVZCHParams] = None,
|
| 159 | + enable_optimizer_offloading: bool = False, # whether enable optimizer offloading |
159 | 160 | ) -> None:
|
160 | 161 | super(SSDTableBatchedEmbeddingBags, self).__init__()
|
161 | 162 |
|
@@ -435,6 +436,21 @@ def __init__(
|
435 | 436 | self.backend_type = backend_type
|
436 | 437 | if self.kv_zch_params:
|
437 | 438 | self.kv_zch_params.validate()
|
| 439 | + self.enable_optimizer_offloading: bool = enable_optimizer_offloading |
| 440 | + # initial num of embeddings on this rank used for loading |
| 441 | + self.local_weight_counts: List[int] = [0] * T_ |
| 442 | + self.load_state_dict: bool = False |
| 443 | + # cache loaded id, weights and optimizer temporarily before write to backend for kvzch |
| 444 | + # since checkpoint loading does not guarantee tensor loading order, but we need to make |
| 445 | + # sure id tensor is loaded before weight tensor and optimizer tensor can be applied |
| 446 | + # to backend. So we cache them here and apply them to backend after all tensors are loaded. |
| 447 | + # with this solution, we'll duplicate the memory usage for id, weight and optimizer tensors |
| 448 | + # unexpectedly, but it's a tradeoff we can take to make sure the correctness of checkpoint loading |
| 449 | + # before the backend support streaming loading with id tensor first. |
| 450 | + self._cached_id_tensor_per_table: Optional[List[torch.Tensor]] = None |
| 451 | + self._cached_weight_tensor_per_table: Optional[List[torch.Tensor]] = None |
| 452 | + self._cached_optimizer_state_per_table: Optional[List[torch.Tensor]] = None |
| 453 | + self._cached_bucket_splits: Optional[List[torch.Tensor]] = None |
438 | 454 |
|
439 | 455 | # create tbe unique id using rank index | local tbe idx
|
440 | 456 | if tbe_unique_id == -1:
|
@@ -1777,6 +1793,46 @@ def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor, int, int]]:
|
1777 | 1793 | for t, row in enumerate(rows)
|
1778 | 1794 | ]
|
1779 | 1795 |
|
| 1796 | + @torch.jit.export |
| 1797 | + def split_optimizer_states( |
| 1798 | + self, |
| 1799 | + sorted_id_tensor: Optional[List[torch.Tensor]] = None, |
| 1800 | + ) -> List[torch.Tensor]: |
| 1801 | + """ |
| 1802 | + Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD, |
| 1803 | + so only momentum1 state is returned. |
| 1804 | +
|
| 1805 | + Since EXACT_ROWWISE_ADAGRAD has small optimizer states, we would generate |
| 1806 | + a full tensor for each table (shard). When other optimizer types are supported, |
| 1807 | + we should integrate with KVTensorWrapper (ssd_split_table_batched_embeddings.cpp) |
| 1808 | + to allow caller to read the optimizer states using `narrow()` in a rolling-window manner. |
| 1809 | +
|
| 1810 | + Args: |
| 1811 | + sorted_id_tensor (Optional[List[torch.Tensor]]): sorted id tensor by table, used to query optimizer |
| 1812 | + state from backend. Call should reuse the generated id tensor from weight state_dict, to guarantee |
| 1813 | + id consistency between weight and optimizer states. |
| 1814 | +
|
| 1815 | + """ |
| 1816 | + raise NotImplementedError( |
| 1817 | + "split_optimizer_states is not implemented for SSDTableBatchedEmbeddingBags" |
| 1818 | + ) |
| 1819 | + |
| 1820 | + @torch.jit.export |
| 1821 | + def get_optimizer_state( |
| 1822 | + self, |
| 1823 | + sorted_id_tensor: Optional[List[torch.Tensor]], |
| 1824 | + ) -> List[Dict[str, torch.Tensor]]: |
| 1825 | + """ |
| 1826 | + Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD |
| 1827 | + so only momentum1 state is returned. |
| 1828 | + """ |
| 1829 | + return [ |
| 1830 | + ({"momentum1": states}) |
| 1831 | + for states in self.split_optimizer_states( |
| 1832 | + sorted_id_tensor=sorted_id_tensor, |
| 1833 | + ) |
| 1834 | + ] |
| 1835 | + |
1780 | 1836 | @torch.jit.export
|
1781 | 1837 | def debug_split_embedding_weights(self) -> List[torch.Tensor]:
|
1782 | 1838 | """
|
@@ -1822,6 +1878,11 @@ def debug_split_embedding_weights(self) -> List[torch.Tensor]:
|
1822 | 1878 |
|
1823 | 1879 | return splits
|
1824 | 1880 |
|
| 1881 | + def clear_cache(self) -> None: |
| 1882 | + self._cached_bucket_splits = None |
| 1883 | + self._cached_id_tensor_per_table = None |
| 1884 | + self._cached_weight_tensor_per_table = None |
| 1885 | + |
1825 | 1886 | @torch.jit.export
|
1826 | 1887 | def split_embedding_weights(
|
1827 | 1888 | self,
|
@@ -1926,6 +1987,18 @@ def split_embedding_weights(
|
1926 | 1987 | )
|
1927 | 1988 | return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)
|
1928 | 1989 |
|
| 1990 | + @torch.jit.ignore |
| 1991 | + def apply_state_dict(self) -> None: |
| 1992 | + # After checkpoint loading, the _cached_bucket_splits, _cached_id_tensor_per_table |
| 1993 | + # and _cached_weight_tensor_per_table will be loaded from checkpoint. Caller should |
| 1994 | + # call this function to apply the cached states to backend. |
| 1995 | + pass |
| 1996 | + |
| 1997 | + @torch.jit.ignore |
| 1998 | + def enable_load_state_dict_mode(self) -> None: |
| 1999 | + # Enable load state dict mode before loading checkpoint |
| 2000 | + pass |
| 2001 | + |
1929 | 2002 | @torch.jit.export
|
1930 | 2003 | def set_learning_rate(self, lr: float) -> None:
|
1931 | 2004 | """
|
|
0 commit comments