Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tools/import_differ/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ python import_differ.py \
- previous\_data: Path to the previous data (wildcard on local/GCS supported).
- output\_location: Path to the output data folder (local/GCS).
- file\_format: Format of the input data (mcf,tfrecord).
- runner\_mode: Runner mode: local (Python) / cloud (Dataflow in Cloud).
- runner\_mode: Runner mode: native (Python) / direct (Java runner) /cloud (Dataflow in Cloud).
- project\_id: GCP project Id for the dataflow job.
- job\_name: Name of the differ dataflow job.

Expand Down
88 changes: 27 additions & 61 deletions tools/import_differ/differ_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import glob
import fnmatch
import json
import os
import pandas as pd
import re

from absl import logging
from google.cloud import storage
from util.file_util import FileIO
from util.file_util import file_get_matching


def load_mcf_file(file: str):
""" Reads an MCF text file and returns mcf nodes."""
mcf_file = open(file, 'r', encoding='utf-8')
mcf_contents = mcf_file.read()
mcf_file.close()
with FileIO(file, 'r', encoding='utf-8') as mcf_file:
mcf_contents = mcf_file.read()
# nodes separated by a blank line
mcf_nodes_text = mcf_contents.split('\n\n')
# lines seprated as property: constraint
Expand All @@ -36,7 +34,7 @@ def load_mcf_files(path: str) -> pd.DataFrame:
""" Loads all sharded mcf files in the given directory and
returns a combined MCF node list."""
node_list = []
filenames = glob.glob(path)
filenames = file_get_matching(path)
logging.info(f'Loading {len(filenames)} files from path {path}')
for filename in filenames:
nodes = load_mcf_file(filename)
Expand All @@ -48,72 +46,44 @@ def load_csv_data(path: str, tmp_dir: str) -> pd.DataFrame:
""" Loads all matched files in the given path and
returns a single combined dataframe."""
df_list = []
pattern = path
if path.startswith('gs://'):
pattern = get_gcs_data(path, tmp_dir)

filenames = glob.glob(pattern)
filenames = file_get_matching(path)
for filename in filenames:
df = pd.read_csv(filename)
df_list.append(df)
with FileIO(filename, mode='r') as in_file:
df = pd.read_csv(in_file)
df_list.append(df)
result = pd.concat(df_list, ignore_index=True)
return result


def write_csv_data(df: pd.DataFrame, dest: str, file: str, tmp_dir: str):
""" Writes a dataframe to a CSV file with the given path."""
if dest.startswith('gs://'):
path = os.path.join(tmp_dir, file)
else:
path = os.path.join(dest, file)
with open(path, mode='w', encoding='utf-8') as out_file:
path = os.path.join(dest, file)
with FileIO(path, mode='w', encoding='utf-8') as out_file:
df.to_csv(out_file, index=False, mode='w', header=True)
if dest.startswith('gs://'):
upload_output_data(path, dest)


def write_json_data(data, dest: str, file: str, tmp_dir: str):
""" Writes data to a JSON file with the given path."""
if dest.startswith('gs://'):
path = os.path.join(tmp_dir, file)
else:
path = os.path.join(dest, file)
with open(path, mode='w', encoding='utf-8') as out_file:
path = os.path.join(dest, file)
with FileIO(path, mode='w', encoding='utf-8') as out_file:
json.dump(data, out_file, indent=4)
if dest.startswith('gs://'):
upload_output_data(path, dest)


def upload_output_data(src: str, dest: str):
client = storage.Client()
bucket_name = dest.split('/')[2]
bucket = client.get_bucket(bucket_name)
for filepath in glob.iglob(src):
filename = os.path.basename(filepath)
logging.info('Uploading %s to %s', filename, dest)
blobname = dest[len('gs://' + bucket_name + '/'):] + '/' + filename
blob = bucket.blob(blobname)
blob.upload_from_filename(filepath)

def write_mcf_nodes(nodes: list, dest: str, file: str, tmp_dir: str):
""" Writes mcf nodes to a file with the given path."""
path = os.path.join(dest, file)
with FileIO(path, mode='w', encoding='utf-8') as out_file:
for node in nodes:
if 'Node' in node:
out_file.write(f'Node: {node["Node"]}\n')
elif 'dcid' in node:
out_file.write(f'dcid: {node["dcid"]}\n')

def get_gcs_data(uri: str, dest_dir: str) -> str:
""" Downloads files from GCS and copies them to local.
Args:
uri: single file path or wildcard format
dest_dir: destination folder
Returns:
path to the output file/folder
"""
client = storage.Client()
bucket = client.get_bucket(uri.split('/')[2])
file_pat = uri.split(bucket.name, 1)[1][1:]
dirname = os.path.dirname(file_pat)
for blob in bucket.list_blobs(prefix=dirname):
if fnmatch.fnmatch(blob.name, file_pat):
dest_file = os.path.join(dest_dir, blob.name)
os.makedirs(os.path.dirname(dest_file), exist_ok=True)
blob.download_to_filename(dest_file)
return os.path.join(dest_dir, file_pat)
for key, value in node.items():
if key in ['Node', 'dcid']:
continue
out_file.write(f'{key}: {value}\n')
out_file.write('\n')


def load_data(path: str, tmp_dir: str) -> list:
Expand All @@ -124,9 +94,5 @@ def load_data(path: str, tmp_dir: str) -> list:
Returns:
combined list of mcf nodes
"""
if path.startswith('gs://'):
os.makedirs(tmp_dir, exist_ok=True)
path = get_gcs_data(path, tmp_dir)

mcf_nodes = load_mcf_files(path)
return mcf_nodes
Loading
Loading