Add exempt

This commit is contained in:
wcorrales
2021-03-07 12:17:58 -06:00
parent 9e9894dd8d
commit 17b3d3a2ce
3 changed files with 45 additions and 3 deletions

View File

@@ -4,7 +4,7 @@ import os
from urllib.parse import urlparse
from itsdangerous import BadData, SignatureExpired, URLSafeTimedSerializer
from quart import current_app, g, request, session
from quart import Blueprint, current_app, g, request, session
from werkzeug.exceptions import BadRequest
from werkzeug.security import safe_str_cmp
from wtforms import ValidationError
@@ -143,6 +143,9 @@ class CSRFProtect:
"""
def __init__(self, app=None):
self._exempt_views = set()
self._exempt_blueprints = set()
if app:
self.init_app(app)
@@ -178,6 +181,15 @@ class CSRFProtect:
if not request.endpoint:
return
if request.blueprint in self._exempt_blueprints:
return
view = app.view_functions.get(request.endpoint)
dest = f'{view.__module__}.{view.__name__}'
if dest in self._exempt_views:
return
await self.protect()
async def _get_csrf_token(self):
@@ -227,6 +239,30 @@ class CSRFProtect:
g.csrf_valid = True # mark this request as CSRF valid
def exempt(self, view):
"""Mark a view or blueprint to be excluded from CSRF protection.
::
@app.route('/some-view', methods=['POST'])
@csrf.exempt
def some_view():
...
::
bp = Blueprint(...)
csrf.exempt(bp)
"""
if isinstance(view, Blueprint):
self._exempt_blueprints.add(view.name)
return view
if isinstance(view, str):
view_location = view
else:
view_location = '.'.join((view.__module__, view.__name__))
self._exempt_views.add(view_location)
return view
def _error_response(self, reason):
raise CSRFError(reason)