Skip to content

Commit a1f1d20

Browse files
authored
Merge pull request #41 from sanderlab/tf2_refactor
TF2 Refactor
2 parents 1dfa9fb + 216afbd commit a1f1d20

File tree

12 files changed

+2400
-46
lines changed

12 files changed

+2400
-46
lines changed

binder/requirements.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@ importlib-metadata==1.7.0
1010
Keras-Applications==1.0.8
1111
Keras-Preprocessing==1.1.2
1212
Markdown==3.2.2
13-
numpy==1.16.0
13+
numpy==1.19.5
1414
opt-einsum==3.2.1
1515
pandas==0.24.2
1616
protobuf==3.12.2
1717
python-dateutil==2.8.1
1818
pytz==2020.1
1919
six==1.15.0
20-
tensorboard==1.15.0
21-
tensorflow==1.15.0
22-
tensorflow-estimator==1.15.1
20+
tensorboard==2.6.0
21+
tensorflow==2.6.2
22+
tensorflow-estimator==2.6.0
2323
termcolor==1.1.0
2424
Werkzeug==1.0.1
2525
wrapt==1.12.1

cellbox/cellbox/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import os
88
import numpy as np
99
import pandas as pd
10-
import tensorflow as tf
10+
import tensorflow.compat.v1 as tf
1111
from scipy import sparse
12+
tf.disable_v2_behavior()
1213

1314

1415
def factory(cfg):

cellbox/cellbox/kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
degree of ODEs, and the envelope forms
44
"""
55

6-
import tensorflow as tf
6+
import tensorflow.compat.v1 as tf
7+
tf.disable_v2_behavior()
78

89

910
def get_envelope(args):

cellbox/cellbox/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
"""
55

66
import numpy as np
7-
import tensorflow as tf
7+
import tensorflow.compat.v1 as tf
88
import cellbox.kernel
99
from cellbox.utils import loss, optimize
1010
# import tensorflow_probability as tfp
11+
tf.disable_v2_behavior()
1112

1213

1314
def factory(args):

cellbox/cellbox/train.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import time
88
import numpy as np
99
import pandas as pd
10-
import tensorflow as tf
10+
import tensorflow.compat.v1 as tf
1111
from tensorflow.compat.v1.errors import OutOfRangeError
1212
import cellbox
1313
from cellbox.utils import TimeLogger
14+
tf.disable_v2_behavior()
1415

1516

1617
def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n_iter_buffer, n_iter_patience, args):
@@ -59,7 +60,7 @@ def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n
5960
while True:
6061
if idx_iter > n_iter or n_unchanged > n_iter_patience:
6162
break
62-
t0 = time.clock()
63+
t0 = time.perf_counter()
6364
try:
6465
_, loss_train_i, loss_train_mse_i = sess.run(
6566
(model.op_optimize, model.train_loss, model.train_mse_loss), feed_dict=args.feed_dicts['train_set'])
@@ -78,7 +79,7 @@ def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n
7879
n_iter_patience))
7980
append_record("record_eval.csv",
8081
[idx_epoch, idx_iter, loss_train_i, loss_valid_i, loss_train_mse_i,
81-
loss_valid_mse_i, None, time.clock() - t0])
82+
loss_valid_mse_i, None, time.perf_counter() - t0])
8283
# early stopping
8384
idx_iter += 1
8485
if new_loss < best_params.loss_min:
@@ -89,18 +90,18 @@ def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n
8990
n_unchanged += 1
9091

9192
# Evaluation on valid set
92-
t0 = time.clock()
93+
t0 = time.perf_counter()
9394
sess.run(model.iter_eval.initializer, feed_dict=args.feed_dicts['valid_set'])
9495
loss_valid_i, loss_valid_mse_i = eval_model(sess, model.iter_eval, (model.eval_loss, model.eval_mse_loss),
9596
args.feed_dicts['valid_set'], n_batches_eval=args.n_batches_eval)
96-
append_record("record_eval.csv", [-1, None, None, loss_valid_i, None, loss_valid_mse_i, None, time.clock() - t0])
97+
append_record("record_eval.csv", [-1, None, None, loss_valid_i, None, loss_valid_mse_i, None, time.perf_counter() - t0])
9798

9899
# Evaluation on test set
99-
t0 = time.clock()
100+
t0 = time.perf_counter()
100101
sess.run(model.iter_eval.initializer, feed_dict=args.feed_dicts['test_set'])
101102
loss_test_mse = eval_model(sess, model.iter_eval, model.eval_mse_loss,
102103
args.feed_dicts['test_set'], n_batches_eval=args.n_batches_eval)
103-
append_record("record_eval.csv", [-1, None, None, None, None, None, loss_test_mse, time.clock() - t0])
104+
append_record("record_eval.csv", [-1, None, None, None, None, None, loss_test_mse, time.perf_counter() - t0])
104105

105106
best_params.save()
106107
args.logger.log("------------------ Substage {} finished!-------------------".format(substage_i))

cellbox/cellbox/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
import time
77
import hashlib
8-
import tensorflow as tf
8+
import tensorflow.compat.v1 as tf
99
import json
10-
10+
tf.disable_v2_behavior()
1111

1212
def loss(x_gold, x_hat, W, l1=0, l2=0, weight=1.):
1313
"""evaluate loss"""

cellbox/cellbox/version.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
This module defines the version of the package
33
"""
44

5-
__version__ = '0.3.1'
5+
__version__ = '0.3.2'
66
VERSION = __version__
77

88

@@ -104,14 +104,14 @@ def get_msg():
104104

105105
"""
106106
version 0.2.3
107-
-- June 8, 2020 --
107+
-- Jun 8, 2020 --
108108
* Add support to L2 loss (alone or together with L1, i.e. elastic net)
109109
* Clean the example configs folder
110110
""",
111111

112112
"""
113113
version 0.3.0
114-
-- June 8, 2020 --
114+
-- Jun 8, 2020 --
115115
Add support for alternative form of perturbation
116116
* Previous: add u on activity nodes
117117
* New: fix activity nodes directly
@@ -123,10 +123,16 @@ def get_msg():
123123

124124
"""
125125
version 0.3.1
126-
-- Sept 25, 2020 --
126+
-- Sep 25, 2020 --
127127
* Release version for publication
128128
* Add documentation
129129
* Rename package to 'cellbox'
130+
""",
131+
132+
"""
133+
version 0.3.2
134+
-- Feb 10, 2023 --
135+
* Modify CellBox to support TF2
130136
"""
131137
]
132138
print(
@@ -138,12 +144,12 @@ def get_msg():
138144
" | |___| __/ | | |_) | (_) > < \n"
139145
" \_____\___|_|_|____/ \___/_/\_\ \n"
140146
"Running CellBox scripts developed in Sander lab\n"
141-
"Maintained by Bo Yuan, Judy Shen, and Augustin Luna"
147+
"Maintained by Bo Yuan, Judy Shen, and Augustin Luna; contributions by Daniel Ritter"
142148
)
143149

144150
print(changelog[-1])
145151
print(
146-
"Tutorials and documentations are available at https://github.com/dfci/CellBox\n"
152+
"Tutorials and documentations are available at https://github.com/sanderlab/CellBox\n"
147153
"If you want to discuss the usage or to report a bug, please use the 'Issues' function at GitHub.\n"
148154
"If you find CellBox useful for your research, please consider citing the corresponding publication.\n"
149155
"For more information, please email us at [email protected] and [email protected], "

cellbox/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
url="https://github.com/dfci/CellBox",
1818
packages=['cellbox'],
1919
python_requires='>=3.6',
20-
install_requires=['tensorflow==1.15.0', 'numpy==1.16.0', 'pandas==0.24.2', 'scipy==1.3.0'],
20+
install_requires=['tensorflow==2.11.0', 'numpy==1.24.1', 'pandas==1.5.3', 'scipy==1.10.0'],
2121
tests_require=['pytest', 'pandas', 'numpy', 'scipy'],
2222
setup_requires=['pytest-runner', "pytest"],
2323
zip_safe=True,

0 commit comments

Comments
 (0)