diff --git a/pkg/espflasher/chip.go b/pkg/espflasher/chip.go index 5ee9a4d..994c47a 100644 --- a/pkg/espflasher/chip.go +++ b/pkg/espflasher/chip.go @@ -117,6 +117,11 @@ type chipDef struct { // FlashSizes maps size strings to header byte values. FlashSizes map[string]byte + + // PostConnect is called after chip detection to perform chip-specific + // initialization (e.g. USB interface detection, watchdog disable). + // May set Flasher fields like usesUSB. + PostConnect func(f *Flasher) error } // chipDetectMagicRegAddr is the register address that has a different diff --git a/pkg/espflasher/flasher.go b/pkg/espflasher/flasher.go index 6a9db98..0909be8 100644 --- a/pkg/espflasher/flasher.go +++ b/pkg/espflasher/flasher.go @@ -105,8 +105,10 @@ type connection interface { changeBaud(newBaud, oldBaud uint32) error eraseFlash() error eraseRegion(offset, size uint32) error + readFlash(offset, size uint32) ([]byte, error) flushInput() isStub() bool + setUSB(v bool) setSupportsEncryptedFlash(v bool) loadStub(s *stub) error } @@ -119,6 +121,7 @@ type Flasher struct { chip *chipDef opts *FlasherOptions portStr string + usesUSB bool secInfo []byte // cached security info from ROM (GET_SECURITY_INFO opcode 0x14) } @@ -163,7 +166,7 @@ func New(portName string, opts *FlasherOptions) (*Flasher, error) { // Connect to the bootloader if err := f.connect(); err != nil { - port.Close() //nolint:errcheck + f.port.Close() //nolint:errcheck return nil, err } @@ -175,6 +178,31 @@ func (f *Flasher) Close() error { return f.port.Close() } +// reopenPort closes and reopens the serial port after a USB device +// re-enumeration. TinyUSB CDC devices may briefly disappear during reset. +func (f *Flasher) reopenPort() error { + f.port.Close() //nolint:errcheck + + var lastErr error + deadline := time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + time.Sleep(500 * time.Millisecond) + port, err := serial.Open(f.portStr, &serial.Mode{ + BaudRate: f.opts.BaudRate, + Parity: serial.NoParity, + DataBits: 8, + StopBits: serial.OneStopBit, + }) + if err == nil { + f.port = port + f.conn = newConn(port) + return nil + } + lastErr = err + } + return fmt.Errorf("reopen port %s: %w", f.portStr, lastErr) +} + // ChipType returns the detected chip type. func (f *Flasher) ChipType() ChipType { if f.chip != nil { @@ -234,6 +262,11 @@ func (f *Flasher) connect() error { } time.Sleep(50 * time.Millisecond) } + + // Sync failed — try reopening port (USB CDC may have re-enumerated) + if err := f.reopenPort(); err != nil { + continue // port reopen failed, try next attempt + } } return &SyncError{Attempts: attempts} @@ -258,9 +291,21 @@ synced: f.logf("Detected chip: %s", f.chip.Name) + // Run chip-specific post-connect initialization. + if f.chip.PostConnect != nil { + if err := f.chip.PostConnect(f); err != nil { + f.logf("Warning: post-connect: %v", err) + } + } + // Propagate chip capabilities to the connection layer. f.conn.setSupportsEncryptedFlash(f.chip.SupportsEncryptedFlash) + // Propagate USB flag to connection layer for block size optimization. + if f.usesUSB { + f.conn.setUSB(true) + } + // Upload the stub loader to enable advanced features (erase, compression, etc.). if s, ok := stubFor(f.chip.ChipType); ok { f.logf("Loading stub loader...") @@ -717,6 +762,20 @@ func (f *Flasher) GetMD5(offset, size uint32) (string, error) { return hex.EncodeToString(result), nil } +// ReadFlash reads data from flash memory. +// Requires the stub loader to be running. +func (f *Flasher) ReadFlash(offset, size uint32) ([]byte, error) { + if !f.conn.isStub() { + return nil, &UnsupportedCommandError{Command: "read flash (requires stub)"} + } + + if err := f.attachFlash(); err != nil { + return nil, err + } + + return f.conn.readFlash(offset, size) +} + // Reset performs a hard reset of the device, causing it to run user code. func (f *Flasher) Reset() { if f.conn.isStub() { @@ -731,7 +790,7 @@ func (f *Flasher) Reset() { // CMD_FLASH_BEGIN after a compressed download may interfere with // the flash controller state at offset 0. esptool also just does // a hard reset without any flash commands for the ROM path. - hardReset(f.port, false) + hardReset(f.port, f.usesUSB) f.logf("Device reset.") } diff --git a/pkg/espflasher/flasher_test.go b/pkg/espflasher/flasher_test.go index ec66913..99b9dd0 100644 --- a/pkg/espflasher/flasher_test.go +++ b/pkg/espflasher/flasher_test.go @@ -449,6 +449,21 @@ func TestGetMD5RequiresStub(t *testing.T) { } } +func TestReadFlashRequiresStub(t *testing.T) { + mock := &mockConnection{} + mock.stubMode = false // ROM mode + f := &Flasher{conn: mock, chip: chipDefs[ChipESP32]} + _, err := f.ReadFlash(0, 1024) + if err == nil { + t.Fatal("expected error when stub is not running") + } + if ue, ok := err.(*UnsupportedCommandError); !ok { + t.Errorf("expected UnsupportedCommandError, got %T: %v", err, err) + } else if ue.Command != "read flash (requires stub)" { + t.Errorf("unexpected error message: %s", ue.Command) + } +} + func TestGetSecurityInfo(t *testing.T) { secInfo := make([]byte, 20) binary.LittleEndian.PutUint32(secInfo[0:4], 0x05) diff --git a/pkg/espflasher/protocol.go b/pkg/espflasher/protocol.go index 4e97873..4c9354e 100644 --- a/pkg/espflasher/protocol.go +++ b/pkg/espflasher/protocol.go @@ -59,6 +59,9 @@ const ( // flashSectorSize is the minimum flash erase unit. flashSectorSize uint32 = 0x1000 // 4KB + // readFlashBlockSize is the block size for read flash operations. + readFlashBlockSize uint32 = 0x1000 // 4KB + // espImageMagic is the first byte of a valid ESP firmware image. espImageMagic byte = 0xE9 @@ -79,7 +82,8 @@ const ( type conn struct { port serial.Port reader *slipReader - stub bool + stub bool + usesUSB bool // set for USB-OTG and USB-JTAG/Serial connections // supportsEncryptedFlash indicates the ROM supports the 5th parameter // (encrypted flag) in flash_begin/flash_defl_begin commands. // Set based on chip type after detection. @@ -91,6 +95,11 @@ func (c *conn) isStub() bool { return c.stub } +// setUSB sets whether the connection uses USB-OTG or USB-JTAG endpoints. +func (c *conn) setUSB(v bool) { + c.usesUSB = v +} + // setSupportsEncryptedFlash sets whether the ROM supports encrypted flash commands. func (c *conn) setSupportsEncryptedFlash(v bool) { c.supportsEncryptedFlash = v @@ -125,8 +134,20 @@ func (c *conn) sendCommand(opcode byte, data []byte, chk uint32) error { copy(pkt[8:], data) frame := slipEncode(pkt) - _, err := c.port.Write(frame) - return err + // USB CDC endpoints have limited buffer sizes. Writing large SLIP frames + // in one shot can overflow the endpoint buffer and cause data loss. + // Chunk writes to 64 bytes (standard USB Full Speed bulk endpoint size). + const maxChunk = 64 + for off := 0; off < len(frame); off += maxChunk { + end := off + maxChunk + if end > len(frame) { + end = len(frame) + } + if _, err := c.port.Write(frame[off:end]); err != nil { + return err + } + } + return nil } // commandResponse represents a parsed response from the ESP device. @@ -542,6 +563,51 @@ func (c *conn) eraseRegion(offset, size uint32) error { return err } +// readFlash reads data from flash memory (stub-only). +func (c *conn) readFlash(offset, size uint32) ([]byte, error) { + data := make([]byte, 16) + binary.LittleEndian.PutUint32(data[0:4], offset) + binary.LittleEndian.PutUint32(data[4:8], size) + binary.LittleEndian.PutUint32(data[8:12], readFlashBlockSize) + binary.LittleEndian.PutUint32(data[12:16], 64) // max_inflight (stub clamps to 1) + + if _, err := c.checkCommand("read flash", cmdReadFlash, data, 0, defaultTimeout, 0); err != nil { + return nil, err + } + + blockTimeout := defaultTimeout + time.Duration(readFlashBlockSize/256)*100*time.Millisecond + numBlocks := (size + readFlashBlockSize - 1) / readFlashBlockSize + result := make([]byte, 0, size) + + for i := uint32(0); i < numBlocks; i++ { + // Read SLIP-framed data block + block, err := c.reader.ReadFrame(blockTimeout) + if err != nil { + return nil, fmt.Errorf("read flash block %d/%d: %w", i+1, numBlocks, err) + } + result = append(result, block...) + + // Send ACK: cumulative bytes received (SLIP-framed) + ack := make([]byte, 4) + binary.LittleEndian.PutUint32(ack, uint32(len(result))) + ackFrame := slipEncode(ack) + if _, err := c.port.Write(ackFrame); err != nil { + return nil, fmt.Errorf("read flash ACK %d/%d: %w", i+1, numBlocks, err) + } + } + + // Read final 16-byte MD5 digest (SLIP-framed) + _, err := c.reader.ReadFrame(defaultTimeout) + if err != nil { + return nil, fmt.Errorf("read flash MD5: %w", err) + } + + if uint32(len(result)) > size { + result = result[:size] + } + return result, nil +} + // flashWriteSize returns the appropriate block size based on loader type. func (c *conn) flashWriteSize() uint32 { if c.stub { @@ -602,17 +668,24 @@ func (c *conn) loadStub(s *stub) error { // uploadToRAM writes a binary segment to the device's RAM via mem_begin/mem_data. func (c *conn) uploadToRAM(data []byte, addr uint32) error { + // USB CDC endpoints have limited buffer sizes. Use 1KB blocks for + // USB connections instead of the default 6KB to avoid timeout. + blockSize := espRAMBlock + if c.usesUSB { + blockSize = 0x400 // 1KB + } + dataLen := uint32(len(data)) - numBlocks := (dataLen + espRAMBlock - 1) / espRAMBlock + numBlocks := (dataLen + blockSize - 1) / blockSize - if err := c.memBegin(dataLen, numBlocks, espRAMBlock, addr); err != nil { + if err := c.memBegin(dataLen, numBlocks, blockSize, addr); err != nil { return err } seq := uint32(0) offset := uint32(0) for offset < dataLen { - end := offset + espRAMBlock + end := offset + blockSize if end > dataLen { end = dataLen } diff --git a/pkg/espflasher/protocol_test.go b/pkg/espflasher/protocol_test.go index dcb62e6..05a0528 100644 --- a/pkg/espflasher/protocol_test.go +++ b/pkg/espflasher/protocol_test.go @@ -442,7 +442,9 @@ type mockConnection struct { eraseRegionFunc func(offset, size uint32) error flushInputFunc func() loadStubFunc func(s *stub) error + readFlashFunc func(offset, size uint32) ([]byte, error) stubMode bool + usbMode bool supportsEncryptedFlashValue bool } @@ -571,10 +573,21 @@ func (m *mockConnection) flushInput() { } } +func (m *mockConnection) readFlash(offset, size uint32) ([]byte, error) { + if m.readFlashFunc != nil { + return m.readFlashFunc(offset, size) + } + return nil, nil +} + func (m *mockConnection) isStub() bool { return m.stubMode } +func (m *mockConnection) setUSB(v bool) { + m.usbMode = v +} + func (m *mockConnection) setSupportsEncryptedFlash(v bool) { m.supportsEncryptedFlashValue = v } @@ -609,3 +622,117 @@ func makeChangeBaudResponse() []byte { resp[9] = 0x00 return resp } + +func TestSendCommandChunking(t *testing.T) { + // Verify that sendCommand writes in chunks <= 64 bytes + var writes [][]byte + mock := &mockPort{ + writeFunc: func(data []byte) (int, error) { + // Record each write call + chunk := make([]byte, len(data)) + copy(chunk, data) + writes = append(writes, chunk) + return len(data), nil + }, + } + + c := &conn{ + port: mock, + reader: &slipReader{port: mock}, + } + + // Create test data that will result in a large SLIP frame + testData := make([]byte, 256) // Large payload + for i := range testData { + testData[i] = byte(i % 256) + } + chk := uint32(0xDEADBEEF) + + err := c.sendCommand(cmdFlashData, testData, chk) + if err != nil { + t.Fatalf("sendCommand failed: %v", err) + } + + // Verify that we got multiple writes + if len(writes) < 2 { + t.Errorf("expected multiple writes for large frame, got %d", len(writes)) + } + + // Verify each write is <= 64 bytes + const maxChunk = 64 + for i, w := range writes { + if len(w) > maxChunk { + t.Errorf("write[%d] = %d bytes, want <= %d", i, len(w), maxChunk) + } + } + + // Verify the reassembled frame decodes correctly + var reassembled []byte + for _, w := range writes { + reassembled = append(reassembled, w...) + } + decoded := slipDecode(reassembled) + + // Check that we got the right opcode and data + if len(decoded) < 8 { + t.Fatalf("decoded frame too short: %d bytes", len(decoded)) + } + if decoded[1] != cmdFlashData { + t.Errorf("opcode = 0x%02X, want 0x%02X", decoded[1], cmdFlashData) + } +} + +func TestUploadToRAMUSBBlockSize(t *testing.T) { + // Verify that USB connections use 1KB block size and regular use 6KB + tests := []struct { + name string + usesUSB bool + dataSize uint32 + expectedBS uint32 + expectedNum uint32 + }{ + {"non-USB 6144 bytes uses 6KB blocks", false, 6144, 0x1800, 1}, + {"non-USB 12288 bytes uses 6KB blocks", false, 12288, 0x1800, 2}, + {"USB 1024 bytes uses 1KB blocks", true, 1024, 0x400, 1}, + {"USB 2048 bytes uses 1KB blocks", true, 2048, 0x400, 2}, + {"USB 6144 bytes split into 1KB blocks", true, 6144, 0x400, 6}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // We'll verify block size by checking the calculation logic + blockSize := espRAMBlock + if tt.usesUSB { + blockSize = 0x400 + } + + dataLen := tt.dataSize + numBlocks := (dataLen + blockSize - 1) / blockSize + + if blockSize != tt.expectedBS { + t.Errorf("block size = %d, want %d", blockSize, tt.expectedBS) + } + + if numBlocks != tt.expectedNum { + t.Errorf("num blocks = %d, want %d", numBlocks, tt.expectedNum) + } + }) + } +} + +func TestReadFlashBlockSize(t *testing.T) { + // Verify readFlashBlockSize constant + if readFlashBlockSize != 0x1000 { + t.Errorf("readFlashBlockSize = 0x%X, want 0x1000", readFlashBlockSize) + } +} + +func TestReadFlashParameterValidation(t *testing.T) { + // Verify that readFlash uses correct command opcode + if cmdReadFlash != 0xD2 { + t.Errorf("cmdReadFlash = 0x%02X, want 0xD2", cmdReadFlash) + } + if cmdReadFlash != 0xD2 { + t.Errorf("cmdReadFlash opcode mismatch") + } +} diff --git a/pkg/espflasher/slip.go b/pkg/espflasher/slip.go index 28753ba..fd20659 100644 --- a/pkg/espflasher/slip.go +++ b/pkg/espflasher/slip.go @@ -68,7 +68,8 @@ func slipDecode(frame []byte) []byte { // slipReader reads complete SLIP frames from a serial port. type slipReader struct { - port serial.Port + port serial.Port + leftover []byte } // newSlipReader creates a SLIP frame reader for the given serial port. @@ -84,35 +85,12 @@ func (r *slipReader) ReadFrame(timeout time.Duration) ([]byte, error) { inFrame := false inEscape := false - buf := make([]byte, 256) - - for time.Now().Before(deadline) { - remaining := time.Until(deadline) - if remaining <= 0 { - break - } - - readTimeout := min(remaining, 100*time.Millisecond) - r.port.SetReadTimeout(readTimeout) - - n, err := r.port.Read(buf) - if err != nil && err != io.EOF { - // On timeout, continue; on real error, return - if n == 0 { - continue - } - } - if n == 0 { - continue - } - - for i := range n { - b := buf[i] - + processBytes := func(data []byte) ([]byte, bool, error) { + for i, b := range data { if !inFrame { if b == slipEnd { inFrame = true - partial = partial[:0] // reset + partial = partial[:0] } continue } @@ -125,7 +103,7 @@ func (r *slipReader) ReadFrame(timeout time.Duration) ([]byte, error) { case slipEscEsc: partial = append(partial, slipEsc) default: - return nil, fmt.Errorf("invalid SLIP escape: 0xDB 0x%02X", b) + return nil, false, fmt.Errorf("invalid SLIP escape: 0xDB 0x%02X", b) } continue } @@ -135,15 +113,53 @@ func (r *slipReader) ReadFrame(timeout time.Duration) ([]byte, error) { if len(partial) > 0 { result := make([]byte, len(partial)) copy(result, partial) - return result, nil + remaining := data[i+1:] + r.leftover = make([]byte, len(remaining)) + copy(r.leftover, remaining) + return result, true, nil } - // Empty frame, keep reading case slipEsc: inEscape = true default: partial = append(partial, b) } } + return nil, false, nil + } + + // Process leftover bytes first + if len(r.leftover) > 0 { + saved := r.leftover + r.leftover = nil + if result, done, err := processBytes(saved); err != nil { + return nil, err + } else if done { + return result, nil + } + } + + buf := make([]byte, 256) + for time.Now().Before(deadline) { + remaining := time.Until(deadline) + if remaining <= 0 { + break + } + readTimeout := min(remaining, 100*time.Millisecond) + r.port.SetReadTimeout(readTimeout) //nolint:errcheck + n, err := r.port.Read(buf) + if err != nil && err != io.EOF { + if n == 0 { + continue + } + } + if n == 0 { + continue + } + if result, done, err := processBytes(buf[:n]); err != nil { + return nil, err + } else if done { + return result, nil + } } return nil, &TimeoutError{Op: "SLIP read"}