diff --git a/examples/configs/grpo_adk_gemma.yaml b/examples/configs/grpo_adk_gemma.yaml new file mode 100644 index 0000000000..15b2faae08 --- /dev/null +++ b/examples/configs/grpo_adk_gemma.yaml @@ -0,0 +1,51 @@ +# GRPO configuration for unique numbers environment +defaults: "grpo_math_8B.yaml" + +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_rollout_turns: 20 + max_num_steps: 100 + val_at_start: false + +data: + add_system_prompt: false + shuffle: false + +checkpointing: + enabled: false + checkpoint_dir: "results/grpo-adk" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + +env: + unique_numbers: + cfg: + max_turns: 15 + min_length: 5 + max_length: 10 + max_integer: 15 + +logger: + wandb_enabled: True + wandb: + project: "grpo-simulated-adk" + name: "gemma-4b-__NOW__" + + +policy: + train_global_batch_size: 512 + logprob_batch_size: 1 + model_name: google/gemma-3-4b-it + dynamic_batching: + enabled: True + sequence_packing: + enabled: False + tokenizer: + name: google/gemma-3-4b-it + chat_template: "{%- if add_bos_token|default(false) %}{{ bos_token }}{% endif %}{% for message in messages %}{% set role = 'model' if message['role'] == 'assistant' else message['role'] %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ 'model\n' }}{% endif %}" + +cluster: + gpus_per_node: 8 diff --git a/examples/configs/grpo_adk_llama8b.yaml b/examples/configs/grpo_adk_llama8b.yaml new file mode 100644 index 0000000000..89198d5251 --- /dev/null +++ b/examples/configs/grpo_adk_llama8b.yaml @@ -0,0 +1,45 @@ +# GRPO configuration for unique numbers environment +defaults: "grpo_math_8B.yaml" + +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_rollout_turns: 20 + max_num_steps: 100 + val_at_start: false + +data: + add_system_prompt: false + shuffle: false + +checkpointing: + enabled: false + checkpoint_dir: "results/grpo-adk" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + +env: + unique_numbers: + cfg: + max_turns: 15 + min_length: 5 + max_length: 10 + max_integer: 15 + +logger: + wandb_enabled: True + wandb: + project: "grpo-simulated-adk" + name: "llama-8b-__NOW__" + +policy: + train_global_batch_size: 512 + dynamic_batching: + enabled: False + tokenizer: + chat_template: '{%- if add_bos_token|default(false) %}{{ bos_token }}{% endif %}{% for message in messages %}{{ "<|start_header_id|>" + message.role + "<|end_header_id|>\n\n" + message.content | trim + "<|eot_id|>" }}{% endfor %}{% if add_generation_prompt %}{{ "<|start_header_id|>assistant<|end_header_id|>\n\n" }}{% endif %}' + +cluster: + gpus_per_node: 8 diff --git a/examples/run_grpo_unique_numbers_w_adk.py b/examples/run_grpo_unique_numbers_w_adk.py new file mode 100644 index 0000000000..2b7dc0d2a4 --- /dev/null +++ b/examples/run_grpo_unique_numbers_w_adk.py @@ -0,0 +1,273 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run GRPO with the Unique Numbers Simulator using ADK. + +This script sets up and executes the Group Relative Policy Optimization (GRPO) algorithm +in a multi-turn conversational environment powered by the ADK framework. + +### Task Overview +The objective is to train an agent to guess the number of unique integers in a list generated by a simulated user. +The interaction is structured as a turn-based dialogue: +- The user generates a list of integers. +- The agent queries specific positions in the list (by index). +- The user replies with the value at that index (if available). +- The agent continues the interaction until it makes a final guess at the number of unique integers. + +### Environment Details +The environment is a simulated user that: +- Randomly generates a list of integers at setup. +- Responds to the agent's queries using an LLM via the ADK endpoint. +- Optionally evaluates the agent's final guess using an LLM-based grader (included for extensibility, though not essential for this task). + +### Example Usage + uv run --extra adk --extra automodel python examples/run_grpo_unique_numbers_w_adk.py + +### Requirements +- A working ADK environment with access to a compatible LLM endpoint. + For the default Gemini endpoint, the following environment variables must be set: + - `GOOGLE_GENAI_USE_VERTEXAI=1` + - `GOOGLE_CLOUD_PROJECT="your-project-id"` + - `GOOGLE_CLOUD_LOCATION="your-location"` + +- A properly configured GRPO YAML file. + By default, the script uses: + `examples/configs/grpo_adk_llama8b.yaml` +""" + +import argparse +import itertools +import os +import pprint +import random +from datetime import datetime, timedelta +from typing import Iterator + +from omegaconf import OmegaConf +from torch.utils.data import IterableDataset +from transformers import AutoTokenizer + +from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup +from nemo_rl.algorithms.utils import get_tokenizer +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType +from nemo_rl.distributed.ray_actor_environment_registry import ( + get_actor_python_env, +) +from nemo_rl.distributed.virtual_cluster import init_ray +from nemo_rl.environments.simulated_user.prompt import starting_user_prompt +from nemo_rl.environments.simulated_user.unique_numbers import ( + UniqueNumbersEnv, + UniqueNumbersMetadata, +) +from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir + +OmegaConf.register_new_resolver("mul", lambda a, b: a * b) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run GRPO with unique numbers simulator" + ) + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + args, overrides = parser.parse_known_args() + return args, overrides + + +def generate_datum( + tokenizer: AutoTokenizer, + env_cfg: dict, + task_name: str, + idx: int, + add_system_prompt: bool, +) -> DatumSpec: + # please check the specific chat_template in the yaml file + formatted_prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": starting_user_prompt}], + tokenize=False, + # add_system_prompt=add_system_prompt, + add_bos_token=True, + add_generation_prompt=True, + add_special_tokens=False, + ) + token_ids = tokenizer( + formatted_prompt, return_tensors="pt", add_special_tokens=False + )["input_ids"][0] + + def _generate_numbers( + min_length, max_length, max_integer, default_max_turns + ) -> UniqueNumbersMetadata: + length = random.randint(min_length, max_length) + numbers = [random.randint(0, max_integer) for _ in range(length)] + return UniqueNumbersMetadata( + numbers=numbers, + unique_count=len(set(numbers)), + turn=0, + max_turns=default_max_turns, + ) + + metadata = _generate_numbers( + min_length=env_cfg["cfg"]["min_length"], + max_length=env_cfg["cfg"]["max_length"], + max_integer=env_cfg["cfg"]["max_integer"], + default_max_turns=env_cfg["cfg"]["max_turns"], + ) + + message_log: LLMMessageLogType = [ + {"role": "user", "content": formatted_prompt, "token_ids": token_ids} + ] + return { + "message_log": message_log, + "length": len(token_ids), + "extra_env_info": metadata, + "loss_multiplier": 1.0, + "idx": idx, + "task_name": task_name, + } + + +class IterableNumbersDataset(IterableDataset): + def __init__(self, tokenizer, env_cfg, task_name, add_system_prompt, length): + super().__init__() + self.tokenizer = tokenizer + self.env_cfg = env_cfg + self.task_name = task_name + self.add_system_prompt = add_system_prompt + self.length = length + + def __iter__(self) -> Iterator[DatumSpec]: + for i in itertools.count(): + yield generate_datum( + tokenizer=self.tokenizer, + env_cfg=self.env_cfg, + task_name=self.task_name, + idx=i, + add_system_prompt=self.add_system_prompt, + ) + + def __len__(self): + return self.length + + +def setup_data(tokenizer, env_cfg, task_name, length, val_length, add_system_prompt): + env_config = env_cfg[task_name] + env = UniqueNumbersEnv.options( # type: ignore # it's wrapped with ray.remote + num_gpus=0, + runtime_env={ + "py_executable": get_actor_python_env( + "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv" + ), + "env_vars": dict(os.environ), # Pass thru all user environment variables + }, + ).remote(cfg=dict(env_config["cfg"])) + + task_to_env = {task_name: env} + + train_ds = IterableNumbersDataset( + tokenizer=tokenizer, + env_cfg=env_config, + task_name=task_name, + add_system_prompt=add_system_prompt, + length=length, + ) + val_ds = IterableNumbersDataset( + tokenizer=tokenizer, + env_cfg=env_config, + task_name=task_name, + add_system_prompt=add_system_prompt, + length=val_length, + ) + val_task_to_env = task_to_env + return train_ds, val_ds, task_to_env, val_task_to_env + + +def main(): + args, overrides = parse_args() + if not args.config: + args.config = os.path.join( + os.path.dirname(__file__), "configs", "grpo_adk_llama8b.yaml" + ) + config = load_config(args.config) + if overrides: + config = parse_hydra_overrides(config, overrides) + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + + now_pst = datetime.utcnow() + timedelta(hours=-7) + config["logger"]["wandb"]["name"] = config["logger"]["wandb"]["name"].replace( + "__NOW__", now_pst.strftime("%m/%d-%H:%M") + ) + + config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) + if config["checkpointing"]["enabled"]: + print( + f"\U0001f4ca Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) + + pprint.pprint(config) + + init_ray() + + tokenizer = get_tokenizer(config["policy"]["tokenizer"]) + config["policy"]["generation"] = configure_generation_config( + config["policy"]["generation"], tokenizer + ) + + ds_length = ( + config["grpo"]["num_prompts_per_step"] + * config["grpo"]["num_generations_per_prompt"] + * config["grpo"]["max_num_steps"] + ) + dataset, val_dataset, task_to_env, val_task_to_env = setup_data( + tokenizer=tokenizer, + env_cfg=config["env"], + task_name="unique_numbers", + length=ds_length, + val_length=config["grpo"]["max_val_samples"], + add_system_prompt=config["data"]["add_system_prompt"], + ) + + ( + policy, + policy_generation, + cluster, + dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + grpo_state, + master_config, + ) = setup(config, tokenizer, dataset, val_dataset) + + grpo_train( + policy, + policy_generation, + dataloader, + val_dataloader, + tokenizer, + loss_fn, + task_to_env, + val_task_to_env, + logger, + checkpointer, + grpo_state, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/nemo_rl/distributed/ray_actor_environment_registry.py b/nemo_rl/distributed/ray_actor_environment_registry.py index 6a3529d4a1..95c742e5a4 100644 --- a/nemo_rl/distributed/ray_actor_environment_registry.py +++ b/nemo_rl/distributed/ray_actor_environment_registry.py @@ -37,6 +37,7 @@ "nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.reward_model_environment.RewardModelEnvironment": PY_EXECUTABLES.SYSTEM, "nemo_rl.environments.games.sliding_puzzle.SlidingPuzzleEnv": PY_EXECUTABLES.SYSTEM, + "nemo_rl.environments.simulated_user.unique_numbers.UniqueNumbersEnv": PY_EXECUTABLES.ADK, # AsyncTrajectoryCollector needs vLLM environment to handle exceptions from VllmGenerationWorker "nemo_rl.algorithms.async_utils.AsyncTrajectoryCollector": PY_EXECUTABLES.VLLM, # ReplayBuffer needs vLLM environment to handle trajectory data from VllmGenerationWorker diff --git a/nemo_rl/distributed/virtual_cluster.py b/nemo_rl/distributed/virtual_cluster.py index c28befb541..e4624eb290 100644 --- a/nemo_rl/distributed/virtual_cluster.py +++ b/nemo_rl/distributed/virtual_cluster.py @@ -58,6 +58,8 @@ class PY_EXECUTABLES: # Use Penguin dependencies PENGUIN = "uv run --locked --extra penguin" + ADK = "uv run --locked --extra adk" + @ray.remote # pragma: no cover def _get_node_ip_and_free_port() -> tuple[str, int]: diff --git a/nemo_rl/environments/interfaces.py b/nemo_rl/environments/interfaces.py index b869c32df7..47c7405051 100644 --- a/nemo_rl/environments/interfaces.py +++ b/nemo_rl/environments/interfaces.py @@ -46,7 +46,7 @@ class EnvironmentReturn(NamedTuple, Generic[MetadataT]): next_stop_strings: list[list[str] | None] | list[None] rewards: Tensor terminateds: Tensor - answers: list[str | None] | None + answers: list[str | None] | None = None class EnvironmentInterface(abc.ABC, Generic[MetadataT]): diff --git a/nemo_rl/environments/simulated_user/__init__.py b/nemo_rl/environments/simulated_user/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_rl/environments/simulated_user/adk_utils.py b/nemo_rl/environments/simulated_user/adk_utils.py new file mode 100644 index 0000000000..82ac94d1bc --- /dev/null +++ b/nemo_rl/environments/simulated_user/adk_utils.py @@ -0,0 +1,206 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import random + +# Initialize logging +logging.basicConfig( + format="[%(asctime)s] [%(levelname)s] %(message)s", + level=logging.WARNING, +) +logger = logging.getLogger(__name__) + + +# Define the agents +def create_agent( + instruction: str | None = None, + name: str = "simulated_user", + model: str = "gemini-2.0-flash", +): + from google.adk.agents import Agent + from google.genai import types + + return Agent( + model=model, + name=name, + description="Agent", + instruction=instruction + or "You are a helpful assistant that help people answer questions.", + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), + ) + + +def get_session_from_runner(runner, user_id: str): + app_session_map = runner.session_service.sessions + assert len(app_session_map) == 1, "Expected exactly one app in session_service" + user_sessions_map = next(iter(app_session_map.values())) + sessions = user_sessions_map[user_id] + assert len(sessions) == 1, "Expected exactly one user in app session" + return next(iter(sessions.values())) + + +def get_agent_instruction_from_runner(runner): + return runner.agent.instruction + + +def extract_conversation_history(runner, user_id: str, silence: bool = True): + session = get_session_from_runner(runner, user_id) + instruction = get_agent_instruction_from_runner(runner) + convo = [{"role": "instruction", "content": instruction}] + for event in session.events: + if event.content.parts and event.content.parts[0].text: + convo.append({"role": event.author, "content": event.content.parts[0].text}) + if not silence: + logger.info(f"[{convo[-1]['role']}]: {convo[-1]['content']}") + return session.id, convo + + +async def run_prompt_async( + runner, + user_id: str, + new_message: str, + silence: bool = True, + max_retries: int = 3, + initial_delay: float = 2, +) -> str: + from google.genai import types + from google.genai.errors import ServerError + + new_message = new_message.strip() + content = types.Content(role="user", parts=[types.Part.from_text(text=new_message)]) + if not silence: + logger.info(f"** [User]->|||{new_message}|||") + + session = get_session_from_runner(runner, user_id) + + retries = 0 + delay = initial_delay + while retries < max_retries: + try: + async for event in runner.run_async( + user_id=session.user_id, + session_id=session.id, + new_message=content, + ): + if event.content.parts and event.content.parts[0].text: + if not silence: + logger.info( + f"** [{event.author}]->|||{event.content.parts[0].text.strip()}|||" + ) + return event.content.parts[0].text.strip() + else: + return "" + except ServerError as e: + retries += 1 + delay_with_jitter = delay + (random.random() * 2 - 1) * (delay * 0.5) + logger.error( + f"Gemini API call (with message {new_message}) failed with ServerError {e} (attempt {retries}/{max_retries}). Retrying in {delay_with_jitter} seconds..." + ) + await asyncio.sleep(delay_with_jitter) + delay *= 2 # Exponential backoff + except Exception as e: + logger.error( + f"Gemini API call (with message {new_message}) failed with an unexpected error: {e}." + ) + return f"" + + logger.error( + f"Gemini API call (with message {new_message}) reached maximum retries ({max_retries}) without success." + ) + return f"" + + +async def setup_runner_async(agent, app_name: str, user_id: str): + from google.adk.runners import Runner + from google.adk.sessions import InMemorySessionService + + runner = Runner( + agent=agent, app_name=app_name, session_service=InMemorySessionService() + ) + await runner.session_service.create_session(app_name=app_name, user_id=user_id) + return runner + + +async def main(): + from google.adk.runners import Runner + from google.adk.sessions import InMemorySessionService + + sample_id_1 = "sample_1" + sample_id_2 = "sample_2" + + # Set up simulated user runner + simulated_user_app_name = "su_app" + simulated_user_runner = Runner( + agent=create_agent(name="simulated_user"), + app_name=simulated_user_app_name, + session_service=InMemorySessionService(), + ) + + await simulated_user_runner.session_service.create_session( + app_name=simulated_user_app_name, user_id=sample_id_1 + ) + await simulated_user_runner.session_service.create_session( + app_name=simulated_user_app_name, user_id=sample_id_2 + ) + + # setup grader runner + grader_app_name = "grader_app" + grader_instruction = "You are a helpful agent that can grade the correctness and coherent of a conversation. Please only give an integer as the score." + grader_runner = await setup_runner_async( + agent=create_agent(name="grader", instruction=grader_instruction), + app_name=grader_app_name, + user_id=sample_id_1, + ) + + # Simulated user interactions + await run_prompt_async( + simulated_user_runner, sample_id_1, "what is 2*3+5?", silence=False + ) + await run_prompt_async(simulated_user_runner, sample_id_2, "what is 2*3-5?") + await run_prompt_async(simulated_user_runner, sample_id_1, "Now add another 10.") + await run_prompt_async(simulated_user_runner, sample_id_2, "Now add another 100.") + + # Print conversation + logger.info("-" * 100) + _, convo1 = extract_conversation_history( + simulated_user_runner, sample_id_1, silence=False + ) + logger.info("-" * 100) + _, convo2 = extract_conversation_history( + simulated_user_runner, sample_id_2, silence=False + ) + logger.info("-" * 100) + + # Grade conversation + await run_prompt_async( + grader_runner, + sample_id_1, + f"Grade the above conversation and give a score between 0-10. \n\n{convo1}", + silence=False, + ) + logger.info("-" * 100) + logger.info("DONE!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/nemo_rl/environments/simulated_user/prompt.py b/nemo_rl/environments/simulated_user/prompt.py new file mode 100644 index 0000000000..7bb9e8c278 --- /dev/null +++ b/nemo_rl/environments/simulated_user/prompt.py @@ -0,0 +1,44 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +starting_user_prompt = ( + "I will play a game with you. I have a list of integers in mind and can NOT tell you. " + "Your goal is to guess the count of UNIQUE numbers in my list. The only 2 things you can do is the following: " + "You can either ask me 'what is number k?' to get the number at position k in my list, " + "or answer 'there are m unique numbers' whenever you feel you want to make a guess. " + "Please do not say anything else. You cannot ask me to provide the list of integers." +) + + +simulated_user_instruction = """ +You are a simulated user in a game where the assistant must figure out how many unique numbers you have. +You have a list of numbers (which may contain duplicates) that you will not reveal to the assistant. +The assistant can ask you questions of the form "What is number k?" where k is a 1-based index into your list of numbers. +You should respond with the number at that index. +The assistant can also make a guess by saying "There are m unique numbers" where m is their guess for the count of unique numbers. +If the assistant makes a correct guess, you will reward it. If the guess is incorrect, you will penalize it. + +Here is your list of numbers: {numbers}. +""".strip() + +grader_instruction = """ +Your are a strict grader to evaluate whether the assistant has properly guessed the count of unique numbers. +Here is your list of numbers: {numbers}. +You will see a conversation between the assistant and a simulated user who has this list of numbers. +You will need to evaluete in the end whether the assistant has made a correct guess of the count of unique numbers. +If the assistant made a correct guess, give it a score of 1. If the guess is incorrect, give it a score of 0. +If assistant made a correct guess but you feel the assistant has asked too many questions, please give a score between 0 and 1. +If the assistant never made a guess, give it a score of 0. +Please only output an integer score between 0 and 1, and nothing else. +""".strip() diff --git a/nemo_rl/environments/simulated_user/unique_numbers.py b/nemo_rl/environments/simulated_user/unique_numbers.py new file mode 100644 index 0000000000..7aaff98bfb --- /dev/null +++ b/nemo_rl/environments/simulated_user/unique_numbers.py @@ -0,0 +1,310 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simulated user environment for counting unique numbers.""" + +from __future__ import annotations + +import asyncio +import os +import re +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Optional, TypedDict + +import ray +import torch + +from nemo_rl.data.interfaces import LLMMessageLogType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.environments.interfaces import EnvironmentInterface, EnvironmentReturn +from nemo_rl.environments.simulated_user.prompt import ( + grader_instruction, + simulated_user_instruction, +) + +PENALTY_FOR_NO_GUESS = -0.2 +PENALTY_FOR_INCORRECT_GUESS = 0.0 +PENALTY_FOR_EVERY_ASK = 0.0 +PENALTY_FOR_INCORRECT_FORMAT = 0.0 + +ADK_LOG_FOLDER = None # 'logs/adk' or None to disable logging +if ADK_LOG_FOLDER is not None: + os.makedirs(ADK_LOG_FOLDER, exist_ok=True) + +GEMINI_CALL_MAX_WORKERS = ( + 64 # if 1 then it will be single-threaded, otherwise it will use ThreadPoolExecutor +) + + +class UniqueNumbersConfig(TypedDict, total=False): + """Configuration for :class:`UniqueNumbersEnv`.""" + + min_length: int + max_length: int + max_turns: int + + +class UniqueNumbersMetadata(TypedDict): + """Metadata for a UniqueNumbersEnv episode.""" + + numbers: list[int] + unique_count: int + turn: int + max_turns: int + simulated_user_runner: Optional[Any] + grader_runner: Optional[Any] + + +class _UniqueNumbersRunner: + query_re = re.compile(r"what is number (\d+)\??$", re.IGNORECASE) + guess_re = re.compile(r"there are (\d+) unique number", re.IGNORECASE) + + def _maybe_dump_adk_messages_to_file( + self, + runner: Any, + user_id: str, + log_name_suffix: str = "", + dump_folder: Optional[str] = ADK_LOG_FOLDER, + ): + from nemo_rl.environments.simulated_user.adk_utils import ( + extract_conversation_history, + ) + + if dump_folder is None: + return + session_id, messages = extract_conversation_history( + runner, user_id, silence=True + ) + file_name = f"{user_id}_{session_id}{log_name_suffix}.log" + with open(os.path.join(dump_folder, file_name), "a") as f: + for message in messages: + f.write(f"[{message['role']}]:|||{message['content']}|||\n") + + def process_turn( + self, message_log: LLMMessageLogType, metadata: UniqueNumbersMetadata + ) -> tuple[dict[str, str], float, bool, None, Optional[UniqueNumbersMetadata]]: + from nemo_rl.environments.simulated_user.adk_utils import ( + create_agent, + run_prompt_async, + setup_runner_async, + ) + + turn = metadata["turn"] + max_turns = metadata["max_turns"] + + if ( + "simulated_user_runner" not in metadata + or metadata["simulated_user_runner"] is None + ): + instruction = simulated_user_instruction.replace( + "{numbers}", str(metadata["numbers"]) + ) + simulated_user_agent = create_agent( + name="simulated_user", model="gemini-2.0-flash", instruction=instruction + ) + metadata["simulated_user_runner"] = asyncio.run( + setup_runner_async( + simulated_user_agent, "simulated_user_app", "simulated_user" + ) + ) + + if "grader_runner" not in metadata or metadata["grader_runner"] is None: + instruction = grader_instruction.replace( + "{numbers}", str(metadata["numbers"]) + ) + grader_agent = create_agent( + name="grader", model="gemini-2.0-flash", instruction=instruction + ) + metadata["grader_runner"] = asyncio.run( + setup_runner_async(grader_agent, "grader_app", "grader") + ) + + if turn >= max_turns: + self._maybe_dump_adk_messages_to_file( + metadata["simulated_user_runner"], "simulated_user", "_maxturns" + ) + return ( + {"role": "user", "content": ""}, + PENALTY_FOR_NO_GUESS, + True, + None, + None, + ) + + last_msg = "" + if message_log and message_log[-1]["role"] == "assistant": + last_msg = message_log[-1]["content"].strip() + + if not last_msg: + # no last message from assistant, assuming done + return ( + {"role": "user", "content": ""}, + PENALTY_FOR_NO_GUESS, + True, + None, + None, + ) + + # simulate user utterance via ADK + query_match = self.query_re.search(last_msg) + if query_match: + simulated_content = asyncio.run( + run_prompt_async( + metadata["simulated_user_runner"], + "simulated_user", + last_msg, + silence=True, + ) + ) + next_meta: UniqueNumbersMetadata = { + "numbers": metadata["numbers"], + "unique_count": metadata["unique_count"], + "turn": turn + 1, + "max_turns": max_turns, + "simulated_user_runner": metadata.get("simulated_user_runner", None), + "grader_runner": metadata.get("grader_runner", None), + } + return ( + {"role": "user", "content": simulated_content}, + PENALTY_FOR_EVERY_ASK, + False, + None, + next_meta, + ) + + # calculate reward if the assistant made a guess + guess_match = self.guess_re.search(last_msg) + if guess_match: + m = int(guess_match.group(1)) + reward = ( + 1.0 if m == metadata["unique_count"] else PENALTY_FOR_INCORRECT_GUESS + ) + + # grade the conversation via ADK grader + if metadata["grader_runner"] is not None: + convo_str = "\n".join( + [f"{msg['role']}: {msg['content']}" for msg in message_log] + ) + grading_prompt = f"Here is the converstation \n{convo_str}\nAnd please give the score between 0 and 1." + grading_response = asyncio.run( + run_prompt_async( + metadata["grader_runner"], + "grader", + grading_prompt, + silence=True, + ) + ) + try: + grade = int(re.search(r"(\d+)", grading_response).group(1)) + reward = (reward + grade) / 2.0 + except Exception as e: + print( + f"Failed to parse grade from grader response '{grading_response}': {e}" + ) + + self._maybe_dump_adk_messages_to_file( + metadata["simulated_user_runner"], "simulated_user", "_stop" + ) + self._maybe_dump_adk_messages_to_file(metadata["grader_runner"], "grader") + + return {"role": "user", "content": ""}, reward, True, None, None + + # default response + next_meta: UniqueNumbersMetadata = { + "numbers": metadata["numbers"], + "unique_count": metadata["unique_count"], + "turn": turn + 1, + "max_turns": max_turns, + "simulated_user_runner": metadata.get("simulated_user_runner", None), + "grader_runner": metadata.get("grader_runner", None), + } + help_msg = "Please ask 'what is number k?' or say 'there are m unique numbers'." + return ( + {"role": "user", "content": help_msg}, + PENALTY_FOR_INCORRECT_FORMAT, + False, + None, + next_meta, + ) + + +@ray.remote +class UniqueNumbersEnv(EnvironmentInterface): + """Environment where the LLM must deduce the count of unique numbers.""" + + def __init__(self, cfg: Optional[UniqueNumbersConfig] = None): + cfg = cfg or UniqueNumbersConfig() + self.min_length = cfg.get("min_length", 3) + self.max_length = cfg.get("max_length", 7) + self.default_max_turns = cfg.get("max_turns", 10) + + self.runner = _UniqueNumbersRunner() + + def step( + self, + message_log_batch: list[LLMMessageLogType], + metadata: list[Optional[UniqueNumbersMetadata]], + ) -> EnvironmentReturn: + args = [] + for log, meta in zip(message_log_batch, metadata): + assert meta is not None, "Metadata must not be None for UniqueNumbersEnv." + assert meta["numbers"] is not None, "Numbers must not be None in metadata." + assert meta["unique_count"] > 0, ( + "Unique count must be greater than 0 in metadata." + ) + args.append((log, meta)) + + # Process either serially or in parallel + if GEMINI_CALL_MAX_WORKERS is None or GEMINI_CALL_MAX_WORKERS <= 1: + results = [self.runner.process_turn(log, meta) for log, meta in args] + else: + with ThreadPoolExecutor(max_workers=GEMINI_CALL_MAX_WORKERS) as executor: + results = list( + executor.map(lambda p: self.runner.process_turn(*p), args) + ) + + observations, rewards, terminateds, stop_strings, next_metadata = ( + [], + [], + [], + [], + [], + ) + for obs, rew, term, stops, meta in results: + observations.append(obs) + rewards.append(rew) + terminateds.append(term) + stop_strings.append(stops) + next_metadata.append(meta) + + return EnvironmentReturn( + observations=observations, + metadata=next_metadata, + next_stop_strings=stop_strings, + rewards=torch.tensor(rewards, dtype=torch.float32), + terminateds=torch.tensor(terminateds, dtype=torch.bool), + answers=None, + ) + + def shutdown(self) -> None: # pragma: no cover + pass + + def global_post_process_and_metrics( + self, batch: BatchedDataDict + ) -> tuple[BatchedDataDict, dict]: + final_rewards = batch.get( + "total_reward", torch.tensor([0.0] * len(batch["idx"])) + ) + avg_reward = final_rewards.mean().item() if len(final_rewards) > 0 else 0.0 + return batch, {"unique_numbers_avg_reward": avg_reward} diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index b8b378542c..4a4cceae26 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -379,6 +379,11 @@ def run_multi_turn_rollout( if len(active_indices) == 0: break + if max_rollout_turns > 1: + print( + f"▶ ▶ ▶ Running rollout turn {turn + 1} / {max_rollout_turns} with {len(active_indices)} active samples..." + ) + active_samples_per_turn.append(len(active_indices)) # Convert LLMMessageLogType to FlatMessagesType for generation @@ -404,6 +409,7 @@ def run_multi_turn_rollout( "stop_strings": active_stop_strings, } ) + # add the multimodal data to the generation input data multimodal_data = active_flat_messages.get_multimodal_dict(as_tensors=False) generation_input_data.update(multimodal_data) @@ -444,11 +450,30 @@ def run_multi_turn_rollout( truncation_mask = torch.zeros_like(env_output.terminateds, dtype=torch.bool) for i, global_idx in enumerate(active_indices.tolist()): env_obs_content = env_output.observations[i]["content"] - # Tokenize the raw content from the environment - # TODO @sahilj: handle if we want these subsequent messages to have a chat template - tokenized_obs = tokenizer( - env_obs_content, return_tensors="pt", add_special_tokens=False - ).input_ids[0] + # Tokenize the raw content from the environment into chat format if needed + env_role = env_output.observations[i]["role"].lower() + if env_role in {"user", "assistant", "system"}: + formatted_obs = tokenizer.apply_chat_template( + [{"role": env_role, "content": env_obs_content.strip()}], + tokenize=False, + add_generation_prompt=True, + ) + tokenized_obs = tokenizer( + formatted_obs, return_tensors="pt", add_special_tokens=False + ).input_ids[0] + # remove the bos token if added after `apply_chat_template` + if ( + len(formatted_obs) > 0 + and hasattr(tokenizer, "bos_token_id") + and formatted_obs[0] == tokenizer.bos_token_id + ): + formatted_obs = formatted_obs[1:] + else: + formatted_obs = env_obs_content.strip() + tokenized_obs = tokenizer( + formatted_obs, return_tensors="pt", add_special_tokens=False + ).input_ids[0] + # tokenizer returns torch.float32 when env_obs_content is empty tokenized_obs = tokenized_obs.to(dtype=torch.int64) @@ -471,7 +496,7 @@ def run_multi_turn_rollout( tokenized_env_obs_message = { "role": env_output.observations[i]["role"], - "content": env_obs_content, + "content": formatted_obs, "token_ids": tokenized_obs, } current_batch["message_log"][global_idx].append(tokenized_env_obs_message) @@ -713,9 +738,28 @@ async def run_sample_multi_turn_rollout( terminated = env_output.terminateds[0].item() env_obs_content = env_output.observations[0]["content"] # Tokenize environment response - tokenized_obs = tokenizer( - env_obs_content, return_tensors="pt", add_special_tokens=False - ).input_ids[0] + env_role = env_output.observations[0]["role"].lower() + if env_role in {"user", "assistant", "system"}: + formatted_obs = tokenizer.apply_chat_template( + [{"role": env_role, "content": env_obs_content.strip()}], + tokenize=False, + add_generation_prompt=True, + ) + tokenized_obs = tokenizer( + formatted_obs, return_tensors="pt", add_special_tokens=False + ).input_ids[0] + # remove the bos token if added after `apply_chat_template` + if ( + len(formatted_obs) > 0 + and hasattr(tokenizer, "bos_token_id") + and formatted_obs[0] == tokenizer.bos_token_id + ): + formatted_obs = formatted_obs[1:] + else: + formatted_obs = env_obs_content.strip() + tokenized_obs = tokenizer( + formatted_obs, return_tensors="pt", add_special_tokens=False + ).input_ids[0] # Check for sequence length overflow if input_lengths + gen_token_count + len(tokenized_obs) >= max_seq_len: @@ -729,7 +773,7 @@ async def run_sample_multi_turn_rollout( env_message = { "role": env_output.observations[0]["role"], - "content": env_obs_content, + "content": formatted_obs, "token_ids": tokenized_obs, } current_message_log.append(env_message) diff --git a/pyproject.toml b/pyproject.toml index e64a6441f6..73f4ed01fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,6 +99,7 @@ mcore = [ "flash-attn==2.8.1", ] penguin = ["penguin"] +adk = ["google-adk==1.14.1", "google-genai==1.38.0"] [dependency-groups] diff --git a/pyrefly.toml b/pyrefly.toml index a1d64ad6fa..578c320e0c 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -16,6 +16,8 @@ replace-imports-with-any = [ "numpy.*", "sphinx.*", "docutils.*", + "google.adk.*", + "google.genai.*", ] project-includes = [ # TODO: enable these once we have 100 correctness @@ -84,6 +86,10 @@ project-includes = [ "nemo_rl/environments/rewards.py", "nemo_rl/environments/utils.py", "nemo_rl/environments/vlm_environment.py", + "nemo_rl/environments/simulated_user/__init__.py", + "nemo_rl/environments/simulated_user/adk_utils.py", + "nemo_rl/environments/simulated_user/prompt.py", + "nemo_rl/environments/simulated_user/unique_numbers.py", "nemo_rl/evals/__init__.py", "nemo_rl/evals/answer_parsing.py", "nemo_rl/experience/__init__.py", diff --git a/tests/unit/environments/test_simulated_user.py b/tests/unit/environments/test_simulated_user.py new file mode 100644 index 0000000000..6c01253173 --- /dev/null +++ b/tests/unit/environments/test_simulated_user.py @@ -0,0 +1,275 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import re +import sys +import types + +import pytest + +# Stub out external modules that are not available in the test environment +google_module = sys.modules.get("google", types.ModuleType("google")) +google_module.adk = types.ModuleType("google.adk") +google_module.adk.agents = types.ModuleType("google.adk.agents") +google_module.adk.agents.Agent = object +google_module.adk.Agent = object +google_module.adk.runners = types.ModuleType("google.adk.runners") +google_module.adk.runners.Runner = object +google_module.adk.sessions = types.ModuleType("google.adk.sessions") +google_module.adk.sessions.InMemorySessionService = object +google_module.genai = types.ModuleType("google.genai") + + +class _Part: + def __init__(self, text: str): + self.text = text + + @classmethod + def from_text(cls, text: str): + return cls(text) + + +class _Content: + def __init__(self, role: str | None = None, parts: list[_Part] | None = None): + self.role = role + self.parts = parts or [] + + +google_module.genai.types = types.SimpleNamespace( + GenerateContentConfig=object, + SafetySetting=object, + HarmCategory=object, + HarmBlockThreshold=object, + Content=_Content, + Part=_Part, +) +google_module.genai.errors = types.SimpleNamespace(ServerError=Exception) +sys.modules.setdefault("google.adk", google_module.adk) +sys.modules.setdefault("google.adk.agents", google_module.adk.agents) +sys.modules.setdefault("google.adk.runners", google_module.adk.runners) +sys.modules.setdefault("google.adk.sessions", google_module.adk.sessions) +sys.modules.setdefault("google.genai", google_module.genai) +sys.modules.setdefault("google.genai.types", google_module.genai.types) +sys.modules.setdefault("google.genai.errors", google_module.genai.errors) + +ray_module = types.ModuleType("ray") +ray_module.remote = lambda cls=None, **_: cls +sys.modules.setdefault("ray", ray_module) + +torch_module = types.ModuleType("torch") +torch_module.tensor = lambda data, **_: data +torch_module.float32 = "float32" +torch_module.bool = "bool" +torch_module.Tensor = object +torch_module.distributed = types.SimpleNamespace(ProcessGroup=object) +torch_module.device = object +sys.modules.setdefault("torch", torch_module) + +transformers_module = types.ModuleType("transformers") +transformers_module.PreTrainedTokenizerBase = object +sys.modules.setdefault("transformers", transformers_module) + +from nemo_rl.environments.simulated_user import adk_utils, unique_numbers + + +# Dummy runner object used for mocking run_prompt_async behaviour +class DummyRunner: + def __init__(self, numbers): + self.numbers = numbers + + +def _make_metadata(numbers): + return { + "numbers": numbers, + "unique_count": len(set(numbers)), + "turn": 0, + "max_turns": 5, + "simulated_user_runner": DummyRunner(numbers), + "grader_runner": DummyRunner(numbers), + } + + +@pytest.fixture +def patch_unique_numbers(monkeypatch): + async def fake_run_prompt_async(runner, user_id, msg, silence=True): + match = re.search(r"what is number (\d+)", msg, re.IGNORECASE) + if match: + idx = int(match.group(1)) - 1 + if 0 <= idx < len(runner.numbers): + return str(runner.numbers[idx]) + return "" + + monkeypatch.setattr(adk_utils, "run_prompt_async", fake_run_prompt_async) + monkeypatch.setattr(unique_numbers, "run_prompt_async", fake_run_prompt_async) + monkeypatch.setattr( + unique_numbers._UniqueNumbersRunner, + "_maybe_dump_adk_messages_to_file", + lambda *a, **k: None, + ) + + +@pytest.fixture +def patch_adk_utils(monkeypatch): + class FakeAgent: + def __init__(self, instruction: str | None = None, **_): + self.instruction = instruction + + class FakeSession: + def __init__(self, user_id: str, session_id: str = "s1"): + self.user_id = user_id + self.id = session_id + self.events = [] + + class FakeSessionService: + def __init__(self): + self.sessions = {} + + async def create_session(self, app_name: str, user_id: str): + sess = FakeSession(user_id) + self.sessions.setdefault(app_name, {}).setdefault(user_id, {})[sess.id] = ( + sess + ) + return sess + + class FakeRunner: + def __init__( + self, agent=None, app_name="app", session_service=None, responses=None + ): + self.agent = agent + self.app_name = app_name + self.session_service = session_service or FakeSessionService() + self.responses = list(responses or []) + + async def run_async(self, user_id: str, session_id: str, new_message): + session = self.session_service.sessions[self.app_name][user_id][session_id] + session.events.append( + types.SimpleNamespace(author="user", content=new_message) + ) + text = self.responses.pop(0) if self.responses else "ack" + event = types.SimpleNamespace( + author="assistant", + content=adk_utils.types.Content( + parts=[adk_utils.types.Part.from_text(text)] + ), + ) + session.events.append(event) + yield event + + class FakeServerError(Exception): + pass + + monkeypatch.setattr(adk_utils, "Agent", FakeAgent) + monkeypatch.setattr(adk_utils, "Runner", FakeRunner) + monkeypatch.setattr(adk_utils, "InMemorySessionService", FakeSessionService) + monkeypatch.setattr(adk_utils, "ServerError", FakeServerError) + return FakeRunner + + +def test_process_turn_query(patch_unique_numbers): + runner = unique_numbers._UniqueNumbersRunner() + numbers = [1, 2, 3] + metadata = _make_metadata(numbers) + + msg_log = [{"role": "assistant", "content": "What is number 2?"}] + obs, reward, terminated, _, next_meta = runner.process_turn(msg_log, metadata) + + assert obs == {"role": "user", "content": "2"} + assert reward == unique_numbers.PENALTY_FOR_EVERY_ASK + assert not terminated + assert next_meta is not None and next_meta["turn"] == 1 + + +def test_process_turn_correct_guess(patch_unique_numbers): + runner = unique_numbers._UniqueNumbersRunner() + numbers = [1, 2, 1] + metadata = _make_metadata(numbers) + + msg_log = [{"role": "assistant", "content": "There are 2 unique numbers"}] + obs, reward, terminated, _, next_meta = runner.process_turn(msg_log, metadata) + + assert obs == {"role": "user", "content": ""} + assert reward == 1.0 + assert terminated + assert next_meta is None + + +def test_process_turn_incorrect_guess(patch_unique_numbers): + runner = unique_numbers._UniqueNumbersRunner() + numbers = [1, 2, 1] + metadata = _make_metadata(numbers) + + msg_log = [{"role": "assistant", "content": "There are 3 unique numbers"}] + obs, reward, terminated, _, _ = runner.process_turn(msg_log, metadata) + + assert obs["content"] == "" + assert reward == unique_numbers.PENALTY_FOR_INCORRECT_GUESS + assert terminated + + +def test_process_turn_no_message(patch_unique_numbers): + runner = unique_numbers._UniqueNumbersRunner() + numbers = [1, 2, 1] + metadata = _make_metadata(numbers) + + msg_log = [] + obs, reward, terminated, _, _ = runner.process_turn(msg_log, metadata) + + assert obs["content"] == "" + assert reward == unique_numbers.PENALTY_FOR_NO_GUESS + assert terminated + + +def test_run_prompt_async_basic(patch_adk_utils): + agent = adk_utils.Agent(instruction="hi") + runner = adk_utils.Runner(agent=agent, responses=["7"]) + asyncio.run(runner.session_service.create_session("app", "u1")) + out = asyncio.run(adk_utils.run_prompt_async(runner, "u1", "hello", silence=True)) + assert out == "7" + + +def test_run_prompt_async_retry(monkeypatch, patch_adk_utils): + class ErrorRunner(adk_utils.Runner): + def __init__(self): + super().__init__(agent=adk_utils.Agent()) + self.call = 0 + + async def run_async(self, user_id: str, session_id: str, new_message): + self.call += 1 + if self.call == 1: + raise adk_utils.ServerError() + async for e in super().run_async(user_id, session_id, new_message): + yield e + + runner = ErrorRunner() + asyncio.run(runner.session_service.create_session("app", "u2")) + res = asyncio.run( + adk_utils.run_prompt_async( + runner, "u2", "hi", silence=True, max_retries=2, initial_delay=0 + ) + ) + assert res == "ack" + assert runner.call == 2 + + +def test_extract_conversation_history(patch_adk_utils): + agent = adk_utils.Agent(instruction="inst") + runner = adk_utils.Runner(agent=agent, responses=["42"]) + asyncio.run(runner.session_service.create_session("app", "u3")) + asyncio.run(adk_utils.run_prompt_async(runner, "u3", "question", silence=True)) + session_id, convo = adk_utils.extract_conversation_history(runner, "u3") + assert session_id == "s1" + assert convo[0] == {"role": "instruction", "content": "inst"} + assert convo[1]["role"] == "user" + assert convo[2]["content"] == "42" diff --git a/uv.lock b/uv.lock index 7b06abd41f..689ff1e6b6 100644 --- a/uv.lock +++ b/uv.lock @@ -2,17 +2,23 @@ version = 1 revision = 3 requires-python = ">=3.12" resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", - "python_full_version >= '3.13' and sys_platform == 'win32'", - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and sys_platform == 'darwin'", - "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", "python_full_version < '3.13' and sys_platform == 'win32'", "python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", "python_full_version < '3.13' and sys_platform == 'darwin'", "python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", ] @@ -62,6 +68,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/aa/ba0014cc4659328dc818a28827be78e6d97312ab0cb98105a770924dc11e/absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d", size = 135811, upload-time = "2025-07-03T09:31:42.253Z" }, ] +[[package]] +name = "absolufy-imports" +version = "0.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/74/0f/9da9dc9a12ebf4622ec96d9338d221e0172699e7574929f65ec8fdb30f9c/absolufy_imports-0.3.1.tar.gz", hash = "sha256:c90638a6c0b66826d1fb4880ddc20ef7701af34192c94faf40b95d32b59f9793", size = 4724, upload-time = "2022-01-20T14:48:53.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/a4/b65c9fbc2c0c09c0ea3008f62d2010fd261e62a4881502f03a6301079182/absolufy_imports-0.3.1-py2.py3-none-any.whl", hash = "sha256:49bf7c753a9282006d553ba99217f48f947e3eef09e18a700f8a82f75dc7fc5c", size = 5937, upload-time = "2022-01-20T14:48:51.718Z" }, +] + [[package]] name = "accelerate" version = "1.10.0" @@ -349,6 +364,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f6/22/91616fe707a5c5510de2cac9b046a30defe7007ba8a0c04f9c08f27df312/audioop_lts-0.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:b492c3b040153e68b9fdaff5913305aaaba5bb433d8a7f73d5cf6a64ed3cc1dd", size = 25206, upload-time = "2025-08-05T16:43:16.444Z" }, ] +[[package]] +name = "authlib" +version = "1.6.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cd/3f/1d3bbd0bf23bdd99276d4def22f29c27a914067b4cf66f753ff9b8bbd0f3/authlib-1.6.5.tar.gz", hash = "sha256:6aaf9c79b7cc96c900f0b284061691c5d4e61221640a948fe690b556a6d6d10b", size = 164553, upload-time = "2025-10-02T13:36:09.489Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/aa/5082412d1ee302e9e7d80b6949bc4d2a8fa1149aaab610c5fc24709605d6/authlib-1.6.5-py2.py3-none-any.whl", hash = "sha256:3e0e0507807f842b02175507bdee8957a1d5707fd4afb17c32fb43fee90b6e3a", size = 243608, upload-time = "2025-10-02T13:36:07.637Z" }, +] + [[package]] name = "av" version = "15.0.0" @@ -1281,6 +1308,15 @@ version = "0.6.2" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/a2/55/8f8cab2afd404cf578136ef2cc5dfb50baa1761b68c9da1fb1e4eed343c9/docopt-0.6.2.tar.gz", hash = "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491", size = 25901, upload-time = "2014-06-16T11:18:57.406Z" } +[[package]] +name = "docstring-parser" +version = "0.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/9d/c3b43da9515bd270df0f80548d9944e389870713cc1fe2b8fb35fe2bcefd/docstring_parser-0.17.0.tar.gz", hash = "sha256:583de4a309722b3315439bb31d64ba3eebada841f2e2cee23b99df001434c912", size = 27442, upload-time = "2025-07-21T07:35:01.868Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, +] + [[package]] name = "docutils" version = "0.21.2" @@ -1654,6 +1690,50 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl", hash = "sha256:8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77", size = 208168, upload-time = "2025-07-24T03:45:52.517Z" }, ] +[[package]] +name = "google-adk" +version = "1.14.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absolufy-imports" }, + { name = "anyio" }, + { name = "authlib" }, + { name = "click" }, + { name = "fastapi" }, + { name = "google-api-python-client" }, + { name = "google-cloud-aiplatform", extra = ["agent-engines"] }, + { name = "google-cloud-bigtable" }, + { name = "google-cloud-secret-manager", version = "2.24.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" }, + { name = "google-cloud-secret-manager", version = "2.25.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" }, + { name = "google-cloud-spanner" }, + { name = "google-cloud-speech" }, + { name = "google-cloud-storage" }, + { name = "google-genai" }, + { name = "graphviz" }, + { name = "mcp" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-gcp-trace" }, + { name = "opentelemetry-sdk" }, + { name = "pydantic" }, + { name = "python-dateutil" }, + { name = "python-dotenv" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "sqlalchemy" }, + { name = "sqlalchemy-spanner" }, + { name = "starlette" }, + { name = "tenacity" }, + { name = "typing-extensions" }, + { name = "tzlocal" }, + { name = "uvicorn" }, + { name = "watchdog" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/fe/0efba60d22bfcd7ab18f48d23771f0701664fd93be247eddc42592b9b68f/google_adk-1.14.1.tar.gz", hash = "sha256:06caab4599286123eceb9348e4accb6c3c1476b8d9b2b13f078a975c8ace966f", size = 1681879, upload-time = "2025-09-15T00:06:48.823Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/74/0b68fab470f13e80fd135bcf890c13bb1154804c1eaaff60dd1f5995027c/google_adk-1.14.1-py3-none-any.whl", hash = "sha256:acb31ed41d3b05b0d3a65cce76f6ef1289385f49a72164a07dae56190b648d50", size = 1922802, upload-time = "2025-09-15T00:06:47.011Z" }, +] + [[package]] name = "google-api-core" version = "2.25.1" @@ -1670,6 +1750,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/14/4b/ead00905132820b623732b175d66354e9d3e69fcf2a5dcdab780664e7896/google_api_core-2.25.1-py3-none-any.whl", hash = "sha256:8a2a56c1fef82987a524371f99f3bd0143702fecc670c72e600c1cda6bf8dbb7", size = 160807, upload-time = "2025-06-12T20:52:19.334Z" }, ] +[package.optional-dependencies] +grpc = [ + { name = "grpcio" }, + { name = "grpcio-status" }, +] + +[[package]] +name = "google-api-python-client" +version = "2.187.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "google-auth" }, + { name = "google-auth-httplib2" }, + { name = "httplib2" }, + { name = "uritemplate" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/75/83/60cdacf139d768dd7f0fcbe8d95b418299810068093fdf8228c6af89bb70/google_api_python_client-2.187.0.tar.gz", hash = "sha256:e98e8e8f49e1b5048c2f8276473d6485febc76c9c47892a8b4d1afa2c9ec8278", size = 14068154, upload-time = "2025-11-06T01:48:53.274Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/58/c1e716be1b055b504d80db2c8413f6c6a890a6ae218a65f178b63bc30356/google_api_python_client-2.187.0-py3-none-any.whl", hash = "sha256:d8d0f6d85d7d1d10bdab32e642312ed572bdc98919f72f831b44b9a9cebba32f", size = 14641434, upload-time = "2025-11-06T01:48:50.763Z" }, +] + [[package]] name = "google-auth" version = "2.40.3" @@ -1684,6 +1786,341 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/17/63/b19553b658a1692443c62bd07e5868adaa0ad746a0751ba62c59568cd45b/google_auth-2.40.3-py2.py3-none-any.whl", hash = "sha256:1370d4593e86213563547f97a92752fc658456fe4514c809544f330fed45a7ca", size = 216137, upload-time = "2025-06-04T18:04:55.573Z" }, ] +[[package]] +name = "google-auth-httplib2" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "httplib2" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e0/83/7ef576d1c7ccea214e7b001e69c006bc75e058a3a1f2ab810167204b698b/google_auth_httplib2-0.2.1.tar.gz", hash = "sha256:5ef03be3927423c87fb69607b42df23a444e434ddb2555b73b3679793187b7de", size = 11086, upload-time = "2025-10-30T21:13:16.569Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/a7/ca23dd006255f70e2bc469d3f9f0c82ea455335bfd682ad4d677adc435de/google_auth_httplib2-0.2.1-py3-none-any.whl", hash = "sha256:1be94c611db91c01f9703e7f62b0a59bbd5587a95571c7b6fade510d648bc08b", size = 9525, upload-time = "2025-10-30T21:13:15.758Z" }, +] + +[[package]] +name = "google-cloud-aiplatform" +version = "1.128.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docstring-parser" }, + { name = "google-api-core", extra = ["grpc"] }, + { name = "google-auth" }, + { name = "google-cloud-bigquery" }, + { name = "google-cloud-resource-manager" }, + { name = "google-cloud-storage" }, + { name = "google-genai" }, + { name = "packaging" }, + { name = "proto-plus" }, + { name = "protobuf" }, + { name = "pydantic" }, + { name = "shapely" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f2/34/6fe43626f4e4508759494d2ac4e06708bd36f820150158dfdba33c207a89/google_cloud_aiplatform-1.128.0.tar.gz", hash = "sha256:2c4f6ac60fd52b12499b84f80d6e66e601763f21114f5ac9b01bb1daacf54cd9", size = 9782505, upload-time = "2025-11-19T01:36:51.324Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/3e/977708c8db669d9de9c4bb8cda18a649a6e0dfc1188f3def30c352fc9566/google_cloud_aiplatform-1.128.0-py2.py3-none-any.whl", hash = "sha256:3341e5c688124381833529886fea59351681f4f42c10dcc50000096680d49555", size = 8123088, upload-time = "2025-11-19T01:36:48.544Z" }, +] + +[package.optional-dependencies] +agent-engines = [ + { name = "cloudpickle" }, + { name = "google-cloud-logging" }, + { name = "google-cloud-trace" }, + { name = "opentelemetry-exporter-gcp-logging" }, + { name = "opentelemetry-exporter-gcp-trace" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, + { name = "opentelemetry-sdk" }, + { name = "packaging" }, + { name = "pydantic" }, + { name = "typing-extensions" }, +] + +[[package]] +name = "google-cloud-appengine-logging" +version = "1.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core", extra = ["grpc"] }, + { name = "google-auth" }, + { name = "proto-plus" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e7/ea/85da73d4f162b29d24ad591c4ce02688b44094ee5f3d6c0cc533c2b23b23/google_cloud_appengine_logging-1.6.2.tar.gz", hash = "sha256:4890928464c98da9eecc7bf4e0542eba2551512c0265462c10f3a3d2a6424b90", size = 16587, upload-time = "2025-06-11T22:38:53.525Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/9e/dc1fd7f838dcaf608c465171b1a25d8ce63f9987e2d5c73bda98792097a9/google_cloud_appengine_logging-1.6.2-py3-none-any.whl", hash = "sha256:2b28ed715e92b67e334c6fcfe1deb523f001919560257b25fc8fcda95fd63938", size = 16889, upload-time = "2025-06-11T22:38:52.26Z" }, +] + +[[package]] +name = "google-cloud-audit-log" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "googleapis-common-protos" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/d2/ad96950410f8a05e921a6da2e1a6ba4aeca674bbb5dda8200c3c7296d7ad/google_cloud_audit_log-0.4.0.tar.gz", hash = "sha256:8467d4dcca9f3e6160520c24d71592e49e874838f174762272ec10e7950b6feb", size = 44682, upload-time = "2025-10-17T02:33:44.641Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/25/532886995f11102ad6de290496de5db227bd3a73827702445928ad32edcb/google_cloud_audit_log-0.4.0-py3-none-any.whl", hash = "sha256:6b88e2349df45f8f4cc0993b687109b1388da1571c502dc1417efa4b66ec55e0", size = 44890, upload-time = "2025-10-17T02:30:55.11Z" }, +] + +[[package]] +name = "google-cloud-bigquery" +version = "3.38.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core", extra = ["grpc"] }, + { name = "google-auth" }, + { name = "google-cloud-core" }, + { name = "google-resumable-media" }, + { name = "packaging" }, + { name = "python-dateutil" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/07/b2/a17e40afcf9487e3d17db5e36728ffe75c8d5671c46f419d7b6528a5728a/google_cloud_bigquery-3.38.0.tar.gz", hash = "sha256:8afcb7116f5eac849097a344eb8bfda78b7cfaae128e60e019193dd483873520", size = 503666, upload-time = "2025-09-17T20:33:33.47Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/3c/c8cada9ec282b29232ed9aed5a0b5cca6cf5367cb2ffa8ad0d2583d743f1/google_cloud_bigquery-3.38.0-py3-none-any.whl", hash = "sha256:e06e93ff7b245b239945ef59cb59616057598d369edac457ebf292bd61984da6", size = 259257, upload-time = "2025-09-17T20:33:31.404Z" }, +] + +[[package]] +name = "google-cloud-bigtable" +version = "2.34.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core", extra = ["grpc"] }, + { name = "google-auth" }, + { name = "google-cloud-core" }, + { name = "google-crc32c" }, + { name = "grpc-google-iam-v1" }, + { name = "proto-plus" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/29/20/8a29e1d5858ba76f443dc527a223e769347b915cb060a9f19250241aa38a/google_cloud_bigtable-2.34.0.tar.gz", hash = "sha256:773258b00cd3f9a3a35639cc38bd711f4f1418aaa0c8d70cb028978ed98dc2c2", size = 766606, upload-time = "2025-10-22T19:04:53.645Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/55/6d/aa44110504b4b9d125f1cc9715b72a178ebbe5cb79698e7a95893c391e56/google_cloud_bigtable-2.34.0-py3-none-any.whl", hash = "sha256:a4a8db4903840cd3f89fb19c060eea2e7c09c1265cb0538cfc11288dbc6000e4", size = 537041, upload-time = "2025-10-22T19:04:52.014Z" }, +] + +[[package]] +name = "google-cloud-core" +version = "2.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "google-auth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a6/03/ef0bc99d0e0faf4fdbe67ac445e18cdaa74824fd93cd069e7bb6548cb52d/google_cloud_core-2.5.0.tar.gz", hash = "sha256:7c1b7ef5c92311717bd05301aa1a91ffbc565673d3b0b4163a52d8413a186963", size = 36027, upload-time = "2025-10-29T23:17:39.513Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/20/bfa472e327c8edee00f04beecc80baeddd2ab33ee0e86fd7654da49d45e9/google_cloud_core-2.5.0-py3-none-any.whl", hash = "sha256:67d977b41ae6c7211ee830c7912e41003ea8194bff15ae7d72fd6f51e57acabc", size = 29469, upload-time = "2025-10-29T23:17:38.548Z" }, +] + +[[package]] +name = "google-cloud-logging" +version = "3.12.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core", extra = ["grpc"] }, + { name = "google-auth" }, + { name = "google-cloud-appengine-logging" }, + { name = "google-cloud-audit-log" }, + { name = "google-cloud-core" }, + { name = "grpc-google-iam-v1" }, + { name = "opentelemetry-api" }, + { name = "proto-plus" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/14/9c/d42ecc94f795a6545930e5f846a7ae59ff685ded8bc086648dd2bee31a1a/google_cloud_logging-3.12.1.tar.gz", hash = "sha256:36efc823985055b203904e83e1c8f9f999b3c64270bcda39d57386ca4effd678", size = 289569, upload-time = "2025-04-22T20:50:24.71Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/41/f8a3197d39b773a91f335dee36c92ef26a8ec96efe78d64baad89d367df4/google_cloud_logging-3.12.1-py2.py3-none-any.whl", hash = "sha256:6817878af76ec4e7568976772839ab2c43ddfd18fbbf2ce32b13ef549cd5a862", size = 229466, upload-time = "2025-04-22T20:50:23.294Z" }, +] + +[[package]] +name = "google-cloud-resource-manager" +version = "1.14.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core", extra = ["grpc"] }, + { name = "google-auth" }, + { name = "grpc-google-iam-v1" }, + { name = "proto-plus" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6e/ca/a4648f5038cb94af4b3942815942a03aa9398f9fb0bef55b3f1585b9940d/google_cloud_resource_manager-1.14.2.tar.gz", hash = "sha256:962e2d904c550d7bac48372607904ff7bb3277e3bb4a36d80cc9a37e28e6eb74", size = 446370, upload-time = "2025-03-17T11:35:56.343Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/ea/a92631c358da377af34d3a9682c97af83185c2d66363d5939ab4a1169a7f/google_cloud_resource_manager-1.14.2-py3-none-any.whl", hash = "sha256:d0fa954dedd1d2b8e13feae9099c01b8aac515b648e612834f9942d2795a9900", size = 394344, upload-time = "2025-03-17T11:35:54.722Z" }, +] + +[[package]] +name = "google-cloud-secret-manager" +version = "2.24.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "google-api-core", extra = ["grpc"], marker = "python_full_version >= '3.14'" }, + { name = "google-auth", marker = "python_full_version >= '3.14'" }, + { name = "grpc-google-iam-v1", marker = "python_full_version >= '3.14'" }, + { name = "proto-plus", marker = "python_full_version >= '3.14'" }, + { name = "protobuf", marker = "python_full_version >= '3.14'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/58/7a/2fa6735ec693d822fe08a76709c4d95d9b5b4c02e83e720497355039d2ee/google_cloud_secret_manager-2.24.0.tar.gz", hash = "sha256:ce573d40ffc2fb7d01719243a94ee17aa243ea642a6ae6c337501e58fbf642b5", size = 269516, upload-time = "2025-06-05T22:22:22.965Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/af/db1217cae1809e69a4527ee6293b82a9af2a1fb2313ad110c775e8f3c820/google_cloud_secret_manager-2.24.0-py3-none-any.whl", hash = "sha256:9bea1254827ecc14874bc86c63b899489f8f50bfe1442bfb2517530b30b3a89b", size = 218050, upload-time = "2025-06-10T02:02:19.88Z" }, +] + +[[package]] +name = "google-cloud-secret-manager" +version = "2.25.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version < '3.13' and sys_platform == 'win32'", + "python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version < '3.13' and sys_platform == 'darwin'", + "python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "google-api-core", extra = ["grpc"], marker = "python_full_version < '3.14'" }, + { name = "google-auth", marker = "python_full_version < '3.14'" }, + { name = "grpc-google-iam-v1", marker = "python_full_version < '3.14'" }, + { name = "grpcio", marker = "python_full_version < '3.14'" }, + { name = "proto-plus", marker = "python_full_version < '3.14'" }, + { name = "protobuf", marker = "python_full_version < '3.14'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c3/7c/be2d11415eec83c400d315cf9876ba29742bc7af90df391d357763463cd2/google_cloud_secret_manager-2.25.0.tar.gz", hash = "sha256:a3792bb1cb307326908297a61536031ac94852c22248f04ae112ff51a853b561", size = 269853, upload-time = "2025-10-14T15:42:59.511Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/74/bf87966a6ee48c98d1b8a6a1839256911e9a2a205be76b21e54f58171615/google_cloud_secret_manager-2.25.0-py3-none-any.whl", hash = "sha256:eaf1adce3ff5dc0f24335709eba3410dc7e9d20aeea3e8df5b758e27080ebf14", size = 218548, upload-time = "2025-10-14T15:42:47.839Z" }, +] + +[[package]] +name = "google-cloud-spanner" +version = "3.59.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core", extra = ["grpc"] }, + { name = "google-cloud-core" }, + { name = "grpc-google-iam-v1" }, + { name = "grpc-interceptor" }, + { name = "proto-plus" }, + { name = "protobuf" }, + { name = "sqlparse" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/62/f0e535875e49b34128710342115681fe1a97f45759e1427307ab150a4caa/google_cloud_spanner-3.59.0.tar.gz", hash = "sha256:dec7a78bfe1f94aef508ff9c61dba4196f3c70c83a0f75c271b4652686d08641", size = 705137, upload-time = "2025-10-23T09:35:49.885Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d8/08/1a38139853364b4737e3a0e03a3fd87d60c7545e90a963a8a6457777b5f9/google_cloud_spanner-3.59.0-py3-none-any.whl", hash = "sha256:409ed9746787c9435fd015731a5e3cf6f3ea2995a807c580f4216bb5d464260a", size = 502645, upload-time = "2025-10-23T09:35:47.954Z" }, +] + +[[package]] +name = "google-cloud-speech" +version = "2.33.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core", extra = ["grpc"] }, + { name = "google-auth" }, + { name = "proto-plus" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9a/74/9c5a556f8af19cab461058aa15e1409e7afa453ca2383473a24a12801ef7/google_cloud_speech-2.33.0.tar.gz", hash = "sha256:fd08511b5124fdaa768d71a4054e84a5d8eb02531cb6f84f311c0387ea1314ed", size = 389072, upload-time = "2025-06-11T23:56:37.231Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/12/1d/880342b2541b4bad888ad8ab2ac77d4b5dad25b32a2a1c5f21140c14c8e3/google_cloud_speech-2.33.0-py3-none-any.whl", hash = "sha256:4ba16c8517c24a6abcde877289b0f40b719090504bf06b1adea248198ccd50a5", size = 335681, upload-time = "2025-06-11T23:56:36.026Z" }, +] + +[[package]] +name = "google-cloud-storage" +version = "2.19.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "google-auth" }, + { name = "google-cloud-core" }, + { name = "google-crc32c" }, + { name = "google-resumable-media" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/36/76/4d965702e96bb67976e755bed9828fa50306dca003dbee08b67f41dd265e/google_cloud_storage-2.19.0.tar.gz", hash = "sha256:cd05e9e7191ba6cb68934d8eb76054d9be4562aa89dbc4236feee4d7d51342b2", size = 5535488, upload-time = "2024-12-05T01:35:06.49Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/94/6db383d8ee1adf45dc6c73477152b82731fa4c4a46d9c1932cc8757e0fd4/google_cloud_storage-2.19.0-py2.py3-none-any.whl", hash = "sha256:aeb971b5c29cf8ab98445082cbfe7b161a1f48ed275822f59ed3f1524ea54fba", size = 131787, upload-time = "2024-12-05T01:35:04.736Z" }, +] + +[[package]] +name = "google-cloud-trace" +version = "1.16.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core", extra = ["grpc"] }, + { name = "google-auth" }, + { name = "proto-plus" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/ea/0e42e2196fb2bc8c7b25f081a0b46b5053d160b34d5322e7eac2d5f7a742/google_cloud_trace-1.16.2.tar.gz", hash = "sha256:89bef223a512465951eb49335be6d60bee0396d576602dbf56368439d303cab4", size = 97826, upload-time = "2025-06-12T00:53:02.12Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/96/7a8d271e91effa9ccc2fd7cfd5cf287a2d7900080a475477c2ac0c7a331d/google_cloud_trace-1.16.2-py3-none-any.whl", hash = "sha256:40fb74607752e4ee0f3d7e5fc6b8f6eb1803982254a1507ba918172484131456", size = 103755, upload-time = "2025-06-12T00:53:00.672Z" }, +] + +[[package]] +name = "google-crc32c" +version = "1.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/19/ae/87802e6d9f9d69adfaedfcfd599266bf386a54d0be058b532d04c794f76d/google_crc32c-1.7.1.tar.gz", hash = "sha256:2bff2305f98846f3e825dbeec9ee406f89da7962accdb29356e4eadc251bd472", size = 14495, upload-time = "2025-03-26T14:29:13.32Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dd/b7/787e2453cf8639c94b3d06c9d61f512234a82e1d12d13d18584bd3049904/google_crc32c-1.7.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2d73a68a653c57281401871dd4aeebbb6af3191dcac751a76ce430df4d403194", size = 30470, upload-time = "2025-03-26T14:34:31.655Z" }, + { url = "https://files.pythonhosted.org/packages/ed/b4/6042c2b0cbac3ec3a69bb4c49b28d2f517b7a0f4a0232603c42c58e22b44/google_crc32c-1.7.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:22beacf83baaf59f9d3ab2bbb4db0fb018da8e5aebdce07ef9f09fce8220285e", size = 30315, upload-time = "2025-03-26T15:01:54.634Z" }, + { url = "https://files.pythonhosted.org/packages/29/ad/01e7a61a5d059bc57b702d9ff6a18b2585ad97f720bd0a0dbe215df1ab0e/google_crc32c-1.7.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19eafa0e4af11b0a4eb3974483d55d2d77ad1911e6cf6f832e1574f6781fd337", size = 33180, upload-time = "2025-03-26T14:41:32.168Z" }, + { url = "https://files.pythonhosted.org/packages/3b/a5/7279055cf004561894ed3a7bfdf5bf90a53f28fadd01af7cd166e88ddf16/google_crc32c-1.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6d86616faaea68101195c6bdc40c494e4d76f41e07a37ffdef270879c15fb65", size = 32794, upload-time = "2025-03-26T14:41:33.264Z" }, + { url = "https://files.pythonhosted.org/packages/0f/d6/77060dbd140c624e42ae3ece3df53b9d811000729a5c821b9fd671ceaac6/google_crc32c-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:b7491bdc0c7564fcf48c0179d2048ab2f7c7ba36b84ccd3a3e1c3f7a72d3bba6", size = 33477, upload-time = "2025-03-26T14:29:10.94Z" }, + { url = "https://files.pythonhosted.org/packages/8b/72/b8d785e9184ba6297a8620c8a37cf6e39b81a8ca01bb0796d7cbb28b3386/google_crc32c-1.7.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:df8b38bdaf1629d62d51be8bdd04888f37c451564c2042d36e5812da9eff3c35", size = 30467, upload-time = "2025-03-26T14:36:06.909Z" }, + { url = "https://files.pythonhosted.org/packages/34/25/5f18076968212067c4e8ea95bf3b69669f9fc698476e5f5eb97d5b37999f/google_crc32c-1.7.1-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:e42e20a83a29aa2709a0cf271c7f8aefaa23b7ab52e53b322585297bb94d4638", size = 30309, upload-time = "2025-03-26T15:06:15.318Z" }, + { url = "https://files.pythonhosted.org/packages/92/83/9228fe65bf70e93e419f38bdf6c5ca5083fc6d32886ee79b450ceefd1dbd/google_crc32c-1.7.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:905a385140bf492ac300026717af339790921f411c0dfd9aa5a9e69a08ed32eb", size = 33133, upload-time = "2025-03-26T14:41:34.388Z" }, + { url = "https://files.pythonhosted.org/packages/c3/ca/1ea2fd13ff9f8955b85e7956872fdb7050c4ace8a2306a6d177edb9cf7fe/google_crc32c-1.7.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b211ddaf20f7ebeec5c333448582c224a7c90a9d98826fbab82c0ddc11348e6", size = 32773, upload-time = "2025-03-26T14:41:35.19Z" }, + { url = "https://files.pythonhosted.org/packages/89/32/a22a281806e3ef21b72db16f948cad22ec68e4bdd384139291e00ff82fe2/google_crc32c-1.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:0f99eaa09a9a7e642a61e06742856eec8b19fc0037832e03f941fe7cf0c8e4db", size = 33475, upload-time = "2025-03-26T14:29:11.771Z" }, + { url = "https://files.pythonhosted.org/packages/b8/c5/002975aff514e57fc084ba155697a049b3f9b52225ec3bc0f542871dd524/google_crc32c-1.7.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32d1da0d74ec5634a05f53ef7df18fc646666a25efaaca9fc7dcfd4caf1d98c3", size = 33243, upload-time = "2025-03-26T14:41:35.975Z" }, + { url = "https://files.pythonhosted.org/packages/61/cb/c585282a03a0cea70fcaa1bf55d5d702d0f2351094d663ec3be1c6c67c52/google_crc32c-1.7.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e10554d4abc5238823112c2ad7e4560f96c7bf3820b202660373d769d9e6e4c9", size = 32870, upload-time = "2025-03-26T14:41:37.08Z" }, +] + +[[package]] +name = "google-genai" +version = "1.38.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "google-auth" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "tenacity" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b4/11/108ddd3aca8af6a9e2369e59b9646a3a4c64aefb39d154f6467ab8d79f34/google_genai-1.38.0.tar.gz", hash = "sha256:363272fc4f677d0be6a1aed7ebabe8adf45e1626a7011a7886a587e9464ca9ec", size = 244903, upload-time = "2025-09-16T23:25:42.577Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/53/6c/1de711bab3c118284904c3bedf870519e8c63a7a8e0905ac3833f1db9cbc/google_genai-1.38.0-py3-none-any.whl", hash = "sha256:95407425132d42b3fa11bc92b3f5cf61a0fbd8d9add1f0e89aac52c46fbba090", size = 245558, upload-time = "2025-09-16T23:25:41.141Z" }, +] + +[[package]] +name = "google-resumable-media" +version = "2.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-crc32c" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/d7/520b62a35b23038ff005e334dba3ffc75fcf583bee26723f1fd8fd4b6919/google_resumable_media-2.8.0.tar.gz", hash = "sha256:f1157ed8b46994d60a1bc432544db62352043113684d4e030ee02e77ebe9a1ae", size = 2163265, upload-time = "2025-11-17T15:38:06.659Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/0b/93afde9cfe012260e9fe1522f35c9b72d6ee222f316586b1f23ecf44d518/google_resumable_media-2.8.0-py3-none-any.whl", hash = "sha256:dd14a116af303845a8d932ddae161a26e86cc229645bc98b39f026f9b1717582", size = 81340, upload-time = "2025-11-17T15:38:05.594Z" }, +] + [[package]] name = "googleapis-common-protos" version = "1.70.0" @@ -1696,6 +2133,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/86/f1/62a193f0227cf15a920390abe675f386dec35f7ae3ffe6da582d3ade42c7/googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8", size = 294530, upload-time = "2025-04-14T10:17:01.271Z" }, ] +[package.optional-dependencies] +grpc = [ + { name = "grpcio" }, +] + [[package]] name = "gradio" version = "5.49.1" @@ -1789,6 +2231,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/74/16/a4cf06adbc711bd364a73ce043b0b08d8fa5aae3df11b6ee4248bcdad2e0/graphql_relay-3.2.0-py3-none-any.whl", hash = "sha256:c9b22bd28b170ba1fe674c74384a8ff30a76c8e26f88ac3aa1584dd3179953e5", size = 16940, upload-time = "2022-04-16T11:03:43.895Z" }, ] +[[package]] +name = "graphviz" +version = "0.21" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/b3/3ac91e9be6b761a4b30d66ff165e54439dcd48b83f4e20d644867215f6ca/graphviz-0.21.tar.gz", hash = "sha256:20743e7183be82aaaa8ad6c93f8893c923bd6658a04c32ee115edb3c8a835f78", size = 200434, upload-time = "2025-06-15T09:35:05.824Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/4c/e0ce1ef95d4000ebc1c11801f9b944fa5910ecc15b5e351865763d8657f8/graphviz-0.21-py3-none-any.whl", hash = "sha256:54f33de9f4f911d7e84e4191749cac8cc5653f815b06738c54db9a15ab8b1e42", size = 47300, upload-time = "2025-06-15T09:35:04.433Z" }, +] + [[package]] name = "greenlet" version = "3.2.4" @@ -1885,6 +2336,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/28/27/3d6dcadc8a3214d8522c1e7f6a19554e33659be44546d44a2f7572ac7d2a/groovy-0.1.2-py3-none-any.whl", hash = "sha256:7f7975bab18c729a257a8b1ae9dcd70b7cafb1720481beae47719af57c35fa64", size = 14090, upload-time = "2025-02-28T20:24:55.152Z" }, ] +[[package]] +name = "grpc-google-iam-v1" +version = "0.14.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "googleapis-common-protos", extra = ["grpc"] }, + { name = "grpcio" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/76/1e/1011451679a983f2f5c6771a1682542ecb027776762ad031fd0d7129164b/grpc_google_iam_v1-0.14.3.tar.gz", hash = "sha256:879ac4ef33136c5491a6300e27575a9ec760f6cdf9a2518798c1b8977a5dc389", size = 23745, upload-time = "2025-10-15T21:14:53.318Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/bd/330a1bbdb1afe0b96311249e699b6dc9cfc17916394fd4503ac5aca2514b/grpc_google_iam_v1-0.14.3-py3-none-any.whl", hash = "sha256:7a7f697e017a067206a3dfef44e4c634a34d3dee135fe7d7a4613fe3e59217e6", size = 32690, upload-time = "2025-10-15T21:14:51.72Z" }, +] + +[[package]] +name = "grpc-interceptor" +version = "0.15.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "grpcio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9f/28/57449d5567adf4c1d3e216aaca545913fbc21a915f2da6790d6734aac76e/grpc-interceptor-0.15.4.tar.gz", hash = "sha256:1f45c0bcb58b6f332f37c637632247c9b02bc6af0fdceb7ba7ce8d2ebbfb0926", size = 19322, upload-time = "2023-11-16T02:05:42.459Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/ac/8d53f230a7443401ce81791ec50a3b0e54924bf615ad287654fa4a2f5cdc/grpc_interceptor-0.15.4-py3-none-any.whl", hash = "sha256:0035f33228693ed3767ee49d937bac424318db173fef4d2d0170b3215f254d9d", size = 20848, upload-time = "2023-11-16T02:05:40.913Z" }, +] + [[package]] name = "grpcio" version = "1.74.0" @@ -1913,6 +2390,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/34/80/de3eb55eb581815342d097214bed4c59e806b05f1b3110df03b2280d6dfd/grpcio-1.74.0-cp313-cp313-win_amd64.whl", hash = "sha256:fd3c71aeee838299c5887230b8a1822795325ddfea635edd82954c1eaa831e24", size = 4489214, upload-time = "2025-07-24T18:53:59.771Z" }, ] +[[package]] +name = "grpcio-status" +version = "1.74.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "googleapis-common-protos" }, + { name = "grpcio" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/93/22/238c5f01e6837df54494deb08d5c772bc3f5bf5fb80a15dce254892d1a81/grpcio_status-1.74.0.tar.gz", hash = "sha256:c58c1b24aa454e30f1fc6a7e0dbbc194c54a408143971a94b5f4e40bb5831432", size = 13662, upload-time = "2025-07-24T19:01:56.874Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/28/aa/1b1fe7d8ab699e1ec26d3a36b91d3df9f83a30abc07d4c881d0296b17b67/grpcio_status-1.74.0-py3-none-any.whl", hash = "sha256:52cdbd759a6760fc8f668098a03f208f493dd5c76bf8e02598bbbaf1f6fc2876", size = 14425, upload-time = "2025-07-24T19:01:19.963Z" }, +] + [[package]] name = "gunicorn" version = "23.0.0" @@ -1977,6 +2468,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, ] +[[package]] +name = "httplib2" +version = "0.31.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyparsing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/77/6653db69c1f7ecfe5e3f9726fdadc981794656fcd7d98c4209fecfea9993/httplib2-0.31.0.tar.gz", hash = "sha256:ac7ab497c50975147d4f7b1ade44becc7df2f8954d42b38b3d69c515f531135c", size = 250759, upload-time = "2025-09-11T12:16:03.403Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8c/a2/0d269db0f6163be503775dc8b6a6fa15820cc9fdc866f6ba608d86b721f2/httplib2-0.31.0-py3-none-any.whl", hash = "sha256:b9cd78abea9b4e43a7714c6e0f8b6b8561a6fc1e95d5dbd367f5bf0ef35f5d24", size = 91148, upload-time = "2025-09-11T12:16:01.803Z" }, +] + [[package]] name = "httptools" version = "0.6.4" @@ -2014,6 +2517,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "httpx-sse" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, +] + [[package]] name = "huggingface-hub" version = "0.34.4" @@ -2575,6 +3087,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5b/60/3601f8ce6d76a7c81c7f25a0e15fde0d6b66226dd187aa6d2838e6374161/matplotlib-3.10.5-cp314-cp314t-win_arm64.whl", hash = "sha256:2efaf97d72629e74252e0b5e3c46813e9eeaa94e011ecf8084a971a31a97f40b", size = 8153849, upload-time = "2025-07-31T18:09:19.673Z" }, ] +[[package]] +name = "mcp" +version = "1.22.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "jsonschema" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "python-multipart" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/a2/c5ec0ab38b35ade2ae49a90fada718fbc76811dc5aa1760414c6aaa6b08a/mcp-1.22.0.tar.gz", hash = "sha256:769b9ac90ed42134375b19e777a2858ca300f95f2e800982b3e2be62dfc0ba01", size = 471788, upload-time = "2025-11-20T20:11:28.095Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/bb/711099f9c6bb52770f56e56401cdfb10da5b67029f701e0df29362df4c8e/mcp-1.22.0-py3-none-any.whl", hash = "sha256:bed758e24df1ed6846989c909ba4e3df339a27b4f30f1b8b627862a4bade4e98", size = 175489, upload-time = "2025-11-20T20:11:26.542Z" }, +] + [[package]] name = "mdit-py-plugins" version = "0.5.0" @@ -2736,12 +3273,18 @@ name = "ml-dtypes" version = "0.4.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", - "python_full_version >= '3.13' and sys_platform == 'win32'", - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and sys_platform == 'darwin'", - "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", ] dependencies = [ { name = "numpy", marker = "python_full_version >= '3.13'" }, @@ -3239,6 +3782,10 @@ dependencies = [ ] [package.optional-dependencies] +adk = [ + { name = "google-adk" }, + { name = "google-genai" }, +] automodel = [ { name = "causal-conv1d" }, { name = "flash-attn" }, @@ -3321,6 +3868,8 @@ requires-dist = [ { name = "flash-attn", marker = "extra == 'automodel'", specifier = "==2.8.1" }, { name = "flash-attn", marker = "extra == 'mcore'", specifier = "==2.8.1" }, { name = "flash-attn", marker = "extra == 'vllm'", specifier = "==2.8.1" }, + { name = "google-adk", marker = "extra == 'adk'", specifier = "==1.14.1" }, + { name = "google-genai", marker = "extra == 'adk'", specifier = "==1.38.0" }, { name = "hydra-core" }, { name = "mamba-ssm", marker = "extra == 'automodel'", git = "https://github.com/state-spaces/mamba.git?rev=2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4" }, { name = "mamba-ssm", marker = "extra == 'vllm'", git = "https://github.com/state-spaces/mamba.git?rev=2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4" }, @@ -3362,7 +3911,7 @@ requires-dist = [ { name = "vllm", marker = "extra == 'vllm'", specifier = "==0.11.0" }, { name = "wandb" }, ] -provides-extras = ["automodel", "vllm", "mcore", "penguin"] +provides-extras = ["automodel", "vllm", "mcore", "penguin", "adk"] [package.metadata.requires-dev] build = [ @@ -3934,6 +4483,66 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ae/a2/d86e01c28300bd41bab8f18afd613676e2bd63515417b77636fc1add426f/opentelemetry_api-1.38.0-py3-none-any.whl", hash = "sha256:2891b0197f47124454ab9f0cf58f3be33faca394457ac3e09daba13ff50aa582", size = 65947, upload-time = "2025-10-16T08:35:30.23Z" }, ] +[[package]] +name = "opentelemetry-exporter-gcp-logging" +version = "1.11.0a0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-cloud-logging" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-resourcedetector-gcp" }, + { name = "opentelemetry-sdk" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/2d/6aa7063b009768d8f9415b36a29ae9b3eb1e2c5eff70f58ca15e104c245f/opentelemetry_exporter_gcp_logging-1.11.0a0.tar.gz", hash = "sha256:58496f11b930c84570060ffbd4343cd0b597ea13c7bc5c879df01163dd552f14", size = 22400, upload-time = "2025-11-04T19:32:13.812Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/b7/2d3df53fa39bfd52f88c78a60367d45a7b1adbf8a756cce62d6ac149d49a/opentelemetry_exporter_gcp_logging-1.11.0a0-py3-none-any.whl", hash = "sha256:f8357c552947cb9c0101c4575a7702b8d3268e28bdeefdd1405cf838e128c6ef", size = 14168, upload-time = "2025-11-04T19:32:07.073Z" }, +] + +[[package]] +name = "opentelemetry-exporter-gcp-trace" +version = "1.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-cloud-trace" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-resourcedetector-gcp" }, + { name = "opentelemetry-sdk" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/10/9c/4c3b26e5494f8b53c7873732a2317df905abe2b8ab33e9edfcbd5a8ff79b/opentelemetry_exporter_gcp_trace-1.11.0.tar.gz", hash = "sha256:c947ab4ab53e16517ade23d6fe71fe88cf7ca3f57a42c9f0e4162d2b929fecb6", size = 18770, upload-time = "2025-11-04T19:32:15.109Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/4a/876703e8c5845198d95cd4006c8d1b2e3b129a9e288558e33133360f8d5d/opentelemetry_exporter_gcp_trace-1.11.0-py3-none-any.whl", hash = "sha256:b3dcb314e1a9985e9185cb7720b693eb393886fde98ae4c095ffc0893de6cefa", size = 14016, upload-time = "2025-11-04T19:32:09.009Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-common" +version = "1.38.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-proto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/83/dd4660f2956ff88ed071e9e0e36e830df14b8c5dc06722dbde1841accbe8/opentelemetry_exporter_otlp_proto_common-1.38.0.tar.gz", hash = "sha256:e333278afab4695aa8114eeb7bf4e44e65c6607d54968271a249c180b2cb605c", size = 20431, upload-time = "2025-10-16T08:35:53.285Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/9e/55a41c9601191e8cd8eb626b54ee6827b9c9d4a46d736f32abc80d8039fc/opentelemetry_exporter_otlp_proto_common-1.38.0-py3-none-any.whl", hash = "sha256:03cb76ab213300fe4f4c62b7d8f17d97fcfd21b89f0b5ce38ea156327ddda74a", size = 18359, upload-time = "2025-10-16T08:35:34.099Z" }, +] + +[[package]] +name = "opentelemetry-exporter-otlp-proto-http" +version = "1.38.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "googleapis-common-protos" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-exporter-otlp-proto-common" }, + { name = "opentelemetry-proto" }, + { name = "opentelemetry-sdk" }, + { name = "requests" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/81/0a/debcdfb029fbd1ccd1563f7c287b89a6f7bef3b2902ade56797bfd020854/opentelemetry_exporter_otlp_proto_http-1.38.0.tar.gz", hash = "sha256:f16bd44baf15cbe07633c5112ffc68229d0edbeac7b37610be0b2def4e21e90b", size = 17282, upload-time = "2025-10-16T08:35:54.422Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/77/154004c99fb9f291f74aa0822a2f5bbf565a72d8126b3a1b63ed8e5f83c7/opentelemetry_exporter_otlp_proto_http-1.38.0-py3-none-any.whl", hash = "sha256:84b937305edfc563f08ec69b9cb2298be8188371217e867c1854d77198d0825b", size = 19579, upload-time = "2025-10-16T08:35:36.269Z" }, +] + [[package]] name = "opentelemetry-exporter-prometheus" version = "0.59b0" @@ -3960,6 +4569,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b6/6a/82b68b14efca5150b2632f3692d627afa76b77378c4999f2648979409528/opentelemetry_proto-1.38.0-py3-none-any.whl", hash = "sha256:b6ebe54d3217c42e45462e2a1ae28c3e2bf2ec5a5645236a490f55f45f1a0a18", size = 72535, upload-time = "2025-10-16T08:35:45.749Z" }, ] +[[package]] +name = "opentelemetry-resourcedetector-gcp" +version = "1.11.0a0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-sdk" }, + { name = "requests" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c1/5d/2b3240d914b87b6dd9cd5ca2ef1ccaf1d0626b897d4c06877e22c8c10fcf/opentelemetry_resourcedetector_gcp-1.11.0a0.tar.gz", hash = "sha256:915a1d6fd15daca9eedd3fc52b0f705375054f2ef140e2e7a6b4cca95a47cdb1", size = 18796, upload-time = "2025-11-04T19:32:16.59Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/6c/1e13fe142a7ca3dc6489167203a1209d32430cca12775e1df9c9a41c54b2/opentelemetry_resourcedetector_gcp-1.11.0a0-py3-none-any.whl", hash = "sha256:5d65a2a039b1d40c6f41421dbb08d5f441368275ac6de6e76a8fccd1f6acb67e", size = 18798, upload-time = "2025-11-04T19:32:10.915Z" }, +] + [[package]] name = "opentelemetry-sdk" version = "1.38.0" @@ -4761,6 +5385,20 @@ pycountry = [ { name = "pycountry" }, ] +[[package]] +name = "pydantic-settings" +version = "2.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/4b/ac7e0aae12027748076d72a8764ff1c9d82ca75a7a52622e67ed3f765c54/pydantic_settings-2.12.0.tar.gz", hash = "sha256:005538ef951e3c2a68e1c08b292b5f2e71490def8589d4221b95dab00dafcfd0", size = 194184, upload-time = "2025-11-10T14:25:47.013Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/60/5d4751ba3f4a40a6891f24eec885f51afd78d208498268c734e256fb13c4/pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809", size = 51880, upload-time = "2025-11-10T14:25:45.546Z" }, +] + [[package]] name = "pydata-sphinx-theme" version = "0.16.1" @@ -4811,6 +5449,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, +] + +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + [[package]] name = "pynvml" version = "12.0.0" @@ -5692,6 +6344,57 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0d/6d/b4752b044bf94cb802d88a888dc7d288baaf77d7910b7dedda74b5ceea0c/setuptools-79.0.1-py3-none-any.whl", hash = "sha256:e147c0549f27767ba362f9da434eab9c5dc0045d5304feb602a0af001089fc51", size = 1256281, upload-time = "2025-04-23T22:20:56.768Z" }, ] +[[package]] +name = "shapely" +version = "2.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4d/bc/0989043118a27cccb4e906a46b7565ce36ca7b57f5a18b78f4f1b0f72d9d/shapely-2.1.2.tar.gz", hash = "sha256:2ed4ecb28320a433db18a5bf029986aa8afcfd740745e78847e330d5d94922a9", size = 315489, upload-time = "2025-09-24T13:51:41.432Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/24/c0/f3b6453cf2dfa99adc0ba6675f9aaff9e526d2224cbd7ff9c1a879238693/shapely-2.1.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fe2533caae6a91a543dec62e8360fe86ffcdc42a7c55f9dfd0128a977a896b94", size = 1833550, upload-time = "2025-09-24T13:50:30.019Z" }, + { url = "https://files.pythonhosted.org/packages/86/07/59dee0bc4b913b7ab59ab1086225baca5b8f19865e6101db9ebb7243e132/shapely-2.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ba4d1333cc0bc94381d6d4308d2e4e008e0bd128bdcff5573199742ee3634359", size = 1643556, upload-time = "2025-09-24T13:50:32.291Z" }, + { url = "https://files.pythonhosted.org/packages/26/29/a5397e75b435b9895cd53e165083faed5d12fd9626eadec15a83a2411f0f/shapely-2.1.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0bd308103340030feef6c111d3eb98d50dc13feea33affc8a6f9fa549e9458a3", size = 2988308, upload-time = "2025-09-24T13:50:33.862Z" }, + { url = "https://files.pythonhosted.org/packages/b9/37/e781683abac55dde9771e086b790e554811a71ed0b2b8a1e789b7430dd44/shapely-2.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1e7d4d7ad262a48bb44277ca12c7c78cb1b0f56b32c10734ec9a1d30c0b0c54b", size = 3099844, upload-time = "2025-09-24T13:50:35.459Z" }, + { url = "https://files.pythonhosted.org/packages/d8/f3/9876b64d4a5a321b9dc482c92bb6f061f2fa42131cba643c699f39317cb9/shapely-2.1.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e9eddfe513096a71896441a7c37db72da0687b34752c4e193577a145c71736fc", size = 3988842, upload-time = "2025-09-24T13:50:37.478Z" }, + { url = "https://files.pythonhosted.org/packages/d1/a0/704c7292f7014c7e74ec84eddb7b109e1fbae74a16deae9c1504b1d15565/shapely-2.1.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:980c777c612514c0cf99bc8a9de6d286f5e186dcaf9091252fcd444e5638193d", size = 4152714, upload-time = "2025-09-24T13:50:39.9Z" }, + { url = "https://files.pythonhosted.org/packages/53/46/319c9dc788884ad0785242543cdffac0e6530e4d0deb6c4862bc4143dcf3/shapely-2.1.2-cp312-cp312-win32.whl", hash = "sha256:9111274b88e4d7b54a95218e243282709b330ef52b7b86bc6aaf4f805306f454", size = 1542745, upload-time = "2025-09-24T13:50:41.414Z" }, + { url = "https://files.pythonhosted.org/packages/ec/bf/cb6c1c505cb31e818e900b9312d514f381fbfa5c4363edfce0fcc4f8c1a4/shapely-2.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:743044b4cfb34f9a67205cee9279feaf60ba7d02e69febc2afc609047cb49179", size = 1722861, upload-time = "2025-09-24T13:50:43.35Z" }, + { url = "https://files.pythonhosted.org/packages/c3/90/98ef257c23c46425dc4d1d31005ad7c8d649fe423a38b917db02c30f1f5a/shapely-2.1.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b510dda1a3672d6879beb319bc7c5fd302c6c354584690973c838f46ec3e0fa8", size = 1832644, upload-time = "2025-09-24T13:50:44.886Z" }, + { url = "https://files.pythonhosted.org/packages/6d/ab/0bee5a830d209adcd3a01f2d4b70e587cdd9fd7380d5198c064091005af8/shapely-2.1.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8cff473e81017594d20ec55d86b54bc635544897e13a7cfc12e36909c5309a2a", size = 1642887, upload-time = "2025-09-24T13:50:46.735Z" }, + { url = "https://files.pythonhosted.org/packages/2d/5e/7d7f54ba960c13302584c73704d8c4d15404a51024631adb60b126a4ae88/shapely-2.1.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fe7b77dc63d707c09726b7908f575fc04ff1d1ad0f3fb92aec212396bc6cfe5e", size = 2970931, upload-time = "2025-09-24T13:50:48.374Z" }, + { url = "https://files.pythonhosted.org/packages/f2/a2/83fc37e2a58090e3d2ff79175a95493c664bcd0b653dd75cb9134645a4e5/shapely-2.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7ed1a5bbfb386ee8332713bf7508bc24e32d24b74fc9a7b9f8529a55db9f4ee6", size = 3082855, upload-time = "2025-09-24T13:50:50.037Z" }, + { url = "https://files.pythonhosted.org/packages/44/2b/578faf235a5b09f16b5f02833c53822294d7f21b242f8e2d0cf03fb64321/shapely-2.1.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a84e0582858d841d54355246ddfcbd1fce3179f185da7470f41ce39d001ee1af", size = 3979960, upload-time = "2025-09-24T13:50:51.74Z" }, + { url = "https://files.pythonhosted.org/packages/4d/04/167f096386120f692cc4ca02f75a17b961858997a95e67a3cb6a7bbd6b53/shapely-2.1.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:dc3487447a43d42adcdf52d7ac73804f2312cbfa5d433a7d2c506dcab0033dfd", size = 4142851, upload-time = "2025-09-24T13:50:53.49Z" }, + { url = "https://files.pythonhosted.org/packages/48/74/fb402c5a6235d1c65a97348b48cdedb75fb19eca2b1d66d04969fc1c6091/shapely-2.1.2-cp313-cp313-win32.whl", hash = "sha256:9c3a3c648aedc9f99c09263b39f2d8252f199cb3ac154fadc173283d7d111350", size = 1541890, upload-time = "2025-09-24T13:50:55.337Z" }, + { url = "https://files.pythonhosted.org/packages/41/47/3647fe7ad990af60ad98b889657a976042c9988c2807cf322a9d6685f462/shapely-2.1.2-cp313-cp313-win_amd64.whl", hash = "sha256:ca2591bff6645c216695bdf1614fca9c82ea1144d4a7591a466fef64f28f0715", size = 1722151, upload-time = "2025-09-24T13:50:57.153Z" }, + { url = "https://files.pythonhosted.org/packages/3c/49/63953754faa51ffe7d8189bfbe9ca34def29f8c0e34c67cbe2a2795f269d/shapely-2.1.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2d93d23bdd2ed9dc157b46bc2f19b7da143ca8714464249bef6771c679d5ff40", size = 1834130, upload-time = "2025-09-24T13:50:58.49Z" }, + { url = "https://files.pythonhosted.org/packages/7f/ee/dce001c1984052970ff60eb4727164892fb2d08052c575042a47f5a9e88f/shapely-2.1.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:01d0d304b25634d60bd7cf291828119ab55a3bab87dc4af1e44b07fb225f188b", size = 1642802, upload-time = "2025-09-24T13:50:59.871Z" }, + { url = "https://files.pythonhosted.org/packages/da/e7/fc4e9a19929522877fa602f705706b96e78376afb7fad09cad5b9af1553c/shapely-2.1.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8d8382dd120d64b03698b7298b89611a6ea6f55ada9d39942838b79c9bc89801", size = 3018460, upload-time = "2025-09-24T13:51:02.08Z" }, + { url = "https://files.pythonhosted.org/packages/a1/18/7519a25db21847b525696883ddc8e6a0ecaa36159ea88e0fef11466384d0/shapely-2.1.2-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:19efa3611eef966e776183e338b2d7ea43569ae99ab34f8d17c2c054d3205cc0", size = 3095223, upload-time = "2025-09-24T13:51:04.472Z" }, + { url = "https://files.pythonhosted.org/packages/48/de/b59a620b1f3a129c3fecc2737104a0a7e04e79335bd3b0a1f1609744cf17/shapely-2.1.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:346ec0c1a0fcd32f57f00e4134d1200e14bf3f5ae12af87ba83ca275c502498c", size = 4030760, upload-time = "2025-09-24T13:51:06.455Z" }, + { url = "https://files.pythonhosted.org/packages/96/b3/c6655ee7232b417562bae192ae0d3ceaadb1cc0ffc2088a2ddf415456cc2/shapely-2.1.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6305993a35989391bd3476ee538a5c9a845861462327efe00dd11a5c8c709a99", size = 4170078, upload-time = "2025-09-24T13:51:08.584Z" }, + { url = "https://files.pythonhosted.org/packages/a0/8e/605c76808d73503c9333af8f6cbe7e1354d2d238bda5f88eea36bfe0f42a/shapely-2.1.2-cp313-cp313t-win32.whl", hash = "sha256:c8876673449f3401f278c86eb33224c5764582f72b653a415d0e6672fde887bf", size = 1559178, upload-time = "2025-09-24T13:51:10.73Z" }, + { url = "https://files.pythonhosted.org/packages/36/f7/d317eb232352a1f1444d11002d477e54514a4a6045536d49d0c59783c0da/shapely-2.1.2-cp313-cp313t-win_amd64.whl", hash = "sha256:4a44bc62a10d84c11a7a3d7c1c4fe857f7477c3506e24c9062da0db0ae0c449c", size = 1739756, upload-time = "2025-09-24T13:51:12.105Z" }, + { url = "https://files.pythonhosted.org/packages/fc/c4/3ce4c2d9b6aabd27d26ec988f08cb877ba9e6e96086eff81bfea93e688c7/shapely-2.1.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:9a522f460d28e2bf4e12396240a5fc1518788b2fcd73535166d748399ef0c223", size = 1831290, upload-time = "2025-09-24T13:51:13.56Z" }, + { url = "https://files.pythonhosted.org/packages/17/b9/f6ab8918fc15429f79cb04afa9f9913546212d7fb5e5196132a2af46676b/shapely-2.1.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1ff629e00818033b8d71139565527ced7d776c269a49bd78c9df84e8f852190c", size = 1641463, upload-time = "2025-09-24T13:51:14.972Z" }, + { url = "https://files.pythonhosted.org/packages/a5/57/91d59ae525ca641e7ac5551c04c9503aee6f29b92b392f31790fcb1a4358/shapely-2.1.2-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f67b34271dedc3c653eba4e3d7111aa421d5be9b4c4c7d38d30907f796cb30df", size = 2970145, upload-time = "2025-09-24T13:51:16.961Z" }, + { url = "https://files.pythonhosted.org/packages/8a/cb/4948be52ee1da6927831ab59e10d4c29baa2a714f599f1f0d1bc747f5777/shapely-2.1.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:21952dc00df38a2c28375659b07a3979d22641aeb104751e769c3ee825aadecf", size = 3073806, upload-time = "2025-09-24T13:51:18.712Z" }, + { url = "https://files.pythonhosted.org/packages/03/83/f768a54af775eb41ef2e7bec8a0a0dbe7d2431c3e78c0a8bdba7ab17e446/shapely-2.1.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:1f2f33f486777456586948e333a56ae21f35ae273be99255a191f5c1fa302eb4", size = 3980803, upload-time = "2025-09-24T13:51:20.37Z" }, + { url = "https://files.pythonhosted.org/packages/9f/cb/559c7c195807c91c79d38a1f6901384a2878a76fbdf3f1048893a9b7534d/shapely-2.1.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:cf831a13e0d5a7eb519e96f58ec26e049b1fad411fc6fc23b162a7ce04d9cffc", size = 4133301, upload-time = "2025-09-24T13:51:21.887Z" }, + { url = "https://files.pythonhosted.org/packages/80/cd/60d5ae203241c53ef3abd2ef27c6800e21afd6c94e39db5315ea0cbafb4a/shapely-2.1.2-cp314-cp314-win32.whl", hash = "sha256:61edcd8d0d17dd99075d320a1dd39c0cb9616f7572f10ef91b4b5b00c4aeb566", size = 1583247, upload-time = "2025-09-24T13:51:23.401Z" }, + { url = "https://files.pythonhosted.org/packages/74/d4/135684f342e909330e50d31d441ace06bf83c7dc0777e11043f99167b123/shapely-2.1.2-cp314-cp314-win_amd64.whl", hash = "sha256:a444e7afccdb0999e203b976adb37ea633725333e5b119ad40b1ca291ecf311c", size = 1773019, upload-time = "2025-09-24T13:51:24.873Z" }, + { url = "https://files.pythonhosted.org/packages/a3/05/a44f3f9f695fa3ada22786dc9da33c933da1cbc4bfe876fe3a100bafe263/shapely-2.1.2-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:5ebe3f84c6112ad3d4632b1fd2290665aa75d4cef5f6c5d77c4c95b324527c6a", size = 1834137, upload-time = "2025-09-24T13:51:26.665Z" }, + { url = "https://files.pythonhosted.org/packages/52/7e/4d57db45bf314573427b0a70dfca15d912d108e6023f623947fa69f39b72/shapely-2.1.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5860eb9f00a1d49ebb14e881f5caf6c2cf472c7fd38bd7f253bbd34f934eb076", size = 1642884, upload-time = "2025-09-24T13:51:28.029Z" }, + { url = "https://files.pythonhosted.org/packages/5a/27/4e29c0a55d6d14ad7422bf86995d7ff3f54af0eba59617eb95caf84b9680/shapely-2.1.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b705c99c76695702656327b819c9660768ec33f5ce01fa32b2af62b56ba400a1", size = 3018320, upload-time = "2025-09-24T13:51:29.903Z" }, + { url = "https://files.pythonhosted.org/packages/9f/bb/992e6a3c463f4d29d4cd6ab8963b75b1b1040199edbd72beada4af46bde5/shapely-2.1.2-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a1fd0ea855b2cf7c9cddaf25543e914dd75af9de08785f20ca3085f2c9ca60b0", size = 3094931, upload-time = "2025-09-24T13:51:32.699Z" }, + { url = "https://files.pythonhosted.org/packages/9c/16/82e65e21070e473f0ed6451224ed9fa0be85033d17e0c6e7213a12f59d12/shapely-2.1.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:df90e2db118c3671a0754f38e36802db75fe0920d211a27481daf50a711fdf26", size = 4030406, upload-time = "2025-09-24T13:51:34.189Z" }, + { url = "https://files.pythonhosted.org/packages/7c/75/c24ed871c576d7e2b64b04b1fe3d075157f6eb54e59670d3f5ffb36e25c7/shapely-2.1.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:361b6d45030b4ac64ddd0a26046906c8202eb60d0f9f53085f5179f1d23021a0", size = 4169511, upload-time = "2025-09-24T13:51:36.297Z" }, + { url = "https://files.pythonhosted.org/packages/b1/f7/b3d1d6d18ebf55236eec1c681ce5e665742aab3c0b7b232720a7d43df7b6/shapely-2.1.2-cp314-cp314t-win32.whl", hash = "sha256:b54df60f1fbdecc8ebc2c5b11870461a6417b3d617f555e5033f1505d36e5735", size = 1602607, upload-time = "2025-09-24T13:51:37.757Z" }, + { url = "https://files.pythonhosted.org/packages/9a/f6/f09272a71976dfc138129b8faf435d064a811ae2f708cb147dccdf7aacdb/shapely-2.1.2-cp314-cp314t-win_amd64.whl", hash = "sha256:0036ac886e0923417932c2e6369b6c52e38e0ff5d9120b90eef5cd9a5fc5cae9", size = 1796682, upload-time = "2025-09-24T13:51:39.233Z" }, +] + [[package]] name = "shellingham" version = "1.5.4" @@ -6006,6 +6709,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/d9/13bdde6521f322861fab67473cec4b1cc8999f3871953531cf61945fad92/sqlalchemy-2.0.43-py3-none-any.whl", hash = "sha256:1681c21dd2ccee222c2fe0bef671d1aef7c504087c9c4e800371cfcc8ac966fc", size = 1924759, upload-time = "2025-08-11T15:39:53.024Z" }, ] +[[package]] +name = "sqlalchemy-spanner" +version = "1.17.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "alembic" }, + { name = "google-cloud-spanner" }, + { name = "sqlalchemy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/23/64/74e4d7aebc5210feff9b27e799fa81cc2bdf38f474e304e5c2b3f934f361/sqlalchemy_spanner-1.17.1.tar.gz", hash = "sha256:1542c2e69b1923974d8ad884ffc458f7d135e44af1c475b98decf75d90eccaa3", size = 82630, upload-time = "2025-10-21T14:33:54.183Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/72/187ca1767648d54ada46c074b2b346894712bc56b6c0dab3410bd0996209/sqlalchemy_spanner-1.17.1-py3-none-any.whl", hash = "sha256:8b8444c23e66c84aab5dbab589face8fd75733fa6c1811db368d5202cdfb5f8e", size = 31859, upload-time = "2025-10-21T14:33:52.926Z" }, +] + [[package]] name = "sqlparse" version = "0.5.3" @@ -6015,6 +6732,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a9/5c/bfd6bd0bf979426d405cc6e71eceb8701b148b16c21d2dc3c261efc61c7b/sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca", size = 44415, upload-time = "2024-12-10T12:05:27.824Z" }, ] +[[package]] +name = "sse-starlette" +version = "3.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/db/3c/fa6517610dc641262b77cc7bf994ecd17465812c1b0585fe33e11be758ab/sse_starlette-3.0.3.tar.gz", hash = "sha256:88cfb08747e16200ea990c8ca876b03910a23b547ab3bd764c0d8eb81019b971", size = 21943, upload-time = "2025-10-30T18:44:20.117Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/a0/984525d19ca5c8a6c33911a0c164b11490dd0f90ff7fd689f704f84e9a11/sse_starlette-3.0.3-py3-none-any.whl", hash = "sha256:af5bf5a6f3933df1d9c7f8539633dc8444ca6a97ab2e2a7cd3b6e431ac03a431", size = 11765, upload-time = "2025-10-30T18:44:18.834Z" }, +] + [[package]] name = "starlette" version = "0.47.2" @@ -6108,6 +6837,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b4/94/fd3853b98f39d10206b08f2737d2ec2dc6f46a42dc7b7e05f4f0162d13ee/tdigest-0.5.2.2-py3-none-any.whl", hash = "sha256:dd25f8d6e6be002192bba9e4b8c16491d36c10b389f50637818603d1f67c6fb2", size = 9440, upload-time = "2019-05-07T18:57:38.942Z" }, ] +[[package]] +name = "tenacity" +version = "8.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/4d/6a19536c50b849338fcbe9290d562b52cbdcf30d8963d3588a68a4107df1/tenacity-8.5.0.tar.gz", hash = "sha256:8bc6c0c8a09b31e6cad13c47afbed1a567518250a9a171418582ed8d9c20ca78", size = 47309, upload-time = "2024-07-05T07:25:31.836Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/3f/8ba87d9e287b9d385a02a7114ddcef61b26f86411e121c9003eb509a1773/tenacity-8.5.0-py3-none-any.whl", hash = "sha256:b594c2a5945830c267ce6b79a166228323ed52718f30302c1359836112346687", size = 28165, upload-time = "2024-07-05T07:25:29.591Z" }, +] + [[package]] name = "tensorboard" version = "2.20.0" @@ -6143,12 +6881,18 @@ name = "tensorstore" version = "0.1.74" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", - "python_full_version >= '3.13' and sys_platform == 'win32'", - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and sys_platform == 'darwin'", - "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", ] dependencies = [ { name = "ml-dtypes", version = "0.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.13'" }, @@ -6289,7 +7033,8 @@ name = "torch" version = "2.8.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and sys_platform == 'darwin'", "python_full_version < '3.13' and sys_platform == 'darwin'", ] dependencies = [ @@ -6312,15 +7057,20 @@ name = "torch" version = "2.8.0+cu129" source = { registry = "https://download.pytorch.org/whl/cu129" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", - "python_full_version >= '3.13' and sys_platform == 'win32'", - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", "python_full_version < '3.13' and sys_platform == 'win32'", "python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", "python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", ] dependencies = [ @@ -6436,7 +7186,8 @@ name = "torchvision" version = "0.23.0" source = { registry = "https://download.pytorch.org/whl/cu129" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", ] dependencies = [ @@ -6455,7 +7206,8 @@ name = "torchvision" version = "0.23.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and sys_platform == 'darwin'", + "python_full_version >= '3.14' and sys_platform == 'darwin'", + "python_full_version == '3.13.*' and sys_platform == 'darwin'", "python_full_version < '3.13' and sys_platform == 'darwin'", ] dependencies = [ @@ -6474,11 +7226,15 @@ name = "torchvision" version = "0.23.0+cu129" source = { registry = "https://download.pytorch.org/whl/cu129" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", - "python_full_version >= '3.13' and sys_platform == 'win32'", - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", "python_full_version < '3.13' and sys_platform == 'win32'", "python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", "python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", @@ -6578,9 +7334,11 @@ name = "triton" version = "3.4.0" source = { registry = "https://download.pytorch.org/whl/cu129" } resolution-markers = [ - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux'", - "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", ] dependencies = [ @@ -6600,9 +7358,12 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.13' and sys_platform == 'win32'", - "python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", - "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and sys_platform == 'win32'", + "python_full_version == '3.13.*' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", + "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", "python_full_version < '3.13' and sys_platform == 'win32'", "python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", "python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32'", @@ -6695,6 +7456,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" }, ] +[[package]] +name = "tzlocal" +version = "5.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "tzdata", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8b/2e/c14812d3d4d9cd1773c6be938f89e5735a1f11a9f184ac3639b93cef35d5/tzlocal-5.3.1.tar.gz", hash = "sha256:cceffc7edecefea1f595541dbd6e990cb1ea3d19bf01b2809f362a03dd7921fd", size = 30761, upload-time = "2025-03-05T21:17:41.549Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/14/e2a54fabd4f08cd7af1c07030603c3356b74da07f7cc056e600436edfa17/tzlocal-5.3.1-py3-none-any.whl", hash = "sha256:eb1a66c3ef5847adf7a834f1be0800581b683b5608e74f86ecbcef8ab91bb85d", size = 18026, upload-time = "2025-03-05T21:17:39.857Z" }, +] + +[[package]] +name = "uritemplate" +version = "4.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/98/60/f174043244c5306c9988380d2cb10009f91563fc4b31293d27e17201af56/uritemplate-4.2.0.tar.gz", hash = "sha256:480c2ed180878955863323eea31b0ede668795de182617fef9c6ca09e6ec9d0e", size = 33267, upload-time = "2025-06-02T15:12:06.318Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/99/3ae339466c9183ea5b8ae87b34c0b897eda475d2aec2307cae60e5cd4f29/uritemplate-4.2.0-py3-none-any.whl", hash = "sha256:962201ba1c4edcab02e60f9a0d3821e82dfc5d2d6662a21abd533879bdb8a686", size = 11488, upload-time = "2025-06-02T15:12:03.405Z" }, +] + [[package]] name = "urllib3" version = "1.26.20" @@ -6870,6 +7652,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/ba/81c77d5d831fcddb89661c85175fcbb91d2ffecf6b0591972829da3eb42f/wandb-0.21.1-py3-none-win_amd64.whl", hash = "sha256:8be92a7e92b5cb5ce00ec0961f9dbaad7757ffdbc5b5a8f2cc7188e23f653f0a", size = 21569817, upload-time = "2025-08-07T18:52:45.559Z" }, ] +[[package]] +name = "watchdog" +version = "6.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/db/7d/7f3d619e951c88ed75c6037b246ddcf2d322812ee8ea189be89511721d54/watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282", size = 131220, upload-time = "2024-11-01T14:07:13.037Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/ea/3930d07dafc9e286ed356a679aa02d777c06e9bfd1164fa7c19c288a5483/watchdog-6.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bdd4e6f14b8b18c334febb9c4425a878a2ac20efd1e0b231978e7b150f92a948", size = 96471, upload-time = "2024-11-01T14:06:37.745Z" }, + { url = "https://files.pythonhosted.org/packages/12/87/48361531f70b1f87928b045df868a9fd4e253d9ae087fa4cf3f7113be363/watchdog-6.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c7c15dda13c4eb00d6fb6fc508b3c0ed88b9d5d374056b239c4ad1611125c860", size = 88449, upload-time = "2024-11-01T14:06:39.748Z" }, + { url = "https://files.pythonhosted.org/packages/5b/7e/8f322f5e600812e6f9a31b75d242631068ca8f4ef0582dd3ae6e72daecc8/watchdog-6.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f10cb2d5902447c7d0da897e2c6768bca89174d0c6e1e30abec5421af97a5b0", size = 89054, upload-time = "2024-11-01T14:06:41.009Z" }, + { url = "https://files.pythonhosted.org/packages/68/98/b0345cabdce2041a01293ba483333582891a3bd5769b08eceb0d406056ef/watchdog-6.0.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:490ab2ef84f11129844c23fb14ecf30ef3d8a6abafd3754a6f75ca1e6654136c", size = 96480, upload-time = "2024-11-01T14:06:42.952Z" }, + { url = "https://files.pythonhosted.org/packages/85/83/cdf13902c626b28eedef7ec4f10745c52aad8a8fe7eb04ed7b1f111ca20e/watchdog-6.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:76aae96b00ae814b181bb25b1b98076d5fc84e8a53cd8885a318b42b6d3a5134", size = 88451, upload-time = "2024-11-01T14:06:45.084Z" }, + { url = "https://files.pythonhosted.org/packages/fe/c4/225c87bae08c8b9ec99030cd48ae9c4eca050a59bf5c2255853e18c87b50/watchdog-6.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a175f755fc2279e0b7312c0035d52e27211a5bc39719dd529625b1930917345b", size = 89057, upload-time = "2024-11-01T14:06:47.324Z" }, + { url = "https://files.pythonhosted.org/packages/a9/c7/ca4bf3e518cb57a686b2feb4f55a1892fd9a3dd13f470fca14e00f80ea36/watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13", size = 79079, upload-time = "2024-11-01T14:06:59.472Z" }, + { url = "https://files.pythonhosted.org/packages/5c/51/d46dc9332f9a647593c947b4b88e2381c8dfc0942d15b8edc0310fa4abb1/watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379", size = 79078, upload-time = "2024-11-01T14:07:01.431Z" }, + { url = "https://files.pythonhosted.org/packages/d4/57/04edbf5e169cd318d5f07b4766fee38e825d64b6913ca157ca32d1a42267/watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e", size = 79076, upload-time = "2024-11-01T14:07:02.568Z" }, + { url = "https://files.pythonhosted.org/packages/ab/cc/da8422b300e13cb187d2203f20b9253e91058aaf7db65b74142013478e66/watchdog-6.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:212ac9b8bf1161dc91bd09c048048a95ca3a4c4f5e5d4a7d1b1a7d5752a7f96f", size = 79077, upload-time = "2024-11-01T14:07:03.893Z" }, + { url = "https://files.pythonhosted.org/packages/2c/3b/b8964e04ae1a025c44ba8e4291f86e97fac443bca31de8bd98d3263d2fcf/watchdog-6.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:e3df4cbb9a450c6d49318f6d14f4bbc80d763fa587ba46ec86f99f9e6876bb26", size = 79078, upload-time = "2024-11-01T14:07:05.189Z" }, + { url = "https://files.pythonhosted.org/packages/62/ae/a696eb424bedff7407801c257d4b1afda455fe40821a2be430e173660e81/watchdog-6.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:2cce7cfc2008eb51feb6aab51251fd79b85d9894e98ba847408f662b3395ca3c", size = 79077, upload-time = "2024-11-01T14:07:06.376Z" }, + { url = "https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:20ffe5b202af80ab4266dcd3e91aae72bf2da48c0d33bdb15c66658e685e94e2", size = 79078, upload-time = "2024-11-01T14:07:07.547Z" }, + { url = "https://files.pythonhosted.org/packages/07/f6/d0e5b343768e8bcb4cda79f0f2f55051bf26177ecd5651f84c07567461cf/watchdog-6.0.0-py3-none-win32.whl", hash = "sha256:07df1fdd701c5d4c8e55ef6cf55b8f0120fe1aef7ef39a1c6fc6bc2e606d517a", size = 79065, upload-time = "2024-11-01T14:07:09.525Z" }, + { url = "https://files.pythonhosted.org/packages/db/d9/c495884c6e548fce18a8f40568ff120bc3a4b7b99813081c8ac0c936fa64/watchdog-6.0.0-py3-none-win_amd64.whl", hash = "sha256:cbafb470cf848d93b5d013e2ecb245d4aa1c8fd0504e863ccefa32445359d680", size = 79070, upload-time = "2024-11-01T14:07:10.686Z" }, + { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, +] + [[package]] name = "watchfiles" version = "1.1.0"