diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/ssd_config.py b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/ssd_config.py index ff62e3d5b3..a03461a9a4 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/ssd/ssd_config.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/ssd/ssd_config.py @@ -314,6 +314,7 @@ class BackendType(enum.IntEnum): SSD = 0 DRAM = 1 PS = 2 + DRAM_SSD = 3 @classmethod def from_str(cls, key: str) -> "BackendType": @@ -321,6 +322,7 @@ def from_str(cls, key: str) -> "BackendType": "ssd": BackendType.SSD, "dram": BackendType.DRAM, "ps": BackendType.PS, + "dram_ssd": BackendType.DRAM_SSD, } if key in lookup: return lookup[key] diff --git a/fbgemm_gpu/test/tbe/ssd/ssd_config_test.py b/fbgemm_gpu/test/tbe/ssd/ssd_config_test.py index f25f5cb8a2..382db985c2 100644 --- a/fbgemm_gpu/test/tbe/ssd/ssd_config_test.py +++ b/fbgemm_gpu/test/tbe/ssd/ssd_config_test.py @@ -123,6 +123,17 @@ def test_dram(self) -> None: self.assertEqual(BackendType.from_str("dram"), BackendType.DRAM) + def test_ps(self) -> None: + from fbgemm_gpu.tbe.ssd import BackendType + + self.assertEqual(BackendType.from_str("ps"), BackendType.PS) + + def test_dram_ssd(self) -> None: + from fbgemm_gpu.tbe.ssd import BackendType + + self.assertEqual(BackendType.from_str("dram_ssd"), BackendType.DRAM_SSD) + self.assertEqual(BackendType.DRAM_SSD.value, 3) + def test_invalid_raises(self) -> None: from fbgemm_gpu.tbe.ssd import BackendType