diff --git a/tests/pyathena/sqlalchemy/test_base.py b/tests/pyathena/sqlalchemy/test_base.py index cb81c0a0..70920244 100644 --- a/tests/pyathena/sqlalchemy/test_base.py +++ b/tests/pyathena/sqlalchemy/test_base.py @@ -46,7 +46,7 @@ def test_reflect_no_such_table(self, engine): def test_reflect_table(self, engine): engine, conn = engine - one_row = Table("one_row", MetaData(), autoload_with=conn) + one_row = Table("one_row", MetaData(schema=ENV.schema), autoload_with=conn) assert len(one_row.c) == 1 assert one_row.c.number_of_rows is not None assert one_row.comment == "table comment" @@ -104,7 +104,7 @@ def test_reflect_table_with_schema(self, engine): def test_reflect_table_include_columns(self, engine): engine, conn = engine - one_row_complex = Table("one_row_complex", MetaData()) + one_row_complex = Table("one_row_complex", MetaData(schema=ENV.schema)) version = float(re.search(r"^([\d]+\.[\d]+)\..+", sqlalchemy.__version__).group(1)) if version <= 1.2: engine.dialect.reflecttable( @@ -135,7 +135,7 @@ def test_reflect_table_include_columns(self, engine): def test_partition_table_columns(self, engine): engine, conn = engine - partition_table = Table("partition_table", MetaData(), autoload_with=conn) + partition_table = Table("partition_table", MetaData(schema=ENV.schema), autoload_with=conn) assert len(partition_table.columns) == 2 assert "a" in partition_table.columns assert "b" in partition_table.columns @@ -157,22 +157,30 @@ def test_reflect_schemas(self, engine): def test_get_table_names(self, engine): engine, conn = engine - meta = MetaData() + meta = MetaData(schema=ENV.schema) meta.reflect(bind=engine) - assert "one_row" in meta.tables - assert "one_row_complex" in meta.tables - assert "view_one_row" not in meta.tables + # With schema specified, table names are schema-qualified + schema_qualified_one_row = f"{ENV.schema}.one_row" + schema_qualified_one_row_complex = f"{ENV.schema}.one_row_complex" + schema_qualified_view_one_row = f"{ENV.schema}.view_one_row" + assert schema_qualified_one_row in meta.tables + assert schema_qualified_one_row_complex in meta.tables + assert schema_qualified_view_one_row not in meta.tables insp = sqlalchemy.inspect(engine) assert "many_rows" in insp.get_table_names(schema=ENV.schema) def test_get_view_names(self, engine): engine, conn = engine - meta = MetaData() + meta = MetaData(schema=ENV.schema) meta.reflect(bind=engine, views=True) - assert "one_row" in meta.tables - assert "one_row_complex" in meta.tables - assert "view_one_row" in meta.tables + # With schema specified, table names are schema-qualified + schema_qualified_one_row = f"{ENV.schema}.one_row" + schema_qualified_one_row_complex = f"{ENV.schema}.one_row_complex" + schema_qualified_view_one_row = f"{ENV.schema}.view_one_row" + assert schema_qualified_one_row in meta.tables + assert schema_qualified_one_row_complex in meta.tables + assert schema_qualified_view_one_row in meta.tables insp = sqlalchemy.inspect(engine) actual = insp.get_view_names(schema=ENV.schema) @@ -680,7 +688,7 @@ def test_create_table(self, engine): column_name = "col" table = Table( table_name, - MetaData(), + MetaData(schema=ENV.schema), Column(column_name, types.String(10)), schema=ENV.schema, awsathena_location=f"{ENV.s3_staging_dir}{ENV.schema}/{table_name}",