Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: songlab-cal/tape
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: v0.4
Choose a base ref
...
head repository: songlab-cal/tape
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: master
Choose a head ref
  • 19 commits
  • 16 files changed
  • 5 contributors

Commits on May 22, 2020

  1. Added force re-download

    Pravin Ravishanker authored and Pravin Ravishanker committed May 22, 2020
    Copy the full SHA
    905ba6c View commit details

Commits on May 29, 2020

  1. hugging face adoption

    Pravin Ravishanker authored and Pravin Ravishanker committed May 29, 2020
    Copy the full SHA
    54c4b8f View commit details

Commits on Jun 12, 2020

  1. Force Download Bug Solved

    Pravin Ravishanker authored and Pravin Ravishanker committed Jun 12, 2020
    Copy the full SHA
    0950c8c View commit details
  2. requirements update

    Pravin Ravishanker authored and Pravin Ravishanker committed Jun 12, 2020
    Copy the full SHA
    5c2486e View commit details
  3. file utils changed

    Pravin Ravishanker authored and Pravin Ravishanker committed Jun 12, 2020
    Copy the full SHA
    b7cb990 View commit details

Commits on Jun 26, 2020

  1. force download

    Pravin Ravishanker authored and Pravin Ravishanker committed Jun 26, 2020
    Copy the full SHA
    6f447f3 View commit details
  2. changed tests

    Pravin Ravishanker authored and Pravin Ravishanker committed Jun 26, 2020
    Copy the full SHA
    6e204f5 View commit details
  3. Merge pull request #69 from pravinrav/ForceDownloadInstall

    add force download to repository
    
    * Fix #6 
    * add test for model download + force download
    rmrao authored Jun 26, 2020

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
    Copy the full SHA
    8495202 View commit details

Commits on Sep 27, 2020

  1. Update README.md

    rmrao authored Sep 27, 2020

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
    Copy the full SHA
    7b13620 View commit details
  2. Merge pull request #87 from songlab-cal/rmrao-patch-1

    Update README.md
    rmrao authored Sep 27, 2020

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
    Copy the full SHA
    8a6928c View commit details

Commits on Apr 14, 2021

  1. Update requirements.txt

    rmrao authored Apr 14, 2021

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
    Copy the full SHA
    f89977f View commit details

Commits on May 1, 2021

  1. Copy the full SHA
    4ecd075 View commit details

Commits on May 12, 2021

  1. Merge pull request #107 from chrislengerich/spelling_fix

    Fix spelling (wiht -> with).
    rmrao authored May 12, 2021

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
    Copy the full SHA
    b38e317 View commit details

Commits on Aug 2, 2021

  1. Copy the full SHA
    f259fd6 View commit details

Commits on Aug 4, 2021

  1. Merge pull request #110 from songlab-cal/new-s3-bucket

    Update data and model paths to new s3 bucket
    rmrao authored Aug 4, 2021

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
    Copy the full SHA
    02a7fda View commit details

Commits on Sep 7, 2021

  1. update workflow to install pytorch

    Roshan Rao committed Sep 7, 2021
    Copy the full SHA
    82ef434 View commit details
  2. bump version and add TestPyPi workflow

    Roshan Rao committed Sep 7, 2021
    Copy the full SHA
    6d411b5 View commit details
  3. don't test pypi on push

    Roshan Rao committed Sep 7, 2021
    Copy the full SHA
    7219ca8 View commit details
  4. add workflow to upload to pypi

    Roshan Rao committed Sep 7, 2021
    Copy the full SHA
    6d345c2 View commit details
41 changes: 41 additions & 0 deletions .github/workflows/pypi.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Upload to PyPI

# Controls when the action will run.
on:
# Triggers the workflow when a release is created
release:
types: [created]

# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "testpypi"
testpypi:
# The type of runner that the job will run on
runs-on: ubuntu-latest

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2

# Sets up python3
- uses: actions/setup-python@v2
with:
python-version: 3.8

- name: "Installs dependencies"
run: |
python3 -m pip install --upgrade pip
python3 -m pip install setuptools wheel twine
# Upload to TestPyPI
- name: Build and upload to PyPI
run: |
python3 setup.py sdist bdist_wheel
python3 -m twine upload dist/*
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.TWINE_TOKEN }}
2 changes: 1 addition & 1 deletion .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install flake8 pytest torch
pip install .
- name: Lint with flake8
run: |
38 changes: 38 additions & 0 deletions .github/workflows/testpypi.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Upload to TestPyPI

# Controls when the action will run.
on:
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "testpypi"
testpypi:
# The type of runner that the job will run on
runs-on: ubuntu-latest

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2

# Sets up python3
- uses: actions/setup-python@v2
with:
python-version: 3.8

- name: "Installs dependencies"
run: |
python3 -m pip install --upgrade pip
python3 -m pip install setuptools wheel twine
# Upload to TestPyPI
- name: Build and upload to TestPyPI
run: |
python3 setup.py sdist bdist_wheel
python3 -m twine upload dist/*
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.TWINE_TEST_TOKEN }}
TWINE_REPOSITORY: testpypi
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -12,6 +12,8 @@ Our paper is available at [https://arxiv.org/abs/1906.08230](https://arxiv.org/a

Some documentation is incomplete. We will try to fill it in over time, but if there is something you would like an explanation for, please open an issue so we know where to focus our effort!

**Update 09/26/2020:** We no longer recommend trying to train directly with TAPE's training code. It will likely still work for some time, but will not be updated for future pytorch versions. Internally, we have been working with different frameworks for training (specifically Pytorch Lightning and Fairseq). We strongly recommend using a framework like these, as it offloads the requirement of maintaining compatability with Pytorch versions. TAPE models will continue to be available, and if the code is working for you, feel free to use it. However we will not be fixing issues regarding multi-GPU errors, OOM erros, etc during training.

## Contents

* [Installation](#installation)
@@ -187,7 +189,7 @@ This will report the overall accuracy, and will also dump a `results.pkl` file i

### trRosetta

We have recently re-implemented the trRosetta model from Yang et. al. (2020). A link to the original repository, which was used as a basis for this re-implementation, can be found [here](https://github.com/gjoni/trRosetta). We provide a pytorch implementation and dataset to allow you to play around with the model. Data is available [here](http://s3.amazonaws.com/proteindata/data_pytorch/trrosetta.tar.gz). This is the same as the data in the original paper, however we've added train / val split files to allow you to train your own model reproducibly. To use this model
We have recently re-implemented the trRosetta model from Yang et. al. (2020). A link to the original repository, which was used as a basis for this re-implementation, can be found [here](https://github.com/gjoni/trRosetta). We provide a pytorch implementation and dataset to allow you to play around with the model. Data is available [here](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/trrosetta.tar.gz). This is the same as the data in the original paper, however we've added train / val split files to allow you to train your own model reproducibly. To use this model

```python
from tape import TRRosetta
@@ -243,13 +245,13 @@ The unsupervised Pfam dataset is around 7GB compressed and 19GB uncompressed. Th

### LMDB Data

[Pretraining Corpus (Pfam)](http://s3.amazonaws.com/proteindata/data_pytorch/pfam.tar.gz) __|__ [Secondary Structure](http://s3.amazonaws.com/proteindata/data_pytorch/secondary_structure.tar.gz) __|__ [Contact (ProteinNet)](http://s3.amazonaws.com/proteindata/data_pytorch/proteinnet.tar.gz) __|__ [Remote Homology](http://s3.amazonaws.com/proteindata/data_pytorch/remote_homology.tar.gz) __|__ [Fluorescence](http://s3.amazonaws.com/proteindata/data_pytorch/fluorescence.tar.gz) __|__ [Stability](http://s3.amazonaws.com/proteindata/data_pytorch/stability.tar.gz)
[Pretraining Corpus (Pfam)](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/pfam.tar.gz) __|__ [Secondary Structure](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/secondary_structure.tar.gz) __|__ [Contact (ProteinNet)](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/proteinnet.tar.gz) __|__ [Remote Homology](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/remote_homology.tar.gz) __|__ [Fluorescence](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/fluorescence.tar.gz) __|__ [Stability](http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/stability.tar.gz)

### Raw Data

Raw data files are stored in JSON format for maximum portability. This data is JSON-ified, which removes certain constructs (in particular numpy arrays). As a result they cannot be directly loaded into the provided pytorch datasets (although the conversion should be quite easy by simply adding calls to `np.array`).

[Pretraining Corpus (Pfam)](http://s3.amazonaws.com/proteindata/data_raw_pytorch/pfam.tar.gz) __|__ [Secondary Structure](http://s3.amazonaws.com/proteindata/data_raw_pytorch/secondary_structure.tar.gz) __|__ [Contact (ProteinNet)](http://s3.amazonaws.com/proteindata/data_raw_pytorch/proteinnet.tar.gz) __|__ [Remote Homology](http://s3.amazonaws.com/proteindata/data_raw_pytorch/remote_homology.tar.gz) __|__ [Fluorescence](http://s3.amazonaws.com/proteindata/data_raw_pytorch/fluorescence.tar.gz) __|__ [Stability](http://s3.amazonaws.com/proteindata/data_raw_pytorch/stability.tar.gz)
[Pretraining Corpus (Pfam)](http://s3.amazonaws.com/songlabdata/proteindata/data_raw_pytorch/pfam.tar.gz) __|__ [Secondary Structure](http://s3.amazonaws.com/songlabdata/proteindata/data_raw_pytorch/secondary_structure.tar.gz) __|__ [Contact (ProteinNet)](http://s3.amazonaws.com/songlabdata/proteindata/data_raw_pytorch/proteinnet.tar.gz) __|__ [Remote Homology](http://s3.amazonaws.com/songlabdata/proteindata/data_raw_pytorch/remote_homology.tar.gz) __|__ [Fluorescence](http://s3.amazonaws.com/songlabdata/proteindata/data_raw_pytorch/fluorescence.tar.gz) __|__ [Stability](http://s3.amazonaws.com/songlabdata/proteindata/data_raw_pytorch/stability.tar.gz)


## Leaderboard
16 changes: 8 additions & 8 deletions download_data.sh
Original file line number Diff line number Diff line change
@@ -4,25 +4,25 @@ mkdir -p ./data
while true; do
read -p "Do you wish to download and unzip the pretraining corpus? It is 7.7GB compressed and 19GB uncompressed? [y/n]" yn
case $yn in
[Yy]* ) wget http://s3.amazonaws.com/proteindata/data_pytorch/pfam.tar.gz; tar -xzf pfam.tar.gz -C ./data; rm pfam.tar.gz; break;;
[Yy]* ) wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/pfam.tar.gz; tar -xzf pfam.tar.gz -C ./data; rm pfam.tar.gz; break;;
[Nn]* ) exit;;
* ) echo "Please answer yes (Y/y) or no (N/n).";;
esac
done

# Download Vocab/Model files
wget http://s3.amazonaws.com/proteindata/data_pytorch/pfam.model
wget http://s3.amazonaws.com/proteindata/data_pytorch/pfam.vocab
wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/pfam.model
wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/pfam.vocab

mv pfam.model data
mv pfam.vocab data

# Download Data Files
wget http://s3.amazonaws.com/proteindata/data_pytorch/secondary_structure.tar.gz
wget http://s3.amazonaws.com/proteindata/data_pytorch/proteinnet.tar.gz
wget http://s3.amazonaws.com/proteindata/data_pytorch/remote_homology.tar.gz
wget http://s3.amazonaws.com/proteindata/data_pytorch/fluorescence.tar.gz
wget http://s3.amazonaws.com/proteindata/data_pytorch/stability.tar.gz
wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/secondary_structure.tar.gz
wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/proteinnet.tar.gz
wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/remote_homology.tar.gz
wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/fluorescence.tar.gz
wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/stability.tar.gz

tar -xzf secondary_structure.tar.gz -C ./data
tar -xzf proteinnet.tar.gz -C ./data
16 changes: 8 additions & 8 deletions download_data_aws.sh
Original file line number Diff line number Diff line change
@@ -4,30 +4,30 @@ mkdir -p ./data
while true; do
read -p "Do you wish to download and unzip the pretraining corpus? It is 7.7GB compressed and 19GB uncompressed. [y/n]" yn
case $yn in
[Yy]* ) aws s3 cp s3://proteindata/data_pytorch/pfam.tar.gz .; tar -xzf pfam.tar.gz -C ./data; rm pfam.tar.gz; break;;
[Yy]* ) aws s3 cp s3://songlabdata/proteindata/data_pytorch/pfam.tar.gz .; tar -xzf pfam.tar.gz -C ./data; rm pfam.tar.gz; break;;
[Nn]* ) exit;;
* ) echo "Please answer yes (Y/y) or no (N/n).";;
esac
done

echo "Downloading BPE Vocab/Model files"
aws s3 cp s3://proteindata/data_pytorch/pfam.model . && mv pfam.model data
aws s3 cp s3://proteindata/data_pytorch/pfam.vocab . && mv pfam.vocab data
aws s3 cp s3://songlabdata/proteindata/data_pytorch/pfam.model . && mv pfam.model data
aws s3 cp s3://songlabdata/proteindata/data_pytorch/pfam.vocab . && mv pfam.vocab data

# Download Data Files
echo "Download TAPE task datasets"
aws s3 cp s3://proteindata/data_pytorch/secondary_structure.tar.gz . \
aws s3 cp s3://songlabdata/proteindata/data_pytorch/secondary_structure.tar.gz . \
&& tar -xzf secondary_structure.tar.gz -C ./data \
&& rm secondary_structure.tar.gz
aws s3 cp s3://proteindata/data_pytorch/proteinnet.tar.gz . \
aws s3 cp s3://songlabdata/proteindata/data_pytorch/proteinnet.tar.gz . \
&& tar -xzf proteinnet.tar.gz -C ./data \
&& rm proteinnet.tar.gz
aws s3 cp s3://proteindata/data_pytorch/remote_homology.tar.gz . \
aws s3 cp s3://songlabdata/proteindata/data_pytorch/remote_homology.tar.gz . \
&& tar -xzf remote_homology.tar.gz -C ./data \
&& rm remote_homology.tar.gz
aws s3 cp s3://proteindata/data_pytorch/fluorescence.tar.gz . \
aws s3 cp s3://songlabdata/proteindata/data_pytorch/fluorescence.tar.gz . \
&& tar -xzf fluorescence.tar.gz -C ./data \
&& rm fluorescence.tar.gz
aws s3 cp s3://proteindata/data_pytorch/stability.tar.gz . \
aws s3 cp s3://songlabdata/proteindata/data_pytorch/stability.tar.gz . \
&& tar -xzf stability.tar.gz -C ./data \
&& rm stability.tar.gz
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
torch>=1.0,<1.5
tqdm
tensorboardX
scipy
lmdb
boto3
requests
biopython
filelock
2 changes: 1 addition & 1 deletion tape/__init__.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import importlib
import pkgutil

__version__ = '0.4'
__version__ = '0.5'


# Import all the models and configs
2 changes: 1 addition & 1 deletion tape/main.py
Original file line number Diff line number Diff line change
@@ -117,7 +117,7 @@ def create_eval_parser(base_parser: argparse.ArgumentParser) -> argparse.Argumen

def create_embed_parser(base_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description='Embed a set of proteins wiht a pretrained model',
description='Embed a set of proteins with a pretrained model',
parents=[base_parser])
parser.add_argument('data_file', type=str,
help='File containing set of proteins to embed')
101 changes: 92 additions & 9 deletions tape/models/file_utils.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
https://github.com/huggingface/transformers, which in turn is adapted from the AllenNLP
library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
Note - this file goes to effort to support Python 2, but the rest of this repository does not.
"""
from __future__ import (absolute_import, division, print_function, unicode_literals)
@@ -14,18 +13,25 @@
import json
import logging
import os
import shutil
import tempfile
import fnmatch
from functools import wraps
from hashlib import sha256
from io import open

import boto3
import requests
from botocore.exceptions import ClientError
from tqdm import tqdm

from contextlib import contextmanager
from functools import partial, wraps
from hashlib import sha256

from filelock import FileLock
# from tqdm.auto import tqdm

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


try:
from torch.hub import _get_torch_home
torch_cache_home = _get_torch_home()
@@ -55,6 +61,30 @@
logger = logging.getLogger(__name__) # pylint: disable=invalid-name


def get_cache():
return PROTEIN_MODELS_CACHE


def get_etag(url):
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
try:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
etag = None
else:
etag = response.headers.get("ETag")
except EnvironmentError:
etag = None

if sys.version_info[0] == 2 and etag is not None:
etag = etag.decode('utf-8')

return etag


def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
@@ -99,12 +129,18 @@ def filename_to_url(filename, cache_dir=None):
return url, etag


def cached_path(url_or_filename, cache_dir=None):
def cached_path(url_or_filename, force_download=False, cache_dir=None):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
Args:
cache_dir: specify a cache directory to save the file to
(overwrite the default cache dir).
force_download: if True, re-dowload the file even if it's
already cached in the cache dir.
"""
if cache_dir is None:
cache_dir = PROTEIN_MODELS_CACHE
@@ -117,10 +153,10 @@ def cached_path(url_or_filename, cache_dir=None):

if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
output_path = get_from_cache(url_or_filename, cache_dir, force_download)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
output_path = url_or_filename
elif parsed.scheme == '':
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
@@ -129,6 +165,8 @@ def cached_path(url_or_filename, cache_dir=None):
raise ValueError("unable to parse {} as a URL or as a local path".format(
url_or_filename))

return output_path


def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
@@ -191,7 +229,7 @@ def http_get(url, temp_file):
progress.close()


def get_from_cache(url, cache_dir=None):
def get_from_cache(url, cache_dir=None, force_download=False, resume_download=False):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
@@ -226,6 +264,9 @@ def get_from_cache(url, cache_dir=None):
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)

if os.path.exists(cache_path) and etag is None:
return cache_path

# If we don't have a connection (etag is None) and can't identify the file
# try to get the last downloaded one
if not os.path.exists(cache_path) and etag is None:
@@ -234,6 +275,48 @@ def get_from_cache(url, cache_dir=None):
if matching_files:
cache_path = os.path.join(cache_dir, matching_files[-1])

# From now on, etag is not None
if os.path.exists(cache_path) and not force_download:
return cache_path

# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + ".lock"
with FileLock(lock_path):

# If the download just completed while the lock was activated.
if os.path.exists(cache_path) and not force_download:
# Even if returning early like here, the lock will be released.
return cache_path

if resume_download:
incomplete_path = cache_path + ".incomplete"

@contextmanager
def _resumable_file_manager():
with open(incomplete_path, "a+b") as f:
yield f

temp_file_manager = _resumable_file_manager
else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir,
delete=False)
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info("%s not in cache or force_download=True, download to %s",
url, temp_file.name)

http_get(url, temp_file)

logger.info("storing %s in cache at %s", url, cache_path)
os.replace(temp_file.name, cache_path)

logger.info("creating metadata file for %s", cache_path)
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
json.dump(meta, meta_file)
'''
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
@@ -266,5 +349,5 @@ def get_from_cache(url, cache_dir=None):
meta_file.write(output_string)
logger.info("removing temp file %s", temp_file.name)

'''
return cache_path
2 changes: 1 addition & 1 deletion tape/models/modeling_bert.py
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@

logger = logging.getLogger(__name__)

URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
URL_PREFIX = "https://s3.amazonaws.com/songlabdata/proteindata/pytorch-models/"
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base': URL_PREFIX + "bert-base-pytorch_model.bin",
}
2 changes: 1 addition & 1 deletion tape/models/modeling_lstm.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
logger = logging.getLogger(__name__)


URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
URL_PREFIX = "https://s3.amazonaws.com/songlabdata/proteindata/pytorch-models/"
LSTM_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {}
LSTM_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {}

2 changes: 1 addition & 1 deletion tape/models/modeling_trrosetta.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
from .modeling_utils import ProteinConfig
from .modeling_utils import ProteinModel

URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
URL_PREFIX = "https://s3.amazonaws.com/songlabdata/proteindata/pytorch-models/"
TRROSETTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xaa': URL_PREFIX + "trRosetta-xaa-pytorch_model.bin",
'xab': URL_PREFIX + "trRosetta-xab-pytorch_model.bin",
2 changes: 1 addition & 1 deletion tape/models/modeling_unirep.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
logger = logging.getLogger(__name__)


URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
URL_PREFIX = "https://s3.amazonaws.com/songlabdata/proteindata/pytorch-models/"
UNIREP_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {
'babbler-1900': URL_PREFIX + 'unirep-base-config.json'}
UNIREP_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {
20 changes: 17 additions & 3 deletions tape/models/modeling_utils.py
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@


class ProteinConfig(object):
r""" Base class for all configuration classes.
""" Base class for all configuration classes.
Handles a few parameters common to all models' configurations as well as methods
for loading/downloading/saving configurations.
@@ -426,6 +426,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override
the cached versions if they exists.
resume_download: (`optional`) boolean, default False:
Do not delete incompletely recieved file. Attempt to resume the download if
such a file exists.
output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys,
unexpected keys and error messages.
@@ -435,7 +443,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
initiate the model. (e.g. ``output_attention=True``). Behave differently
depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be
- If a configuration is provided with ``config``, ``**kwarg
directly passed to the underlying model's ``__init__`` method (we assume
all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the
@@ -462,11 +470,16 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
cache_dir = kwargs.pop('cache_dir', None)
output_loading_info = kwargs.pop('output_loading_info', False)

force_download = kwargs.pop("force_download", False)
kwargs.pop("resume_download", False)

# Load config
if config is None:
config, model_kwargs = cls.config_class.from_pretrained(
pretrained_model_name_or_path, *model_args,
cache_dir=cache_dir, return_unused_kwargs=True,
# force_download=force_download,
# resume_download=resume_download,
**kwargs
)
else:
@@ -481,7 +494,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir,
force_download=force_download)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error(
25 changes: 25 additions & 0 deletions tests/test_forceDownload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import time
import os
from tape.models.file_utils import url_to_filename, get_cache, get_etag
from tape import ProteinBertModel
from tape import TAPETokenizer
from tape.models.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
import torch


def test_forcedownload():
model = ProteinBertModel.from_pretrained('bert-base')
url = BERT_PRETRAINED_MODEL_ARCHIVE_MAP['bert-base']
filename = url_to_filename(url, get_etag(url))
wholepath = get_cache()/filename
oldtime = time.ctime(os.path.getmtime(wholepath))
model = ProteinBertModel.from_pretrained('bert-base', force_download=True)
newtime = time.ctime(os.path.getmtime(wholepath))
assert(newtime != oldtime)
# Deploy model
# iupac is the vocab for TAPE models, use unirep for the UniRep model
tokenizer = TAPETokenizer(vocab='iupac')
# Pfam Family: Hexapep, Clan: CL0536
sequence = 'GCTVEDRCLIGMGAILLNGCVIGSGSLVAAGALITQ'
token_ids = torch.tensor([tokenizer.encode(sequence)])
model(token_ids)