Skip to content

Commit adb281f

Browse files
zdevitofacebook-github-bot
authored andcommitted
Change torch.jit.trace to no longer be a decorator (#11069)
Summary: This was done because it surprising for a decorator to run a function rather than wrap it, and not simplify the syntax for tracing modules. Pull Request resolved: pytorch/pytorch#11069 Reviewed By: jamesr66a Differential Revision: D9583192 Pulled By: zdevito fbshipit-source-id: b914b7ab4c73c255086465a6576eef3a22de1e13
1 parent 9055520 commit adb281f

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

pytorch_translate/ensemble_export.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def __init__(
512512

513513
encoder_ens = EncoderEnsemble(self.models)
514514
example_encoder_outs = encoder_ens(src_tokens, src_lengths)
515-
self.encoder_ens = torch.jit.trace(src_tokens, src_lengths)(encoder_ens)
515+
self.encoder_ens = torch.jit.trace(encoder_ens, (src_tokens, src_lengths))
516516
decoder_ens = DecoderBatchedStepEnsemble(
517517
self.models,
518518
tgt_dict,
@@ -536,14 +536,17 @@ def __init__(
536536
prev_token, prev_scores, ts, *example_encoder_outs
537537
)
538538
self.decoder_ens_tile = torch.jit.trace(
539-
prev_token, prev_scores, ts, *example_encoder_outs
540-
)(decoder_ens_tile)
539+
decoder_ens_tile, (prev_token, prev_scores, ts, *example_encoder_outs)
540+
)
541541
self.decoder_ens = torch.jit.trace(
542-
prev_token.repeat(self.beam_size),
543-
prev_scores.repeat(self.beam_size),
544-
ts,
545-
*tiled_states,
546-
)(decoder_ens)
542+
decoder_ens,
543+
(
544+
prev_token.repeat(self.beam_size),
545+
prev_scores.repeat(self.beam_size),
546+
ts,
547+
*tiled_states,
548+
),
549+
)
547550

548551
self.input_names = [
549552
"src_tokens",
@@ -858,7 +861,7 @@ def __init__(self, model_list, tgt_dict, word_reward=0, unk_reward=0):
858861

859862
encoder_ens = EncoderEnsemble(self.models)
860863
example_encoder_outs = encoder_ens(source_tokens, source_length)
861-
self.encoder_ens = torch.jit.trace(source_tokens, source_length)(encoder_ens)
864+
self.encoder_ens = torch.jit.trace(encoder_ens, (source_tokens, source_length))
862865
decoder_ens = KnownOutputDecoderStepEnsemble(
863866
self.models, tgt_dict, word_reward, unk_reward
864867
)
@@ -867,8 +870,8 @@ def __init__(self, model_list, tgt_dict, word_reward=0, unk_reward=0):
867870
ts = torch.LongTensor([0])
868871
_, *states = decoder_ens(prev_token, target_token, ts, *example_encoder_outs)
869872
self.decoder_ens = torch.jit.trace(
870-
prev_token, target_token, ts, *example_encoder_outs
871-
)(decoder_ens)
873+
decoder_ens, (prev_token, target_token, ts, *example_encoder_outs)
874+
)
872875

873876
self.input_names = [
874877
"source_tokens",

0 commit comments

Comments
 (0)