Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
423 changes: 419 additions & 4 deletions salt/client/ssh/__init__.py

Large diffs are not rendered by default.

86 changes: 71 additions & 15 deletions salt/states/grains.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,30 +178,44 @@ def list_present(name, value, delimiter=DEFAULT_TARGET_DELIM):
name = re.sub(delimiter, DEFAULT_TARGET_DELIM, name)
ret = {"name": name, "changes": {}, "result": True, "comment": ""}
grain = __salt__["grains.get"](name)

# Check pending_grains first to avoid duplicates within the same state run
pending_grains = __context__.get("pending_grains", {})
if name in pending_grains:
# Combine current grain with pending grains for duplicate checking
if grain and isinstance(grain, list):
combined_grain = set(grain) | pending_grains[name]
else:
combined_grain = pending_grains[name]
else:
combined_grain = set(grain) if grain and isinstance(grain, list) else set()

if grain:
# check whether grain is a list
if not isinstance(grain, list):
ret["result"] = False
ret["comment"] = f"Grain {name} is not a valid list"
return ret
if isinstance(value, list):
if make_hashable(value).issubset(
make_hashable(__salt__["grains.get"](name))
):
ret["comment"] = f"Value {value} is already in grain {name}"
return ret
elif name in __context__.get("pending_grains", {}):
# elements common to both
intersection = set(value).intersection(
__context__.get("pending_grains", {})[name]
# Check against combined grain (actual + pending) to avoid duplicates
if make_hashable(value).issubset(make_hashable(list(combined_grain))):
ret["comment"] = (
f"Value {value} is already in grain {name} (or pending)"
)
return ret
# Check for intersection with pending grains
if name in pending_grains:
intersection = set(value).intersection(pending_grains[name])
if intersection:
value = list(
set(value).difference(__context__["pending_grains"][name])
)
value = list(set(value).difference(pending_grains[name]))
if not value:
ret["comment"] = (
f'All values already pending in grain "{name}".\n'
)
return ret
ret["comment"] = (
'Removed value {} from update due to context found in "{}".\n'.format(
value, name
intersection, name
)
)
if "pending_grains" not in __context__:
Expand All @@ -210,9 +224,18 @@ def list_present(name, value, delimiter=DEFAULT_TARGET_DELIM):
__context__["pending_grains"][name] = set()
__context__["pending_grains"][name].update(value)
else:
if value in grain:
ret["comment"] = f"Value {value} is already in grain {name}"
# For single value, check against combined grain
if value in combined_grain:
ret["comment"] = (
f"Value {value} is already in grain {name} (or pending)"
)
return ret
# Add single value to pending_grains to avoid duplicates
if "pending_grains" not in __context__:
__context__["pending_grains"] = {}
if name not in __context__["pending_grains"]:
__context__["pending_grains"][name] = set()
__context__["pending_grains"][name].add(value)
if __opts__["test"]:
ret["result"] = None
ret["comment"] = "Value {1} is set to be appended to grain {0}".format(
Expand All @@ -221,6 +244,39 @@ def list_present(name, value, delimiter=DEFAULT_TARGET_DELIM):
ret["changes"] = {"new": grain}
return ret

# Handle case where grain doesn't exist yet
if not grain:
# Check if values are already pending
if name in pending_grains:
if isinstance(value, list):
if make_hashable(value).issubset(
make_hashable(list(pending_grains[name]))
):
ret["comment"] = f"Value {value} is already pending in grain {name}"
return ret
# Remove already pending values
intersection = set(value).intersection(pending_grains[name])
if intersection:
value = list(set(value).difference(pending_grains[name]))
if not value:
ret["comment"] = (
f'All values already pending in grain "{name}".\n'
)
return ret
else:
if value in pending_grains[name]:
ret["comment"] = f"Value {value} is already pending in grain {name}"
return ret
# Initialize pending_grains if needed
if "pending_grains" not in __context__:
__context__["pending_grains"] = {}
if name not in __context__["pending_grains"]:
__context__["pending_grains"][name] = set()
if isinstance(value, list):
__context__["pending_grains"][name].update(value)
else:
__context__["pending_grains"][name].add(value)

if __opts__["test"]:
ret["result"] = None
ret["comment"] = f"Grain {name} is set to be added"
Expand Down
1 change: 1 addition & 0 deletions salt/utils/extmods.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,5 @@ def sync(
shutil.rmtree(emptydir, ignore_errors=True)
except Exception as exc: # pylint: disable=broad-except
log.error("Failed to sync %s module: %s", form, exc)

return ret, touched
20 changes: 17 additions & 3 deletions salt/utils/relenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,27 @@ def gen_relenv(
if not os.path.isdir(relenv_dir):
os.makedirs(relenv_dir)

relenv_url = get_tarball(kernel, os_arch)
tarball_path = os.path.join(relenv_dir, "salt-relenv.tar.xz")

# Download the tarball if it doesn't exist or overwrite is True
if overwrite or not os.path.exists(tarball_path):
if not download(cachedir, relenv_url, tarball_path):
return False
# Check for shared test cache first (for integration tests)
import shutil
import tempfile

shared_cache = os.path.join(tempfile.gettempdir(), "salt_ssh_test_relenv_cache")
shared_tarball = os.path.join(
shared_cache, "relenv", kernel, os_arch, "salt-relenv.tar.xz"
)

if os.path.exists(shared_tarball):
log.info("Copying tarball from shared test cache: %s", shared_tarball)
shutil.copy(shared_tarball, tarball_path)
else:
# Download from repository
relenv_url = get_tarball(kernel, os_arch)
if not download(cachedir, relenv_url, tarball_path):
return False

return tarball_path

Expand Down
36 changes: 16 additions & 20 deletions tests/pytests/integration/netapi/test_ssh_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,29 +165,25 @@ def test_shell_inject_ssh_priv(
"""
# ZDI-CAN-11143
path = tmp_path / "test-11143"
tgts = ["packages.broadcom.com", "www.zerodayinitiative.com"]
ret = None
for tgt in tgts:
low = {
"roster": "cache",
"client": "ssh",
"tgt": tgt,
"ssh_priv": f"aaa|id>{path} #",
"fun": "test.ping",
"eauth": "auto",
"username": salt_auto_account.username,
"password": salt_auto_account.password,
"roster_file": str(salt_ssh_roster_file),
"rosters": [rosters_dir],
}
ret = client.run(low)
if ret:
break
low = {
"roster": "cache",
"client": "ssh",
"tgt": "127.0.0.1",
"ssh_priv": f"aaa|id>{path} #",
"fun": "test.ping",
"eauth": "auto",
"username": salt_auto_account.username,
"password": salt_auto_account.password,
"roster_file": str(salt_ssh_roster_file),
"rosters": "/",
"ignore_host_keys": True,
}
ret = client.run(low)

assert path.exists() is False
assert ret
assert not ret[tgt]["stdout"]
assert ret[tgt]["stderr"]
assert not ret["127.0.0.1"]["stdout"]
assert ret["127.0.0.1"]["stderr"]


def test_shell_inject_tgt(client, salt_ssh_roster_file, tmp_path, salt_auto_account):
Expand Down
151 changes: 147 additions & 4 deletions tests/pytests/integration/ssh/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,148 @@
import logging
import os
import platform

import pytest

from tests.support.helpers import system_python_version
from tests.support.pytest.helpers import reap_stray_processes

log = logging.getLogger(__name__)


@pytest.fixture(scope="session", autouse=True)
def relenv_tarball_cached(tmp_path_factory):
"""
Pre-cache the relenv tarball once for the entire test session to a shared location.
This avoids downloading it multiple times across different test modules.
Runs automatically at session start (autouse=True).
"""
# Import here to avoid issues if salt is not installed
import tempfile

import salt.utils.relenv

# Use a shared system temp directory that persists across test master instances
# This allows all tests in the session to share the same cached tarball
shared_cache = os.path.join(tempfile.gettempdir(), "salt_ssh_test_relenv_cache")
os.makedirs(shared_cache, exist_ok=True)

# Detect OS and architecture
kernel = platform.system().lower()
if kernel == "darwin":
kernel = "darwin"
elif kernel == "windows":
kernel = "windows"
else:
kernel = "linux"

machine = platform.machine().lower()
if machine in ("amd64", "x86_64"):
os_arch = "x86_64"
elif machine in ("aarch64", "arm64"):
os_arch = "arm64"
else:
os_arch = machine

log.info(
"Pre-caching relenv tarball for %s/%s in shared cache: %s",
kernel,
os_arch,
shared_cache,
)

# Try to copy from local artifacts first (for CI/test environments)
import glob
import shutil

artifacts_tarball = f"/salt/artifacts/salt-*-onedir-{kernel}-{os_arch}.tar.xz"
matching_files = glob.glob(artifacts_tarball)
if matching_files:
source_tarball = matching_files[0]
dest_dir = os.path.join(shared_cache, "relenv", kernel, os_arch)
os.makedirs(dest_dir, exist_ok=True)
dest_tarball = os.path.join(dest_dir, "salt-relenv.tar.xz")

try:
shutil.copy(source_tarball, dest_tarball)
file_size = os.path.getsize(dest_tarball) / (1024 * 1024) # Size in MB
log.info(
"Copied local tarball from %s to %s (%.2f MB)",
source_tarball,
dest_tarball,
file_size,
)
return dest_tarball
except Exception as e: # pylint: disable=broad-exception-caught
log.warning("Failed to copy local tarball: %s", e)

# Fall back to downloading if local tarball not available
try:
tarball_path = salt.utils.relenv.gen_relenv(shared_cache, kernel, os_arch)
log.info("Relenv tarball cached at: %s", tarball_path)

if os.path.exists(tarball_path):
file_size = os.path.getsize(tarball_path) / (1024 * 1024) # Size in MB
log.info("Cached tarball size: %.2f MB", file_size)
return tarball_path
else:
log.warning(
"Tarball download completed but file not found at: %s", tarball_path
)
return None
except Exception as e: # pylint: disable=broad-exception-caught
# Broad exception is intentional - we don't want relenv caching failures to break test setup
log.warning("Failed to pre-cache relenv tarball: %s", e)
return None


@pytest.fixture(scope="module", params=["thin", "relenv"], ids=["thin", "relenv"])
def ssh_deployment_type(request):
"""
Fixture to parameterize tests with both thin and relenv deployments.
The relenv_tarball_cached autouse fixture pre-caches the tarball at session start.
"""
return request.param


@pytest.fixture(scope="function")
def salt_ssh_cli_parameterized(
ssh_deployment_type,
salt_master,
salt_ssh_roster_file,
sshd_config_dir,
known_hosts_file,
):
"""
Parameterized salt-ssh CLI fixture that tests with both thin and relenv deployments.

Note: This uses function scope (not module scope) to ensure each test gets a fresh
SSH instance. This is necessary because the SSH class conditionally initializes
self.thin based on opts['relenv'], and with parametrized tests, we need a new
instance for each deployment type to avoid shared state issues.
"""
assert salt_master.is_running()
cli = salt_master.salt_ssh_cli(
timeout=180,
roster_file=salt_ssh_roster_file,
target_host="localhost",
client_key=str(sshd_config_dir / "client_key"),
)

# Wrap the run method to inject --relenv flag when needed
original_run = cli.run

def run_with_deployment(*args, **kwargs):
if ssh_deployment_type == "relenv":
# Filter out -t/--thin flags which are incompatible with --relenv
filtered_args = tuple(arg for arg in args if arg not in ("-t", "--thin"))
# Insert --relenv flag at the beginning
args = ("--relenv",) + filtered_args
return original_run(*args, **kwargs)

cli.run = run_with_deployment
return cli


@pytest.fixture(scope="package", autouse=True)
def _auto_skip_on_system_python_too_recent(grains):
Expand Down Expand Up @@ -109,17 +249,19 @@ def state_tree_dir(base_env_state_tree_root_dir):
State tree with files to test salt-ssh
when the map.jinja file is in another directory
"""
# Remove unused import from top file to avoid salt-ssh file sync issues
# Use "testdir" instead of "test" to avoid conflicts with state_tree fixture
top_file = """
{%- from "test/map.jinja" import abc with context %}
base:
'localhost':
- test
- testdir
'127.0.0.1':
- test
- testdir
"""
map_file = """
{%- set abc = "def" %}
"""
# State file imports from subdirectory - this is what we're testing
state_file = """
{%- from "test/map.jinja" import abc with context %}

Expand All @@ -132,8 +274,9 @@ def state_tree_dir(base_env_state_tree_root_dir):
map_tempfile = pytest.helpers.temp_file(
"test/map.jinja", map_file, base_env_state_tree_root_dir
)
# Use testdir.sls to avoid collision with state_tree's test.sls
state_tempfile = pytest.helpers.temp_file(
"test.sls", state_file, base_env_state_tree_root_dir
"testdir.sls", state_file, base_env_state_tree_root_dir
)

with top_tempfile, map_tempfile, state_tempfile:
Expand Down
Loading