Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/container.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ jobs:
with:
image_namespace: wipacrepo
image_name: iceprod
mode: BUILD
204 changes: 172 additions & 32 deletions iceprod/credentials/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
from typing import Any

import jwt
from prometheus_client import Info, start_http_server
import pymongo
import pymongo.errors
Expand All @@ -24,7 +25,8 @@
from iceprod.rest.auth import authorization
from iceprod.rest.base_handler import IceProdRestConfig, APIBase
from iceprod.server.util import nowstr, datetime2str
from .service import RefreshService, get_expiration, is_expired
from .service import RefreshService
from .util import get_expiration, is_expired

logger = logging.getLogger('server')

Expand Down Expand Up @@ -74,6 +76,7 @@ async def create(self, db, base_data):
argo.add_argument('secret_key', type=str, default='', required=False)
argo.add_argument('access_token', type=str, default='', required=False)
argo.add_argument('refresh_token', type=str, default='', required=False)
argo.add_argument('scope', type=str, default=None, required=False)
argo.add_argument('expire_date', type=float, default=now, required=False)
argo.add_argument('last_use', type=float, default=now, required=False)
args = vars(argo.parse_args())
Expand Down Expand Up @@ -103,13 +106,18 @@ async def create(self, db, base_data):
raise HTTPError(400, reason='must specify either access or refresh tokens')
data['access_token'] = args['access_token']
data['refresh_token'] = args['refresh_token']
if args['scope']:
base_data['scope'] = args['scope']
data['scope'] = args['scope']
data['expiration'] = args['expire_date']
data['last_use'] = args['last_use']
if data['access_token'] and data['scope'] is None:
data['scope'] = jwt.decode(data['access_token'], options={"verify_signature": False}).get('scope', None)

if 'refresh_token' in data and not data.get('access_token', ''):
if 'refresh_token' in data and not data['access_token']:
new_cred = await self.refresh_service.refresh_cred(data)
data.update(new_cred)
if (not data.get('access_token', '')) and data.get('expiration') == now:
if (not data['access_token']) and data.get('expiration') == now:
data['expiration'] = get_expiration(data['access_token'])

else:
Expand All @@ -130,13 +138,14 @@ async def patch_cred(self, db, base_data):
argo.add_argument('secret_key', type=str, default='', required=False)
argo.add_argument('access_token', type=str, default='', required=False)
argo.add_argument('refresh_token', type=str, default='', required=False)
argo.add_argument('scope', type=str, default='', required=False)
argo.add_argument('expiration', type=float, default=0, required=False)
argo.add_argument('last_use', type=float, default=0, required=False)
args = vars(argo.parse_args())
base_data['url'] = args['url']

data = {}
for key in ('buckets', 'access_key', 'secret_key', 'access_token', 'refresh_token', 'expiration', 'last_use'):
for key in ('buckets', 'access_key', 'secret_key', 'access_token', 'refresh_token', 'scope', 'expiration', 'last_use'):
if val := args[key]:
data[key] = val

Expand All @@ -157,6 +166,9 @@ async def search_creds(self, db, base_data):
assert self.refresh_service
if url := self.get_argument('url', None):
base_data['url'] = url
if scope := self.get_argument('scope', None):
base_data['scope'] = scope
logger.info('base_data: %r', base_data)

refresh = self.get_argument('norefresh', None) is None

Expand All @@ -166,23 +178,34 @@ async def search_creds(self, db, base_data):
filters['type'] = 'oauth'
await db.update_many(filters, {'$set': update_data})

ret = {}
ret = []
async for row in db.find(base_data, projection={'_id': False}):
ret[row['url']] = row

for key in list(ret):
cred = ret[key]
if refresh and is_expired(cred) and cred['refresh_token']:
if refresh and is_expired(row) and row['refresh_token']:
try:
new_cred = await self.refresh_service.refresh_cred(cred)
new_cred = await self.refresh_service.refresh_cred(row)
filters = base_data.copy()
filters['url'] = key
ret[key] = await db.find_one_and_update(filters, {'$set': new_cred}, projection={'_id': False})
filters['url'] = row['url']
cred = await db.find_one_and_update(filters, {'$set': new_cred}, projection={'_id': False})
ret.append(cred)
except Exception:
del ret[key]
logging.debug('ignore expired token %r', row)
else:
ret.append(row)

return ret

async def delete_cred(self, db, base_data):
argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self)
argo.add_argument('url', type=str, default='', required=False)
argo.add_argument('scope', type=str, default=None, required=False)
body_args = argo.parse_args()
if body_args.url:
base_data['url'] = body_args.url
if body_args.scope is not None:
base_data['scope'] = body_args.scope

await db.delete_many(base_data)


class GroupCredentialsHandler(BaseCredentialsHandler):
"""
Expand Down Expand Up @@ -221,6 +244,7 @@ async def post(self, groupname):
OAuth body args:
access_token (str): access token
refresh_token (str): refresh token
scope (str): scope of access token
expire_date (str): access token expiration, ISO date time in UTC (optional)

Args:
Expand All @@ -237,8 +261,9 @@ async def patch(self, groupname):
"""
Update a group credential. Usually used to update a specifc field.

Required body args:
Body args:
url (str): url of controlled resource
scope (str): (optional) scope of access token

Other body args will update a credential.

Expand All @@ -260,19 +285,12 @@ async def delete(self, groupname):
groupname (str): groupname
Body args:
url (str): (optional) url of controlled resource
scope (str): (optional) scope of access token
"""
if self.auth_roles == ['user'] and groupname not in self.auth_groups:
raise HTTPError(403, 'unauthorized')

args = {'groupname': groupname}

argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self)
argo.add_argument('url', type=str, default='', required=False)
body_args = argo.parse_args()
if body_args.url:
args['url'] = body_args.url

await self.db.group_creds.delete_many(args)
await self.delete_cred(self.db.group_creds, {'groupname': groupname})
self.write({})


Expand Down Expand Up @@ -313,6 +331,7 @@ async def post(self, username):
OAuth body args:
access_token (str): access token
refresh_token (str): refresh token
scope (str): scope of access token
expire_date (str): access token expiration, ISO date time in UTC (optional)

Args:
Expand All @@ -329,8 +348,9 @@ async def patch(self, username):
"""
Update a user credential. Usually used to update a specifc field.

Required body args:
Body args:
url (str): url of controlled resource
scope (str): (optional) scope of access token

Other body args will update a credential.

Expand All @@ -352,21 +372,129 @@ async def delete(self, username):
username (str): username
Body args:
url (str): (optional) url of controlled resource
scope (str): (optional) scope of access token
Returns:
dict: url: credential dict
"""
if self.auth_roles == ['user'] and username != self.current_user:
raise HTTPError(403, 'unauthorized')

args = {'username': username}
await self.delete_cred(self.db.user_creds, {'username': username})
self.write({})

argo = ArgumentHandler(ArgumentSource.JSON_BODY_ARGUMENTS, self)
argo.add_argument('url', type=str, default='', required=False)
body_args = argo.parse_args()
if body_args.url:
args['url'] = body_args.url

await self.db.user_creds.delete_many(args)
class DatasetCredentialsHandler(BaseCredentialsHandler):
"""
Handle dataset credentials requests.
"""
@authorization(roles=['admin', 'system'])
async def get(self, dataset_id):
"""
Get a datasets's credentials.

Args:
dataset_id (str): dataset_id
Returns:
dict: url: credential dict
"""
ret = await self.search_creds(self.db.dataset_creds, {'dataset_id': dataset_id})
self.write(ret)

@authorization(roles=['admin', 'system'])
async def delete(self, dataset_id):
"""
Delete a dataset's credentials.

Args:
dataset_id (str): dataset_id
Body args:
url (str): (optional) url of controlled resource
scope (str): (optional) scope of access token
Returns:
dict: url: credential dict
"""
await self.delete_cred(self.db.dataset_creds, {'dataset_id': dataset_id})
self.write({})


class DatasetTaskCredentialsHandler(BaseCredentialsHandler):
"""
Handle dataset/task credentials requests.
"""
@authorization(roles=['admin', 'system'])
async def get(self, dataset_id, task_name):
"""
Get a datasets's credentials.

Args:
dataset_id (str): dataset_id
task_name (str): task name
Returns:
dict: url: credential dict
"""
ret = await self.search_creds(self.db.dataset_creds, {'dataset_id': dataset_id, 'task_name': task_name})
self.write(ret)

@authorization(roles=['admin', 'system'])
async def post(self, dataset_id, task_name):
"""
Set a dataset credential. Overwrites an existing credential for the specified url.

Common body args:
url (str): url of controlled resource
type (str): credential type (`s3` or `oauth`)

S3 body args:
buckets (list): list of buckets for this url, or [] if using virtual-hosted buckets in the url
access_key (str): access key
secret_key (str): secret key

OAuth body args:
access_token (str): access token
refresh_token (str): refresh token
scope (str): scope of access token
expire_date (str): access token expiration, ISO date time in UTC (optional)

Args:
dataset_id (str): dataset_id
task_name (str): task name
"""
await self.create(self.db.dataset_creds, {'dataset_id': dataset_id, 'task_name': task_name})
self.write({})

@authorization(roles=['admin', 'system'])
async def patch(self, dataset_id, task_name):
"""
Update a dataset credential. Usually used to update a specifc field.

Body args:
url (str): url of controlled resource
scope (str): (optional) scope of access token

Other body args will update a credential.

Args:
dataset_id (str): dataset_id
task_name (str): task name
"""
await self.patch_cred(self.db.dataset_creds, {'dataset_id': dataset_id, 'task_name': task_name})
self.write({})

@authorization(roles=['admin', 'system'])
async def delete(self, dataset_id, task_name):
"""
Delete a dataset's credentials.

Args:
dataset_id (str): dataset_id
task_name (str): task name
Body args:
url (str): (optional) url of controlled resource
scope (str): (optional) scope of access token
Returns:
dict: url: credential dict
"""
await self.delete_cred(self.db.dataset_creds, {'dataset_id': dataset_id, 'task_name': task_name})
self.write({})


Expand Down Expand Up @@ -483,6 +611,16 @@ def __init__(self):
},
'user_creds': {
'username_index': {'keys': [('username', pymongo.DESCENDING), ('url', pymongo.DESCENDING)], 'unique': True},
},
'dataset_creds': {
'dataset_index': {
'keys': [
('dataset_id', pymongo.DESCENDING),
('task_name', pymongo.DESCENDING),
('url', pymongo.DESCENDING)
],
'unique': True,
},
}
}

Expand Down Expand Up @@ -516,6 +654,8 @@ def __init__(self):

server.add_route(r'/groups/(?P<groupname>\w+)/credentials', GroupCredentialsHandler, kwargs)
server.add_route(r'/users/(?P<username>\w+)/credentials', UserCredentialsHandler, kwargs)
server.add_route(r'/datasets/(?P<dataset_id>\w+)/credentials', DatasetCredentialsHandler, kwargs)
server.add_route(r'/datasets/(?P<dataset_id>\w+)/tasks/(?P<task_name>\w+)/credentials', DatasetTaskCredentialsHandler, kwargs)
server.add_route('/healthz', HealthHandler, kwargs)
server.add_route(r'/(.*)', Error)

Expand Down
Loading
Loading