Skip to content

Commit 29e226f

Browse files
committed
Synchronise the Amule Connection to avoid multiple requests to intersect
1 parent ecf86b4 commit 29e226f

3 files changed

Lines changed: 105 additions & 13 deletions

File tree

build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies {
2020
testImplementation("org.jetbrains.kotlin:kotlin-test-junit5")
2121
testImplementation("org.junit.jupiter:junit-jupiter-engine:5.9.3")
2222
testImplementation("ch.qos.logback:logback-classic:1.4.11")
23+
testImplementation("io.mockk:mockk:1.13.8")
2324
testImplementation("io.kotest:kotest-runner-junit5:5.7.2")
2425
testImplementation("io.kotest:kotest-runner-junit5-jvm:5.7.2")
2526
testImplementation("io.kotest.extensions:kotest-extensions-testcontainers:2.0.2")

src/main/kotlin/jamule/AmuleConnection.kt

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,20 @@ import java.io.IOException
1616
import java.net.Socket
1717

1818
internal class AmuleConnection(
19-
private val host: String,
20-
private val port: Int,
21-
private val timeout: Int,
19+
private var socketBuilder: () -> Socket,
2220
private val password: String,
2321
private val logger: Logger
2422
) {
25-
private var socket = Socket(host, port).apply { soTimeout = timeout }
2623
private var connected = false
24+
private var socket = socketBuilder()
25+
26+
internal constructor(
27+
host: String,
28+
port: Int,
29+
timeout: Int,
30+
password: String,
31+
logger: Logger
32+
) : this({ Socket(host, port).apply { soTimeout = timeout } }, password, logger)
2733

2834
@OptIn(ExperimentalUnsignedTypes::class)
2935
private val tagParser = TagParser(logger)
@@ -42,7 +48,7 @@ internal class AmuleConnection(
4248
logger.info("Reconnecting...")
4349
connected = false
4450
runCatching { socket.close() }
45-
socket = Socket(host, port).apply { soTimeout = timeout }
51+
socket = socketBuilder()
4652
authenticate()
4753
}
4854
}
@@ -59,14 +65,16 @@ internal class AmuleConnection(
5965

6066
@OptIn(ExperimentalUnsignedTypes::class)
6167
fun sendRequestNoAuth(request: Request): Response {
62-
val outputStream = socket.getOutputStream()
63-
val inputStream = socket.getInputStream().buffered()
64-
val packet = request.packet()
65-
packetWriter.write(packet, outputStream)
66-
val responsePacket = packetParser.parse(inputStream)
67-
return ResponseParser.parse(responsePacket).also {
68-
if (it is ErrorResponse) {
69-
throw ServerException(it.serverMessage)
68+
synchronized(socket) {
69+
val outputStream = socket.getOutputStream()
70+
val inputStream = socket.getInputStream().buffered()
71+
val packet = request.packet()
72+
packetWriter.write(packet, outputStream)
73+
val responsePacket = packetParser.parse(inputStream)
74+
return ResponseParser.parse(responsePacket).also {
75+
if (it is ErrorResponse) {
76+
throw ServerException(it.serverMessage)
77+
}
7078
}
7179
}
7280
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package jamule
2+
3+
import io.kotest.core.spec.style.FunSpec
4+
import io.kotest.matchers.shouldBe
5+
import io.mockk.every
6+
import io.mockk.mockk
7+
import io.mockk.verify
8+
import jamule.request.StatsRequest
9+
import org.slf4j.LoggerFactory
10+
import java.io.ByteArrayInputStream
11+
import java.io.OutputStream
12+
import java.net.Socket
13+
import java.util.concurrent.CountDownLatch
14+
15+
@OptIn(ExperimentalStdlibApi::class)
16+
class AmuleConnectionTest : FunSpec({
17+
18+
val socket = mockk<Socket>()
19+
val logger = LoggerFactory.getLogger(this::class.java)
20+
val outputStream = OutputStream.nullOutputStream()
21+
every { socket.getOutputStream() } returns outputStream
22+
every { socket.close() } returns Unit
23+
val authSaltResponse = ByteArrayInputStream("000000220000000d4f0116050855099a4aea510c43".hexToByteArray())
24+
val authOkResponse =
25+
ByteArrayInputStream("000000220000001d0401e0a8960616322e332e31204164756e616e7a4120323031322e3100".hexToByteArray())
26+
val statusResponse = ByteArrayInputStream(
27+
("000000220000008c0c10d08003021664d082020100d484020100d4860302" +
28+
"1664d488020100d48a020100d084020100d086020100d09002010" +
29+
"0d08c020100d092040400017cbbd09402010ad096040402e2740f" +
30+
"d09803020438d0b60201000b023f03e0a881081f01e0a88206124" +
31+
"16b74656f6e20536572766572204e6f3200b07de76247b50c0404" +
32+
"1d4e48541404041d4e485419")
33+
.hexToByteArray()
34+
)
35+
36+
test("single request works ok") {
37+
val amule = AmuleConnection({ socket }, "password", logger)
38+
every { socket.getInputStream() } returnsMany listOf(
39+
authSaltResponse,
40+
authOkResponse,
41+
statusResponse
42+
)
43+
amule.sendRequest(StatsRequest())
44+
// Called 3 times: 1 for salt, 1 for auth, 1 for stats
45+
verify(exactly = 3) { socket.getOutputStream() }
46+
}
47+
48+
test("multiple parallel requests are synchronised") {
49+
val amule = AmuleConnection({ socket }, "password", logger)
50+
val firstRequestArrivedLatch = CountDownLatch(1)
51+
val firstRequestLatch = CountDownLatch(1)
52+
val secondRequestLatch = CountDownLatch(1)
53+
var requestCount = 0
54+
every { socket.getInputStream() } answers {
55+
when (requestCount++) {
56+
0 -> authSaltResponse
57+
1 -> authOkResponse
58+
2 -> {
59+
firstRequestArrivedLatch.countDown()
60+
firstRequestLatch.await()
61+
statusResponse
62+
}
63+
64+
3 -> {
65+
secondRequestLatch.await()
66+
statusResponse
67+
}
68+
69+
else -> throw IllegalStateException("Unexpected request count: $requestCount")
70+
}.also { requestCount++ }
71+
}
72+
// Send two requests from two separate threads
73+
Thread { amule.sendRequest(StatsRequest()) }.start()
74+
Thread { amule.sendRequest(StatsRequest()) }.start()
75+
// Wait for the first request to arrive
76+
firstRequestArrivedLatch.await()
77+
Thread.sleep(50) // Allow for the second request to arrive if it's not synchronised
78+
requestCount shouldBe 3
79+
firstRequestLatch.countDown()
80+
secondRequestLatch.countDown()
81+
}
82+
83+
})

0 commit comments

Comments
 (0)