diff --git a/.gitignore b/.gitignore index 0c0eceb115..1fbdad6853 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,6 @@ test/adata.arrow test/mdata.arrow *.csv *.arrow -tutorial.md \ No newline at end of file +tutorial.md +log +examples/rl/log \ No newline at end of file diff --git a/Project.toml b/Project.toml index 84b6693946..3c7a0e638d 100644 --- a/Project.toml +++ b/Project.toml @@ -30,25 +30,31 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [weakdeps] Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45" +Crux = "e51cc422-768a-4345-bb8e-2246287ae729" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" OSMMakie = "76b6901f-8821-46bb-9129-841bc9cfe677" +POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" [extensions] AgentsArrow = "Arrow" AgentsGraphVisualizations = ["Makie", "GraphMakie"] AgentsOSMVisualizations = ["Makie", "OSMMakie"] AgentsVisualizations = "Makie" +AgentsRL = ["Crux", "POMDPs", "Flux"] [compat] Arrow = "2" CSV = "0.9.7, 0.10" CommonSolve = "0.2.4" +Crux = "0.1" DataFrames = "0.21, 0.22, 1" DataStructures = "0.18" Distributed = "1" Distributions = "0.25" Downloads = "1" +Flux = "0.14" GraphMakie = "0.5, 0.6" Graphs = "1.4" JLD2 = "0.4, 0.5" @@ -57,8 +63,9 @@ LightOSM = "0.2, 0.3" LightSumTypes = "5" LinearAlgebra = "1" MacroTools = "0.5" -Makie = "0.20, 0.21, 0.22, 0.24" +Makie = "0.20, 0.21, 0.22, 0.23" OSMMakie = "0.0, 0.1" +POMDPs = "0.9, 1" PrecompileTools = "1" ProgressMeter = "1.5" Random = "1" diff --git a/docs/Project.toml b/docs/Project.toml index 095a9a9b66..47ae9813d7 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,6 +6,7 @@ BlackBoxOptim = "a134a8b2-14d6-55f6-9291-3336d3ab0209" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CellListMap = "69e1c6dd-3888-40e6-b3c8-31ac5f578864" ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +Crux = "e51cc422-768a-4345-bb8e-2246287ae729" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DelaunayTriangulation = "927a84f5-c5f4-47a5-9785-b46e178433df" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" @@ -15,6 +16,7 @@ DocumenterTools = "35a29f4d-8980-5a13-9543-d66fff28ecb8" DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1" LightSumTypes = "f56206fc-af4c-5561-a72a-43fe2ca5a923" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73" @@ -28,6 +30,7 @@ MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" OSMMakie = "76b6901f-8821-46bb-9129-841bc9cfe677" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622" diff --git a/docs/make.jl b/docs/make.jl index e03e57e417..57280fda49 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -15,6 +15,7 @@ pages = [ "examples/predator_prey.md", "examples/rabbit_fox_hawk.md", "examples/event_rock_paper_scissors.md", + "examples/rl_boltzmann.md", "examples.md" ], "api.md", diff --git a/docs/src/api.md b/docs/src/api.md index 82c5b5a02a..f6bb846038 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,6 +8,7 @@ In this page we list the remaining API functions, which constitute the bulk of A - [`AgentBasedModel`](@ref) - [`StandardABM`](@ref) - [`EventQueueABM`](@ref) +- [`ReinforcementLearningABM`](@ref) ```@docs AgentBasedModel @@ -28,6 +29,18 @@ AgentEvent add_event! ``` +### Reinforcement learning models + +```@docs +ReinforcementLearningABM +set_rl_config! +create_policy_network +create_value_network +train_model! +get_trained_policies +copy_trained_policies! +``` + ## Agent types ```@docs @@ -94,6 +107,7 @@ OpenStreetMapSpace ``` ## Adding agents + ```@docs add_agent! add_agent_own_pos! @@ -102,6 +116,7 @@ random_position ``` ## Moving agents + ```@docs move_agent! walk! @@ -110,6 +125,7 @@ get_direction ``` ### Movement with paths + For [`OpenStreetMapSpace`](@ref), and [`GridSpace`](@ref)/[`ContinuousSpace`](@ref) using [`Pathfinding`](@ref), a special movement method is available. @@ -121,6 +137,7 @@ is_stationary ``` ## Removing agents + ```@docs remove_agent! remove_all! @@ -128,12 +145,14 @@ sample! ``` ## Space utility functions + ```@docs normalize_position spacesize ``` ## [`DiscreteSpace` exclusives](@id DiscreteSpace_exclusives) + ```@docs positions npositions @@ -154,6 +173,7 @@ isempty(::Int, ::ABM) ``` ## `GraphSpace` exclusives + ```@docs add_edge! rem_edge! @@ -162,6 +182,7 @@ rem_vertex! ``` ## [`ContinuousSpace` exclusives](@id ContinuosSpace_exclusives) + ```@docs nearest_neighbor get_spatial_property @@ -173,6 +194,7 @@ manhattan_distance ``` ## `OpenStreetMapSpace` exclusives + ```@docs OSM OSM.lonlat @@ -189,6 +211,7 @@ OSM.download_osm_network ``` ## Nearby Agents + ```@docs nearby_ids nearby_agents @@ -204,6 +227,7 @@ Most iteration in Agents.jl is **dynamic** and **lazy**, when possible, for perf **Dynamic** means that when iterating over the result of e.g. the [`ids_in_position`](@ref) function, the iterator will be affected by actions that would alter its contents. Specifically, imagine the scenario + ```@example docs using Agents # We don't need to make a new agent type here, @@ -217,16 +241,20 @@ for id in ids_in_position((1, 1, 1, 1), model) end collect(allids(model)) ``` + You will notice that only 1 agent was removed. This is simply because the final state of the iteration of `ids_in_position` was reached unnaturally, because the length of its output was reduced by 1 _during_ iteration. To avoid problems like these, you need to `collect` the iterator to have a non dynamic version. **Lazy** means that when possible the outputs of the iteration are not collected and instead are generated on the fly. A good example to illustrate this is [`nearby_ids`](@ref), where doing something like + ```julia a = random_agent(model) sort!(nearby_ids(random_agent(model), model)) ``` + leads to error, since you cannot `sort!` the returned iterator. This can be easily solved by adding a `collect` in between: + ```@example docs a = random_agent(model) sort!(collect(nearby_agents(a, model))) @@ -247,13 +275,13 @@ index_mapped_groups ``` ## Data collection and analysis + ```@docs run! ensemblerun! paramscan ``` - ### Manual data collection The central simulation function is [`run!`](@ref). @@ -268,6 +296,7 @@ dataname ``` For example, the core loop of `run!` is just + ```julia df_agent = init_agent_dataframe(model, adata) df_model = init_model_dataframe(model, mdata) @@ -286,6 +315,7 @@ while until(t, t0, n, model) end return df_agent, df_model ``` + (here `until` and `should_we_collect` are internal functions) ## [Schedulers](@id Schedulers) @@ -310,15 +340,19 @@ Schedulers.ByKind ``` ### [Advanced scheduling](@id advanced_scheduling) + You can use [Function-like objects](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects) to make your scheduling possible of arbitrary events. For example, imagine that after the `n`-th step of your simulation you want to fundamentally change the order of agents. To achieve this you can define + ```julia mutable struct MyScheduler n::Int # step number w::Float64 end ``` + and then define a calling method for it like so + ```julia function (ms::MyScheduler)(model::ABM) ms.n += 1 # increment internal counter by 1 each time its called @@ -333,17 +367,20 @@ function (ms::MyScheduler)(model::ABM) end end ``` + and pass it to e.g. `step!` by initializing it + ```julia ms = MyScheduler(100, 0.5) step!(model, agentstep, modelstep, 100; scheduler = ms) ``` - ### How to use `Distributed` + To use the `parallel=true` option of [`ensemblerun!`](@ref) you need to load `Agents` and define your fundamental types at all processors. See the [Performance Tips](@ref) page for parallelization. ## Path-finding + ```@docs Pathfinding Pathfinding.AStar @@ -353,6 +390,7 @@ Pathfinding.random_walkable ``` ### Pathfinding Metrics + ```@docs Pathfinding.DirectDistance Pathfinding.MaxDistance @@ -363,8 +401,10 @@ Building a custom metric is straightforward, if the provided ones do not suit yo See the [Developer Docs](@ref) for details. ## Save, Load, Checkpoints + There may be scenarios where interacting with data in the form of files is necessary. The following functions provide an interface to save/load data to/from files. + ```@docs AgentsIO.save_checkpoint AgentsIO.load_checkpoint @@ -373,6 +413,7 @@ AgentsIO.dump_to_csv ``` It is also possible to write data to file at predefined intervals while running your model, instead of storing it in memory: + ```@docs offline_run! ``` diff --git a/docs/src/assets/boltzmann_rl_final_state.png b/docs/src/assets/boltzmann_rl_final_state.png new file mode 100644 index 0000000000..677c4eaf51 Binary files /dev/null and b/docs/src/assets/boltzmann_rl_final_state.png differ diff --git a/docs/src/assets/boltzmann_rl_initial_state.png b/docs/src/assets/boltzmann_rl_initial_state.png new file mode 100644 index 0000000000..010ac41a90 Binary files /dev/null and b/docs/src/assets/boltzmann_rl_initial_state.png differ diff --git a/docs/src/assets/rl_boltzmann.mp4 b/docs/src/assets/rl_boltzmann.mp4 new file mode 100644 index 0000000000..26f17e2d28 Binary files /dev/null and b/docs/src/assets/rl_boltzmann.mp4 differ diff --git a/docs/src/assets/rl_boltzmann_learning_curve.png b/docs/src/assets/rl_boltzmann_learning_curve.png new file mode 100644 index 0000000000..9a523358c9 Binary files /dev/null and b/docs/src/assets/rl_boltzmann_learning_curve.png differ diff --git a/examples/rl_boltzmann.jl b/examples/rl_boltzmann.jl new file mode 100644 index 0000000000..ed9de76245 --- /dev/null +++ b/examples/rl_boltzmann.jl @@ -0,0 +1,361 @@ +# # Boltzmann Wealth Model with Reinforcement Learning + +# This example demonstrates how to integrate reinforcement learning (RL) with +# agent-based modeling using the Boltzmann wealth distribution model. In this model, +# agents move around a grid and exchange wealth when they encounter other agents, +# but their movement decisions are learned through reinforcement learning rather +# than being random. + +# The model showcases how RL agents can learn to optimize their behavior to achieve +# specific goals - in this case, reducing wealth inequality as measured by the +# Gini coefficient. + +# ## Model specification + +# The Boltzmann wealth model is a classic example in econophysics where agents +# represent economic actors who exchange wealth. The traditional model uses random +# movement, but here we replace that with learned behavior using reinforcement learning. + +# **Rules:** +# - Agents move on a 2D periodic grid +# - When agents occupy the same position, they may exchange wealth +# - Wealth flows from richer to poorer agents +# - Agent movement is learned through RL to minimize wealth inequality + +# **RL Integration:** +# - **Actions**: Stay, move North, South, East, or West (5 discrete actions) +# - **Observations**: Local neighborhood information and agent's relative wealth +# - **Reward**: Reduction in Gini coefficient (wealth inequality measure) +# - **Goal**: Learn movement patterns that promote wealth redistribution + +# ## Loading packages and defining the agent type + +using Agents, Random, Statistics, Distributions +using POMDPs, Crux, Flux + +@agent struct RLBoltzmannAgent(GridAgent{2}) + wealth::Int +end + +# ## Utility functions + +# First, we define the Gini coefficient calculation, which measures wealth inequality. +# A Gini coefficient of 0 represents perfect equality, while 1 represents maximum inequality. + +function gini(wealths::Vector{Int}) + n, sum_wi = length(wealths), sum(wealths) + (n <= 1 || sum_wi == 0.0) && return 0.0 + num = sum((2i - n - 1) * w for (i, w) in enumerate(sort(wealths))) + den = n * sum_wi + return num / den +end + +# ## Agent stepping function + +# The agent stepping function defines how agents behave in response to RL actions. +# Unlike traditional ABM where this might contain random movement, here the movement +# is determined by the RL policy based on the chosen action. + +function boltzmann_rl_step!(agent::RLBoltzmannAgent, model, action::Int) + ## Action definitions: 1=stay, 2=north, 3=south, 4=east, 5=west + dirs = ((0, 0), (0, 1), (0, -1), (1, 0), (-1, 0)) + walk!(agent, dirs[action], model; ifempty=false) + + ## Wealth exchange mechanism + other = random_agent_in_position(agent.pos, model, a -> a.id != agent.id) + if !isnothing(other) + ## Transfer wealth from richer to poorer agent + if other.wealth > agent.wealth && other.wealth > 0 + agent.wealth += 1 + other.wealth -= 1 + end + end +end + + +# ## RL-specific functions + +# The following functions define how the RL environment interacts with the ABM: +# - **Observation function**: Extracts relevant state information for the RL agent +# - **Reward function**: Defines what behavior we want to encourage +# - **Terminal function**: Determines when an episode ends + +# ### Observation function + +# The observation function provides agents with local neighborhood information. +# This includes occupancy information and relative wealth of nearby agents. + +function global_to_local(neighbor_pos, center_pos, radius, grid_dims) # helper function + function transform_dim(neighbor_coord, center_coord, dim_size) + local_center = radius + 1 + delta = neighbor_coord - center_coord + delta > radius && return local_center + (delta - dim_size) + delta < -radius && return local_center + (delta + dim_size) + return local_center + delta + end + return ntuple(i -> transform_dim(neighbor_pos[i], center_pos[i], grid_dims[i]), length(grid_dims)) +end + +function get_local_observation_boltzmann(model::ABM, agent_id::Int) + target_agent = model[agent_id] + agent_pos = target_agent.pos + width, height = spacesize(model) + observation_radius = model.rl_config[][:observation_radius] + + grid_size = 2 * observation_radius + 1 + ## 2 channels: occupancy and relative wealth + neighborhood_grid = zeros(Float32, grid_size, grid_size, 2) + + for pos in nearby_positions(target_agent.pos, model, observation_radius) + k = 0 + for neighbor in agents_in_position(pos, model) + lpos = global_to_local(pos, target_agent.pos, observation_radius, spacesize(model)) + neighbor.id == agent_id && continue + neighborhood_grid[lpos..., 1] = 1.0 + wealth_diff = Float32(neighbor.wealth - target_agent.wealth) + wealth_sum = Float32(neighbor.wealth + target_agent.wealth) + if wealth_sum > 0 + k += 1 + neighborhood_grid[lpos..., 2] = wealth_diff / wealth_sum + end + k != 0 && (neighborhood_grid[lpos..., 2] /= k) + end + end + + total_wealth = sum(a.wealth for a in allagents(model)) + normalized_wealth = total_wealth > 0 ? Float32(target_agent.wealth / total_wealth) : 0.0f0 + normalized_pos = (Float32(agent_pos[1] / width), Float32(agent_pos[2] / height)) + + return ( + normalized_wealth=normalized_wealth, + normalized_pos=normalized_pos, + neighborhood_grid=neighborhood_grid + ) +end + +# Define observation function that returns vectors directly + +function boltzmann_get_observation(model::ABM, agent_id::Int) + observation_data = get_local_observation_boltzmann(model, agent_id) + flattened_grid = vec(observation_data.neighborhood_grid) + + ## Combine all normalized features into a single vector + return vcat( + Float32(observation_data.normalized_wealth), + Float32(observation_data.normalized_pos[1]), + Float32(observation_data.normalized_pos[2]), + flattened_grid + ) +end + +# ### Reward function + +# The reward function encourages agents to reduce wealth inequality by rewarding +# decreases in the Gini coefficient. This creates an incentive for agents to learn +# movement patterns that promote wealth redistribution. + +function boltzmann_calculate_reward(env, agent, action, initial_model, final_model) + initial_wealths = [a.wealth for a in allagents(initial_model)] + final_wealths = [a.wealth for a in allagents(final_model)] + + initial_gini = gini(initial_wealths) + final_gini = gini(final_wealths) + + ## Reward decrease in Gini coefficient + reward = (initial_gini - final_gini) * 100 + reward > 0 && (reward = reward / (abmtime(env) + 1)) + ## Small penalty for neutral actions + reward <= 0.0 && (reward = -0.1f0) + + return reward +end + +# ### Terminal condition + +# Define when an RL episode should end. Here, episodes terminate when wealth +# inequality (Gini coefficient) drops below a threshold, indicating success. + +function boltzmann_is_terminal_rl(env) + wealths = [a.wealth for a in allagents(env)] + current_gini = gini(wealths) + return current_gini < 0.1 +end + +# ## Model initialization + +# The following functions handle model creation and RL configuration setup. +# Define a separate function for model initialization +function create_fresh_boltzmann_model(num_agents, dims, initial_wealth, seed=rand(Int)) + rng = MersenneTwister(seed) + space = GridSpace(dims; periodic=true) + + properties = Dict{Symbol,Any}( + :gini_coefficient => 0.0, + :step_count => 0 + ) + + model = ReinforcementLearningABM(RLBoltzmannAgent, space; + agent_step=boltzmann_rl_step!, + properties=properties, rng=rng, scheduler=Schedulers.Randomly()) + + ## Add agents with random initial wealth + for _ in 1:num_agents + add_agent_single!(RLBoltzmannAgent, model, rand(rng, 1:initial_wealth)) + end + + ## Calculate initial Gini coefficient + wealths = [a.wealth for a in allagents(model)] + model.gini_coefficient = gini(wealths) + + return model +end + +function initialize_boltzmann_rl_model(; num_agents=10, dims=(10, 10), initial_wealth=10, observation_radius=4) + ## RL configuration specifies the learning environment parameters + rl_config = ( + model_init_fn=() -> create_fresh_boltzmann_model(num_agents, dims, initial_wealth), + observation_fn=boltzmann_get_observation, + reward_fn=boltzmann_calculate_reward, + terminal_fn=boltzmann_is_terminal_rl, + agent_step_fn=boltzmann_rl_step!, + action_spaces=Dict( + RLBoltzmannAgent => Crux.DiscreteSpace(5) ## 5 possible actions + ), + observation_spaces=Dict( + ## Observation space: (2*radius+1)² grid cells * 2 channels + 3 agent features + RLBoltzmannAgent => Crux.ContinuousSpace((((2 * observation_radius + 1)^2 * 2) + 3,), Float32) + ), + training_agent_types=[RLBoltzmannAgent], + max_steps=50, + observation_radius=observation_radius + ) + + ## Create the main model using the initialization function + model = create_fresh_boltzmann_model(num_agents, dims, initial_wealth) + + ## Set the RL configuration + set_rl_config!(model, rl_config) + + return model +end + +# ## Training the RL agents + +# Now we create and train our model. The agents will learn through trial and error +# which movement patterns best achieve the goal of reducing wealth inequality. + +# Create and train the Boltzmann RL model +boltzmann_rl_model = initialize_boltzmann_rl_model() + +# Train the Boltzmann agents +train_model!( + boltzmann_rl_model, RLBoltzmannAgent; + training_steps=200000, + solver_params=Dict( + :ΔN => 200, # Custom batch size for PPO updates + :log => (period=1000,) # Log every 1000 steps +)) + +# Plot the learning curve to see how agents improved over training +plot_learning(boltzmann_rl_model.training_history[RLBoltzmannAgent]) + +# ## Running the trained model +# After training, we create a fresh model instance and apply the learned policies +# to see how well the agents perform. + +#First, create a fresh model instance for simulation with the same parameters +fresh_boltzmann_model = initialize_boltzmann_rl_model() + +# And copy the trained policies to the fresh model +copy_trained_policies!(fresh_boltzmann_model, boltzmann_rl_model) + +# Let's visualize the initial state and run a simulation to see the trained behavior. +using CairoMakie, ColorSchemes + +function agent_color(agent) # Custom color function based on wealth + max_expected_wealth = 10 + clamped_wealth = clamp(agent.wealth, 0, max_expected_wealth) + normalized_wealth = clamped_wealth / max_expected_wealth + ## Color scheme: red (poor) to green (rich) + return ColorSchemes.RdYlGn_4[normalized_wealth] +end +function agent_size(agent) # Custom size function based on wealth + max_expected_wealth = 10 + clamped_wealth = clamp(agent.wealth, 0, max_expected_wealth) + size_factor = clamped_wealth / max_expected_wealth + return 10 + size_factor * 15 +end + +fig, ax = abmplot(fresh_boltzmann_model; + agent_color=agent_color, + agent_size=agent_size, + agent_marker=:circle +) +ax.title = "Boltzmann Wealth Distribution (Initial State)" +ax.xlabel = "X Position" +ax.ylabel = "Y Position" +fig + +# Run simulation with trained agents on the fresh model +initial_wealths = [a.wealth for a in allagents(fresh_boltzmann_model)] +initial_gini = gini(initial_wealths) +"Initial wealth distribution anf Gini coefficient: $initial_wealths, $initial_gini" + +# Step the model forward to see the trained behavior +Agents.step!(fresh_boltzmann_model, 10) + +# Check the results after simulation +final_wealths = [a.wealth for a in allagents(fresh_boltzmann_model)] +final_gini = gini(final_wealths) +"Final wealth distribution and Gini coefficient: $final_wealths, $final_gini" + +# Plot the final state +fig, ax = abmplot(fresh_boltzmann_model; + agent_color=agent_color, + agent_size=agent_size, + agent_marker=:circle +) +ax.title = "Boltzmann Wealth Distribution (After 10 RL Steps)" +ax.xlabel = "X Position" +ax.ylabel = "Y Position" +fig + +# Finally, let's create a video showing the trained agents in action over multiple steps +# on a bigger scale, and compare visually with a random policy + +# Random policy because no policy is specified +fresh_boltzmann_model = initialize_boltzmann_rl_model(; num_agents=500, dims=(100, 100)) +plotkwargs = (; + agent_color=agent_color, + agent_size=agent_size, + agent_marker=:circle +) +abmvideo("boltzmann.mp4", fresh_boltzmann_model; frames=100, + framerate=20, + title="Boltzmann Money Model with Random Agents", + plotkwargs...) + +# We know copy the trained policies and the agents are...smarter! +fresh_boltzmann_model = initialize_boltzmann_rl_model(; num_agents=500, dims=(100, 100)) +copy_trained_policies!(fresh_boltzmann_model, boltzmann_rl_model) +abmvideo("rl_boltzmann.mp4", fresh_boltzmann_model; frames=100, + framerate=20, + title="Boltzmann Money Model with RL Agents", + plotkwargs...) + +# ## Key takeaways + +# This example demonstrates several important concepts: + +# 1. **RL-ABM Integration**: How to seamlessly integrate reinforcement learning +# with agent-based modeling using the `ReinforcementLearningABM` type. + +# 2. **Custom Reward Design**: The reward function encourages behavior that +# reduces wealth inequality, showing how RL can optimize for specific outcomes. + +# 3. **Observation Engineering**: Agents observe their local neighborhood and +# relative wealth position, providing them with relevant information for decision-making. + +# 4. **Policy Transfer**: Trained policies can be copied to fresh model instances, +# enabling evaluation and deployment of learned behaviors. + diff --git a/examples/rl_wolfsheep.jl b/examples/rl_wolfsheep.jl new file mode 100644 index 0000000000..db76b19c2b --- /dev/null +++ b/examples/rl_wolfsheep.jl @@ -0,0 +1,620 @@ +# # Predator-Prey Model with Reinforcement Learning + +# This example demonstrates how to integrate reinforcement learning (RL) with +# the classic predator-prey model. Building on the traditional Wolf-Sheep model, +# this version replaces random movement with learned behavior, where agents use +# reinforcement learning to optimize their survival and reproduction strategies. + +# The model showcases how RL agents can learn complex behaviors in multi-species +# ecosystems, with wolves learning to hunt efficiently and sheep learning to +# avoid predators while foraging for grass. + +# ## Model specification + +# This model extends the classic predator-prey dynamics with reinforcement learning: + +# **Environment:** +# - 2D periodic grid with grass that regrows over time +# - Wolves hunt sheep for energy +# - Sheep eat grass for energy +# - Both species can reproduce when they have sufficient energy + +# **RL Integration:** +# - **Actions**: Stay, move North, South, East, or West (5 discrete actions) +# - **Observations**: Local neighborhood information including other agents, grass, and own energy +# - **Rewards**: Survival, energy maintenance, successful feeding, and reproduction +# - **Goal**: Learn optimal movement and foraging/hunting strategies + +# **Key differences from traditional model:** +# - Movement decisions are learned rather than random +# - Agents can develop sophisticated strategies over time +# - Emergent behaviors arise from individual learning rather than hard-coded rules + +# ## Loading packages and defining agent types + +# ```julia +# using Agents, Random, Statistics, POMDPs, Crux, Flux, Distributions +# +# @agent struct RLSheep(GridAgent{2}) +# energy::Float64 +# reproduction_prob::Float64 +# Δenergy::Float64 +# end +# +# @agent struct RLWolf(GridAgent{2}) +# energy::Float64 +# reproduction_prob::Float64 +# Δenergy::Float64 +# end +# ``` + +# ## Agent stepping functions + +# The stepping functions define how agents behave in response to RL actions. +# Unlike the traditional model with random movement, here movement is determined +# by the RL policy based on the learned strategy. + +# ### Sheep stepping function + +# Sheep must balance energy conservation, grass foraging, and predator avoidance. + +# ```julia +# # Wolf-sheep RL step functions +# function sheepwolf_step_rl!(sheep::RLSheep, model, action::Int) +# # Action definitions: 1=stay, 2=north, 3=south, 4=east, 5=west +# current_x, current_y = sheep.pos +# width, height = getfield(model, :space).extent +# +# dx, dy = 0, 0 +# if action == 2 # North +# dy = 1 +# elseif action == 3 # South +# dy = -1 +# elseif action == 4 # East +# dx = 1 +# elseif action == 5 # West +# dx = -1 +# end +# +# # Apply periodic boundary wrapping and move +# if action != 1 # If not staying +# new_x = mod1(current_x + dx, width) +# new_y = mod1(current_y + dy, height) +# target_pos = (new_x, new_y) +# move_agent!(sheep, target_pos, model) +# end +# +# # Energy decreases with each step (movement cost) +# sheep.energy -= 1 +# if sheep.energy < 0 +# remove_agent!(sheep, model) +# return +# end +# +# # Try to eat grass if available +# if model.fully_grown[sheep.pos...] +# sheep.energy += sheep.Δenergy +# model.fully_grown[sheep.pos...] = false +# model.countdown[sheep.pos...] = model.regrowth_time +# end +# +# # Reproduce if energy is sufficient +# if rand(abmrng(model)) ≤ sheep.reproduction_prob +# sheep.energy /= 2 +# replicate!(sheep, model) +# end +# end +# ``` + +# ### Wolf stepping function + +# Wolves must learn efficient hunting strategies while managing their energy reserves. + +# ```julia +# # WOLF Step +# function sheepwolf_step_rl!(wolf::RLWolf, model, action::Int) +# # Action definitions: 1=stay, 2=north, 3=south, 4=east, 5=west +# current_x, current_y = wolf.pos +# width, height = getfield(model, :space).extent +# +# dx, dy = 0, 0 +# if action == 2 # North +# dy = 1 +# elseif action == 3 # South +# dy = -1 +# elseif action == 4 # East +# dx = 1 +# elseif action == 5 # West +# dx = -1 +# end +# +# # Apply periodic boundary wrapping and move +# if action != 1 # If not staying +# new_x = mod1(current_x + dx, width) +# new_y = mod1(current_y + dy, height) +# move_agent!(wolf, (new_x, new_y), model) +# end +# +# # Energy decreases with each step +# wolf.energy -= 1 +# if wolf.energy < 0 +# remove_agent!(wolf, model) +# return +# end +# +# # Hunt sheep if available at current position +# sheep_ids = [id for id in ids_in_position(wolf.pos, model) if haskey(model.agents, id) && model[id] isa RLSheep] +# if !isempty(sheep_ids) +# dinner = model[sheep_ids[1]] +# remove_agent!(dinner, model) +# wolf.energy += wolf.Δenergy +# end +# +# # Reproduce if energy is sufficient +# if rand(abmrng(model)) ≤ wolf.reproduction_prob +# wolf.energy /= 2 +# replicate!(wolf, model) +# end +# end +# ``` + +# ### Grass dynamics and unified stepping + +# Grass regrows over time, providing a renewable resource for sheep. + +# ```julia +# function grass_step!(model) +# @inbounds for p in positions(model) +# if !(model.fully_grown[p...]) # If grass is not fully grown +# if model.countdown[p...] ≤ 0 +# model.fully_grown[p...] = true # Regrow grass +# else +# model.countdown[p...] -= 1 # Countdown to regrowth +# end +# end +# end +# end +# +# # Unified stepping function for both agent types +# function wolfsheep_rl_step!(agent::Union{RLSheep,RLWolf}, model, action::Int) +# if agent isa RLSheep +# sheepwolf_step_rl!(agent, model, action) +# elseif agent isa RLWolf +# sheepwolf_step_rl!(agent, model, action) +# end +# +# # Stochastic grass regrowth +# if rand(abmrng(model)) < 0.6 +# grass_step!(model) +# end +# end +# +# function agent_wolfsheep_rl_step!(agent::Union{RLSheep,RLWolf}, model, action::Int) +# if agent isa RLSheep +# sheepwolf_step_rl!(agent, model, action) +# elseif agent isa RLWolf +# sheepwolf_step_rl!(agent, model, action) +# end +# end +# ``` + +# ## RL-specific functions + +# The following functions define how the RL environment interacts with the ABM: +# - **Observation function**: Provides agents with local environmental information +# - **Reward function**: Shapes learning by rewarding desired behaviors +# - **Terminal function**: Determines when episodes end + +# ### Observation function + +# Agents observe their local neighborhood, including other agents, grass availability, +# and their own status. This information helps them make informed decisions. + +# ```julia +# # Wolf-sheep observation function +# function get_local_observation(model::ABM, agent_id::Int, observation_radius::Int) +# target_agent = model[agent_id] +# agent_pos = target_agent.pos +# width, height = getfield(model, :space).extent +# agent_type = target_agent isa RLSheep ? :sheep : :wolf +# +# grid_size = 2 * observation_radius + 1 +# # 3 channels: sheep, wolves, grass +# neighborhood_grid = zeros(Float32, grid_size, grid_size, 3, 1) +# +# # Get valid neighboring agents +# neighbor_ids = nearby_ids(target_agent, model, observation_radius) +# valid_neighbors = [] +# for id in neighbor_ids +# if haskey(model.agents, id) && id != agent_id +# push!(valid_neighbors, model[id]) +# end +# end +# +# # Map neighbors to observation grid +# for neighbor in valid_neighbors +# dx = neighbor.pos[1] - agent_pos[1] +# dy = neighbor.pos[2] - agent_pos[2] +# +# # Handle periodic boundaries +# if abs(dx) > width / 2 +# dx -= sign(dx) * width +# end +# if abs(dy) > height / 2 +# dy -= sign(dy) * height +# end +# +# grid_x = dx + observation_radius + 1 +# grid_y = dy + observation_radius + 1 +# +# if 1 <= grid_x <= grid_size && 1 <= grid_y <= grid_size +# if neighbor isa RLSheep +# neighborhood_grid[grid_x, grid_y, 1, 1] = 1.0 # Sheep channel +# elseif neighbor isa RLWolf +# neighborhood_grid[grid_x, grid_y, 2, 1] = 1.0 # Wolf channel +# end +# end +# end +# +# # Add grass information to observation +# for dx in -observation_radius:observation_radius +# for dy in -observation_radius:observation_radius +# pos_x = mod1(agent_pos[1] + dx, width) +# pos_y = mod1(agent_pos[2] + dy, height) +# +# grid_x = dx + observation_radius + 1 +# grid_y = dy + observation_radius + 1 +# +# if model.fully_grown[pos_x, pos_y] +# neighborhood_grid[grid_x, grid_y, 3, 1] = 1.0 # Grass channel +# end +# end +# end +# +# # Normalize agent's own information +# normalized_energy = Float32(target_agent.energy / 40.0) +# normalized_pos = (Float32(agent_pos[1] / width), Float32(agent_pos[2] / height)) +# +# return ( +# agent_id=agent_id, +# agent_type=agent_type, +# own_energy=normalized_energy, +# normalized_pos=normalized_pos, +# neighborhood_grid=neighborhood_grid +# ) +# end +# +# # Convert observation to vector format for neural networks +# function wolfsheep_get_observation(model, agent_id, observation_radius) +# observation_data = get_local_observation(model, agent_id, observation_radius) +# +# # Flatten spatial information +# flattened_grid = vec(observation_data.neighborhood_grid) +# +# # Combine all features into a single observation vector +# return vcat( +# Float32(observation_data.own_energy), +# Float32(observation_data.normalized_pos[1]), +# Float32(observation_data.normalized_pos[2]), +# Float32(observation_data.agent_type == :sheep ? 1.0 : 0.0), # Agent type indicator +# flattened_grid +# ) +# end +# ``` + +# ### Reward function + +# The reward function shapes agent learning by providing feedback on their actions. +# Different strategies are used for sheep (survival and foraging) vs wolves (hunting). + +# ```julia +# # Define reward function +# function wolfsheep_calculate_reward(env, agent, action, initial_model, final_model) +# # Death penalty - strongest negative reward +# if agent.id ∉ [a.id for a in allagents(final_model)] +# return -50.0 +# end +# +# if agent isa RLSheep +# # Sheep rewards: survival, energy maintenance, successful foraging +# reward = 1.0 # Base survival bonus +# +# # Energy level bonus (normalized) +# energy_ratio = agent.energy / 20.0 +# reward += energy_ratio * 0.5 +# +# # Bonus for successful foraging (energy increase) +# if haskey(initial_model.agents, agent.id) +# initial_energy = initial_model[agent.id].energy +# if agent.energy > initial_energy +# reward += 0.5 # Foraging success bonus +# end +# end +# +# return reward +# +# else # Wolf +# # Wolf rewards: survival, energy maintenance, successful hunting +# reward = 1.0 # Base survival bonus +# +# # Energy level bonus (wolves can have higher energy) +# energy_ratio = agent.energy / 40.0 +# reward += energy_ratio * 0.3 +# +# # Large bonus for successful hunting (significant energy increase) +# if haskey(initial_model.agents, agent.id) +# initial_energy = initial_model[agent.id].energy +# if agent.energy > initial_energy + 10 # Indicates successful hunt +# reward += 0.5 # Hunting success bonus +# end +# end +# +# return reward +# end +# end +# ``` + +# ### Terminal condition + +# Episodes end when either species goes extinct, creating natural stopping points +# for learning episodes while maintaining ecological realism. + +# ```julia +# # Define terminal condition for RL model +# function wolfsheep_is_terminal_rl(env) +# sheep_count = length([a for a in allagents(env) if a isa RLSheep]) +# wolf_count = length([a for a in allagents(env) if a isa RLWolf]) +# return sheep_count == 0 || wolf_count == 0 +# end +# ``` + +# ## Model initialization + +# The following functions handle model creation and RL configuration setup, +# similar to the traditional wolf-sheep model but with RL capabilities. + +# ```julia +# function create_fresh_wolfsheep_model(n_sheeps, n_wolves, dims, regrowth_time, Δenergy_sheep, +# Δenergy_wolf, sheep_reproduce, wolf_reproduce, seed) +# +# rng = MersenneTwister(seed) +# space = GridSpace(dims, periodic=true) +# +# # Model properties for grass dynamics +# properties = Dict{Symbol,Any}( +# :fully_grown => falses(dims), +# :countdown => zeros(Int, dims), +# :regrowth_time => regrowth_time, +# ) +# +# # Create the ReinforcementLearningABM +# model = ReinforcementLearningABM(Union{RLSheep,RLWolf}, space; +# agent_step=agent_wolfsheep_rl_step!, model_step=grass_step!, +# properties=properties, rng=rng, +# scheduler=Schedulers.Randomly()) +# +# # Add sheep agents +# for _ in 1:n_sheeps +# energy = rand(abmrng(model), 1:(Δenergy_sheep*2)) - 1 +# add_agent!(RLSheep, model, energy, sheep_reproduce, Δenergy_sheep) +# end +# +# # Add wolf agents +# for _ in 1:n_wolves +# energy = rand(abmrng(model), 1:(Δenergy_wolf*2)) - 1 +# add_agent!(RLWolf, model, energy, wolf_reproduce, Δenergy_wolf) +# end +# +# # Initialize grass with random growth states +# for p in positions(model) +# fully_grown = rand(abmrng(model), Bool) +# countdown = fully_grown ? regrowth_time : rand(abmrng(model), 1:regrowth_time) - 1 +# model.countdown[p...] = countdown +# model.fully_grown[p...] = fully_grown +# end +# +# return model +# end +# +# # Initialize model function for RL ABM +# function initialize_rl_model(; n_sheeps=30, n_wolves=5, dims=(10, 10), regrowth_time=10, +# Δenergy_sheep=5, Δenergy_wolf=20, sheep_reproduce=0.2, wolf_reproduce=0.05, +# observation_radius=4, seed=1234) +# +# # RL configuration specifying learning environment parameters +# rl_config = ( +# model_init_fn=() -> create_fresh_wolfsheep_model(n_sheeps, n_wolves, dims, regrowth_time, +# Δenergy_sheep, Δenergy_wolf, sheep_reproduce, wolf_reproduce, seed), +# observation_fn=wolfsheep_get_observation, +# reward_fn=wolfsheep_calculate_reward, +# terminal_fn=wolfsheep_is_terminal_rl, +# agent_step_fn=wolfsheep_rl_step!, +# action_spaces=Dict( +# RLSheep => Crux.DiscreteSpace(5), # 5 movement actions +# RLWolf => Crux.DiscreteSpace(5) # 5 movement actions +# ), +# observation_spaces=Dict( +# RLSheep => Crux.ContinuousSpace((((2 * observation_radius + 1)^2 * 3) + 4,), Float32), +# RLWolf => Crux.ContinuousSpace((((2 * observation_radius + 1)^2 * 3) + 4,), Float32) +# ), +# training_agent_types=[RLSheep, RLWolf], +# max_steps=300, +# observation_radius=observation_radius, +# discount_rates=Dict( +# RLSheep => 0.99, # Long-term planning for survival +# RLWolf => 0.99 # Long-term planning for hunting +# ) +# ) +# +# # Create the model and set RL configuration +# model = create_fresh_wolfsheep_model(n_sheeps, n_wolves, dims, regrowth_time, Δenergy_sheep, +# Δenergy_wolf, sheep_reproduce, wolf_reproduce, seed) +# +# set_rl_config!(model, rl_config) +# +# return model +# end +# ``` + +# ## Training the RL agents + +# Now we create the model and train both sheep and wolves simultaneously. +# This creates a co-evolutionary dynamic where both species adapt to each other. + +# ```julia +# # Create the model +# rl_model = initialize_rl_model(n_sheeps=50, n_wolves=10, dims=(20, 20), regrowth_time=30, +# Δenergy_sheep=4, Δenergy_wolf=20, sheep_reproduce=0.04, wolf_reproduce=0.05, seed=1234) +# +# println("Created ReinforcementLearningABM with $(nagents(rl_model)) agents") +# println("Sheep: $(length([a for a in allagents(rl_model) if a isa RLSheep]))") +# println("Wolves: $(length([a for a in allagents(rl_model) if a isa RLWolf]))") +# ``` + +# Train both species simultaneously +# ```julia +# println("\nTraining wolves and sheep with reinforcement learning...") +# try +# train_model!(rl_model, [RLSheep, RLWolf]; +# training_mode=:simultaneous, # Both species learn at the same time +# n_iterations=5, +# batch_size=400 * nagents(rl_model), +# solver_params=Dict( +# :ΔN => 100 * nagents(rl_model), +# :log => (period=100 * nagents(rl_model),), +# :max_steps => 200 * nagents(rl_model) +# )) +# println("Training completed successfully") +# catch e +# println("Training failed with error: $e") +# rethrow(e) +# end +# ``` + + +# ## Running the trained model + +# After training, we create a fresh model instance and apply the learned policies +# to observe how the trained agents behave in the predator-prey ecosystem. + +# ```julia +# # Create a fresh model instance for simulation +# println("\nCreating fresh Wolf-Sheep model for simulation...") +# fresh_ws_model = initialize_rl_model(n_sheeps=50, n_wolves=10, dims=(20, 20), regrowth_time=30, +# Δenergy_sheep=4, Δenergy_wolf=20, sheep_reproduce=0.04, wolf_reproduce=0.05, seed=1234) +# +# # Copy the trained policies to the fresh model +# copy_trained_policies!(fresh_ws_model, rl_model) +# println("Applied trained policies to fresh model") +# ``` + +# ## Visualization + +# Let's visualize the ecosystem and observe the learned behaviors. + +# ```julia +# using CairoMakie, ColorSchemes +# CairoMakie.activate!() +# +# # Define colors and markers for different agent types +# function agent_color(agent) +# if agent isa RLSheep +# return :lightblue # Sheep are light blue +# elseif agent isa RLWolf +# return :red # Wolves are red +# else +# return :black # Fallback color +# end +# end +# +# function agent_marker(agent) +# if agent isa RLSheep +# return :circle # Sheep are circles +# elseif agent isa RLWolf +# return :rect # Wolves are squares +# else +# return :circle # Fallback marker +# end +# end +# +# # Plot the initial state +# fig, ax = abmplot(fresh_ws_model; +# agent_color=agent_color, +# agent_marker=agent_marker +# ) +# display(fig) +# ``` + +# Run simulation with trained agents +# ```julia +# println("\nRunning simulation with trained RL agents...") +# initial_sheep = length([a for a in allagents(fresh_ws_model) if a isa RLSheep]) +# initial_wolves = length([a for a in allagents(fresh_ws_model) if a isa RLWolf]) +# println("Initial populations - Sheep: $initial_sheep, Wolves: $initial_wolves") +# +# # Step the model forward to observe trained behavior +# try +# Agents.step!(fresh_ws_model, 200) +# println("Simulation completed successfully") +# catch e +# println("Simulation failed with error: $e") +# rethrow(e) +# end +# +# # Check final population numbers +# final_sheep = length([a for a in allagents(fresh_ws_model) if a isa RLSheep]) +# final_wolves = length([a for a in allagents(fresh_ws_model) if a isa RLWolf]) +# +# println("Population changes after 200 steps:") +# println("Sheep: $initial_sheep → $final_sheep") +# println("Wolves: $initial_wolves → $final_wolves") +# +# # Analyze the results +# if final_sheep > 0 && final_wolves > 0 +# println("Success! Both species coexist - predator-prey balance maintained") +# elseif final_sheep == 0 +# println("Wolves were too successful - sheep went extinct") +# elseif final_wolves == 0 +# println("Sheep outlasted wolves - predators died out") +# end +# ``` + +# ## Creating an animation + +# Create a video showing the trained ecosystem dynamics over time. +# ```julia +# fresh_ws_model = initialize_rl_model(n_sheeps=50, n_wolves=10, dims=(20, 20), regrowth_time=30, +# Δenergy_sheep=4, Δenergy_wolf=20, sheep_reproduce=0.04, wolf_reproduce=0.05, seed=1234) +# +# # Copy the trained policies to the fresh model +# copy_trained_policies!(fresh_ws_model, rl_model) +# +# plotkwargs = ( +# agent_color=agent_color, +# agent_marker=agent_marker, +# ) +# abmvideo("wolfsheep_rl.mp4", fresh_ws_model; frames=100, +# framerate=2, +# title="Wolf-Sheep Model with RL - Blue=Sheep, Red=Wolves", +# plotkwargs...) +# ``` + + +# ## Key takeaways + +# This example demonstrates several important concepts: + +# 1. **Multi-agent RL**: Both predator and prey species learn simultaneously, +# creating co-evolutionary dynamics where each species adapts to the other. + +# 2. **Complex reward structures**: Different reward functions for different agent types +# (survival for sheep, hunting for wolves) lead to emergent ecological behaviors. + +# 3. **Spatial awareness**: Agents learn to use local environmental information +# (locations of prey/predators, grass availability) to make strategic decisions. + +# 4. **Emergent strategies**: Trained agents may develop sophisticated behaviors like +# flocking (sheep), pursuit strategies (wolves), or territorial behaviors. + +# 5. **Ecosystem dynamics**: The learned behaviors can lead to more realistic +# predator-prey cycles compared to purely random movement models. diff --git a/ext/AgentsRL/AgentsRL.jl b/ext/AgentsRL/AgentsRL.jl new file mode 100644 index 0000000000..d8e9804578 --- /dev/null +++ b/ext/AgentsRL/AgentsRL.jl @@ -0,0 +1,10 @@ +module AgentsRL + +using Agents, Crux, POMDPs, Flux, Distributions, Random + +# Import reinforcement learning functions from the extension +include("src/rl_utils.jl") +include("src/rl_training_functions.jl") +include("src/step_reinforcement_learning.jl") + +end \ No newline at end of file diff --git a/ext/AgentsRL/src/rl_training_functions.jl b/ext/AgentsRL/src/rl_training_functions.jl new file mode 100644 index 0000000000..ff6da7b19d --- /dev/null +++ b/ext/AgentsRL/src/rl_training_functions.jl @@ -0,0 +1,309 @@ +function Agents.setup_rl_training(model::ReinforcementLearningABM, agent_type; + training_steps=50_000, + max_steps=nothing, + value_network=nothing, + policy_network=nothing, + solver=nothing, + solver_type=:PPO, + solver_params=Dict() +) + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + # Set the current training agent type in the model + model.current_training_agent_type[] = agent_type + + # Wrap the model for POMDPs compatibility + env = wrap_for_rl_training(model) + + # If a complete solver is provided, use it directly + if !isnothing(solver) + return env, solver + end + + # Get observation and action spaces + O = POMDPs.observations(env) + as = POMDPs.actions(env).vals + + # Define neural network architecture + if isnothing(value_network) + value_net = ContinuousNetwork(Chain(Dense(Crux.dim(O)..., 64, relu), Dense(64, 64, relu), Dense(64, 1))) + else + value_net = value_network() + end + + if isnothing(policy_network) + policy_net = DiscreteNetwork(Chain(Dense(Crux.dim(O)..., 64, relu), Dense(64, 64, relu), Dense(64, length(as))), as) + else + policy_net = policy_network() + end + + if isnothing(max_steps) + max_steps = model.rl_config[][:max_steps] + end + + # Create solver based on type + if solver_type == :PPO + default_params = Dict( + :π => ActorCritic(policy_net, value_net), + :S => O, + :N => training_steps, + :ΔN => 200, + :max_steps => max_steps, + :log => (period=1000,) + ) + merged_params = merge(default_params, solver_params) + solver = PPO(; merged_params...) + elseif solver_type == :DQN + if isnothing(policy_network) + QS_net = DiscreteNetwork( + Chain( + Dense(Crux.dim(O)[1], 64, relu), + Dense(64, 64, relu), + Dense(64, length(as)) + ), as + ) + else + QS_net = policy_network() + end + default_params = Dict( + :π => QS_net, + :S => O, + :N => training_steps, + :max_steps => max_steps, + :buffer_size => 10000, + :buffer_init => 1000, + :ΔN => 50 + ) + merged_params = merge(default_params, solver_params) + solver = DQN(; merged_params...) + elseif solver_type == :A2C + default_params = Dict( + :π => ActorCritic(policy_net, value_net), + :S => O, + :N => training_steps, + :ΔN => 20, + :max_steps => max_steps, + :log => (period=1000,) + ) + merged_params = merge(default_params, solver_params) + solver = A2C(; merged_params...) + else + error("Unsupported solver type: $solver_type.") + end + + return env, solver +end + +function Agents.train_agent_sequential(model::ReinforcementLearningABM, agent_types; + training_steps=50_000, + custom_networks=Dict(), + custom_solvers=Dict(), + solver_types=Dict(), + solver_params=Dict() +) + println("Training agents sequentially...") + + # Ensure agent_types is a vector + agent_types_vec = agent_types isa Vector ? agent_types : [agent_types] + + policies = Dict{Type,Any}() + solvers = Dict{Type,Any}() + + for (i, agent_type) in enumerate(agent_types_vec) + println("Training $(agent_type) ($(i)/$(length(agent_types_vec)))...") + + # Get custom parameters for this agent type + agent_networks = get(custom_networks, agent_type, Dict()) + value_net = get(agent_networks, :value_network, nothing) + policy_net = get(agent_networks, :policy_network, nothing) + custom_solver = get(custom_solvers, agent_type, nothing) + solver_type = get(solver_types, agent_type, :PPO) + solver_params_agent = Agents.process_solver_params(solver_params, agent_type) + + # Set up training + env, solver = Agents.setup_rl_training( + model, + agent_type; + training_steps=training_steps, + value_network=value_net, + policy_network=policy_net, + solver=custom_solver, + solver_type=solver_type, + solver_params=solver_params_agent + ) + + # Add previously trained policies to the model + for (prev_type, policy) in policies + model.trained_policies[prev_type] = policy + end + + # Train the agent + policy = solve(solver, env) + policies[agent_type] = policy + solvers[agent_type] = solver + + println("Completed training $(agent_type)") + end + + return policies, solvers +end + +function Agents.train_agent_simultaneous(model::ReinforcementLearningABM, agent_types; + n_iterations=5, + batch_size=10_000, + custom_networks=Dict(), + custom_solvers=Dict(), + solver_types=Dict(), + solver_params=Dict() +) + println("Training agents simultaneously...") + + # Ensure agent_types is a vector + agent_types_vec = agent_types isa Vector ? agent_types : [agent_types] + + # Initialize solvers for each agent type + solvers = Dict{Type,Any}() + envs = Dict{Type,Any}() + + for agent_type in agent_types_vec + println("Setting up solver for $(agent_type)...") + + # Get custom parameters for this agent type + agent_networks = get(custom_networks, agent_type, Dict()) + value_net = get(agent_networks, :value_network, nothing) + policy_net = get(agent_networks, :policy_network, nothing) + custom_solver = get(custom_solvers, agent_type, nothing) + solver_type = get(solver_types, agent_type, :PPO) + solver_params_agent = Agents.process_solver_params(solver_params, agent_type) + + env, solver = Agents.setup_rl_training( + model, + agent_type; + training_steps=batch_size, + value_network=value_net, + policy_network=policy_net, + solver=custom_solver, + solver_type=solver_type, + solver_params=solver_params_agent + ) + + envs[agent_type] = env + solvers[agent_type] = solver + end + + policies = Dict{Type,Any}() + + # Train in alternating batches + for iter in 1:n_iterations + println("Iteration $(iter)/$(n_iterations)") + + for agent_type in agent_types_vec + println(" Training $(agent_type)...") + + # Update model with current policies + for (other_type, policy) in policies + if other_type != agent_type + model.trained_policies[other_type] = policy + end + end + + # Train the agent + policy = solve(solvers[agent_type], envs[agent_type]) + policies[agent_type] = policy + end + end + + return policies, solvers +end + +function Agents.create_value_network(input_dims, hidden_layers=[64, 64], activation=relu) + layers = [] + + # Input layer + push!(layers, Dense(input_dims..., hidden_layers[1], activation)) + + # Hidden layers + for i in 1:(length(hidden_layers)-1) + push!(layers, Dense(hidden_layers[i], hidden_layers[i+1], activation)) + end + + # Output layer + push!(layers, Dense(hidden_layers[end], 1)) + + return () -> ContinuousNetwork(Chain(layers...)) +end + +function Agents.create_policy_network(input_dims, output_dims, action_space, hidden_layers=[64, 64], activation=relu) + layers = [] + + # Input layer + push!(layers, Dense(input_dims..., hidden_layers[1], activation)) + + # Hidden layers + for i in 1:(length(hidden_layers)-1) + push!(layers, Dense(hidden_layers[i], hidden_layers[i+1], activation)) + end + + # Output layer + push!(layers, Dense(hidden_layers[end], output_dims)) + + return () -> DiscreteNetwork(Chain(layers...), action_space) +end + +function Agents.create_custom_solver(solver_type, π, S; custom_params...) + if solver_type == :PPO + return PPO(π=π, S=S; custom_params...) + elseif solver_type == :DQN + return DQN(π=π, S=S; custom_params...) + elseif solver_type == :A2C + return A2C(π=π, S=S; custom_params...) + else + error("Unsupported solver type: $solver_type") + end +end + + +function Agents.train_model!(model::ReinforcementLearningABM, agent_types; + training_mode::Symbol=:sequential, + kwargs...) + + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + # Ensure agent_types is a vector + agent_types_vec = agent_types isa Vector ? agent_types : [agent_types] + + # Set training flag + model.is_training[] = true + + try + # Train agents based on mode + if training_mode == :sequential + policies, solvers = Agents.train_agent_sequential(model, agent_types_vec; kwargs...) + elseif training_mode == :simultaneous + policies, solvers = Agents.train_agent_simultaneous(model, agent_types_vec; kwargs...) + else + error("Unknown training mode: $training_mode. Use :sequential or :simultaneous.") + end + + # Store trained policies + for (agent_type, policy) in policies + model.trained_policies[agent_type] = policy + end + + # Store training history (solvers) + for (agent_type, solver) in solvers + model.training_history[agent_type] = solver + end + + println("Training completed for agent types: $(join(string.(agent_types_vec), ", "))") + + finally + model.is_training[] = false + end + + return model +end diff --git a/ext/AgentsRL/src/rl_utils.jl b/ext/AgentsRL/src/rl_utils.jl new file mode 100644 index 0000000000..88c962149a --- /dev/null +++ b/ext/AgentsRL/src/rl_utils.jl @@ -0,0 +1,522 @@ +""" + RLEnvironmentWrapper{M} <: POMDPs.POMDP{Vector{Float32}, Int, Vector{Float32}} + +A wrapper around `ReinforcementLearningABM` that implements the POMDPs.POMDP interface +to enable training with RL algorithms that require POMDPs compatibility. + +This wrapper serves as a bridge between Agent-Based Models and Reinforcement Learning +algorithms, translating between ABM concepts and RL concepts: + +- **States**: ABM state → Vector{Float32} representations +- **Actions**: Discrete integer actions → Agent behaviors +- **Observations**: Agent-centric views → Vector{Float32} feature vectors +- **Rewards**: Simulation outcomes → Scalar reward signals + +## Type Parameters +- `M <: ReinforcementLearningABM`: The type of the wrapped ABM + +## Fields +- `model::M`: The wrapped ReinforcementLearningABM instance + +## POMDPs Interface +The wrapper implements the complete POMDPs interface including: +- `actions(env)`: Get available actions +- `observations(env)`: Get observation space +- `observation(env, state)`: Generate observations +- `gen(env, state, action, rng)`: State transitions and rewards +- `initialstate(env)`: Episode initialization +- `isterminal(env, state)`: Termination conditions +- `discount(env)`: Discount factor + +## Example +```julia +# Create wrapper +env = wrap_for_rl_training(model) + +# Use with RL algorithms +solver = PPO(π=policy, S=observations(env), N=10000) +policy = solve(solver, env) +``` +""" +struct RLEnvironmentWrapper{M<:ReinforcementLearningABM} <: POMDPs.POMDP{Vector{Float32},Int,Vector{Float32}} + model::M +end + + +""" + wrap_for_rl_training(model::ReinforcementLearningABM) → RLEnvironmentWrapper + +Wrap a ReinforcementLearningABM in an RLEnvironmentWrapper to make it compatible +with POMDPs-based RL training algorithms. + +## Arguments +- `model::ReinforcementLearningABM`: The ReinforcementLearningABM to wrap + +## Returns +- `RLEnvironmentWrapper`: A wrapper that implements the POMDPs.POMDP interface + +## Notes +This wrapper enables the use of standard RL algorithms (PPO, DQN, A2C) with ABMs by: +- Translating ABM states to RL observations +- Mapping RL actions to agent behaviors +- Computing rewards based on simulation outcomes +- Managing episode termination conditions + +The wrapper automatically handles agent cycling, multi-agent coordination, and +integrates with the configured observation, reward, and terminal functions. +``` +""" +function wrap_for_rl_training(model::ReinforcementLearningABM) + return RLEnvironmentWrapper(model) +end + +""" + POMDPs.actions(wrapper::RLEnvironmentWrapper) → ActionSpace + +Get the action space for the currently training agent type. + +## Arguments +- `wrapper::RLEnvironmentWrapper`: The wrapped RL environment + +## Returns +- `ActionSpace`: The action space (e.g., Crux.DiscreteSpace) for the current agent type + +## Throws +- `ErrorException`: If RL configuration is not set or no action space is defined for the agent type + +## Notes +This function is part of the POMDPs interface and is called automatically during training +to determine what actions are available to the agent. +""" +function POMDPs.actions(wrapper::RLEnvironmentWrapper) + model = wrapper.model + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + current_agent_type = Agents.get_current_training_agent_type(model) + config = model.rl_config[] + + if haskey(config.action_spaces, current_agent_type) + return config.action_spaces[current_agent_type] + else + error("No action space defined for agent type $current_agent_type") + end +end + +""" + POMDPs.observations(wrapper::RLEnvironmentWrapper) → ObservationSpace + +Get the observation space for the currently training agent type. + +## Arguments +- `wrapper::RLEnvironmentWrapper`: The wrapped RL environment + +## Returns +- `ObservationSpace`: The observation space (e.g., Crux.ContinuousSpace) for the current agent type + +## Notes +This function is part of the POMDPs interface. If no observation space is defined for the +agent type, it returns a default ContinuousSpace with 10 dimensions and issues a warning. + +## Example +```julia +env = wrap_for_rl_training(model) +obs_space = POMDPs.observations(env) +println("Observation dimensions: ", Crux.dim(obs_space)) +``` +""" +function POMDPs.observations(wrapper::RLEnvironmentWrapper) + model = wrapper.model + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + current_agent_type = Agents.get_current_training_agent_type(model) + config = model.rl_config[] + + if haskey(config.observation_spaces, current_agent_type) + return config.observation_spaces[current_agent_type] + else + # Return default observation space with smaller dimensions + println("WARNING: No observation space found for agent type $current_agent_type, using default") + return Crux.ContinuousSpace((10,), Float32) + end +end + +""" + POMDPs.observation(wrapper::RLEnvironmentWrapper, s::Vector{Float32}) → Vector{Float32} + +Get the observation for the current training agent. + +## Arguments +- `wrapper::RLEnvironmentWrapper`: The wrapped RL environment +- `s::Vector{Float32}`: The current state (typically unused in ABM context) + +## Returns +- `Vector{Float32}`: The observation vector for the current training agent + +## Notes +This function uses the configured observation function to generate observation vectors +for the current training agent. If no agent is currently being trained, it returns a +zero vector with appropriate dimensions. + +## Example +```julia +env = wrap_for_rl_training(model) +state = zeros(Float32, 10) +obs = POMDPs.observation(env, state) +println("Observation: ", obs) +``` +""" +function POMDPs.observation(wrapper::RLEnvironmentWrapper, s::Vector{Float32}) + model = wrapper.model + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + current_agent = Agents.get_current_training_agent(model) + if isnothing(current_agent) + # Return zero observation with correct dimensions + obs_space = POMDPs.observations(wrapper) + obs_dims = Crux.dim(obs_space) + return zeros(Float32, obs_dims...) + end + config = model.rl_config[] + # Get observation vector directly from the configured function + return config.observation_fn(model, current_agent.id) +end + +""" + POMDPs.initialstate(wrapper::RLEnvironmentWrapper) → Dirac{Vector{Float32}} + +Initialize the state for a new episode. + +## Arguments +- `wrapper::RLEnvironmentWrapper`: The wrapped RL environment + +## Returns +- `Dirac{Vector{Float32}}`: A deterministic distribution over initial states + +## Notes +This function resets the model to its initial state using `reset_model_for_episode!` +and returns a deterministic initial state distribution. The state dimensions are +determined from the RL configuration or default to 10 dimensions. + +## Example +```julia +env = wrap_for_rl_training(model) +initial_state_dist = POMDPs.initialstate(env) +initial_state = rand(initial_state_dist) +``` +""" +function POMDPs.initialstate(wrapper::RLEnvironmentWrapper) + model = wrapper.model + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + # Reset the model to initial state + Agents.reset_model_for_episode!(model) + + # Return initial state + current_agent_type = Agents.get_current_training_agent_type(model) + config = model.rl_config[] + + if haskey(config, :state_spaces) && haskey(config.state_spaces, current_agent_type) + state_dims = Crux.dim(config.state_spaces[current_agent_type]) + else + state_dims = (10,) # Default state dimensions + end + + return Dirac(zeros(Float32, state_dims...)) +end + +""" + POMDPs.initialobs(wrapper::RLEnvironmentWrapper, initial_state::Vector{Float32}) → Dirac{Vector{Float32}} + +Get the initial observation for a new episode. + +## Arguments +- `wrapper::RLEnvironmentWrapper`: The wrapped RL environment +- `initial_state::Vector{Float32}`: The initial state vector + +## Returns +- `Dirac{Vector{Float32}}`: A deterministic distribution over initial observations + +## Notes +This function generates the initial observation for a new episode by calling +`POMDPs.observation` with the initial state and wrapping the result in a +deterministic distribution. + +## Example +```julia +env = wrap_for_rl_training(model) +initial_state = zeros(Float32, 10) +initial_obs_dist = POMDPs.initialobs(env, initial_state) +initial_obs = rand(initial_obs_dist) +``` +""" +function POMDPs.initialobs(wrapper::RLEnvironmentWrapper, initial_state::Vector{Float32}) + obs = POMDPs.observation(wrapper, initial_state) + return Dirac(obs) +end + +""" + POMDPs.gen(wrapper::RLEnvironmentWrapper, s, action::Int, rng::AbstractRNG) → NamedTuple + +Generate the next state, observation, and reward after taking an action. + +## Arguments +- `wrapper::RLEnvironmentWrapper`: The wrapped RL environment +- `s`: The current state +- `action::Int`: The action to take +- `rng::AbstractRNG`: Random number generator (typically unused) + +## Returns +- `NamedTuple`: A named tuple with fields: + - `sp`: Next state (same as input state in ABM context) + - `o::Vector{Float32}`: Next observation vector + - `r::Float32`: Reward for the action + +## Notes +This is the core POMDPs interface function that: +1. Executes the action using the configured agent stepping function +2. Calculates the reward using the configured reward function +3. Advances the simulation to handle other agents and environment updates +4. Returns the next observation + +If no current training agent exists, returns a terminal state with -10.0 reward. + +## Example +```julia +env = wrap_for_rl_training(model) +state = zeros(Float32, 10) +action = 1 +result = POMDPs.gen(env, state, action, Random.default_rng()) +println("Reward: ", result.r) +``` +""" +function POMDPs.gen(wrapper::RLEnvironmentWrapper, s, action::Int, rng::AbstractRNG) + model = wrapper.model + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + current_agent = Agents.get_current_training_agent(model) + + if isnothing(current_agent) + # Episode terminated + obs_space = POMDPs.observations(wrapper) + obs_dims = Crux.dim(obs_space) + return (sp=s, o=zeros(Float32, obs_dims...), r=-10.0) + end + + config = model.rl_config[] + + # Record initial state for reward calculation + initial_state = deepcopy(model) + + # Execute the action using the configured stepping function + config.agent_step_fn(current_agent, model, action) + + # Calculate reward using the configured function + reward = config.reward_fn(model, current_agent, action, initial_state, model) + + # Advance simulation + advance_simulation!(model) + + # Return next state and observation + sp = s # Dummy state + o = POMDPs.observation(wrapper, sp) + + return (sp=sp, o=o, r=reward) +end + +""" + POMDPs.isterminal(wrapper::RLEnvironmentWrapper, s) → Bool + +Check if the current state is terminal. + +## Arguments +- `wrapper::RLEnvironmentWrapper`: The wrapped RL environment +- `s`: The current state + +## Returns +- `Bool`: `true` if the episode should terminate, `false` otherwise + +## Notes +An episode terminates if: +1. The configured terminal function returns `true`, OR +2. The model time has reached the maximum steps configured in RL config + +The maximum steps default to 100 if not specified in the configuration. + +## Example +```julia +env = wrap_for_rl_training(model) +state = zeros(Float32, 10) +is_done = POMDPs.isterminal(env, state) +println("Episode terminated: ", is_done) +``` +""" +function POMDPs.isterminal(wrapper::RLEnvironmentWrapper, s) + model = wrapper.model + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + config = model.rl_config[] + max_steps = get(config, :max_steps, 100) + + return config.terminal_fn(model) || abmtime(model) >= max_steps +end + +""" + POMDPs.discount(wrapper::RLEnvironmentWrapper) → Float64 + +Get the discount factor for the current agent type. + +## Arguments +- `wrapper::RLEnvironmentWrapper`: The wrapped RL environment + +## Returns +- `Float64`: The discount factor (gamma) for the current training agent type + +## Notes +The discount factor is looked up from the RL configuration's `discount_rates` dictionary +using the current training agent type as the key. If not found, defaults to 0.99. + +## Example +```julia +env = wrap_for_rl_training(model) +γ = POMDPs.discount(env) +println("Discount factor: ", γ) +``` +""" +function POMDPs.discount(wrapper::RLEnvironmentWrapper) + model = wrapper.model + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + current_agent_type = Agents.get_current_training_agent_type(model) + config = model.rl_config[] + + if haskey(config, :discount_rates) && haskey(config.discount_rates, current_agent_type) + return config.discount_rates[current_agent_type] + else + return 0.99 + end +end + +""" + Crux.state_space(wrapper::RLEnvironmentWrapper) → StateSpace + +Get the state space for the current agent type. + +## Arguments +- `wrapper::RLEnvironmentWrapper`: The wrapped RL environment + +## Returns +- `StateSpace`: The state space (e.g., Crux.ContinuousSpace) for the current agent type + +## Notes +This function looks up the state space from the RL configuration's `state_spaces` dictionary. +If not found, defaults to a ContinuousSpace with 10 dimensions. This is part of the +Crux.jl interface extension. + +## Example +```julia +env = wrap_for_rl_training(model) +state_space = Crux.state_space(env) +println("State dimensions: ", Crux.dim(state_space)) +``` +""" +function Crux.state_space(wrapper::RLEnvironmentWrapper) + model = wrapper.model + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + current_agent_type = Agents.get_current_training_agent_type(model) + config = model.rl_config[] + + if haskey(config, :state_spaces) && haskey(config.state_spaces, current_agent_type) + return config.state_spaces[current_agent_type] + else + return Crux.ContinuousSpace((10,)) + end +end + +""" + advance_simulation!(model::ReinforcementLearningABM) + +Advance the simulation by one step, handling other agents and environment updates. + +## Arguments +- `model::ReinforcementLearningABM`: The RL model to advance + +## Notes +This function implements the core simulation advancement logic: + +1. **Agent Cycling**: Moves to the next agent of the current training type +2. **Multi-Agent Coordination**: When all training agents have acted, runs other agent types +3. **Policy Application**: Uses trained policies for other agents when available, falls back to random actions +4. **Environment Step**: Executes the model stepping function and increments time + +The function handles agent removal and ensures +proper coordination between different agent types during training. +""" +function advance_simulation!(model::ReinforcementLearningABM) + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + config = model.rl_config[] + current_agent_type = Agents.get_current_training_agent_type(model) + + # Move to next agent of the training type + agents_of_type = [a for a in allagents(model) if typeof(a) == current_agent_type] + + if !isempty(agents_of_type) + model.current_training_agent_id[] += 1 + + # If we've cycled through all agents of this type, run other agents and environment step + if model.current_training_agent_id[] > length(agents_of_type) + model.current_training_agent_id[] = 1 + + # Run other agent types with their policies or random behavior + training_agent_types = get(config, :training_agent_types, [current_agent_type]) + for agent_type in training_agent_types + if agent_type != current_agent_type + other_agents = [a for a in allagents(model) if typeof(a) == agent_type] + + for other_agent in other_agents + try + if haskey(model.trained_policies, agent_type) + # Use trained policy + obs_vec = config.observation_fn(model, other_agent.id) + action = Crux.action(model.trained_policies[agent_type], obs_vec) + config.agent_step_fn(other_agent, model, action) + else + # Fall back to random behavior + if haskey(config.action_spaces, agent_type) + action = rand(config.action_spaces[agent_type].vals) + config.agent_step_fn(other_agent, model, action) + end + end + catch e + # Agent might have died during action, continue + continue + end + end + end + end + + # Run model step and increment time + model.model_step(model) + model.time[] += 1 + end + end +end \ No newline at end of file diff --git a/ext/AgentsRL/src/step_reinforcement_learning.jl b/ext/AgentsRL/src/step_reinforcement_learning.jl new file mode 100644 index 0000000000..36c4b96257 --- /dev/null +++ b/ext/AgentsRL/src/step_reinforcement_learning.jl @@ -0,0 +1,65 @@ +function Agents.CommonSolve.step!(model::ReinforcementLearningABM, n::Union{Real,Function}=1) + agent_step! = Agents.agent_step_field(model) + model_step! = Agents.model_step_field(model) + t = getfield(model, :time) + Agents.step_ahead_rl!(model, agent_step!, model_step!, n, t) + return model +end + +""" + rl_agent_step!(agent, model) + +Default agent stepping function for RL agents. This will use trained policies +if available, otherwise fall back to random actions. +""" +function Agents.rl_agent_step!(agent, model) + if model isa ReinforcementLearningABM + agent_type = typeof(agent) + + if haskey(model.trained_policies, agent_type) && !isnothing(model.rl_config[]) + # Use trained policy + config = model.rl_config[] + obs_vec = config.observation_fn(model, agent.id) + action = Crux.action(model.trained_policies[agent_type], obs_vec) + config.agent_step_fn(agent, model, action[1]) + else + # Fall back to random behavior + if !isnothing(model.rl_config[]) && haskey(model.rl_config[].action_spaces, agent_type) + action_space = model.rl_config[].action_spaces[agent_type] + action = rand(abmrng(model), action_space.vals) + model.rl_config[].agent_step_fn(agent, model, action) + else + # Do nothing if no RL configuration available + println("Warning: No trained policy or action space defined for agent type $agent_type. Skipping step.") + return + end + end + else + error("rl_agent_step! can only be used with ReinforcementLearningABM models.") + end +end + +function Agents.step_ahead_rl!(model::ReinforcementLearningABM, agent_step!, model_step!, n, t) + agents_first = getfield(model, :agents_first) + t0 = t[] + while Agents.until(t[], t0, n, model) + !agents_first && model_step!(model) + for id in Agents.schedule(model) + # ensure we don't act on agent that doesn't exist + Agents.agent_not_removed(id, model) || continue + + # Use RL-based stepping + agent = model[id] + agent_type = typeof(agent) + if !isnothing(model.rl_config[]) && haskey(model.rl_config[].action_spaces, agent_type) + # Use trained policy for this agent or fallback to random actions + Agents.rl_agent_step!(agent, model) + else + # Use standard agent stepping + agent_step!(agent, model) + end + end + agents_first && model_step!(model) + t[] += 1 + end +end diff --git a/ext/AgentsVisualizations/src/interaction.jl b/ext/AgentsVisualizations/src/interaction.jl index b83e34e061..fa336d2f88 100644 --- a/ext/AgentsVisualizations/src/interaction.jl +++ b/ext/AgentsVisualizations/src/interaction.jl @@ -19,10 +19,10 @@ Agents.add_interaction!(ax) = add_interaction!(ax, first_abmplot_in(ax)) function add_controls!(fig, abmobs, dt) model, adata, mdata, adf, mdf, when = - getfield.(Ref(abmobs), (:model, :adata, :mdata, :adf, :mdf, :when)) + getfield.(Ref(abmobs), (:model, :adata, :mdata, :adf, :mdf, :when)) # Create new layout for control buttons - controllayout = fig[end+1,:][1,1] = GridLayout(tellheight = true) + controllayout = fig[end+1, :][1, 1] = GridLayout(tellheight=true) # Sliders if abmspace(model[]) isa Agents.ContinuousSpace @@ -32,24 +32,26 @@ function add_controls!(fig, abmobs, dt) end dtrange = isnothing(dt) ? _default_dts_from_model(model[]) : dt - sg = SliderGrid(controllayout[1,1], - (label = "dt", range = dtrange, startvalue = 1), - (label = "sleep", range = _sleepr, startvalue = _sleep0), + sg = SliderGrid(controllayout[1, 1], + (label="dt", range=dtrange, startvalue=1), + (label="sleep", range=_sleepr, startvalue=_sleep0), ) dtslider, slep = [s.value for s in sg.sliders] # Step button # We need an additional observable that keep track of the last time data # was collected. Here collection is the same for agent of models so we need 1 variable. - step = Button(fig, label = "step\nmodel") + step = Button(fig, label="step\nmodel") on(step.clicks) do c # notice that stepping the abmobs both steps the model and collects data!!! Agents.step!(abmobs, dtslider[]) end # Run button - run = Button(fig, label = "run\nmodel") + run = Button(fig, label="run\nmodel") isrunning = Observable(false) - on(run.clicks) do c; isrunning[] = !isrunning[]; end + on(run.clicks) do c + isrunning[] = !isrunning[] + end on(run.clicks) do c @async while isrunning[] step.clicks[] = step.clicks[] + 1 @@ -58,7 +60,7 @@ function add_controls!(fig, abmobs, dt) end end # Reset button - reset = Button(fig, label = "reset\nmodel") + reset = Button(fig, label="reset\nmodel") model0 = deepcopy(model[]) # backup initial model state on(reset.clicks) do c !isnothing(adf) && update_offsets!(model[], abmobs.offset_time_adf[], adf[]) @@ -69,7 +71,7 @@ function add_controls!(fig, abmobs, dt) abmobs.t_last_collect[] = abmtime(model0) end # Clear button - clear = Button(fig, label = "clear\ndata") + clear = Button(fig, label="clear\ndata") on(clear.clicks) do c timetype = typeof(abmtime(model[])) abmobs.offset_time_adf[] = (Ref(abmobs.offset_time_adf[][1][]), timetype[]) @@ -80,7 +82,7 @@ function add_controls!(fig, abmobs, dt) abmobs.t_last_collect[] = abmtime(model[]) end # Layout buttons - controllayout[2, :] = Makie.hbox!(step, run, reset, clear; tellwidth = false) + controllayout[2, :] = Makie.hbox!(step, run, reset, clear; tellwidth=false) return step.clicks, reset.clicks end @@ -91,7 +93,7 @@ function update_offsets!(model, offsets, df) offsets[1][] += abmtime(model) end -_default_dts_from_model(::StandardABM) = 1:50 +_default_dts_from_model(::AgentBasedModel) = 1:50 _default_dts_from_model(::EventQueueABM) = 0.1:0.1:10.0 "reinitialize agent and model dataframes." @@ -107,23 +109,23 @@ end "Initialize parameter control sliders." function add_param_sliders!(fig, model, params, resetclick) - datalayout = fig[end,:][1,2] = GridLayout(tellheight = true) + datalayout = fig[end, :][1, 2] = GridLayout(tellheight=true) - slidervals = Dict{Symbol, Observable}() + slidervals = Dict{Symbol,Observable}() tuples_for_slidergrid = [] for (i, (k, vals)) in enumerate(params) startvalue = has_key(abmproperties(model[]), k) ? - get_value(abmproperties(model[]), k) : vals[1] + get_value(abmproperties(model[]), k) : vals[1] label = string(k) - push!(tuples_for_slidergrid, (;label, range = vals, startvalue)) + push!(tuples_for_slidergrid, (; label, range=vals, startvalue)) end - sg = SliderGrid(datalayout[1,1], tuples_for_slidergrid...; tellheight = true) + sg = SliderGrid(datalayout[1, 1], tuples_for_slidergrid...; tellheight=true) for (i, (l, vals)) in enumerate(params) slidervals[l] = sg.sliders[i].value end # Update button - update = Button(datalayout[end+1, :], label = "update", tellwidth = false) + update = Button(datalayout[end+1, :], label="update", tellwidth=false) on(update.clicks) do c for (k, v) in pairs(slidervals) if has_key(abmproperties(model[]), k) diff --git a/src/Agents.jl b/src/Agents.jl index fb458866f0..28ee53a68a 100644 --- a/src/Agents.jl +++ b/src/Agents.jl @@ -33,6 +33,8 @@ include("core/agent_containers.jl") include("core/model_standard.jl") include("core/model_event_queue.jl") include("core/model_validation.jl") +# reinforcement learning (singleton methods for package extension) +include("reinforcement_learning.jl") include("core/model_accessing_API.jl") include("core/space_interaction_API.jl") include("core/higher_order_iteration.jl") @@ -77,31 +79,31 @@ include("precompile.jl") using Scratch function __init__() -display_update = true -version_number = "6" -update_name = "update_v$(version_number)" -update_message = """ -Update message: Agents v$(version_number) - -This is a new major version of Agents.jl with lots of cool stuff! -However, from this version onwards, we will stop posting update messages -to the REPL console! - -If you want to be updated, follow this discourse post: - -https://discourse.julialang.org/t/agents-jl-v6-releases-announcement-post/111678 - -(and see the CHANGELOG.md file online for a list of changes!) -""" - -if display_update - # Get scratch space for this package - versions_dir = @get_scratch!("versions") - if !isfile(joinpath(versions_dir, update_name)) - printstyled(stdout, "\n"*update_message; color=:light_magenta) - touch(joinpath(versions_dir, update_name)) + display_update = true + version_number = "6" + update_name = "update_v$(version_number)" + update_message = """ + Update message: Agents v$(version_number) + + This is a new major version of Agents.jl with lots of cool stuff! + However, from this version onwards, we will stop posting update messages + to the REPL console! + + If you want to be updated, follow this discourse post: + + https://discourse.julialang.org/t/agents-jl-v6-releases-announcement-post/111678 + + (and see the CHANGELOG.md file online for a list of changes!) + """ + + if display_update + # Get scratch space for this package + versions_dir = @get_scratch!("versions") + if !isfile(joinpath(versions_dir, update_name)) + printstyled(stdout, "\n" * update_message; color=:light_magenta) + touch(joinpath(versions_dir, update_name)) + end end -end end # _init__ function. end # module diff --git a/src/core/model_accessing_API.jl b/src/core/model_accessing_API.jl index f83efa3c12..d78e511e0a 100644 --- a/src/core/model_accessing_API.jl +++ b/src/core/model_accessing_API.jl @@ -1,9 +1,12 @@ const DictABM = Union{StandardABM{S,A,<:AbstractDict{<:Integer,A}} where {S,A}, - EventQueueABM{S,A,<:AbstractDict{<:Integer,A}} where {S,A}} + EventQueueABM{S,A,<:AbstractDict{<:Integer,A}} where {S,A}, + ReinforcementLearningABM{S,A,<:AbstractDict{<:Integer,A}} where {S,A}} const VecABM = Union{StandardABM{S,A,<:AbstractVector{A}} where {S,A}, - EventQueueABM{S,A,<:AbstractVector{A}} where {S,A}} + EventQueueABM{S,A,<:AbstractVector{A}} where {S,A}, + ReinforcementLearningABM{S,A,<:AbstractVector{A}} where {S,A}} const StructVecABM = Union{StandardABM{S,A,<:StructVector{A}} where {S,A}, - EventQueueABM{S,A,<:StructVector{A}} where {S,A}} + EventQueueABM{S,A,<:StructVector{A}} where {S,A}, + ReinforcementLearningABM{S,A,<:StructVector{A}} where {S,A}} nextid(model::DictABM) = getfield(model, :maxid)[] + 1 nextid(model::Union{VecABM, StructVecABM}) = nagents(model) + 1 @@ -42,6 +45,7 @@ end function extra_actions_after_add!(agent, model::EventQueueABM{S,A,<:StructVector} where {S,A}) getfield(model, :autogenerate_on_add) && add_event!(model[agent.id], model) end +extra_actions_after_add!(agent, model::ReinforcementLearningABM) = nothing function remove_agent_from_container!(agent::AbstractAgent, model::DictABM) delete!(agent_container(model), agent.id) diff --git a/src/reinforcement_learning.jl b/src/reinforcement_learning.jl new file mode 100644 index 0000000000..66b34dc973 --- /dev/null +++ b/src/reinforcement_learning.jl @@ -0,0 +1,787 @@ +export ReinforcementLearningABM +export get_trained_policies, set_rl_config!, copy_trained_policies! +export train_model!, create_policy_network, create_value_network + +struct ReinforcementLearningABM{ + S<:SpaceType, + A<:AbstractAgent, + C<:Union{AbstractDict{Int,A},AbstractVector{A}}, + T,G,K,F,P,R<:AbstractRNG} <: AgentBasedModel{S} + # Standard ABM components + agents::C + agent_step::G + model_step::K + space::S + scheduler::F + properties::P + rng::R + agents_types::T + agents_first::Bool + maxid::Base.RefValue{Int64} + time::Base.RefValue{Int64} + + # RL-specific components + rl_config::Base.RefValue{Any} + trained_policies::Dict{Type,Any} + training_history::Dict{Type,Any} + is_training::Base.RefValue{Bool} + current_training_agent_type::Base.RefValue{Any} + current_training_agent_id::Base.RefValue{Int} # Counter/index for cycling through agents of training type (not actual agent ID) +end + +# Extend mandatory internal API for `AgentBasedModel` +containertype(::ReinforcementLearningABM{S,A,C}) where {S,A,C} = C +agenttype(::ReinforcementLearningABM{S,A}) where {S,A} = A +discretimeabm(::ReinforcementLearningABM) = true + +# Override property access to handle RL-specific fields +function Base.getproperty(m::ReinforcementLearningABM, s::Symbol) + # Handle RL-specific fields directly + if s in (:rl_config, :trained_policies, :training_history, :is_training, :current_training_agent_type, :current_training_agent_id) + return getfield(m, s) + elseif s in (:agents, :agent_step, :model_step, :space, :scheduler, :rng, + :agents_types, :agents_first, :maxid, :time, :properties) + return getfield(m, s) + else + # Delegate to properties for other fields + p = abmproperties(m) + if p isa Dict + return getindex(p, s) + else + return getproperty(p, s) + end + end +end + +function Base.setproperty!(m::ReinforcementLearningABM, s::Symbol, x) + # Handle RL-specific fields directly + if s in (:rl_config, :trained_policies, :training_history, :is_training, :current_training_agent_type, :current_training_agent_id) + return setfield!(m, s, x) + # Handle standard ABM fields directly (except properties which is immutable) + elseif s in (:agents, :agent_step, :model_step, :space, :scheduler, :rng, + :agents_types, :agents_first, :maxid, :time) + return setfield!(m, s, x) + # Special handling for properties - can't setfield! but can modify Dict contents + elseif s == :properties + error("Cannot replace properties field directly. Use model.properties[key] = value to modify properties.") + else + # Delegate to properties for other fields + properties = abmproperties(m) + exception = ErrorException( + "Cannot set property $(s) for model $(nameof(typeof(m))) with " * + "properties container type $(typeof(properties))." + ) + properties === nothing && throw(exception) + if properties isa Dict && haskey(properties, s) + properties[s] = x + elseif hasproperty(properties, s) + setproperty!(properties, s, x) + else + throw(exception) + end + end +end + +""" + ReinforcementLearningABM <: AgentBasedModel + +A concrete implementation of an [`AgentBasedModel`](@ref) that extends [`StandardABM`](@ref) +with built-in reinforcement learning capabilities. This model type integrates RL training +into the ABM framework, allowing agents to learn and adapt their behavior +through interaction with the environment. + +## Key Features + +- **Integrated RL Training**: Built-in support for training agents using various RL algorithms +- **Multi-Agent Learning**: Support for training multiple agent types simultaneously or sequentially +- **Flexible Observation Models**: Customizable observation functions for different agent types +- **Reward Engineering**: User-defined reward functions for different learning objectives +- **Policy Management**: Automatic management of trained policies and their deployment + +Here is how to construct a `ReinforcementLearningABM`: + + ReinforcementLearningABM(AgentType(s), space [, rl_config]; kwargs...) + +## Arguments + +- `AgentType(s)`: The result of `@agent` or `@multiagent` or a `Union` of agent types. + Any agent type can be used - they don't need to inherit from `RLAgent`. +- `space`: A subtype of `AbstractSpace`. See [Space](@ref available_spaces) for all available spaces. +- `rl_config`: (Optional) A named tuple containing RL configuration. Can be set later with `set_rl_config!`. + +## Keyword Arguments + +Same as [`StandardABM`](@ref): +- `agent_step!`: Function for stepping agents. If not provided, will use RL-based stepping when policies are available. +- `model_step!`: Function for stepping the model. + +## RL Configuration + +The `rl_config` should be a named tuple with the following fields: + +### Required Functions + +- **`observation_fn(model::ReinforcementLearningABM, agent_id::Int) → Vector{Float32}`** + Function to generate observation vectors for agents from the model state. + - `model`: The ReinforcementLearningABM instance + - `agent_id`: ID of the agent for which to generate observation + - **Returns**: `Vector{Float32}` - Flattened feature vector ready for neural network input + +- **`reward_fn(env::ReinforcementLearningABM, agent::AbstractAgent, action::Int, initial_model::ReinforcementLearningABM, final_model::ReinforcementLearningABM) → Float32`** + Function to calculate scalar rewards based on agent actions and state transitions. + - `env`: Current model state (typically same as `final_model`) + - `agent`: The agent that took the action + - `action`: Integer action that was taken + - `initial_model`: Model state before the action + - `final_model`: Model state after the action + - **Returns**: `Float32` - Scalar reward signal for the action + +- **`terminal_fn(env::ReinforcementLearningABM) → Bool`** + Function to determine if the current episode should terminate. + - `env`: The current model state + - **Returns**: `Bool` - `true` if episode should end, `false` to continue + +- **`agent_step_fn(agent::AbstractAgent, model::ReinforcementLearningABM, action::Int) → Nothing`** + Function that executes an agent's action in the model. + - `agent`: The agent taking the action + - `model`: The model containing the agent + - `action`: Integer action to execute + - **Returns**: `Nothing` - Modifies agent and model state in-place + +### Required Spaces + +- **`action_spaces::Dict{Type, ActionSpace}`** + Dictionary mapping agent types to their available actions. + - Keys: Agent types (e.g., `MyAgent`) + - Values: Action spaces (e.g., `Crux.DiscreteSpace(5)` for 5 discrete actions) + +- **`observation_spaces::Dict{Type, ObservationSpace}`** + Dictionary mapping agent types to their observation vector dimensions. + - Keys: Agent types (e.g., `MyAgent`) + - Values: Observation spaces (e.g., `Crux.ContinuousSpace((84,), Float32)` for 84-dim vectors) + +### Required Configuration + +- **`training_agent_types::Vector{Type}`** + Vector of agent types that should undergo RL training. + - Must be a subset of agent types present in the model + - Example: `[MyAgent1, MyAgent2]` + +- **`max_steps::Int`** + Maximum number of simulation steps per training episode. + - Episodes terminate when this limit is reached OR `terminal_fn` returns `true` + - Typical values: 50-500 depending on model complexity + +- **`observation_radius::Int or Dict{Type, Int}`** + Radius for local neighborhood observations in grid-based models. + - Used in `observation_fn` to determine neighborhood size + - Example: `4` creates a 9×9 observation grid around each agent + - **Multi-agent support**: Can specify different radii per agent type by passing a `Dict{Type, Int}` instead of a single `Int` + - **Usage**: `observation_radius=4` (all agents) or `observation_radius=Dict(Wolf => 5, Sheep => 3)` (per-type) + +### Optional Configuration + +- **`discount_rates::Dict{Type, Float64}`** *(Optional)* + Dictionary mapping agent types to their reward discount factors (γ). + - Keys: Agent types + - Values: Discount factors between 0.0 and 1.0 + - **Default**: 0.99 for all agent types if not specified + +- **`model_init_fn() → ReinforcementLearningABM`** *(Optional)* + Function to create fresh model instances for episode resets during training. + - **Returns**: New ReinforcementLearningABM instance with reset state + - If not provided, uses basic model reset without full reinitialization +""" +function ReinforcementLearningABM( + A::Type, + space::S=nothing, + rl_config=nothing; + agent_step!::G=dummystep, + model_step!::K=dummystep, + container::Type=Dict, + scheduler::F=Schedulers.Randomly(), + properties::P=nothing, + rng::R=Random.default_rng(), + agents_first::Bool=true, + warn=true, + kwargs... +) where {S<:SpaceType,G,K,F,P,R<:AbstractRNG} + + # Initialize agent container using proper construction + agents = construct_agent_container(container, A) + agents_types = union_types(A) + T = typeof(agents_types) + C = typeof(agents) + + model = ReinforcementLearningABM{S,A,C,T,G,K,F,P,R}( + agents, + agent_step!, + model_step!, + space, + scheduler, + properties, + rng, + agents_types, + agents_first, + Ref(0), + Ref(0), + Ref{Any}(rl_config), + Dict{Type,Any}(), + Dict{Type,Any}(), + Ref(false), + Ref{Any}(nothing), + Ref(1) + ) + + return model +end + +""" + set_rl_config!(model::ReinforcementLearningABM, config) → ReinforcementLearningABM + +Set the RL configuration for the model. + +## Arguments +- `model::ReinforcementLearningABM`: The model to configure +- `config`: Named tuple containing RL configuration parameters + +## Returns +- `ReinforcementLearningABM`: The configured model + +## Example +```julia +config = ( + observation_fn = my_obs_function, + reward_fn = my_reward_function, + # ... other config parameters +) +set_rl_config!(model, config) +``` +""" +function set_rl_config!(model::ReinforcementLearningABM, config) + model.rl_config[] = config + + # Initialize training history for each training agent type + if haskey(config, :training_agent_types) + for agent_type in config.training_agent_types + if !haskey(model.training_history, agent_type) + model.training_history[agent_type] = nothing # Will be set during training + end + end + end +end + +""" + get_trained_policies(model::ReinforcementLearningABM) → Dict{Type, Any} + +Get the dictionary of trained policies for each agent type. + +## Arguments +- `model::ReinforcementLearningABM`: The model containing trained policies + +## Returns +- `Dict{Type, Any}`: Dictionary mapping agent types to their trained policies + +## Example +```julia +policies = get_trained_policies(model) +if haskey(policies, MyAgent) + println("MyAgent has a trained policy") +end +``` +""" +get_trained_policies(model::ReinforcementLearningABM) = model.trained_policies + +""" + copy_trained_policies!(target_model::ReinforcementLearningABM, source_model::ReinforcementLearningABM) → ReinforcementLearningABM + +Copy all trained policies from the source model to the target model. + +## Arguments +- `target_model::ReinforcementLearningABM`: The model to copy policies to +- `source_model::ReinforcementLearningABM`: The model to copy policies from + +## Returns +- `ReinforcementLearningABM`: The target model with copied policies (for chaining) + +## Example +```julia +# Train policies in one model +train_model!(training_model, MyAgent) + +# Copy to a fresh simulation model +fresh_model = initialize_model() +copy_trained_policies!(fresh_model, training_model) +``` +""" +function copy_trained_policies!(target_model::ReinforcementLearningABM, source_model::ReinforcementLearningABM) + for (agent_type, policy) in source_model.trained_policies + target_model.trained_policies[agent_type] = policy + end + return target_model +end + +""" + rl_agent_step!(agent, model) + +Default agent stepping function for RL agents. This will use trained policies +if available, otherwise fall back to random actions. + +## Arguments +- `agent`: The agent to step +- `model::ReinforcementLearningABM`: The model containing the agent + +## Notes +This function automatically selects between trained policies and random actions +based on what's available for the agent's type. It's used internally by the +RL stepping infrastructure. +""" +function rl_agent_step! end + +""" + get_current_training_agent_type(model::ReinforcementLearningABM) → Type + +Get the currently training agent type. + +## Arguments +- `model::ReinforcementLearningABM`: The RL model + +## Returns +- `Type`: The agent type currently being trained +""" +function get_current_training_agent_type(model::ReinforcementLearningABM) + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + # Check if current training agent type is set in the model + if !isnothing(model.current_training_agent_type[]) + return model.current_training_agent_type[] + end + + # Otherwise, fall back to first agent type in training_agent_types + config = model.rl_config[] + if haskey(config, :training_agent_types) && !isempty(config.training_agent_types) + return config.training_agent_types[1] + else + error("No training agent type specified in RL configuration") + end +end + +""" + get_current_training_agent(model::ReinforcementLearningABM) → Union{AbstractAgent, Nothing} + +Get the current agent being trained. + +## Arguments +- `model::ReinforcementLearningABM`: The RL model + +## Returns +- `Union{AbstractAgent, Nothing}`: The current agent being trained, or `nothing` if no agents of the training type exist + +## Notes +The `current_training_agent_id` is a counter/index that cycles through agents of the training type, +not the actual agent ID. +""" +function get_current_training_agent(model::ReinforcementLearningABM) + current_agent_type = get_current_training_agent_type(model) + agents_of_type = [a for a in allagents(model) if typeof(a) == current_agent_type] + + if isempty(agents_of_type) + return nothing + end + + current_agent_id = model.current_training_agent_id[] + + # Cycle through agents of the training type + agent_idx = ((current_agent_id - 1) % length(agents_of_type)) + 1 + return agents_of_type[agent_idx] +end + +""" + reset_model_for_episode!(model::ReinforcementLearningABM) + +Reset the model to initial state for a new training episode. + +## Arguments +- `model::ReinforcementLearningABM`: The model to reset + +## Notes +This function resets the model time, agent positions, and other state based on the +`model_init_fn` in the RL configuration. It's used internally during training +to reset episodes. +""" +function reset_model_for_episode!(model::ReinforcementLearningABM) + if isnothing(model.rl_config[]) + error("RL configuration not set. Use set_rl_config! first.") + end + + # Reset time + model.time[] = 0 + + # Reset current agent ID + model.current_training_agent_id[] = 1 + + config = model.rl_config[] + + # If there's a model initialization function, use it + if haskey(config, :model_init_fn) + new_model = config.model_init_fn() + # Copy agents and properties from new model + remove_all!(model) + for agent in allagents(new_model) + add_agent!(agent, model) + end + # Copy properties if they exist + if !isnothing(abmproperties(new_model)) + for (key, value) in pairs(abmproperties(new_model)) + abmproperties(model)[key] = value + end + end + end +end + +""" + step_ahead_rl!(model::ReinforcementLearningABM, agent_step!, model_step!, n, t) + +Steps the model forward using RL policies for a specified number of steps. + +## Arguments +- `model::ReinforcementLearningABM`: The model to step +- `agent_step!`: Agent stepping function (fallback for non-RL agents) +- `model_step!`: Model stepping function +- `n`: Number of steps or stepping condition +- `t`: Time reference + +## Notes +This function is part of the internal stepping infrastructure and automatically +chooses between RL policies and standard agent stepping based on availability. +It's called internally by the `step!` function. +""" +function step_ahead_rl! end + +################ +### TRAINING ### +################ + +""" + setup_rl_training(model::ReinforcementLearningABM, agent_type; + training_steps=50_000, + value_network=nothing, + policy_network=nothing, + solver=nothing, + solver_type=:PPO, + solver_params=Dict() + ) → (env, solver) + +Set up RL training for a specific agent type using the ReinforcementLearningABM directly. + +## Arguments +- `model::ReinforcementLearningABM`: The model to train +- `agent_type::Type`: The agent type to train + +## Keyword Arguments +- `training_steps::Int`: Number of training steps (default: 50_000) +- `value_network`: Custom value network function (default: auto-generated) +- `policy_network`: Custom policy network function (default: auto-generated) +- `solver`: Complete custom solver (default: auto-generated based on solver_type) +- `solver_type::Symbol`: Type of RL solver (:PPO, :DQN, :A2C) (default: :PPO) +- `solver_params::Dict`: Custom parameters for the solver (default: Dict()) + +## Returns +- `(env, solver)`: A tuple containing the wrapped environment and configured solver +""" +function setup_rl_training end + +""" + train_agent_sequential(model::ReinforcementLearningABM, agent_types; + training_steps=50_000, + custom_networks=Dict(), + custom_solvers=Dict(), + solver_types=Dict(), + solver_params=Dict() + ) → (policies, solvers) + +Train multiple agent types sequentially using the ReinforcementLearningABM, where each +subsequent agent is trained against the previously trained agents. + +## Arguments +- `model::ReinforcementLearningABM`: The model to train +- `agent_types`: Agent type or vector of agent types to train sequentially + +## Keyword Arguments +- `training_steps::Int`: Number of training steps per agent (default: 50_000) +- `custom_networks::Dict`: Dict mapping agent types to custom network configurations +- `custom_solvers::Dict`: Dict mapping agent types to custom solvers +- `solver_types::Dict`: Dict mapping agent types to solver types (default: :PPO for all) +- `solver_params::Dict`: Dict mapping agent types to solver parameters + +## Returns +- `(policies, solvers)`: Tuple containing dictionaries of trained policies and solvers by agent type +""" +function train_agent_sequential end + +""" + train_agent_simultaneous(model::ReinforcementLearningABM, agent_types; + n_iterations=5, + batch_size=10_000, + custom_networks=Dict(), + custom_solvers=Dict(), + solver_types=Dict(), + solver_params=Dict() + ) → (policies, solvers) + +Train multiple agent types simultaneously using the ReinforcementLearningABM with +alternating batch updates. + +## Arguments +- `model::ReinforcementLearningABM`: The model to train +- `agent_types`: Agent type or vector of agent types to train simultaneously + +## Keyword Arguments +- `n_iterations::Int`: Number of alternating training iterations (default: 5) +- `batch_size::Int`: Size of training batches for each iteration (default: 10_000) +- `custom_networks::Dict`: Dict mapping agent types to custom network configurations +- `custom_solvers::Dict`: Dict mapping agent types to custom solvers +- `solver_types::Dict`: Dict mapping agent types to solver types (default: :PPO for all) +- `solver_params::Dict`: Dict mapping agent types to solver parameters + +## Returns +- `(policies, solvers)`: Tuple containing dictionaries of trained policies and solvers by agent type +""" +function train_agent_simultaneous end + +## Helper Functions for Custom Neural Networks +""" + process_solver_params(solver_params, agent_type) → Dict + +Process solver parameters that can be either global or per-agent-type. + +## Arguments +- `solver_params::Dict`: Dictionary of solver parameters, either global or per-agent-type +- `agent_type::Type`: The agent type to get parameters for + +## Returns +- `Dict`: Parameters specific to the given agent type +""" +function process_solver_params(solver_params, agent_type) + if isempty(solver_params) + return Dict() + end + + # Check if solver_params contains agent types as keys + if any(k isa Type for k in keys(solver_params)) + # Per-agent-type parameters + return get(solver_params, agent_type, Dict()) + else + # Global parameters + return solver_params + end +end + +""" + create_value_network(input_dims, hidden_layers=[64, 64], activation=relu) → Function + +Create a custom value network with specified architecture. + +## Arguments +- `input_dims`: Tuple specifying the input dimensions +- `hidden_layers::Vector{Int}`: Sizes of hidden layers (default: [64, 64]) +- `activation`: Activation function (default: relu) + +## Returns +- `Function`: A function that creates a ContinuousNetwork when called +""" +function create_value_network end + +""" + create_policy_network(input_dims, output_dims, action_space_values, hidden_layers=[64, 64], activation=relu) → Function + +Create a custom policy network with specified architecture. + +## Arguments +- `input_dims`: Tuple specifying the input dimensions +- `output_dims::Int`: Number of output neurons (action space size) +- `action_space_values`: The action space values for the policy (e.g. Crux.DiscreteSpace(5).vals) +- `hidden_layers::Vector{Int}`: Sizes of hidden layers (default: [64, 64]) +- `activation`: Activation function (default: relu) + +## Returns +- `Function`: A function that creates a DiscreteNetwork when called +""" +function create_policy_network end + +""" + create_custom_solver(solver_type, π, S; custom_params...) → Solver + +Create a custom solver with specified parameters. + +## Arguments +- `solver_type::Symbol`: Type of solver (:PPO, :DQN, :A2C) +- `π`: Policy network +- `S`: State/observation space +- `custom_params...`: Additional parameters for the solver + +## Returns +- `Solver`: The configured RL solver +""" +function create_custom_solver end + +""" + train_model!(model::ReinforcementLearningABM, agent_types; + training_mode=:sequential, kwargs...) → ReinforcementLearningABM + +Train the specified agent types in the model using reinforcement learning. This is the main +function for RL training in Agents.jl, supporting both single-agent and multi-agent +learning scenarios. + +## Training Modes + +### Sequential Training (`:sequential`) +Agents are trained one at a time in sequence. Each subsequent agent type is trained against +the previously trained agents. + +**Process:** +1. Train first agent type against random agents +2. Train second agent type against the trained first agent +3. Continue until all agent types are trained + +### Simultaneous Training (`:simultaneous`) +All agent types are trained at the same time with alternating batch updates. This creates +a co-evolutionary dynamic where agents adapt to each other simultaneously. + +**Process:** +1. Initialize solvers for all agent types +2. Alternate training batches between agent types +3. Each agent learns against the evolving policies of others + +## Arguments +- `model::ReinforcementLearningABM`: The model containing agents to train. Must have RL + configuration set via `set_rl_config!` before training. +- `agent_types`: Single agent type (e.g., `MyAgent`) or vector of agent types (e.g., + `[Predator, Prey]`) to train. All specified types must exist in the model and be + listed in the RL configuration's `training_agent_types`. + +## Keyword Arguments + +### Core Training Parameters + +- **`training_mode::Symbol`**: Training strategy (default: `:sequential`) + - `:sequential` - Train agent types one after another + - `:simultaneous` - Train all agent types together with alternating updates + +#### Sequential Training Steps +This applies only when `training_mode=:sequential`: + +- **`training_steps::Int`**: Number of environment steps for training each agent type + (default: 50,000). In sequential mode, this is per agent type. In simultaneous mode, + this determines the batch size for each alternating update. + +#### Simultaneous Training Steps +These apply only when `training_mode=:simultaneous`: + +- **`n_iterations::Int`**: Number of alternating training rounds (default: 5) +- **`batch_size::Int`**: Size of training batches for each iteration (default: 10,000) + +## Algorithm Configuration + +- **`solver_params::Dict`**: Algorithm-specific hyperparameters. Can be: + - **Global parameters**: Applied to all agent types + ```julia + solver_params = Dict( + :ΔN => 200, + :log => (period=1000,), + ) + ``` + - **Per-agent-type parameters**: Different settings for each agent type + ```julia + solver_params = Dict( + Predator => Dict(:ΔN => 100), + Prey => Dict(:ΔN => 200) + ) + ``` + +- **`solver_types::Dict{Type, Symbol}`**: Different RL algorithms for different agent types. + ```julia + solver_types = Dict( + FastAgent => :DQN, + SmartAgent => :PPO + ) + ``` + +## Network Architecture Customization + +- **`custom_networks::Dict{Type, Dict{Symbol, Function}}`**: Custom neural network + architectures for specific agent types. Each entry maps an agent type to a dictionary + containing `:value_network` and/or `:policy_network` functions. + ```julia + custom_networks = Dict( + MyAgent => Dict( + :value_network => () -> create_value_network((84,), [128, 64]), + :policy_network => () -> create_policy_network((84,), 5, action_space, [128, 64]) + ) + ) + ``` + +- **`custom_solvers::Dict{Type, Any}`**: Pre-configured complete solvers for specific + agent types. Bypasses automatic solver creation. + ```julia + custom_solvers = Dict( + MyAgent => my_preconfigured_ppo_solver + ) + ``` + + +## Returns +- `ReinforcementLearningABM`: The input model with trained policies stored in + `model.trained_policies`. The trained policies can be accessed via `get_trained_policies(model)` + or copied to other models using `copy_trained_policies!(target, source)`. + +## Notes +- `max_steps` is read directly from the RL configuration (`model.rl_config[][:max_steps]`) +- Episode termination is controlled by the RL environment wrapper using the config value +- Cannot override `max_steps` during training - it must be set in the RL configuration + +## Examples + +### Basic training with custom solver parameters +```julia +train_model!(model, MyAgent; + training_steps=10000, + solver_params=Dict(:ΔN => 100, :log => (period=500,))) +``` + +### Multi-Agent Sequential Training +```julia +# Train predator and prey sequentially +train_model!(model, [Predator, Prey]; + training_mode=:sequential, + training_steps=20000, + solver_params=Dict( + :ΔN => 100, + :log => (period=500,) + )) +``` + +# Multi-Agent Simultaneous Training +```julia +# Co-evolutionary training +train_model!(model, [PlayerA, PlayerB]; + training_mode=:simultaneous, + n_iterations=10, + batch_size=5000, + solver_params=Dict( + PlayerA => Dict(:ΔN => 100), + PlayerB => Dict(:ΔN => 200) + )) +``` + +## See Also + +- [`ReinforcementLearningABM`](@ref): The model type used for RL training +- [`set_rl_config!`](@ref): Setting up RL configuration +- [`copy_trained_policies!`](@ref): Copying policies between models +- [`setup_rl_training`](@ref): Lower-level training setup +- [Crux.jl documentation](https://github.com/sisl/Crux.jl) for solver details +""" +function train_model! end diff --git a/test/Project.toml b/test/Project.toml index 27aa48da40..6e2e3b3ff6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,14 +6,17 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +Crux = "e51cc422-768a-4345-bb8e-2246287ae729" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" OSMMakie = "76b6901f-8821-46bb-9129-841bc9cfe677" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/test/reinforcement_learning_tests.jl b/test/reinforcement_learning_tests.jl new file mode 100644 index 0000000000..35294fb143 --- /dev/null +++ b/test/reinforcement_learning_tests.jl @@ -0,0 +1,557 @@ +# Test agents for RL testing +@agent struct RLTestAgent(GridAgent{2}) + wealth::Float64 + last_action::Int = 0 +end + +@agent struct RLTestAgent2(GridAgent{2}) + energy::Float64 + strategy::Int = 1 +end + +# Helper functions for testing +function test_observation_fn(model::ReinforcementLearningABM, agent_id::Int) +end + +function test_reward_fn(env::ReinforcementLearningABM, agent::AbstractAgent, action::Int, + initial_model::ReinforcementLearningABM, final_model::ReinforcementLearningABM) +end + +function test_terminal_fn(env::ReinforcementLearningABM) +end + +function test_agent_step_fn(agent::AbstractAgent, model::ReinforcementLearningABM, action::Int) +end + +function test_model_init_fn() + space = GridSpace((5, 5)) + model = ReinforcementLearningABM(RLTestAgent, space; rng=StableRNG(42)) + + # Add some agents + for i in 1:3 + add_agent!(RLTestAgent, model, rand() * 50.0, 0) + end + + return model +end + +# Create basic RL configuration +function create_test_rl_config() + return ( + observation_fn=test_observation_fn, + reward_fn=test_reward_fn, + terminal_fn=test_terminal_fn, + agent_step_fn=test_agent_step_fn, + action_spaces=Dict(RLTestAgent => (; vals=1:5)), + observation_spaces=Dict(RLTestAgent => (; dim=(3,))), + training_agent_types=[RLTestAgent], + max_steps=20, + observation_radius=2, + discount_rates=Dict(RLTestAgent => 0.95), + model_init_fn=test_model_init_fn + ) +end + +@testset "ReinforcementLearningABM Interface Tests" begin + + @testset "Model Construction" begin + space = GridSpace((5, 5)) + + # Test basic construction + model = ReinforcementLearningABM(RLTestAgent, space) + @test model isa ReinforcementLearningABM + @test abmspace(model) isa GridSpace + @test isnothing(model.rl_config[]) + @test isempty(model.trained_policies) + @test isempty(model.training_history) + @test model.is_training[] == false + @test isnothing(model.current_training_agent_type[]) + @test model.current_training_agent_id[] == 1 + + # Test with initial RL config + config = create_test_rl_config() + model2 = ReinforcementLearningABM(RLTestAgent, space, config) + @test !isnothing(model2.rl_config[]) + @test model2.rl_config[].max_steps == 20 + + # Test with other standard parameters + model3 = ReinforcementLearningABM(RLTestAgent, space; + scheduler=Schedulers.ByID(), + properties=Dict(:test => true), + rng=StableRNG(123)) + @test model3.scheduler isa Schedulers.ByID + @test model3.properties[:test] == true + @test model3.rng isa StableRNG + end + + @testset "RL Configuration Management" begin + space = GridSpace((5, 5)) + model = ReinforcementLearningABM(RLTestAgent, space) + + # Test setting RL config + config = create_test_rl_config() + set_rl_config!(model, config) + + @test !isnothing(model.rl_config[]) + @test model.rl_config[].max_steps == 20 + @test haskey(model.rl_config[].action_spaces, RLTestAgent) + @test model.rl_config[].training_agent_types == [RLTestAgent] + + # Test training history initialization + @test haskey(model.training_history, RLTestAgent) + @test isnothing(model.training_history[RLTestAgent]) + + agent = add_agent!(RLTestAgent, model, 10.0, 0) + + @test model.rl_config[] isa NamedTuple + @test model.trained_policies isa Dict + @test Agents.get_current_training_agent_type(model) == RLTestAgent + @test Agents.get_current_training_agent(model) == agent + + end + + @testset "Property Access" begin + space = GridSpace((5, 5)) + props = Dict(:custom_prop => 42, :another_prop => "test") + model = ReinforcementLearningABM(RLTestAgent, space; properties=props) + config = create_test_rl_config() + set_rl_config!(model, config) + + # Test direct RL property access + @test !isnothing(model.rl_config) + @test isempty(model.trained_policies) + @test !isempty(model.training_history) + @test model.is_training[] == false + + # Test standard ABM property access + @test model.agents isa AbstractDict + @test model.space isa GridSpace + @test model.maxid[] == 0 + @test model.time[] == 0 + + # Test custom property access via properties + @test model.custom_prop == 42 + @test model.another_prop == "test" + + # Test property modification + model.custom_prop = 100 + @test model.custom_prop == 100 + @test model.properties[:custom_prop] == 100 + + # Test error on setting properties field directly + @test_throws ErrorException model.properties = Dict() + end + + @testset "Agent Management" begin + space = GridSpace((5, 5)) + model = ReinforcementLearningABM(RLTestAgent, space; rng=StableRNG(42)) + config = create_test_rl_config() + set_rl_config!(model, config) + + # Add agents + agent1 = add_agent!(RLTestAgent, model, 25.0, 0) + agent2 = add_agent!(RLTestAgent, model, 30.0, 0) + + @test nagents(model) == 2 + @test model[agent1.id] isa RLTestAgent + @test model[agent2.id] isa RLTestAgent + + # Test agent removal + remove_agent!(agent1, model) + @test nagents(model) == 1 + @test_throws KeyError model[agent1.id] + end + + @testset "Current Training Agent Management" begin + space = GridSpace((5, 5)) + model = ReinforcementLearningABM(RLTestAgent, space; rng=StableRNG(42)) + config = create_test_rl_config() + set_rl_config!(model, config) + + # Test before adding agents + @test Agents.get_current_training_agent_type(model) == RLTestAgent + @test isnothing(Agents.get_current_training_agent(model)) + + # Add agents and test cycling + agent1 = add_agent!(RLTestAgent, model, 25.0, 0) + agent2 = add_agent!(RLTestAgent, model, 30.0, 0) + + # Test first agent + current_agent = Agents.get_current_training_agent(model) + @test model.current_training_agent_id[] == 1 + + model.current_training_agent_id[] += 1 + @test model.current_training_agent_id[] == 2 + end + + @testset "Model Reset" begin + space = GridSpace((5, 5)) + model = ReinforcementLearningABM(RLTestAgent, space; rng=StableRNG(42)) + config = create_test_rl_config() + set_rl_config!(model, config) + + # Add agents and advance time + add_agent!(RLTestAgent, model, 25.0, 0) + add_agent!(RLTestAgent, model, 30.0, 0) + model.time[] = 10 + model.current_training_agent_id[] = 2 + + # Test reset + Agents.reset_model_for_episode!(model) + @test model.time[] == 0 + @test model.current_training_agent_id[] == 1 + + # Since we have model_init_fn, agents should be reset too + @test nagents(model) == 3 # model_init_fn creates 3 agents + end + + @testset "Policy Management" begin + space = GridSpace((5, 5)) + model1 = ReinforcementLearningABM(RLTestAgent, space) + model2 = ReinforcementLearningABM(RLTestAgent, space) + + # Test empty policies initially + @test isempty(Agents.get_trained_policies(model1)) + @test isempty(Agents.get_trained_policies(model2)) + + # Mock a trained policy + mock_policy = "mock_policy_object" + model1.trained_policies[RLTestAgent] = mock_policy + + @test haskey(Agents.get_trained_policies(model1), RLTestAgent) + @test Agents.get_trained_policies(model1)[RLTestAgent] == mock_policy + + # Test policy copying + copy_trained_policies!(model2, model1) + @test haskey(Agents.get_trained_policies(model2), RLTestAgent) + @test Agents.get_trained_policies(model2)[RLTestAgent] == mock_policy + end + + @testset "Multi-Agent Type Configuration" begin + space = GridSpace((5, 5)) + model = ReinforcementLearningABM(Union{RLTestAgent,RLTestAgent2}, space; rng=StableRNG(42)) + + # Create config for multiple agent types + config = ( + observation_fn=(model, agent_id) -> Float32[model[agent_id].pos...], + reward_fn=(env, agent, action, init, final) -> 1.0f0, + terminal_fn=(env) -> false, + agent_step_fn=(agent, model, action) -> nothing, + action_spaces=Dict( + RLTestAgent => (; vals=1:5), + RLTestAgent2 => (; vals=1:3) + ), + observation_spaces=Dict( + RLTestAgent => (; dim=(2,)), + RLTestAgent2 => (; dim=(2,)) + ), + training_agent_types=[RLTestAgent, RLTestAgent2], + max_steps=50, + observation_radius=1, + discount_rates=Dict( + RLTestAgent => 0.95, + RLTestAgent2 => 0.99 + ) + ) + + set_rl_config!(model, config) + + # Add different agent types + agent1 = add_agent!(RLTestAgent, model, 25.0, 0) + agent2 = add_agent!(RLTestAgent2, model, 50.0, 0) + + @test nagents(model) == 2 + @test typeof(model[agent1.id]) == RLTestAgent + @test typeof(model[agent2.id]) == RLTestAgent2 + + # Test configuration access for different types + @test haskey(model.rl_config[].action_spaces, RLTestAgent) + @test haskey(model.rl_config[].action_spaces, RLTestAgent2) + @test length(model.rl_config[].action_spaces[RLTestAgent].vals) == 5 + @test length(model.rl_config[].action_spaces[RLTestAgent2].vals) == 3 + end + + @testset "Configuration Validation and Error Handling" begin + space = GridSpace((5, 5)) + model = ReinforcementLearningABM(RLTestAgent, space) + + # 1. Test errors when no RL config is set + @test_throws ErrorException Agents.get_current_training_agent_type(model) + @test_throws ErrorException Agents.get_current_training_agent(model) + @test_throws ErrorException Agents.reset_model_for_episode!(model) + + # 2. Test with a minimal valid config + minimal_config = (training_agent_types=[RLTestAgent],) + set_rl_config!(model, minimal_config) + + @test Agents.get_current_training_agent_type(model) == RLTestAgent + @test isnothing(Agents.get_current_training_agent(model)) # No agents added yet + + # Reset should work with a minimal config (no model_init_fn) + model.time[] = 5 + Agents.reset_model_for_episode!(model) + @test model.time[] == 0 + + # 3. Test error with an invalid config (empty training agent types) + invalid_config = (training_agent_types=[],) + set_rl_config!(model, invalid_config) + @test_throws ErrorException Agents.get_current_training_agent_type(model) + end + + @testset "Standard ABM Interface Compatibility" begin + space = GridSpace((5, 5)) + model = ReinforcementLearningABM(RLTestAgent, space; rng=StableRNG(42)) + config = create_test_rl_config() + set_rl_config!(model, config) + + # Test that standard ABM functions work + agent1 = add_agent!(RLTestAgent, model, 25.0, 0) + agent2 = add_agent!(RLTestAgent, model, 50.0, 0) + + @test nagents(model) == 2 + @test abmtime(model) == 0 + @test abmspace(model) isa GridSpace + @test abmscheduler(model) isa Schedulers.Randomly + @test abmrng(model) isa StableRNG + + # Test agent access + @test model[agent1.id] == agent1 + @test agent1 in allagents(model) + + # Test spatial queries + neighbors = collect(nearby_agents(agent1, model, 5)) + @test agent2 in [n for n in neighbors] + + # Test scheduling + scheduled_ids = collect(abmscheduler(model)(model)) + @test length(scheduled_ids) == 2 + @test agent1.id in scheduled_ids + @test agent2.id in scheduled_ids + end +end + +# Test for RL wrapper functionality +@testset "RL Wrapper Interface Tests" begin + + # Define test components first + @agent struct WrapperTestAgent(GridAgent{2}) + wealth::Float64 + last_action::Int = 0 + end + + function wrapper_test_observation_fn(model::ReinforcementLearningABM, agent_id::Int) + agent = model[agent_id] + # Return position and wealth as observation + return Float32[agent.pos[1], agent.pos[2], agent.wealth/100.0] + end + + function wrapper_test_reward_fn(env::ReinforcementLearningABM, agent::AbstractAgent, action::Int, + initial_model::ReinforcementLearningABM, final_model::ReinforcementLearningABM) + # Reward based on wealth increase + initial_wealth = initial_model[agent.id].wealth + final_wealth = agent.wealth + return Float32(final_wealth - initial_wealth) + end + + function wrapper_test_terminal_fn(env::ReinforcementLearningABM) + # Terminal if any agent has wealth > 50 or time > 10 + return any(a.wealth > 50.0 for a in allagents(env)) || abmtime(env) >= 10 + end + + function wrapper_test_agent_step_fn(agent::AbstractAgent, model::ReinforcementLearningABM, action::Int) + agent.last_action = action + + # Actions: 1=up, 2=right, 3=down, 4=left, 5=work(+wealth) + if action == 1 && agent.pos[2] > 1 + move_agent!(agent, (agent.pos[1], agent.pos[2] - 1), model) + elseif action == 2 && agent.pos[1] < abmspace(model).extent[1] + move_agent!(agent, (agent.pos[1] + 1, agent.pos[2]), model) + elseif action == 3 && agent.pos[2] < abmspace(model).extent[2] + move_agent!(agent, (agent.pos[1], agent.pos[2] + 1), model) + elseif action == 4 && agent.pos[1] > 1 + move_agent!(agent, (agent.pos[1] - 1, agent.pos[2]), model) + elseif action == 5 + agent.wealth += 5.0 + end + end + + function wrapper_test_model_init_fn() + space = GridSpace((4, 4)) + model = ReinforcementLearningABM(WrapperTestAgent, space; rng=StableRNG(123)) + add_agent!(WrapperTestAgent, model, 10.0, 0) + return model + end + + function create_wrapper_test_config() + return ( + observation_fn=wrapper_test_observation_fn, + reward_fn=wrapper_test_reward_fn, + terminal_fn=wrapper_test_terminal_fn, + agent_step_fn=wrapper_test_agent_step_fn, + action_spaces=Dict(WrapperTestAgent => (; vals=1:5)), + observation_spaces=Dict(WrapperTestAgent => (; dim=(3,))), + training_agent_types=[WrapperTestAgent], + max_steps=10, + observation_radius=2, + discount_rates=Dict(WrapperTestAgent => 0.9), + model_init_fn=wrapper_test_model_init_fn + ) + end + + @testset "Wrapper Prerequisites" begin + space = GridSpace((4, 4)) + model = ReinforcementLearningABM(WrapperTestAgent, space; rng=StableRNG(42)) + config = create_wrapper_test_config() + set_rl_config!(model, config) + + # Add test agent + agent = add_agent!(WrapperTestAgent, model, 15.0, 0) + + # Test observation function works + obs = config.observation_fn(model, agent.id) + @test obs isa Vector{Float32} + @test length(obs) == 3 + @test obs[1] == 2.0f0 # x position + @test obs[2] == 1.0f0 # y position + @test obs[3] ≈ 0.15f0 # wealth/100 + + # Test reward function works + initial_model = deepcopy(model) + wrapper_test_agent_step_fn(agent, model, 5) # work action + reward = config.reward_fn(model, agent, 5, initial_model, model) + @test reward == 5.0f0 # wealth increased by 5 + + # Test terminal function works + @test config.terminal_fn(model) == false # Not terminal yet + + agent.wealth = 60.0 + @test config.terminal_fn(model) == true # Now terminal due to high wealth + + # Reset wealth and test time terminal + agent.wealth = 20.0 + model.time[] = 15 + @test config.terminal_fn(model) == true # Terminal due to time + + # Test agent step function works + model.time[] = 0 + initial_pos = agent.pos + wrapper_test_agent_step_fn(agent, model, 2) # move right + @test agent.pos == (initial_pos[1] + 1, initial_pos[2]) + @test agent.last_action == 2 + end + + @testset "Model Reset for Training Episodes" begin + space = GridSpace((4, 4)) + model = ReinforcementLearningABM(WrapperTestAgent, space; rng=StableRNG(42)) + config = create_wrapper_test_config() + set_rl_config!(model, config) + + # Add agents and modify state + agent1 = add_agent!(WrapperTestAgent, model, 20.0, 0) + agent2 = add_agent!(WrapperTestAgent, model, 25.0, 0) + model.time[] = 5 + model.current_training_agent_id[] = 2 + + @test nagents(model) == 2 + @test abmtime(model) == 5 + @test model.current_training_agent_id[] == 2 + + # Reset using model_init_fn + Agents.reset_model_for_episode!(model) + + @test abmtime(model) == 0 + @test model.current_training_agent_id[] == 1 + @test nagents(model) == 1 # model_init_fn creates only 1 agent + + # Verify the new agent has expected properties + new_agent = collect(allagents(model))[1] + @test new_agent.pos == (1, 4) + @test new_agent.wealth == 10.0 + end + + @testset "Training Configuration Validation" begin + space = GridSpace((3, 3)) + model = ReinforcementLearningABM(WrapperTestAgent, space) + + # Test that wrapper functions expect proper RL config structure + config = create_wrapper_test_config() + + # Verify all required config components are present + @test haskey(config, :observation_fn) + @test haskey(config, :reward_fn) + @test haskey(config, :terminal_fn) + @test haskey(config, :agent_step_fn) + @test haskey(config, :action_spaces) + @test haskey(config, :observation_spaces) + @test haskey(config, :training_agent_types) + @test haskey(config, :max_steps) + @test haskey(config, :observation_radius) + @test haskey(config, :discount_rates) + + # Test setting the config + set_rl_config!(model, config) + @test model.rl_config[] == config + + # Verify action space structure + action_space = config.action_spaces[WrapperTestAgent] + @test haskey(action_space, :vals) + @test action_space.vals == 1:5 + + # Verify observation space structure + obs_space = config.observation_spaces[WrapperTestAgent] + @test haskey(obs_space, :dim) + @test obs_space.dim == (3,) + end + + @testset "Multi-Step Episode Simulation" begin + space = GridSpace((4, 4)) + model = ReinforcementLearningABM(WrapperTestAgent, space; rng=StableRNG(42)) + config = create_wrapper_test_config() + set_rl_config!(model, config) + + # Add an agent + agent = add_agent!(WrapperTestAgent, model, 10.0, 0) + + # Simulate several actions that a wrapper would perform + actions = [5, 2, 3, 5, 1] # work, right, down, work, up + + for action in actions + # Record initial state + initial_model = deepcopy(model) + initial_obs = config.observation_fn(model, agent.id) + + # Execute action + config.agent_step_fn(agent, model, action) + + # Calculate reward + reward = config.reward_fn(model, agent, action, initial_model, model) + + # Get new observation + new_obs = config.observation_fn(model, agent.id) + + # Check terminal condition + is_terminal = config.terminal_fn(model) + + # Verify reasonable behavior + @test agent.last_action == action + @test reward isa Float32 + @test new_obs isa Vector{Float32} + @test length(new_obs) == 3 + + if action == 5 # work action + @test reward > 0.0 # Should get positive reward + end + + if is_terminal + break + end + + # Advance time (simulating environment step) + model.time[] += 1 + end + + # Verify agent has accumulated wealth from work actions + @test agent.wealth > 10.0 # Started with 10, should have increased + @test agent.wealth == 20.0 # Two work actions = +10 wealth + end +end \ No newline at end of file diff --git a/test/rl_extension_tests.jl b/test/rl_extension_tests.jl new file mode 100644 index 0000000000..33e3b01228 --- /dev/null +++ b/test/rl_extension_tests.jl @@ -0,0 +1,553 @@ +using POMDPs, Crux, Flux + +# Test agents for RL extension testing +@agent struct RLExtensionTestAgent(GridAgent{2}) + energy::Float64 + wealth::Float64 + last_action::Int = 0 +end + +@agent struct RLExtensionTestPredator(GridAgent{2}) + energy::Float64 + hunt_success::Int = 0 +end + +@agent struct RLExtensionTestPrey(GridAgent{2}) + energy::Float64 + escape_count::Int = 0 +end + +# Helper functions for RL extension testing +function create_simple_rl_model(; n_agents=5, dims=(8, 8), seed=42) + rng = StableRNG(seed) + space = GridSpace(dims; periodic=true) + + model = ReinforcementLearningABM(RLExtensionTestAgent, space; rng=rng) + + for _ in 1:n_agents + add_agent!(RLExtensionTestAgent, model, rand(rng) * 50.0, rand(rng) * 100.0, 0) + end + + return model +end + +function create_multi_agent_rl_model(; n_predators=3, n_prey=7, dims=(10, 10), seed=42) + rng = StableRNG(seed) + space = GridSpace(dims; periodic=true) + + model = ReinforcementLearningABM(Union{RLExtensionTestPredator,RLExtensionTestPrey}, space; rng=rng) + + for _ in 1:n_predators + add_agent!(RLExtensionTestPredator, model, rand(rng) * 30.0, 0) + end + + for _ in 1:n_prey + add_agent!(RLExtensionTestPrey, model, rand(rng) * 20.0, 0) + end + + return model +end + + +function simple_observation_fn(model, agent_id) + observation_radius = model.rl_config[][:observation_radius] + agent = model[agent_id] + # Simple observation: agent position, energy, wealth, and neighbor count + neighbor_count = length([a for a in nearby_agents(agent, model, observation_radius)]) + return Float32[agent.pos[1], agent.pos[2], agent.energy/50.0, agent.wealth/100.0, neighbor_count/10.0] +end + +function simple_reward_fn(env, agent, action, initial_model, final_model) + # Death penalty + if agent.id ∉ [a.id for a in allagents(final_model)] + return -100.0f0 + end + + # Small positive reward for survival plus energy-based bonus + reward = 1.0f0 + agent.energy / 100.0f0 + + # Bonus for wealth increase + if haskey(initial_model.agents, agent.id) + initial_wealth = initial_model[agent.id].wealth + if agent.wealth > initial_wealth + reward += (agent.wealth - initial_wealth) / 10.0f0 + end + end + + return Float32(reward) +end + +function simple_terminal_fn(env) + # Terminal if less than 2 agents remain or time exceeds limit + return length(allagents(env)) < 2 || abmtime(env) >= 50 +end + +function simple_agent_step_fn(agent, model, action) + agent.last_action = action + + # Actions: 1=stay, 2=north, 3=south, 4=east, 5=west + current_x, current_y = agent.pos + width, height = getfield(model, :space).extent + + dx, dy = 0, 0 + if action == 2 # North + dy = 1 + elseif action == 3 # South + dy = -1 + elseif action == 4 # East + dx = 1 + elseif action == 5 # West + dx = -1 + end + + # Apply periodic boundary movement + if action != 1 + new_x = mod1(current_x + dx, width) + new_y = mod1(current_y + dy, height) + move_agent!(agent, (new_x, new_y), model) + end + + # Energy and wealth updates + agent.energy = max(0.0, agent.energy - 0.5) # Movement cost + if action == 1 # Stay action gives small wealth bonus + agent.wealth += 1.0 + agent.energy += 0.5 + end + + # Remove agent if energy depleted + if agent.energy <= 0 + remove_agent!(agent, model) + end +end + +function multi_agent_observation_fn(model, agent_id) + observation_radius = model.rl_config[][:observation_radius] + agent = model[agent_id] + + # Different observations for different agent types + if agent isa RLExtensionTestPredator + # Predators see prey positions and energy + prey_nearby = length([a for a in nearby_agents(agent, model, observation_radius) + if a isa RLExtensionTestPrey]) + return Float32[agent.pos[1], agent.pos[2], agent.energy/30.0, Float32(prey_nearby)] + else # Prey + # Prey see predator positions and escape routes + predators_nearby = length([a for a in nearby_agents(agent, model, observation_radius) + if a isa RLExtensionTestPredator]) + return Float32[agent.pos[1], agent.pos[2], agent.energy/20.0, Float32(predators_nearby)] + end +end + +function multi_agent_reward_fn(env, agent, action, initial_model, final_model) + # Death penalty + if agent.id ∉ [a.id for a in allagents(final_model)] + return -50.0f0 + end + + if agent isa RLExtensionTestPredator + # Predator rewards: hunt success + reward = 0.5f0 # Base survival + if haskey(initial_model.agents, agent.id) + if agent.hunt_success > initial_model[agent.id].hunt_success + reward += 10.0f0 # Hunt success bonus + end + end + else # Prey + # Prey rewards: survival and escape + reward = 1.0f0 # Base survival + if haskey(initial_model.agents, agent.id) + if agent.escape_count > initial_model[agent.id].escape_count + reward += 5.0f0 # Escape bonus + end + end + end + + return reward +end + +function multi_agent_agent_step_fn(agent, model, action) + # Basic movement (same as simple_agent_step_fn) + current_x, current_y = agent.pos + width, height = getfield(model, :space).extent + + dx, dy = 0, 0 + if action == 2 # North + dy = 1 + elseif action == 3 # South + dy = -1 + elseif action == 4 # East + dx = 1 + elseif action == 5 # West + dx = -1 + end + + if action != 1 + new_x = mod1(current_x + dx, width) + new_y = mod1(current_y + dy, height) + move_agent!(agent, (new_x, new_y), model) + end + + # Type-specific behavior + if agent isa RLExtensionTestPredator + agent.energy = max(0.0, agent.energy - 1.0) # Higher energy cost + + # Check for prey to hunt + prey_here = [a for a in agents_in_position(agent.pos, model) + if a isa RLExtensionTestPrey] + if !isempty(prey_here) + prey = prey_here[1] + remove_agent!(prey, model) + agent.energy += 15.0 # Energy gain from hunt + agent.hunt_success += 1 + end + else # Prey + agent.energy = max(0.0, agent.energy - 0.5) # Lower energy cost + + # Check if escaping from predator + predators_nearby = [a for a in nearby_agents(agent, model, 1) + if a isa RLExtensionTestPredator] + if !isempty(predators_nearby) && action != 1 # Moving away counts as escape attempt + agent.escape_count += 1 + end + + # Energy recovery when staying + if action == 1 + agent.energy += 1.0 + end + end + + # Remove agent if energy depleted + if agent.energy <= 0 + remove_agent!(agent, model) + end +end + +@testset "RL Extension Training Functions" begin + + @testset "Setup RL Training" begin + model = create_simple_rl_model() + + # Set up RL configuration + rl_config = ( + model_init_fn=() -> create_simple_rl_model(), + observation_fn=simple_observation_fn, + reward_fn=simple_reward_fn, + terminal_fn=simple_terminal_fn, + agent_step_fn=simple_agent_step_fn, + action_spaces=Dict( + RLExtensionTestAgent => Crux.DiscreteSpace(5) + ), + observation_spaces=Dict( + RLExtensionTestAgent => Crux.ContinuousSpace((5,), Float32) + ), + training_agent_types=[RLExtensionTestAgent], + max_steps=20, + observation_radius=2 + ) + + set_rl_config!(model, rl_config) + + # Test setup_rl_training function from extension + env, solver = Agents.setup_rl_training(model, RLExtensionTestAgent; training_steps=1000) + + @test env isa POMDPs.POMDP + @test solver isa OnPolicySolver + + # Test that solver has correct configuration + @test solver.N == 1000 # Training steps + @test solver.ΔN == 200 # Default batch size + @test solver.agent isa PolicyParams + end + + @testset "Custom Network Creation" begin + model = create_simple_rl_model() + + rl_config = ( + model_init_fn=() -> create_simple_rl_model(), + observation_fn=simple_observation_fn, + reward_fn=simple_reward_fn, + terminal_fn=simple_terminal_fn, + agent_step_fn=simple_agent_step_fn, + action_spaces=Dict( + RLExtensionTestAgent => Crux.DiscreteSpace(5) + ), + observation_spaces=Dict( + RLExtensionTestAgent => Crux.ContinuousSpace((5,), Float32) + ), + training_agent_types=[RLExtensionTestAgent], + max_steps=20, + observation_radius=2 + ) + + set_rl_config!(model, rl_config) + + # Test custom network creation functions + value_net_fn = Agents.create_value_network((5,), [32, 16]) + policy_net_fn = Agents.create_policy_network((5,), 5, Crux.DiscreteSpace(5).vals, [32, 16]) + + @test value_net_fn isa Function + @test policy_net_fn isa Function + + # Create networks and test structure + value_net = value_net_fn() + policy_net = policy_net_fn() + + @test value_net isa ContinuousNetwork + @test policy_net isa DiscreteNetwork + + # Test with custom networks + env, solver = Agents.setup_rl_training(model, RLExtensionTestAgent; + training_steps=500, + value_network=value_net_fn, + policy_network=policy_net_fn + ) + + @test solver.agent.π.A isa DiscreteNetwork + @test solver.agent.π.C isa ContinuousNetwork + end + + @testset "Sequential Training" begin + model = create_multi_agent_rl_model() + + # Set up multi-agent RL configuration + rl_config = ( + model_init_fn=() -> create_multi_agent_rl_model(), + observation_fn=multi_agent_observation_fn, + reward_fn=multi_agent_reward_fn, + terminal_fn=(env) -> length(allagents(env)) < 3 || abmtime(env) >= 30, + agent_step_fn=multi_agent_agent_step_fn, + action_spaces=Dict( + RLExtensionTestPredator => Crux.DiscreteSpace(5), + RLExtensionTestPrey => Crux.DiscreteSpace(5) + ), + observation_spaces=Dict( + RLExtensionTestPredator => Crux.ContinuousSpace((4,), Float32), + RLExtensionTestPrey => Crux.ContinuousSpace((4,), Float32) + ), + training_agent_types=[RLExtensionTestPredator, RLExtensionTestPrey], + max_steps=30, + observation_radius=3 + ) + + set_rl_config!(model, rl_config) + + # Test sequential training with small parameters for speed + policies, solvers = Agents.train_agent_sequential(model, + [RLExtensionTestPredator, RLExtensionTestPrey]; + training_steps=10, + solver_params=Dict(:ΔN => 5) + ) + + @test length(policies) == 2 + @test length(solvers) == 2 + @test haskey(policies, RLExtensionTestPredator) + @test haskey(policies, RLExtensionTestPrey) + @test haskey(solvers, RLExtensionTestPredator) + @test haskey(solvers, RLExtensionTestPrey) + + # Test that policies are different objects + @test policies[RLExtensionTestPredator] !== policies[RLExtensionTestPrey] + @test solvers[RLExtensionTestPredator] !== solvers[RLExtensionTestPrey] + + # Test that trained policies are stored in model + @test haskey(model.trained_policies, RLExtensionTestPredator) + end + + @testset "Simultaneous Training" begin + model = create_multi_agent_rl_model(n_predators=2, n_prey=3) + + rl_config = ( + model_init_fn=() -> create_multi_agent_rl_model(n_predators=2, n_prey=3), + observation_fn=multi_agent_observation_fn, + reward_fn=multi_agent_reward_fn, + terminal_fn=(env) -> length(allagents(env)) < 2 || abmtime(env) >= 25, + agent_step_fn=multi_agent_agent_step_fn, + action_spaces=Dict( + RLExtensionTestPredator => Crux.DiscreteSpace(5), + RLExtensionTestPrey => Crux.DiscreteSpace(5) + ), + observation_spaces=Dict( + RLExtensionTestPredator => Crux.ContinuousSpace((4,), Float32), + RLExtensionTestPrey => Crux.ContinuousSpace((4,), Float32) + ), + training_agent_types=[RLExtensionTestPredator, RLExtensionTestPrey], + max_steps=25, + observation_radius=2 + ) + + set_rl_config!(model, rl_config) + + # Test simultaneous training with small parameters + policies, solvers = Agents.train_agent_simultaneous(model, + [RLExtensionTestPredator, RLExtensionTestPrey]; + n_iterations=2, + batch_size=10, + solver_params=Dict(:ΔN => 5) + ) + + @test length(policies) == 2 + @test length(solvers) == 2 + @test haskey(policies, RLExtensionTestPredator) + @test haskey(policies, RLExtensionTestPrey) + + # Verify policies are different + @test policies[RLExtensionTestPredator] !== policies[RLExtensionTestPrey] + + # Test that model has been updated with trained policies + @test haskey(model.trained_policies, RLExtensionTestPredator) + @test haskey(model.trained_policies, RLExtensionTestPrey) + end + + @testset "Train Model Function Integration" begin + model = create_simple_rl_model(n_agents=3) + + rl_config = ( + model_init_fn=() -> create_simple_rl_model(n_agents=3), + observation_fn=simple_observation_fn, + reward_fn=simple_reward_fn, + terminal_fn=simple_terminal_fn, + agent_step_fn=simple_agent_step_fn, + action_spaces=Dict( + RLExtensionTestAgent => Crux.DiscreteSpace(5) + ), + observation_spaces=Dict( + RLExtensionTestAgent => Crux.ContinuousSpace((5,), Float32) + ), + training_agent_types=[RLExtensionTestAgent], + max_steps=15, + observation_radius=1 + ) + + set_rl_config!(model, rl_config) + + # Test single agent training via train_model! + train_model!(model, RLExtensionTestAgent; + training_steps=10, + solver_params=Dict(:ΔN => 5) + ) + + @test haskey(model.trained_policies, RLExtensionTestAgent) + @test model.trained_policies[RLExtensionTestAgent] isa Crux.ActorCritic + + # Test with different solver types + model2 = create_simple_rl_model(n_agents=3) + set_rl_config!(model2, rl_config) + + train_model!(model2, RLExtensionTestAgent; + solver_types=Dict(RLExtensionTestAgent => :A2C), + training_steps=10, + solver_params=Dict(:ΔN => 5) + ) + + @test haskey(model2.trained_policies, RLExtensionTestAgent) + @test model2.trained_policies[RLExtensionTestAgent] isa Crux.ActorCritic + end + + @testset "Solver Parameter Processing" begin + # Test process_solver_params function + global_params = Dict(:ΔN => 100, :log => (period=500,)) + + # Test with single agent type + processed = Agents.process_solver_params(global_params, RLExtensionTestAgent) + @test processed[:ΔN] == 100 + @test processed[:log] == (period=500,) + + # Test with agent-specific parameters + agent_specific_params = Dict( + RLExtensionTestPredator => Dict(:ΔN => 50, :lr => 0.001), + RLExtensionTestPrey => Dict(:ΔN => 75, :lr => 0.002), + ) + + pred_params = Agents.process_solver_params(agent_specific_params, RLExtensionTestPredator) + @test pred_params[:ΔN] == 50 + @test pred_params[:lr] == 0.001 + + prey_params = Agents.process_solver_params(agent_specific_params, RLExtensionTestPrey) + @test prey_params[:ΔN] == 75 + @test prey_params[:lr] == 0.002 + end + + @testset "Policy Copying and Management" begin + # Create source model with training + source_model = create_simple_rl_model() + + rl_config = ( + model_init_fn=() -> create_simple_rl_model(), + observation_fn=simple_observation_fn, + reward_fn=simple_reward_fn, + terminal_fn=simple_terminal_fn, + agent_step_fn=simple_agent_step_fn, + action_spaces=Dict( + RLExtensionTestAgent => Crux.DiscreteSpace(5) + ), + observation_spaces=Dict( + RLExtensionTestAgent => Crux.ContinuousSpace((5,), Float32) + ), + training_agent_types=[RLExtensionTestAgent], + max_steps=10, + observation_radius=1 + ) + + set_rl_config!(source_model, rl_config) + + # Train source model + train_model!(source_model, RLExtensionTestAgent; training_steps=10, + solver_params=Dict(:ΔN => 5)) + + # Create target model and copy policies + target_model = create_simple_rl_model() + set_rl_config!(target_model, rl_config) + + @test isempty(target_model.trained_policies) + + copy_trained_policies!(target_model, source_model) + + @test haskey(target_model.trained_policies, RLExtensionTestAgent) + @test target_model.trained_policies[RLExtensionTestAgent] === + source_model.trained_policies[RLExtensionTestAgent] + + # Test get_trained_policies function + policies = get_trained_policies(target_model) + @test haskey(policies, RLExtensionTestAgent) + @test policies[RLExtensionTestAgent] === source_model.trained_policies[RLExtensionTestAgent] + end + + + @testset "Different Solver Types" begin + model = create_simple_rl_model() + + rl_config = ( + model_init_fn=() -> create_simple_rl_model(), + observation_fn=simple_observation_fn, + reward_fn=simple_reward_fn, + terminal_fn=simple_terminal_fn, + agent_step_fn=simple_agent_step_fn, + action_spaces=Dict( + RLExtensionTestAgent => Crux.DiscreteSpace(5) + ), + observation_spaces=Dict( + RLExtensionTestAgent => Crux.ContinuousSpace((5,), Float32) + ), + training_agent_types=[RLExtensionTestAgent], + max_steps=10, + observation_radius=1 + ) + + set_rl_config!(model, rl_config) + + # Test DQN solver + env, solver = Agents.setup_rl_training(model, RLExtensionTestAgent; + solver_type=:DQN, training_steps=100) + @test solver isa OffPolicySolver + + # Test A2C solver + env, solver = Agents.setup_rl_training(model, RLExtensionTestAgent; + solver_type=:A2C, training_steps=100) + @test solver isa OnPolicySolver + + # Test PPO solver (default) + env, solver = Agents.setup_rl_training(model, RLExtensionTestAgent; + solver_type=:PPO, training_steps=100) + @test solver isa OnPolicySolver + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 9b299822cf..8c3dfb4d73 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -170,4 +170,5 @@ end include("jld2_tests.jl") include("visualization_tests.jl") include("new_space_tests.jl") + include("reinforcement_learning_tests.jl") end