Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions src/scribae/init_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
from __future__ import annotations

from pathlib import Path
from typing import cast

import typer
import yaml

from .project import ProjectConfig, default_project


class InitError(Exception):
"""Raised when initialization cannot proceed."""


def _resolve_output_path(project: str | None, file: Path | None) -> Path:
if project and file:
raise typer.BadParameter("Options --project and --file are mutually exclusive.", param_hint="--project/--file")

if project:
project_path = Path(project).expanduser()
try:
project_path.mkdir(parents=True, exist_ok=True)
except OSError as exc:
raise InitError(f"Unable to create project directory {project_path}: {exc}") from exc
return project_path / "scribae.yaml"

if file:
return file.expanduser()

return Path("scribae.yaml")


def _confirm_overwrite(path: Path, *, force: bool) -> None:
if path.exists() and path.is_dir():
raise InitError(f"Target path {path} is a directory, expected a file.")

if path.exists() and not force:
overwrite = typer.confirm(f"{path} already exists. Overwrite?", default=False)
if not overwrite:
typer.secho("Cancelled; existing file preserved.", err=True, fg=typer.colors.YELLOW)
raise typer.Exit(1)


def _prompt_text(label: str, description: str, example: str, *, default: str, show_default: bool = True) -> str:
typer.echo("")
typer.secho(label, fg=typer.colors.CYAN, bold=True)
typer.echo(description)
typer.secho(f"Example: {example}", fg=typer.colors.MAGENTA)
return cast(str, typer.prompt("Value", default=default, show_default=show_default))


def _split_list(value: str) -> list[str]:
return [item.strip() for item in value.split(",") if item.strip()]


def _collect_project_config() -> ProjectConfig:
defaults = default_project()

typer.secho("Scribae init", fg=typer.colors.GREEN, bold=True)
typer.echo("Let's create a scribae.yaml so Scribae can tailor outputs to your project.")

site_name = _prompt_text(
"Site name",
"The publication or brand name used in prompts and metadata.",
"Acme Labs Blog",
default=defaults["site_name"],
)
domain = _prompt_text(
"Domain",
"The canonical site URL used for metadata and link generation (include https://).",
"https://example.com",
default=defaults["domain"],
)
audience = _prompt_text(
"Audience",
"Describe who you are writing for so the AI can match their expectations.",
"Product managers at SaaS startups",
default=defaults["audience"],
)
tone = _prompt_text(
"Tone",
"The voice and style Scribae should aim for when drafting content.",
"Conversational, clear, and friendly",
default=defaults["tone"],
)

keywords_default = ", ".join(defaults["keywords"])
keywords_raw = _prompt_text(
"Focus keywords",
"Optional seed topics Scribae should keep in mind (comma-separated).",
"python, SEO, content strategy",
default=keywords_default,
show_default=bool(keywords_default),
)
language = _prompt_text(
"Language",
"Primary output language code (use ISO 639-1 where possible).",
"en",
default=defaults["language"],
)

allowed_tags_raw = _prompt_text(
"Allowed HTML tags",
"Optional allowlist for HTML tags in generated content; leave blank for no restriction.",
"p, em, strong, a",
default="",
show_default=False,
)

return {
"site_name": site_name.strip(),
"domain": domain.strip(),
"audience": audience.strip(),
"tone": tone.strip(),
"keywords": _split_list(keywords_raw),
"language": language.strip(),
"allowed_tags": _split_list(allowed_tags_raw) or None,
}


def _render_yaml(config: ProjectConfig) -> str:
payload: dict[str, object] = {
"site_name": config["site_name"],
"domain": config["domain"],
"audience": config["audience"],
"tone": config["tone"],
"language": config["language"],
"keywords": config["keywords"],
}
if config["allowed_tags"] is not None:
payload["allowed_tags"] = config["allowed_tags"]
rendered = yaml.safe_dump(payload, sort_keys=False, allow_unicode=True)
return cast(str, rendered).strip() + "\n"


def init_command(
project: str | None = typer.Option( # noqa: B008
None,
"--project",
"-p",
help="Project directory to create and write scribae.yaml into.",
),
file: Path | None = typer.Option( # noqa: B008
None,
"--file",
"-f",
resolve_path=True,
help="Custom path/filename for the scribae.yaml output.",
),
force: bool = typer.Option( # noqa: B008
False,
"--force",
help="Overwrite existing files without prompting.",
),
) -> None:
"""Initialize a Scribae project configuration file."""

try:
output_path = _resolve_output_path(project, file)
except InitError as exc:
typer.secho(str(exc), err=True, fg=typer.colors.RED)
raise typer.Exit(2) from exc

try:
_confirm_overwrite(output_path, force=force)
except InitError as exc:
typer.secho(str(exc), err=True, fg=typer.colors.RED)
raise typer.Exit(2) from exc

config = _collect_project_config()
yaml_body = _render_yaml(config)

try:
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(yaml_body, encoding="utf-8")
except OSError as exc:
typer.secho(f"Unable to write {output_path}: {exc}", err=True, fg=typer.colors.RED)
raise typer.Exit(2) from exc

typer.secho(f"Wrote {output_path}", fg=typer.colors.GREEN)


__all__ = ["init_command"]
2 changes: 2 additions & 0 deletions src/scribae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .brief_cli import brief_command
from .feedback_cli import feedback_command
from .idea_cli import idea_command
from .init_cli import init_command
from .meta_cli import meta_command
from .translate_cli import translate_command
from .version_cli import version_command
Expand All @@ -27,6 +28,7 @@ def app_callback() -> None:


app.command("idea", help="Brainstorm article ideas from a note with project-aware guidance.")(idea_command)
app.command("init", help="Create a scribae.yaml config via a guided questionnaire.")(init_command)
app.command(
"brief",
help="Generate a validated SEO brief (keywords, outline, FAQ, metadata) from a note.",
Expand Down
Empty file added tests/__init__.py
Empty file.
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import re
from collections.abc import Generator
from typing import Any

import pytest
from faker import Faker


def strip_ansi(text: str) -> str:
"""Remove ANSI escape codes from text."""
return re.sub(r"\x1b\[[0-9;]*m", "", text)


@pytest.fixture(autouse=True)
def stub_mt_pipeline(monkeypatch: pytest.MonkeyPatch) -> None:
def _fake_pipeline(self: object, model_id: str) -> Any: # noqa: ARG001
Expand Down
94 changes: 94 additions & 0 deletions tests/unit/init_cli_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from pathlib import Path

import yaml
from typer.testing import CliRunner

from scribae.main import app
from tests.conftest import strip_ansi

runner = CliRunner()


def _questionnaire_input() -> str:
return "\n".join(
[
"Scribae Blog",
"https://example.com",
"developers and writers",
"friendly and practical",
"seo, content strategy",
"en",
"p, em, a",
]
) + "\n"


def test_init_writes_config_in_current_dir() -> None:
with runner.isolated_filesystem():
result = runner.invoke(app, ["init"], input=_questionnaire_input())

assert result.exit_code == 0
config_path = Path("scribae.yaml")
assert config_path.exists()
payload = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert payload["site_name"] == "Scribae Blog"
assert payload["domain"] == "https://example.com"
assert payload["audience"] == "developers and writers"
assert payload["tone"] == "friendly and practical"
assert payload["keywords"] == ["seo", "content strategy"]
assert payload["language"] == "en"
assert payload["allowed_tags"] == ["p", "em", "a"]


def test_init_writes_config_in_project_dir() -> None:
with runner.isolated_filesystem():
result = runner.invoke(app, ["init", "--project", "demo"], input=_questionnaire_input())

assert result.exit_code == 0
config_path = Path("demo") / "scribae.yaml"
assert config_path.exists()


def test_init_writes_config_to_custom_file() -> None:
with runner.isolated_filesystem():
result = runner.invoke(
app,
["init", "--file", "config/custom.yaml"],
input=_questionnaire_input(),
)

assert result.exit_code == 0
config_path = Path("config") / "custom.yaml"
assert config_path.exists()


def test_init_prompts_before_overwrite() -> None:
with runner.isolated_filesystem():
config_path = Path("scribae.yaml")
config_path.write_text("site_name: old", encoding="utf-8")

result = runner.invoke(app, ["init"], input="n\n")

assert result.exit_code != 0
assert config_path.read_text(encoding="utf-8") == "site_name: old"


def test_init_force_overwrites_existing_file() -> None:
with runner.isolated_filesystem():
config_path = Path("scribae.yaml")
config_path.write_text("site_name: old", encoding="utf-8")

result = runner.invoke(app, ["init", "--force"], input=_questionnaire_input())

assert result.exit_code == 0
payload = yaml.safe_load(config_path.read_text(encoding="utf-8"))
assert payload["site_name"] == "Scribae Blog"


def test_init_rejects_project_and_file_options() -> None:
result = runner.invoke(app, ["init", "--project", "demo", "--file", "config.yaml"])

assert result.exit_code != 0
stderr = strip_ansi(result.stderr)
assert "--project" in stderr
assert "--file" in stderr
10 changes: 5 additions & 5 deletions tests/unit/translate_cli_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import json
import os
import re
from pathlib import Path
from typing import Any, cast

Expand All @@ -13,6 +12,7 @@
from scribae.main import app
from scribae.translate import TranslationConfig
from scribae.translate.markdown_segmenter import TextBlock
from tests.conftest import strip_ansi

runner = CliRunner()

Expand Down Expand Up @@ -332,7 +332,7 @@ def test_translate_requires_input_without_prefetch_only(
)

assert result.exit_code != 0
ansi_stripped = re.sub(r"\x1b\[[0-9;]*m", "", result.stderr)
ansi_stripped = strip_ansi(result.stderr)
assert "--in is required unless --prefetch-only" in ansi_stripped


Expand Down Expand Up @@ -449,7 +449,7 @@ def _raise(self: object, steps: list[object]) -> None:
)

assert result.exit_code != 0
ansi_stripped = re.sub(r"\x1b\[[0-9;]*m", "", result.stderr)
ansi_stripped = strip_ansi(result.stderr)
assert "prefetch failed" in ansi_stripped


Expand All @@ -474,7 +474,7 @@ def test_translate_warns_on_source_mismatch(
)

assert result.exit_code == 0
ansi_stripped = re.sub(r"\x1b\[[0-9;]*m", "", result.stderr)
ansi_stripped = strip_ansi(result.stderr)
assert "detected source language 'fr' does not match --src 'en'" in ansi_stripped


Expand All @@ -496,7 +496,7 @@ def test_translate_rejects_invalid_language_codes(
)

assert result.exit_code != 0
ansi_stripped = re.sub(r"\x1b\[[0-9;]*m", "", result.stderr)
ansi_stripped = strip_ansi(result.stderr)
assert "must be a language code like en or eng_Latn" in ansi_stripped


Expand Down