Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion internal/dependencymanager/dependencyinstaller.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ type DependencyInstaller struct {
SkipAlias bool
logs categorizedLogs
dependencies map[string]config.Dependency
accountAliases map[string]map[string]flowsdk.Address // network -> account -> alias
}

// NewDependencyInstaller creates a new instance of DependencyInstaller
Expand Down Expand Up @@ -148,6 +149,7 @@ func NewDependencyInstaller(logger output.Logger, state *flowkit.State, saveStat
SkipAlias: flags.skipAlias,
dependencies: make(map[string]config.Dependency),
logs: categorizedLogs{},
accountAliases: make(map[string]map[string]flowsdk.Address),
}, nil
}

Expand Down Expand Up @@ -552,16 +554,38 @@ func (di *DependencyInstaller) updateDependencyAlias(contractName, aliasNetwork
}

for _, missingNetwork := range missingNetworks {
// Check if we already have an alias for this account on this network
accountAddress := di.getCurrentContractAccountAddress(contractName, aliasNetwork)
if accountAddress != "" {
if existingAlias, exists := di.getAccountAlias(accountAddress, missingNetwork); exists {
// Automatically apply the existing alias
contract, err := di.State.Contracts().ByName(contractName)
if err != nil {
return err
}
contract.Aliases.Add(missingNetwork, existingAlias)
di.Logger.Info(fmt.Sprintf("%s Automatically applied alias %s for %s on %s (from same account)",
util.PrintEmoji("🔄"), existingAlias.String(), contractName, missingNetwork))
continue
}
}

label := fmt.Sprintf("Enter an alias address for %s on %s if you have one, otherwise leave blank", contractName, missingNetwork)
raw := prompt.AddressPromptOrEmpty(label, "Invalid alias address")

if raw != "" {
aliasAddress := flowsdk.HexToAddress(raw)

if accountAddress != "" {
di.setAccountAlias(accountAddress, missingNetwork, aliasAddress)
}

contract, err := di.State.Contracts().ByName(contractName)
if err != nil {
return err
}

contract.Aliases.Add(missingNetwork, flowsdk.HexToAddress(raw))
contract.Aliases.Add(missingNetwork, aliasAddress)
}
}

Expand Down Expand Up @@ -591,3 +615,31 @@ func (di *DependencyInstaller) updateDependencyState(networkName, contractAddres

return nil
}

// getCurrentContractAccountAddress returns the account address for the current contract being processed
func (di *DependencyInstaller) getCurrentContractAccountAddress(contractName, networkName string) string {
for _, dep := range di.dependencies {
if dep.Name == contractName && dep.Source.NetworkName == networkName {
return dep.Source.Address.String()
}
}
return ""
}

// getAccountAlias returns the stored alias for an account on a specific network
func (di *DependencyInstaller) getAccountAlias(accountAddress, networkName string) (flowsdk.Address, bool) {
if networkAliases, exists := di.accountAliases[networkName]; exists {
if alias, exists := networkAliases[accountAddress]; exists {
return alias, true
}
}
return flowsdk.Address{}, false
}

// setAccountAlias stores an alias for an account on a specific network
func (di *DependencyInstaller) setAccountAlias(accountAddress, networkName string, alias flowsdk.Address) {
if di.accountAliases[networkName] == nil {
di.accountAliases[networkName] = make(map[string]flowsdk.Address)
}
di.accountAliases[networkName][accountAddress] = alias
}
73 changes: 73 additions & 0 deletions internal/dependencymanager/dependencyinstaller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"testing"

"github.com/onflow/flow-go-sdk"
flowsdk "github.com/onflow/flow-go-sdk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"

Expand Down Expand Up @@ -290,3 +291,75 @@ func TestDependencyInstallerAddMany(t *testing.T) {
}
})
}

func TestDependencyInstallerAliasTracking(t *testing.T) {
logger := output.NewStdoutLogger(output.NoneLog)
_, state, _ := util.TestMocks(t)

serviceAcc, _ := state.EmulatorServiceAccount()
serviceAddress := serviceAcc.Address

t.Run("AutoApplyAliasForSameAccount", func(t *testing.T) {
gw := mocks.DefaultMockGateway()

// Mock the same account for both contracts
gw.GetAccount.Run(func(args mock.Arguments) {
addr := args.Get(1).(flow.Address)
assert.Equal(t, addr.String(), serviceAcc.Address.String())
acc := tests.NewAccountWithAddress(addr.String())
acc.Contracts = map[string][]byte{
"ContractOne": []byte("access(all) contract ContractOne {}"),
"ContractTwo": []byte("access(all) contract ContractTwo {}"),
}

gw.GetAccount.Return(acc, nil)
})

di := &DependencyInstaller{
Gateways: map[string]gateway.Gateway{
config.EmulatorNetwork.Name: gw.Mock,
config.TestnetNetwork.Name: gw.Mock,
config.MainnetNetwork.Name: gw.Mock,
},
Logger: logger,
State: state,
SaveState: true,
TargetDir: "",
SkipDeployments: true,
SkipAlias: false,
dependencies: make(map[string]config.Dependency),
accountAliases: make(map[string]map[string]flowsdk.Address),
}

dep1 := config.Dependency{
Name: "ContractOne",
Source: config.Source{
NetworkName: "mainnet",
Address: flow.HexToAddress(serviceAddress.String()),
ContractName: "ContractOne",
},
}
di.dependencies["mainnet://"+serviceAddress.String()+".ContractOne"] = dep1

aliasAddress := flowsdk.HexToAddress("0x1234567890abcdef")
di.setAccountAlias(serviceAddress.String(), "testnet", aliasAddress)

// Add second contract - this should automatically use the same alias
dep2 := config.Dependency{
Name: "ContractTwo",
Source: config.Source{
NetworkName: "mainnet",
Address: flow.HexToAddress(serviceAddress.String()),
ContractName: "ContractTwo",
},
}
di.dependencies["mainnet://"+serviceAddress.String()+".ContractTwo"] = dep2

existingAlias, exists := di.getAccountAlias(serviceAddress.String(), "testnet")
assert.True(t, exists, "Alias should exist for the account")
assert.Equal(t, aliasAddress, existingAlias, "Alias should match the stored value")

accountAddr := di.getCurrentContractAccountAddress("ContractOne", "mainnet")
assert.Equal(t, serviceAddress.String(), accountAddr, "Should return correct account address")
})
}
Loading