Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ Manifest.toml
benchmarks/benchmarks_output.json

.ipynb_checkpoints
*.ipynb
.devcontainer/*
2 changes: 2 additions & 0 deletions src/qobj/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ Converts a sparse QuantumObject to a dense QuantumObject.
to_dense(A::QuantumObject) = QuantumObject(to_dense(A.data), A.type, A.dimensions)
to_dense(A::MT) where {MT<:AbstractSparseArray} = Array(A)
to_dense(A::MT) where {MT<:AbstractArray} = A
to_dense(A::Diagonal) = diagm(A.diag)

to_dense(::Type{T}, A::AbstractSparseArray) where {T<:Number} = Array{T}(A)
to_dense(::Type{T1}, A::AbstractArray{T2}) where {T1<:Number,T2<:Number} = Array{T1}(A)
to_dense(::Type{T}, A::AbstractArray{T}) where {T<:Number} = A
to_dense(::Type{T}, A::Diagonal{T}) where {T<:Number} = diagm(A.diag)

function to_dense(::Type{M}) where {M<:Union{Diagonal,SparseMatrixCSC}}
T = M
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct LindbladJump{
T2,
RNGType<:AbstractRNG,
RandT,
CT<:AbstractVector,
CT<:AbstractArray,
WT<:AbstractVector,
JTT<:AbstractVector,
JWT<:AbstractVector,
Expand Down
23 changes: 14 additions & 9 deletions src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ function _mcsolve_output_func(sol, i)
return (sol, false)
end

function _normalize_state!(u, dims, normalize_states)
function _normalize_state!(u, dims, normalize_states, type)
getVal(normalize_states) && normalize!(u)
return QuantumObject(u, Ket(), dims)
return QuantumObject(u, type(), dims)
end

function _mcsolve_make_Heff_QobjEvo(H::QuantumObject, c_ops)
Expand Down Expand Up @@ -110,15 +110,15 @@ If the environmental measurements register a quantum jump, the wave function und
"""
function mcsolveProblem(
H::Union{AbstractQuantumObject{Operator},Tuple},
ψ0::QuantumObject{Ket},
ψ0::QuantumObject{X},
tlist::AbstractVector,
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
params = NullParameters(),
rng::AbstractRNG = default_rng(),
jump_callback::TJC = ContinuousLindbladJumpCallback(),
kwargs...,
) where {TJC<:LindbladJumpCallbackType}
) where {TJC<:LindbladJumpCallbackType,X<:Union{Ket,Operator}}
haskey(kwargs, :save_idxs) &&
throw(ArgumentError("The keyword argument \"save_idxs\" is not supported in QuantumToolbox."))

Expand Down Expand Up @@ -221,7 +221,7 @@ If the environmental measurements register a quantum jump, the wave function und
"""
function mcsolveEnsembleProblem(
H::Union{AbstractQuantumObject{Operator},Tuple},
ψ0::QuantumObject{Ket},
ψ0::QuantumObject{X},
tlist::AbstractVector,
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
e_ops::Union{Nothing,AbstractVector,Tuple} = nothing,
Expand All @@ -234,7 +234,7 @@ function mcsolveEnsembleProblem(
prob_func::Union{Function,Nothing} = nothing,
output_func::Union{Tuple,Nothing} = nothing,
kwargs...,
) where {TJC<:LindbladJumpCallbackType}
) where {TJC<:LindbladJumpCallbackType,X<:Union{Ket,Operator}}
_prob_func = isnothing(prob_func) ? _ensemble_dispatch_prob_func(rng, ntraj, tlist, _mcsolve_prob_func) : prob_func
_output_func =
output_func isa Nothing ?
Expand All @@ -261,6 +261,7 @@ function mcsolveEnsembleProblem(
ensemble_prob = TimeEvolutionProblem(
EnsembleProblem(prob_mc.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
prob_mc.times,
X,
prob_mc.dimensions,
(progr = _output_func[2], channel = _output_func[3]),
)
Expand Down Expand Up @@ -358,7 +359,7 @@ If the environmental measurements register a quantum jump, the wave function und
"""
function mcsolve(
H::Union{AbstractQuantumObject{Operator},Tuple},
ψ0::QuantumObject{Ket},
ψ0::QuantumObject{X},
tlist::AbstractVector,
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
alg::AbstractODEAlgorithm = DP5(),
Expand All @@ -374,7 +375,7 @@ function mcsolve(
keep_runs_results::Union{Val,Bool} = Val(false),
normalize_states::Union{Val,Bool} = Val(true),
kwargs...,
) where {TJC<:LindbladJumpCallbackType}
) where {TJC<:LindbladJumpCallbackType} where {X<:Union{Ket,Operator}}
ens_prob_mc = mcsolveEnsembleProblem(
H,
ψ0,
Expand Down Expand Up @@ -415,7 +416,11 @@ function mcsolve(
expvals_all = _expvals_all isa Nothing ? nothing : stack(_expvals_all, dims = 2) # Stack on dimension 2 to align with QuTiP

# stack to transform Vector{Vector{QuantumObject}} -> Matrix{QuantumObject}
states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)), dims = 1)
# states_all = stack(map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states), eachindex(sol)), dims = 1)
states_all = stack(
map(i -> _normalize_state!.(sol[:, i].u, Ref(dims), normalize_states, ens_prob_mc.states_type), eachindex(sol)),
dims = 1,
)

col_times = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_times, eachindex(sol))
col_which = map(i -> _mc_get_jump_callback(sol[:, i]).affect!.col_which, eachindex(sol))
Expand Down
72 changes: 47 additions & 25 deletions src/time_evolution/mesolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ _mesolve_make_L_QobjEvo(H::Union{QuantumObjectEvolution,Tuple}, c_ops) = liouvil
_mesolve_make_L_QobjEvo(H::Nothing, c_ops::Nothing) = throw(ArgumentError("Both H and
c_ops are Nothing. You are probably running the wrong function."))

function _gen_mesolve_solution(sol, times, dimensions, isoperket::Val)
if getVal(isoperket)
ρt = map(ϕ -> QuantumObject(ϕ, type = OperatorKet(), dims = dimensions), sol.u)
function _gen_mesolve_solution(sol, prob::TimeEvolutionProblem{X}) where {X<:Union{Operator,OperatorKet,SuperOperator}}
if X() == Operator()
ρt = map(ϕ -> QuantumObject(vec2mat(ϕ), type = X(), dims = prob.dimensions), sol.u)
else
ρt = map(ϕ -> QuantumObject(vec2mat(ϕ), type = Operator(), dims = dimensions), sol.u)
ρt = map(ϕ -> QuantumObject(ϕ, type = X(), dims = prob.dimensions), sol.u)
end

kwargs = NamedTuple(sol.prob.kwargs) # Convert to NamedTuple for Zygote.jl compatibility

return TimeEvolutionSol(
times,
prob.times,
sol.t,
ρt,
_get_expvals(sol, SaveFuncMESolve),
Expand Down Expand Up @@ -86,8 +86,8 @@ function mesolveProblem(
progress_bar::Union{Val,Bool} = Val(true),
inplace::Union{Val,Bool} = Val(true),
kwargs...,
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}}
(isoper(H) && isket(ψ0) && isnothing(c_ops)) && return sesolveProblem(
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}}
(isoper(H) && (isket(ψ0) || isoper(ψ0)) && isnothing(c_ops)) && return sesolveProblem(
H,
ψ0,
tlist;
Expand All @@ -107,11 +107,27 @@ function mesolveProblem(
check_dimensions(L_evo, ψ0)

T = Base.promote_eltype(L_evo, ψ0)
ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
to_dense(_complex_float_type(T), copy(ψ0.data))
# ρ0 = if isoperket(ψ0) # Convert it to dense vector with complex element type
# to_dense(_complex_float_type(T), copy(ψ0.data))
# else
# to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
# end
if isoper(ψ0)
ρ0 = to_dense(_complex_float_type(T), mat2vec(ψ0.data))
state_type = Operator()
elseif isoperket(ψ0)
ρ0 = to_dense(_complex_float_type(T), copy(ψ0.data))
state_type = OperatorKet()
elseif isket(ψ0)
ρ0 = to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
state_type = Operator()
elseif issuper(ψ0)
ρ0 = to_dense(_complex_float_type(T), copy(ψ0.data))
state_type = SuperOperator()
else
to_dense(_complex_float_type(T), mat2vec(ket2dm(ψ0).data))
throw(ArgumentError("Unsupported state type for ψ0 in mesolveProblem."))
end

L = cache_operator(L_evo.data, ρ0)

kwargs2 = _merge_saveat(tlist, e_ops, DEFAULT_ODE_SOLVER_OPTIONS; kwargs...)
Expand All @@ -122,7 +138,7 @@ function mesolveProblem(

prob = ODEProblem{getVal(inplace),FullSpecialize}(L, ρ0, tspan, params; kwargs4...)

return TimeEvolutionProblem(prob, tlist, L_evo.dimensions, (isoperket = Val(isoperket(ψ0)),))
return TimeEvolutionProblem(prob, tlist, state_type, L_evo.dimensions)#, (isoperket = Val(isoperket(ψ0)),))
end

@doc raw"""
Expand Down Expand Up @@ -188,8 +204,8 @@ function mesolve(
progress_bar::Union{Val,Bool} = Val(true),
inplace::Union{Val,Bool} = Val(true),
kwargs...,
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}}
(isoper(H) && isket(ψ0) && isnothing(c_ops)) && return sesolve(
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}}
(isoper(H) && (isket(ψ0) || isoper(ψ0)) && isnothing(c_ops)) && return sesolve(
H,
ψ0,
tlist;
Expand Down Expand Up @@ -230,7 +246,7 @@ end
function mesolve(prob::TimeEvolutionProblem, alg::AbstractODEAlgorithm = DP5(); kwargs...)
sol = solve(prob.prob, alg; kwargs...)

return _gen_mesolve_solution(sol, prob.times, prob.dimensions, prob.kwargs.isoperket)
return _gen_mesolve_solution(sol, prob)#, prob.kwargs.isoperket)
end

@doc raw"""
Expand Down Expand Up @@ -298,8 +314,8 @@ function mesolve_map(
params::Union{NullParameters,Tuple} = NullParameters(),
progress_bar::Union{Val,Bool} = Val(true),
kwargs...,
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}}
(isoper(H) && all(isket, ψ0) && isnothing(c_ops)) && return sesolve_map(
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}}
(isoper(H) && (all(isket, ψ0) || all(isoper, ψ0)) && isnothing(c_ops)) && return sesolve_map(
H,
ψ0,
tlist;
Expand All @@ -315,10 +331,16 @@ function mesolve_map(
# Convert to appropriate format based on state type
ψ0_iter = map(ψ0) do state
T = _complex_float_type(eltype(state))
if isoperket(state)
to_dense(T, copy(state.data))
if isoper(state)
to_dense(_complex_float_type(T), mat2vec(state.data))
elseif isoperket(state)
to_dense(_complex_float_type(T), copy(state.data))
elseif isket(state)
to_dense(_complex_float_type(T), mat2vec(ket2dm(state).data))
elseif issuper(state)
to_dense(_complex_float_type(T), copy(state.data))
else
to_dense(T, mat2vec(ket2dm(state).data))
throw(ArgumentError("Unsupported state type for ψ0 in mesolveProblem."))
end
end
if params isa NullParameters
Expand Down Expand Up @@ -347,7 +369,7 @@ mesolve_map(
tlist::AbstractVector,
c_ops::Union{Nothing,AbstractVector,Tuple} = nothing;
kwargs...,
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet}} =
) where {HOpType<:Union{Operator,SuperOperator},StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}} =
mesolve_map(H, [ψ0], tlist, c_ops; kwargs...)

# this method is for advanced usage
Expand All @@ -357,14 +379,14 @@ mesolve_map(
#
# Return: An array of TimeEvolutionSol objects with the size same as the given iter.
function mesolve_map(
prob::TimeEvolutionProblem{<:ODEProblem},
prob::TimeEvolutionProblem{StateOpType,<:ODEProblem},
iter::AbstractArray,
alg::AbstractODEAlgorithm = DP5(),
ensemblealg::EnsembleAlgorithm = EnsembleThreads();
prob_func::Union{Function,Nothing} = nothing,
output_func::Union{Tuple,Nothing} = nothing,
progress_bar::Union{Val,Bool} = Val(true),
)
) where {StateOpType<:Union{Ket,Operator,OperatorKet,SuperOperator}}
# generate ensemble problem
ntraj = length(iter)
_prob_func = isnothing(prob_func) ? (prob, i, repeat) -> _se_me_map_prob_func(prob, i, repeat, iter) : prob_func
Expand All @@ -380,14 +402,14 @@ function mesolve_map(
ens_prob = TimeEvolutionProblem(
EnsembleProblem(prob.prob, prob_func = _prob_func, output_func = _output_func[1], safetycopy = false),
prob.times,
StateOpType(),
prob.dimensions,
(progr = _output_func[2], channel = _output_func[3], isoperket = prob.kwargs.isoperket),
(progr = _output_func[2], channel = _output_func[3]),
)

sol = _ensemble_dispatch_solve(ens_prob, alg, ensemblealg, ntraj)

# handle solution and make it become an Array of TimeEvolutionSol
sol_vec =
[_gen_mesolve_solution(sol[:, i], prob.times, prob.dimensions, prob.kwargs.isoperket) for i in eachindex(sol)] # map is type unstable
sol_vec = [_gen_mesolve_solution(sol[:, i], prob) for i in eachindex(sol)] # map is type unstable
return reshape(sol_vec, size(iter))
end
Loading