From 6a26031d348336117e04c8505c461c2cb784d99e Mon Sep 17 00:00:00 2001 From: David Schultz Date: Wed, 29 Oct 2025 17:41:34 -0500 Subject: [PATCH 01/12] initial cred client --- iceprod/website/server.py | 93 +++++++++++---------------------------- 1 file changed, 26 insertions(+), 67 deletions(-) diff --git a/iceprod/website/server.py b/iceprod/website/server.py index 9fc1800c..94d497e0 100644 --- a/iceprod/website/server.py +++ b/iceprod/website/server.py @@ -27,6 +27,7 @@ import tornado.gen import jwt import tornado.concurrent +import requests.exceptions from rest_tools.client import RestClient, ClientCredentialsAuth from rest_tools.server import catch_error, RestServer, RestHandlerSetup, RestHandler, OpenIDCookieHandlerMixin, OpenIDLoginHandler from rest_tools.server.session import SessionMixin, Session @@ -219,108 +220,66 @@ def clear_tokens(self): self._session_mgr.delete_session(username) -class TokenStorageMixin(OpenIDCookieHandlerMixin, RestHandler): +class TokenStorageMixin(RestHandler): """ - Store/load current user's `OpenIDLoginHandler` tokens in iceprod credentials API. + Store/load current user's tokens in iceprod credentials API. """ - def initialize(self, cred_rest_client: RestClient, full_url: str, **kwargs): # type: ignore + def initialize(self, *args, cred_rest_client, **kwargs): super().initialize(**kwargs) self.cred_rest_client = cred_rest_client - self.full_url = full_url - def get_current_user(self): - return None - - async def get_current_user_async(self): - """Get the current user, and set auth-related attributes.""" + @authenticated + async def get_cred_tokens(self, url): + """Get selected tokens from the credential service.""" try: assert self.auth - username = self.get_secure_cookie('iceprod_username') - if not username: - return None - if isinstance(username, bytes): - username = username.decode('utf-8') - creds = await self.cred_rest_client.request('GET', f'/users/{username}/credentials', {'url': self.full_url}) - cred = creds[self.full_url] - access_token = cred['access_token'] - try: - data = self.auth.validate(access_token) - except jwt.ExpiredSignatureError: - logger.debug('user access_token expired') - return None - self.auth_data = data - - # lookup groups - auth_groups = set() - try: - for name in GROUPS: - for expression in GROUPS[name]: - ret = eval_expression(data, expression) - auth_groups.update(match.expand(name) for match in ret) - except Exception: - logger.info('cannot determine groups', exc_info=True) - self.auth_groups = sorted(auth_groups) - - self.auth_access_token = access_token - self.auth_refresh_token = cred.get('refresh_token', '') - return username - - except Exception: - logger.debug('failed auth', exc_info=True) + username = self.current_user + creds = await self.cred_rest_client.request('GET', f'/users/{username}/credentials', {'url': url}) + return creds[url] + except requests.exceptions.RequestException: + logger.warning('failed to get credentials', exc_info=True) return None - def store_tokens( + @authenticated + async def put_cred_tokens( self, + url, access_token, - access_token_exp, refresh_token=None, - refresh_token_exp=None, - user_info=None, - user_info_exp=None, ): """ - Store jwt tokens and user info from OpenID-compliant auth source. + Store jwt tokens from OpenID-compliant auth source. Args: + url (str): site url access_token (str): jwt access token - access_token_exp (int): access token expiration in seconds refresh_token (str): jwt refresh token - refresh_token_exp (int): refresh token expiration in seconds - user_info (dict): user info (from id token or user info lookup) - user_info_exp (int): user info expiration in seconds """ assert self.auth - if not user_info: - user_info = self.auth.validate(access_token) - username = user_info.get('preferred_username') - if not username: - username = user_info.get('upn') - if not username: - raise tornado.web.HTTPError(400, reason='no username in token') + username = self.current_user args = { - 'url': self.full_url, + 'url': url, 'type': 'oauth', 'access_token': access_token, } if refresh_token: args['refresh_token'] = refresh_token - self.cred_rest_client.request_seq('POST', f'/users/{username}/credentials', args) - - self.set_secure_cookie('iceprod_username', username, expires_days=30) + await self.cred_rest_client.request('POST', f'/users/{username}/credentials', args) - def clear_tokens(self): + async def clear_cred_tokens(self): """ - Clear token data, usually on logout. + Clear all token data. """ - self.clear_cookie('iceprod_username') + username = self.current_user + await self.cred_rest_client.request('DELETE', f'/users/{username}/credentials', {}) class Login(LoginMixin, PromRequestMixin, OpenIDLoginHandler): # type: ignore pass -class PublicHandler(LoginMixin, PromRequestMixin, RestHandler): +class PublicHandler(LoginMixin, TokenStorageMixin, PromRequestMixin, RestHandler): """Default Handler""" def initialize(self, rest_api, cred_rest_client, system_rest_client, **kwargs): # type: ignore """ @@ -869,7 +828,6 @@ def __init__(self): raise RuntimeError('ICEPROD_CRED_CLIENT_ID or ICEPROD_CRED_CLIENT_SECRET not specified, and CI_TESTING not enabled!') handler_args = RestHandlerSetup(rest_config) - handler_args['cred_rest_client'] = cred_client if config.CI_TESTING: self.session = Session() else: @@ -906,6 +864,7 @@ def __init__(self): handler_args.update({ 'rest_api': rest_address, + 'cred_rest_client': cred_client, 'system_rest_client': rest_client, }) if config.COOKIE_SECRET: From edb49cc43ea1a1f2cc8ac8ec9a4485dfa6da27dc Mon Sep 17 00:00:00 2001 From: David Schultz Date: Wed, 29 Oct 2025 18:00:46 -0500 Subject: [PATCH 02/12] remove dup cred_rest_client --- iceprod/website/server.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/iceprod/website/server.py b/iceprod/website/server.py index 94d497e0..a22fbd4e 100644 --- a/iceprod/website/server.py +++ b/iceprod/website/server.py @@ -23,10 +23,7 @@ from cachetools.func import ttl_cache from prometheus_client import Info, start_http_server import tornado.web -import tornado.httpserver -import tornado.gen import jwt -import tornado.concurrent import requests.exceptions from rest_tools.client import RestClient, ClientCredentialsAuth from rest_tools.server import catch_error, RestServer, RestHandlerSetup, RestHandler, OpenIDCookieHandlerMixin, OpenIDLoginHandler @@ -38,7 +35,6 @@ from iceprod.roles_groups import GROUPS from iceprod.core.config import CONFIG_SCHEMA as DATASET_SCHEMA from iceprod.server.config import CONFIG_SCHEMA as SERVER_SCHEMA -import iceprod.core.functions from iceprod.server import documentation import iceprod.server.states from iceprod.server.util import datetime2str, nowstr @@ -281,17 +277,15 @@ class Login(LoginMixin, PromRequestMixin, OpenIDLoginHandler): # type: ignore class PublicHandler(LoginMixin, TokenStorageMixin, PromRequestMixin, RestHandler): """Default Handler""" - def initialize(self, rest_api, cred_rest_client, system_rest_client, **kwargs): # type: ignore + def initialize(self, rest_api, system_rest_client, **kwargs): # type: ignore """ Get some params from the website module :param rest_api: the rest api url - :param cred_rest_client: the rest api url for the cred service :param system_rest_client: the rest client for the system role """ super().initialize(**kwargs) self.rest_api = rest_api - self.cred_rest_client = cred_rest_client self.system_rest_client = system_rest_client self.rest_client: RestClient | None = None From 5a5853bb7a2cd4b48ba722e115ac9618ab7ffdf0 Mon Sep 17 00:00:00 2001 From: David Schultz Date: Wed, 29 Oct 2025 18:10:00 -0500 Subject: [PATCH 03/12] fix container build --- .github/workflows/container.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/container.yml b/.github/workflows/container.yml index b736b22e..65f452cd 100644 --- a/.github/workflows/container.yml +++ b/.github/workflows/container.yml @@ -11,3 +11,4 @@ jobs: with: image_namespace: wipacrepo image_name: iceprod + mode: BUILD From 84520dd064dbe0aae098be8e99f2ecc3592a0af7 Mon Sep 17 00:00:00 2001 From: David Schultz Date: Tue, 11 Nov 2025 16:20:43 -0600 Subject: [PATCH 04/12] bump to condor 25 --- iceprod/server/plugins/condor.py | 2 +- tests/server/plugins/condor_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/iceprod/server/plugins/condor.py b/iceprod/server/plugins/condor.py index 8fbdba97..27b755ca 100644 --- a/iceprod/server/plugins/condor.py +++ b/iceprod/server/plugins/condor.py @@ -19,7 +19,7 @@ import time from typing import Generator, NamedTuple -import htcondor # type: ignore +import htcondor2 as htcondor from wipac_dev_tools.prometheus_tools import GlobalLabels, AsyncPromWrapper, PromWrapper, AsyncPromTimer, PromTimer from iceprod.core.config import Task diff --git a/tests/server/plugins/condor_test.py b/tests/server/plugins/condor_test.py index bd7546e4..26780533 100644 --- a/tests/server/plugins/condor_test.py +++ b/tests/server/plugins/condor_test.py @@ -6,7 +6,7 @@ import time from unittest.mock import MagicMock, AsyncMock -import htcondor +import htcondor2 as htcondor import pytest from iceprod.core.config import Job, Task From dbf2d294032ae4bc5966bfa2f80ebdd9457d5ecf Mon Sep 17 00:00:00 2001 From: David Schultz Date: Tue, 11 Nov 2025 16:21:15 -0600 Subject: [PATCH 05/12] complete bump to condor 25 --- iceprod/credentials/server.py | 204 ++++++++++--- iceprod/credentials/service.py | 95 +++--- iceprod/rest/base_handler.py | 6 + .../website/data/www_templates/profile.html | 4 - iceprod/website/server.py | 271 +++++++++++++----- pyproject.toml | 4 +- tests/credentials/test_server.py | 226 +++++++++++++-- tests/credentials/test_service.py | 98 +++++-- 8 files changed, 710 insertions(+), 198 deletions(-) diff --git a/iceprod/credentials/server.py b/iceprod/credentials/server.py index 8592b5b8..735a89c4 100644 --- a/iceprod/credentials/server.py +++ b/iceprod/credentials/server.py @@ -8,6 +8,7 @@ import time from typing import Any +import jwt from prometheus_client import Info, start_http_server import pymongo import pymongo.errors @@ -24,7 +25,8 @@ from iceprod.rest.auth import authorization from iceprod.rest.base_handler import IceProdRestConfig, APIBase from iceprod.server.util import nowstr, datetime2str -from .service import RefreshService, get_expiration, is_expired +from .service import RefreshService +from .util import get_expiration, is_expired logger = logging.getLogger('server') @@ -74,6 +76,7 @@ async def create(self, db, base_data): argo.add_argument('secret_key', type=str, default='', required=False) argo.add_argument('access_token', type=str, default='', required=False) argo.add_argument('refresh_token', type=str, default='', required=False) + argo.add_argument('scope', type=str, default=None, required=False) argo.add_argument('expire_date', type=float, default=now, required=False) argo.add_argument('last_use', type=float, default=now, required=False) args = vars(argo.parse_args()) @@ -103,13 +106,18 @@ async def create(self, db, base_data): raise HTTPError(400, reason='must specify either access or refresh tokens') data['access_token'] = args['access_token'] data['refresh_token'] = args['refresh_token'] + if args['scope']: + base_data['scope'] = args['scope'] + data['scope'] = args['scope'] data['expiration'] = args['expire_date'] data['last_use'] = args['last_use'] + if data['access_token'] and data['scope'] is None: + data['scope'] = jwt.decode(data['access_token'], options={"verify_signature": False}).get('scope', None) - if 'refresh_token' in data and not data.get('access_token', ''): + if 'refresh_token' in data and not data['access_token']: new_cred = await self.refresh_service.refresh_cred(data) data.update(new_cred) - if (not data.get('access_token', '')) and data.get('expiration') == now: + if (not data['access_token']) and data.get('expiration') == now: data['expiration'] = get_expiration(data['access_token']) else: @@ -130,13 +138,14 @@ async def patch_cred(self, db, base_data): argo.add_argument('secret_key', type=str, default='', required=False) argo.add_argument('access_token', type=str, default='', required=False) argo.add_argument('refresh_token', type=str, default='', required=False) + argo.add_argument('scope', type=str, default='', required=False) argo.add_argument('expiration', type=float, default=0, required=False) argo.add_argument('last_use', type=float, default=0, required=False) args = vars(argo.parse_args()) base_data['url'] = args['url'] data = {} - for key in ('buckets', 'access_key', 'secret_key', 'access_token', 'refresh_token', 'expiration', 'last_use'): + for key in ('buckets', 'access_key', 'secret_key', 'access_token', 'refresh_token', 'scope', 'expiration', 'last_use'): if val := args[key]: data[key] = val @@ -157,6 +166,9 @@ async def search_creds(self, db, base_data): assert self.refresh_service if url := self.get_argument('url', None): base_data['url'] = url + if scope := self.get_argument('scope', None): + base_data['scope'] = scope + logger.info('base_data: %r', base_data) refresh = self.get_argument('norefresh', None) is None @@ -166,23 +178,34 @@ async def search_creds(self, db, base_data): filters['type'] = 'oauth' await db.update_many(filters, {'$set': update_data}) - ret = {} + ret = [] async for row in db.find(base_data, projection={'_id': False}): - ret[row['url']] = row - - for key in list(ret): - cred = ret[key] - if refresh and is_expired(cred) and cred['refresh_token']: + if refresh and is_expired(row) and row['refresh_token']: try: - new_cred = await self.refresh_service.refresh_cred(cred) + new_cred = await self.refresh_service.refresh_cred(row) filters = base_data.copy() - filters['url'] = key - ret[key] = await db.find_one_and_update(filters, {'$set': new_cred}, projection={'_id': False}) + filters['url'] = row['url'] + cred = await db.find_one_and_update(filters, {'$set': new_cred}, projection={'_id': False}) + ret.append(cred) except Exception: - del ret[key] + logging.debug('ignore expired token %r', row) + else: + ret.append(row) return ret + async def delete_cred(self, db, base_data): + argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self) + argo.add_argument('url', type=str, default='', required=False) + argo.add_argument('scope', type=str, default=None, required=False) + body_args = argo.parse_args() + if body_args.url: + base_data['url'] = body_args.url + if body_args.scope is not None: + base_data['scope'] = body_args.scope + + await db.delete_many(base_data) + class GroupCredentialsHandler(BaseCredentialsHandler): """ @@ -221,6 +244,7 @@ async def post(self, groupname): OAuth body args: access_token (str): access token refresh_token (str): refresh token + scope (str): scope of access token expire_date (str): access token expiration, ISO date time in UTC (optional) Args: @@ -237,8 +261,9 @@ async def patch(self, groupname): """ Update a group credential. Usually used to update a specifc field. - Required body args: + Body args: url (str): url of controlled resource + scope (str): (optional) scope of access token Other body args will update a credential. @@ -260,19 +285,12 @@ async def delete(self, groupname): groupname (str): groupname Body args: url (str): (optional) url of controlled resource + scope (str): (optional) scope of access token """ if self.auth_roles == ['user'] and groupname not in self.auth_groups: raise HTTPError(403, 'unauthorized') - args = {'groupname': groupname} - - argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self) - argo.add_argument('url', type=str, default='', required=False) - body_args = argo.parse_args() - if body_args.url: - args['url'] = body_args.url - - await self.db.group_creds.delete_many(args) + await self.delete_cred(self.db.group_creds, {'groupname': groupname}) self.write({}) @@ -313,6 +331,7 @@ async def post(self, username): OAuth body args: access_token (str): access token refresh_token (str): refresh token + scope (str): scope of access token expire_date (str): access token expiration, ISO date time in UTC (optional) Args: @@ -329,8 +348,9 @@ async def patch(self, username): """ Update a user credential. Usually used to update a specifc field. - Required body args: + Body args: url (str): url of controlled resource + scope (str): (optional) scope of access token Other body args will update a credential. @@ -352,21 +372,129 @@ async def delete(self, username): username (str): username Body args: url (str): (optional) url of controlled resource + scope (str): (optional) scope of access token Returns: dict: url: credential dict """ if self.auth_roles == ['user'] and username != self.current_user: raise HTTPError(403, 'unauthorized') - args = {'username': username} + await self.delete_cred(self.db.user_creds, {'username': username}) + self.write({}) - argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self) - argo.add_argument('url', type=str, default='', required=False) - body_args = argo.parse_args() - if body_args.url: - args['url'] = body_args.url - await self.db.user_creds.delete_many(args) +class DatasetCredentialsHandler(BaseCredentialsHandler): + """ + Handle dataset credentials requests. + """ + @authorization(roles=['admin', 'system']) + async def get(self, dataset_id): + """ + Get a datasets's credentials. + + Args: + dataset_id (str): dataset_id + Returns: + dict: url: credential dict + """ + ret = await self.search_creds(self.db.dataset_creds, {'dataset_id': dataset_id}) + self.write(ret) + + @authorization(roles=['admin', 'system']) + async def delete(self, dataset_id): + """ + Delete a dataset's credentials. + + Args: + dataset_id (str): dataset_id + Body args: + url (str): (optional) url of controlled resource + scope (str): (optional) scope of access token + Returns: + dict: url: credential dict + """ + await self.delete_cred(self.db.dataset_creds, {'dataset_id': dataset_id}) + self.write({}) + + +class DatasetTaskCredentialsHandler(BaseCredentialsHandler): + """ + Handle dataset/task credentials requests. + """ + @authorization(roles=['admin', 'system']) + async def get(self, dataset_id, task_name): + """ + Get a datasets's credentials. + + Args: + dataset_id (str): dataset_id + task_name (str): task name + Returns: + dict: url: credential dict + """ + ret = await self.search_creds(self.db.dataset_creds, {'dataset_id': dataset_id, 'task_name': task_name}) + self.write(ret) + + @authorization(roles=['admin', 'system']) + async def post(self, dataset_id, task_name): + """ + Set a dataset credential. Overwrites an existing credential for the specified url. + + Common body args: + url (str): url of controlled resource + type (str): credential type (`s3` or `oauth`) + + S3 body args: + buckets (list): list of buckets for this url, or [] if using virtual-hosted buckets in the url + access_key (str): access key + secret_key (str): secret key + + OAuth body args: + access_token (str): access token + refresh_token (str): refresh token + scope (str): scope of access token + expire_date (str): access token expiration, ISO date time in UTC (optional) + + Args: + dataset_id (str): dataset_id + task_name (str): task name + """ + await self.create(self.db.dataset_creds, {'dataset_id': dataset_id, 'task_name': task_name}) + self.write({}) + + @authorization(roles=['admin', 'system']) + async def patch(self, dataset_id, task_name): + """ + Update a dataset credential. Usually used to update a specifc field. + + Body args: + url (str): url of controlled resource + scope (str): (optional) scope of access token + + Other body args will update a credential. + + Args: + dataset_id (str): dataset_id + task_name (str): task name + """ + await self.patch_cred(self.db.dataset_creds, {'dataset_id': dataset_id, 'task_name': task_name}) + self.write({}) + + @authorization(roles=['admin', 'system']) + async def delete(self, dataset_id, task_name): + """ + Delete a dataset's credentials. + + Args: + dataset_id (str): dataset_id + task_name (str): task name + Body args: + url (str): (optional) url of controlled resource + scope (str): (optional) scope of access token + Returns: + dict: url: credential dict + """ + await self.delete_cred(self.db.dataset_creds, {'dataset_id': dataset_id, 'task_name': task_name}) self.write({}) @@ -483,6 +611,16 @@ def __init__(self): }, 'user_creds': { 'username_index': {'keys': [('username', pymongo.DESCENDING), ('url', pymongo.DESCENDING)], 'unique': True}, + }, + 'dataset_creds': { + 'dataset_index': { + 'keys': [ + ('dataset_id', pymongo.DESCENDING), + ('task_name', pymongo.DESCENDING), + ('url', pymongo.DESCENDING) + ], + 'unique': True, + }, } } @@ -516,6 +654,8 @@ def __init__(self): server.add_route(r'/groups/(?P\w+)/credentials', GroupCredentialsHandler, kwargs) server.add_route(r'/users/(?P\w+)/credentials', UserCredentialsHandler, kwargs) + server.add_route(r'/datasets/(?P\w+)/credentials', DatasetCredentialsHandler, kwargs) + server.add_route(r'/datasets/(?P\w+)/tasks/(?P\w+)/credentials', DatasetTaskCredentialsHandler, kwargs) server.add_route('/healthz', HealthHandler, kwargs) server.add_route(r'/(.*)', Error) diff --git a/iceprod/credentials/service.py b/iceprod/credentials/service.py index bd2906f6..024fd06f 100644 --- a/iceprod/credentials/service.py +++ b/iceprod/credentials/service.py @@ -1,47 +1,14 @@ import asyncio -import json import logging import time -from cachetools.func import ttl_cache import httpx import jwt -from rest_tools.utils.auth import OpenIDAuth - -logger = logging.getLogger('refresh_service') - - -@ttl_cache(maxsize=256, ttl=3600) -def get_auth(url): - return OpenIDAuth(url) +from .util import ClientCreds, get_expiration -def get_expiration(token): - """ - Find a token's expiration time. - - Args: - token (str): jwt token - Returns: - float: expiration unix time - """ - return jwt.decode(token, options={"verify_signature": False})['exp'] - - -def is_expired(cred): - """ - Check if an OAuth credential is expired. - - Will mark credential as expired if the access token has less than 5 seconds left. +logger = logging.getLogger('refresh_service') - Args: - cred (dict): credential dict - Returns: - bool: True if expired - """ - if cred['type'] != 'oauth': - return False - return cred['expiration'] < (time.time() + 5) class RefreshService: @@ -57,7 +24,8 @@ class RefreshService: """ def __init__(self, database, clients, refresh_window, expire_buffer, service_run_interval): self.db = database - self.clients = json.loads(clients) + self.clients = ClientCreds(clients) + self.clients.validate() self.refresh_window = refresh_window * 3600 self.expire_buffer = expire_buffer * 3600 @@ -71,26 +39,31 @@ async def refresh_cred(self, cred): raise Exception('cred does not have a refresh token') openid_url = jwt.decode(cred['refresh_token'], options={"verify_signature": False})['iss'] - if openid_url not in self.clients: + try: + client = self.clients.get_client(openid_url) + except KeyError: raise Exception('jwt issuer not registered') - auth = get_auth(openid_url) # try the refresh token args = { 'grant_type': 'refresh_token', 'refresh_token': cred['refresh_token'], - 'client_id': self.clients[openid_url][0], + 'client_id': client.client_id, } - if len(self.clients[openid_url]) > 1: - args['client_secret'] = self.clients[openid_url][1] + if client.client_secret: + args['client_secret'] = client.client_secret + if cred.get('scope', None) is not None: + args['scope'] = cred['scope'] + + logging.warning('refreshing on %s with args %r', client.auth.token_url, args) new_cred = {} try: - async with httpx.AsyncClient() as client: - r = await client.post(auth.token_url, data=args) + async with httpx.AsyncClient() as http_client: + r = await http_client.post(client.auth.token_url, data=args) r.raise_for_status() req = r.json() - except httpx.HTTPError as exc: + except httpx.HTTPStatusError as exc: logger.debug('%r', exc.response.text) try: req = exc.response.json() @@ -103,11 +76,13 @@ async def refresh_cred(self, cred): new_cred['access_token'] = req['access_token'] new_cred['refresh_token'] = req['refresh_token'] new_cred['expiration'] = get_expiration(req['access_token']) + new_cred['scope'] = req.get('scope', cred.get('scope', None)) logger.debug('%r', new_cred) return new_cred def should_refresh(self, cred): + logger.info('should_refresh for cred %r', cred) now = time.time() refresh_exp = now + self.expire_buffer last_use_date = now - self.refresh_window @@ -133,6 +108,7 @@ async def _run_once(self): 'type': 'oauth', 'last_use': {'$gt': last_use_check}, 'expiration': {'$lt': exp_check}, + 'scope': {'$exists': True}, } user_creds = {} @@ -145,16 +121,15 @@ async def _run_once(self): if not cred['refresh_token']: logger.info('skipping non-refresh token for user %s, url %s', cred['username'], cred['url']) continue - logger.debug('cred: %r', cred) try: if self.should_refresh(cred): args = await self.refresh_cred(cred) - await self.db.user_creds.update_one({'username': cred['username'], 'url': cred['url']}, {'$set': args}) - logger.info('refreshed token for user %s, url %s', cred['username'], cred['url']) + await self.db.user_creds.update_one({'username': cred['username'], 'url': cred['url'], 'scope': cred['scope']}, {'$set': args}) + logger.info('refreshed token for user %s, url %s, scope %s', cred['username'], cred['url'], cred['scope']) else: - logger.info('not yet time to refresh token for user %s, url %s', cred['username'], cred['url']) + logger.info('not yet time to refresh token for user %s, url %s, scope %s', cred['username'], cred['url'], cred['scope']) except Exception: - logger.error('error refreshing token for user %s, url: %s', cred['username'], cred['url'], exc_info=True) + logger.error('error refreshing token for user %s, url: %s, scope %s', cred['username'], cred['url'], cred['scope'], exc_info=True) group_creds = {} async for row in self.db.group_creds.find(filters, {'_id': False}): @@ -176,6 +151,28 @@ async def _run_once(self): except Exception: logger.error('error refreshing token for group %s, url: %s', cred['groupname'], cred['url'], exc_info=True) + + dataset_creds = [] + async for row in self.db.dataset_creds.find(filters, {'_id': False}): + dataset_creds.append(row) + + for cred in dataset_creds: + if cred['type'] != 'oauth': + continue + if not cred['refresh_token']: + logger.info('skipping non-refresh token for dataset %s, task, %s, url %s', cred['dataset_id'], cred['task_name'], cred['url']) + continue + try: + if self.should_refresh(cred): + args = await self.refresh_cred(cred) + await self.db.dataset_creds.update_one({'dataset_id': cred['dataset_id'], 'task_name': cred['task_name'], 'scope': cred['scope'], 'url': cred['url']}, {'$set': args}) + logger.info('refreshed token for dataset %s, task, %s, url %s', cred['dataset_id'], cred['task_name'], cred['url']) + else: + logger.info('not yet time to refresh token for dataset %s, task, %s, url %s', cred['dataset_id'], cred['task_name'], cred['url']) + except Exception: + logger.error('error refreshing token for dataset %s, task, %s, url %s', cred['dataset_id'], cred['task_name'], cred['url'], exc_info=True) + + self.last_success_time = time.time() except Exception: logger.error('error running refresh', exc_info=True) diff --git a/iceprod/rest/base_handler.py b/iceprod/rest/base_handler.py index dd47141f..8f24f89f 100644 --- a/iceprod/rest/base_handler.py +++ b/iceprod/rest/base_handler.py @@ -3,6 +3,7 @@ import motor.motor_asyncio from rest_tools.server import RestHandlerSetup, RestHandler +from tornado.escape import json_encode from iceprod.util import VERSION_STRING from iceprod.prom_utils import PromRequestMixin @@ -50,3 +51,8 @@ def get_current_user(self): logger.info('could not find auth username') return username + + def write(self, chunk: dict | list) -> None: # type: ignore[override] + """Write dict or list to json""" + self.set_header("Content-Type", "application/json; charset=UTF-8") + super().write(json_encode(chunk)) diff --git a/iceprod/website/data/www_templates/profile.html b/iceprod/website/data/www_templates/profile.html index 671be8bf..7e3d26b8 100644 --- a/iceprod/website/data/www_templates/profile.html +++ b/iceprod/website/data/www_templates/profile.html @@ -63,10 +63,6 @@

User Credentials

TYPE: {{ user_creds[url]["type"]}}

{% end %} -
- {% module xsrf_form_html() %} - -