Source code for genome_kit.data_manager

# Copyright (C) 2016-2023 Deep Genomics Inc. All Rights Reserved.

import base64
import hashlib
import logging
import os
import tempfile
from abc import ABC
from contextlib import contextmanager
from functools import wraps, lru_cache
from pathlib import Path
from typing import Dict

from google.cloud import storage
from tqdm.auto import tqdm
from tqdm.utils import ObjectWrapper

logger = logging.getLogger(__name__)

[docs]class ProgressPercentage(object): # pragma: no cover """ Display a progress bar on a text terminal showing file being transferred, completion percentage, estimated transfer time and throughput. ```Progress "myfile.bin" [ 20.0 B/ 20.0 B] 100% |###############| ETA: 0:00:00 147.8 B/s``` """
[docs] def __init__(self, filename, filesize=None, progress_msg='Progress'): """ Parameters ---------- filename : :py:class:`str` file name to display, will be truncated for display if necessary filesize : :py:class:`int` file size in bytes progress_msg : :py:class:`str` message to display before filename, e.g. "[MyPackage] downloading", defaults to "" """ if len(filename) > 34: filename = os.path.splitext(filename)[0][-34:] progress_msg = '' self.progress = tqdm(unit='B', unit_scale=True, unit_divisor=1024, miniters=1, total=filesize, postfix={progress_msg: filename})
[docs] def __call__(self, bytes_amount): self.progress.update(bytes_amount)
_GCS_BUCKET = os.environ.get("GENOMEKIT_GCS_BUCKET", "genomekit-public-dg") def _hashfile(afile, hasher, blocksize=65536): """Memory efficient file hashing function. Parameters ---------- afile : :py:class:`file` An open file handle to read from. hasher : :py:class:`_hashlib.HASH` A hash function, e.g. hashlib.sha256(). blocksize : :py:class:`int` The size of each hash update (speed only). Returns ------- :py:class:`str` The hash string. """ buf = afile.read(blocksize) while len(buf) > 0: hasher.update(buf) buf = afile.read(blocksize) return hasher.digest()
[docs]class DataManager(ABC):
[docs] def __init__(self, data_dir: str): """ Implementations should accept a data directory. Args: data_dir: location where local files are cached """ os.makedirs(data_dir, exist_ok=True) self.data_dir = data_dir
[docs] def get_file(self, filename: str) -> str: """Get path to a managed data file and download if necessary. Parameters ---------- filename A single filename, without a path. Returns ------- The path to the local version of the file. """ raise NotImplementedError("Descendant classes must implement get_file")
[docs] def upload_file(self, filepath: str, filename: str, metadata: Dict[str, str]=None): """Upload a file to make it available in GenomeKit. Parameters ---------- filepath Full path to the file to be uploaded. filename The remote filename. metadata Extra metadata to attach to the file. """ raise NotImplementedError("Descendant classes must implement upload_file")
[docs] def list_available_genomes(self): """List all available genomes in the data manager""" raise NotImplementedError("Descendant classes must implement list_available_genomes")
def _remote_equal(blob: storage.Blob, file_path: Path) -> bool: remote_checksum = base64.b64decode(blob.md5_hash) with open(file_path, 'rb') as f: local_checksum = _hashfile(f, hashlib.md5()) return remote_checksum == local_checksum class CallbackIOWrapper(ObjectWrapper, object): def __init__(self, callback, stream): super(CallbackIOWrapper, self).__init__(stream) write = getattr(stream, "write") @wraps(write) def wrapped_write(data, *args, **kwargs): res = write(data, *args, **kwargs) callback(data) return res self.wrapper_setattr("write", wrapped_write) read = getattr(stream, "read") @wraps(read) def wrapped_read(*args, **kwargs): data = read(*args, **kwargs) callback(data) return data self.wrapper_setattr("read", wrapped_read) @contextmanager def FileIO(filename, desc, mode, size, quiet): with open(filename, mode) as fp: with tqdm( disable=quiet, total=size, unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=desc, ) as bar: decorated = CallbackIOWrapper( lambda x: bar.update(len(x)), fp ) yield decorated class DefaultDataManager(DataManager): """A minimal data manager implementation that retrieves files from the GCS samples bucket. Uploads are not supported""" def __init__(self, data_dir: str): """ Args: data_dir: location where local files are cached """ super().__init__(data_dir) @property def bucket(self): if not hasattr(self, "_bucket"): gcloud_client = storage.Client() self._bucket = gcloud_client.bucket(_GCS_BUCKET, user_project=os.environ.get("GENOMEKIT_GCS_BILLING_PROJECT", None)) return self._bucket def get_file(self, filename: str) -> str: blob = self.bucket.blob(filename) if not blob.exists(): raise FileNotFoundError(f"File '{filename}' not found in the GCS bucket") local_path = Path(self.data_dir, filename) blob.reload() if local_path.exists(): if _remote_equal(blob, local_path): return str(local_path) # form a temporary filename to make the download safe temp_file = tempfile.NamedTemporaryFile(delete=False, mode="wb", dir=self.data_dir, prefix=filename, suffix=".part") temp_file_path = str(Path(self.data_dir, temp_file.name)) try: temp_file.close() with FileIO(temp_file_path, filename, "wb", blob.size, quiet=False) as f: blob.download_to_file(f) except: os.remove(temp_file_path) raise # atomically (on POSIX) rename the file to the real one os.rename(temp_file_path, local_path) return str(local_path) def upload_file(self, filepath: str, filename: str, metadata: Dict[str, str]=None): blob = self.bucket.blob(filename) if blob.exists(): blob.reload() if _remote_equal(blob, Path(filepath)): logger.info(f"File '{filename}' already exists in the GCS bucket and is identical, skipping.") return if "GK_UPLOAD_ALLOW_OVERWRITE" not in os.environ: raise ValueError(f"File '{filename}' already exists in the GCS bucket." "Set GK_UPLOAD_ALLOW_OVERWRITE=1 to overwrite.") else: logger.warning(f"Overwriting '{filename}' on the server.") blob.metadata.update(metadata or {}) else: blob.metadata = metadata with FileIO(filepath, filename, "rb", os.path.getsize(filepath), quiet=False) as f: blob.upload_from_file(f) @lru_cache def list_available_genomes(self): blobs = self.bucket.list_blobs(match_glob="*.{2bit,cfg}") return [blob.name.rpartition(".")[0] for blob in blobs]