Skip to content

Commit d0b7f1a

Browse files
committed
[#18362][relax.frontend.torch] change to temporary solution
1 parent b28db36 commit d0b7f1a

File tree

4 files changed

+9
-8
lines changed

4 files changed

+9
-8
lines changed

3rdparty/dlpack

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit e2bdd3bee8cb6501558042633fa59144cc8b7f5f

3rdparty/vta-hw

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 36a91576edf633479c78649e050f18dd2ddc8103

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ def _randn(self, node: fx.Node) -> relax.Var:
7575

7676
shape = relax.ShapeExpr(size)
7777

78-
mean = relax.const(0.0, dtype)
79-
std = relax.const(1.0, dtype)
80-
81-
return self.block_builder.emit(relax.op.nn.normal(mean, std, shape, dtype))
78+
# TODO: This is a temporary solution that returns zeros instead of random values
79+
# since random initialization is mainly used during training, not inference.
80+
# This should be updated once Relax adds proper random number generation support.
81+
return self.block_builder.emit(relax.op.zeros(shape, dtype))
8282

8383
########## Neural Network ##########
8484

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6156,9 +6156,9 @@ def main(
61566156
with R.dataflow():
61576157
lv0 = R.zeros((5, 5), dtype="float32")
61586158

6159-
lv1 = R.random.normal(
6160-
R.const(0, "float32"), R.const(1, "float32"), R.shape([5]), dtype="float32"
6161-
)
6159+
# Use zeros instead of random normal distribution
6160+
lv1 = R.zeros((5,), dtype="float32")
6161+
61626162
lv2 = R.nn.elu(lv1)
61636163
lv3 = R.add(lv2, R.const(1.0, "float32"))
61646164
v = R.add(lv3, R.const(1e-8, "float32"))
@@ -6168,7 +6168,6 @@ def main(
61686168
)
61696169

61706170
L = R.tensor_update(lv0, (idx, idx), v)
6171-
61726171
y = R.add(x, R.const(1, "float32"))
61736172

61746173
gv = R.tuple(y, L)

0 commit comments

Comments
 (0)