diff --git a/.github/workflows/container.yml b/.github/workflows/container.yml index b736b22e..65f452cd 100644 --- a/.github/workflows/container.yml +++ b/.github/workflows/container.yml @@ -11,3 +11,4 @@ jobs: with: image_namespace: wipacrepo image_name: iceprod + mode: BUILD diff --git a/iceprod/credentials/server.py b/iceprod/credentials/server.py index 8592b5b8..735a89c4 100644 --- a/iceprod/credentials/server.py +++ b/iceprod/credentials/server.py @@ -8,6 +8,7 @@ import time from typing import Any +import jwt from prometheus_client import Info, start_http_server import pymongo import pymongo.errors @@ -24,7 +25,8 @@ from iceprod.rest.auth import authorization from iceprod.rest.base_handler import IceProdRestConfig, APIBase from iceprod.server.util import nowstr, datetime2str -from .service import RefreshService, get_expiration, is_expired +from .service import RefreshService +from .util import get_expiration, is_expired logger = logging.getLogger('server') @@ -74,6 +76,7 @@ async def create(self, db, base_data): argo.add_argument('secret_key', type=str, default='', required=False) argo.add_argument('access_token', type=str, default='', required=False) argo.add_argument('refresh_token', type=str, default='', required=False) + argo.add_argument('scope', type=str, default=None, required=False) argo.add_argument('expire_date', type=float, default=now, required=False) argo.add_argument('last_use', type=float, default=now, required=False) args = vars(argo.parse_args()) @@ -103,13 +106,18 @@ async def create(self, db, base_data): raise HTTPError(400, reason='must specify either access or refresh tokens') data['access_token'] = args['access_token'] data['refresh_token'] = args['refresh_token'] + if args['scope']: + base_data['scope'] = args['scope'] + data['scope'] = args['scope'] data['expiration'] = args['expire_date'] data['last_use'] = args['last_use'] + if data['access_token'] and data['scope'] is None: + data['scope'] = jwt.decode(data['access_token'], options={"verify_signature": False}).get('scope', None) - if 'refresh_token' in data and not data.get('access_token', ''): + if 'refresh_token' in data and not data['access_token']: new_cred = await self.refresh_service.refresh_cred(data) data.update(new_cred) - if (not data.get('access_token', '')) and data.get('expiration') == now: + if (not data['access_token']) and data.get('expiration') == now: data['expiration'] = get_expiration(data['access_token']) else: @@ -130,13 +138,14 @@ async def patch_cred(self, db, base_data): argo.add_argument('secret_key', type=str, default='', required=False) argo.add_argument('access_token', type=str, default='', required=False) argo.add_argument('refresh_token', type=str, default='', required=False) + argo.add_argument('scope', type=str, default='', required=False) argo.add_argument('expiration', type=float, default=0, required=False) argo.add_argument('last_use', type=float, default=0, required=False) args = vars(argo.parse_args()) base_data['url'] = args['url'] data = {} - for key in ('buckets', 'access_key', 'secret_key', 'access_token', 'refresh_token', 'expiration', 'last_use'): + for key in ('buckets', 'access_key', 'secret_key', 'access_token', 'refresh_token', 'scope', 'expiration', 'last_use'): if val := args[key]: data[key] = val @@ -157,6 +166,9 @@ async def search_creds(self, db, base_data): assert self.refresh_service if url := self.get_argument('url', None): base_data['url'] = url + if scope := self.get_argument('scope', None): + base_data['scope'] = scope + logger.info('base_data: %r', base_data) refresh = self.get_argument('norefresh', None) is None @@ -166,23 +178,34 @@ async def search_creds(self, db, base_data): filters['type'] = 'oauth' await db.update_many(filters, {'$set': update_data}) - ret = {} + ret = [] async for row in db.find(base_data, projection={'_id': False}): - ret[row['url']] = row - - for key in list(ret): - cred = ret[key] - if refresh and is_expired(cred) and cred['refresh_token']: + if refresh and is_expired(row) and row['refresh_token']: try: - new_cred = await self.refresh_service.refresh_cred(cred) + new_cred = await self.refresh_service.refresh_cred(row) filters = base_data.copy() - filters['url'] = key - ret[key] = await db.find_one_and_update(filters, {'$set': new_cred}, projection={'_id': False}) + filters['url'] = row['url'] + cred = await db.find_one_and_update(filters, {'$set': new_cred}, projection={'_id': False}) + ret.append(cred) except Exception: - del ret[key] + logging.debug('ignore expired token %r', row) + else: + ret.append(row) return ret + async def delete_cred(self, db, base_data): + argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self) + argo.add_argument('url', type=str, default='', required=False) + argo.add_argument('scope', type=str, default=None, required=False) + body_args = argo.parse_args() + if body_args.url: + base_data['url'] = body_args.url + if body_args.scope is not None: + base_data['scope'] = body_args.scope + + await db.delete_many(base_data) + class GroupCredentialsHandler(BaseCredentialsHandler): """ @@ -221,6 +244,7 @@ async def post(self, groupname): OAuth body args: access_token (str): access token refresh_token (str): refresh token + scope (str): scope of access token expire_date (str): access token expiration, ISO date time in UTC (optional) Args: @@ -237,8 +261,9 @@ async def patch(self, groupname): """ Update a group credential. Usually used to update a specifc field. - Required body args: + Body args: url (str): url of controlled resource + scope (str): (optional) scope of access token Other body args will update a credential. @@ -260,19 +285,12 @@ async def delete(self, groupname): groupname (str): groupname Body args: url (str): (optional) url of controlled resource + scope (str): (optional) scope of access token """ if self.auth_roles == ['user'] and groupname not in self.auth_groups: raise HTTPError(403, 'unauthorized') - args = {'groupname': groupname} - - argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self) - argo.add_argument('url', type=str, default='', required=False) - body_args = argo.parse_args() - if body_args.url: - args['url'] = body_args.url - - await self.db.group_creds.delete_many(args) + await self.delete_cred(self.db.group_creds, {'groupname': groupname}) self.write({}) @@ -313,6 +331,7 @@ async def post(self, username): OAuth body args: access_token (str): access token refresh_token (str): refresh token + scope (str): scope of access token expire_date (str): access token expiration, ISO date time in UTC (optional) Args: @@ -329,8 +348,9 @@ async def patch(self, username): """ Update a user credential. Usually used to update a specifc field. - Required body args: + Body args: url (str): url of controlled resource + scope (str): (optional) scope of access token Other body args will update a credential. @@ -352,21 +372,129 @@ async def delete(self, username): username (str): username Body args: url (str): (optional) url of controlled resource + scope (str): (optional) scope of access token Returns: dict: url: credential dict """ if self.auth_roles == ['user'] and username != self.current_user: raise HTTPError(403, 'unauthorized') - args = {'username': username} + await self.delete_cred(self.db.user_creds, {'username': username}) + self.write({}) - argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self) - argo.add_argument('url', type=str, default='', required=False) - body_args = argo.parse_args() - if body_args.url: - args['url'] = body_args.url - await self.db.user_creds.delete_many(args) +class DatasetCredentialsHandler(BaseCredentialsHandler): + """ + Handle dataset credentials requests. + """ + @authorization(roles=['admin', 'system']) + async def get(self, dataset_id): + """ + Get a datasets's credentials. + + Args: + dataset_id (str): dataset_id + Returns: + dict: url: credential dict + """ + ret = await self.search_creds(self.db.dataset_creds, {'dataset_id': dataset_id}) + self.write(ret) + + @authorization(roles=['admin', 'system']) + async def delete(self, dataset_id): + """ + Delete a dataset's credentials. + + Args: + dataset_id (str): dataset_id + Body args: + url (str): (optional) url of controlled resource + scope (str): (optional) scope of access token + Returns: + dict: url: credential dict + """ + await self.delete_cred(self.db.dataset_creds, {'dataset_id': dataset_id}) + self.write({}) + + +class DatasetTaskCredentialsHandler(BaseCredentialsHandler): + """ + Handle dataset/task credentials requests. + """ + @authorization(roles=['admin', 'system']) + async def get(self, dataset_id, task_name): + """ + Get a datasets's credentials. + + Args: + dataset_id (str): dataset_id + task_name (str): task name + Returns: + dict: url: credential dict + """ + ret = await self.search_creds(self.db.dataset_creds, {'dataset_id': dataset_id, 'task_name': task_name}) + self.write(ret) + + @authorization(roles=['admin', 'system']) + async def post(self, dataset_id, task_name): + """ + Set a dataset credential. Overwrites an existing credential for the specified url. + + Common body args: + url (str): url of controlled resource + type (str): credential type (`s3` or `oauth`) + + S3 body args: + buckets (list): list of buckets for this url, or [] if using virtual-hosted buckets in the url + access_key (str): access key + secret_key (str): secret key + + OAuth body args: + access_token (str): access token + refresh_token (str): refresh token + scope (str): scope of access token + expire_date (str): access token expiration, ISO date time in UTC (optional) + + Args: + dataset_id (str): dataset_id + task_name (str): task name + """ + await self.create(self.db.dataset_creds, {'dataset_id': dataset_id, 'task_name': task_name}) + self.write({}) + + @authorization(roles=['admin', 'system']) + async def patch(self, dataset_id, task_name): + """ + Update a dataset credential. Usually used to update a specifc field. + + Body args: + url (str): url of controlled resource + scope (str): (optional) scope of access token + + Other body args will update a credential. + + Args: + dataset_id (str): dataset_id + task_name (str): task name + """ + await self.patch_cred(self.db.dataset_creds, {'dataset_id': dataset_id, 'task_name': task_name}) + self.write({}) + + @authorization(roles=['admin', 'system']) + async def delete(self, dataset_id, task_name): + """ + Delete a dataset's credentials. + + Args: + dataset_id (str): dataset_id + task_name (str): task name + Body args: + url (str): (optional) url of controlled resource + scope (str): (optional) scope of access token + Returns: + dict: url: credential dict + """ + await self.delete_cred(self.db.dataset_creds, {'dataset_id': dataset_id, 'task_name': task_name}) self.write({}) @@ -483,6 +611,16 @@ def __init__(self): }, 'user_creds': { 'username_index': {'keys': [('username', pymongo.DESCENDING), ('url', pymongo.DESCENDING)], 'unique': True}, + }, + 'dataset_creds': { + 'dataset_index': { + 'keys': [ + ('dataset_id', pymongo.DESCENDING), + ('task_name', pymongo.DESCENDING), + ('url', pymongo.DESCENDING) + ], + 'unique': True, + }, } } @@ -516,6 +654,8 @@ def __init__(self): server.add_route(r'/groups/(?P\w+)/credentials', GroupCredentialsHandler, kwargs) server.add_route(r'/users/(?P\w+)/credentials', UserCredentialsHandler, kwargs) + server.add_route(r'/datasets/(?P\w+)/credentials', DatasetCredentialsHandler, kwargs) + server.add_route(r'/datasets/(?P\w+)/tasks/(?P\w+)/credentials', DatasetTaskCredentialsHandler, kwargs) server.add_route('/healthz', HealthHandler, kwargs) server.add_route(r'/(.*)', Error) diff --git a/iceprod/credentials/service.py b/iceprod/credentials/service.py index bd2906f6..024fd06f 100644 --- a/iceprod/credentials/service.py +++ b/iceprod/credentials/service.py @@ -1,47 +1,14 @@ import asyncio -import json import logging import time -from cachetools.func import ttl_cache import httpx import jwt -from rest_tools.utils.auth import OpenIDAuth - -logger = logging.getLogger('refresh_service') - - -@ttl_cache(maxsize=256, ttl=3600) -def get_auth(url): - return OpenIDAuth(url) +from .util import ClientCreds, get_expiration -def get_expiration(token): - """ - Find a token's expiration time. - - Args: - token (str): jwt token - Returns: - float: expiration unix time - """ - return jwt.decode(token, options={"verify_signature": False})['exp'] - - -def is_expired(cred): - """ - Check if an OAuth credential is expired. - - Will mark credential as expired if the access token has less than 5 seconds left. +logger = logging.getLogger('refresh_service') - Args: - cred (dict): credential dict - Returns: - bool: True if expired - """ - if cred['type'] != 'oauth': - return False - return cred['expiration'] < (time.time() + 5) class RefreshService: @@ -57,7 +24,8 @@ class RefreshService: """ def __init__(self, database, clients, refresh_window, expire_buffer, service_run_interval): self.db = database - self.clients = json.loads(clients) + self.clients = ClientCreds(clients) + self.clients.validate() self.refresh_window = refresh_window * 3600 self.expire_buffer = expire_buffer * 3600 @@ -71,26 +39,31 @@ async def refresh_cred(self, cred): raise Exception('cred does not have a refresh token') openid_url = jwt.decode(cred['refresh_token'], options={"verify_signature": False})['iss'] - if openid_url not in self.clients: + try: + client = self.clients.get_client(openid_url) + except KeyError: raise Exception('jwt issuer not registered') - auth = get_auth(openid_url) # try the refresh token args = { 'grant_type': 'refresh_token', 'refresh_token': cred['refresh_token'], - 'client_id': self.clients[openid_url][0], + 'client_id': client.client_id, } - if len(self.clients[openid_url]) > 1: - args['client_secret'] = self.clients[openid_url][1] + if client.client_secret: + args['client_secret'] = client.client_secret + if cred.get('scope', None) is not None: + args['scope'] = cred['scope'] + + logging.warning('refreshing on %s with args %r', client.auth.token_url, args) new_cred = {} try: - async with httpx.AsyncClient() as client: - r = await client.post(auth.token_url, data=args) + async with httpx.AsyncClient() as http_client: + r = await http_client.post(client.auth.token_url, data=args) r.raise_for_status() req = r.json() - except httpx.HTTPError as exc: + except httpx.HTTPStatusError as exc: logger.debug('%r', exc.response.text) try: req = exc.response.json() @@ -103,11 +76,13 @@ async def refresh_cred(self, cred): new_cred['access_token'] = req['access_token'] new_cred['refresh_token'] = req['refresh_token'] new_cred['expiration'] = get_expiration(req['access_token']) + new_cred['scope'] = req.get('scope', cred.get('scope', None)) logger.debug('%r', new_cred) return new_cred def should_refresh(self, cred): + logger.info('should_refresh for cred %r', cred) now = time.time() refresh_exp = now + self.expire_buffer last_use_date = now - self.refresh_window @@ -133,6 +108,7 @@ async def _run_once(self): 'type': 'oauth', 'last_use': {'$gt': last_use_check}, 'expiration': {'$lt': exp_check}, + 'scope': {'$exists': True}, } user_creds = {} @@ -145,16 +121,15 @@ async def _run_once(self): if not cred['refresh_token']: logger.info('skipping non-refresh token for user %s, url %s', cred['username'], cred['url']) continue - logger.debug('cred: %r', cred) try: if self.should_refresh(cred): args = await self.refresh_cred(cred) - await self.db.user_creds.update_one({'username': cred['username'], 'url': cred['url']}, {'$set': args}) - logger.info('refreshed token for user %s, url %s', cred['username'], cred['url']) + await self.db.user_creds.update_one({'username': cred['username'], 'url': cred['url'], 'scope': cred['scope']}, {'$set': args}) + logger.info('refreshed token for user %s, url %s, scope %s', cred['username'], cred['url'], cred['scope']) else: - logger.info('not yet time to refresh token for user %s, url %s', cred['username'], cred['url']) + logger.info('not yet time to refresh token for user %s, url %s, scope %s', cred['username'], cred['url'], cred['scope']) except Exception: - logger.error('error refreshing token for user %s, url: %s', cred['username'], cred['url'], exc_info=True) + logger.error('error refreshing token for user %s, url: %s, scope %s', cred['username'], cred['url'], cred['scope'], exc_info=True) group_creds = {} async for row in self.db.group_creds.find(filters, {'_id': False}): @@ -176,6 +151,28 @@ async def _run_once(self): except Exception: logger.error('error refreshing token for group %s, url: %s', cred['groupname'], cred['url'], exc_info=True) + + dataset_creds = [] + async for row in self.db.dataset_creds.find(filters, {'_id': False}): + dataset_creds.append(row) + + for cred in dataset_creds: + if cred['type'] != 'oauth': + continue + if not cred['refresh_token']: + logger.info('skipping non-refresh token for dataset %s, task, %s, url %s', cred['dataset_id'], cred['task_name'], cred['url']) + continue + try: + if self.should_refresh(cred): + args = await self.refresh_cred(cred) + await self.db.dataset_creds.update_one({'dataset_id': cred['dataset_id'], 'task_name': cred['task_name'], 'scope': cred['scope'], 'url': cred['url']}, {'$set': args}) + logger.info('refreshed token for dataset %s, task, %s, url %s', cred['dataset_id'], cred['task_name'], cred['url']) + else: + logger.info('not yet time to refresh token for dataset %s, task, %s, url %s', cred['dataset_id'], cred['task_name'], cred['url']) + except Exception: + logger.error('error refreshing token for dataset %s, task, %s, url %s', cred['dataset_id'], cred['task_name'], cred['url'], exc_info=True) + + self.last_success_time = time.time() except Exception: logger.error('error running refresh', exc_info=True) diff --git a/iceprod/rest/base_handler.py b/iceprod/rest/base_handler.py index dd47141f..8f24f89f 100644 --- a/iceprod/rest/base_handler.py +++ b/iceprod/rest/base_handler.py @@ -3,6 +3,7 @@ import motor.motor_asyncio from rest_tools.server import RestHandlerSetup, RestHandler +from tornado.escape import json_encode from iceprod.util import VERSION_STRING from iceprod.prom_utils import PromRequestMixin @@ -50,3 +51,8 @@ def get_current_user(self): logger.info('could not find auth username') return username + + def write(self, chunk: dict | list) -> None: # type: ignore[override] + """Write dict or list to json""" + self.set_header("Content-Type", "application/json; charset=UTF-8") + super().write(json_encode(chunk)) diff --git a/iceprod/server/plugins/condor.py b/iceprod/server/plugins/condor.py index 8fbdba97..27b755ca 100644 --- a/iceprod/server/plugins/condor.py +++ b/iceprod/server/plugins/condor.py @@ -19,7 +19,7 @@ import time from typing import Generator, NamedTuple -import htcondor # type: ignore +import htcondor2 as htcondor from wipac_dev_tools.prometheus_tools import GlobalLabels, AsyncPromWrapper, PromWrapper, AsyncPromTimer, PromTimer from iceprod.core.config import Task diff --git a/iceprod/website/data/www_templates/profile.html b/iceprod/website/data/www_templates/profile.html index 671be8bf..7e3d26b8 100644 --- a/iceprod/website/data/www_templates/profile.html +++ b/iceprod/website/data/www_templates/profile.html @@ -63,10 +63,6 @@

User Credentials

TYPE: {{ user_creds[url]["type"]}}

{% end %} -
- {% module xsrf_form_html() %} - -