Skip to content

Commit d341d41

Browse files
committed
security(CASA-28): API Key security enhancements
1 parent f9ce566 commit d341d41

File tree

21 files changed

+1036
-377
lines changed

21 files changed

+1036
-377
lines changed

backend/airweave/api/deps.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,33 @@ async def _authenticate_auth0_user(
5959
return user_context, AuthMethod.AUTH0, {"auth0_id": auth0_user.id}
6060

6161

62+
def _extract_client_ip(request: Request) -> str:
63+
"""Extract client IP from request headers.
64+
65+
Checks X-Forwarded-For header first (for proxied requests),
66+
then falls back to direct client IP.
67+
68+
Args:
69+
----
70+
request (Request): FastAPI request object
71+
72+
Returns:
73+
-------
74+
str: Client IP address or "unknown" if not available
75+
76+
"""
77+
# Check X-Forwarded-For first (for proxied requests)
78+
forwarded_for = request.headers.get("X-Forwarded-For")
79+
if forwarded_for:
80+
# X-Forwarded-For can be a comma-separated list, take the first one (original client)
81+
return forwarded_for.split(",")[0].strip()
82+
83+
# Fallback to direct client IP
84+
return request.client.host if request.client else "unknown"
85+
86+
6287
async def _authenticate_api_key(
63-
db: AsyncSession, api_key: str
88+
db: AsyncSession, api_key: str, request: Request
6489
) -> Tuple[None, AuthMethod, dict, str]:
6590
"""Authenticate API key and return organization ID.
6691
@@ -87,6 +112,14 @@ async def _authenticate_api_key(
87112
api_key_obj = await crud.api_key.get_by_key(db, key=api_key)
88113
org_id = api_key_obj.organization_id
89114

115+
# Log API key usage with structured dimensions (flows to Azure LAW)
116+
client_ip = _extract_client_ip(request)
117+
audit_logger = logger.with_context(event_type="api_key_usage")
118+
audit_logger.info(
119+
f"API key usage: key={api_key_obj.id} org={org_id} ip={client_ip} "
120+
f"endpoint={request.url.path} created_by={api_key_obj.created_by_email}"
121+
)
122+
90123
# Cache the mapping for next time
91124
await context_cache.set_api_key_org_id(api_key, org_id)
92125

@@ -328,7 +361,7 @@ async def get_context(
328361
user_context, auth_method, auth_metadata = await _get_or_fetch_user_context(db, auth0_user)
329362
elif x_api_key:
330363
user_context, auth_method, auth_metadata, api_key_org_id = await _authenticate_api_key(
331-
db, x_api_key
364+
db, x_api_key, request
332365
)
333366

334367
if not auth_method:

backend/airweave/api/v1/endpoints/api_keys.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from airweave.api.context import ApiContext
1111
from airweave.api.router import TrailingSlashRouter
1212
from airweave.core import credentials
13+
from airweave.core.datetime_utils import utc_now_naive
1314

1415
router = TrailingSlashRouter()
1516

@@ -18,7 +19,7 @@
1819
async def create_api_key(
1920
*,
2021
db: AsyncSession = Depends(deps.get_db),
21-
api_key_in: schemas.APIKeyCreate = Body({}), # Default to empty dict if not provided
22+
api_key_in: schemas.APIKeyCreate = Body(default_factory=lambda: schemas.APIKeyCreate()),
2223
ctx: ApiContext = Depends(deps.get_context),
2324
) -> schemas.APIKey:
2425
"""Create a new API key for the current user.
@@ -39,13 +40,22 @@ async def create_api_key(
3940
"""
4041
api_key_obj = await crud.api_key.create(db=db, obj_in=api_key_in, ctx=ctx)
4142

43+
# Audit log: API key creation (flows to Azure LAW)
44+
expiration_days = (api_key_obj.expiration_date - api_key_obj.created_at).days
45+
audit_logger = ctx.logger.with_context(event_type="api_key_created")
46+
audit_logger.info(
47+
f"API key created: {api_key_obj.id} by {api_key_obj.created_by_email} "
48+
f"for org {ctx.organization.id}, expires in {expiration_days} days "
49+
f"({api_key_obj.expiration_date.isoformat()})"
50+
)
51+
4252
# Decrypt the key for the response
4353
decrypted_data = credentials.decrypt(api_key_obj.encrypted_key)
4454
decrypted_key = decrypted_data["key"]
4555

4656
api_key_data = {
4757
"id": api_key_obj.id,
48-
"organization": ctx.organization.id, # Use the user's organization_id
58+
"organization_id": ctx.organization.id,
4959
"created_at": api_key_obj.created_at,
5060
"modified_at": api_key_obj.modified_at,
5161
"last_used_date": None, # New key has no last used date
@@ -88,7 +98,7 @@ async def read_api_key(
8898

8999
api_key_data = {
90100
"id": api_key.id,
91-
"organization": ctx.organization.id,
101+
"organization_id": ctx.organization.id,
92102
"created_at": api_key.created_at,
93103
"modified_at": api_key.modified_at,
94104
"last_used_date": api_key.last_used_date if hasattr(api_key, "last_used_date") else None,
@@ -132,7 +142,7 @@ async def read_api_keys(
132142

133143
api_key_data = {
134144
"id": api_key.id,
135-
"organization": ctx.organization.id,
145+
"organization_id": ctx.organization.id,
136146
"created_at": api_key.created_at,
137147
"modified_at": api_key.modified_at,
138148
"last_used_date": (
@@ -148,6 +158,63 @@ async def read_api_keys(
148158
return result
149159

150160

161+
@router.post("/{id}/rotate", response_model=schemas.APIKey)
162+
async def rotate_api_key(
163+
*,
164+
db: AsyncSession = Depends(deps.get_db),
165+
id: UUID,
166+
ctx: ApiContext = Depends(deps.get_context),
167+
) -> schemas.APIKey:
168+
"""Rotate an API key by creating a new one.
169+
170+
This endpoint creates a new API key with a fresh 90-day expiration.
171+
The old key remains active until its original expiration date.
172+
Users can manage multiple keys or delete the old one manually if desired.
173+
174+
Args:
175+
----
176+
db (AsyncSession): The database session.
177+
id (UUID): The ID of the API key to rotate.
178+
ctx (ApiContext): The current authentication context.
179+
180+
Returns:
181+
-------
182+
schemas.APIKey: The newly created API key with decrypted key value.
183+
184+
Raises:
185+
------
186+
HTTPException: If the API key is not found or user doesn't have access.
187+
188+
"""
189+
# Verify old key exists and user has access
190+
old_key = await crud.api_key.get(db=db, id=id, ctx=ctx)
191+
old_key_schema = schemas.APIKey.model_validate(old_key, from_attributes=True)
192+
193+
# Create new key with default 90-day expiration
194+
new_key_obj = await crud.api_key.create(
195+
db=db,
196+
obj_in=schemas.APIKeyCreate(), # Uses default 90 days
197+
ctx=ctx,
198+
)
199+
200+
# Decrypt the new key for the response
201+
decrypted_data = credentials.decrypt(new_key_obj.encrypted_key)
202+
decrypted_key = decrypted_data["key"]
203+
204+
new_key_schema = schemas.APIKey.model_validate(new_key_obj, from_attributes=True)
205+
new_key_schema.decrypted_key = decrypted_key
206+
207+
# Audit log: API key rotation (flows to Azure LAW)
208+
audit_logger = ctx.logger.with_context(event_type="api_key_rotated")
209+
audit_logger.info(
210+
f"API key rotated: old={old_key_schema.id}, new={new_key_schema.id} "
211+
f"by {new_key_schema.created_by_email} for org {ctx.organization.id}, "
212+
f"new key expires {new_key_schema.expiration_date.isoformat()}"
213+
)
214+
215+
return new_key_schema
216+
217+
151218
@router.delete("/", response_model=schemas.APIKey)
152219
async def delete_api_key(
153220
*,
@@ -181,7 +248,7 @@ async def delete_api_key(
181248
# Create a copy of the data before deletion
182249
api_key_data = {
183250
"id": api_key.id,
184-
"organization": ctx.organization.id,
251+
"organization_id": ctx.organization.id,
185252
"created_at": api_key.created_at,
186253
"modified_at": api_key.modified_at,
187254
"last_used_date": api_key.last_used_date if hasattr(api_key, "last_used_date") else None,
@@ -191,6 +258,14 @@ async def delete_api_key(
191258
"decrypted_key": decrypted_key,
192259
}
193260

261+
# Audit log: API key deletion (flows to Azure LAW)
262+
was_expired = api_key.expiration_date < utc_now_naive()
263+
audit_logger = ctx.logger.with_context(event_type="api_key_deleted")
264+
audit_logger.info(
265+
f"API key deleted: {api_key.id} by {ctx.tracking_email} for org {ctx.organization.id} "
266+
f"(was_expired={was_expired})"
267+
)
268+
194269
# Now delete the API key
195270
await crud.api_key.remove(db=db, id=id, ctx=ctx)
196271

backend/airweave/api/v1/endpoints/users.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from airweave.api.auth import auth0
1818
from airweave.api.context import ApiContext
1919
from airweave.api.router import TrailingSlashRouter
20-
from airweave.core.email_service import send_welcome_email
2120
from airweave.core.exceptions import NotFoundException
2221
from airweave.core.logging import logger
2322
from airweave.core.shared_models import AuthMethod
2423
from airweave.db.unit_of_work import UnitOfWork
24+
from airweave.email.services import send_welcome_email
2525
from airweave.schemas import OrganizationWithRole, User
2626

2727
router = TrailingSlashRouter()

backend/airweave/crud/crud_api_key.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""CRUD operations for the APIKey model."""
22

33
import secrets
4-
from datetime import timedelta
4+
from datetime import datetime, timedelta
55
from typing import Optional
66
from uuid import UUID
77

@@ -45,9 +45,9 @@ async def create(
4545
key = secrets.token_urlsafe(32)
4646
encrypted_key = credentials.encrypt({"key": key})
4747

48-
expiration_date = obj_in.expiration_date or (
49-
utc_now_naive() + timedelta(days=180) # Default to 180 days
50-
)
48+
# Calculate expiration date from days (defaults to 90)
49+
expiration_days = obj_in.expiration_days if obj_in.expiration_days is not None else 90
50+
expiration_date = utc_now_naive() + timedelta(days=expiration_days)
5151

5252
# Create a dictionary with the data instead of using the schema
5353
api_key_data = {
@@ -143,5 +143,36 @@ async def get_by_key(self, db: AsyncSession, *, key: str) -> Optional[APIKey]:
143143

144144
raise NotFoundException("API key not found")
145145

146+
async def get_keys_expiring_in_range(
147+
self,
148+
db: AsyncSession,
149+
start_date: datetime,
150+
end_date: datetime,
151+
) -> list[APIKey]:
152+
"""Get API keys expiring within a date range.
153+
154+
Args:
155+
----
156+
db (AsyncSession): The database session.
157+
start_date (datetime): Start of the date range (inclusive).
158+
end_date (datetime): End of the date range (exclusive).
159+
160+
Returns:
161+
-------
162+
list[APIKey]: List of API keys expiring in the range.
163+
164+
"""
165+
from sqlalchemy import and_, select
166+
167+
query = select(self.model).where(
168+
and_(
169+
self.model.expiration_date >= start_date,
170+
self.model.expiration_date < end_date,
171+
)
172+
)
173+
174+
result = await db.execute(query)
175+
return list(result.scalars().all())
176+
146177

147178
api_key = CRUDAPIKey(APIKey)

backend/airweave/email/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""Email module for Airweave."""
2+
3+
from airweave.email.services import send_email_via_resend, send_welcome_email
4+
from airweave.email.templates import get_api_key_expiration_email
5+
6+
__all__ = [
7+
"send_email_via_resend",
8+
"send_welcome_email",
9+
"get_api_key_expiration_email",
10+
]

backend/airweave/core/email_service.py renamed to backend/airweave/email/services.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,91 @@
1-
"""Simple email service for sending welcome emails via Resend."""
1+
"""Email service for sending emails via Resend."""
22

33
import asyncio
44
import random
55
from datetime import datetime, timedelta, timezone
6+
from typing import Optional
67

78
import resend
89

910
from airweave.core.config import settings
1011
from airweave.core.logging import logger
1112

1213

14+
def _send_email_via_resend_sync(
15+
to_email: str,
16+
subject: str,
17+
html_body: str,
18+
from_email: Optional[str] = None,
19+
scheduled_at: Optional[str] = None,
20+
) -> None:
21+
"""Synchronous email sending function via Resend to be run in a thread pool.
22+
23+
Args:
24+
----
25+
to_email (str): Recipient email address
26+
subject (str): Email subject line
27+
html_body (str): HTML email body
28+
from_email (Optional[str]): Sender email (defaults to settings.RESEND_FROM_EMAIL)
29+
scheduled_at (Optional[str]): ISO 8601 timestamp for scheduled delivery
30+
31+
"""
32+
resend.api_key = settings.RESEND_API_KEY
33+
34+
email_data = {
35+
"from": from_email or settings.RESEND_FROM_EMAIL,
36+
"to": [to_email],
37+
"subject": subject,
38+
"html": html_body,
39+
}
40+
41+
if scheduled_at:
42+
email_data["scheduled_at"] = scheduled_at
43+
44+
resend.Emails.send(email_data)
45+
46+
47+
async def send_email_via_resend(
48+
to_email: str,
49+
subject: str,
50+
html_body: str,
51+
from_email: Optional[str] = None,
52+
scheduled_at: Optional[str] = None,
53+
) -> bool:
54+
"""Send an email via Resend asynchronously.
55+
56+
Args:
57+
----
58+
to_email (str): Recipient email address
59+
subject (str): Email subject line
60+
html_body (str): HTML email body
61+
from_email (Optional[str]): Sender email (defaults to settings.RESEND_FROM_EMAIL)
62+
scheduled_at (Optional[str]): ISO 8601 timestamp for scheduled delivery
63+
64+
Returns:
65+
-------
66+
bool: True if email was sent successfully, False otherwise
67+
68+
"""
69+
if not settings.RESEND_API_KEY or not settings.RESEND_FROM_EMAIL:
70+
logger.debug("RESEND_API_KEY or RESEND_FROM_EMAIL not configured - skipping email")
71+
return False
72+
73+
try:
74+
await asyncio.to_thread(
75+
_send_email_via_resend_sync,
76+
to_email,
77+
subject,
78+
html_body,
79+
from_email,
80+
scheduled_at,
81+
)
82+
logger.info(f"Email sent to {to_email}: {subject}")
83+
return True
84+
except Exception as e:
85+
logger.error(f"Failed to send email to {to_email}: {e}")
86+
return False
87+
88+
1389
def _send_welcome_email_sync(to_email: str, user_name: str) -> None:
1490
"""Synchronous email sending function to be run in a thread pool."""
1591
resend.api_key = settings.RESEND_API_KEY

0 commit comments

Comments
 (0)