Skip to content

Commit ba6f6ac

Browse files
committed
upgrading requirements to tensorflow2 and adding tf.compat.v1 code so the tf1 code still runs against the tf2 binary
1 parent 087b882 commit ba6f6ac

File tree

8 files changed

+17
-13
lines changed

8 files changed

+17
-13
lines changed

cellbox/cellbox/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import os
88
import numpy as np
99
import pandas as pd
10-
import tensorflow as tf
10+
import tensorflow.compat.v1 as tf
1111
from scipy import sparse
12+
tf.disable_v2_behavior()
1213

1314

1415
def factory(cfg):

cellbox/cellbox/kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
degree of ODEs, and the envelope forms
44
"""
55

6-
import tensorflow as tf
6+
import tensorflow.compat.v1 as tf
7+
tf.disable_v2_behavior()
78

89

910
def get_envelope(args):

cellbox/cellbox/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
"""
55

66
import numpy as np
7-
import tensorflow as tf
7+
import tensorflow.compat.v1 as tf
88
import cellbox.kernel
99
from cellbox.utils import loss, optimize
1010
# import tensorflow_probability as tfp
11+
tf.disable_v2_behavior()
1112

1213

1314
def factory(args):

cellbox/cellbox/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import time
88
import numpy as np
99
import pandas as pd
10-
import tensorflow as tf
10+
import tensorflow.compat.v1 as tf
1111
from tensorflow.compat.v1.errors import OutOfRangeError
1212
import cellbox
1313
from cellbox.utils import TimeLogger
14+
tf.disable_v2_behavior()
1415

1516

1617
def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n_iter_buffer, n_iter_patience, args):

cellbox/cellbox/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
import time
77
import hashlib
8-
import tensorflow as tf
8+
import tensorflow.compat.v1 as tf
99
import json
10-
10+
tf.disable_v2_behavior()
1111

1212
def loss(x_gold, x_hat, W, l1=0, l2=0, weight=1.):
1313
"""evaluate loss"""

cellbox/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
url="https://github.com/dfci/CellBox",
1818
packages=['cellbox'],
1919
python_requires='>=3.6',
20-
install_requires=['tensorflow==1.15.0', 'numpy==1.16.0', 'pandas==0.24.2', 'scipy==1.3.0'],
20+
install_requires=['tensorflow==2.6.2', 'numpy==1.19.5', 'pandas==0.24.2', 'scipy==1.3.0'],
2121
tests_require=['pytest', 'pandas', 'numpy', 'scipy'],
2222
setup_requires=['pytest-runner', "pytest"],
2323
zip_safe=True,

requirements.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ importlib-metadata==1.7.0
88
Keras-Applications==1.0.8
99
Keras-Preprocessing==1.1.2
1010
Markdown==3.2.2
11-
numpy==1.16.0
11+
numpy==1.19.5
1212
opt-einsum==3.2.1
1313
pandas==0.24.2
1414
protobuf==3.12.2
1515
python-dateutil==2.8.1
1616
pytz==2020.1
1717
six==1.15.0
18-
tensorboard==1.15.0
19-
tensorflow==1.15.0
20-
tensorflow-estimator==1.15.1
18+
tensorboard==2.6.0
19+
tensorflow==2.6.2
20+
tensorflow-estimator==2.6.0
2121
termcolor==1.1.0
2222
Werkzeug==1.0.1
2323
wrapt==1.12.1

scripts/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import os
66
import numpy as np
77
import pandas as pd
8-
import tensorflow as tf
8+
import tensorflow.compat.v1 as tf
99
import shutil
1010
import argparse
1111
import json
12-
12+
tf.disable_v2_behavior()
1313
parser = argparse.ArgumentParser(description='CellBox main script')
1414
parser.add_argument('-config', '--experiment_config_path', required=True, type=str, help="Path of experiment config")
1515
parser.add_argument('-i', '--working_index', default=0, type=int)

0 commit comments

Comments
 (0)