Skip to content

Instantly share code, notes, and snippets.

@kellerza
Last active March 27, 2024 16:32
Show Gist options
  • Save kellerza/8aad3952086b827a9f32516373df1623 to your computer and use it in GitHub Desktop.
Save kellerza/8aad3952086b827a9f32516373df1623 to your computer and use it in GitHub Desktop.
AsyncIO based OAuth Authorization Code Flow using the Microsoft MSAL Python library. Includes an aiohttp server example.
"""AsyncIO based OAuth Authorization Code Flow using the Microsoft MSAL Python library.
The AsyncMSAL class contains more info to perform OAuth & get the required tokens.
Once you have the OAuth tokens store in the session, you are free to make requests
(typically from an aiohttp server's inside a request)
For more info on Authorization Code flow, refer to https://auth0.com/docs/flows/authorization-code-flow
"""
import asyncio
import json
from functools import partial, wraps
from aiohttp import web
from aiohttp.client import ClientSession, _RequestContextManager
from msal import ConfidentialClientApplication, SerializableTokenCache
# Store your tokens etc in ENV (optional)
ENV = None
HTTP_GET = "get"
HTTP_POST = "post"
HTTP_PUT = "put"
HTTP_PATCH = "patch"
HTTP_DELETE = "delete"
HTTP_ALLOWED = [HTTP_GET, HTTP_POST, HTTP_PUT, HTTP_PATCH, HTTP_DELETE]
MY_SCOPE = ["User.Read", "User.Read.All"]
def async_wrap(func):
"""Wrap a function doing I/O to run in an executor thread."""
@wraps(func)
async def run(*args, loop=None, executor=None, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()
pfunc = partial(func, *args, **kwargs)
return await loop.run_in_executor(executor, pfunc)
return run
# These keys will be used on the aiohttp session
TOKEN_CACHE = "token_cache"
FLOW_CACHE = "flow_cache"
USER_EMAIL = "mail"
class AsyncMSAL:
"""
AsyncIO based OAuth using the Microsoft Authentication Library (MSAL) for Python.
Blocking MSAL functions are executed in the executor thread.
Use until such time as MSAL Python gets a true async version...
Tested with MSAL Python 1.13.0
https://github.com/AzureAD/microsoft-authentication-library-for-python
AsyncMSAL is based on the following example app
https://github.com/Azure-Samples/ms-identity-python-webapp/blob/master/app.py#L76
Use as follows:
Get the tokens via oauth
1. initiate_auth_code_flow
https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.initiate_auth_code_flow
The caller is expected to:
1.somehow store this content, typically inside the current session of the server,
2.guide the end user (i.e. resource owner) to visit that auth_uri,
typically with a redirect
3.and then relay this dict and subsequent auth response to
acquire_token_by_auth_code_flow().
[1. and part of 3.] is stored by this class in the aiohttp_session
2. acquire_token_by_auth_code_flow
https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.acquire_token_by_auth_code_flow
Now you are free to make requests (typically from an aiohttp server)
session = await get_session(request)
aiomsal = AsyncMSAL(session)
async with aiomsal.get("https://graph.microsoft.com/v1.0/me") as res:
res = await res.json()
"""
aiohttp_session: ClientSession = None
client_id = ENV.SP_APP_ID if ENV else None
client_credential = ENV.SP_APP_PW if ENV else None
authority = ENV.SP_AUTHORITY if ENV else None
def __init__(self, session):
"""Create the application using the cache.
Based on: https://github.com/Azure-Samples/ms-identity-python-webapp/blob/master/app.py#L76
session: an aiohttp_session.Session object
"""
self.session = session
self._token_cache = SerializableTokenCache()
# _load_token_cache
if session and session.get(TOKEN_CACHE):
self._token_cache.deserialize(session[TOKEN_CACHE])
self.app = ConfidentialClientApplication(
client_id=self.client_id,
client_credential=self.client_credential,
authority=self.authority, # common/oauth2/v2.0/token'
validate_authority=False,
token_cache=self._token_cache,
)
def _save_token_cache(self):
"""Save the token cache if it changed."""
if self._token_cache.has_state_changed:
self.session[TOKEN_CACHE] = self._token_cache.serialize()
def build_auth_code_flow(self, redirect_uri):
"""First step - Start the flow"""
if not self.session:
raise Exception("session required")
self.session[TOKEN_CACHE] = None
self.session[USER_EMAIL] = None
self.session[FLOW_CACHE] = res = self.app.initiate_auth_code_flow(
MY_SCOPE,
redirect_uri=redirect_uri,
) # https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.initiate_auth_code_flow
return res["auth_uri"]
@async_wrap
def async_acquire_token_by_auth_code_flow(self, auth_response):
"""Second step - Acquire token."""
# Assume we have it in the cache (added by /login)
# will raise keryerror if no cache
auth_code_flow = self.session.pop(FLOW_CACHE)
result = self.app.acquire_token_by_auth_code_flow(auth_code_flow, auth_response)
if "error" in result or "id_token_claims" not in result:
raise web.HTTPException(text=result)
self._save_token_cache()
self.session[USER_EMAIL] = result.get("id_token_claims").get(
"preferred_username"
)
@async_wrap
def async_get_token(self):
"""Acquire a token based on username."""
accounts = self.app.get_accounts()
if accounts:
result = self.app.acquire_token_silent(scopes=MY_SCOPE, account=accounts[0])
self._save_token_cache()
return result
return None
async def request(self, method, url, **kwargs):
"""Make a request to url using an oauth session
:param str url: url to send request to
:param str method: type of request (get/put/post/patch/delete)
:param kwargs: extra params to send to the request api
:return: Response of the request
:rtype: aiohttp.Response
"""
if not self.aiohttp_session:
AsyncMSAL.aiohttp_session = ClientSession(trust_env=True)
token = await self.async_get_token()
kwargs = kwargs.copy()
# Ensure headers exist & make a copy
kwargs["headers"] = headers = dict(kwargs.get("headers", {}))
headers["Authorization"] = "Bearer " + token["access_token"]
assert method in HTTP_ALLOWED, "Method must be one of the allowed ones"
if method == HTTP_GET:
kwargs.setdefault("allow_redirects", True)
elif method in [HTTP_POST, HTTP_PUT, HTTP_PATCH]:
headers["Content-type"] = "application/json"
if "data" in kwargs:
kwargs["data"] = json.dumps(kwargs["data"]) # auto convert to json
response = await self.aiohttp_session.request(method, url, **kwargs)
return response
def get(self, url, **kwargs):
"""GET Request."""
return _RequestContextManager(self.request(HTTP_GET, url, **kwargs))
def post(self, url, **kwargs):
"""POST request."""
return _RequestContextManager(self.request(HTTP_POST, url, **kwargs))
"""async_msal example server."""
from aiohttp import web
from aiohttp_session import get_session, new_session, setup
from aiohttp_session.cookie_storage import EncryptedCookieStorage
from .msal_async import AsyncMSAL
ROUTES = web.RouteTableDef()
SESSION_REDIRECT = "session_redirect"
@ROUTES.get("/user/info")
async def user_info(request):
"""Example route to get info from MS Graph API"""
session = await get_session(request)
aiomsal = AsyncMSAL(session)
async with aiomsal.get("https://graph.microsoft.com/v1.0/me") as res:
res = await res.json()
return web.json_response(res)
@ROUTES.get("/user/login/{redirect:.+$}")
async def user_login(request):
"""Start the user Login"""
session = await new_session(request)
session[SESSION_REDIRECT] = request.match_info.get(
SESSION_REDIRECT, session.get(SESSION_REDIRECT, "")
)
aiomsal = AsyncMSAL(session)
redir = aiomsal.build_auth_code_flow(
redirect_uri="https://mysite.com/user/authorized"
)
# Redirect user to sign in
return web.HTTPFound(redir)
@ROUTES.get("/user/authorized")
async def user_authorized(request: web.Request):
"""Process return flow after login."""
session = await get_session(request)
# build a plain dict from the aiohttp server request's url parameters
auth_response = dict(request.rel_url.query.items())
aiomsal = AsyncMSAL(session)
try:
await aiomsal.async_acquire_token_by_auth_code_flow(auth_response)
except Exception as err: # pylint: disable=broad-except
print("<b>Could not get token</b> - async_acquire_token_by_auth_code_flow", err)
raise
# Redirect user to local site
redirect = session.pop(SESSION_REDIRECT, "") or "/user/info"
return web.HTTPFound(f"/{redirect}")
def main():
"""Main web server."""
app = web.Application()
setup(app, EncryptedCookieStorage(b"Thirty two length bytes key."))
app.add_routes(ROUTES)
web.run_app(app)
if __name__ == "__main__":
main()
@svinther
Copy link

I'm also on 3.8.1 but found that the cookie size exceeded 4k, that seemed to have caused the misbehaviour. Changing the aoihttp-session cookie storage to a keybased cookie, e.g https://github.com/zhangkaizhao/aiohttp-session-file fixed this

@kellerza
Copy link
Author

Ok, that is an entirely different issue then. Glad you found a solution. The standard aiohtto_session also have redis & memcached Storage options that works in a similar fashion (storing only a key in the cookie and the data on the server)

@kellerza
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment