| 3 | | {{{#!diff |
| 4 | | diff --git a/mediagoblin/plugins/oauth/__init__.py b/mediagoblin/plugins/oauth/__init__.py |
| 5 | | index 4714d95..5762379 100644 |
| 6 | | --- a/mediagoblin/plugins/oauth/__init__.py |
| 7 | | +++ b/mediagoblin/plugins/oauth/__init__.py |
| 8 | | @@ -34,7 +34,7 @@ def setup_plugin(): |
| 9 | | _log.debug('OAuth config: {0}'.format(config)) |
| 10 | | |
| 11 | | routes = [ |
| 12 | | - ('mediagoblin.plugins.oauth.authorize', |
| 13 | | + ('mediagoblin.plugins.oauth.authorize', |
| 14 | | '/oauth/authorize', |
| 15 | | 'mediagoblin.plugins.oauth.views:authorize'), |
| 16 | | ('mediagoblin.plugins.oauth.authorize_client', |
| 17 | | diff --git a/mediagoblin/plugins/oauth/migrations.py b/mediagoblin/plugins/oauth/migrations.py |
| 18 | | index 6aa0d7c..f70a2e8 100644 |
| 19 | | --- a/mediagoblin/plugins/oauth/migrations.py |
| 20 | | +++ b/mediagoblin/plugins/oauth/migrations.py |
| 21 | | @@ -102,6 +102,22 @@ class OAuthCode_v0(declarative_base()): |
| 22 | | client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False) |
| 23 | | |
| 24 | | |
| 25 | | +class OAuthRefreshToken_v0(declarative_base()): |
| 26 | | + __tablename__ = 'oauth__refresh_tokens' |
| 27 | | + |
| 28 | | + id = Column(Integer, primary_key=True) |
| 29 | | + created = Column(DateTime, nullable=False, |
| 30 | | + default=datetime.now) |
| 31 | | + |
| 32 | | + token = Column(Unicode, index=True) |
| 33 | | + |
| 34 | | + user_id = Column(Integer, ForeignKey(User.id), nullable=False, |
| 35 | | + index=True) |
| 36 | | + |
| 37 | | + # XXX: Is it OK to use OAuthClient_v0.id in this way? |
| 38 | | + client_id = Column(Integer, ForeignKey(OAuthClient_v0.id), nullable=False) |
| 39 | | + |
| 40 | | + |
| 41 | | @RegisterMigration(1, MIGRATIONS) |
| 42 | | def remove_and_replace_token_and_code(db): |
| 43 | | metadata = MetaData(bind=db.bind) |
| 44 | | @@ -122,3 +138,22 @@ def remove_and_replace_token_and_code(db): |
| 45 | | OAuthCode_v0.__table__.create(db.bind) |
| 46 | | |
| 47 | | db.commit() |
| 48 | | + |
| 49 | | + |
| 50 | | +@RegisterMigration(2, MIGRATIONS) |
| 51 | | +def remove_refresh_token_field(db): |
| 52 | | + metadata = MetaData(bind=db.bind) |
| 53 | | + |
| 54 | | + token_table = Table('oauth__tokens', metadata, autoload=True, |
| 55 | | + autoload_with=db.bind) |
| 56 | | + |
| 57 | | + refresh_token = token_table.columns['refresh_token'] |
| 58 | | + |
| 59 | | + refresh_token.drop() |
| 60 | | + db.commit() |
| 61 | | + |
| 62 | | +@RegisterMigration(3, MIGRATIONS) |
| 63 | | +def create_refresh_token_table(db): |
| 64 | | + OAuthRefreshToken_v0.__table__.create(db.bind) |
| 65 | | + |
| 66 | | + db.commit() |
| 67 | | diff --git a/mediagoblin/plugins/oauth/models.py b/mediagoblin/plugins/oauth/models.py |
| 68 | | index 695dad3..28735dd 100644 |
| 69 | | --- a/mediagoblin/plugins/oauth/models.py |
| 70 | | +++ b/mediagoblin/plugins/oauth/models.py |
| 71 | | @@ -19,12 +19,14 @@ import bcrypt |
| 72 | | |
| 73 | | from datetime import datetime, timedelta |
| 74 | | |
| 75 | | -from mediagoblin.db.base import Base |
| 76 | | -from mediagoblin.db.models import User |
| 77 | | |
| 78 | | from sqlalchemy import ( |
| 79 | | Column, Unicode, Integer, DateTime, ForeignKey, Enum) |
| 80 | | from sqlalchemy.orm import relationship |
| 81 | | +from mediagoblin.db.base import Base |
| 82 | | +from mediagoblin.db.models import User |
| 83 | | +from mediagoblin.plugins.oauth.tools import generate_identifier, \ |
| 84 | | + generate_secret, generate_token, generate_code, generate_refresh_token |
| 85 | | |
| 86 | | # Don't remove this, I *think* it applies sqlalchemy-migrate functionality onto |
| 87 | | # the models. |
| 88 | | @@ -41,8 +43,9 @@ class OAuthClient(Base): |
| 89 | | name = Column(Unicode) |
| 90 | | description = Column(Unicode) |
| 91 | | |
| 92 | | - identifier = Column(Unicode, unique=True, index=True) |
| 93 | | - secret = Column(Unicode, index=True) |
| 94 | | + identifier = Column(Unicode, unique=True, index=True, |
| 95 | | + default=generate_identifier) |
| 96 | | + secret = Column(Unicode, index=True, default=generate_secret) |
| 97 | | |
| 98 | | owner_id = Column(Integer, ForeignKey(User.id)) |
| 99 | | owner = relationship(User, backref='registered_clients') |
| 100 | | @@ -54,14 +57,8 @@ class OAuthClient(Base): |
| 101 | | u'public', |
| 102 | | name=u'oauth__client_type')) |
| 103 | | |
| 104 | | - def generate_identifier(self): |
| 105 | | - self.identifier = unicode(uuid.uuid4()) |
| 106 | | - |
| 107 | | - def generate_secret(self): |
| 108 | | - self.secret = unicode( |
| 109 | | - bcrypt.hashpw( |
| 110 | | - unicode(uuid.uuid4()), |
| 111 | | - bcrypt.gensalt())) |
| 112 | | + def update_secret(self): |
| 113 | | + self.secret = generate_secret() |
| 114 | | |
| 115 | | def __repr__(self): |
| 116 | | return '<{0} {1}:{2} ({3})>'.format( |
| 117 | | @@ -76,10 +73,10 @@ class OAuthUserClient(Base): |
| 118 | | id = Column(Integer, primary_key=True) |
| 119 | | |
| 120 | | user_id = Column(Integer, ForeignKey(User.id)) |
| 121 | | - user = relationship(User, backref='oauth_clients') |
| 122 | | + user = relationship(User, backref='oauth_client_relations') |
| 123 | | |
| 124 | | client_id = Column(Integer, ForeignKey(OAuthClient.id)) |
| 125 | | - client = relationship(OAuthClient, backref='users') |
| 126 | | + client = relationship(OAuthClient, backref='oauth_user_relations') |
| 127 | | |
| 128 | | state = Column(Enum( |
| 129 | | u'approved', |
| 130 | | @@ -103,8 +100,7 @@ class OAuthToken(Base): |
| 131 | | default=datetime.now) |
| 132 | | expires = Column(DateTime, nullable=False, |
| 133 | | default=lambda: datetime.now() + timedelta(days=30)) |
| 134 | | - token = Column(Unicode, index=True) |
| 135 | | - refresh_token = Column(Unicode, index=True) |
| 136 | | + token = Column(Unicode, index=True, default=generate_token) |
| 137 | | |
| 138 | | user_id = Column(Integer, ForeignKey(User.id), nullable=False, |
| 139 | | index=True) |
| 140 | | @@ -121,6 +117,31 @@ class OAuthToken(Base): |
| 141 | | self.user, |
| 142 | | self.client) |
| 143 | | |
| 144 | | +class OAuthRefreshToken(Base): |
| 145 | | + __tablename__ = 'oauth__refresh_tokens' |
| 146 | | + |
| 147 | | + id = Column(Integer, primary_key=True) |
| 148 | | + created = Column(DateTime, nullable=False, |
| 149 | | + default=datetime.now) |
| 150 | | + |
| 151 | | + token = Column(Unicode, index=True, |
| 152 | | + default=generate_refresh_token) |
| 153 | | + |
| 154 | | + user_id = Column(Integer, ForeignKey(User.id), nullable=False, |
| 155 | | + index=True) |
| 156 | | + |
| 157 | | + user = relationship(User) |
| 158 | | + |
| 159 | | + client_id = Column(Integer, ForeignKey(OAuthClient.id), nullable=False) |
| 160 | | + client = relationship(OAuthClient) |
| 161 | | + |
| 162 | | + def __repr__(self): |
| 163 | | + return '<{0} #{1} [{3}, {4}]>'.format( |
| 164 | | + self.__class__.__name__, |
| 165 | | + self.id, |
| 166 | | + self.user, |
| 167 | | + self.client) |
| 168 | | + |
| 169 | | |
| 170 | | class OAuthCode(Base): |
| 171 | | __tablename__ = 'oauth__codes' |
| 172 | | @@ -130,7 +151,7 @@ class OAuthCode(Base): |
| 173 | | default=datetime.now) |
| 174 | | expires = Column(DateTime, nullable=False, |
| 175 | | default=lambda: datetime.now() + timedelta(minutes=5)) |
| 176 | | - code = Column(Unicode, index=True) |
| 177 | | + code = Column(Unicode, index=True, default=generate_code) |
| 178 | | |
| 179 | | user_id = Column(Integer, ForeignKey(User.id), nullable=False, |
| 180 | | index=True) |
| 181 | | @@ -150,6 +171,7 @@ class OAuthCode(Base): |
| 182 | | |
| 183 | | MODELS = [ |
| 184 | | OAuthToken, |
| 185 | | + OAuthRefreshToken, |
| 186 | | OAuthCode, |
| 187 | | OAuthClient, |
| 188 | | OAuthUserClient] |
| 189 | | diff --git a/mediagoblin/plugins/oauth/tools.py b/mediagoblin/plugins/oauth/tools.py |
| 190 | | index d21c8a5..25d0977 100644 |
| 191 | | --- a/mediagoblin/plugins/oauth/tools.py |
| 192 | | +++ b/mediagoblin/plugins/oauth/tools.py |
| 193 | | @@ -1,3 +1,4 @@ |
| 194 | | +# -*- coding: utf-8 -*- |
| 195 | | # GNU MediaGoblin -- federated, autonomous media hosting |
| 196 | | # Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS. |
| 197 | | # |
| 198 | | @@ -14,13 +15,25 @@ |
| 199 | | # You should have received a copy of the GNU Affero General Public License |
| 200 | | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
| 201 | | |
| 202 | | +import uuid |
| 203 | | +import bcrypt |
| 204 | | + |
| 205 | | +from datetime import datetime |
| 206 | | + |
| 207 | | from functools import wraps |
| 208 | | |
| 209 | | -from mediagoblin.plugins.oauth.models import OAuthClient |
| 210 | | from mediagoblin.plugins.api.tools import json_response |
| 211 | | |
| 212 | | |
| 213 | | def require_client_auth(controller): |
| 214 | | + ''' |
| 215 | | + View decorator |
| 216 | | + |
| 217 | | + - Requires the presence of ``?client_id`` |
| 218 | | + ''' |
| 219 | | + # Avoid circular import |
| 220 | | + from mediagoblin.plugins.oauth.models import OAuthClient |
| 221 | | + |
| 222 | | @wraps(controller) |
| 223 | | def wrapper(request, *args, **kw): |
| 224 | | if not request.GET.get('client_id'): |
| 225 | | @@ -41,3 +54,63 @@ def require_client_auth(controller): |
| 226 | | return controller(request, client) |
| 227 | | |
| 228 | | return wrapper |
| 229 | | + |
| 230 | | + |
| 231 | | +def create_token(client, user): |
| 232 | | + ''' |
| 233 | | + Create an OAuthToken and an OAuthRefreshToken entry in the database |
| 234 | | + |
| 235 | | + Returns the data structure expected by the OAuth clients. |
| 236 | | + ''' |
| 237 | | + from mediagoblin.plugins.oauth.models import OAuthToken, OAuthRefreshToken |
| 238 | | + |
| 239 | | + token = OAuthToken() |
| 240 | | + token.user = user |
| 241 | | + token.client = client |
| 242 | | + token.save() |
| 243 | | + |
| 244 | | + refresh_token = OAuthRefreshToken() |
| 245 | | + refresh_token.user = user |
| 246 | | + refresh_token.client = client |
| 247 | | + refresh_token.save() |
| 248 | | + |
| 249 | | + # expire time of token in full seconds |
| 250 | | + # timedelta.total_seconds is python >= 2.7 or we would use that |
| 251 | | + td = token.expires - datetime.now() |
| 252 | | + exp_in = 86400*td.days + td.seconds # just ignore µsec |
| 253 | | + |
| 254 | | + return {'access_token': token.token, 'token_type': 'bearer', |
| 255 | | + 'refresh_token': refresh_token.token, 'expires_in': exp_in} |
| 256 | | + |
| 257 | | + |
| 258 | | +def generate_identifier(): |
| 259 | | + ''' Generates a ``uuid.uuid4()`` ''' |
| 260 | | + return unicode(uuid.uuid4()) |
| 261 | | + |
| 262 | | + |
| 263 | | +def generate_token(): |
| 264 | | + ''' Uses generate_identifier ''' |
| 265 | | + return generate_identifier() |
| 266 | | + |
| 267 | | + |
| 268 | | +def generate_refresh_token(): |
| 269 | | + ''' Uses generate_identifier ''' |
| 270 | | + return generate_identifier() |
| 271 | | + |
| 272 | | + |
| 273 | | +def generate_code(): |
| 274 | | + ''' Uses generate_identifier ''' |
| 275 | | + return generate_identifier() |
| 276 | | + |
| 277 | | + |
| 278 | | +def generate_secret(): |
| 279 | | + ''' |
| 280 | | + Generate a long string of pseudo-random characters |
| 281 | | + ''' |
| 282 | | + # XXX: We might not want it to use bcrypt, since bcrypt takes its time to |
| 283 | | + # generate the result. |
| 284 | | + return unicode( |
| 285 | | + bcrypt.hashpw( |
| 286 | | + unicode(uuid.uuid4()), |
| 287 | | + bcrypt.gensalt())) |
| 288 | | + |
| 289 | | diff --git a/mediagoblin/plugins/oauth/views.py b/mediagoblin/plugins/oauth/views.py |
| 290 | | index c7b2a33..ad8ea8f 100644 |
| 291 | | --- a/mediagoblin/plugins/oauth/views.py |
| 292 | | +++ b/mediagoblin/plugins/oauth/views.py |
| 293 | | @@ -16,21 +16,21 @@ |
| 294 | | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
| 295 | | |
| 296 | | import logging |
| 297 | | -import json |
| 298 | | |
| 299 | | from urllib import urlencode |
| 300 | | -from uuid import uuid4 |
| 301 | | -from datetime import datetime |
| 302 | | + |
| 303 | | +from werkzeug.exceptions import BadRequest |
| 304 | | |
| 305 | | from mediagoblin.tools.response import render_to_response, redirect |
| 306 | | from mediagoblin.decorators import require_active_login |
| 307 | | -from mediagoblin.messages import add_message, SUCCESS, ERROR |
| 308 | | +from mediagoblin.messages import add_message, SUCCESS |
| 309 | | from mediagoblin.tools.translate import pass_to_ugettext as _ |
| 310 | | -from mediagoblin.plugins.oauth.models import OAuthCode, OAuthToken, \ |
| 311 | | - OAuthClient, OAuthUserClient |
| 312 | | +from mediagoblin.plugins.oauth.models import OAuthCode, OAuthClient, \ |
| 313 | | + OAuthUserClient, OAuthRefreshToken |
| 314 | | from mediagoblin.plugins.oauth.forms import ClientRegistrationForm, \ |
| 315 | | AuthorizationForm |
| 316 | | -from mediagoblin.plugins.oauth.tools import require_client_auth |
| 317 | | +from mediagoblin.plugins.oauth.tools import require_client_auth, \ |
| 318 | | + create_token |
| 319 | | from mediagoblin.plugins.api.tools import json_response |
| 320 | | |
| 321 | | _log = logging.getLogger(__name__) |
| 322 | | @@ -51,9 +51,6 @@ def register_client(request): |
| 323 | | client.owner_id = request.user.id |
| 324 | | client.redirect_uri = unicode(request.form['redirect_uri']) |
| 325 | | |
| 326 | | - client.generate_identifier() |
| 327 | | - client.generate_secret() |
| 328 | | - |
| 329 | | client.save() |
| 330 | | |
| 331 | | add_message(request, SUCCESS, _('The client {0} has been registered!')\ |
| 332 | | @@ -92,8 +89,8 @@ def authorize_client(request): |
| 333 | | form.client_id.data).first() |
| 334 | | |
| 335 | | if not client: |
| 336 | | - _log.error('''No such client id as received from client authorization |
| 337 | | - form.''') |
| 338 | | + _log.error('No such client id as received from client authorization \ |
| 339 | | +form.') |
| 340 | | return BadRequest() |
| 341 | | |
| 342 | | if form.validate(): |
| 343 | | @@ -105,7 +102,7 @@ def authorize_client(request): |
| 344 | | elif form.deny.data: |
| 345 | | relation.state = u'rejected' |
| 346 | | else: |
| 347 | | - return BadRequest |
| 348 | | + return BadRequest() |
| 349 | | |
| 350 | | relation.save() |
| 351 | | |
| 352 | | @@ -136,7 +133,7 @@ def authorize(request, client): |
| 353 | | return json_response({ |
| 354 | | 'status': 400, |
| 355 | | 'errors': |
| 356 | | - [u'Public clients MUST have a redirect_uri pre-set']}, |
| 357 | | + [u'Public clients should have a redirect_uri pre-set.']}, |
| 358 | | _disable_cors=True) |
| 359 | | |
| 360 | | redirect_uri = client.redirect_uri |
| 361 | | @@ -146,11 +143,10 @@ def authorize(request, client): |
| 362 | | if not redirect_uri: |
| 363 | | return json_response({ |
| 364 | | 'status': 400, |
| 365 | | - 'errors': [u'Can not find a redirect_uri for client: {0}'\ |
| 366 | | - .format(client.name)]}, _disable_cors=True) |
| 367 | | + 'errors': [u'No redirect_uri supplied!']}, |
| 368 | | + _disable_cors=True) |
| 369 | | |
| 370 | | code = OAuthCode() |
| 371 | | - code.code = unicode(uuid4()) |
| 372 | | code.user = request.user |
| 373 | | code.client = client |
| 374 | | code.save() |
| 375 | | @@ -180,59 +176,79 @@ def authorize(request, client): |
| 376 | | |
| 377 | | |
| 378 | | def access_token(request): |
| 379 | | + ''' |
| 380 | | + Access token endpoint provides access tokens to any clients that have the |
| 381 | | + right grants/credentials |
| 382 | | + ''' |
| 383 | | + |
| 384 | | + client = None |
| 385 | | + user = None |
| 386 | | + |
| 387 | | if request.GET.get('code'): |
| 388 | | + # Validate the code arg, then get the client object from the db. |
| 389 | | code = OAuthCode.query.filter(OAuthCode.code == |
| 390 | | request.GET.get('code')).first() |
| 391 | | |
| 392 | | - if code: |
| 393 | | - if code.client.type == u'confidential': |
| 394 | | - client_identifier = request.GET.get('client_id') |
| 395 | | - |
| 396 | | - if not client_identifier: |
| 397 | | - return json_response({ |
| 398 | | - 'error': 'invalid_request', |
| 399 | | - 'error_description': |
| 400 | | - 'Missing client_id in request'}) |
| 401 | | - |
| 402 | | - client_secret = request.GET.get('client_secret') |
| 403 | | - |
| 404 | | - if not client_secret: |
| 405 | | - return json_response({ |
| 406 | | - 'error': 'invalid_request', |
| 407 | | - 'error_description': |
| 408 | | - 'Missing client_secret in request'}) |
| 409 | | - |
| 410 | | - if not client_secret == code.client.secret or \ |
| 411 | | - not client_identifier == code.client.identifier: |
| 412 | | - return json_response({ |
| 413 | | - 'error': 'invalid_client', |
| 414 | | - 'error_description': |
| 415 | | - 'The client_id or client_secret does not match the' |
| 416 | | - ' code'}) |
| 417 | | - |
| 418 | | - token = OAuthToken() |
| 419 | | - token.token = unicode(uuid4()) |
| 420 | | - token.user = code.user |
| 421 | | - token.client = code.client |
| 422 | | - token.save() |
| 423 | | - |
| 424 | | - # expire time of token in full seconds |
| 425 | | - # timedelta.total_seconds is python >= 2.7 or we would use that |
| 426 | | - td = token.expires - datetime.now() |
| 427 | | - exp_in = 86400*td.days + td.seconds # just ignore µsec |
| 428 | | - |
| 429 | | - access_token_data = { |
| 430 | | - 'access_token': token.token, |
| 431 | | - 'token_type': 'bearer', |
| 432 | | - 'expires_in': exp_in} |
| 433 | | - return json_response(access_token_data, _disable_cors=True) |
| 434 | | - else: |
| 435 | | + if not code: |
| 436 | | return json_response({ |
| 437 | | 'error': 'invalid_request', |
| 438 | | 'error_description': |
| 439 | | - 'Invalid code'}) |
| 440 | | - else: |
| 441 | | - return json_response({ |
| 442 | | - 'error': 'invalid_request', |
| 443 | | - 'error_descriptin': |
| 444 | | - 'Missing `code` parameter in request'}) |
| 445 | | + 'Invalid code.'}) |
| 446 | | + |
| 447 | | + client = code.client |
| 448 | | + user = code.user |
| 449 | | + |
| 450 | | + elif request.args.get('refresh_token'): |
| 451 | | + # Validate a refresh token, then get the client object from the db. |
| 452 | | + refresh_token = OAuthRefreshToken.query.filter( |
| 453 | | + OAuthRefreshToken.token == |
| 454 | | + request.args.get('refresh_token')).first() |
| 455 | | + |
| 456 | | + if not refresh_token: |
| 457 | | + return json_response({ |
| 458 | | + 'error': 'invalid_request', |
| 459 | | + 'error_description': |
| 460 | | + 'Invalid refresh token.'}) |
| 461 | | + |
| 462 | | + client = refresh_token.client |
| 463 | | + user = refresh_token.user |
| 464 | | + |
| 465 | | + if client: |
| 466 | | + client_identifier = request.GET.get('client_id') |
| 467 | | + |
| 468 | | + if not client_identifier: |
| 469 | | + return json_response({ |
| 470 | | + 'error': 'invalid_request', |
| 471 | | + 'error_description': |
| 472 | | + 'Missing client_id in request.'}) |
| 473 | | + |
| 474 | | + if not client_identifier == client.identifier: |
| 475 | | + return json_response({ |
| 476 | | + 'error': 'invalid_client', |
| 477 | | + 'error_description': |
| 478 | | + 'Mismatching client credentials.'}) |
| 479 | | + |
| 480 | | + if client.type == u'confidential': |
| 481 | | + client_secret = request.GET.get('client_secret') |
| 482 | | + |
| 483 | | + if not client_secret: |
| 484 | | + return json_response({ |
| 485 | | + 'error': 'invalid_request', |
| 486 | | + 'error_description': |
| 487 | | + 'Missing client_secret in request.'}) |
| 488 | | + |
| 489 | | + if not client_secret == client.secret: |
| 490 | | + return json_response({ |
| 491 | | + 'error': 'invalid_client', |
| 492 | | + 'error_description': |
| 493 | | + 'Mismatching client credentials.'}) |
| 494 | | + |
| 495 | | + |
| 496 | | + access_token_data = create_token(client, user) |
| 497 | | + |
| 498 | | + return json_response(access_token_data, _disable_cors=True) |
| 499 | | + |
| 500 | | + return json_response({ |
| 501 | | + 'error': 'invalid_request', |
| 502 | | + 'error_description': |
| 503 | | + 'Missing `code` or `refresh_token` parameter in request.'}) |
| 504 | | diff --git a/mediagoblin/tests/test_oauth.py b/mediagoblin/tests/test_oauth.py |
| 505 | | index 94ba5da..f036569 100644 |
| 506 | | --- a/mediagoblin/tests/test_oauth.py |
| 507 | | +++ b/mediagoblin/tests/test_oauth.py |
| 508 | | @@ -16,12 +16,13 @@ |
| 509 | | |
| 510 | | import json |
| 511 | | import logging |
| 512 | | +import urllib |
| 513 | | |
| 514 | | from urlparse import parse_qs, urlparse |
| 515 | | |
| 516 | | from mediagoblin import mg_globals |
| 517 | | from mediagoblin.tools import template, pluginapi |
| 518 | | -from mediagoblin.tests.tools import get_app, fixture_add_user |
| 519 | | +from mediagoblin.tests.tools import get_app, fixture_add_user, expect_failure |
| 520 | | |
| 521 | | |
| 522 | | _log = logging.getLogger(__name__) |
| 523 | | @@ -70,7 +71,7 @@ class TestOAuth(object): |
| 524 | | assert response.status_int == 200 |
| 525 | | |
| 526 | | # Should display an error |
| 527 | | - assert ctx['form'].redirect_uri.errors |
| 528 | | + assert len(ctx['form'].redirect_uri.errors) |
| 529 | | |
| 530 | | # Should not pass through |
| 531 | | assert not client |
| 532 | | @@ -78,12 +79,16 @@ class TestOAuth(object): |
| 533 | | def test_2_successful_public_client_registration(self): |
| 534 | | ''' Successfully register a public client ''' |
| 535 | | self.login() |
| 536 | | + uri = 'http://foo.example' |
| 537 | | self.register_client(u'OMGOMG', 'public', 'OMG!', |
| 538 | | - 'http://foo.example') |
| 539 | | + uri) |
| 540 | | |
| 541 | | client = self.db.OAuthClient.query.filter( |
| 542 | | self.db.OAuthClient.name == u'OMGOMG').first() |
| 543 | | |
| 544 | | + # redirect_uri should be set |
| 545 | | + assert client.redirect_uri == uri |
| 546 | | + |
| 547 | | # Client should have been registered |
| 548 | | assert client |
| 549 | | |
| 550 | | @@ -111,7 +116,7 @@ class TestOAuth(object): |
| 551 | | redirect_uri = 'https://foo.example' |
| 552 | | response = self.app.get('/oauth/authorize', { |
| 553 | | 'client_id': client.identifier, |
| 554 | | - 'scope': 'admin', |
| 555 | | + 'scope': 'all', |
| 556 | | 'redirect_uri': redirect_uri}) |
| 557 | | |
| 558 | | # User-agent should NOT be redirected |
| 559 | | @@ -137,6 +142,7 @@ class TestOAuth(object): |
| 560 | | return authorization_response, client_identifier |
| 561 | | |
| 562 | | def get_code_from_redirect_uri(self, uri): |
| 563 | | + ''' Get the value of ?code= from an URI ''' |
| 564 | | return parse_qs(urlparse(uri).query)['code'][0] |
| 565 | | |
| 566 | | def test_token_endpoint_successful_confidential_request(self): |
| 567 | | @@ -162,6 +168,11 @@ code={1}&client_secret={2}'.format(client_id, code, client.secret)) |
| 568 | | assert type(token_data['expires_in']) == int |
| 569 | | assert token_data['expires_in'] > 0 |
| 570 | | |
| 571 | | + # There should be a refresh token provided in the token data |
| 572 | | + assert len(token_data['refresh_token']) |
| 573 | | + |
| 574 | | + return client_id, token_data |
| 575 | | + |
| 576 | | def test_token_endpont_missing_id_confidential_request(self): |
| 577 | | ''' Unsuccessful request against token endpoint, missing client_id ''' |
| 578 | | code_redirect, client_id = self.test_4_authorize_confidential_client() |
| 579 | | @@ -181,4 +192,30 @@ code={0}&client_secret={1}'.format(code, client.secret)) |
| 580 | | assert 'error' in token_data |
| 581 | | assert not 'access_token' in token_data |
| 582 | | assert token_data['error'] == 'invalid_request' |
| 583 | | - assert token_data['error_description'] == 'Missing client_id in request' |
| 584 | | + assert len(token_data['error_description']) |
| 585 | | + |
| 586 | | + def test_refresh_token(self): |
| 587 | | + ''' Try to get a new access token using the refresh token ''' |
| 588 | | + # Get an access token and a refresh token |
| 589 | | + client_id, token_data =\ |
| 590 | | + self.test_token_endpoint_successful_confidential_request() |
| 591 | | + |
| 592 | | + client = self.db.OAuthClient.query.filter( |
| 593 | | + self.db.OAuthClient.identifier == client_id).first() |
| 594 | | + |
| 595 | | + token_res = self.app.get('/oauth/access_token', |
| 596 | | + {'refresh_token': token_data['refresh_token'], |
| 597 | | + 'client_id': client_id, |
| 598 | | + 'client_secret': client.secret |
| 599 | | + }) |
| 600 | | + |
| 601 | | + assert token_res.status_int == 200 |
| 602 | | + |
| 603 | | + new_token_data = json.loads(token_res.body) |
| 604 | | + |
| 605 | | + assert not 'error' in new_token_data |
| 606 | | + assert 'access_token' in new_token_data |
| 607 | | + assert 'token_type' in new_token_data |
| 608 | | + assert 'expires_in' in new_token_data |
| 609 | | + assert type(new_token_data['expires_in']) == int |
| 610 | | + assert new_token_data['expires_in'] > 0 |
| 611 | | diff --git a/mediagoblin/tests/tools.py b/mediagoblin/tests/tools.py |
| 612 | | index cc4a7ad..71954e0 100644 |
| 613 | | --- a/mediagoblin/tests/tools.py |
| 614 | | +++ b/mediagoblin/tests/tools.py |
| 615 | | @@ -19,6 +19,8 @@ import os |
| 616 | | import pkg_resources |
| 617 | | import shutil |
| 618 | | |
| 619 | | +import nose |
| 620 | | + |
| 621 | | from functools import wraps |
| 622 | | |
| 623 | | from paste.deploy import loadapp |
| 624 | | @@ -235,7 +237,7 @@ def fixture_media_entry(title=u"Some title", slug=None, |
| 625 | | entry.slug = slug |
| 626 | | entry.uploader = uploader or fixture_add_user().id |
| 627 | | entry.media_type = u'image' |
| 628 | | - |
| 629 | | + |
| 630 | | if gen_slug: |
| 631 | | entry.generate_slug() |
| 632 | | if save: |
| 633 | | @@ -263,3 +265,16 @@ def fixture_add_collection(name=u"My first Collection", user=None): |
| 634 | | Session.expunge(coll) |
| 635 | | |
| 636 | | return coll |
| 637 | | + |
| 638 | | + |
| 639 | | +def expect_failure(test): |
| 640 | | + @wraps(test) |
| 641 | | + def wrapper(*args, **kwargs): |
| 642 | | + try: |
| 643 | | + test(*args, **kwargs) |
| 644 | | + except Exception: |
| 645 | | + raise nose.SkipTest |
| 646 | | + else: |
| 647 | | + raise AssertionError( |
| 648 | | + 'Test is expected to fail, did you add a feature?') |
| 649 | | + return wrapper |
| 650 | | }}} |
| | 3 | ''Edited to remove really long diff.'' |