diff --git a/pyro/compressible/interface.py b/pyro/compressible/interface.py index c6bca9717..df553661e 100644 --- a/pyro/compressible/interface.py +++ b/pyro/compressible/interface.py @@ -3,7 +3,7 @@ @njit(cache=True) -def states(idir, ng, dx, dt, +def states(idir, ng, dx, dloga, dt, irho, iu, iv, ip, ix, nspec, gamma, qv, dqv): r""" @@ -212,6 +212,27 @@ def states(idir, ng, dx, dt, q_l[i, j + 1, m] = q_l[i, j + 1, m] + sum_l q_r[i, j, m] = q_r[i, j, m] + sum_r + # Geometric Source term from converting conserved-variable to primitive + # It's only there for non Cartesian coord. + + if idir == 1: + rho_source = -0.5 * dt * dloga[i, j] * q[irho] * q[iu] + + q_l[i + 1, j, irho] += rho_source + q_r[i, j, irho] += rho_source + + q_l[i + 1, j, ip] += rho_source * cs * cs + q_r[i, j, ip] += rho_source * cs * cs + + else: + rho_source = -0.5 * dt * dloga[i, j] * q[irho] * q[iv] + + q_l[i, j + 1, irho] += rho_source + q_r[i, j, irho] += rho_source + + q_l[i, j + 1, ip] += rho_source * cs * cs + q_r[i, j, ip] += rho_source * cs * cs + return q_l, q_r diff --git a/pyro/compressible/unsplit_fluxes.py b/pyro/compressible/unsplit_fluxes.py index 637051edb..9ee7b4178 100644 --- a/pyro/compressible/unsplit_fluxes.py +++ b/pyro/compressible/unsplit_fluxes.py @@ -206,11 +206,13 @@ def interface_states(my_data, rp, ivars, tc, dt): tm_states = tc.timer("interfaceStates") tm_states.begin() - V_l, V_r = ifc.states(1, myg.ng, myg.Lx, dt, + _V_l, _V_r = ifc.states(1, myg.ng, myg.Lx, myg.dlogAx, dt, ivars.irho, ivars.iu, ivars.iv, ivars.ip, ivars.ix, ivars.naux, gamma, q, ldx) + V_l = ai.ArrayIndexer(d=_V_l, grid=myg) + V_r = ai.ArrayIndexer(d=_V_r, grid=myg) tm_states.end() @@ -225,7 +227,7 @@ def interface_states(my_data, rp, ivars, tc, dt): # left and right primitive variable states tm_states.begin() - _V_l, _V_r = ifc.states(2, myg.ng, myg.Ly, dt, + _V_l, _V_r = ifc.states(2, myg.ng, myg.Ly, myg.dlogAy, dt, ivars.irho, ivars.iu, ivars.iv, ivars.ip, ivars.ix, ivars.naux, gamma,