diff --git a/internal/dependencymanager/dependencyinstaller.go b/internal/dependencymanager/dependencyinstaller.go index c7cce1a41..0b5587a71 100644 --- a/internal/dependencymanager/dependencyinstaller.go +++ b/internal/dependencymanager/dependencyinstaller.go @@ -113,6 +113,7 @@ type DependencyInstaller struct { SkipAlias bool logs categorizedLogs dependencies map[string]config.Dependency + accountAliases map[string]map[string]flowsdk.Address // network -> account -> alias } // NewDependencyInstaller creates a new instance of DependencyInstaller @@ -148,6 +149,7 @@ func NewDependencyInstaller(logger output.Logger, state *flowkit.State, saveStat SkipAlias: flags.skipAlias, dependencies: make(map[string]config.Dependency), logs: categorizedLogs{}, + accountAliases: make(map[string]map[string]flowsdk.Address), }, nil } @@ -552,16 +554,38 @@ func (di *DependencyInstaller) updateDependencyAlias(contractName, aliasNetwork } for _, missingNetwork := range missingNetworks { + // Check if we already have an alias for this account on this network + accountAddress := di.getCurrentContractAccountAddress(contractName, aliasNetwork) + if accountAddress != "" { + if existingAlias, exists := di.getAccountAlias(accountAddress, missingNetwork); exists { + // Automatically apply the existing alias + contract, err := di.State.Contracts().ByName(contractName) + if err != nil { + return err + } + contract.Aliases.Add(missingNetwork, existingAlias) + di.Logger.Info(fmt.Sprintf("%s Automatically applied alias %s for %s on %s (from same account)", + util.PrintEmoji("🔄"), existingAlias.String(), contractName, missingNetwork)) + continue + } + } + label := fmt.Sprintf("Enter an alias address for %s on %s if you have one, otherwise leave blank", contractName, missingNetwork) raw := prompt.AddressPromptOrEmpty(label, "Invalid alias address") if raw != "" { + aliasAddress := flowsdk.HexToAddress(raw) + + if accountAddress != "" { + di.setAccountAlias(accountAddress, missingNetwork, aliasAddress) + } + contract, err := di.State.Contracts().ByName(contractName) if err != nil { return err } - contract.Aliases.Add(missingNetwork, flowsdk.HexToAddress(raw)) + contract.Aliases.Add(missingNetwork, aliasAddress) } } @@ -591,3 +615,31 @@ func (di *DependencyInstaller) updateDependencyState(networkName, contractAddres return nil } + +// getCurrentContractAccountAddress returns the account address for the current contract being processed +func (di *DependencyInstaller) getCurrentContractAccountAddress(contractName, networkName string) string { + for _, dep := range di.dependencies { + if dep.Name == contractName && dep.Source.NetworkName == networkName { + return dep.Source.Address.String() + } + } + return "" +} + +// getAccountAlias returns the stored alias for an account on a specific network +func (di *DependencyInstaller) getAccountAlias(accountAddress, networkName string) (flowsdk.Address, bool) { + if networkAliases, exists := di.accountAliases[networkName]; exists { + if alias, exists := networkAliases[accountAddress]; exists { + return alias, true + } + } + return flowsdk.Address{}, false +} + +// setAccountAlias stores an alias for an account on a specific network +func (di *DependencyInstaller) setAccountAlias(accountAddress, networkName string, alias flowsdk.Address) { + if di.accountAliases[networkName] == nil { + di.accountAliases[networkName] = make(map[string]flowsdk.Address) + } + di.accountAliases[networkName][accountAddress] = alias +} diff --git a/internal/dependencymanager/dependencyinstaller_test.go b/internal/dependencymanager/dependencyinstaller_test.go index 571e38743..0de117f7b 100644 --- a/internal/dependencymanager/dependencyinstaller_test.go +++ b/internal/dependencymanager/dependencyinstaller_test.go @@ -23,6 +23,7 @@ import ( "testing" "github.com/onflow/flow-go-sdk" + flowsdk "github.com/onflow/flow-go-sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -290,3 +291,75 @@ func TestDependencyInstallerAddMany(t *testing.T) { } }) } + +func TestDependencyInstallerAliasTracking(t *testing.T) { + logger := output.NewStdoutLogger(output.NoneLog) + _, state, _ := util.TestMocks(t) + + serviceAcc, _ := state.EmulatorServiceAccount() + serviceAddress := serviceAcc.Address + + t.Run("AutoApplyAliasForSameAccount", func(t *testing.T) { + gw := mocks.DefaultMockGateway() + + // Mock the same account for both contracts + gw.GetAccount.Run(func(args mock.Arguments) { + addr := args.Get(1).(flow.Address) + assert.Equal(t, addr.String(), serviceAcc.Address.String()) + acc := tests.NewAccountWithAddress(addr.String()) + acc.Contracts = map[string][]byte{ + "ContractOne": []byte("access(all) contract ContractOne {}"), + "ContractTwo": []byte("access(all) contract ContractTwo {}"), + } + + gw.GetAccount.Return(acc, nil) + }) + + di := &DependencyInstaller{ + Gateways: map[string]gateway.Gateway{ + config.EmulatorNetwork.Name: gw.Mock, + config.TestnetNetwork.Name: gw.Mock, + config.MainnetNetwork.Name: gw.Mock, + }, + Logger: logger, + State: state, + SaveState: true, + TargetDir: "", + SkipDeployments: true, + SkipAlias: false, + dependencies: make(map[string]config.Dependency), + accountAliases: make(map[string]map[string]flowsdk.Address), + } + + dep1 := config.Dependency{ + Name: "ContractOne", + Source: config.Source{ + NetworkName: "mainnet", + Address: flow.HexToAddress(serviceAddress.String()), + ContractName: "ContractOne", + }, + } + di.dependencies["mainnet://"+serviceAddress.String()+".ContractOne"] = dep1 + + aliasAddress := flowsdk.HexToAddress("0x1234567890abcdef") + di.setAccountAlias(serviceAddress.String(), "testnet", aliasAddress) + + // Add second contract - this should automatically use the same alias + dep2 := config.Dependency{ + Name: "ContractTwo", + Source: config.Source{ + NetworkName: "mainnet", + Address: flow.HexToAddress(serviceAddress.String()), + ContractName: "ContractTwo", + }, + } + di.dependencies["mainnet://"+serviceAddress.String()+".ContractTwo"] = dep2 + + existingAlias, exists := di.getAccountAlias(serviceAddress.String(), "testnet") + assert.True(t, exists, "Alias should exist for the account") + assert.Equal(t, aliasAddress, existingAlias, "Alias should match the stored value") + + accountAddr := di.getCurrentContractAccountAddress("ContractOne", "mainnet") + assert.Equal(t, serviceAddress.String(), accountAddr, "Should return correct account address") + }) +}