Skip to content

Commit ad47ad6

Browse files
committed
Fix the sizes
1 parent 8bd8f03 commit ad47ad6

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, inputs=None, outputs=None,
6565
Bidirectional(GRU(units=H))
6666
]))
6767

68-
char_embedding_layer.build(input_shape=(None, None, C))
68+
# char_embedding_layer.build(input_shape=(None, None, C))
6969

7070
P_char_embeddings = char_embedding_layer(P_str)
7171
Q_char_embeddings = char_embedding_layer(Q_str)

train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
args = parser.parse_args()
3636

3737
print('Creating the model...', end='')
38-
model = RNet(hdim=args.hdim, dropout_rate=args.dropout, N=None, M=None,
38+
model = RNet(hdim=args.hdim, dropout_rate=args.dropout, N=300, M=30,
3939
char_level_embeddings=args.char_level_embeddings)
4040
print('Done!')
4141

@@ -53,7 +53,7 @@
5353
print('Done!')
5454

5555
print('Preparing generators...', end='')
56-
maxlen = [200, 200, 30, 30] if args.char_level_embeddings else [300, 30]
56+
maxlen = [300, 300, 30, 30] if args.char_level_embeddings else [300, 30]
5757

5858
train_data_gen = BatchGen(*train_data, batch_size=args.batch_size, shuffle=False, group=True, maxlen=maxlen)
5959
valid_data_gen = BatchGen(*valid_data, batch_size=args.batch_size, shuffle=False, group=True, maxlen=maxlen)

0 commit comments

Comments
 (0)