diff --git a/deepq_mineral_shards.py b/deepq_mineral_shards.py index fe6d6f6..22bf71e 100644 --- a/deepq_mineral_shards.py +++ b/deepq_mineral_shards.py @@ -17,7 +17,7 @@ from pysc2.lib import features from pysc2.lib import actions -import gflags as flags +from absl import flags _PLAYER_RELATIVE = features.SCREEN_FEATURES.player_relative.index _PLAYER_FRIENDLY = 1 diff --git a/defeat_zerglings/dqfd.py b/defeat_zerglings/dqfd.py index b06a36a..a4246b2 100644 --- a/defeat_zerglings/dqfd.py +++ b/defeat_zerglings/dqfd.py @@ -19,7 +19,7 @@ from defeat_zerglings import common -import gflags as flags +from absl import flags _PLAYER_RELATIVE = features.SCREEN_FEATURES.player_relative.index diff --git a/defeat_zerglings/run_demo_agent.py b/defeat_zerglings/run_demo_agent.py index c312dcc..3d06252 100644 --- a/defeat_zerglings/run_demo_agent.py +++ b/defeat_zerglings/run_demo_agent.py @@ -1,6 +1,6 @@ import sys -import gflags as flags +from absl import flags from baselines import deepq from pysc2.env import sc2_env from pysc2.lib import actions @@ -22,9 +22,14 @@ def main(): FLAGS(sys.argv) with sc2_env.SC2Env( - "DefeatZerglingsAndBanelings", + map_name="DefeatZerglingsAndBanelings", step_mul=step_mul, visualize=True, + players=[sc2_env.Agent(sc2_env.Race.terran)], + agent_interface_format=sc2_env.AgentInterfaceFormat( + feature_dimensions=sc2_env.Dimensions( + screen=64, + minimap=64)), game_steps_per_episode=steps * step_mul) as env: demo_replay = [] diff --git a/defeat_zerglings/train.py b/defeat_zerglings/train.py index 731cf9d..d65b777 100644 --- a/defeat_zerglings/train.py +++ b/defeat_zerglings/train.py @@ -1,6 +1,6 @@ import sys -import gflags as flags +from absl import flags from baselines import deepq from pysc2.env import sc2_env from pysc2.lib import actions @@ -20,9 +20,14 @@ def main(): FLAGS(sys.argv) with sc2_env.SC2Env( - "DefeatZerglingsAndBanelings", + map_name="DefeatZerglingsAndBanelings", step_mul=step_mul, visualize=True, + players=[sc2_env.Agent(sc2_env.Race.terran)], + agent_interface_format=sc2_env.AgentInterfaceFormat( + feature_dimensions=sc2_env.Dimensions( + screen=64, + minimap=64)), game_steps_per_episode=steps * step_mul) as env: model = deepq.models.cnn_to_mlp( diff --git a/enjoy_mineral_shards.py b/enjoy_mineral_shards.py index 1d1ad49..0f6f4d5 100644 --- a/enjoy_mineral_shards.py +++ b/enjoy_mineral_shards.py @@ -1,7 +1,7 @@ import sys import baselines.common.tf_util as U -import gflags as flags +from absl import flags import numpy as np from baselines import deepq from pysc2.env import environment @@ -31,9 +31,14 @@ def main(): FLAGS(sys.argv) with sc2_env.SC2Env( - "CollectMineralShards", + map_name="CollectMineralShards", step_mul=step_mul, visualize=True, + players=[sc2_env.Agent(sc2_env.Race.terran)], + agent_interface_format=sc2_env.AgentInterfaceFormat( + feature_dimensions=sc2_env.Dimensions( + screen=64, + minimap=64)), game_steps_per_episode=steps * step_mul) as env: model = deepq.models.cnn_to_mlp( diff --git a/tests/scripted_test.py b/tests/scripted_test.py index 2394500..4474384 100644 --- a/tests/scripted_test.py +++ b/tests/scripted_test.py @@ -12,7 +12,7 @@ from pysc2.lib import features from pysc2.lib import basetest -import gflags as flags +from absl import flags import sys _NO_OP = sc2_actions.FUNCTIONS.no_op.id @@ -28,9 +28,14 @@ def test_defeat_zerglings(self): FLAGS(sys.argv) with sc2_env.SC2Env( - "DefeatZerglingsAndBanelings", + map_name="DefeatZerglingsAndBanelings", step_mul=self.step_mul, visualize=True, + players=[sc2_env.Agent(sc2_env.Race.terran)], + agent_interface_format=sc2_env.AgentInterfaceFormat( + feature_dimensions=sc2_env.Dimensions( + screen=64, + minimap=64)), game_steps_per_episode=self.steps * self.step_mul) as env: obs = env.step(actions=[sc2_actions.FunctionCall(_NO_OP, [])]) player_relative = obs[0].observation["screen"][_PLAYER_RELATIVE] diff --git a/train_mineral_shards.py b/train_mineral_shards.py index 79465a5..1d0eaa5 100644 --- a/train_mineral_shards.py +++ b/train_mineral_shards.py @@ -1,6 +1,6 @@ import sys -import gflags as flags +from absl import flags from baselines import deepq from pysc2.env import sc2_env from pysc2.lib import actions @@ -19,8 +19,13 @@ def main(): FLAGS(sys.argv) with sc2_env.SC2Env( - "CollectMineralShards", + map_name="CollectMineralShards", step_mul=step_mul, + players=[sc2_env.Agent(sc2_env.Race.terran)], + agent_interface_format=sc2_env.AgentInterfaceFormat( + feature_dimensions=sc2_env.Dimensions( + screen=64, + minimap=64)), visualize=True) as env: model = deepq.models.cnn_to_mlp(