Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 62 additions & 27 deletions tests/e2e/devstack/docker/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os/exec"
"strings"
"time"

"golang.org/x/sync/errgroup"

Expand All @@ -25,6 +28,7 @@ const (
indexerPort = 8082
mockOAuth2Port = 8088
postgresPort = 5432
maxErrorBody = 4096
)

// ServiceDiscovery resolves running container ports and reads the bootstrap
Expand Down Expand Up @@ -57,10 +61,9 @@ type deployManifest struct {
// are issued concurrently via errgroup to minimize wall-clock time at test-suite
// startup.
//
// DSNs are read directly from each service's own environment variables
// (RELAYER_DATABASE_URL, API_SERVER_DATABASE_URL, INDEXER_DATABASE_URL) via
// docker inspect, with the internal hostname replaced by the published
// localhost:PORT. This avoids hardcoding credentials or database names.
// The api-server DSN is read from the API_SERVER_DATABASE_URL environment
// variable via docker inspect, with the internal hostname replaced by the
// published localhost:PORT. This avoids hardcoding credentials or database names.
func (d *ServiceDiscovery) Manifest(ctx context.Context) (*stack.ServiceManifest, error) {
// Phase 1: resolve all endpoints and the postgres host in parallel.
var (
Expand Down Expand Up @@ -89,27 +92,13 @@ func (d *ServiceDiscovery) Manifest(ctx context.Context) (*stack.ServiceManifest
return nil, err
}

// Phase 2: fetch DSNs in parallel (depend on postgresHost from phase 1).
var (
apiDSN string
relayerDSN string
indexerDSN string
)
apiDSN, err := d.serviceDSN(ctx, "api-server", "API_SERVER_DATABASE_URL", postgresHost)
if err != nil {
return nil, err
}

g2, gctx2 := errgroup.WithContext(ctx)
g2.Go(func() (err error) {
apiDSN, err = d.serviceDSN(gctx2, "api-server", "API_SERVER_DATABASE_URL", postgresHost)
return
})
g2.Go(func() (err error) {
relayerDSN, err = d.serviceDSN(gctx2, "relayer", "RELAYER_DATABASE_URL", postgresHost)
return
})
g2.Go(func() (err error) {
indexerDSN, err = d.serviceDSN(gctx2, "indexer", "INDEXER_DATABASE_URL", postgresHost)
return
})
if err := g2.Wait(); err != nil {
cantonDomainID, err := d.readCantonDomainID(ctx, cantonHTTP)
if err != nil {
return nil, err
}

Expand All @@ -122,17 +111,54 @@ func (d *ServiceDiscovery) Manifest(ctx context.Context) (*stack.ServiceManifest
IndexerHTTP: indexerHTTP,
OAuthHTTP: oauthHTTP,
APIDatabaseDSN: apiDSN,
RelayerDatabaseDSN: relayerDSN,
IndexerDatabaseDSN: indexerDSN,
PromptTokenAddr: dm.PromptToken,
BridgeAddr: dm.CantonBridge,
PromptInstrumentAdmin: dm.PromptInstrumentAdmin,
PromptInstrumentID: dm.PromptInstrumentID,
DemoInstrumentAdmin: dm.DemoInstrumentAdmin,
DemoInstrumentID: dm.DemoInstrumentID,
CantonDomainID: cantonDomainID,
DemoTokenAddr: stack.DemoTokenVirtualAddr,
}, nil
}

// readCantonDomainID calls the Canton HTTP JSON API to retrieve the synchronizer ID.
func (ServiceDiscovery) readCantonDomainID(ctx context.Context, cantonHTTP string) (string, error) {
type synchronizer struct {
SynchronizerID string `json:"synchronizerId"`
}
type response struct {
ConnectedSynchronizers []synchronizer `json:"connectedSynchronizers"`
}

client := &http.Client{Timeout: 10 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodGet,
cantonHTTP+"/v2/state/connected-synchronizers", nil)
if err != nil {
return "", fmt.Errorf("build canton domain-id request: %w", err)
}

resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("get canton connected-synchronizers: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
b, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBody))
return "", fmt.Errorf("canton connected-synchronizers: status %d: %s", resp.StatusCode, b)
}

var r response
if err := json.NewDecoder(resp.Body).Decode(&r); err != nil {
return "", fmt.Errorf("decode canton connected-synchronizers: %w", err)
}
if len(r.ConnectedSynchronizers) == 0 {
return "", fmt.Errorf("canton has no connected synchronizers")
}
return r.ConnectedSynchronizers[0].SynchronizerID, nil
}

// serviceDSN reads the named environment variable from the running service
// container via docker inspect, then rewrites the host to the published
// postgresHost (localhost:PORT) so the DSN is usable from outside Docker.
Expand All @@ -149,6 +175,11 @@ func (d *ServiceDiscovery) serviceDSN(ctx context.Context, service, envVar, post
return "", fmt.Errorf("parsing %s from %s: %w", envVar, service, err)
}
u.Host = postgresHost
// The devnet postgres container has no SSL; force sslmode=disable so the
// lib/pq driver can connect from outside Docker.
q := u.Query()
q.Set("sslmode", "disable")
u.RawQuery = q.Encode()
return u.String(), nil
}

Expand Down Expand Up @@ -245,11 +276,15 @@ func (d *ServiceDiscovery) publishedPort(ctx context.Context, service string, co
//
// docker compose -p <project> run --rm bootstrap cat /tmp/e2e-deploy.json
func (d *ServiceDiscovery) readDeployManifest(ctx context.Context) (*deployManifest, error) {
// Use --entrypoint cat to bypass the bootstrap container's own entrypoint
// (docker-bootstrap.sh), which writes status text to stdout and would
// corrupt the JSON before we can parse it.
cmd := dockerComposeCommand(ctx,
"-p", d.projectName,
"run", "--rm",
"--entrypoint", "cat",
"bootstrap",
"cat", "/tmp/e2e-deploy.json",
"/tmp/e2e-deploy.json",
)
var out, errBuf bytes.Buffer
cmd.Stdout = &out
Expand Down
208 changes: 208 additions & 0 deletions tests/e2e/devstack/shim/anvil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
//go:build e2e

// Package shim provides concrete implementations of the stack service
// interfaces. Each shim wraps a real network client (go-ethereum, HTTP, SQL)
// and is initialized from a ServiceManifest produced by ServiceDiscovery.
package shim

import (
"context"
"crypto/ecdsa"
"encoding/hex"
"errors"
"fmt"
"math/big"
"strings"
"time"

ethereum "github.com/ethereum/go-ethereum"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethclient"

"github.com/chainsafe/canton-middleware/pkg/auth"
"github.com/chainsafe/canton-middleware/pkg/ethereum/contracts"
"github.com/chainsafe/canton-middleware/tests/e2e/devstack/stack"
)

var _ stack.Anvil = (*AnvilShim)(nil)

// txGasLimit is a fixed gas ceiling for approve and depositToCanton transactions
// on the local Anvil devnet. Anvil's instant mining makes estimation unnecessary.
const (
txGasLimit = 300_000
txWaitTimeout = 30 * time.Second
bytes32Len = 32
)

// AnvilShim implements stack.Anvil against a local Anvil node.
type AnvilShim struct {
endpoint string
rpc *ethclient.Client
chainID *big.Int
tokenAddr common.Address
bridgeAddr common.Address
}

// NewAnvil dials the Anvil RPC endpoint from the manifest and returns a ready
// shim. It resolves chainID eagerly so callers do not need a context.
func NewAnvil(ctx context.Context, manifest *stack.ServiceManifest) (*AnvilShim, error) {
client, err := ethclient.DialContext(ctx, manifest.AnvilRPC)
if err != nil {
return nil, fmt.Errorf("dial anvil: %w", err)
}
chainID, err := client.ChainID(ctx)
if err != nil {
return nil, fmt.Errorf("get anvil chain ID: %w", err)
}
return &AnvilShim{
endpoint: manifest.AnvilRPC,
rpc: client,
chainID: chainID,
tokenAddr: common.HexToAddress(manifest.PromptTokenAddr),
bridgeAddr: common.HexToAddress(manifest.BridgeAddr),
}, nil
}

func (a *AnvilShim) Endpoint() string { return a.endpoint }
func (a *AnvilShim) RPC() *ethclient.Client { return a.rpc }
func (a *AnvilShim) ChainID() *big.Int { return a.chainID }
func (a *AnvilShim) Close() { a.rpc.Close() }

// ERC20Balance returns the on-chain ERC-20 balance of owner for tokenAddr.
func (a *AnvilShim) ERC20Balance(ctx context.Context, tokenAddr, owner common.Address) (*big.Int, error) {
token, err := contracts.NewPromptToken(tokenAddr, a.rpc)
if err != nil {
return nil, fmt.Errorf("bind erc20: %w", err)
}
bal, err := token.BalanceOf(&bind.CallOpts{Context: ctx}, owner)
if err != nil {
return nil, fmt.Errorf("balanceOf: %w", err)
}
return bal, nil
}

// ApproveAndDeposit approves the bridge contract and submits a depositToCanton
// transaction for account. The canton recipient bytes32 is derived from the
// account's EVM address fingerprint via auth.ComputeFingerprint.
func (a *AnvilShim) ApproveAndDeposit(ctx context.Context, account *stack.Account, amount *big.Int) (common.Hash, error) {
key, err := parseKey(account.PrivateKey)
if err != nil {
return common.Hash{}, err
}

fingerprint := auth.ComputeFingerprint(account.Address.Hex())
recipient, err := fingerprintToBytes32(fingerprint)
if err != nil {
return common.Hash{}, err
}

token, err := contracts.NewPromptToken(a.tokenAddr, a.rpc)
if err != nil {
return common.Hash{}, fmt.Errorf("bind prompt token: %w", err)
}
bridge, err := contracts.NewCantonBridge(a.bridgeAddr, a.rpc)
if err != nil {
return common.Hash{}, fmt.Errorf("bind canton bridge: %w", err)
}

// Step 1: approve.
auth, err := newTransactor(ctx, a.rpc, key, a.chainID)
if err != nil {
return common.Hash{}, err
}
approveTx, err := token.Approve(auth, a.bridgeAddr, amount)
if err != nil {
return common.Hash{}, fmt.Errorf("approve: %w", err)
}
if waitErr := waitForTx(ctx, a.rpc, approveTx.Hash(), txWaitTimeout); waitErr != nil {
return common.Hash{}, fmt.Errorf("wait approve tx: %w", waitErr)
}

// Step 2: deposit.
auth, err = newTransactor(ctx, a.rpc, key, a.chainID)
if err != nil {
return common.Hash{}, err
}
depositTx, err := bridge.DepositToCanton(auth, a.tokenAddr, amount, recipient)
if err != nil {
return common.Hash{}, fmt.Errorf("depositToCanton: %w", err)
}
if waitErr := waitForTx(ctx, a.rpc, depositTx.Hash(), txWaitTimeout); waitErr != nil {
return common.Hash{}, fmt.Errorf("wait deposit tx: %w", waitErr)
}

return depositTx.Hash(), nil
}

// newTransactor creates a TransactOpts with current nonce and suggested gas price.
func newTransactor(ctx context.Context, client *ethclient.Client, key *ecdsa.PrivateKey, chainID *big.Int) (*bind.TransactOpts, error) {
auth, err := bind.NewKeyedTransactorWithChainID(key, chainID)
if err != nil {
return nil, fmt.Errorf("keyed transactor: %w", err)
}
nonce, err := client.PendingNonceAt(ctx, crypto.PubkeyToAddress(key.PublicKey))
if err != nil {
return nil, fmt.Errorf("pending nonce: %w", err)
}
auth.Nonce = new(big.Int).SetUint64(nonce)
gasPrice, err := client.SuggestGasPrice(ctx)
if err != nil {
return nil, fmt.Errorf("suggest gas price: %w", err)
}
auth.GasPrice = gasPrice
auth.GasLimit = txGasLimit
return auth, nil
}

// waitForTx polls until the transaction is mined or the timeout is reached.
// It returns immediately on any RPC error other than ethereum.NotFound (tx not
// yet visible) to avoid masking genuine node failures.
func waitForTx(ctx context.Context, client *ethclient.Client, hash common.Hash, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
for {
receipt, err := client.TransactionReceipt(ctx, hash)
if err == nil {
if receipt.Status == 1 {
return nil
}
return fmt.Errorf("transaction %s reverted", hash.Hex())
}
if !errors.Is(err, ethereum.NotFound) {
return fmt.Errorf("receipt query for %s: %w", hash.Hex(), err)
}
select {
case <-ctx.Done():
return fmt.Errorf("timeout waiting for tx %s: %w", hash.Hex(), ctx.Err())
case <-time.After(time.Second):
}
}
}

// parseKey decodes a hex-encoded ECDSA private key (without 0x prefix).
func parseKey(hexKey string) (*ecdsa.PrivateKey, error) {
key, err := crypto.HexToECDSA(hexKey)
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
return key, nil
}

// fingerprintToBytes32 converts a hex fingerprint string to a [32]byte.
// auth.ComputeFingerprint always returns a keccak256 hash (exactly 32 bytes),
// so copy fills the full array with no trailing zeros.
func fingerprintToBytes32(fingerprint string) ([32]byte, error) {
var result [32]byte
fingerprint = strings.TrimPrefix(fingerprint, "0x")
data, err := hex.DecodeString(fingerprint)
if err != nil {
return result, fmt.Errorf("decode fingerprint: %w", err)
}
if len(data) > bytes32Len {
return result, fmt.Errorf("fingerprint too long: %d bytes", len(data))
}
copy(result[:], data)
return result, nil
}
Comment thread
sadiq1971 marked this conversation as resolved.
Loading
Loading