Skip to content

Commit 79540fc

Browse files
committed
Minor bugs fixed
1 parent 26d6d20 commit 79540fc

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

pertbio/pertbio/train.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,18 @@ def append_record(filename, contents):
9898
f.write('\n')
9999

100100

101-
def eval_model(sess, eval_iter, obj_fn, eval_dict):
101+
def eval_model(sess, eval_iter, obj_fn, eval_dict, return_avg=True):
102102
sess.run(eval_iter.initializer, feed_dict=eval_dict)
103103
eval_results = []
104104
while True:
105105
try:
106106
eval_results.append(sess.run(obj_fn, feed_dict=eval_dict))
107107
except OutOfRangeError:
108108
break
109-
return np.mean(np.array(eval_results), axis=0)
109+
if return_avg:
110+
return np.mean(np.array(eval_results), axis=0)
111+
else:
112+
return np.vstack(eval_results)
110113

111114

112115
def train_model(model, args):
@@ -226,7 +229,7 @@ def screenshot(self, sess, model, substage_i, node_index, loss_min, args):
226229

227230
if self.export_verbose > 1 or self.export_verbose == -1: # no params but y_hat
228231
sess.run(model.iter_eval.initializer, feed_dict=model.args.feed_dicts['test_set'])
229-
y_hat = sess.run(model.eval_yhat, feed_dict=model.args.feed_dicts['test_set'])
232+
y_hat = eval_model(sess, model.iter_eval, model.eval_yhat, args.feed_dicts['test_set'], return_avg=False)
230233
y_hat = pd.DataFrame(y_hat, columns=node_index[0])
231234
self.update({'y_hat': y_hat})
232235

scripts/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ def set_seed(in_seed):
2222
def prepare_workdir(in_cfg):
2323
# Read Data
2424
in_cfg.root_dir = os.getcwd()
25-
in_cfg.node_index = pd.read_csv(in_cfg.node_index_file, header=None, names=None)
25+
in_cfg.node_index = pd.read_csv(in_cfg.node_index_file, header=None, names=None) \
26+
if hasattr(in_cfg, 'node_index_file') else pd.DataFrame(np.arange(in_cfg.n_x))
27+
2628
in_cfg.loo = pd.read_csv("data/loo_label.csv", header=None)
2729

2830
# Create Output Folder

0 commit comments

Comments
 (0)