58 lines
2.2 KiB
Python
58 lines
2.2 KiB
Python
from functools import wraps
|
|
|
|
from django.core.cache import cache
|
|
from django.http import JsonResponse
|
|
from django.contrib.auth import authenticate
|
|
from django.contrib.auth.models import User
|
|
from django.contrib.auth.backends import ModelBackend
|
|
|
|
from rest_framework.exceptions import AuthenticationFailed
|
|
|
|
from mozilla_django_oidc.auth import OIDCAuthenticationBackend
|
|
from mozilla_django_oidc.contrib.drf import OIDCAuthentication
|
|
|
|
|
|
class CustomOIDCBackend(OIDCAuthenticationBackend):
|
|
def get_username(self, claims):
|
|
if 'preferred_username' in claims and not User.objects.filter(username=claims['preferred_username']).exists():
|
|
print(claims['preferred_username'])
|
|
return claims['preferred_username']
|
|
return super().get_username(claims)
|
|
|
|
def authenticate(self, request, **kwargs):
|
|
"""Hack to use the same auth as DRF"""
|
|
back = OIDCAuthentication()
|
|
try:
|
|
u, tok = back.authenticate(request)
|
|
except AuthenticationFailed:
|
|
u = None
|
|
return u
|
|
|
|
def get_userinfo(self, access_token, id_token, payload):
|
|
userinfo = cache.get(f'userinfo-{access_token}')
|
|
if userinfo is None:
|
|
print("no cache found for userinfo-{access_token} yet.")
|
|
userinfo = super().get_userinfo(access_token, id_token, payload)
|
|
if userinfo:
|
|
cache.set(f'userinfo-{access_token}', userinfo, timeout=60*60*24)
|
|
return userinfo
|
|
|
|
def update_user(self, user, claims): # TODO: update groups?
|
|
return super().update_user(user, claims)
|
|
|
|
def create_user(self, claims): # TODO: add groups?
|
|
return super().create_user(claims)
|
|
|
|
|
|
def login_required(func):
|
|
@wraps(func)
|
|
def wrapper(request, *args, **kwargs):
|
|
if request.META.get("HTTP_AUTHORIZATION", "").startswith("Bearer"):
|
|
if not hasattr(request, "user") or request.user.is_anonymous:
|
|
user = authenticate(request=request)
|
|
if not user:
|
|
return JsonResponse({"error": "Unauthorized"}, status=401)
|
|
request.user = request._cached_user = user
|
|
return func(request, *args, **kwargs)
|
|
return wrapper
|