Skip to content

Commit 65a119c

Browse files
emlinfacebook-github-bot
authored andcommitted
Add new interfaces to SSD TBE for checkpoint saving and loading
Summary: These methods are added to integrate with MVAI trainer and torchrec for checkpoint saving and loading. Differential Revision: D74265651
1 parent 45f0851 commit 65a119c

File tree

1 file changed

+74
-1
lines changed

1 file changed

+74
-1
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

+74-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import threading
1818
import time
1919
from math import floor, log2
20-
from typing import Any, Callable, List, Optional, Tuple, Type, Union
20+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
2121
import torch # usort:skip
2222

2323
# @manual=//deeplearning/fbgemm/fbgemm_gpu/codegen:split_embedding_codegen_lookup_invokers
@@ -156,6 +156,7 @@ def __init__(
156156
lazy_bulk_init_enabled: bool = False,
157157
backend_type: BackendType = BackendType.SSD,
158158
kv_zch_params: Optional[KVZCHParams] = None,
159+
enable_optimizer_offloading: bool = False, # whether enable optimizer offloading
159160
) -> None:
160161
super(SSDTableBatchedEmbeddingBags, self).__init__()
161162

@@ -435,6 +436,21 @@ def __init__(
435436
self.backend_type = backend_type
436437
if self.kv_zch_params:
437438
self.kv_zch_params.validate()
439+
self.enable_optimizer_offloading: bool = enable_optimizer_offloading
440+
# initial num of embeddings on this rank used for loading
441+
self.local_weight_counts: List[int] = [0] * T_
442+
self.load_state_dict: bool = False
443+
# cache loaded id, weights and optimizer temporarily before write to backend for kvzch
444+
# since checkpoint loading does not guarantee tensor loading order, but we need to make
445+
# sure id tensor is loaded before weight tensor and optimizer tensor can be applied
446+
# to backend. So we cache them here and apply them to backend after all tensors are loaded.
447+
# with this solution, we'll duplicate the memory usage for id, weight and optimizer tensors
448+
# unexpectedly, but it's a tradeoff we can take to make sure the correctness of checkpoint loading
449+
# before the backend support streaming loading with id tensor first.
450+
self._cached_id_tensor_per_table: Optional[List[torch.Tensor]] = None
451+
self._cached_weight_tensor_per_table: Optional[List[torch.Tensor]] = None
452+
self._cached_optimizer_state_per_table: Optional[List[torch.Tensor]] = None
453+
self._cached_bucket_splits: Optional[List[torch.Tensor]] = None
438454

439455
# create tbe unique id using rank index | local tbe idx
440456
if tbe_unique_id == -1:
@@ -1777,6 +1793,46 @@ def debug_split_optimizer_states(self) -> List[Tuple[torch.Tensor, int, int]]:
17771793
for t, row in enumerate(rows)
17781794
]
17791795

1796+
@torch.jit.export
1797+
def split_optimizer_states(
1798+
self,
1799+
sorted_id_tensor: Optional[List[torch.Tensor]] = None,
1800+
) -> List[torch.Tensor]:
1801+
"""
1802+
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD,
1803+
so only momentum1 state is returned.
1804+
1805+
Since EXACT_ROWWISE_ADAGRAD has small optimizer states, we would generate
1806+
a full tensor for each table (shard). When other optimizer types are supported,
1807+
we should integrate with KVTensorWrapper (ssd_split_table_batched_embeddings.cpp)
1808+
to allow caller to read the optimizer states using `narrow()` in a rolling-window manner.
1809+
1810+
Args:
1811+
sorted_id_tensor (Optional[List[torch.Tensor]]): sorted id tensor by table, used to query optimizer
1812+
state from backend. Call should reuse the generated id tensor from weight state_dict, to guarantee
1813+
id consistency between weight and optimizer states.
1814+
1815+
"""
1816+
raise NotImplementedError(
1817+
"split_optimizer_states is not implemented for SSDTableBatchedEmbeddingBags"
1818+
)
1819+
1820+
@torch.jit.export
1821+
def get_optimizer_state(
1822+
self,
1823+
sorted_id_tensor: Optional[List[torch.Tensor]],
1824+
) -> List[Dict[str, torch.Tensor]]:
1825+
"""
1826+
Returns a list of optimizer states split by table. So far, we only support EXACT_ROWWISE_ADAGRAD
1827+
so only momentum1 state is returned.
1828+
"""
1829+
return [
1830+
({"momentum1": states})
1831+
for states in self.split_optimizer_states(
1832+
sorted_id_tensor=sorted_id_tensor,
1833+
)
1834+
]
1835+
17801836
@torch.jit.export
17811837
def debug_split_embedding_weights(self) -> List[torch.Tensor]:
17821838
"""
@@ -1822,6 +1878,11 @@ def debug_split_embedding_weights(self) -> List[torch.Tensor]:
18221878

18231879
return splits
18241880

1881+
def clear_cache(self) -> None:
1882+
self._cached_bucket_splits = None
1883+
self._cached_id_tensor_per_table = None
1884+
self._cached_weight_tensor_per_table = None
1885+
18251886
@torch.jit.export
18261887
def split_embedding_weights(
18271888
self,
@@ -1926,6 +1987,18 @@ def split_embedding_weights(
19261987
)
19271988
return (pmt_splits, bucket_sorted_id_splits, active_id_cnt_per_bucket_split)
19281989

1990+
@torch.jit.ignore
1991+
def apply_state_dict(self) -> None:
1992+
# After checkpoint loading, the _cached_bucket_splits, _cached_id_tensor_per_table
1993+
# and _cached_weight_tensor_per_table will be loaded from checkpoint. Caller should
1994+
# call this function to apply the cached states to backend.
1995+
pass
1996+
1997+
@torch.jit.ignore
1998+
def enable_load_state_dict_mode(self) -> None:
1999+
# Enable load state dict mode before loading checkpoint
2000+
pass
2001+
19292002
@torch.jit.export
19302003
def set_learning_rate(self, lr: float) -> None:
19312004
"""

0 commit comments

Comments
 (0)