Skip to content

Commit f687116

Browse files
Merge pull request #169 from 11michalis11/state_tracking_steady_state
State tracking steady state
2 parents 4367cda + 1760522 commit f687116

2 files changed

Lines changed: 399 additions & 1 deletion

File tree

ciw/tests/test_state_tracker.py

Lines changed: 350 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,4 +865,353 @@ def test_track_history_two_node_two_class(self):
865865
]
866866
self.assertEqual(len(observed_history), len(expected_history))
867867
for obs, exp in zip(observed_history, expected_history):
868-
self.assertEqual([round(obs[0], 2), obs[1]], exp)
868+
self.assertEqual([round(obs[0], 2), obs[1]], exp)
869+
870+
871+
class TestStateProbabilities(unittest.TestCase):
872+
def test_prob_one_node_deterministic_naiveblocking(self):
873+
N = ciw.create_network(
874+
arrival_distributions=[ciw.dists.Sequential([1.5, 0.3, 2.4, 1.1])],
875+
service_distributions=[ciw.dists.Sequential([1.8, 2.2, 0.2, 0.2, 0.2, 0.2])],
876+
number_of_servers=[1]
877+
)
878+
B = ciw.trackers.NaiveBlocking()
879+
Q = ciw.Simulation(N, tracker=B, exact=26)
880+
Q.simulate_until_max_time(15.5)
881+
expected_probabilities = {
882+
((0, 0),): Decimal('0.38157894736842105263157895'),
883+
((1, 0),): Decimal('0.26973684210526315789473684'),
884+
((2, 0),): Decimal('0.26315789473684210526315789'),
885+
((3, 0),): Decimal('0.085526315789473684210526316')
886+
}
887+
888+
expected_probabilities_with_time_period = {
889+
((2, 0),): Decimal('0.1'),
890+
((3, 0),): Decimal('0.04'),
891+
((1, 0),): Decimal('0.22'),
892+
((0, 0),): Decimal('0.64')
893+
}
894+
895+
for state in expected_probabilities:
896+
self.assertEqual(
897+
Q.statetracker.state_probabilities()[state],
898+
expected_probabilities[state]
899+
)
900+
901+
for state in expected_probabilities_with_time_period:
902+
self.assertEqual(
903+
Q.statetracker.state_probabilities(observation_period=(5,10))[state],
904+
expected_probabilities_with_time_period[state]
905+
)
906+
907+
def test_prob_one_node_deterministic_systempopulation(self):
908+
N = ciw.create_network(
909+
arrival_distributions=[ciw.dists.Sequential([1.5, 0.3, 2.4, 1.1])],
910+
service_distributions=[ciw.dists.Sequential([1.8, 2.2, 0.2, 0.2, 0.2, 0.2])],
911+
number_of_servers=[1]
912+
)
913+
B = ciw.trackers.SystemPopulation()
914+
Q = ciw.Simulation(N, tracker=B, exact=26)
915+
Q.simulate_until_max_time(15.5)
916+
expected_probabilities = {
917+
0: Decimal('0.38157894736842105263157895'),
918+
1: Decimal('0.26973684210526315789473684'),
919+
2: Decimal('0.26315789473684210526315789'),
920+
3: Decimal('0.085526315789473684210526316')
921+
}
922+
expected_probabilities_with_time_period = {
923+
2: Decimal('0.1'),
924+
3: Decimal('0.04'),
925+
1: Decimal('0.22'),
926+
0: Decimal('0.64')
927+
}
928+
929+
for state in expected_probabilities:
930+
self.assertEqual(
931+
Q.statetracker.state_probabilities()[state],
932+
expected_probabilities[state]
933+
)
934+
935+
for state in expected_probabilities_with_time_period:
936+
self.assertEqual(
937+
Q.statetracker.state_probabilities(observation_period=(5,10))[state],
938+
expected_probabilities_with_time_period[state]
939+
)
940+
941+
def test_prob_one_node_deterministic_nodepopulation(self):
942+
N = ciw.create_network(
943+
arrival_distributions=[ciw.dists.Sequential([1.5, 0.3, 2.4, 1.1])],
944+
service_distributions=[ciw.dists.Sequential([1.8, 2.2, 0.2, 0.2, 0.2, 0.2])],
945+
number_of_servers=[1]
946+
)
947+
B = ciw.trackers.NodePopulation()
948+
Q = ciw.Simulation(N, tracker=B, exact=26)
949+
Q.simulate_until_max_time(15.5)
950+
expected_probabilities = {
951+
(0,): Decimal('0.38157894736842105263157895'),
952+
(1,): Decimal('0.26973684210526315789473684'),
953+
(2,): Decimal('0.26315789473684210526315789'),
954+
(3,): Decimal('0.085526315789473684210526316')
955+
}
956+
expected_probabilities_with_time_period = {
957+
(2,): Decimal('0.1'),
958+
(3,): Decimal('0.04'),
959+
(1,): Decimal('0.22'),
960+
(0,): Decimal('0.64')
961+
}
962+
for state in expected_probabilities:
963+
self.assertEqual(
964+
Q.statetracker.state_probabilities()[state],
965+
expected_probabilities[state]
966+
)
967+
968+
for state in expected_probabilities_with_time_period:
969+
self.assertEqual(
970+
Q.statetracker.state_probabilities(observation_period=(5,10))[state],
971+
expected_probabilities_with_time_period[state]
972+
)
973+
974+
def test_prob_one_node_deterministic_nodeclassmatrix(self):
975+
N = ciw.create_network(
976+
arrival_distributions=[ciw.dists.Sequential([1.5, 0.3, 2.4, 1.1])],
977+
service_distributions=[ciw.dists.Sequential([1.8, 2.2, 0.2, 0.2, 0.2, 0.2])],
978+
number_of_servers=[1]
979+
)
980+
B = ciw.trackers.NodeClassMatrix()
981+
Q = ciw.Simulation(N, tracker=B, exact=26)
982+
Q.simulate_until_max_time(15.5)
983+
expected_probabilities = {
984+
((0,),): Decimal('0.38157894736842105263157895'),
985+
((1,),): Decimal('0.26973684210526315789473684'),
986+
((2,),): Decimal('0.26315789473684210526315789'),
987+
((3,),): Decimal('0.085526315789473684210526316')
988+
}
989+
expected_probabilities_with_time_period = {
990+
((2,),): Decimal('0.1'),
991+
((3,),): Decimal('0.04'),
992+
((1,),): Decimal('0.22'),
993+
((0,),): Decimal('0.64')
994+
}
995+
for state in expected_probabilities:
996+
self.assertEqual(
997+
Q.statetracker.state_probabilities()[state],
998+
expected_probabilities[state]
999+
)
1000+
1001+
for state in expected_probabilities_with_time_period:
1002+
self.assertEqual(
1003+
Q.statetracker.state_probabilities(observation_period=(5,10))[state],
1004+
expected_probabilities_with_time_period[state]
1005+
)
1006+
1007+
def test_prob_track_history_two_node_two_class(self):
1008+
N = ciw.create_network(
1009+
arrival_distributions={
1010+
'Class 0': [ciw.dists.Exponential(0.5), ciw.dists.Exponential(0.5)],
1011+
'Class 1': [ciw.dists.Exponential(0.5), ciw.dists.Exponential(0.5)]},
1012+
service_distributions={
1013+
'Class 0': [ciw.dists.Exponential(1), ciw.dists.Exponential(1)],
1014+
'Class 1': [ciw.dists.Exponential(1), ciw.dists.Exponential(1)]},
1015+
number_of_servers=[1, 1],
1016+
routing={
1017+
'Class 0': [[0.2, 0.2], [0.2, 0.2]],
1018+
'Class 1': [[0.2, 0.2], [0.2, 0.2]]}
1019+
)
1020+
1021+
# System Population
1022+
ciw.seed(0)
1023+
Q = ciw.Simulation(N, tracker=ciw.trackers.SystemPopulation())
1024+
Q.simulate_until_max_time(5)
1025+
expected_probabilities = {
1026+
0: 0.1366944915680613,
1027+
1: 0.11225559969470539,
1028+
2: 0.3240424959070235,
1029+
3: 0.22414907894503974,
1030+
4: 0.04813145483931205,
1031+
5: 0.13662855193885537,
1032+
6: 0.018098327107002654
1033+
}
1034+
expected_probabilities_with_time_period = {
1035+
1: 0.030475430361061855,
1036+
2: 0.47354672264790626,
1037+
3: 0.2921852718539804,
1038+
4: 0.008450291962042685,
1039+
5: 0.16889388953780524,
1040+
6: 0.026448393637203527
1041+
}
1042+
1043+
for state in expected_probabilities:
1044+
self.assertEqual(
1045+
Q.statetracker.state_probabilities()[state],
1046+
expected_probabilities[state]
1047+
)
1048+
1049+
for state in expected_probabilities_with_time_period:
1050+
self.assertEqual(
1051+
Q.statetracker.state_probabilities(observation_period=(1, 4))[state],
1052+
expected_probabilities_with_time_period[state]
1053+
)
1054+
1055+
# Node Population
1056+
ciw.seed(0)
1057+
Q = ciw.Simulation(N, tracker=ciw.trackers.NodePopulation())
1058+
Q.simulate_until_max_time(5)
1059+
expected_probabilities = {
1060+
(0, 0): 0.1366944915680613,
1061+
(0, 1): 0.11225559969470539,
1062+
(0, 2): 0.3240424959070235,
1063+
(0, 3): 0.0983851025458379,
1064+
(1, 2): 0.12576397639920184,
1065+
(1, 3): 0.02999254345852149,
1066+
(2, 3): 0.13662855193885537,
1067+
(3, 3): 0.018098327107002654,
1068+
(2, 2): 0.018138911380790552
1069+
}
1070+
expected_probabilities_with_time_period = {
1071+
(0, 1): 0.030475430361061855,
1072+
(0, 2): 0.47354672264790626,
1073+
(0, 3): 0.10839728230553976,
1074+
(1, 2): 0.18378798954844067,
1075+
(1, 3): 0.008450291962042685,
1076+
(2, 3): 0.16889388953780524,
1077+
(3, 3): 0.026448393637203527
1078+
}
1079+
1080+
for state in expected_probabilities:
1081+
self.assertEqual(
1082+
Q.statetracker.state_probabilities()[state],
1083+
expected_probabilities[state]
1084+
)
1085+
1086+
for state in expected_probabilities_with_time_period:
1087+
self.assertEqual(
1088+
Q.statetracker.state_probabilities(observation_period=(1,4))[state],
1089+
expected_probabilities_with_time_period[state]
1090+
)
1091+
1092+
# Node Class Matrix
1093+
ciw.seed(0)
1094+
Q = ciw.Simulation(N, tracker=ciw.trackers.NodeClassMatrix())
1095+
Q.simulate_until_max_time(5)
1096+
1097+
expected_probabilities = {((0, 0), (0, 0)): 0.1366944915680613,
1098+
((0, 0), (0, 1)): 0.11225559969470539,
1099+
((0, 0), (1, 1)): 0.12454611130253443,
1100+
((0, 0), (1, 2)): 0.0983851025458379,
1101+
((0, 0), (0, 2)): 0.19949638460448907,
1102+
((0, 1), (0, 2)): 0.12576397639920184,
1103+
((0, 1), (0, 3)): 0.005782436172743465,
1104+
((0, 2), (0, 3)): 0.07008049325290676,
1105+
((1, 2), (0, 3)): 0.018098327107002654,
1106+
((1, 1), (0, 3)): 0.06259720833539882,
1107+
((1, 1), (0, 2)): 0.018138911380790552,
1108+
((1, 1), (1, 2)): 0.003950850350549779,
1109+
((1, 0), (1, 2)): 0.024210107285778028
1110+
}
1111+
expected_probabilities_with_time_period = {
1112+
((0, 0), (0, 1)): 0.030475430361061855,
1113+
((0, 0), (1, 1)): 0.18200823524942553,
1114+
((0, 0), (1, 2)): 0.10839728230553976,
1115+
((0, 0), (0, 2)): 0.29153848739848076,
1116+
((0, 1), (0, 2)): 0.18378798954844067,
1117+
((0, 1), (0, 3)): 0.008450291962042685,
1118+
((0, 2), (0, 3)): 0.10241369055182432,
1119+
((1, 2), (0, 3)): 0.026448393637203527,
1120+
((1, 1), (0, 3)): 0.0664801989859809
1121+
}
1122+
1123+
for state in expected_probabilities:
1124+
self.assertEqual(
1125+
Q.statetracker.state_probabilities()[state],
1126+
expected_probabilities[state]
1127+
)
1128+
1129+
for state in expected_probabilities_with_time_period:
1130+
self.assertEqual(
1131+
Q.statetracker.state_probabilities(observation_period=(1,4))[state],
1132+
expected_probabilities_with_time_period[state]
1133+
)
1134+
1135+
def test_compare_state_probabilities_to_analytical(self):
1136+
#Example: λ = 1, μ = 3
1137+
lamda = 1
1138+
mu = 3
1139+
ciw.seed(0)
1140+
N = ciw.create_network(
1141+
arrival_distributions=[ciw.dists.Exponential(lamda)],
1142+
service_distributions=[ciw.dists.Exponential(mu)],
1143+
number_of_servers=[1]
1144+
)
1145+
Q = ciw.Simulation(N, tracker=ciw.trackers.SystemPopulation())
1146+
Q.simulate_until_max_time(20000)
1147+
state_probs = Q.statetracker.state_probabilities(observation_period=(500, 20000))
1148+
1149+
vec = [(lamda/mu)**i for i in sorted(state_probs.keys())]
1150+
expected_probs = [v / sum(vec) for v in vec]
1151+
1152+
for state in state_probs:
1153+
self.assertEqual(round(state_probs[state], 2), round(expected_probs[state], 2))
1154+
1155+
error_squared = sum([(state_probs[i] - expected_probs[i])**2 for i in sorted(state_probs.keys())])
1156+
self.assertEqual(round(error_squared, 4), 0)
1157+
1158+
1159+
#Example: λ = 1, μ = 4
1160+
lamda = 1
1161+
mu = 4
1162+
ciw.seed(0)
1163+
N = ciw.create_network(
1164+
arrival_distributions=[ciw.dists.Exponential(lamda)],
1165+
service_distributions=[ciw.dists.Exponential(mu)],
1166+
number_of_servers=[1]
1167+
)
1168+
Q = ciw.Simulation(N, tracker=ciw.trackers.SystemPopulation())
1169+
Q.simulate_until_max_time(20000)
1170+
state_probs = Q.statetracker.state_probabilities(observation_period=(500, 20000))
1171+
1172+
vec = [(lamda/mu)**i for i in sorted(state_probs.keys())]
1173+
expected_probs = [v / sum(vec) for v in vec]
1174+
1175+
for state in state_probs:
1176+
self.assertEqual(round(state_probs[state], 2), round(expected_probs[state], 2))
1177+
1178+
error_squared = sum([(state_probs[i] - expected_probs[i])**2 for i in sorted(state_probs.keys())])
1179+
self.assertEqual(round(error_squared, 4), 0)
1180+
1181+
1182+
#Example: λ = 1, μ = 5
1183+
lamda = 1
1184+
mu = 5
1185+
ciw.seed(0)
1186+
N = ciw.create_network(
1187+
arrival_distributions=[ciw.dists.Exponential(lamda)],
1188+
service_distributions=[ciw.dists.Exponential(mu)],
1189+
number_of_servers=[1]
1190+
)
1191+
Q = ciw.Simulation(N, tracker=ciw.trackers.SystemPopulation())
1192+
Q.simulate_until_max_time(20000)
1193+
state_probs = Q.statetracker.state_probabilities(observation_period=(500, 20000))
1194+
1195+
vec = [(lamda/mu)**i for i in sorted(state_probs.keys())]
1196+
expected_probs = [v / sum(vec) for v in vec]
1197+
1198+
for state in state_probs:
1199+
self.assertEqual(round(state_probs[state], 2), round(expected_probs[state], 2))
1200+
1201+
error_squared = sum([(state_probs[i] - expected_probs[i])**2 for i in sorted(state_probs.keys())])
1202+
self.assertEqual(round(error_squared, 4), 0)
1203+
1204+
def test_error_checking_for_state_probabilities(self):
1205+
ciw.seed(0)
1206+
N = ciw.create_network(
1207+
arrival_distributions=[ciw.dists.Exponential(1)],
1208+
service_distributions=[ciw.dists.Exponential(2)],
1209+
number_of_servers=[1]
1210+
)
1211+
Q = ciw.Simulation(N, tracker=ciw.trackers.SystemPopulation())
1212+
Q.simulate_until_max_time(10)
1213+
1214+
self.assertRaises(ValueError, Q.statetracker.state_probabilities, (-1, 5))
1215+
self.assertRaises(ValueError, Q.statetracker.state_probabilities, (4, 2))
1216+
self.assertRaises(ValueError, Q.statetracker.state_probabilities, (-1, -4))
1217+
self.assertRaises(ValueError, Q.statetracker.state_probabilities, (3, 3))

0 commit comments

Comments
 (0)