diff --git a/src/cmd/shim/main.go b/src/cmd/shim/main.go index 13c5356..a54121d 100644 --- a/src/cmd/shim/main.go +++ b/src/cmd/shim/main.go @@ -13,6 +13,7 @@ import ( "github.com/dtvem/dtvem/src/internal/config" "github.com/dtvem/dtvem/src/internal/constants" "github.com/dtvem/dtvem/src/internal/runtime" + "github.com/dtvem/dtvem/src/internal/shim" "github.com/dtvem/dtvem/src/internal/ui" // Import runtime providers to register them @@ -193,17 +194,24 @@ func getShimName() string { } // mapShimToRuntime maps a shim name to its runtime -// For example: python3 -> python, pip -> python, npm -> node -// This queries all registered providers for their shims, eliminating the need -// for a central hardcoded mapping. +// For example: python3 -> python, pip -> python, npm -> node, tsc -> node +// First checks the shim map cache (generated by reshim), then falls back +// to querying registered providers. func mapShimToRuntime(shimName string) string { + // First, try the shim map cache for O(1) lookup + // This handles both core shims and dynamically installed packages (tsc, eslint, black, etc.) + if runtimeName, ok := shim.LookupRuntime(shimName); ok { + return runtimeName + } + + // Fall back to provider-based lookup if cache is missing or doesn't have the shim // Get all registered providers (using ShimProvider interface) providers := runtime.GetAllShimProviders() // Check each provider's shims for an exact match first for _, provider := range providers { - for _, shim := range provider.Shims() { - if shim == shimName { + for _, s := range provider.Shims() { + if s == shimName { return provider.Name() } } @@ -211,8 +219,8 @@ func mapShimToRuntime(shimName string) string { // Check for prefix match (e.g., python3 -> python) for _, provider := range providers { - for _, shim := range provider.Shims() { - if strings.HasPrefix(shimName, shim) { + for _, s := range provider.Shims() { + if strings.HasPrefix(shimName, s) { return provider.Name() } } diff --git a/src/internal/config/paths.go b/src/internal/config/paths.go index 6a07dbf..9827962 100644 --- a/src/internal/config/paths.go +++ b/src/internal/config/paths.go @@ -16,6 +16,7 @@ type Paths struct { Shims string // Shims directory (~/.dtvem/shims) Versions string // Versions directory (~/.dtvem/versions) Config string // Config directory (~/.dtvem/config) + Cache string // Cache directory (~/.dtvem/cache) } var ( @@ -40,6 +41,7 @@ func initPaths() *Paths { Shims: filepath.Join(root, "shims"), Versions: filepath.Join(root, "versions"), Config: filepath.Join(root, "config"), + Cache: filepath.Join(root, "cache"), } } @@ -90,6 +92,7 @@ func EnsureDirectories() error { paths.Shims, paths.Versions, paths.Config, + paths.Cache, } for _, dir := range dirs { @@ -120,3 +123,19 @@ const LocalConfigDirName = ".dtvem" // RuntimesFileName is the name of the runtimes configuration file const RuntimesFileName = "runtimes.json" + +// ShimMapFileName is the name of the shim-to-runtime mapping cache file +const ShimMapFileName = "shim-map.json" + +// ShimMapPath returns the path to the shim-to-runtime mapping cache file +func ShimMapPath() string { + paths := DefaultPaths() + return filepath.Join(paths.Cache, ShimMapFileName) +} + +// ResetPathsCache resets the cached paths, forcing reinitialization on next access. +// This is primarily useful for testing. +func ResetPathsCache() { + pathsOnce = sync.Once{} + defaultPaths = nil +} diff --git a/src/internal/shim/cache.go b/src/internal/shim/cache.go new file mode 100644 index 0000000..f4e25ee --- /dev/null +++ b/src/internal/shim/cache.go @@ -0,0 +1,87 @@ +// Package shim manages shim executables that intercept runtime commands +package shim + +import ( + "encoding/json" + "os" + "sync" + + "github.com/dtvem/dtvem/src/internal/config" +) + +// ShimMap represents the shim-to-runtime mapping cache +// The map key is the shim name (e.g., "tsc", "npm", "black") +// The map value is the runtime name (e.g., "node", "python") +type ShimMap map[string]string + +var ( + shimMapCache ShimMap + shimMapCacheOnce sync.Once + shimMapCacheErr error +) + +// LoadShimMap loads the shim-to-runtime mapping from the cache file. +// It uses sync.Once to ensure the cache is only loaded once per process. +// Returns the cached map and any error that occurred during loading. +func LoadShimMap() (ShimMap, error) { + shimMapCacheOnce.Do(func() { + shimMapCache, shimMapCacheErr = loadShimMapFromDisk() + }) + return shimMapCache, shimMapCacheErr +} + +// loadShimMapFromDisk reads the shim map cache file from disk +func loadShimMapFromDisk() (ShimMap, error) { + cachePath := config.ShimMapPath() + + data, err := os.ReadFile(cachePath) + if err != nil { + return nil, err + } + + var shimMap ShimMap + if err := json.Unmarshal(data, &shimMap); err != nil { + return nil, err + } + + return shimMap, nil +} + +// SaveShimMap writes the shim-to-runtime mapping to the cache file. +// This should be called during reshim operations. +func SaveShimMap(shimMap ShimMap) error { + // Ensure cache directory exists + paths := config.DefaultPaths() + if err := os.MkdirAll(paths.Cache, 0755); err != nil { + return err + } + + cachePath := config.ShimMapPath() + + data, err := json.MarshalIndent(shimMap, "", " ") + if err != nil { + return err + } + + return os.WriteFile(cachePath, data, 0644) +} + +// LookupRuntime looks up the runtime for a given shim name using the cache. +// Returns the runtime name and true if found, or empty string and false if not. +func LookupRuntime(shimName string) (string, bool) { + shimMap, err := LoadShimMap() + if err != nil { + return "", false + } + + runtime, ok := shimMap[shimName] + return runtime, ok +} + +// ResetShimMapCache resets the cached shim map, forcing a reload on next access. +// This is primarily useful for testing. +func ResetShimMapCache() { + shimMapCacheOnce = sync.Once{} + shimMapCache = nil + shimMapCacheErr = nil +} diff --git a/src/internal/shim/cache_test.go b/src/internal/shim/cache_test.go new file mode 100644 index 0000000..ca2226a --- /dev/null +++ b/src/internal/shim/cache_test.go @@ -0,0 +1,220 @@ +package shim + +import ( + "os" + "path/filepath" + "testing" + + "github.com/dtvem/dtvem/src/internal/config" +) + +func TestSaveAndLoadShimMap(t *testing.T) { + // Create a temporary directory for the test + tempDir := t.TempDir() + + // Set DTVEM_ROOT to use our temp directory + originalRoot := os.Getenv("DTVEM_ROOT") + _ = os.Setenv("DTVEM_ROOT", tempDir) + defer func() { _ = os.Setenv("DTVEM_ROOT", originalRoot) }() + + // Reset the paths cache to pick up new DTVEM_ROOT + config.ResetPathsCache() + defer config.ResetPathsCache() + + // Reset the shim map cache + ResetShimMapCache() + defer ResetShimMapCache() + + // Create the cache directory + cacheDir := filepath.Join(tempDir, "cache") + if err := os.MkdirAll(cacheDir, 0755); err != nil { + t.Fatalf("Failed to create cache directory: %v", err) + } + + // Create a test shim map + testMap := ShimMap{ + "node": "node", + "npm": "node", + "npx": "node", + "tsc": "node", + "eslint": "node", + "python": "python", + "pip": "python", + "black": "python", + } + + // Save the map + if err := SaveShimMap(testMap); err != nil { + t.Fatalf("Failed to save shim map: %v", err) + } + + // Verify the file was created + cachePath := config.ShimMapPath() + if _, err := os.Stat(cachePath); os.IsNotExist(err) { + t.Fatalf("Shim map cache file was not created at %s", cachePath) + } + + // Load the map + loadedMap, err := LoadShimMap() + if err != nil { + t.Fatalf("Failed to load shim map: %v", err) + } + + // Verify all entries + for shimName, expectedRuntime := range testMap { + if loadedRuntime, ok := loadedMap[shimName]; !ok { + t.Errorf("Shim %q not found in loaded map", shimName) + } else if loadedRuntime != expectedRuntime { + t.Errorf("Shim %q: expected runtime %q, got %q", shimName, expectedRuntime, loadedRuntime) + } + } +} + +func TestLookupRuntime(t *testing.T) { + // Create a temporary directory for the test + tempDir := t.TempDir() + + // Set DTVEM_ROOT to use our temp directory + originalRoot := os.Getenv("DTVEM_ROOT") + _ = os.Setenv("DTVEM_ROOT", tempDir) + defer func() { _ = os.Setenv("DTVEM_ROOT", originalRoot) }() + + // Reset the paths cache to pick up new DTVEM_ROOT + config.ResetPathsCache() + defer config.ResetPathsCache() + + // Reset the shim map cache + ResetShimMapCache() + defer ResetShimMapCache() + + // Create the cache directory + cacheDir := filepath.Join(tempDir, "cache") + if err := os.MkdirAll(cacheDir, 0755); err != nil { + t.Fatalf("Failed to create cache directory: %v", err) + } + + // Create and save a test shim map + testMap := ShimMap{ + "node": "node", + "npm": "node", + "tsc": "node", + "python": "python", + "black": "python", + } + + if err := SaveShimMap(testMap); err != nil { + t.Fatalf("Failed to save shim map: %v", err) + } + + // Test lookups + tests := []struct { + shimName string + expectedRuntime string + expectedFound bool + }{ + {"node", "node", true}, + {"npm", "node", true}, + {"tsc", "node", true}, + {"python", "python", true}, + {"black", "python", true}, + {"unknown", "", false}, + {"", "", false}, + } + + for _, tc := range tests { + t.Run(tc.shimName, func(t *testing.T) { + runtime, found := LookupRuntime(tc.shimName) + if found != tc.expectedFound { + t.Errorf("LookupRuntime(%q): expected found=%v, got found=%v", tc.shimName, tc.expectedFound, found) + } + if runtime != tc.expectedRuntime { + t.Errorf("LookupRuntime(%q): expected runtime=%q, got runtime=%q", tc.shimName, tc.expectedRuntime, runtime) + } + }) + } +} + +func TestLookupRuntimeNoCacheFile(t *testing.T) { + // Create a temporary directory for the test (empty, no cache file) + tempDir := t.TempDir() + + // Set DTVEM_ROOT to use our temp directory + originalRoot := os.Getenv("DTVEM_ROOT") + _ = os.Setenv("DTVEM_ROOT", tempDir) + defer func() { _ = os.Setenv("DTVEM_ROOT", originalRoot) }() + + // Reset the paths cache to pick up new DTVEM_ROOT + config.ResetPathsCache() + defer config.ResetPathsCache() + + // Reset the shim map cache + ResetShimMapCache() + defer ResetShimMapCache() + + // Lookup should return not found when cache doesn't exist + runtime, found := LookupRuntime("node") + if found { + t.Errorf("LookupRuntime should return found=false when cache doesn't exist") + } + if runtime != "" { + t.Errorf("LookupRuntime should return empty runtime when cache doesn't exist, got %q", runtime) + } +} + +func TestShimMapCacheOnlyLoadsOnce(t *testing.T) { + // Create a temporary directory for the test + tempDir := t.TempDir() + + // Set DTVEM_ROOT to use our temp directory + originalRoot := os.Getenv("DTVEM_ROOT") + _ = os.Setenv("DTVEM_ROOT", tempDir) + defer func() { _ = os.Setenv("DTVEM_ROOT", originalRoot) }() + + // Reset the paths cache to pick up new DTVEM_ROOT + config.ResetPathsCache() + defer config.ResetPathsCache() + + // Reset the shim map cache + ResetShimMapCache() + defer ResetShimMapCache() + + // Create the cache directory + cacheDir := filepath.Join(tempDir, "cache") + if err := os.MkdirAll(cacheDir, 0755); err != nil { + t.Fatalf("Failed to create cache directory: %v", err) + } + + // Create and save initial shim map + initialMap := ShimMap{"node": "node"} + if err := SaveShimMap(initialMap); err != nil { + t.Fatalf("Failed to save initial shim map: %v", err) + } + + // Load the map + map1, err := LoadShimMap() + if err != nil { + t.Fatalf("Failed to load shim map: %v", err) + } + + // Modify the file on disk + modifiedMap := ShimMap{"node": "modified", "new": "entry"} + if err := SaveShimMap(modifiedMap); err != nil { + t.Fatalf("Failed to save modified shim map: %v", err) + } + + // Load again - should return cached version (sync.Once) + map2, err := LoadShimMap() + if err != nil { + t.Fatalf("Failed to load shim map second time: %v", err) + } + + // Both should be the same (initial map, not modified) + if map1["node"] != map2["node"] { + t.Errorf("Cache should return same map: map1[node]=%q, map2[node]=%q", map1["node"], map2["node"]) + } + + // The modified entry should not be present (cache wasn't reloaded) + if _, ok := map2["new"]; ok { + t.Errorf("Cache should not have reloaded - 'new' entry should not exist") + } +} diff --git a/src/internal/shim/manager.go b/src/internal/shim/manager.go index 4e80b23..28d7e13 100644 --- a/src/internal/shim/manager.go +++ b/src/internal/shim/manager.go @@ -144,8 +144,8 @@ func (m *Manager) Rehash() error { return fmt.Errorf("failed to read versions directory: %w", err) } - // Collect all shims to create (use map to deduplicate) - shimsToCreate := make(map[string]bool) + // Collect shim-to-runtime mappings (shim name -> runtime name) + shimMap := make(ShimMap) for _, entry := range entries { if !entry.IsDir() { @@ -172,34 +172,47 @@ func (m *Manager) Rehash() error { // First, add core runtime shims (from provider) coreShims := RuntimeShims(runtimeName) for _, shimName := range coreShims { - shimsToCreate[shimName] = true + shimMap[shimName] = runtimeName } // Then, scan bin directory for globally installed packages binDir := filepath.Join(versionDir, "bin") if execs, err := findExecutables(binDir); err == nil { for _, exec := range execs { - shimsToCreate[exec] = true + shimMap[exec] = runtimeName } } // On Windows, also check the root version directory for .cmd/.bat files + // and Scripts directory for Python packages if runtime.GOOS == constants.OSWindows { if execs, err := findExecutables(versionDir); err == nil { for _, exec := range execs { - shimsToCreate[exec] = true + shimMap[exec] = runtimeName + } + } + // Check Scripts directory for Python pip packages + scriptsDir := filepath.Join(versionDir, "Scripts") + if execs, err := findExecutables(scriptsDir); err == nil { + for _, exec := range execs { + shimMap[exec] = runtimeName } } } } } - if len(shimsToCreate) == 0 { + if len(shimMap) == 0 { return fmt.Errorf("no runtimes installed - nothing to reshim") } + // Save the shim map cache + if err := SaveShimMap(shimMap); err != nil { + return fmt.Errorf("failed to save shim map cache: %w", err) + } + // Create all shims - for shimName := range shimsToCreate { + for shimName := range shimMap { if err := m.CreateShim(shimName); err != nil { return err }